WIP for debugging nrAssignments issue
							parent
							
								
									0cd36db4d9
								
							
						
					
					
						commit
						73b563a9aa
					
				|  | @ -93,7 +93,7 @@ namespace gtsam { | |||
|     /// print
 | ||||
|     void print(const std::string& s, const LabelFormatter& labelFormatter, | ||||
|                const ValueFormatter& valueFormatter) const override { | ||||
|       std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; | ||||
|       std::cout << s << " Leaf " << valueFormatter(constant_) << " | nrAssignments: " << nrAssignments_ << std::endl; | ||||
|     } | ||||
| 
 | ||||
|     /** Write graphviz format to stream `os`. */ | ||||
|  | @ -207,9 +207,9 @@ namespace gtsam { | |||
| 
 | ||||
|         size_t nrAssignments = 0; | ||||
|         for(auto branch: f->branches()) { | ||||
|           assert(branch->isLeaf()); | ||||
|           nrAssignments += | ||||
|               std::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments(); | ||||
|           if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) { | ||||
|             nrAssignments += leaf->nrAssignments(); | ||||
|           } | ||||
|         } | ||||
|         NodePtr newLeaf( | ||||
|             new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(), | ||||
|  | @ -217,9 +217,35 @@ namespace gtsam { | |||
|         return newLeaf; | ||||
|       } else | ||||
| #endif | ||||
|           // {
 | ||||
|           //   Choice choice_node;
 | ||||
| 
 | ||||
|           //   for (auto branch : f->branches()) {
 | ||||
|           //     if (auto choice = std::dynamic_pointer_cast<const
 | ||||
|           //     Choice>(branch)) {
 | ||||
|           //       // `branch` is a Choice node so we apply Unique to it.
 | ||||
|           //       choice_node.push_back(Unique(choice));
 | ||||
| 
 | ||||
|           //     } else if (auto leaf =
 | ||||
|           //                    std::dynamic_pointer_cast<const Leaf>(branch)) {
 | ||||
|           //       choice_node.push_back(leaf);
 | ||||
|           //     }
 | ||||
|           //   }
 | ||||
|           //   return std::make_shared<const Choice>(choice_node);
 | ||||
|           // }
 | ||||
|         return f; | ||||
|     } | ||||
| 
 | ||||
|     static NodePtr UpdateNrAssignments(const NodePtr& f) { | ||||
|       if (auto choice = std::dynamic_pointer_cast<const Choice>(f)) { | ||||
|         // `f` is a Choice node so we recurse.
 | ||||
|         return UpdateNrAssignments(f); | ||||
| 
 | ||||
|       } else if (auto leaf = std::dynamic_pointer_cast<const Leaf>(f)) { | ||||
|          | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     bool isLeaf() const override { return false; } | ||||
| 
 | ||||
|     /// Constructor, given choice label and mandatory expected branch count.
 | ||||
|  | @ -282,7 +308,7 @@ namespace gtsam { | |||
|     void print(const std::string& s, const LabelFormatter& labelFormatter, | ||||
|                const ValueFormatter& valueFormatter) const override { | ||||
|       std::cout << s << " Choice("; | ||||
|       std::cout << labelFormatter(label_) << ") " << std::endl; | ||||
|       std::cout << labelFormatter(label_) << ") " << " | All Same: " << allSame_ << " | nrBranches: " << branches_.size() << std::endl; | ||||
|       for (size_t i = 0; i < branches_.size(); i++) { | ||||
|         branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter); | ||||
|       } | ||||
|  | @ -569,16 +595,16 @@ namespace gtsam { | |||
|     // find highest label among branches
 | ||||
|     std::optional<L> highestLabel; | ||||
|     size_t nrChoices = 0; | ||||
|     for (Iterator it = begin; it != end; it++) { | ||||
|       if (it->root_->isLeaf()) | ||||
|         continue; | ||||
|       std::shared_ptr<const Choice> c = | ||||
|           std::dynamic_pointer_cast<const Choice>(it->root_); | ||||
|       if (!highestLabel || c->label() > *highestLabel) { | ||||
|         highestLabel = c->label(); | ||||
|         nrChoices = c->nrChoices(); | ||||
|       } | ||||
|     } | ||||
|     // for (Iterator it = begin; it != end; it++) {
 | ||||
|     //   if (it->root_->isLeaf())
 | ||||
|     //     continue;
 | ||||
|     //   std::shared_ptr<const Choice> c =
 | ||||
|     //       std::dynamic_pointer_cast<const Choice>(it->root_);
 | ||||
|     //   if (!highestLabel || c->label() > *highestLabel) {
 | ||||
|     //     highestLabel = c->label();
 | ||||
|     //     nrChoices = c->nrChoices();
 | ||||
|     //   }
 | ||||
|     // }
 | ||||
| 
 | ||||
|     // if label is already in correct order, just put together a choice on label
 | ||||
|     if (!nrChoices || !highestLabel || label > *highestLabel) { | ||||
|  | @ -604,6 +630,7 @@ namespace gtsam { | |||
|         NodePtr fi = compose(functions.begin(), functions.end(), label); | ||||
|         choiceOnHighestLabel->push_back(fi); | ||||
|       } | ||||
|       // return Choice::ComputeNrAssignments(Choice::Unique(choiceOnHighestLabel));
 | ||||
|       return Choice::Unique(choiceOnHighestLabel); | ||||
|     } | ||||
|   } | ||||
|  |  | |||
|  | @ -121,7 +121,7 @@ struct Ring { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // test DT
 | ||||
| TEST(DecisionTree, example) { | ||||
| TEST_DISABLED(DecisionTree, example) { | ||||
|   // Create labels
 | ||||
|   string A("A"), B("B"), C("C"); | ||||
| 
 | ||||
|  | @ -231,7 +231,7 @@ TEST(DecisionTree, example) { | |||
| bool bool_of_int(const int& y) { return y != 0; }; | ||||
| typedef DecisionTree<string, bool> StringBoolTree; | ||||
| 
 | ||||
| TEST(DecisionTree, ConvertValuesOnly) { | ||||
| TEST_DISABLED(DecisionTree, ConvertValuesOnly) { | ||||
|   // Create labels
 | ||||
|   string A("A"), B("B"); | ||||
| 
 | ||||
|  | @ -252,7 +252,7 @@ TEST(DecisionTree, ConvertValuesOnly) { | |||
| enum Label { U, V, X, Y, Z }; | ||||
| typedef DecisionTree<Label, bool> LabelBoolTree; | ||||
| 
 | ||||
| TEST(DecisionTree, ConvertBoth) { | ||||
| TEST_DISABLED(DecisionTree, ConvertBoth) { | ||||
|   // Create labels
 | ||||
|   string A("A"), B("B"); | ||||
| 
 | ||||
|  | @ -279,7 +279,7 @@ TEST(DecisionTree, ConvertBoth) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // test Compose expansion
 | ||||
| TEST(DecisionTree, Compose) { | ||||
| TEST_DISABLED(DecisionTree, Compose) { | ||||
|   // Create labels
 | ||||
|   string A("A"), B("B"), C("C"); | ||||
| 
 | ||||
|  | @ -305,7 +305,7 @@ TEST(DecisionTree, Compose) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Check we can create a decision tree of containers.
 | ||||
| TEST(DecisionTree, Containers) { | ||||
| TEST_DISABLED(DecisionTree, Containers) { | ||||
|   using Container = std::vector<double>; | ||||
|   using StringContainerTree = DecisionTree<string, Container>; | ||||
| 
 | ||||
|  | @ -327,7 +327,7 @@ TEST(DecisionTree, Containers) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test nrAssignments.
 | ||||
| TEST(DecisionTree, NrAssignments) { | ||||
| TEST_DISABLED(DecisionTree, NrAssignments) { | ||||
|   const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2); | ||||
|   DT tree({A, B, C}, "1 1 1 1 1 1 1 1"); | ||||
| 
 | ||||
|  | @ -375,7 +375,7 @@ TEST(DecisionTree, NrAssignments) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test visit.
 | ||||
| TEST(DecisionTree, visit) { | ||||
| TEST_DISABLED(DecisionTree, visit) { | ||||
|   // Create small two-level tree
 | ||||
|   string A("A"), B("B"); | ||||
|   DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); | ||||
|  | @ -387,7 +387,7 @@ TEST(DecisionTree, visit) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test visit, with Choices argument.
 | ||||
| TEST(DecisionTree, visitWith) { | ||||
| TEST_DISABLED(DecisionTree, visitWith) { | ||||
|   // Create small two-level tree
 | ||||
|   string A("A"), B("B"); | ||||
|   DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); | ||||
|  | @ -399,7 +399,7 @@ TEST(DecisionTree, visitWith) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test visit, with Choices argument.
 | ||||
| TEST(DecisionTree, VisitWithPruned) { | ||||
| TEST_DISABLED(DecisionTree, VisitWithPruned) { | ||||
|   // Create pruned tree
 | ||||
|   std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2); | ||||
|   std::vector<std::pair<string, size_t>> labels = {C, B, A}; | ||||
|  | @ -437,7 +437,7 @@ TEST(DecisionTree, VisitWithPruned) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test fold.
 | ||||
| TEST(DecisionTree, fold) { | ||||
| TEST_DISABLED(DecisionTree, fold) { | ||||
|   // Create small two-level tree
 | ||||
|   string A("A"), B("B"); | ||||
|   DT tree(B, DT(A, 1, 1), DT(A, 2, 3)); | ||||
|  | @ -448,7 +448,7 @@ TEST(DecisionTree, fold) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test retrieving all labels.
 | ||||
| TEST(DecisionTree, labels) { | ||||
| TEST_DISABLED(DecisionTree, labels) { | ||||
|   // Create small two-level tree
 | ||||
|   string A("A"), B("B"); | ||||
|   DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); | ||||
|  | @ -458,7 +458,7 @@ TEST(DecisionTree, labels) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test unzip method.
 | ||||
| TEST(DecisionTree, unzip) { | ||||
| TEST_DISABLED(DecisionTree, unzip) { | ||||
|   using DTP = DecisionTree<string, std::pair<int, string>>; | ||||
|   using DT1 = DecisionTree<string, int>; | ||||
|   using DT2 = DecisionTree<string, string>; | ||||
|  | @ -479,7 +479,7 @@ TEST(DecisionTree, unzip) { | |||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test thresholding.
 | ||||
| TEST(DecisionTree, threshold) { | ||||
| TEST_DISABLED(DecisionTree, threshold) { | ||||
|   // Create three level tree
 | ||||
|   const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2), | ||||
|                                 DT::LabelC("A", 2)}; | ||||
|  | @ -524,6 +524,8 @@ TEST(DecisionTree, ApplyWithAssignment) { | |||
|   DT prunedTree = tree.apply(pruner); | ||||
| 
 | ||||
|   DT expectedTree(keys, "0 0 0 0 5 6 7 8"); | ||||
|   // expectedTree.print();
 | ||||
|   // prunedTree.print();
 | ||||
|   EXPECT(assert_equal(expectedTree, prunedTree)); | ||||
| 
 | ||||
|   size_t count = 0; | ||||
|  | @ -542,16 +544,27 @@ TEST(DecisionTree, ApplyWithAssignment) { | |||
| TEST(DecisionTree, NrAssignments2) { | ||||
|   using gtsam::symbol_shorthand::M; | ||||
| 
 | ||||
|   DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; | ||||
|   std::vector<double> probs = {0, 0, 1, 2}; | ||||
|   DecisionTree<Key, double> dt1(keys, probs); | ||||
|    | ||||
|   EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); | ||||
| 
 | ||||
|   DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; | ||||
|   DecisionTree<Key, double> dt1(keys, probs); | ||||
|   EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); | ||||
|   dt1.print("", DefaultKeyFormatter, [](double x) { return std::to_string(x);}); | ||||
| 
 | ||||
|   /* Create the DecisionTree
 | ||||
|     Choice(m1) | ||||
|     0 Choice(m0) | ||||
|     0 0 Leaf 0.000000 | ||||
|     0 1 Leaf 1.000000 | ||||
|     1 Choice(m0) | ||||
|     1 0 Leaf 0.000000 | ||||
|     1 1 Leaf 2.000000 | ||||
|   */ | ||||
|   DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; | ||||
|   DecisionTree<Key, double> dt2(keys2, probs); | ||||
|   //TODO(Varun) The below is failing, because the number of assignments aren't being set correctly.
 | ||||
|   EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); | ||||
|   std::cout << "\n\n" << std::endl; | ||||
|   dt2.print("", DefaultKeyFormatter, [](double x) { return std::to_string(x);}); | ||||
|   // EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
 | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue