normalizationConstants returns all constants as a DecisionTreeFactor

release/4.3a0
Frank Dellaert 2023-01-12 10:49:30 -08:00
parent 618ac28f2c
commit 34a9aef6f3
3 changed files with 58 additions and 14 deletions

View File

@ -170,21 +170,41 @@ KeyVector GaussianMixture::continuousParents() const {
}
/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const {
// Check that values has all frontals
for (auto &&kv : frontals) {
if (frontals.find(kv.first) == frontals.end()) {
throw std::runtime_error("GaussianMixture: frontals missing factor key.");
boost::shared_ptr<DecisionTreeFactor> GaussianMixture::normalizationConstants()
const {
DecisionTree<Key, double> constants(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->normalizationConstant();
});
// If all constants the same, return nullptr:
if (constants.nrLeaves() == 1) return nullptr;
return boost::make_shared<DecisionTreeFactor>(discreteKeys(), constants);
}
/* ************************************************************************* */
bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const {
for (auto &&kv : given) {
if (given.find(kv.first) == given.end()) {
return false;
}
}
return true;
}
/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &given) const {
if (!allFrontalsGiven(given)) {
throw std::runtime_error(
"GaussianMixture::likelihood: given values are missing some frontals.");
}
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(frontals),
conditional->likelihood(given),
conditional->logNormalizationConstant()};
});
return boost::make_shared<GaussianMixtureFactor>(
@ -285,8 +305,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return 1e50;
}
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
return DecisionTree<Key, double>(conditionals_, errorFunc);
}
/* *******************************************************************************/

View File

@ -155,10 +155,16 @@ class GTSAM_EXPORT GaussianMixture
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;
// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
// on the parents.
/// Return a discrete factor with possibly varying normalization constants.
/// If there is no variation, return nullptr.
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
/**
* Create a likelihood factor for a Gaussian mixture, return a Mixture factor
* on the parents.
*/
boost::shared_ptr<GaussianMixtureFactor> likelihood(
const VectorValues &frontals) const;
const VectorValues &given) const;
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals() const;
@ -233,6 +239,9 @@ class GTSAM_EXPORT GaussianMixture
/// @}
private:
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>

View File

@ -106,13 +106,16 @@ TEST(GaussianMixture, Error) {
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
X(2), S2, model);
// Create decision tree
// Create Gaussian Mixture.
DiscreteKey m1(M(1), 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals);
// Check that normalizationConstants returns nullptr, as all constants equal.
CHECK(!mixture.normalizationConstants());
VectorValues values;
values.insert(X(1), Vector2::Ones());
values.insert(X(2), Vector2::Zero());
@ -163,6 +166,19 @@ TEST(GaussianMixture, ContinuousParents) {
EXPECT(continuousParentKeys[0] == X(0));
}
/* ************************************************************************* */
/// Check we can create a DecisionTreeFactor with all normalization constants.
TEST(GaussianMixture, NormalizationConstants) {
const GaussianMixture gm = createSimpleGaussianMixture();
const auto factor = gm.normalizationConstants();
// Test with 1D Gaussian normalization constants for sigma 0.5 and 3:
auto c = [](double sigma) { return 1.0 / (sqrt(2 * M_PI) * sigma); };
const DecisionTreeFactor expected({M(0), 2}, {c(0.5), c(3)});
EXPECT(assert_equal(expected, *factor));
}
/* ************************************************************************* */
/// Check that likelihood returns a mixture factor on the parents.
TEST(GaussianMixture, Likelihood) {
@ -186,7 +202,7 @@ TEST(GaussianMixture, Likelihood) {
conditional->logNormalizationConstant()};
});
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
EXPECT(assert_equal(*factor, expected));
EXPECT(assert_equal(expected, *factor));
}
/* ************************************************************************* */