From 5e9020237f8a821a641080911adebeab8da84426 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 18 Mar 2022 00:11:05 -0400 Subject: [PATCH] add test for visitWith with pruned tree --- gtsam/discrete/tests/testDecisionTree.cpp | 46 ++++++++++++++++++++--- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 967023eeb..ca5daf201 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -20,12 +20,10 @@ // #define DT_DEBUG_MEMORY // #define DT_NO_PRUNING #define DISABLE_DOT -#include - -#include -#include - #include +#include +#include +#include #include using namespace boost::assign; @@ -346,6 +344,44 @@ TEST(DecisionTree, visitWith) { EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); } +/* ************************************************************************** */ +// Test visit, with Choices argument. +TEST(DecisionTree, VisitWithPruned) { + // Create pruned tree + std::pair A("A", 2), B("B", 2), C("C", 2); + std::vector> labels = {C, B, A}; + std::vector nodes = {0, 0, 2, 3, 4, 4, 6, 7}; + DT tree(labels, nodes); + + std::vector> choices; + auto func = [&](const Assignment& choice, const int& d) { + choices.push_back(choice); + }; + tree.visitWith(func); + + EXPECT_LONGS_EQUAL(6, choices.size()); + + Assignment expectedAssignment; + + expectedAssignment = {{"B", 0}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(0)); + + expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(1)); + + expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(2)); + + expectedAssignment = {{"B", 0}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(3)); + + expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(4)); + + expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(5)); +} + /* ************************************************************************** */ // Test fold. TEST(DecisionTree, fold) {