diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 439889ebf..f6a64f11f 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -248,8 +248,9 @@ namespace gtsam { void dot(std::ostream& os, bool showZero) const override { os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ << "\"]\n"; - for (size_t i = 0; i < branches_.size(); i++) { - NodePtr branch = branches_[i]; + size_t B = branches_.size(); + for (size_t i = 0; i < B; i++) { + const NodePtr& branch = branches_[i]; // Check if zero if (!showZero) { @@ -258,8 +259,10 @@ namespace gtsam { } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; - if (i == 0) os << " [style=dashed]"; - if (i > 1) os << " [style=bold]"; + if (B == 2) { + if (i == 0) os << " [style=dashed]"; + if (i > 1) os << " [style=bold]"; + } os << std::endl; branch->dot(os, showZero); } @@ -671,7 +674,14 @@ namespace gtsam { int result = system( ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); -} + } + + template + std::string DecisionTree::dot(bool showZero) const { + std::stringstream ss; + dot(ss, showZero); + return ss.str(); + } /*********************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0ee0b8be0..1f76f4ca3 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -198,6 +198,9 @@ namespace gtsam { /** output to graphviz format, open a file */ void dot(const std::string& name, bool showZero = true) const; + /** output to graphviz format string */ + std::string dot(bool showZero = true) const; + /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index daea84e70..a883226cc 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -39,6 +39,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + string dot(bool showZero = false) const; }; #include @@ -67,6 +68,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t sample(const gtsam::DiscreteValues& parentsValues) const; void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; + string dot(bool showZero = false) const; }; #include