diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index ce2ddda81..bcd6f48c4 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,25 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +DiscreteValues HybridBayesTree::discreteMaxProduct( + const DiscreteFactorGraph& dfg) const { + TableFactor product = TableProduct(dfg); + + uint64_t maxIdx = 0; + double maxValue = 0.0; + Eigen::SparseVector sparseTable = product.sparseTable(); + for (TableFactor::SparseIt it(sparseTable); it; ++it) { + if (it.value() > maxValue) { + maxIdx = it.index(); + maxValue = it.value(); + } + } + + DiscreteValues assignment = product.findAssignments(maxIdx); + return assignment; +} + /* ************************************************************************* */ HybridValues HybridBayesTree::optimize() const { DiscreteFactorGraph discrete_fg; @@ -52,8 +72,10 @@ HybridValues HybridBayesTree::optimize() const { // The root should be discrete only, we compute the MPE if (root_conditional->isDiscrete()) { - discrete_fg.push_back(root_conditional->asDiscrete()); - mpe = discrete_fg.optimize(); + auto discrete = std::dynamic_pointer_cast( + root_conditional->asDiscrete()); + discrete_fg.push_back(discrete); + mpe = discreteMaxProduct(discrete_fg); } else { throw std::runtime_error( "HybridBayesTree root is not discrete-only. Please check elimination "