diff --git a/CMakeLists.txt b/CMakeLists.txt index 74019da44..21d8d1b60 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ endif() set (GTSAM_VERSION_MAJOR 4) set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_PATCH 0) -set (GTSAM_PRERELEASE_VERSION "a0") +set (GTSAM_PRERELEASE_VERSION "a1") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") if (${GTSAM_VERSION_PATCH} EQUAL 0) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 3c74e57fd..ab14b2a72 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -261,9 +261,8 @@ namespace gtsam { // Check if zero if (!showZero) { - const Leaf* leaf = dynamic_cast (branch.get()); - std::string value = valueFormatter(leaf->constant()); - if (leaf && value.compare("0")) continue; + const Leaf* leaf = dynamic_cast(branch.get()); + if (leaf && valueFormatter(leaf->constant()).compare("0")) continue; } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index b4f95780d..8ee93eb77 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -307,7 +307,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, if (nrParents() == 0) { // We have no parents, call factor method. ss << ")*:\n" << std::endl; - ss << DecisionTreeFactor::markdown(keyFormatter); + ss << DecisionTreeFactor::markdown(keyFormatter, names); return ss.str(); } diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index a83732883..d17401e44 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -48,7 +48,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, - bool showZero = false) const; + bool showZero = true) const; std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index c4e5f06bb..716a77127 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -100,6 +100,20 @@ TEST(DecisionTreeFactor, enumerate) { EXPECT(actual == expected); } +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, DotWithNames) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; + + for (bool showZero:{true, false}) { + string actual = f.dot(formatter, showZero); + // pretty weak test, as ids are pointers and not stable across platforms. + string expected = "digraph G {"; + EXPECT(actual.substr(0, 11) == expected); + } +} + /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DecisionTreeFactor, markdown) { diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index b498b0541..6d2af3cff 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -135,6 +135,24 @@ TEST(DiscreteConditional, markdown_prior) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Check markdown representation looks as expected, no parents + names. +TEST(DiscreteConditional, markdown_prior_names) { + Symbol x1('x', 1); + DiscreteKey A(x1, 3); + DiscreteConditional conditional(A % "1/2/2"); + string expected = + " *P(x1)*:\n\n" + "|x1|value|\n" + "|:-:|:-:|\n" + "|A0|0.2|\n" + "|A1|0.4|\n" + "|A2|0.4|\n"; + DecisionTreeFactor::Names names{{x1, {"A0", "A1", "A2"}}}; + string actual = conditional.markdown(DefaultKeyFormatter, names); + EXPECT(actual == expected); +} + /* ************************************************************************* */ // Check markdown representation looks as expected, multivalued. TEST(DiscreteConditional, markdown_multivalued) { @@ -155,7 +173,7 @@ TEST(DiscreteConditional, markdown_multivalued) { } /* ************************************************************************* */ -// Check markdown representation looks as expected, two parents. +// Check markdown representation looks as expected, two parents + names. TEST(DiscreteConditional, markdown) { DiscreteKey A(2, 2), B(1, 2), C(0, 3); DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index b6172382a..ef9efbe02 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -382,6 +382,31 @@ TEST(DiscreteFactorGraph, Dot) { EXPECT(actual == expected); } +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, DotWithNames) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + vector names{"C", "A", "B"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = graph.dot(formatter); + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"C\"];\n" + " var1[label=\"A\"];\n" + " var2[label=\"B\"];\n" + "\n" + " var0--var1;\n" + " var0--var2;\n" + "}\n"; + EXPECT(actual == expected); +} + /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DiscreteFactorGraph, markdown) {