diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 9eb9bde55..7fa97051a 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include @@ -53,7 +52,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // Multiply into one big conditional. NOTE: possibly quite expensive. DiscreteConditional joint; for (auto &&conditional : marginal) { - joint = joint * (*conditional); + // The last discrete conditional may be a DiscreteTableConditional + if (auto dtc = + std::dynamic_pointer_cast(conditional)) { + DiscreteConditional dc(dtc->nrFrontals(), + dtc->table().toDecisionTreeFactor()); + joint = joint * dc; + } else { + joint = joint * (*conditional); + } } // Prune the joint. NOTE: again, possibly quite expensive. @@ -127,7 +134,14 @@ HybridValues HybridBayesNet::optimize() const { for (auto &&conditional : *this) { if (conditional->isDiscrete()) { - discrete_fg.push_back(conditional->asDiscrete()); + if (auto dtc = conditional->asDiscrete()) { + // The number of keys should be small so should not + // be expensive to convert to DiscreteConditional. + discrete_fg.push_back(DiscreteConditional( + dtc->nrFrontals(), dtc->table().toDecisionTreeFactor())); + } else { + discrete_fg.push_back(conditional->asDiscrete()); + } } } diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index bcd6f48c4..55a9c7e88 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -201,7 +201,8 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { - auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); + auto discreteProbs = + this->roots_.at(0)->conditional()->asDiscrete(); DiscreteConditional::shared_ptr prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);