Fix SumFrontals test

release/4.3a0
Frank Dellaert 2023-01-01 11:43:52 -05:00
parent b09495376b
commit 4cb03b303b
2 changed files with 11 additions and 7 deletions

View File

@ -35,7 +35,7 @@ const DiscreteKey mode{M(0), 2};
* Create a tiny two variable hybrid model which represents * Create a tiny two variable hybrid model which represents
* the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). * the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
*/ */
static HybridBayesNet createHybridBayesNet(int num_measurements = 1) { HybridBayesNet createHybridBayesNet(int num_measurements = 1) {
// Create hybrid Bayes net. // Create hybrid Bayes net.
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
@ -60,8 +60,8 @@ static HybridBayesNet createHybridBayesNet(int num_measurements = 1) {
return bayesNet; return bayesNet;
} }
static HybridGaussianFactorGraph convertBayesNet(const HybridBayesNet& bayesNet, HybridGaussianFactorGraph convertBayesNet(const HybridBayesNet& bayesNet,
const HybridValues& sample) { const HybridValues& sample) {
HybridGaussianFactorGraph fg; HybridGaussianFactorGraph fg;
int num_measurements = bayesNet.size() - 2; int num_measurements = bayesNet.size() - 2;
for (int i = 0; i < num_measurements; i++) { for (int i = 0; i < num_measurements; i++) {
@ -74,7 +74,7 @@ static HybridGaussianFactorGraph convertBayesNet(const HybridBayesNet& bayesNet,
return fg; return fg;
} }
static HybridGaussianFactorGraph createHybridGaussianFactorGraph( HybridGaussianFactorGraph createHybridGaussianFactorGraph(
int num_measurements = 1) { int num_measurements = 1) {
auto bayesNet = createHybridBayesNet(num_measurements); auto bayesNet = createHybridBayesNet(num_measurements);
auto sample = bayesNet.sample(); auto sample = bayesNet.sample();

View File

@ -636,10 +636,14 @@ TEST(HybridGaussianFactorGraph, SumFrontals) {
// Expected decision tree with two factor graphs: // Expected decision tree with two factor graphs:
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0) // f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianMixture::Sum expected{ GaussianMixture::Sum expected{
M(0), GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}), M(0),
GaussianFactorGraph(std::vector<GF>{mixture->factor(d1), prior})}; {GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
-0.225791 /* regression */},
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d1), prior}),
-2.01755 /* regression */}};
EXPECT(assert_equal(expected(d0), sum(d0))); EXPECT(assert_equal(expected(d0), sum(d0), 1e-5));
EXPECT(assert_equal(expected(d1), sum(d1), 1e-5));
} }
/* ************************************************************************* */ /* ************************************************************************* */