release/4.3a0
Frank Dellaert 2022-12-31 18:27:13 -05:00
parent 4023e719ad
commit c8008cbb7c
2 changed files with 39 additions and 2 deletions

View File

@ -17,6 +17,7 @@
*/
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/inference/Symbol.h>
#pragma once
@ -59,5 +60,26 @@ static HybridBayesNet createHybridBayesNet(int num_measurements = 1) {
return bayesNet;
}
static HybridGaussianFactorGraph convertBayesNet(const HybridBayesNet& bayesNet,
const HybridValues& sample) {
HybridGaussianFactorGraph fg;
int num_measurements = bayesNet.size() - 2;
for (int i = 0; i < num_measurements; i++) {
auto conditional = bayesNet.atMixture(i);
auto factor = conditional->likelihood(sample.continuousSubset({Z(i)}));
fg.push_back(factor);
}
fg.push_back(bayesNet.atGaussian(num_measurements));
fg.push_back(bayesNet.atDiscrete(num_measurements + 1));
return fg;
}
static HybridGaussianFactorGraph createHybridGaussianFactorGraph(
int num_measurements = 1) {
auto bayesNet = createHybridBayesNet(num_measurements);
auto sample = bayesNet.sample();
return convertBayesNet(bayesNet, sample);
}
} // namespace tiny
} // namespace gtsam

View File

@ -47,6 +47,7 @@
#include <vector>
#include "Switching.h"
#include "TinyHybridExample.h"
using namespace std;
using namespace gtsam;
@ -133,8 +134,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
auto dc = result->at(2)->asDiscrete();
DiscreteValues dv;
dv[M(1)] = 0;
// regression
EXPECT_DOUBLES_EQUAL(8.5730017810851127, dc->operator()(dv), 1e-3);
// Regression test
EXPECT_DOUBLES_EQUAL(0.62245933120185448, dc->operator()(dv), 1e-3);
}
/* ************************************************************************* */
@ -613,6 +614,20 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
EXPECT(assert_equal(expected_probs, probs, 1e-7));
}
/* ****************************************************************************/
// Test creation of a tiny hybrid Bayes net.
TEST(HybridBayesNet, Tiny) {
auto fg = tiny::createHybridGaussianFactorGraph();
EXPECT_LONGS_EQUAL(3, fg.size());
}
/* ****************************************************************************/
// // Test summing frontals
// TEST(HybridGaussianFactorGraph, SumFrontals) {
// HybridGaussianFactorGraph fg;
// fg.
// }
/* ************************************************************************* */
int main() {
TestResult tr;