diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ad4cbad43..7bd9e9b7f 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -67,7 +67,7 @@ namespace gtsam { void DecisionTreeFactor::print(const string& s, const KeyFormatter& formatter) const { cout << s; - ADT::print("Potentials:",formatter); + ADT::print("", formatter); } /* ************************************************************************* */ @@ -163,6 +163,18 @@ namespace gtsam { return result; } + /* ************************************************************************* */ + DiscreteKeys DecisionTreeFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; + } + /* ************************************************************************* */ static std::string valueFormatter(const double& v) { return (boost::format("%4.2g") % v).str(); diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 8beeb4c4a..0bfdf6b90 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -183,6 +183,9 @@ namespace gtsam { /// Enumerate all values into a map from values to double. std::vector> enumerate() const; + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; + /// @} /// @name Wrapper support /// @{