gtsam/gtsam/hybrid/tests/testHybridBayesNet.cpp

404 lines
14 KiB
C++
Raw Normal View History

2022-06-08 06:39:10 +08:00
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridBayesNet.cpp
* @brief Unit tests for HybridBayesNet
* @author Varun Agrawal
* @author Fan Jiang
* @author Frank Dellaert
* @date December 2021
*/
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
2022-08-27 04:45:44 +08:00
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
2022-06-08 08:13:42 +08:00
#include "Switching.h"
#include "TinyHybridExample.h"
2022-06-08 06:39:10 +08:00
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
2022-06-08 06:39:10 +08:00
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;
2022-06-08 06:39:10 +08:00
static const Key asiaKey = 0;
static const DiscreteKey Asia(asiaKey, 2);
2022-06-08 06:39:10 +08:00
/* ****************************************************************************/
// Test creation of a pure discrete Bayes net.
2022-06-08 06:39:10 +08:00
TEST(HybridBayesNet, Creation) {
HybridBayesNet bayesNet;
2023-01-06 03:52:56 +08:00
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
2022-06-08 06:39:10 +08:00
DiscreteConditional expected(Asia, "99/1");
2023-01-06 03:52:56 +08:00
CHECK(bayesNet.at(0)->asDiscrete());
EXPECT(assert_equal(expected, *bayesNet.at(0)->asDiscrete()));
2022-06-08 06:39:10 +08:00
}
2022-09-17 06:13:59 +08:00
/* ****************************************************************************/
// Test adding a Bayes net to another one.
2022-09-17 06:13:59 +08:00
TEST(HybridBayesNet, Add) {
HybridBayesNet bayesNet;
2023-01-06 03:52:56 +08:00
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
2022-09-17 06:13:59 +08:00
HybridBayesNet other;
2023-01-06 03:52:56 +08:00
other.add(bayesNet);
2022-09-17 06:13:59 +08:00
EXPECT(bayesNet.equals(other));
}
/* ****************************************************************************/
// Test evaluate for a pure discrete Bayes net P(Asia).
2023-01-01 07:07:17 +08:00
TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet;
2023-01-11 13:55:18 +08:00
bayesNet.emplace_back(new DiscreteConditional(Asia, "4/6"));
HybridValues values;
values.insert(asiaKey, 0);
2023-01-11 13:55:18 +08:00
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
}
2023-01-01 07:07:17 +08:00
/* ****************************************************************************/
// Test creation of a tiny hybrid Bayes net.
TEST(HybridBayesNet, Tiny) {
auto bn = tiny::createHybridBayesNet();
EXPECT_LONGS_EQUAL(3, bn.size());
const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(4, fg.size());
2023-01-01 07:07:17 +08:00
}
/* ****************************************************************************/
2022-12-29 01:55:38 +08:00
// Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
TEST(HybridBayesNet, evaluateHybrid) {
2023-01-06 03:52:56 +08:00
const auto continuousConditional = GaussianConditional::sharedMeanAndStddev(
2022-12-29 01:55:38 +08:00
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0);
2022-12-29 01:55:38 +08:00
const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)),
model1 = noiseModel::Diagonal::Sigmas(Vector1(3.0));
2022-12-29 01:55:38 +08:00
const auto conditional0 = boost::make_shared<GaussianConditional>(
X(1), Vector1::Constant(5), I_1x1, model0),
conditional1 = boost::make_shared<GaussianConditional>(
X(1), Vector1::Constant(2), I_1x1, model1);
2022-12-29 01:55:38 +08:00
// Create hybrid Bayes net.
HybridBayesNet bayesNet;
bayesNet.push_back(continuousConditional);
2023-01-06 03:52:56 +08:00
bayesNet.emplace_back(
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}));
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
// Create values at which to evaluate.
HybridValues values;
values.insert(asiaKey, 0);
2022-12-29 01:55:38 +08:00
values.insert(X(0), Vector1(-6));
values.insert(X(1), Vector1(1));
const double conditionalProbability =
2023-01-06 03:52:56 +08:00
continuousConditional->evaluate(values.continuous());
2022-12-29 01:55:38 +08:00
const double mixtureProbability = conditional0->evaluate(values.continuous());
EXPECT_DOUBLES_EQUAL(conditionalProbability * mixtureProbability * 0.99,
bayesNet.evaluate(values), 1e-9);
}
2022-06-08 06:39:10 +08:00
/* ****************************************************************************/
// Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) {
Switching s(4);
Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteValues assignment;
assignment[M(0)] = 1;
2022-06-08 06:39:10 +08:00
assignment[M(1)] = 1;
assignment[M(2)] = 0;
2022-06-08 06:39:10 +08:00
GaussianBayesNet gbn = hybridBayesNet->choose(assignment);
EXPECT_LONGS_EQUAL(4, gbn.size());
2023-01-06 04:27:08 +08:00
EXPECT(assert_equal(*(*hybridBayesNet->at(0)->asMixture())(assignment),
2022-06-08 06:39:10 +08:00
*gbn.at(0)));
2023-01-06 04:27:08 +08:00
EXPECT(assert_equal(*(*hybridBayesNet->at(1)->asMixture())(assignment),
2022-06-08 06:39:10 +08:00
*gbn.at(1)));
2023-01-06 04:27:08 +08:00
EXPECT(assert_equal(*(*hybridBayesNet->at(2)->asMixture())(assignment),
2022-06-08 06:39:10 +08:00
*gbn.at(2)));
2023-01-06 04:27:08 +08:00
EXPECT(assert_equal(*(*hybridBayesNet->at(3)->asMixture())(assignment),
2022-06-08 06:39:10 +08:00
*gbn.at(3)));
}
2022-08-26 23:35:24 +08:00
/* ****************************************************************************/
// Test Bayes net optimize
2022-08-27 04:45:44 +08:00
TEST(HybridBayesNet, OptimizeAssignment) {
2022-08-26 23:35:24 +08:00
Switching s(4);
Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteValues assignment;
assignment[M(0)] = 1;
2022-08-26 23:35:24 +08:00
assignment[M(1)] = 1;
assignment[M(2)] = 1;
VectorValues delta = hybridBayesNet->optimize(assignment);
// The linearization point has the same value as the key index,
// e.g. X(0) = 1, X(1) = 2,
2022-08-26 23:35:24 +08:00
// but the factors specify X(k) = k-1, so delta should be -1.
VectorValues expected_delta;
expected_delta.insert(make_pair(X(0), -Vector1::Ones()));
2022-08-26 23:35:24 +08:00
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
EXPECT(assert_equal(expected_delta, delta));
}
2022-08-27 04:45:44 +08:00
/* ****************************************************************************/
// Test Bayes net optimize
2022-08-27 04:45:44 +08:00
TEST(HybridBayesNet, Optimize) {
Switching s(4, 1.0, 0.1, {0, 1, 2, 3}, "1/1 1/1");
2022-08-27 04:45:44 +08:00
HybridBayesNet::shared_ptr hybridBayesNet =
2023-01-06 23:12:50 +08:00
s.linearizedFactorGraph.eliminateSequential();
2022-08-27 04:45:44 +08:00
HybridValues delta = hybridBayesNet->optimize();
// NOTE: The true assignment is 111, but the discrete priors cause 101
DiscreteValues expectedAssignment;
expectedAssignment[M(0)] = 1;
expectedAssignment[M(1)] = 1;
expectedAssignment[M(2)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
VectorValues expectedValues;
expectedValues.insert(X(0), -Vector1::Ones());
expectedValues.insert(X(1), -Vector1::Ones());
expectedValues.insert(X(2), -Vector1::Ones());
expectedValues.insert(X(3), -Vector1::Ones());
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}
2022-11-02 14:53:51 +08:00
/* ****************************************************************************/
2022-12-28 21:18:00 +08:00
// Test Bayes net error
2023-01-11 13:55:18 +08:00
TEST(HybridBayesNet, logProbability) {
2022-11-02 14:53:51 +08:00
Switching s(3);
HybridBayesNet::shared_ptr hybridBayesNet =
2023-01-06 23:12:50 +08:00
s.linearizedFactorGraph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
2022-11-02 14:53:51 +08:00
HybridValues delta = hybridBayesNet->optimize();
2023-01-11 13:55:18 +08:00
auto error_tree = hybridBayesNet->logProbability(delta.continuous());
2022-11-02 14:53:51 +08:00
2022-11-03 23:44:41 +08:00
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
2023-01-11 13:55:18 +08:00
std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
2022-11-02 14:53:51 +08:00
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
2022-11-02 14:53:51 +08:00
2023-01-11 13:55:18 +08:00
// logProbability on pruned Bayes net
2022-11-02 14:53:51 +08:00
auto prunedBayesNet = hybridBayesNet->prune(2);
2023-01-11 13:55:18 +08:00
auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous());
2022-11-02 14:53:51 +08:00
2023-01-11 13:55:18 +08:00
std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
2022-11-02 14:53:51 +08:00
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
pruned_leaves);
// regression
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
2022-11-02 14:53:51 +08:00
2023-01-11 13:55:18 +08:00
// Verify logProbability computation and check for specific logProbability
// value
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
const HybridValues hybridValues{delta.continuous(), discrete_values};
2023-01-11 13:55:18 +08:00
double logProbability = 0;
logProbability +=
hybridBayesNet->at(0)->asMixture()->logProbability(hybridValues);
logProbability +=
hybridBayesNet->at(1)->asMixture()->logProbability(hybridValues);
logProbability +=
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
// TODO(dellaert): the discrete errors are not added in logProbability tree!
EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values),
1e-9);
logProbability +=
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);
logProbability +=
hybridBayesNet->at(4)->asDiscrete()->logProbability(discrete_values);
EXPECT_DOUBLES_EQUAL(logProbability,
hybridBayesNet->logProbability(hybridValues), 1e-9);
2022-11-02 14:53:51 +08:00
}
2022-10-04 07:14:03 +08:00
/* ****************************************************************************/
// Test Bayes net pruning
2022-10-04 07:14:03 +08:00
TEST(HybridBayesNet, Prune) {
Switching s(4);
HybridBayesNet::shared_ptr hybridBayesNet =
2023-01-06 23:12:50 +08:00
s.linearizedFactorGraph.eliminateSequential();
2022-10-04 07:14:03 +08:00
HybridValues delta = hybridBayesNet->optimize();
auto prunedBayesNet = hybridBayesNet->prune(2);
HybridValues pruned_delta = prunedBayesNet.optimize();
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
}
/* ****************************************************************************/
// Test Bayes net updateDiscreteConditionals
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
Switching s(4);
HybridBayesNet::shared_ptr hybridBayesNet =
2023-01-06 23:12:50 +08:00
s.linearizedFactorGraph.eliminateSequential();
size_t maxNrLeaves = 3;
auto discreteConditionals = hybridBayesNet->discreteConditionals();
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves());
auto original_discrete_conditionals = *(hybridBayesNet->at(4)->asDiscrete());
// Prune!
hybridBayesNet->prune(maxNrLeaves);
// Functor to verify values against the original_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment,
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues choices(assignment);
if (prunedDecisionTree->operator()(choices) == 0) {
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
} else {
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
1e-9);
}
return 0.0;
};
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
auto pruned_discrete_conditionals = hybridBayesNet->at(4)->asDiscrete();
auto discrete_conditional_tree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
pruned_discrete_conditionals);
// The checker functor verifies the values for us.
discrete_conditional_tree->apply(checker);
}
2022-12-24 03:24:26 +08:00
/* ****************************************************************************/
// Test HybridBayesNet sampling.
TEST(HybridBayesNet, Sampling) {
HybridNonlinearFactorGraph nfg;
auto noise_model = noiseModel::Diagonal::Sigmas(Vector1(1.0));
auto zero_motion =
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
auto one_motion =
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
std::vector<NonlinearFactor::shared_ptr> factors = {zero_motion, one_motion};
nfg.emplace_shared<PriorFactor<double>>(X(0), 0.0, noise_model);
nfg.emplace_shared<MixtureFactor>(
2022-12-24 03:24:26 +08:00
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);
DiscreteKey mode(M(0), 2);
nfg.emplace_shared<DiscreteDistribution>(mode, "1/1");
2022-12-24 03:24:26 +08:00
Values initial;
double z0 = 0.0, z1 = 1.0;
initial.insert<double>(X(0), z0);
initial.insert<double>(X(1), z1);
// Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial);
// Eliminate into BN
2023-01-06 23:12:50 +08:00
HybridBayesNet::shared_ptr bn = fg->eliminateSequential();
2022-12-24 03:24:26 +08:00
// Set up sampling
std::mt19937_64 gen(11);
// Initialize containers for computing the mean values.
vector<double> discrete_samples;
VectorValues average_continuous;
size_t num_samples = 1000;
for (size_t i = 0; i < num_samples; i++) {
// Sample
2022-12-24 11:38:47 +08:00
HybridValues sample = bn->sample(&gen);
2022-12-24 03:24:26 +08:00
discrete_samples.push_back(sample.discrete().at(M(0)));
2022-12-24 03:24:26 +08:00
if (i == 0) {
average_continuous.insert(sample.continuous());
} else {
average_continuous += sample.continuous();
}
}
2022-12-24 11:38:47 +08:00
EXPECT_LONGS_EQUAL(2, average_continuous.size());
EXPECT_LONGS_EQUAL(num_samples, discrete_samples.size());
// Regressions don't work across platforms :-(
// // regression for specific RNG seed
// double discrete_sum =
// std::accumulate(discrete_samples.begin(), discrete_samples.end(),
// decltype(discrete_samples)::value_type(0));
// EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9);
// VectorValues expected;
// expected.insert({X(0), Vector1(-0.0131207162712)});
// expected.insert({X(1), Vector1(-0.499026377568)});
// // regression for specific RNG seed
// EXPECT(assert_equal(expected, average_continuous.scale(1.0 /
// num_samples)));
2022-12-24 03:24:26 +08:00
}
2022-06-08 06:39:10 +08:00
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
2022-06-20 21:12:42 +08:00
/* ************************************************************************* */