normalizationConstants returns all constants as a DecisionTreeFactor
parent
618ac28f2c
commit
34a9aef6f3
|
@ -170,21 +170,41 @@ KeyVector GaussianMixture::continuousParents() const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||
const VectorValues &frontals) const {
|
||||
// Check that values has all frontals
|
||||
for (auto &&kv : frontals) {
|
||||
if (frontals.find(kv.first) == frontals.end()) {
|
||||
throw std::runtime_error("GaussianMixture: frontals missing factor key.");
|
||||
boost::shared_ptr<DecisionTreeFactor> GaussianMixture::normalizationConstants()
|
||||
const {
|
||||
DecisionTree<Key, double> constants(
|
||||
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return conditional->normalizationConstant();
|
||||
});
|
||||
// If all constants the same, return nullptr:
|
||||
if (constants.nrLeaves() == 1) return nullptr;
|
||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys(), constants);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const {
|
||||
for (auto &&kv : given) {
|
||||
if (given.find(kv.first) == given.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||
const VectorValues &given) const {
|
||||
if (!allFrontalsGiven(given)) {
|
||||
throw std::runtime_error(
|
||||
"GaussianMixture::likelihood: given values are missing some frontals.");
|
||||
}
|
||||
|
||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const GaussianMixtureFactor::Factors likelihoods(
|
||||
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return GaussianMixtureFactor::FactorAndConstant{
|
||||
conditional->likelihood(frontals),
|
||||
conditional->likelihood(given),
|
||||
conditional->logNormalizationConstant()};
|
||||
});
|
||||
return boost::make_shared<GaussianMixtureFactor>(
|
||||
|
@ -285,8 +305,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
|||
return 1e50;
|
||||
}
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||
return errorTree;
|
||||
return DecisionTree<Key, double>(conditionals_, errorFunc);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -155,10 +155,16 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// Returns the continuous keys among the parents.
|
||||
KeyVector continuousParents() const;
|
||||
|
||||
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
|
||||
// on the parents.
|
||||
/// Return a discrete factor with possibly varying normalization constants.
|
||||
/// If there is no variation, return nullptr.
|
||||
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
|
||||
|
||||
/**
|
||||
* Create a likelihood factor for a Gaussian mixture, return a Mixture factor
|
||||
* on the parents.
|
||||
*/
|
||||
boost::shared_ptr<GaussianMixtureFactor> likelihood(
|
||||
const VectorValues &frontals) const;
|
||||
const VectorValues &given) const;
|
||||
|
||||
/// Getter for the underlying Conditionals DecisionTree
|
||||
const Conditionals &conditionals() const;
|
||||
|
@ -233,6 +239,9 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// @}
|
||||
|
||||
private:
|
||||
/// Check whether `given` has values for all frontal keys.
|
||||
bool allFrontalsGiven(const VectorValues &given) const;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class Archive>
|
||||
|
|
|
@ -106,13 +106,16 @@ TEST(GaussianMixture, Error) {
|
|||
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
||||
X(2), S2, model);
|
||||
|
||||
// Create decision tree
|
||||
// Create Gaussian Mixture.
|
||||
DiscreteKey m1(M(1), 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
|
||||
// Check that normalizationConstants returns nullptr, as all constants equal.
|
||||
CHECK(!mixture.normalizationConstants());
|
||||
|
||||
VectorValues values;
|
||||
values.insert(X(1), Vector2::Ones());
|
||||
values.insert(X(2), Vector2::Zero());
|
||||
|
@ -163,6 +166,19 @@ TEST(GaussianMixture, ContinuousParents) {
|
|||
EXPECT(continuousParentKeys[0] == X(0));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Check we can create a DecisionTreeFactor with all normalization constants.
|
||||
TEST(GaussianMixture, NormalizationConstants) {
|
||||
const GaussianMixture gm = createSimpleGaussianMixture();
|
||||
|
||||
const auto factor = gm.normalizationConstants();
|
||||
|
||||
// Test with 1D Gaussian normalization constants for sigma 0.5 and 3:
|
||||
auto c = [](double sigma) { return 1.0 / (sqrt(2 * M_PI) * sigma); };
|
||||
const DecisionTreeFactor expected({M(0), 2}, {c(0.5), c(3)});
|
||||
EXPECT(assert_equal(expected, *factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Check that likelihood returns a mixture factor on the parents.
|
||||
TEST(GaussianMixture, Likelihood) {
|
||||
|
@ -186,7 +202,7 @@ TEST(GaussianMixture, Likelihood) {
|
|||
conditional->logNormalizationConstant()};
|
||||
});
|
||||
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
|
||||
EXPECT(assert_equal(*factor, expected));
|
||||
EXPECT(assert_equal(expected, *factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue