Create markdown representation in DTFactor
parent
fb3f00d656
commit
a27437690c
|
@ -134,5 +134,38 @@ namespace gtsam {
|
||||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
std::string DecisionTreeFactor::_repr_markdown_() const {
|
||||||
|
std::stringstream ss;
|
||||||
|
|
||||||
|
// Print out header and calculate number of rows.
|
||||||
|
ss << "|";
|
||||||
|
size_t m = 1; // number of rows
|
||||||
|
for (auto& key : cardinalities_) {
|
||||||
|
size_t k = key.second;
|
||||||
|
m *= k;
|
||||||
|
ss << key.first << "(" << k << ")|";
|
||||||
|
}
|
||||||
|
ss << "value|\n";
|
||||||
|
|
||||||
|
// Print out separator with alignment hints.
|
||||||
|
size_t n = cardinalities_.size();
|
||||||
|
ss << "|";
|
||||||
|
for (size_t j = 0; j < n; j++) ss << ":-:|";
|
||||||
|
ss << ":-:|\n";
|
||||||
|
|
||||||
|
// Print out all rows.
|
||||||
|
std::vector<std::pair<Key, size_t>> keys(cardinalities_.begin(),
|
||||||
|
cardinalities_.end());
|
||||||
|
const auto assignments = cartesianProduct(keys);
|
||||||
|
for (auto &&assignment : assignments) {
|
||||||
|
ss << "|";
|
||||||
|
for (auto& kv : assignment) ss << kv.second << "|";
|
||||||
|
const double value = operator()(assignment);
|
||||||
|
ss << value << "|\n";
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -163,6 +163,14 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Render as markdown table.
|
||||||
|
std::string _repr_markdown_() const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
};
|
};
|
||||||
// DecisionTreeFactor
|
// DecisionTreeFactor
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||||
string dot(bool showZero = false) const;
|
string dot(bool showZero = false) const;
|
||||||
|
string _repr_markdown_() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
@ -65,6 +66,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||||
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||||
|
string _repr_markdown_() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
|
|
@ -30,8 +30,10 @@ using namespace gtsam;
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DecisionTreeFactor, constructors)
|
TEST( DecisionTreeFactor, constructors)
|
||||||
{
|
{
|
||||||
|
// Declare a bunch of keys
|
||||||
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
||||||
|
|
||||||
|
// Create factors
|
||||||
DecisionTreeFactor f1(X, "2 8");
|
DecisionTreeFactor f1(X, "2 8");
|
||||||
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
|
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
|
||||||
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||||
|
@ -39,10 +41,6 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
EXPECT_LONGS_EQUAL(2,f2.size());
|
EXPECT_LONGS_EQUAL(2,f2.size());
|
||||||
EXPECT_LONGS_EQUAL(3,f3.size());
|
EXPECT_LONGS_EQUAL(3,f3.size());
|
||||||
|
|
||||||
// f1.print("f1:");
|
|
||||||
// f2.print("f2:");
|
|
||||||
// f3.print("f3:");
|
|
||||||
|
|
||||||
DiscreteValues values;
|
DiscreteValues values;
|
||||||
values[0] = 1; // x
|
values[0] = 1; // x
|
||||||
values[1] = 2; // y
|
values[1] = 2; // y
|
||||||
|
@ -55,37 +53,26 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE( DecisionTreeFactor, multiplication)
|
TEST_UNSAFE( DecisionTreeFactor, multiplication)
|
||||||
{
|
{
|
||||||
// Declare a bunch of keys
|
|
||||||
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
|
DiscreteKey v0(0,2), v1(1,2), v2(2,2);
|
||||||
|
|
||||||
// Create a factor
|
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
||||||
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
|
||||||
// f1.print("f1:");
|
|
||||||
// f2.print("f2:");
|
|
||||||
|
|
||||||
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
|
||||||
|
|
||||||
DecisionTreeFactor actual = f1 * f2;
|
DecisionTreeFactor actual = f1 * f2;
|
||||||
// actual.print("actual: ");
|
|
||||||
CHECK(assert_equal(expected, actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DecisionTreeFactor, sum_max)
|
TEST( DecisionTreeFactor, sum_max)
|
||||||
{
|
{
|
||||||
// Declare a bunch of keys
|
|
||||||
DiscreteKey v0(0,3), v1(1,2);
|
DiscreteKey v0(0,3), v1(1,2);
|
||||||
|
|
||||||
// Create a factor
|
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
|
||||||
DecisionTreeFactor expected(v1, "9 12");
|
DecisionTreeFactor expected(v1, "9 12");
|
||||||
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
||||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||||
// f1.print("f1:");
|
|
||||||
// actual->print("actual: ");
|
|
||||||
// actual->printCache("actual cache: ");
|
|
||||||
|
|
||||||
DecisionTreeFactor expected2(v1, "5 6");
|
DecisionTreeFactor expected2(v1, "5 6");
|
||||||
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
||||||
|
@ -93,11 +80,26 @@ TEST( DecisionTreeFactor, sum_max)
|
||||||
|
|
||||||
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
||||||
// f2.print("f2: ");
|
|
||||||
// actual22->print("actual22: ");
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST( DecisionTreeFactor, markdown)
|
||||||
|
{
|
||||||
|
DiscreteKey v0(0,3), v1(1,2);
|
||||||
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
string expected =
|
||||||
|
"|0(3)|1(2)|value|\n"
|
||||||
|
"|:-:|:-:|:-:|\n"
|
||||||
|
"|0|0|1|\n"
|
||||||
|
"|1|0|3|\n"
|
||||||
|
"|2|0|5|\n"
|
||||||
|
"|0|1|2|\n"
|
||||||
|
"|1|1|4|\n"
|
||||||
|
"|2|1|6|\n";
|
||||||
|
string actual = f1._repr_markdown_();
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue