Added conditional markdown formatter

release/4.3a0
Frank Dellaert 2021-12-23 15:57:55 -05:00
parent c5e6650d67
commit ff730a7184
3 changed files with 92 additions and 4 deletions

View File

@ -222,6 +222,46 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng);
}
/* ******************************************************************************** */
/* ************************************************************************* */
std::string DiscreteConditional::_repr_markdown_(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;
// Print out header and construct argument for `cartesianProduct`.
// TODO(dellaert): examine why we can't use "for (auto key: frontals())"
std::vector<std::pair<Key, size_t>> pairs;
ss << "|";
const_iterator it;
for (it = beginParents(); it != endParents(); ++it) {
auto key = *it;
ss << keyFormatter(key) << "|";
pairs.emplace_back(key, cardinalities_.at(key));
}
for (it = beginFrontals(); it != endFrontals(); ++it) {
auto key = *it;
ss << keyFormatter(key) << "|";
pairs.emplace_back(key, cardinalities_.at(key));
}
ss << "value|\n";
// Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < size(); j++) ss << ":-:|";
ss << ":-:|\n";
// Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
for (const auto& a : assignments) {
ss << "|";
for (it = beginParents(); it != endParents(); ++it) ss << a.at(*it) << "|";
for (it = beginFrontals(); it != endFrontals(); ++it)
ss << "*" << a.at(*it) << "*|";
ss << operator()(a) << "|\n";
}
return ss.str();
}
/* ********************************************************************************
*/
}// namespace

View File

@ -167,7 +167,14 @@ public:
void sampleInPlace(DiscreteValues* parentsValues) const;
/// @}
/// @name Wrapper support
/// @{
/// Render as markdown table.
std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
};
// DiscreteConditional

View File

@ -101,9 +101,50 @@ TEST(DiscreteConditional, Combine) {
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
DiscreteConditional actual(2, factor);
auto expected = DiscreteConditional::Combine(c.begin(), c.end());
EXPECT(assert_equal(*expected, actual, 1e-5));
DiscreteConditional expected(2, factor);
auto actual = DiscreteConditional::Combine(c.begin(), c.end());
EXPECT(assert_equal(expected, *actual, 1e-5));
}
/* ************************************************************************* */
// TEST(DiscreteConditional, Combine2) {
// DiscreteKey A(0, 3), B(1, 2), C(2, 2);
// vector<DiscreteConditional::shared_ptr> c;
// auto P = {B, C};
// c.push_back(boost::make_shared<DiscreteConditional>(A, P, "1/2 2/1 1/2 2/1"));
// c.push_back(boost::make_shared<DiscreteConditional>(B | C = "1/2"));
// auto actual = DiscreteConditional::Combine(c.begin(), c.end());
// GTSAM_PRINT(*actual);
// }
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DiscreteConditional, markdown) {
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
EXPECT_LONGS_EQUAL(A.first, *(conditional.beginFrontals()));
EXPECT_LONGS_EQUAL(B.first, *(conditional.beginParents()));
EXPECT(conditional.endParents() == conditional.end());
EXPECT(conditional.endFrontals() == conditional.beginParents());
string expected =
"|B|C|A|value|\n"
"|:-:|:-:|:-:|:-:|\n"
"|0|0|*0*|0|\n"
"|0|0|*1*|1|\n"
"|0|1|*0*|0.25|\n"
"|0|1|*1*|0.75|\n"
"|0|2|*0*|0.5|\n"
"|0|2|*1*|0.5|\n"
"|1|0|*0*|0.75|\n"
"|1|0|*1*|0.25|\n"
"|1|1|*0*|0|\n"
"|1|1|*1*|1|\n"
"|1|2|*0*|1|\n"
"|1|2|*1*|0|\n";
vector<string> names{"C", "B", "A"};
auto formatter = [names](Key key) { return names[key]; };
string actual = conditional._repr_markdown_(formatter);
EXPECT(actual == expected);
}
/* ************************************************************************* */