Fixed toFactorGraph and added test to verify
parent
3a446d7008
commit
202a5a3264
|
@ -347,8 +347,6 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
|
||||||
fg.push_back(gc->likelihood(measurements));
|
fg.push_back(gc->likelihood(measurements));
|
||||||
} else if (auto gm = conditional->asMixture()) {
|
} else if (auto gm = conditional->asMixture()) {
|
||||||
fg.push_back(gm->likelihood(measurements));
|
fg.push_back(gm->likelihood(measurements));
|
||||||
const auto constantsFactor = gm->normalizationConstants();
|
|
||||||
if (constantsFactor) fg.push_back(constantsFactor);
|
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Unknown conditional type");
|
throw std::runtime_error("Unknown conditional type");
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,9 +77,17 @@ TEST(HybridBayesNet, Tiny) {
|
||||||
auto bn = tiny::createHybridBayesNet();
|
auto bn = tiny::createHybridBayesNet();
|
||||||
EXPECT_LONGS_EQUAL(3, bn.size());
|
EXPECT_LONGS_EQUAL(3, bn.size());
|
||||||
|
|
||||||
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}};
|
||||||
auto fg = bn.toFactorGraph(measurements);
|
auto fg = bn.toFactorGraph(vv);
|
||||||
EXPECT_LONGS_EQUAL(4, fg.size());
|
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||||
|
|
||||||
|
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
||||||
|
std::vector<double> ratio(2);
|
||||||
|
for (size_t mode : {0, 1}) {
|
||||||
|
const HybridValues hv{vv, {{M(0), mode}}};
|
||||||
|
ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv);
|
||||||
|
}
|
||||||
|
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
Loading…
Reference in New Issue