parent
							
								
									647d3c0744
								
							
						
					
					
						commit
						2db08281c6
					
				|  | @ -53,17 +53,26 @@ namespace gtsam { | |||
|     /** constant stored in this leaf */ | ||||
|     Y constant_; | ||||
| 
 | ||||
|     /** The number of assignments contained within this leaf.
 | ||||
|      * Particularly useful when leaves have been pruned. | ||||
|      */ | ||||
|     size_t nrAssignments_; | ||||
| 
 | ||||
|     /// Default constructor for serialization.
 | ||||
|     Leaf() {} | ||||
| 
 | ||||
|     /// Constructor from constant
 | ||||
|     Leaf(const Y& constant) : constant_(constant) {} | ||||
|     Leaf(const Y& constant, size_t nrAssignments = 1) | ||||
|         : constant_(constant), nrAssignments_(nrAssignments) {} | ||||
| 
 | ||||
|     /// Return the constant
 | ||||
|     const Y& constant() const { | ||||
|       return constant_; | ||||
|     } | ||||
| 
 | ||||
|     /// Return the number of assignments contained within this leaf.
 | ||||
|     size_t nrAssignments() const { return nrAssignments_; } | ||||
| 
 | ||||
|     /// Leaf-Leaf equality
 | ||||
|     bool sameLeaf(const Leaf& q) const override { | ||||
|       return constant_ == q.constant_; | ||||
|  | @ -84,7 +93,8 @@ namespace gtsam { | |||
|     /// print
 | ||||
|     void print(const std::string& s, const LabelFormatter& labelFormatter, | ||||
|                const ValueFormatter& valueFormatter) const override { | ||||
|       std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; | ||||
|       std::cout << s << " Leaf [" << nrAssignments() << "]" | ||||
|                 << valueFormatter(constant_) << std::endl; | ||||
|     } | ||||
| 
 | ||||
|     /** Write graphviz format to stream `os`. */ | ||||
|  | @ -104,14 +114,14 @@ namespace gtsam { | |||
| 
 | ||||
|     /** apply unary operator */ | ||||
|     NodePtr apply(const Unary& op) const override { | ||||
|       NodePtr f(new Leaf(op(constant_))); | ||||
|       NodePtr f(new Leaf(op(constant_), nrAssignments_)); | ||||
|       return f; | ||||
|     } | ||||
| 
 | ||||
|     /// Apply unary operator with assignment
 | ||||
|     NodePtr apply(const UnaryAssignment& op, | ||||
|                   const Assignment<L>& assignment) const override { | ||||
|       NodePtr f(new Leaf(op(assignment, constant_))); | ||||
|       NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_)); | ||||
|       return f; | ||||
|     } | ||||
| 
 | ||||
|  | @ -127,7 +137,9 @@ namespace gtsam { | |||
|     // Applying binary operator to two leaves results in a leaf
 | ||||
|     NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { | ||||
|       // fL op gL
 | ||||
|       NodePtr h(new Leaf(op(fL.constant_, constant_))); | ||||
|       // TODO(Varun) nrAssignments setting is not correct.
 | ||||
|       // Depending on f and g, the nrAssignments can be different. This is a bug!
 | ||||
|       NodePtr h(new Leaf(op(fL.constant_, constant_), fL.nrAssignments())); | ||||
|       return h; | ||||
|     } | ||||
| 
 | ||||
|  | @ -138,7 +150,7 @@ namespace gtsam { | |||
| 
 | ||||
|     /** choose a branch, create new memory ! */ | ||||
|     NodePtr choose(const L& label, size_t index) const override { | ||||
|       return NodePtr(new Leaf(constant())); | ||||
|       return NodePtr(new Leaf(constant(), nrAssignments())); | ||||
|     } | ||||
| 
 | ||||
|     bool isLeaf() const override { return true; } | ||||
|  | @ -153,6 +165,7 @@ namespace gtsam { | |||
|     void serialize(ARCHIVE& ar, const unsigned int /*version*/) { | ||||
|       ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); | ||||
|       ar& BOOST_SERIALIZATION_NVP(constant_); | ||||
|       ar& BOOST_SERIALIZATION_NVP(nrAssignments_); | ||||
|     } | ||||
| #endif | ||||
|   };  // Leaf
 | ||||
|  | @ -222,8 +235,16 @@ namespace gtsam { | |||
|           assert(f->branches().size() > 0); | ||||
|           NodePtr f0 = f->branches_[0]; | ||||
| 
 | ||||
|           // Compute total number of assignments
 | ||||
|           size_t nrAssignments = 0; | ||||
|           for (auto branch : f->branches()) { | ||||
|             if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) { | ||||
|               nrAssignments += leaf->nrAssignments(); | ||||
|             } | ||||
|           } | ||||
|           NodePtr newLeaf( | ||||
|               new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant())); | ||||
|               new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(), | ||||
|                        nrAssignments)); | ||||
|           return newLeaf; | ||||
|         } | ||||
| #endif | ||||
|  | @ -709,7 +730,7 @@ namespace gtsam { | |||
|     // If leaf, apply unary conversion "op" and create a unique leaf.
 | ||||
|     using MXLeaf = typename DecisionTree<M, X>::Leaf; | ||||
|     if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) { | ||||
|       return NodePtr(new Leaf(Y_of_X(leaf->constant()))); | ||||
|       return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments())); | ||||
|     } | ||||
| 
 | ||||
|     // Check if Choice
 | ||||
|  | @ -856,6 +877,16 @@ namespace gtsam { | |||
|     return total; | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   template <typename L, typename Y> | ||||
|   size_t DecisionTree<L, Y>::nrAssignments() const { | ||||
|     size_t n = 0; | ||||
|     this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) { | ||||
|       n += leaf.nrAssignments(); | ||||
|     }); | ||||
|     return n; | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   // fold is just done with a visit
 | ||||
|   template <typename L, typename Y> | ||||
|  |  | |||
|  | @ -307,6 +307,42 @@ namespace gtsam { | |||
|     /// Return the number of leaves in the tree.
 | ||||
|     size_t nrLeaves() const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief This is a convenience function which returns the total number of | ||||
|      * leaf assignments in the decision tree. | ||||
|      * This function is not used for anymajor operations within the discrete | ||||
|      * factor graph framework. | ||||
|      * | ||||
|      * Leaf assignments represent the cardinality of each leaf node, e.g. in a | ||||
|      * binary tree each leaf has 2 assignments. This includes counts removed | ||||
|      * from implicit pruning hence, it will always be >= nrLeaves(). | ||||
|      * | ||||
|      * E.g. we have a decision tree as below, where each node has 2 branches: | ||||
|      * | ||||
|      * Choice(m1) | ||||
|      * 0 Choice(m0) | ||||
|      * 0 0 Leaf 0.0 | ||||
|      * 0 1 Leaf 0.0 | ||||
|      * 1 Choice(m0) | ||||
|      * 1 0 Leaf 1.0 | ||||
|      * 1 1 Leaf 2.0 | ||||
|      * | ||||
|      * In the unpruned form, the tree will have 4 assignments, 2 for each key, | ||||
|      * and 4 leaves. | ||||
|      * | ||||
|      * In the pruned form, the number of assignments is still 4 but the number | ||||
|      * of leaves is now 3, as below: | ||||
|      * | ||||
|      * Choice(m1) | ||||
|      * 0 Leaf 0.0 | ||||
|      * 1 Choice(m0) | ||||
|      * 1 0 Leaf 1.0 | ||||
|      * 1 1 Leaf 2.0 | ||||
|      * | ||||
|      * @return size_t | ||||
|      */ | ||||
|     size_t nrAssignments() const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Fold a binary function over the tree, returning accumulator. | ||||
|      * | ||||
|  |  | |||
|  | @ -328,6 +328,59 @@ TEST(DecisionTree, Containers) { | |||
|   StringContainerTree converted(stringIntTree, container_of_int); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test nrAssignments.
 | ||||
| TEST(DecisionTree, NrAssignments) { | ||||
|   const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2); | ||||
|   DT tree({A, B, C}, "1 1 1 1 1 1 1 1"); | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(8, tree.nrAssignments()); | ||||
| 
 | ||||
| #ifdef GTSAM_DT_MERGING | ||||
|   EXPECT(tree.root_->isLeaf()); | ||||
|   auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_); | ||||
|   EXPECT_LONGS_EQUAL(8, leaf->nrAssignments()); | ||||
| #endif | ||||
| 
 | ||||
|   DT tree2({C, B, A}, "1 1 1 2 3 4 5 5"); | ||||
|   /* The tree is
 | ||||
|     Choice(C)  | ||||
|     0 Choice(B)  | ||||
|     0 0 Leaf 1 | ||||
|     0 1 Choice(A)  | ||||
|     0 1 0 Leaf 1 | ||||
|     0 1 1 Leaf 2 | ||||
|     1 Choice(B)  | ||||
|     1 0 Choice(A)  | ||||
|     1 0 0 Leaf 3 | ||||
|     1 0 1 Leaf 4 | ||||
|     1 1 Leaf 5 | ||||
|   */ | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(8, tree2.nrAssignments()); | ||||
| 
 | ||||
|   auto root = std::dynamic_pointer_cast<const DT::Choice>(tree2.root_); | ||||
|   CHECK(root); | ||||
|   auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]); | ||||
|   CHECK(choice0); | ||||
| 
 | ||||
| #ifdef GTSAM_DT_MERGING | ||||
|   EXPECT(choice0->branches()[0]->isLeaf()); | ||||
|   auto choice00 = std::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]); | ||||
|   CHECK(choice00); | ||||
|   EXPECT_LONGS_EQUAL(2, choice00->nrAssignments()); | ||||
| 
 | ||||
|   auto choice1 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]); | ||||
|   CHECK(choice1); | ||||
|   auto choice10 = std::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]); | ||||
|   CHECK(choice10); | ||||
|   auto choice11 = std::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]); | ||||
|   CHECK(choice11); | ||||
|   EXPECT(choice11->isLeaf()); | ||||
|   EXPECT_LONGS_EQUAL(2, choice11->nrAssignments()); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test visit.
 | ||||
| TEST(DecisionTree, visit) { | ||||
|  | @ -540,6 +593,38 @@ TEST(DecisionTree, ApplyWithAssignment) { | |||
| #endif | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test number of assignments.
 | ||||
| TEST(DecisionTree, NrAssignments2) { | ||||
|   using gtsam::symbol_shorthand::M; | ||||
| 
 | ||||
|   std::vector<double> probs = {0, 0, 1, 2}; | ||||
| 
 | ||||
|   /* Create the decision tree
 | ||||
|     Choice(m1) | ||||
|     0 Leaf 0.000000 | ||||
|     1 Choice(m0) | ||||
|     1 0 Leaf 1.000000 | ||||
|     1 1 Leaf 2.000000 | ||||
|   */ | ||||
|   DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; | ||||
|   DecisionTree<Key, double> dt1(keys, probs); | ||||
|   EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); | ||||
| 
 | ||||
|   /* Create the DecisionTree
 | ||||
|     Choice(m1) | ||||
|     0 Choice(m0) | ||||
|     0 0 Leaf 0.000000 | ||||
|     0 1 Leaf 1.000000 | ||||
|     1 Choice(m0) | ||||
|     1 0 Leaf 0.000000 | ||||
|     1 1 Leaf 2.000000 | ||||
|   */ | ||||
|   DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; | ||||
|   DecisionTree<Key, double> dt2(keys2, probs); | ||||
|   EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|  |  | |||
|  | @ -349,6 +349,119 @@ TEST(DiscreteFactorGraph, markdown) { | |||
|   EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); | ||||
| } | ||||
| 
 | ||||
| TEST(DiscreteFactorGraph, NrAssignments) { | ||||
| #ifdef GTSAM_DT_MERGING | ||||
|   string expected_dfg = R"( | ||||
| size: 2 | ||||
| factor 0:  f[ (m0,2), (m1,2), (m2,2), ] | ||||
|  Choice(m2)  | ||||
|  0 Choice(m1)  | ||||
|  0 0 Leaf [2]   0 | ||||
|  0 1 Choice(m0)  | ||||
|  0 1 0 Leaf [1]0.27527634 | ||||
|  0 1 1 Leaf [1]   0 | ||||
|  1 Choice(m1)  | ||||
|  1 0 Leaf [2]   0 | ||||
|  1 1 Choice(m0)  | ||||
|  1 1 0 Leaf [1]0.44944733 | ||||
|  1 1 1 Leaf [1]0.27527634 | ||||
| factor 1:  f[ (m0,2), (m1,2), (m2,2), (m3,2), ] | ||||
|  Choice(m3)  | ||||
|  0 Choice(m2)  | ||||
|  0 0 Choice(m1)  | ||||
|  0 0 0 Leaf [2]   1 | ||||
|  0 0 1 Leaf [2]0.015366387 | ||||
|  0 1 Choice(m1)  | ||||
|  0 1 0 Leaf [2]   1 | ||||
|  0 1 1 Choice(m0)  | ||||
|  0 1 1 0 Leaf [1]   1 | ||||
|  0 1 1 1 Leaf [1]0.015365663 | ||||
|  1 Choice(m2)  | ||||
|  1 0 Choice(m1)  | ||||
|  1 0 0 Leaf [2]   1 | ||||
|  1 0 1 Choice(m0)  | ||||
|  1 0 1 0 Leaf [1]0.0094115739 | ||||
|  1 0 1 1 Leaf [1]0.0094115652 | ||||
|  1 1 Choice(m1)  | ||||
|  1 1 0 Leaf [2]   1 | ||||
|  1 1 1 Choice(m0)  | ||||
|  1 1 1 0 Leaf [1]   1 | ||||
|  1 1 1 1 Leaf [1]0.009321081 | ||||
| )"; | ||||
| #else | ||||
|   string expected_dfg = R"( | ||||
| size: 2 | ||||
| factor 0:  f[ (m0,2), (m1,2), (m2,2), ] | ||||
|  Choice(m2)  | ||||
|  0 Choice(m1)  | ||||
|  0 0 Choice(m0)  | ||||
|  0 0 0 Leaf [1]   0 | ||||
|  0 0 1 Leaf [1]   0 | ||||
|  0 1 Choice(m0)  | ||||
|  0 1 0 Leaf [1]0.27527634 | ||||
|  0 1 1 Leaf [1]0.44944733 | ||||
|  1 Choice(m1)  | ||||
|  1 0 Choice(m0)  | ||||
|  1 0 0 Leaf [1]   0 | ||||
|  1 0 1 Leaf [1]   0 | ||||
|  1 1 Choice(m0)  | ||||
|  1 1 0 Leaf [1]   0 | ||||
|  1 1 1 Leaf [1]0.27527634 | ||||
| factor 1:  f[ (m0,2), (m1,2), (m2,2), (m3,2), ] | ||||
|  Choice(m3)  | ||||
|  0 Choice(m2)  | ||||
|  0 0 Choice(m1)  | ||||
|  0 0 0 Choice(m0)  | ||||
|  0 0 0 0 Leaf [1]   1 | ||||
|  0 0 0 1 Leaf [1]   1 | ||||
|  0 0 1 Choice(m0)  | ||||
|  0 0 1 0 Leaf [1]0.015366387 | ||||
|  0 0 1 1 Leaf [1]0.015366387 | ||||
|  0 1 Choice(m1)  | ||||
|  0 1 0 Choice(m0)  | ||||
|  0 1 0 0 Leaf [1]   1 | ||||
|  0 1 0 1 Leaf [1]   1 | ||||
|  0 1 1 Choice(m0)  | ||||
|  0 1 1 0 Leaf [1]   1 | ||||
|  0 1 1 1 Leaf [1]0.015365663 | ||||
|  1 Choice(m2)  | ||||
|  1 0 Choice(m1)  | ||||
|  1 0 0 Choice(m0)  | ||||
|  1 0 0 0 Leaf [1]   1 | ||||
|  1 0 0 1 Leaf [1]   1 | ||||
|  1 0 1 Choice(m0)  | ||||
|  1 0 1 0 Leaf [1]0.0094115739 | ||||
|  1 0 1 1 Leaf [1]0.0094115652 | ||||
|  1 1 Choice(m1)  | ||||
|  1 1 0 Choice(m0)  | ||||
|  1 1 0 0 Leaf [1]   1 | ||||
|  1 1 0 1 Leaf [1]   1 | ||||
|  1 1 1 Choice(m0)  | ||||
|  1 1 1 0 Leaf [1]   1 | ||||
|  1 1 1 1 Leaf [1]0.009321081 | ||||
| )"; | ||||
| #endif | ||||
| 
 | ||||
|   DiscreteKeys d0{{M(0), 2}, {M(1), 2}, {M(2), 2}}; | ||||
|   std::vector<double> p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468}; | ||||
|   AlgebraicDecisionTree<Key> dt(d0, p0); | ||||
|   //TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up
 | ||||
|   // Issue seems to be in DecisionTreeFactor.cpp L104
 | ||||
|   DiscreteConditional f0(3, DecisionTreeFactor(d0, dt)); | ||||
| 
 | ||||
|   DiscreteKeys d1{{M(0), 2}, {M(1), 2}, {M(2), 2}, {M(3), 2}}; | ||||
|   std::vector<double> p1 = { | ||||
|       1, 1, 1, 1, 0.015366387, 0.0094115739, 1,           1, | ||||
|       1, 1, 1, 1, 0.015366387, 0.0094115652, 0.015365663, 0.009321081}; | ||||
|   DecisionTreeFactor f1(d1, p1); | ||||
| 
 | ||||
|   DiscreteFactorGraph dfg; | ||||
|   dfg.add(f0); | ||||
|   dfg.add(f1); | ||||
| 
 | ||||
|   EXPECT(assert_print_equal(expected_dfg, dfg)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
| TestResult tr; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue