Sampling from GaussianConditional
parent
bc233fc967
commit
f9e6282a2c
|
@ -18,6 +18,7 @@
|
|||
#include <gtsam/linear/linearExceptions.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
#include <gtsam/linear/Sampler.h>
|
||||
|
||||
#include <boost/format.hpp>
|
||||
#ifdef __GNUC__
|
||||
|
@ -220,7 +221,35 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues GaussianConditional::sample(
|
||||
const VectorValues& parentsValues) const {
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"GaussianConditional::sample can only be called on single variable "
|
||||
"conditionals");
|
||||
}
|
||||
if (!model_) {
|
||||
throw std::invalid_argument(
|
||||
"GaussianConditional::sample can only be called if a diagonal noise "
|
||||
"model was specified at construction.");
|
||||
}
|
||||
VectorValues solution = solve(parentsValues);
|
||||
Sampler sampler(model_);
|
||||
Key key = firstFrontalKey();
|
||||
solution[key] += sampler.sample();
|
||||
return solution;
|
||||
}
|
||||
|
||||
VectorValues GaussianConditional::sample() const {
|
||||
if (nrParents() != 0)
|
||||
throw std::invalid_argument(
|
||||
"sample() can only be invoked on no-parent prior");
|
||||
VectorValues values;
|
||||
return sample(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
void GTSAM_DEPRECATED
|
||||
GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const {
|
||||
|
|
|
@ -44,6 +44,9 @@ namespace gtsam {
|
|||
typedef JacobianFactor BaseFactor; ///< Typedef to our factor base class
|
||||
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/** default constructor needed for serialization */
|
||||
GaussianConditional() {}
|
||||
|
||||
|
@ -99,6 +102,10 @@ namespace gtsam {
|
|||
template<typename ITERATOR>
|
||||
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** print */
|
||||
void print(const std::string& = "GaussianConditional",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
@ -106,6 +113,10 @@ namespace gtsam {
|
|||
/** equals function */
|
||||
bool equals(const GaussianFactor&cg, double tol = 1e-9) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Return a view of the upper-triangular R block of the conditional */
|
||||
constABlock R() const { return Ab_.range(0, nrFrontals()); }
|
||||
|
||||
|
@ -138,10 +149,25 @@ namespace gtsam {
|
|||
/** Performs transpose backsubstition in place on values */
|
||||
void solveTransposeInPlace(VectorValues& gy) const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @param parentsValues Known values of the parents
|
||||
* @return sample from conditional
|
||||
*/
|
||||
VectorValues sample(const VectorValues& parentsValues) const;
|
||||
|
||||
/// Zero parent version.
|
||||
VectorValues sample() const;
|
||||
|
||||
/// @}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated
|
||||
/// @{
|
||||
/** Scale the values in \c gy according to the sigmas for the frontal variables in this
|
||||
* conditional. */
|
||||
void GTSAM_DEPRECATED scaleFrontalsBySigma(VectorValues& gy) const;
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
private:
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/JacobianFactor.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include <gtsam/linear/GaussianDensity.h>
|
||||
#include <gtsam/linear/GaussianBayesNet.h>
|
||||
|
||||
#include <boost/assign/std/list.hpp>
|
||||
|
@ -339,6 +340,28 @@ TEST(GaussianConditional, FromMeanAndStddev) {
|
|||
EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test sampling
|
||||
TEST(GaussianConditional, sample) {
|
||||
Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished();
|
||||
const Vector2 b(20, 40), x1(3, 4);
|
||||
const double sigma = 0.01;
|
||||
|
||||
auto density = GaussianDensity::FromMeanAndStddev(X(0), b, sigma);
|
||||
auto actual1 = density.sample();
|
||||
EXPECT_LONGS_EQUAL(1, actual1.size());
|
||||
EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma));
|
||||
|
||||
VectorValues values;
|
||||
values.insert(X(1), x1);
|
||||
|
||||
auto conditional =
|
||||
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
|
||||
auto actual2 = conditional.sample(values);
|
||||
EXPECT_LONGS_EQUAL(1, actual2.size());
|
||||
EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue