Sampling from GaussianConditional
parent
bc233fc967
commit
f9e6282a2c
|
@ -18,6 +18,7 @@
|
||||||
#include <gtsam/linear/linearExceptions.h>
|
#include <gtsam/linear/linearExceptions.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
#include <gtsam/linear/VectorValues.h>
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
#include <gtsam/linear/Sampler.h>
|
||||||
|
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
|
@ -220,7 +221,35 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
|
VectorValues GaussianConditional::sample(
|
||||||
|
const VectorValues& parentsValues) const {
|
||||||
|
if (nrFrontals() != 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"GaussianConditional::sample can only be called on single variable "
|
||||||
|
"conditionals");
|
||||||
|
}
|
||||||
|
if (!model_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"GaussianConditional::sample can only be called if a diagonal noise "
|
||||||
|
"model was specified at construction.");
|
||||||
|
}
|
||||||
|
VectorValues solution = solve(parentsValues);
|
||||||
|
Sampler sampler(model_);
|
||||||
|
Key key = firstFrontalKey();
|
||||||
|
solution[key] += sampler.sample();
|
||||||
|
return solution;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorValues GaussianConditional::sample() const {
|
||||||
|
if (nrParents() != 0)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"sample() can only be invoked on no-parent prior");
|
||||||
|
VectorValues values;
|
||||||
|
return sample(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
void GTSAM_DEPRECATED
|
void GTSAM_DEPRECATED
|
||||||
GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const {
|
GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const {
|
||||||
|
|
|
@ -44,6 +44,9 @@ namespace gtsam {
|
||||||
typedef JacobianFactor BaseFactor; ///< Typedef to our factor base class
|
typedef JacobianFactor BaseFactor; ///< Typedef to our factor base class
|
||||||
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
||||||
|
|
||||||
|
/// @name Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** default constructor needed for serialization */
|
/** default constructor needed for serialization */
|
||||||
GaussianConditional() {}
|
GaussianConditional() {}
|
||||||
|
|
||||||
|
@ -99,6 +102,10 @@ namespace gtsam {
|
||||||
template<typename ITERATOR>
|
template<typename ITERATOR>
|
||||||
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
|
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(const std::string& = "GaussianConditional",
|
void print(const std::string& = "GaussianConditional",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
@ -106,6 +113,10 @@ namespace gtsam {
|
||||||
/** equals function */
|
/** equals function */
|
||||||
bool equals(const GaussianFactor&cg, double tol = 1e-9) const override;
|
bool equals(const GaussianFactor&cg, double tol = 1e-9) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** Return a view of the upper-triangular R block of the conditional */
|
/** Return a view of the upper-triangular R block of the conditional */
|
||||||
constABlock R() const { return Ab_.range(0, nrFrontals()); }
|
constABlock R() const { return Ab_.range(0, nrFrontals()); }
|
||||||
|
|
||||||
|
@ -138,10 +149,25 @@ namespace gtsam {
|
||||||
/** Performs transpose backsubstition in place on values */
|
/** Performs transpose backsubstition in place on values */
|
||||||
void solveTransposeInPlace(VectorValues& gy) const;
|
void solveTransposeInPlace(VectorValues& gy) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* sample
|
||||||
|
* @param parentsValues Known values of the parents
|
||||||
|
* @return sample from conditional
|
||||||
|
*/
|
||||||
|
VectorValues sample(const VectorValues& parentsValues) const;
|
||||||
|
|
||||||
|
/// Zero parent version.
|
||||||
|
VectorValues sample() const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated
|
||||||
|
/// @{
|
||||||
/** Scale the values in \c gy according to the sigmas for the frontal variables in this
|
/** Scale the values in \c gy according to the sigmas for the frontal variables in this
|
||||||
* conditional. */
|
* conditional. */
|
||||||
void GTSAM_DEPRECATED scaleFrontalsBySigma(VectorValues& gy) const;
|
void GTSAM_DEPRECATED scaleFrontalsBySigma(VectorValues& gy) const;
|
||||||
|
/// @}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
#include <gtsam/linear/GaussianDensity.h>
|
||||||
#include <gtsam/linear/GaussianBayesNet.h>
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
|
|
||||||
#include <boost/assign/std/list.hpp>
|
#include <boost/assign/std/list.hpp>
|
||||||
|
@ -339,6 +340,28 @@ TEST(GaussianConditional, FromMeanAndStddev) {
|
||||||
EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test sampling
|
||||||
|
TEST(GaussianConditional, sample) {
|
||||||
|
Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished();
|
||||||
|
const Vector2 b(20, 40), x1(3, 4);
|
||||||
|
const double sigma = 0.01;
|
||||||
|
|
||||||
|
auto density = GaussianDensity::FromMeanAndStddev(X(0), b, sigma);
|
||||||
|
auto actual1 = density.sample();
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual1.size());
|
||||||
|
EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma));
|
||||||
|
|
||||||
|
VectorValues values;
|
||||||
|
values.insert(X(1), x1);
|
||||||
|
|
||||||
|
auto conditional =
|
||||||
|
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
|
||||||
|
auto actual2 = conditional.sample(values);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2.size());
|
||||||
|
EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue