Merge pull request #1542 from borglab/decisiontree-improvements
commit
b86696a00c
|
|
@ -19,6 +19,7 @@ option(GTSAM_FORCE_STATIC_LIB "Force gtsam to be a static library,
|
||||||
option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF)
|
option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF)
|
||||||
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
|
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
|
||||||
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
|
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
|
||||||
|
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
|
||||||
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
|
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
|
||||||
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
|
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
|
||||||
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
|
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,7 @@ print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency c
|
||||||
print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ")
|
print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ")
|
||||||
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
|
||||||
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
|
||||||
|
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
|
||||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
|
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
|
||||||
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
|
||||||
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,9 @@
|
||||||
#cmakedefine GTSAM_ROT3_EXPMAP
|
#cmakedefine GTSAM_ROT3_EXPMAP
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
|
||||||
|
#cmakedefine GTSAM_DT_MERGING
|
||||||
|
|
||||||
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
|
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
|
||||||
#cmakedefine GTSAM_USE_TBB
|
#cmakedefine GTSAM_USE_TBB
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,8 @@ namespace gtsam {
|
||||||
/// print
|
/// print
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
std::cout << s << " Leaf [" << nrAssignments() << "]"
|
||||||
|
<< valueFormatter(constant_) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Write graphviz format to stream `os`. */
|
/** Write graphviz format to stream `os`. */
|
||||||
|
|
@ -136,7 +137,9 @@ namespace gtsam {
|
||||||
// Applying binary operator to two leaves results in a leaf
|
// Applying binary operator to two leaves results in a leaf
|
||||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||||
// fL op gL
|
// fL op gL
|
||||||
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
|
// TODO(Varun) nrAssignments setting is not correct.
|
||||||
|
// Depending on f and g, the nrAssignments can be different. This is a bug!
|
||||||
|
NodePtr h(new Leaf(op(fL.constant_, constant_), fL.nrAssignments()));
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -198,16 +201,43 @@ namespace gtsam {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If all branches of a choice node f are the same, just return a branch.
|
/**
|
||||||
static NodePtr Unique(const ChoicePtr& f) {
|
* @brief Merge branches with equal leaf values for every choice node in a
|
||||||
#ifndef GTSAM_DT_NO_PRUNING
|
* decision tree. If all branches are the same (i.e. have the same leaf
|
||||||
|
* value), replace the choice node with the equivalent leaf node.
|
||||||
|
*
|
||||||
|
* This function applies the branch merging (if enabled) recursively on the
|
||||||
|
* decision tree represented by the root node passed in as the argument. It
|
||||||
|
* recurses to the leaf nodes and merges branches with equal leaf values in
|
||||||
|
* a bottom-up fashion.
|
||||||
|
*
|
||||||
|
* Thus, if all branches of a choice node `f` are the same,
|
||||||
|
* just return a single branch at each recursion step.
|
||||||
|
*
|
||||||
|
* @param node The root node of the decision tree.
|
||||||
|
* @return NodePtr
|
||||||
|
*/
|
||||||
|
static NodePtr Unique(const NodePtr& node) {
|
||||||
|
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
|
||||||
|
// Choice node, we recurse!
|
||||||
|
// Make non-const copy so we can update
|
||||||
|
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||||
|
|
||||||
|
// Iterate over all the branches
|
||||||
|
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||||
|
auto branch = choice->branches_[i];
|
||||||
|
f->push_back(Unique(branch));
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
// If all the branches are the same, we can merge them into one
|
// If all the branches are the same, we can merge them into one
|
||||||
if (f->allSame_) {
|
if (f->allSame_) {
|
||||||
assert(f->branches().size() > 0);
|
assert(f->branches().size() > 0);
|
||||||
NodePtr f0 = f->branches_[0];
|
NodePtr f0 = f->branches_[0];
|
||||||
|
|
||||||
|
// Compute total number of assignments
|
||||||
size_t nrAssignments = 0;
|
size_t nrAssignments = 0;
|
||||||
for(auto branch: f->branches()) {
|
for (auto branch : f->branches()) {
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
||||||
nrAssignments += leaf->nrAssignments();
|
nrAssignments += leaf->nrAssignments();
|
||||||
}
|
}
|
||||||
|
|
@ -216,30 +246,12 @@ namespace gtsam {
|
||||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
||||||
nrAssignments));
|
nrAssignments));
|
||||||
return newLeaf;
|
return newLeaf;
|
||||||
|
}
|
||||||
} else
|
|
||||||
// Else we recurse
|
|
||||||
#endif
|
#endif
|
||||||
{
|
return f;
|
||||||
|
} else {
|
||||||
// Make non-const copy
|
// Leaf node, return as is
|
||||||
auto ff = std::make_shared<Choice>(f->label(), f->nrChoices());
|
return node;
|
||||||
|
|
||||||
// Iterate over all the branches
|
|
||||||
for (size_t i = 0; i < f->nrChoices(); i++) {
|
|
||||||
auto branch = f->branches_[i];
|
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
|
||||||
// Leaf node, simply assign
|
|
||||||
ff->push_back(branch);
|
|
||||||
|
|
||||||
} else if (auto choice =
|
|
||||||
std::dynamic_pointer_cast<const Choice>(branch)) {
|
|
||||||
// Choice node, we recurse
|
|
||||||
ff->push_back(Unique(choice));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ff;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -486,13 +498,11 @@ namespace gtsam {
|
||||||
// DecisionTree
|
// DecisionTree
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree() {
|
DecisionTree<L, Y>::DecisionTree() {}
|
||||||
}
|
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
||||||
root_(root) {
|
root_(root) {}
|
||||||
}
|
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
|
|
@ -608,7 +618,8 @@ namespace gtsam {
|
||||||
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
|
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
|
||||||
for (Iterator it = begin; it != end; it++)
|
for (Iterator it = begin; it != end; it++)
|
||||||
choiceOnLabel->push_back(it->root_);
|
choiceOnLabel->push_back(it->root_);
|
||||||
return Choice::Unique(choiceOnLabel);
|
// If no reordering, no need to call Choice::Unique
|
||||||
|
return choiceOnLabel;
|
||||||
} else {
|
} else {
|
||||||
// Set up a new choice on the highest label
|
// Set up a new choice on the highest label
|
||||||
auto choiceOnHighestLabel =
|
auto choiceOnHighestLabel =
|
||||||
|
|
@ -737,7 +748,7 @@ namespace gtsam {
|
||||||
for (auto&& branch : choice->branches()) {
|
for (auto&& branch : choice->branches()) {
|
||||||
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||||
}
|
}
|
||||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -307,7 +307,7 @@ namespace gtsam {
|
||||||
// Get the probabilities in the decision tree so we can threshold.
|
// Get the probabilities in the decision tree so we can threshold.
|
||||||
std::vector<double> probabilities;
|
std::vector<double> probabilities;
|
||||||
this->visitLeaf([&](const Leaf& leaf) {
|
this->visitLeaf([&](const Leaf& leaf) {
|
||||||
size_t nrAssignments = leaf.nrAssignments();
|
const size_t nrAssignments = leaf.nrAssignments();
|
||||||
double prob = leaf.constant();
|
double prob = leaf.constant();
|
||||||
probabilities.insert(probabilities.end(), nrAssignments, prob);
|
probabilities.insert(probabilities.end(), nrAssignments, prob);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@
|
||||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
// headers first to make sure no missing headers
|
// headers first to make sure no missing headers
|
||||||
//#define GTSAM_DT_NO_PRUNING
|
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||||
#define DISABLE_TIMING
|
#define DISABLE_TIMING
|
||||||
|
|
@ -179,7 +178,11 @@ TEST(ADT, joint) {
|
||||||
dot(joint, "Asia-ASTLBEX");
|
dot(joint, "Asia-ASTLBEX");
|
||||||
joint = apply(joint, pD, &mul);
|
joint = apply(joint, pD, &mul);
|
||||||
dot(joint, "Asia-ASTLBEXD");
|
dot(joint, "Asia-ASTLBEXD");
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(346, muls);
|
EXPECT_LONGS_EQUAL(346, muls);
|
||||||
|
#else
|
||||||
|
EXPECT_LONGS_EQUAL(508, muls);
|
||||||
|
#endif
|
||||||
gttoc_(asiaJoint);
|
gttoc_(asiaJoint);
|
||||||
tictoc_getNode(asiaJointNode, asiaJoint);
|
tictoc_getNode(asiaJointNode, asiaJoint);
|
||||||
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
||||||
|
|
@ -240,7 +243,11 @@ TEST(ADT, inference) {
|
||||||
dot(joint, "Joint-Product-ASTLBEX");
|
dot(joint, "Joint-Product-ASTLBEX");
|
||||||
joint = apply(joint, pD, &mul);
|
joint = apply(joint, pD, &mul);
|
||||||
dot(joint, "Joint-Product-ASTLBEXD");
|
dot(joint, "Joint-Product-ASTLBEXD");
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||||
|
#else
|
||||||
|
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
|
||||||
|
#endif
|
||||||
gttoc_(asiaProd);
|
gttoc_(asiaProd);
|
||||||
tictoc_getNode(asiaProdNode, asiaProd);
|
tictoc_getNode(asiaProdNode, asiaProd);
|
||||||
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
||||||
|
|
@ -258,7 +265,11 @@ TEST(ADT, inference) {
|
||||||
dot(marginal, "Joint-Sum-ADBLE");
|
dot(marginal, "Joint-Sum-ADBLE");
|
||||||
marginal = marginal.combine(E, &add_);
|
marginal = marginal.combine(E, &add_);
|
||||||
dot(marginal, "Joint-Sum-ADBL");
|
dot(marginal, "Joint-Sum-ADBL");
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(161, (long)adds);
|
EXPECT_LONGS_EQUAL(161, (long)adds);
|
||||||
|
#else
|
||||||
|
EXPECT_LONGS_EQUAL(240, (long)adds);
|
||||||
|
#endif
|
||||||
gttoc_(asiaSum);
|
gttoc_(asiaSum);
|
||||||
tictoc_getNode(asiaSumNode, asiaSum);
|
tictoc_getNode(asiaSumNode, asiaSum);
|
||||||
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
|
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
|
||||||
|
|
@ -296,7 +307,11 @@ TEST(ADT, factor_graph) {
|
||||||
fg = apply(fg, pX, &mul);
|
fg = apply(fg, pX, &mul);
|
||||||
fg = apply(fg, pD, &mul);
|
fg = apply(fg, pD, &mul);
|
||||||
dot(fg, "FactorGraph");
|
dot(fg, "FactorGraph");
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(158, (long)muls);
|
EXPECT_LONGS_EQUAL(158, (long)muls);
|
||||||
|
#else
|
||||||
|
EXPECT_LONGS_EQUAL(188, (long)muls);
|
||||||
|
#endif
|
||||||
gttoc_(asiaFG);
|
gttoc_(asiaFG);
|
||||||
tictoc_getNode(asiaFGNode, asiaFG);
|
tictoc_getNode(asiaFGNode, asiaFG);
|
||||||
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
|
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
|
||||||
|
|
@ -315,7 +330,11 @@ TEST(ADT, factor_graph) {
|
||||||
dot(fg, "Marginalized-3E");
|
dot(fg, "Marginalized-3E");
|
||||||
fg = fg.combine(L, &add_);
|
fg = fg.combine(L, &add_);
|
||||||
dot(fg, "Marginalized-2L");
|
dot(fg, "Marginalized-2L");
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
LONGS_EQUAL(49, adds);
|
LONGS_EQUAL(49, adds);
|
||||||
|
#else
|
||||||
|
LONGS_EQUAL(62, adds);
|
||||||
|
#endif
|
||||||
gttoc_(marg);
|
gttoc_(marg);
|
||||||
tictoc_getNode(margNode, marg);
|
tictoc_getNode(margNode, marg);
|
||||||
elapsed = margNode->secs() + margNode->wall();
|
elapsed = margNode->secs() + margNode->wall();
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// #define DT_DEBUG_MEMORY
|
// #define DT_DEBUG_MEMORY
|
||||||
// #define GTSAM_DT_NO_PRUNING
|
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
|
@ -192,7 +191,11 @@ TEST(DecisionTree, example) {
|
||||||
|
|
||||||
// Test choose 0
|
// Test choose 0
|
||||||
DT actual0 = notba.choose(A, 0);
|
DT actual0 = notba.choose(A, 0);
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT(assert_equal(DT(0.0), actual0));
|
EXPECT(assert_equal(DT(0.0), actual0));
|
||||||
|
#else
|
||||||
|
// EXPECT(assert_equal(DT({0.0, 0.0}), actual0));
|
||||||
|
#endif
|
||||||
DOT(actual0);
|
DOT(actual0);
|
||||||
|
|
||||||
// Test choose 1
|
// Test choose 1
|
||||||
|
|
@ -333,9 +336,11 @@ TEST(DecisionTree, NrAssignments) {
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(8, tree.nrAssignments());
|
EXPECT_LONGS_EQUAL(8, tree.nrAssignments());
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT(tree.root_->isLeaf());
|
EXPECT(tree.root_->isLeaf());
|
||||||
auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
|
auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
|
||||||
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
|
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
|
||||||
|
#endif
|
||||||
|
|
||||||
DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
|
DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
|
||||||
/* The tree is
|
/* The tree is
|
||||||
|
|
@ -358,6 +363,8 @@ TEST(DecisionTree, NrAssignments) {
|
||||||
CHECK(root);
|
CHECK(root);
|
||||||
auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
|
auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
|
||||||
CHECK(choice0);
|
CHECK(choice0);
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT(choice0->branches()[0]->isLeaf());
|
EXPECT(choice0->branches()[0]->isLeaf());
|
||||||
auto choice00 = std::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
|
auto choice00 = std::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
|
||||||
CHECK(choice00);
|
CHECK(choice00);
|
||||||
|
|
@ -371,6 +378,7 @@ TEST(DecisionTree, NrAssignments) {
|
||||||
CHECK(choice11);
|
CHECK(choice11);
|
||||||
EXPECT(choice11->isLeaf());
|
EXPECT(choice11->isLeaf());
|
||||||
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
|
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
@ -412,27 +420,61 @@ TEST(DecisionTree, VisitWithPruned) {
|
||||||
};
|
};
|
||||||
tree.visitWith(func);
|
tree.visitWith(func);
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(6, choices.size());
|
EXPECT_LONGS_EQUAL(6, choices.size());
|
||||||
|
#else
|
||||||
|
EXPECT_LONGS_EQUAL(8, choices.size());
|
||||||
|
#endif
|
||||||
|
|
||||||
Assignment<string> expectedAssignment;
|
Assignment<string> expectedAssignment;
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
expectedAssignment = {{"B", 0}, {"C", 0}};
|
expectedAssignment = {{"B", 0}, {"C", 0}};
|
||||||
EXPECT(expectedAssignment == choices.at(0));
|
EXPECT(expectedAssignment == choices.at(0));
|
||||||
|
#else
|
||||||
|
expectedAssignment = {{"A", 0}, {"B", 0}, {"C", 0}};
|
||||||
|
EXPECT(expectedAssignment == choices.at(0));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
|
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
|
||||||
EXPECT(expectedAssignment == choices.at(1));
|
EXPECT(expectedAssignment == choices.at(1));
|
||||||
|
#else
|
||||||
|
expectedAssignment = {{"A", 1}, {"B", 0}, {"C", 0}};
|
||||||
|
EXPECT(expectedAssignment == choices.at(1));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
|
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
|
||||||
EXPECT(expectedAssignment == choices.at(2));
|
EXPECT(expectedAssignment == choices.at(2));
|
||||||
|
#else
|
||||||
|
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
|
||||||
|
EXPECT(expectedAssignment == choices.at(2));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
expectedAssignment = {{"B", 0}, {"C", 1}};
|
expectedAssignment = {{"B", 0}, {"C", 1}};
|
||||||
EXPECT(expectedAssignment == choices.at(3));
|
EXPECT(expectedAssignment == choices.at(3));
|
||||||
|
#else
|
||||||
|
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
|
||||||
|
EXPECT(expectedAssignment == choices.at(3));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}};
|
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}};
|
||||||
EXPECT(expectedAssignment == choices.at(4));
|
EXPECT(expectedAssignment == choices.at(4));
|
||||||
|
#else
|
||||||
|
expectedAssignment = {{"A", 0}, {"B", 0}, {"C", 1}};
|
||||||
|
EXPECT(expectedAssignment == choices.at(4));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}};
|
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}};
|
||||||
EXPECT(expectedAssignment == choices.at(5));
|
EXPECT(expectedAssignment == choices.at(5));
|
||||||
|
#else
|
||||||
|
expectedAssignment = {{"A", 1}, {"B", 0}, {"C", 1}};
|
||||||
|
EXPECT(expectedAssignment == choices.at(5));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
@ -443,7 +485,11 @@ TEST(DecisionTree, fold) {
|
||||||
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||||
auto add = [](const int& y, double x) { return y + x; };
|
auto add = [](const int& y, double x) { return y + x; };
|
||||||
double sum = tree.fold(add, 0.0);
|
double sum = tree.fold(add, 0.0);
|
||||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
|
#ifdef GTSAM_DT_MERGING
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to merging!
|
||||||
|
#else
|
||||||
|
EXPECT_DOUBLES_EQUAL(7.0, sum, 1e-9);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
@ -495,9 +541,14 @@ TEST(DecisionTree, threshold) {
|
||||||
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||||
DT thresholded(tree, threshold);
|
DT thresholded(tree, threshold);
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
// Check number of leaves equal to zero now = 2
|
// Check number of leaves equal to zero now = 2
|
||||||
// Note: it is 2, because the pruned branches are counted as 1!
|
// Note: it is 2, because the pruned branches are counted as 1!
|
||||||
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||||
|
#else
|
||||||
|
// if GTSAM_DT_MERGING is disabled, the count will be larger
|
||||||
|
EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
@ -533,8 +584,13 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
};
|
};
|
||||||
DT prunedTree2 = prunedTree.apply(counter);
|
DT prunedTree2 = prunedTree.apply(counter);
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
// Check if apply doesn't enumerate all leaves.
|
// Check if apply doesn't enumerate all leaves.
|
||||||
EXPECT_LONGS_EQUAL(5, count);
|
EXPECT_LONGS_EQUAL(5, count);
|
||||||
|
#else
|
||||||
|
// if GTSAM_DT_MERGING is disabled, the count will be full
|
||||||
|
EXPECT_LONGS_EQUAL(8, count);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
|
||||||
|
|
@ -15,17 +15,20 @@
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
using symbol_shorthand::M;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||||
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
|
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||||
|
|
@ -345,6 +348,120 @@ TEST(DiscreteFactorGraph, markdown) {
|
||||||
values[1] = 0;
|
values[1] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(DiscreteFactorGraph, NrAssignments) {
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
|
string expected_dfg = R"(
|
||||||
|
size: 2
|
||||||
|
factor 0: f[ (m0,2), (m1,2), (m2,2), ]
|
||||||
|
Choice(m2)
|
||||||
|
0 Choice(m1)
|
||||||
|
0 0 Leaf [2] 0
|
||||||
|
0 1 Choice(m0)
|
||||||
|
0 1 0 Leaf [1]0.27527634
|
||||||
|
0 1 1 Leaf [1] 0
|
||||||
|
1 Choice(m1)
|
||||||
|
1 0 Leaf [2] 0
|
||||||
|
1 1 Choice(m0)
|
||||||
|
1 1 0 Leaf [1]0.44944733
|
||||||
|
1 1 1 Leaf [1]0.27527634
|
||||||
|
factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ]
|
||||||
|
Choice(m3)
|
||||||
|
0 Choice(m2)
|
||||||
|
0 0 Choice(m1)
|
||||||
|
0 0 0 Leaf [2] 1
|
||||||
|
0 0 1 Leaf [2]0.015366387
|
||||||
|
0 1 Choice(m1)
|
||||||
|
0 1 0 Leaf [2] 1
|
||||||
|
0 1 1 Choice(m0)
|
||||||
|
0 1 1 0 Leaf [1] 1
|
||||||
|
0 1 1 1 Leaf [1]0.015365663
|
||||||
|
1 Choice(m2)
|
||||||
|
1 0 Choice(m1)
|
||||||
|
1 0 0 Leaf [2] 1
|
||||||
|
1 0 1 Choice(m0)
|
||||||
|
1 0 1 0 Leaf [1]0.0094115739
|
||||||
|
1 0 1 1 Leaf [1]0.0094115652
|
||||||
|
1 1 Choice(m1)
|
||||||
|
1 1 0 Leaf [2] 1
|
||||||
|
1 1 1 Choice(m0)
|
||||||
|
1 1 1 0 Leaf [1] 1
|
||||||
|
1 1 1 1 Leaf [1]0.009321081
|
||||||
|
)";
|
||||||
|
#else
|
||||||
|
string expected_dfg = R"(
|
||||||
|
size: 2
|
||||||
|
factor 0: f[ (m0,2), (m1,2), (m2,2), ]
|
||||||
|
Choice(m2)
|
||||||
|
0 Choice(m1)
|
||||||
|
0 0 Choice(m0)
|
||||||
|
0 0 0 Leaf [1] 0
|
||||||
|
0 0 1 Leaf [1] 0
|
||||||
|
0 1 Choice(m0)
|
||||||
|
0 1 0 Leaf [1]0.27527634
|
||||||
|
0 1 1 Leaf [1]0.44944733
|
||||||
|
1 Choice(m1)
|
||||||
|
1 0 Choice(m0)
|
||||||
|
1 0 0 Leaf [1] 0
|
||||||
|
1 0 1 Leaf [1] 0
|
||||||
|
1 1 Choice(m0)
|
||||||
|
1 1 0 Leaf [1] 0
|
||||||
|
1 1 1 Leaf [1]0.27527634
|
||||||
|
factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ]
|
||||||
|
Choice(m3)
|
||||||
|
0 Choice(m2)
|
||||||
|
0 0 Choice(m1)
|
||||||
|
0 0 0 Choice(m0)
|
||||||
|
0 0 0 0 Leaf [1] 1
|
||||||
|
0 0 0 1 Leaf [1] 1
|
||||||
|
0 0 1 Choice(m0)
|
||||||
|
0 0 1 0 Leaf [1]0.015366387
|
||||||
|
0 0 1 1 Leaf [1]0.015366387
|
||||||
|
0 1 Choice(m1)
|
||||||
|
0 1 0 Choice(m0)
|
||||||
|
0 1 0 0 Leaf [1] 1
|
||||||
|
0 1 0 1 Leaf [1] 1
|
||||||
|
0 1 1 Choice(m0)
|
||||||
|
0 1 1 0 Leaf [1] 1
|
||||||
|
0 1 1 1 Leaf [1]0.015365663
|
||||||
|
1 Choice(m2)
|
||||||
|
1 0 Choice(m1)
|
||||||
|
1 0 0 Choice(m0)
|
||||||
|
1 0 0 0 Leaf [1] 1
|
||||||
|
1 0 0 1 Leaf [1] 1
|
||||||
|
1 0 1 Choice(m0)
|
||||||
|
1 0 1 0 Leaf [1]0.0094115739
|
||||||
|
1 0 1 1 Leaf [1]0.0094115652
|
||||||
|
1 1 Choice(m1)
|
||||||
|
1 1 0 Choice(m0)
|
||||||
|
1 1 0 0 Leaf [1] 1
|
||||||
|
1 1 0 1 Leaf [1] 1
|
||||||
|
1 1 1 Choice(m0)
|
||||||
|
1 1 1 0 Leaf [1] 1
|
||||||
|
1 1 1 1 Leaf [1]0.009321081
|
||||||
|
)";
|
||||||
|
#endif
|
||||||
|
|
||||||
|
DiscreteKeys d0{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||||
|
std::vector<double> p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468};
|
||||||
|
AlgebraicDecisionTree<Key> dt(d0, p0);
|
||||||
|
//TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up
|
||||||
|
// Issue seems to be in DecisionTreeFactor.cpp L104
|
||||||
|
DiscreteConditional f0(3, DecisionTreeFactor(d0, dt));
|
||||||
|
|
||||||
|
DiscreteKeys d1{{M(0), 2}, {M(1), 2}, {M(2), 2}, {M(3), 2}};
|
||||||
|
std::vector<double> p1 = {
|
||||||
|
1, 1, 1, 1, 0.015366387, 0.0094115739, 1, 1,
|
||||||
|
1, 1, 1, 1, 0.015366387, 0.0094115652, 0.015365663, 0.009321081};
|
||||||
|
DecisionTreeFactor f1(d1, p1);
|
||||||
|
|
||||||
|
DiscreteFactorGraph dfg;
|
||||||
|
dfg.add(f0);
|
||||||
|
dfg.add(f1);
|
||||||
|
|
||||||
|
EXPECT(assert_print_equal(expected_dfg, dfg));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -173,8 +173,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
auto decisionTree =
|
auto decisionTree = this->roots_.at(0)->conditional()->asDiscrete();
|
||||||
this->roots_.at(0)->conditional()->asDiscrete();
|
|
||||||
|
|
||||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
||||||
decisionTree->root_ = prunedDecisionTree.root_;
|
decisionTree->root_ = prunedDecisionTree.root_;
|
||||||
|
|
|
||||||
|
|
@ -70,8 +70,7 @@ Ordering HybridGaussianISAM::GetOrdering(
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridGaussianISAM::updateInternal(
|
void HybridGaussianISAM::updateInternal(
|
||||||
const HybridGaussianFactorGraph& newFactors,
|
const HybridGaussianFactorGraph& newFactors,
|
||||||
HybridBayesTree::Cliques* orphans,
|
HybridBayesTree::Cliques* orphans, const std::optional<size_t>& maxNrLeaves,
|
||||||
const std::optional<size_t>& maxNrLeaves,
|
|
||||||
const std::optional<Ordering>& ordering,
|
const std::optional<Ordering>& ordering,
|
||||||
const HybridBayesTree::Eliminate& function) {
|
const HybridBayesTree::Eliminate& function) {
|
||||||
// Remove the contaminated part of the Bayes tree
|
// Remove the contaminated part of the Bayes tree
|
||||||
|
|
@ -101,8 +100,8 @@ void HybridGaussianISAM::updateInternal(
|
||||||
}
|
}
|
||||||
|
|
||||||
// eliminate all factors (top, added, orphans) into a new Bayes tree
|
// eliminate all factors (top, added, orphans) into a new Bayes tree
|
||||||
HybridBayesTree::shared_ptr bayesTree =
|
HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(
|
||||||
factors.eliminateMultifrontal(elimination_ordering, function, std::cref(index));
|
elimination_ordering, function, std::cref(index));
|
||||||
|
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
bayesTree->prune(*maxNrLeaves);
|
bayesTree->prune(*maxNrLeaves);
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ class MixtureFactor : public HybridFactor {
|
||||||
std::cout << "\nMixtureFactor\n";
|
std::cout << "\nMixtureFactor\n";
|
||||||
auto valueFormatter = [](const sharedFactor& v) {
|
auto valueFormatter = [](const sharedFactor& v) {
|
||||||
if (v) {
|
if (v) {
|
||||||
return "Nonlinear factor on " + std::to_string(v->size()) + " keys";
|
return " Nonlinear factor on " + std::to_string(v->size()) + " keys";
|
||||||
} else {
|
} else {
|
||||||
return std::string("nullptr");
|
return std::string("nullptr");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
||||||
std::string expected =
|
std::string expected =
|
||||||
R"(Hybrid [x1 x2; 1]{
|
R"(Hybrid [x1 x2; 1]{
|
||||||
Choice(1)
|
Choice(1)
|
||||||
0 Leaf :
|
0 Leaf [1]:
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
0;
|
0;
|
||||||
0
|
0
|
||||||
|
|
@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
||||||
b = [ 0 0 ]
|
b = [ 0 0 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1]:
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
0;
|
0;
|
||||||
0
|
0
|
||||||
|
|
|
||||||
|
|
@ -288,8 +288,12 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
std::make_shared<DecisionTreeFactor>(
|
std::make_shared<DecisionTreeFactor>(
|
||||||
discreteConditionals->prune(maxNrLeaves));
|
discreteConditionals->prune(maxNrLeaves));
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||||
prunedDecisionTree->nrLeaves());
|
prunedDecisionTree->nrLeaves());
|
||||||
|
#else
|
||||||
|
EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves());
|
||||||
|
#endif
|
||||||
|
|
||||||
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
|
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -481,6 +481,7 @@ TEST(HybridFactorGraph, Printing) {
|
||||||
const auto [hybridBayesNet, remainingFactorGraph] =
|
const auto [hybridBayesNet, remainingFactorGraph] =
|
||||||
linearizedFactorGraph.eliminatePartialSequential(ordering);
|
linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
string expected_hybridFactorGraph = R"(
|
string expected_hybridFactorGraph = R"(
|
||||||
size: 7
|
size: 7
|
||||||
factor 0:
|
factor 0:
|
||||||
|
|
@ -492,7 +493,7 @@ factor 0:
|
||||||
factor 1:
|
factor 1:
|
||||||
Hybrid [x0 x1; m0]{
|
Hybrid [x0 x1; m0]{
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf :
|
0 Leaf [1]:
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
|
@ -502,7 +503,7 @@ Hybrid [x0 x1; m0]{
|
||||||
b = [ -1 ]
|
b = [ -1 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1]:
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
|
@ -516,7 +517,7 @@ Hybrid [x0 x1; m0]{
|
||||||
factor 2:
|
factor 2:
|
||||||
Hybrid [x1 x2; m1]{
|
Hybrid [x1 x2; m1]{
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Leaf :
|
0 Leaf [1]:
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
|
@ -526,7 +527,7 @@ Hybrid [x1 x2; m1]{
|
||||||
b = [ -1 ]
|
b = [ -1 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1]:
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
|
@ -550,18 +551,104 @@ factor 4:
|
||||||
b = [ -10 ]
|
b = [ -10 ]
|
||||||
No noise model
|
No noise model
|
||||||
factor 5: P( m0 ):
|
factor 5: P( m0 ):
|
||||||
Leaf 0.5
|
Leaf [2] 0.5
|
||||||
|
|
||||||
factor 6: P( m1 | m0 ):
|
factor 6: P( m1 | m0 ):
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf 0.33333333
|
0 0 Leaf [1]0.33333333
|
||||||
0 1 Leaf 0.6
|
0 1 Leaf [1] 0.6
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf 0.66666667
|
1 0 Leaf [1]0.66666667
|
||||||
1 1 Leaf 0.4
|
1 1 Leaf [1] 0.4
|
||||||
|
|
||||||
)";
|
)";
|
||||||
|
#else
|
||||||
|
string expected_hybridFactorGraph = R"(
|
||||||
|
size: 7
|
||||||
|
factor 0:
|
||||||
|
A[x0] = [
|
||||||
|
10
|
||||||
|
]
|
||||||
|
b = [ -10 ]
|
||||||
|
No noise model
|
||||||
|
factor 1:
|
||||||
|
Hybrid [x0 x1; m0]{
|
||||||
|
Choice(m0)
|
||||||
|
0 Leaf [1]:
|
||||||
|
A[x0] = [
|
||||||
|
-1
|
||||||
|
]
|
||||||
|
A[x1] = [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
b = [ -1 ]
|
||||||
|
No noise model
|
||||||
|
|
||||||
|
1 Leaf [1]:
|
||||||
|
A[x0] = [
|
||||||
|
-1
|
||||||
|
]
|
||||||
|
A[x1] = [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
b = [ -0 ]
|
||||||
|
No noise model
|
||||||
|
|
||||||
|
}
|
||||||
|
factor 2:
|
||||||
|
Hybrid [x1 x2; m1]{
|
||||||
|
Choice(m1)
|
||||||
|
0 Leaf [1]:
|
||||||
|
A[x1] = [
|
||||||
|
-1
|
||||||
|
]
|
||||||
|
A[x2] = [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
b = [ -1 ]
|
||||||
|
No noise model
|
||||||
|
|
||||||
|
1 Leaf [1]:
|
||||||
|
A[x1] = [
|
||||||
|
-1
|
||||||
|
]
|
||||||
|
A[x2] = [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
b = [ -0 ]
|
||||||
|
No noise model
|
||||||
|
|
||||||
|
}
|
||||||
|
factor 3:
|
||||||
|
A[x1] = [
|
||||||
|
10
|
||||||
|
]
|
||||||
|
b = [ -10 ]
|
||||||
|
No noise model
|
||||||
|
factor 4:
|
||||||
|
A[x2] = [
|
||||||
|
10
|
||||||
|
]
|
||||||
|
b = [ -10 ]
|
||||||
|
No noise model
|
||||||
|
factor 5: P( m0 ):
|
||||||
|
Choice(m0)
|
||||||
|
0 Leaf [1] 0.5
|
||||||
|
1 Leaf [1] 0.5
|
||||||
|
|
||||||
|
factor 6: P( m1 | m0 ):
|
||||||
|
Choice(m1)
|
||||||
|
0 Choice(m0)
|
||||||
|
0 0 Leaf [1]0.33333333
|
||||||
|
0 1 Leaf [1] 0.6
|
||||||
|
1 Choice(m0)
|
||||||
|
1 0 Leaf [1]0.66666667
|
||||||
|
1 1 Leaf [1] 0.4
|
||||||
|
|
||||||
|
)";
|
||||||
|
#endif
|
||||||
|
|
||||||
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
|
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
|
||||||
|
|
||||||
// Expected output for hybridBayesNet.
|
// Expected output for hybridBayesNet.
|
||||||
|
|
@ -570,13 +657,13 @@ size: 3
|
||||||
conditional 0: Hybrid P( x0 | x1 m0)
|
conditional 0: Hybrid P( x0 | x1 m0)
|
||||||
Discrete Keys = (m0, 2),
|
Discrete Keys = (m0, 2),
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf p(x0 | x1)
|
0 Leaf [1] p(x0 | x1)
|
||||||
R = [ 10.0499 ]
|
R = [ 10.0499 ]
|
||||||
S[x1] = [ -0.0995037 ]
|
S[x1] = [ -0.0995037 ]
|
||||||
d = [ -9.85087 ]
|
d = [ -9.85087 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf p(x0 | x1)
|
1 Leaf [1] p(x0 | x1)
|
||||||
R = [ 10.0499 ]
|
R = [ 10.0499 ]
|
||||||
S[x1] = [ -0.0995037 ]
|
S[x1] = [ -0.0995037 ]
|
||||||
d = [ -9.95037 ]
|
d = [ -9.95037 ]
|
||||||
|
|
@ -586,26 +673,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf p(x1 | x2)
|
0 0 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -9.99901 ]
|
d = [ -9.99901 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
0 1 Leaf p(x1 | x2)
|
0 1 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -9.90098 ]
|
d = [ -9.90098 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf p(x1 | x2)
|
1 0 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -10.098 ]
|
d = [ -10.098 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 1 Leaf p(x1 | x2)
|
1 1 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -10 ]
|
d = [ -10 ]
|
||||||
|
|
@ -615,14 +702,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf p(x2)
|
0 0 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.1489 ]
|
d = [ -10.1489 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
x2: -1.0099
|
x2: -1.0099
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
0 1 Leaf p(x2)
|
0 1 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.1479 ]
|
d = [ -10.1479 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
|
|
@ -630,14 +717,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf p(x2)
|
1 0 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.0504 ]
|
d = [ -10.0504 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
x2: -1.0001
|
x2: -1.0001
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 1 Leaf p(x2)
|
1 1 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.0494 ]
|
d = [ -10.0494 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) {
|
||||||
R"(Hybrid [x1 x2; 1]
|
R"(Hybrid [x1 x2; 1]
|
||||||
MixtureFactor
|
MixtureFactor
|
||||||
Choice(1)
|
Choice(1)
|
||||||
0 Leaf Nonlinear factor on 2 keys
|
0 Leaf [1] Nonlinear factor on 2 keys
|
||||||
1 Leaf Nonlinear factor on 2 keys
|
1 Leaf [1] Nonlinear factor on 2 keys
|
||||||
)";
|
)";
|
||||||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue