From ff730a7184fdbb0c93df7f07d09e4f6df13beefe Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 23 Dec 2021 15:57:55 -0500 Subject: [PATCH] Added conditional markdown formatter --- gtsam/discrete/DiscreteConditional.cpp | 42 ++++++++++++++++- gtsam/discrete/DiscreteConditional.h | 7 +++ .../tests/testDiscreteConditional.cpp | 47 +++++++++++++++++-- 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 293b69748..2a891feb0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -222,6 +222,46 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { return distribution(rng); } -/* ******************************************************************************** */ +/* ************************************************************************* */ +std::string DiscreteConditional::_repr_markdown_( + const KeyFormatter& keyFormatter) const { + std::stringstream ss; + + // Print out header and construct argument for `cartesianProduct`. + // TODO(dellaert): examine why we can't use "for (auto key: frontals())" + std::vector> pairs; + ss << "|"; + const_iterator it; + for (it = beginParents(); it != endParents(); ++it) { + auto key = *it; + ss << keyFormatter(key) << "|"; + pairs.emplace_back(key, cardinalities_.at(key)); + } + for (it = beginFrontals(); it != endFrontals(); ++it) { + auto key = *it; + ss << keyFormatter(key) << "|"; + pairs.emplace_back(key, cardinalities_.at(key)); + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = cartesianProduct(rpairs); + for (const auto& a : assignments) { + ss << "|"; + for (it = beginParents(); it != endParents(); ++it) ss << a.at(*it) << "|"; + for (it = beginFrontals(); it != endFrontals(); ++it) + ss << "*" << a.at(*it) << "*|"; + ss << operator()(a) << "|\n"; + } + return ss.str(); +} +/* ******************************************************************************** + */ }// namespace diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 06928e2e7..ad21151a8 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -167,7 +167,14 @@ public: void sampleInPlace(DiscreteValues* parentsValues) const; /// @} + /// @name Wrapper support + /// @{ + /// Render as markdown table. + std::string _repr_markdown_( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} }; // DiscreteConditional diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 79714217c..d031882c1 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -101,9 +101,50 @@ TEST(DiscreteConditional, Combine) { c.push_back(boost::make_shared(A | B = "1/2 2/1")); c.push_back(boost::make_shared(B % "1/2")); DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional actual(2, factor); - auto expected = DiscreteConditional::Combine(c.begin(), c.end()); - EXPECT(assert_equal(*expected, actual, 1e-5)); + DiscreteConditional expected(2, factor); + auto actual = DiscreteConditional::Combine(c.begin(), c.end()); + EXPECT(assert_equal(expected, *actual, 1e-5)); +} + +/* ************************************************************************* */ +// TEST(DiscreteConditional, Combine2) { +// DiscreteKey A(0, 3), B(1, 2), C(2, 2); +// vector c; +// auto P = {B, C}; +// c.push_back(boost::make_shared(A, P, "1/2 2/1 1/2 2/1")); +// c.push_back(boost::make_shared(B | C = "1/2")); +// auto actual = DiscreteConditional::Combine(c.begin(), c.end()); +// GTSAM_PRINT(*actual); +// } + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteConditional, markdown) { + DiscreteKey A(2, 2), B(1, 2), C(0, 3); + DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); + EXPECT_LONGS_EQUAL(A.first, *(conditional.beginFrontals())); + EXPECT_LONGS_EQUAL(B.first, *(conditional.beginParents())); + EXPECT(conditional.endParents() == conditional.end()); + EXPECT(conditional.endFrontals() == conditional.beginParents()); + string expected = + "|B|C|A|value|\n" + "|:-:|:-:|:-:|:-:|\n" + "|0|0|*0*|0|\n" + "|0|0|*1*|1|\n" + "|0|1|*0*|0.25|\n" + "|0|1|*1*|0.75|\n" + "|0|2|*0*|0.5|\n" + "|0|2|*1*|0.5|\n" + "|1|0|*0*|0.75|\n" + "|1|0|*1*|0.25|\n" + "|1|1|*0*|0|\n" + "|1|1|*1*|1|\n" + "|1|2|*0*|1|\n" + "|1|2|*1*|0|\n"; + vector names{"C", "B", "A"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = conditional._repr_markdown_(formatter); + EXPECT(actual == expected); } /* ************************************************************************* */