diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 688cd85a6..f8f2835e5 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -169,6 +169,9 @@ namespace gtsam { } } + /// Convert into a decision tree + DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } + /// Create new factor by summing all values with the same separator values DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, ADT::Ring::add); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 4c1d0afb1..e2d32e828 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -113,6 +113,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr operator*( const DiscreteFactor::shared_ptr&) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// Create new factor by summing all values with the same separator values virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 727c96ce4..50d15ff5e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -180,6 +180,20 @@ DiscreteFactor::shared_ptr TableFactor::operator*( } } +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); + } + gttic_(toDecisionTreeFactor_Constructor); + // NOTE(Varun): This constructor is really expensive!! + DecisionTreeFactor f(dkeys, table); + gttoc_(toDecisionTreeFactor_Constructor); + return f; +} + /* ************************************************************************ */ TableFactor TableFactor::choose(const DiscreteValues parent_assign, DiscreteKeys parent_keys) const { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 497c42dc2..47a7c6bbb 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -207,6 +207,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { } } + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + /// Create a TableFactor that is a subset of this TableFactor TableFactor choose(const DiscreteValues assignments, DiscreteKeys parent_keys) const;