From 70ffbf32bcb535f3275eba2c65318b2a62dda62f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 8 Jun 2023 13:16:49 -0400 Subject: [PATCH 01/11] mark nrAssignments as const --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 5fb5ae2e6..4aca10a28 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -307,7 +307,7 @@ namespace gtsam { // Get the probabilities in the decision tree so we can threshold. std::vector probabilities; this->visitLeaf([&](const Leaf& leaf) { - size_t nrAssignments = leaf.nrAssignments(); + const size_t nrAssignments = leaf.nrAssignments(); double prob = leaf.constant(); probabilities.insert(probabilities.end(), nrAssignments, prob); }); From 23520432ec68946495325a459509689e7f42a911 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 8 Jun 2023 13:18:06 -0400 Subject: [PATCH 02/11] rename GTSAM_DT_NO_PRUNING to GTSAM_DT_NO_MERGING to help with disambiguation --- gtsam/discrete/DecisionTree-inl.h | 2 +- gtsam/discrete/tests/testAlgebraicDecisionTree.cpp | 2 +- gtsam/discrete/tests/testDecisionTree.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 16a926271..7a227d8dd 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -200,7 +200,7 @@ namespace gtsam { /// If all branches of a choice node f are the same, just return a branch. static NodePtr Unique(const ChoicePtr& f) { -#ifndef GTSAM_DT_NO_PRUNING +#ifndef GTSAM_DT_NO_MERGING // If all the branches are the same, we can merge them into one if (f->allSame_) { assert(f->branches().size() > 0); diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index c7cb7088e..d7e7a071c 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -20,7 +20,7 @@ #include // make sure we have traits #include // headers first to make sure no missing headers -//#define GTSAM_DT_NO_PRUNING +//#define GTSAM_DT_NO_MERGING #include #include // for convert only #define DISABLE_TIMING diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 71daf261d..0559a782b 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -18,7 +18,7 @@ */ // #define DT_DEBUG_MEMORY -// #define GTSAM_DT_NO_PRUNING +// #define GTSAM_DT_NO_MERGING #define DISABLE_DOT #include #include From 2998820d2cedbfb867d1a889ed4b2c3a3cdf79e6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 8 Jun 2023 17:58:12 -0400 Subject: [PATCH 03/11] bottom-up Unique method that works much, much better --- gtsam/discrete/DecisionTree-inl.h | 74 +++++++++---------- .../tests/testDiscreteFactorGraph.cpp | 72 +++++++++++++++++- 2 files changed, 102 insertions(+), 44 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 7a227d8dd..3e85ba70a 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -199,47 +199,41 @@ namespace gtsam { } /// If all branches of a choice node f are the same, just return a branch. - static NodePtr Unique(const ChoicePtr& f) { -#ifndef GTSAM_DT_NO_MERGING - // If all the branches are the same, we can merge them into one - if (f->allSame_) { - assert(f->branches().size() > 0); - NodePtr f0 = f->branches_[0]; - - size_t nrAssignments = 0; - for(auto branch: f->branches()) { - if (auto leaf = std::dynamic_pointer_cast(branch)) { - nrAssignments += leaf->nrAssignments(); - } - } - NodePtr newLeaf( - new Leaf(std::dynamic_pointer_cast(f0)->constant(), - nrAssignments)); - return newLeaf; - - } else - // Else we recurse -#endif - { - + static NodePtr Unique(const NodePtr& node) { + if (auto choice = std::dynamic_pointer_cast(node)) { + // Choice node, we recurse! // Make non-const copy - auto ff = std::make_shared(f->label(), f->nrChoices()); + auto f = std::make_shared(choice->label(), choice->nrChoices()); // Iterate over all the branches - for (size_t i = 0; i < f->nrChoices(); i++) { - auto branch = f->branches_[i]; - if (auto leaf = std::dynamic_pointer_cast(branch)) { - // Leaf node, simply assign - ff->push_back(branch); - - } else if (auto choice = - std::dynamic_pointer_cast(branch)) { - // Choice node, we recurse - ff->push_back(Unique(choice)); - } + for (size_t i = 0; i < choice->nrChoices(); i++) { + auto branch = choice->branches_[i]; + f->push_back(Unique(branch)); } - return ff; +#ifndef GTSAM_DT_NO_MERGING + // If all the branches are the same, we can merge them into one + if (f->allSame_) { + assert(f->branches().size() > 0); + NodePtr f0 = f->branches_[0]; + + // Compute total number of assignments + size_t nrAssignments = 0; + for (auto branch : f->branches()) { + if (auto leaf = std::dynamic_pointer_cast(branch)) { + nrAssignments += leaf->nrAssignments(); + } + } + NodePtr newLeaf( + new Leaf(std::dynamic_pointer_cast(f0)->constant(), + nrAssignments)); + return newLeaf; + } +#endif + return f; + } else { + // Leaf node, return as is + return node; } } @@ -549,7 +543,7 @@ namespace gtsam { template template DecisionTree::DecisionTree( Iterator begin, Iterator end, const L& label) { - root_ = compose(begin, end, label); + root_ = Choice::Unique(compose(begin, end, label)); } /****************************************************************************/ @@ -557,7 +551,7 @@ namespace gtsam { DecisionTree::DecisionTree(const L& label, const DecisionTree& f0, const DecisionTree& f1) { const std::vector functions{f0, f1}; - root_ = compose(functions.begin(), functions.end(), label); + root_ = Choice::Unique(compose(functions.begin(), functions.end(), label)); } /****************************************************************************/ @@ -608,7 +602,7 @@ namespace gtsam { auto choiceOnLabel = std::make_shared(label, end - begin); for (Iterator it = begin; it != end; it++) choiceOnLabel->push_back(it->root_); - return Choice::Unique(choiceOnLabel); + return choiceOnLabel; } else { // Set up a new choice on the highest label auto choiceOnHighestLabel = @@ -737,7 +731,7 @@ namespace gtsam { for (auto&& branch : choice->branches()) { functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } - return LY::compose(functions.begin(), functions.end(), newLabel); + return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel)); } /****************************************************************************/ diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index bbce5e8ce..f148cf1d8 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -15,17 +15,20 @@ * @author Duy-Nguyen Ta */ +#include +#include +#include +#include #include #include -#include -#include #include - -#include +#include using namespace std; using namespace gtsam; +using symbol_shorthand::M; + /* ************************************************************************* */ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3); @@ -345,6 +348,67 @@ TEST(DiscreteFactorGraph, markdown) { values[1] = 0; EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); } + +TEST(DiscreteFactorGraph, NrAssignments) { + string expected_dfg = R"( +size: 2 +factor 0: f[ (m0,2), (m1,2), (m2,2), ] + Choice(m2) + 0 Choice(m1) + 0 0 Leaf [1] 0 + 0 1 Choice(m0) + 0 1 0 Leaf [1]0.27527634 + 0 1 1 Leaf [1]0.44944733 + 1 Choice(m1) + 1 0 Leaf [1] 0 + 1 1 Choice(m0) + 1 1 0 Leaf [1] 0 + 1 1 1 Leaf [1]0.27527634 +factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] + Choice(m3) + 0 Choice(m2) + 0 0 Choice(m1) + 0 0 0 Leaf [2] 1 + 0 0 1 Leaf [2]0.015366387 + 0 1 Choice(m1) + 0 1 0 Leaf [2] 1 + 0 1 1 Choice(m0) + 0 1 1 0 Leaf [1] 1 + 0 1 1 1 Leaf [1]0.015365663 + 1 Choice(m2) + 1 0 Choice(m1) + 1 0 0 Leaf [2] 1 + 1 0 1 Choice(m0) + 1 0 1 0 Leaf [1]0.0094115739 + 1 0 1 1 Leaf [1]0.0094115652 + 1 1 Choice(m1) + 1 1 0 Leaf [2] 1 + 1 1 1 Choice(m0) + 1 1 1 0 Leaf [1] 1 + 1 1 1 1 Leaf [1]0.009321081 +)"; + + DiscreteKeys d0{{M(2), 2}, {M(1), 2}, {M(0), 2}}; + std::vector p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468}; + AlgebraicDecisionTree dt(d0, p0); + //TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up + // Issue seems to be in DecisionTreeFactor.cpp L104 + DiscreteConditional f0(3, DecisionTreeFactor(d0, dt)); + + DiscreteKeys d1{{M(0), 2}, {M(1), 2}, {M(2), 2}, {M(3), 2}}; + std::vector p1 = { + 1, 1, 1, 1, 0.015366387, 0.0094115739, 1, 1, + 1, 1, 1, 1, 0.015366387, 0.0094115652, 0.015365663, 0.009321081}; + DecisionTreeFactor f1(d1, p1); + DecisionTree dt1(d1, p1); + + DiscreteFactorGraph dfg; + dfg.add(f0); + dfg.add(f1); + + EXPECT(assert_print_equal(expected_dfg, dfg)); +} + /* ************************************************************************* */ int main() { TestResult tr; From a66e270faacb0a86520750cda8a8d838a110136e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 8 Jun 2023 18:29:46 -0400 Subject: [PATCH 04/11] print nrAssignments when printing decision trees --- gtsam/discrete/DecisionTree-inl.h | 3 +- gtsam/hybrid/MixtureFactor.h | 2 +- .../tests/testGaussianMixtureFactor.cpp | 4 +- .../tests/testHybridNonlinearFactorGraph.cpp | 38 +++++++++---------- gtsam/hybrid/tests/testMixtureFactor.cpp | 4 +- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 3e85ba70a..5c2b735a0 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -93,7 +93,8 @@ namespace gtsam { /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { - std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; + std::cout << s << " Leaf [" << nrAssignments() << "]" + << valueFormatter(constant_) << std::endl; } /** Write graphviz format to stream `os`. */ diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index df8e0193a..529c8687b 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -191,7 +191,7 @@ class MixtureFactor : public HybridFactor { std::cout << "\nMixtureFactor\n"; auto valueFormatter = [](const sharedFactor& v) { if (v) { - return "Nonlinear factor on " + std::to_string(v->size()) + " keys"; + return " Nonlinear factor on " + std::to_string(v->size()) + " keys"; } else { return std::string("nullptr"); } diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 75ba5a059..5207e9372 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) { std::string expected = R"(Hybrid [x1 x2; 1]{ Choice(1) - 0 Leaf : + 0 Leaf [1]: A[x1] = [ 0; 0 @@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) { b = [ 0 0 ] No noise model - 1 Leaf : + 1 Leaf [1]: A[x1] = [ 0; 0 diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index af3a23b94..7bcaf1762 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -492,7 +492,7 @@ factor 0: factor 1: Hybrid [x0 x1; m0]{ Choice(m0) - 0 Leaf : + 0 Leaf [1]: A[x0] = [ -1 ] @@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{ b = [ -1 ] No noise model - 1 Leaf : + 1 Leaf [1]: A[x0] = [ -1 ] @@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{ factor 2: Hybrid [x1 x2; m1]{ Choice(m1) - 0 Leaf : + 0 Leaf [1]: A[x1] = [ -1 ] @@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{ b = [ -1 ] No noise model - 1 Leaf : + 1 Leaf [1]: A[x1] = [ -1 ] @@ -550,16 +550,16 @@ factor 4: b = [ -10 ] No noise model factor 5: P( m0 ): - Leaf 0.5 + Leaf [2] 0.5 factor 6: P( m1 | m0 ): Choice(m1) 0 Choice(m0) - 0 0 Leaf 0.33333333 - 0 1 Leaf 0.6 + 0 0 Leaf [1]0.33333333 + 0 1 Leaf [1] 0.6 1 Choice(m0) - 1 0 Leaf 0.66666667 - 1 1 Leaf 0.4 + 1 0 Leaf [1]0.66666667 + 1 1 Leaf [1] 0.4 )"; EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph)); @@ -570,13 +570,13 @@ size: 3 conditional 0: Hybrid P( x0 | x1 m0) Discrete Keys = (m0, 2), Choice(m0) - 0 Leaf p(x0 | x1) + 0 Leaf [1] p(x0 | x1) R = [ 10.0499 ] S[x1] = [ -0.0995037 ] d = [ -9.85087 ] No noise model - 1 Leaf p(x0 | x1) + 1 Leaf [1] p(x0 | x1) R = [ 10.0499 ] S[x1] = [ -0.0995037 ] d = [ -9.95037 ] @@ -586,26 +586,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1) Discrete Keys = (m0, 2), (m1, 2), Choice(m1) 0 Choice(m0) - 0 0 Leaf p(x1 | x2) + 0 0 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -9.99901 ] No noise model - 0 1 Leaf p(x1 | x2) + 0 1 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -9.90098 ] No noise model 1 Choice(m0) - 1 0 Leaf p(x1 | x2) + 1 0 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -10.098 ] No noise model - 1 1 Leaf p(x1 | x2) + 1 1 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -10 ] @@ -615,14 +615,14 @@ conditional 2: Hybrid P( x2 | m0 m1) Discrete Keys = (m0, 2), (m1, 2), Choice(m1) 0 Choice(m0) - 0 0 Leaf p(x2) + 0 0 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.1489 ] mean: 1 elements x2: -1.0099 No noise model - 0 1 Leaf p(x2) + 0 1 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.1479 ] mean: 1 elements @@ -630,14 +630,14 @@ conditional 2: Hybrid P( x2 | m0 m1) No noise model 1 Choice(m0) - 1 0 Leaf p(x2) + 1 0 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.0504 ] mean: 1 elements x2: -1.0001 No noise model - 1 1 Leaf p(x2) + 1 1 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.0494 ] mean: 1 elements diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 67a7fd8ae..03fdccff2 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) { R"(Hybrid [x1 x2; 1] MixtureFactor Choice(1) - 0 Leaf Nonlinear factor on 2 keys - 1 Leaf Nonlinear factor on 2 keys + 0 Leaf [1] Nonlinear factor on 2 keys + 1 Leaf [1] Nonlinear factor on 2 keys )"; EXPECT(assert_print_equal(expected, mixtureFactor)); } From 76568f2d7395d46af93229ae3fe97470416ab992 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 9 Jun 2023 10:18:36 -0400 Subject: [PATCH 05/11] formatting --- gtsam/hybrid/HybridBayesTree.cpp | 3 +-- gtsam/hybrid/HybridGaussianISAM.cpp | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index b252e613e..dc1c875e1 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -173,8 +173,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { - auto decisionTree = - this->roots_.at(0)->conditional()->asDiscrete(); + auto decisionTree = this->roots_.at(0)->conditional()->asDiscrete(); DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); decisionTree->root_ = prunedDecisionTree.root_; diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 0dd5fa38b..6f8b7b9ff 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -70,8 +70,7 @@ Ordering HybridGaussianISAM::GetOrdering( /* ************************************************************************* */ void HybridGaussianISAM::updateInternal( const HybridGaussianFactorGraph& newFactors, - HybridBayesTree::Cliques* orphans, - const std::optional& maxNrLeaves, + HybridBayesTree::Cliques* orphans, const std::optional& maxNrLeaves, const std::optional& ordering, const HybridBayesTree::Eliminate& function) { // Remove the contaminated part of the Bayes tree @@ -101,8 +100,8 @@ void HybridGaussianISAM::updateInternal( } // eliminate all factors (top, added, orphans) into a new Bayes tree - HybridBayesTree::shared_ptr bayesTree = - factors.eliminateMultifrontal(elimination_ordering, function, std::cref(index)); + HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal( + elimination_ordering, function, std::cref(index)); if (maxNrLeaves) { bayesTree->prune(*maxNrLeaves); From 29c1816a81b245a4f599db7a47caa5a80db0f5f7 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 9 Jun 2023 20:13:06 -0400 Subject: [PATCH 06/11] change to GTSAM_DT_MERGING and expose via CMake --- cmake/HandleGeneralOptions.cmake | 1 + gtsam/config.h.in | 3 +++ gtsam/discrete/DecisionTree-inl.h | 2 +- gtsam/discrete/tests/testAlgebraicDecisionTree.cpp | 1 - gtsam/discrete/tests/testDecisionTree.cpp | 1 - 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cmake/HandleGeneralOptions.cmake b/cmake/HandleGeneralOptions.cmake index 9ebb07331..c5fd9898c 100644 --- a/cmake/HandleGeneralOptions.cmake +++ b/cmake/HandleGeneralOptions.cmake @@ -19,6 +19,7 @@ option(GTSAM_FORCE_STATIC_LIB "Force gtsam to be a static library, option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF) option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON) option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON) +option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON) option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF) option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON) diff --git a/gtsam/config.h.in b/gtsam/config.h.in index 7f8936d1e..7c08d36bf 100644 --- a/gtsam/config.h.in +++ b/gtsam/config.h.in @@ -39,6 +39,9 @@ #cmakedefine GTSAM_ROT3_EXPMAP #endif +// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree. +#cmakedefine GTSAM_DT_MERGING + // Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake) #cmakedefine GTSAM_USE_TBB diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 5c2b735a0..2f36007a9 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -212,7 +212,7 @@ namespace gtsam { f->push_back(Unique(branch)); } -#ifndef GTSAM_DT_NO_MERGING +#ifdef GTSAM_DT_MERGING // If all the branches are the same, we can merge them into one if (f->allSame_) { assert(f->branches().size() > 0); diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index d7e7a071c..55f5b61d7 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -20,7 +20,6 @@ #include // make sure we have traits #include // headers first to make sure no missing headers -//#define GTSAM_DT_NO_MERGING #include #include // for convert only #define DISABLE_TIMING diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 0559a782b..336945503 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -18,7 +18,6 @@ */ // #define DT_DEBUG_MEMORY -// #define GTSAM_DT_NO_MERGING #define DISABLE_DOT #include #include From 895998268694d16fa61224481adab9fb3123cc04 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 14 Jun 2023 15:23:14 -0400 Subject: [PATCH 07/11] remove extra calls to Unique --- gtsam/discrete/DecisionTree-inl.h | 5 +++-- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 2f36007a9..8dc19ea21 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -544,7 +544,7 @@ namespace gtsam { template template DecisionTree::DecisionTree( Iterator begin, Iterator end, const L& label) { - root_ = Choice::Unique(compose(begin, end, label)); + root_ = compose(begin, end, label); } /****************************************************************************/ @@ -552,7 +552,7 @@ namespace gtsam { DecisionTree::DecisionTree(const L& label, const DecisionTree& f0, const DecisionTree& f1) { const std::vector functions{f0, f1}; - root_ = Choice::Unique(compose(functions.begin(), functions.end(), label)); + root_ = compose(functions.begin(), functions.end(), label); } /****************************************************************************/ @@ -603,6 +603,7 @@ namespace gtsam { auto choiceOnLabel = std::make_shared(label, end - begin); for (Iterator it = begin; it != end; it++) choiceOnLabel->push_back(it->root_); + // If no reordering, no need to call Choice::Unique return choiceOnLabel; } else { // Set up a new choice on the highest label diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index f148cf1d8..6752dbd4a 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -400,7 +400,6 @@ factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] 1, 1, 1, 1, 0.015366387, 0.0094115739, 1, 1, 1, 1, 1, 1, 0.015366387, 0.0094115652, 0.015365663, 0.009321081}; DecisionTreeFactor f1(d1, p1); - DecisionTree dt1(d1, p1); DiscreteFactorGraph dfg; dfg.add(f0); From b24f20afe1b45cf75d809404c96a115ca4bfe565 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 26 Jun 2023 18:04:53 -0400 Subject: [PATCH 08/11] fix tests to work when GTSAM_DT_MERGING=OFF --- .../tests/testAlgebraicDecisionTree.cpp | 20 +++++ gtsam/discrete/tests/testDecisionTree.cpp | 59 ++++++++++++- .../tests/testDiscreteFactorGraph.cpp | 54 ++++++++++++ gtsam/hybrid/tests/testHybridBayesNet.cpp | 4 + .../tests/testHybridNonlinearFactorGraph.cpp | 87 +++++++++++++++++++ 5 files changed, 223 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index 55f5b61d7..19f4686c2 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -178,7 +178,11 @@ TEST(ADT, joint) { dot(joint, "Asia-ASTLBEX"); joint = apply(joint, pD, &mul); dot(joint, "Asia-ASTLBEXD"); +#ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(346, muls); +#else + EXPECT_LONGS_EQUAL(508, muls); +#endif gttoc_(asiaJoint); tictoc_getNode(asiaJointNode, asiaJoint); elapsed = asiaJointNode->secs() + asiaJointNode->wall(); @@ -239,7 +243,11 @@ TEST(ADT, inference) { dot(joint, "Joint-Product-ASTLBEX"); joint = apply(joint, pD, &mul); dot(joint, "Joint-Product-ASTLBEXD"); +#ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering +#else + EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering +#endif gttoc_(asiaProd); tictoc_getNode(asiaProdNode, asiaProd); elapsed = asiaProdNode->secs() + asiaProdNode->wall(); @@ -257,7 +265,11 @@ TEST(ADT, inference) { dot(marginal, "Joint-Sum-ADBLE"); marginal = marginal.combine(E, &add_); dot(marginal, "Joint-Sum-ADBL"); +#ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(161, (long)adds); +#else + EXPECT_LONGS_EQUAL(240, (long)adds); +#endif gttoc_(asiaSum); tictoc_getNode(asiaSumNode, asiaSum); elapsed = asiaSumNode->secs() + asiaSumNode->wall(); @@ -295,7 +307,11 @@ TEST(ADT, factor_graph) { fg = apply(fg, pX, &mul); fg = apply(fg, pD, &mul); dot(fg, "FactorGraph"); +#ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(158, (long)muls); +#else + EXPECT_LONGS_EQUAL(188, (long)muls); +#endif gttoc_(asiaFG); tictoc_getNode(asiaFGNode, asiaFG); elapsed = asiaFGNode->secs() + asiaFGNode->wall(); @@ -314,7 +330,11 @@ TEST(ADT, factor_graph) { dot(fg, "Marginalized-3E"); fg = fg.combine(L, &add_); dot(fg, "Marginalized-2L"); +#ifdef GTSAM_DT_MERGING LONGS_EQUAL(49, adds); +#else + LONGS_EQUAL(62, adds); +#endif gttoc_(marg); tictoc_getNode(margNode, marg); elapsed = margNode->secs() + margNode->wall(); diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 336945503..653360fb7 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -191,7 +191,11 @@ TEST(DecisionTree, example) { // Test choose 0 DT actual0 = notba.choose(A, 0); +#ifdef GTSAM_DT_MERGING EXPECT(assert_equal(DT(0.0), actual0)); +#else + // EXPECT(assert_equal(DT({0.0, 0.0}), actual0)); +#endif DOT(actual0); // Test choose 1 @@ -332,9 +336,11 @@ TEST(DecisionTree, NrAssignments) { EXPECT_LONGS_EQUAL(8, tree.nrAssignments()); +#ifdef GTSAM_DT_MERGING EXPECT(tree.root_->isLeaf()); auto leaf = std::dynamic_pointer_cast(tree.root_); EXPECT_LONGS_EQUAL(8, leaf->nrAssignments()); +#endif DT tree2({C, B, A}, "1 1 1 2 3 4 5 5"); /* The tree is @@ -357,6 +363,8 @@ TEST(DecisionTree, NrAssignments) { CHECK(root); auto choice0 = std::dynamic_pointer_cast(root->branches()[0]); CHECK(choice0); + +#ifdef GTSAM_DT_MERGING EXPECT(choice0->branches()[0]->isLeaf()); auto choice00 = std::dynamic_pointer_cast(choice0->branches()[0]); CHECK(choice00); @@ -370,6 +378,7 @@ TEST(DecisionTree, NrAssignments) { CHECK(choice11); EXPECT(choice11->isLeaf()); EXPECT_LONGS_EQUAL(2, choice11->nrAssignments()); +#endif } /* ************************************************************************** */ @@ -411,27 +420,61 @@ TEST(DecisionTree, VisitWithPruned) { }; tree.visitWith(func); +#ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(6, choices.size()); +#else + EXPECT_LONGS_EQUAL(8, choices.size()); +#endif Assignment expectedAssignment; +#ifdef GTSAM_DT_MERGING expectedAssignment = {{"B", 0}, {"C", 0}}; EXPECT(expectedAssignment == choices.at(0)); +#else + expectedAssignment = {{"A", 0}, {"B", 0}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(0)); +#endif +#ifdef GTSAM_DT_MERGING expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}}; EXPECT(expectedAssignment == choices.at(1)); +#else + expectedAssignment = {{"A", 1}, {"B", 0}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(1)); +#endif +#ifdef GTSAM_DT_MERGING expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}}; EXPECT(expectedAssignment == choices.at(2)); +#else + expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(2)); +#endif +#ifdef GTSAM_DT_MERGING expectedAssignment = {{"B", 0}, {"C", 1}}; EXPECT(expectedAssignment == choices.at(3)); +#else + expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(3)); +#endif +#ifdef GTSAM_DT_MERGING expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}}; EXPECT(expectedAssignment == choices.at(4)); +#else + expectedAssignment = {{"A", 0}, {"B", 0}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(4)); +#endif +#ifdef GTSAM_DT_MERGING expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}}; EXPECT(expectedAssignment == choices.at(5)); +#else + expectedAssignment = {{"A", 1}, {"B", 0}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(5)); +#endif } /* ************************************************************************** */ @@ -442,7 +485,11 @@ TEST(DecisionTree, fold) { DT tree(B, DT(A, 1, 1), DT(A, 2, 3)); auto add = [](const int& y, double x) { return y + x; }; double sum = tree.fold(add, 0.0); - EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning! +#ifdef GTSAM_DT_MERGING + EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to merging! +#else + EXPECT_DOUBLES_EQUAL(7.0, sum, 1e-9); +#endif } /* ************************************************************************** */ @@ -494,9 +541,14 @@ TEST(DecisionTree, threshold) { auto threshold = [](int value) { return value < 5 ? 0 : value; }; DT thresholded(tree, threshold); +#ifdef GTSAM_DT_MERGING // Check number of leaves equal to zero now = 2 // Note: it is 2, because the pruned branches are counted as 1! EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0)); +#else + // if GTSAM_DT_MERGING is disabled, the count will be larger + EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0)); +#endif } /* ************************************************************************** */ @@ -532,8 +584,13 @@ TEST(DecisionTree, ApplyWithAssignment) { }; DT prunedTree2 = prunedTree.apply(counter); +#ifdef GTSAM_DT_MERGING // Check if apply doesn't enumerate all leaves. EXPECT_LONGS_EQUAL(5, count); +#else + // if GTSAM_DT_MERGING is disabled, the count will be full + EXPECT_LONGS_EQUAL(8, count); +#endif } /* ************************************************************************** */ diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 6752dbd4a..33fa933d2 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -350,6 +350,7 @@ TEST(DiscreteFactorGraph, markdown) { } TEST(DiscreteFactorGraph, NrAssignments) { +#ifdef GTSAM_DT_MERGING string expected_dfg = R"( size: 2 factor 0: f[ (m0,2), (m1,2), (m2,2), ] @@ -387,6 +388,59 @@ factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] 1 1 1 0 Leaf [1] 1 1 1 1 1 Leaf [1]0.009321081 )"; +#else + string expected_dfg = R"( +size: 2 +factor 0: f[ (m0,2), (m1,2), (m2,2), ] + Choice(m2) + 0 Choice(m1) + 0 0 Choice(m0) + 0 0 0 Leaf [1] 0 + 0 0 1 Leaf [1] 0 + 0 1 Choice(m0) + 0 1 0 Leaf [1]0.27527634 + 0 1 1 Leaf [1]0.44944733 + 1 Choice(m1) + 1 0 Choice(m0) + 1 0 0 Leaf [1] 0 + 1 0 1 Leaf [1] 0 + 1 1 Choice(m0) + 1 1 0 Leaf [1] 0 + 1 1 1 Leaf [1]0.27527634 +factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] + Choice(m3) + 0 Choice(m2) + 0 0 Choice(m1) + 0 0 0 Choice(m0) + 0 0 0 0 Leaf [1] 1 + 0 0 0 1 Leaf [1] 1 + 0 0 1 Choice(m0) + 0 0 1 0 Leaf [1]0.015366387 + 0 0 1 1 Leaf [1]0.015366387 + 0 1 Choice(m1) + 0 1 0 Choice(m0) + 0 1 0 0 Leaf [1] 1 + 0 1 0 1 Leaf [1] 1 + 0 1 1 Choice(m0) + 0 1 1 0 Leaf [1] 1 + 0 1 1 1 Leaf [1]0.015365663 + 1 Choice(m2) + 1 0 Choice(m1) + 1 0 0 Choice(m0) + 1 0 0 0 Leaf [1] 1 + 1 0 0 1 Leaf [1] 1 + 1 0 1 Choice(m0) + 1 0 1 0 Leaf [1]0.0094115739 + 1 0 1 1 Leaf [1]0.0094115652 + 1 1 Choice(m1) + 1 1 0 Choice(m0) + 1 1 0 0 Leaf [1] 1 + 1 1 0 1 Leaf [1] 1 + 1 1 1 Choice(m0) + 1 1 1 0 Leaf [1] 1 + 1 1 1 1 Leaf [1]0.009321081 +)"; +#endif DiscreteKeys d0{{M(2), 2}, {M(1), 2}, {M(0), 2}}; std::vector p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468}; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index f25675a55..d2f39c6ed 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -288,8 +288,12 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { std::make_shared( discreteConditionals->prune(maxNrLeaves)); +#ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, prunedDecisionTree->nrLeaves()); +#else + EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves()); +#endif auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete()); diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 7bcaf1762..0a621c42d 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -481,6 +481,7 @@ TEST(HybridFactorGraph, Printing) { const auto [hybridBayesNet, remainingFactorGraph] = linearizedFactorGraph.eliminatePartialSequential(ordering); +#ifdef GTSAM_DT_MERGING string expected_hybridFactorGraph = R"( size: 7 factor 0: @@ -562,6 +563,92 @@ factor 6: P( m1 | m0 ): 1 1 Leaf [1] 0.4 )"; +#else +string expected_hybridFactorGraph = R"( +size: 7 +factor 0: + A[x0] = [ + 10 +] + b = [ -10 ] + No noise model +factor 1: +Hybrid [x0 x1; m0]{ + Choice(m0) + 0 Leaf [1]: + A[x0] = [ + -1 +] + A[x1] = [ + 1 +] + b = [ -1 ] + No noise model + + 1 Leaf [1]: + A[x0] = [ + -1 +] + A[x1] = [ + 1 +] + b = [ -0 ] + No noise model + +} +factor 2: +Hybrid [x1 x2; m1]{ + Choice(m1) + 0 Leaf [1]: + A[x1] = [ + -1 +] + A[x2] = [ + 1 +] + b = [ -1 ] + No noise model + + 1 Leaf [1]: + A[x1] = [ + -1 +] + A[x2] = [ + 1 +] + b = [ -0 ] + No noise model + +} +factor 3: + A[x1] = [ + 10 +] + b = [ -10 ] + No noise model +factor 4: + A[x2] = [ + 10 +] + b = [ -10 ] + No noise model +factor 5: P( m0 ): + Choice(m0) + 0 Leaf [1] 0.5 + 1 Leaf [1] 0.5 + +factor 6: P( m1 | m0 ): + Choice(m1) + 0 Choice(m0) + 0 0 Leaf [1]0.33333333 + 0 1 Leaf [1] 0.6 + 1 Choice(m0) + 1 0 Leaf [1]0.66666667 + 1 1 Leaf [1] 0.4 + +)"; +#endif + EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph)); // Expected output for hybridBayesNet. From 8ffddc4077fd2522c4493016900c4a6425c49d2d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 26 Jun 2023 18:05:05 -0400 Subject: [PATCH 09/11] print GTSAM_DT_MERGING cmake config --- cmake/HandlePrintConfiguration.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/HandlePrintConfiguration.cmake b/cmake/HandlePrintConfiguration.cmake index c5c3920cb..42fae90f7 100644 --- a/cmake/HandlePrintConfiguration.cmake +++ b/cmake/HandlePrintConfiguration.cmake @@ -90,6 +90,7 @@ print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency c print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ") print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") +print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree") print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3") print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ") print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration") From e5fea0da5204a1d9e48226e9eef4175a77e3d8fd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 26 Jun 2023 18:16:43 -0400 Subject: [PATCH 10/11] update docstring --- gtsam/discrete/DecisionTree-inl.h | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 8dc19ea21..b65cc6bcf 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -199,11 +199,26 @@ namespace gtsam { #endif } - /// If all branches of a choice node f are the same, just return a branch. + /** + * @brief Merge branches with equal leaf values for every choice node in a + * decision tree. If all branches are the same (i.e. have the same leaf + * value), replace the choice node with the equivalent leaf node. + * + * This function applies the branch merging (if enabled) recursively on the + * decision tree represented by the root node passed in as the argument. It + * recurses to the leaf nodes and merges branches with equal leaf values in + * a bottom-up fashion. + * + * Thus, if all branches of a choice node `f` are the same, + * just return a single branch at each recursion step. + * + * @param node The root node of the decision tree. + * @return NodePtr + */ static NodePtr Unique(const NodePtr& node) { if (auto choice = std::dynamic_pointer_cast(node)) { // Choice node, we recurse! - // Make non-const copy + // Make non-const copy so we can update auto f = std::make_shared(choice->label(), choice->nrChoices()); // Iterate over all the branches From 9b7f4b3f54b4be1e913fd44cc31d4fd437836d27 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 28 Jun 2023 10:12:13 -0400 Subject: [PATCH 11/11] fix test case --- gtsam/discrete/DecisionTree-inl.h | 10 +++++----- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index b65cc6bcf..156177d03 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -137,7 +137,9 @@ namespace gtsam { // Applying binary operator to two leaves results in a leaf NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { // fL op gL - NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_)); + // TODO(Varun) nrAssignments setting is not correct. + // Depending on f and g, the nrAssignments can be different. This is a bug! + NodePtr h(new Leaf(op(fL.constant_, constant_), fL.nrAssignments())); return h; } @@ -496,13 +498,11 @@ namespace gtsam { // DecisionTree /****************************************************************************/ template - DecisionTree::DecisionTree() { - } + DecisionTree::DecisionTree() {} template DecisionTree::DecisionTree(const NodePtr& root) : - root_(root) { - } + root_(root) {} /****************************************************************************/ template diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 33fa933d2..6e8621595 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -356,14 +356,14 @@ size: 2 factor 0: f[ (m0,2), (m1,2), (m2,2), ] Choice(m2) 0 Choice(m1) - 0 0 Leaf [1] 0 + 0 0 Leaf [2] 0 0 1 Choice(m0) 0 1 0 Leaf [1]0.27527634 - 0 1 1 Leaf [1]0.44944733 + 0 1 1 Leaf [1] 0 1 Choice(m1) - 1 0 Leaf [1] 0 + 1 0 Leaf [2] 0 1 1 Choice(m0) - 1 1 0 Leaf [1] 0 + 1 1 0 Leaf [1]0.44944733 1 1 1 Leaf [1]0.27527634 factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] Choice(m3) @@ -442,7 +442,7 @@ factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] )"; #endif - DiscreteKeys d0{{M(2), 2}, {M(1), 2}, {M(0), 2}}; + DiscreteKeys d0{{M(0), 2}, {M(1), 2}, {M(2), 2}}; std::vector p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468}; AlgebraicDecisionTree dt(d0, p0); //TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up