diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 21787ba7d..3aaffe6dc 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #ifdef __GNUC__ @@ -220,7 +221,35 @@ namespace gtsam { } } - /* ************************************************************************* */ + /* ************************************************************************ */ + VectorValues GaussianConditional::sample( + const VectorValues& parentsValues) const { + if (nrFrontals() != 1) { + throw std::invalid_argument( + "GaussianConditional::sample can only be called on single variable " + "conditionals"); + } + if (!model_) { + throw std::invalid_argument( + "GaussianConditional::sample can only be called if a diagonal noise " + "model was specified at construction."); + } + VectorValues solution = solve(parentsValues); + Sampler sampler(model_); + Key key = firstFrontalKey(); + solution[key] += sampler.sample(); + return solution; + } + + VectorValues GaussianConditional::sample() const { + if (nrParents() != 0) + throw std::invalid_argument( + "sample() can only be invoked on no-parent prior"); + VectorValues values; + return sample(values); + } + + /* ************************************************************************ */ #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 void GTSAM_DEPRECATED GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const { diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 12d85d98f..e44da195b 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -44,6 +44,9 @@ namespace gtsam { typedef JacobianFactor BaseFactor; ///< Typedef to our factor base class typedef Conditional BaseConditional; ///< Typedef to our conditional base class + /// @name Constructors + /// @{ + /** default constructor needed for serialization */ GaussianConditional() {} @@ -99,6 +102,10 @@ namespace gtsam { template static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional); + /// @} + /// @name Testable + /// @{ + /** print */ void print(const std::string& = "GaussianConditional", const KeyFormatter& formatter = DefaultKeyFormatter) const override; @@ -106,6 +113,10 @@ namespace gtsam { /** equals function */ bool equals(const GaussianFactor&cg, double tol = 1e-9) const override; + /// @} + /// @name Standard Interface + /// @{ + /** Return a view of the upper-triangular R block of the conditional */ constABlock R() const { return Ab_.range(0, nrFrontals()); } @@ -138,10 +149,25 @@ namespace gtsam { /** Performs transpose backsubstition in place on values */ void solveTransposeInPlace(VectorValues& gy) const; + /** + * sample + * @param parentsValues Known values of the parents + * @return sample from conditional + */ + VectorValues sample(const VectorValues& parentsValues) const; + + /// Zero parent version. + VectorValues sample() const; + + /// @} + #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated + /// @{ /** Scale the values in \c gy according to the sigmas for the frontal variables in this * conditional. */ void GTSAM_DEPRECATED scaleFrontalsBySigma(VectorValues& gy) const; + /// @} #endif private: diff --git a/gtsam/linear/tests/testGaussianConditional.cpp b/gtsam/linear/tests/testGaussianConditional.cpp index 4483066b4..ae9a2d94b 100644 --- a/gtsam/linear/tests/testGaussianConditional.cpp +++ b/gtsam/linear/tests/testGaussianConditional.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -339,6 +340,28 @@ TEST(GaussianConditional, FromMeanAndStddev) { EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9); } +/* ************************************************************************* */ +// Test sampling +TEST(GaussianConditional, sample) { + Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); + const Vector2 b(20, 40), x1(3, 4); + const double sigma = 0.01; + + auto density = GaussianDensity::FromMeanAndStddev(X(0), b, sigma); + auto actual1 = density.sample(); + EXPECT_LONGS_EQUAL(1, actual1.size()); + EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma)); + + VectorValues values; + values.insert(X(1), x1); + + auto conditional = + GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma); + auto actual2 = conditional.sample(values); + EXPECT_LONGS_EQUAL(1, actual2.size()); + EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma)); +} + /* ************************************************************************* */ int main() { TestResult tr;