From 1dfb388587b86c6808e354b8b3086c3725bdfcb4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 20 Jul 2023 15:47:29 -0400 Subject: [PATCH] fix odd behavior in nrAssignments --- gtsam/discrete/DecisionTree-inl.h | 29 ++++++++++++------- gtsam/discrete/DecisionTree.h | 8 +++-- gtsam/discrete/tests/testDecisionTree.cpp | 2 +- .../tests/testDiscreteFactorGraph.cpp | 11 ++++--- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index c3f6b3c0e..541bd77f4 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -137,8 +137,8 @@ namespace gtsam { // Applying binary operator to two leaves results in a leaf NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { // fL op gL - // TODO(Varun) nrAssignments setting is not correct. - // Depending on f and g, the nrAssignments can be different. This is a bug! + // The nrAssignments is always set to fL since we consider g operating on + // (or modifying) f. NodePtr h(new Leaf(op(fL.constant_, constant_), fL.nrAssignments())); return h; } @@ -149,8 +149,9 @@ namespace gtsam { } /** choose a branch, create new memory ! */ - NodePtr choose(const L& label, size_t index) const override { - return NodePtr(new Leaf(constant(), nrAssignments())); + NodePtr choose(const L& label, size_t index, + bool make_unique = true) const override { + return NodePtr(new Leaf(constant(), 1)); } bool isLeaf() const override { return true; } @@ -468,14 +469,22 @@ namespace gtsam { } /** choose a branch, recursively */ - NodePtr choose(const L& label, size_t index) const override { + NodePtr choose(const L& label, size_t index, + bool make_unique = true) const override { if (label_ == label) return branches_[index]; // choose branch // second case, not label of interest, just recurse auto r = std::make_shared(label_, branches_.size()); - for (auto&& branch : branches_) - r->push_back(branch->choose(label, index)); - return Unique(r); + for (auto&& branch : branches_) { + r->push_back(branch->choose(label, index, make_unique)); + } + + if (make_unique) { + return Unique(r); + } else { + return r; + } + // return Unique(r); } private: @@ -997,9 +1006,9 @@ namespace gtsam { template DecisionTree DecisionTree::combine(const L& label, size_t cardinality, const Binary& op) const { - DecisionTree result = choose(label, 0); + DecisionTree result = choose(label, 0, false); for (size_t index = 1; index < cardinality; index++) { - DecisionTree chosen = choose(label, index); + DecisionTree chosen = choose(label, index, false); result = result.apply(chosen, op); } return result; diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 48a1a4596..7761de84e 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -129,7 +129,8 @@ namespace gtsam { virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; - virtual Ptr choose(const L& label, size_t index) const = 0; + virtual Ptr choose(const L& label, size_t index, + bool make_unique = true) const = 0; virtual bool isLeaf() const = 0; private: @@ -403,8 +404,9 @@ namespace gtsam { /** create a new function where value(label)==index * It's like "restrict" in Darwiche09book pg329, 330? */ - DecisionTree choose(const L& label, size_t index) const { - NodePtr newRoot = root_->choose(label, index); + DecisionTree choose(const L& label, size_t index, + bool make_unique = true) const { + NodePtr newRoot = root_->choose(label, index, make_unique); return DecisionTree(newRoot); } diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index e94c72c34..ff50c8952 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -236,7 +236,7 @@ TEST(DecisionTree, Example) { #ifdef GTSAM_DT_MERGING EXPECT(assert_equal(DT(0.0), actual0)); #else - // EXPECT(assert_equal(DT({0.0, 0.0}), actual0)); + EXPECT(assert_equal(DT({0.0, 0.0}), actual0)); #endif DOT(actual0); diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 08b5f2db9..871583a2c 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -349,6 +349,7 @@ TEST(DiscreteFactorGraph, markdown) { EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); } +/* ************************************************************************* */ TEST(DiscreteFactorGraph, NrAssignments) { #ifdef GTSAM_DT_MERGING string expected_dfg = R"( @@ -358,13 +359,13 @@ factor 0: f[ (m0,2), (m1,2), (m2,2), ] 0 Choice(m1) 0 0 Leaf [2] 0 0 1 Choice(m0) - 0 1 0 Leaf [1] 0.27527634 + 0 1 0 Leaf [1] 0.17054468 0 1 1 Leaf [1] 0 1 Choice(m1) 1 0 Leaf [2] 0 1 1 Choice(m0) - 1 1 0 Leaf [1] 0.44944733 - 1 1 1 Leaf [1] 0.27527634 + 1 1 0 Leaf [1] 0.27845056 + 1 1 1 Leaf [1] 0.17054468 factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] Choice(m3) 0 Choice(m2) @@ -445,9 +446,7 @@ factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] DiscreteKeys d0{{M(0), 2}, {M(1), 2}, {M(2), 2}}; std::vector p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468}; AlgebraicDecisionTree dt(d0, p0); - //TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up - // Issue seems to be in DecisionTreeFactor.cpp L104 - DiscreteConditional f0(3, DecisionTreeFactor(d0, dt)); + DiscreteConditional f0(3, d0, dt); DiscreteKeys d1{{M(0), 2}, {M(1), 2}, {M(2), 2}, {M(3), 2}}; std::vector p1 = {