Sampling from GaussianConditional

release/4.3a0
Frank Dellaert 2022-02-06 17:31:13 -05:00
parent bc233fc967
commit f9e6282a2c
3 changed files with 79 additions and 1 deletions

View File

@ -18,6 +18,7 @@
#include <gtsam/linear/linearExceptions.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/linear/Sampler.h>
#include <boost/format.hpp>
#ifdef __GNUC__
@ -220,7 +221,35 @@ namespace gtsam {
}
}
/* ************************************************************************* */
/* ************************************************************************ */
VectorValues GaussianConditional::sample(
const VectorValues& parentsValues) const {
if (nrFrontals() != 1) {
throw std::invalid_argument(
"GaussianConditional::sample can only be called on single variable "
"conditionals");
}
if (!model_) {
throw std::invalid_argument(
"GaussianConditional::sample can only be called if a diagonal noise "
"model was specified at construction.");
}
VectorValues solution = solve(parentsValues);
Sampler sampler(model_);
Key key = firstFrontalKey();
solution[key] += sampler.sample();
return solution;
}
VectorValues GaussianConditional::sample() const {
if (nrParents() != 0)
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
VectorValues values;
return sample(values);
}
/* ************************************************************************ */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void GTSAM_DEPRECATED
GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const {

View File

@ -44,6 +44,9 @@ namespace gtsam {
typedef JacobianFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
/// @name Constructors
/// @{
/** default constructor needed for serialization */
GaussianConditional() {}
@ -99,6 +102,10 @@ namespace gtsam {
template<typename ITERATOR>
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
/// @}
/// @name Testable
/// @{
/** print */
void print(const std::string& = "GaussianConditional",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
@ -106,6 +113,10 @@ namespace gtsam {
/** equals function */
bool equals(const GaussianFactor&cg, double tol = 1e-9) const override;
/// @}
/// @name Standard Interface
/// @{
/** Return a view of the upper-triangular R block of the conditional */
constABlock R() const { return Ab_.range(0, nrFrontals()); }
@ -138,10 +149,25 @@ namespace gtsam {
/** Performs transpose backsubstition in place on values */
void solveTransposeInPlace(VectorValues& gy) const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
VectorValues sample(const VectorValues& parentsValues) const;
/// Zero parent version.
VectorValues sample() const;
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated
/// @{
/** Scale the values in \c gy according to the sigmas for the frontal variables in this
* conditional. */
void GTSAM_DEPRECATED scaleFrontalsBySigma(VectorValues& gy) const;
/// @}
#endif
private:

View File

@ -23,6 +23,7 @@
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/JacobianFactor.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/GaussianDensity.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <boost/assign/std/list.hpp>
@ -339,6 +340,28 @@ TEST(GaussianConditional, FromMeanAndStddev) {
EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9);
}
/* ************************************************************************* */
// Test sampling
TEST(GaussianConditional, sample) {
Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished();
const Vector2 b(20, 40), x1(3, 4);
const double sigma = 0.01;
auto density = GaussianDensity::FromMeanAndStddev(X(0), b, sigma);
auto actual1 = density.sample();
EXPECT_LONGS_EQUAL(1, actual1.size());
EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma));
VectorValues values;
values.insert(X(1), x1);
auto conditional =
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
auto actual2 = conditional.sample(values);
EXPECT_LONGS_EQUAL(1, actual2.size());
EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma));
}
/* ************************************************************************* */
int main() {
TestResult tr;