diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index b7b9d7034..9816aa3fa 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -134,5 +134,38 @@ namespace gtsam { return boost::make_shared(dkeys, result); } -/* ************************************************************************* */ + /* ************************************************************************* */ + std::string DecisionTreeFactor::_repr_markdown_() const { + std::stringstream ss; + + // Print out header and calculate number of rows. + ss << "|"; + size_t m = 1; // number of rows + for (auto& key : cardinalities_) { + size_t k = key.second; + m *= k; + ss << key.first << "(" << k << ")|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + size_t n = cardinalities_.size(); + ss << "|"; + for (size_t j = 0; j < n; j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + std::vector> keys(cardinalities_.begin(), + cardinalities_.end()); + const auto assignments = cartesianProduct(keys); + for (auto &&assignment : assignments) { + ss << "|"; + for (auto& kv : assignment) ss << kv.second << "|"; + const double value = operator()(assignment); + ss << value << "|\n"; + } + return ss.str(); + } + + /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index aa718e35d..68d629e9c 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -163,6 +163,14 @@ namespace gtsam { // } /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string _repr_markdown_() const; + + /// @} + }; // DecisionTreeFactor diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 9c53b3b70..75c56b0dd 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -39,6 +39,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; string dot(bool showZero = false) const; + string _repr_markdown_() const; }; #include @@ -65,6 +66,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t sample(const gtsam::DiscreteValues& parentsValues) const; void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; + string _repr_markdown_() const; }; #include diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index bcab70bd9..3d73b4481 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -30,8 +30,10 @@ using namespace gtsam; /* ************************************************************************* */ TEST( DecisionTreeFactor, constructors) { + // Declare a bunch of keys DiscreteKey X(0,2), Y(1,3), Z(2,2); + // Create factors DecisionTreeFactor f1(X, "2 8"); DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7"); DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); @@ -39,10 +41,6 @@ TEST( DecisionTreeFactor, constructors) EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(3,f3.size()); - // f1.print("f1:"); - // f2.print("f2:"); - // f3.print("f3:"); - DiscreteValues values; values[0] = 1; // x values[1] = 2; // y @@ -55,37 +53,26 @@ TEST( DecisionTreeFactor, constructors) /* ************************************************************************* */ TEST_UNSAFE( DecisionTreeFactor, multiplication) { - // Declare a bunch of keys DiscreteKey v0(0,2), v1(1,2), v2(2,2); - // Create a factor DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); -// f1.print("f1:"); -// f2.print("f2:"); DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); DecisionTreeFactor actual = f1 * f2; -// actual.print("actual: "); CHECK(assert_equal(expected, actual)); } /* ************************************************************************* */ TEST( DecisionTreeFactor, sum_max) { - // Declare a bunch of keys DiscreteKey v0(0,3), v1(1,2); - - // Create a factor DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor expected(v1, "9 12"); DecisionTreeFactor::shared_ptr actual = f1.sum(1); CHECK(assert_equal(expected, *actual, 1e-5)); -// f1.print("f1:"); -// actual->print("actual: "); -// actual->printCache("actual cache: "); DecisionTreeFactor expected2(v1, "5 6"); DecisionTreeFactor::shared_ptr actual2 = f1.max(1); @@ -93,11 +80,26 @@ TEST( DecisionTreeFactor, sum_max) DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); -// f2.print("f2: "); -// actual22->print("actual22: "); - } +/* ************************************************************************* */ +TEST( DecisionTreeFactor, markdown) +{ + DiscreteKey v0(0,3), v1(1,2); + DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); + string expected = + "|0(3)|1(2)|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|1|\n" + "|1|0|3|\n" + "|2|0|5|\n" + "|0|1|2|\n" + "|1|1|4|\n" + "|2|1|6|\n"; + string actual = f1._repr_markdown_(); + EXPECT(actual == expected); + } + /* ************************************************************************* */ int main() { TestResult tr;