From 35502f3f32cfcf63c0a19ee12e2b5c4f7e567b32 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 1 Jan 2025 20:10:01 -0500 Subject: [PATCH] custom max-product for HybridBayesTree --- gtsam/hybrid/HybridBayesTree.cpp | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index ce2ddda81..bcd6f48c4 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,25 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +DiscreteValues HybridBayesTree::discreteMaxProduct( + const DiscreteFactorGraph& dfg) const { + TableFactor product = TableProduct(dfg); + + uint64_t maxIdx = 0; + double maxValue = 0.0; + Eigen::SparseVector sparseTable = product.sparseTable(); + for (TableFactor::SparseIt it(sparseTable); it; ++it) { + if (it.value() > maxValue) { + maxIdx = it.index(); + maxValue = it.value(); + } + } + + DiscreteValues assignment = product.findAssignments(maxIdx); + return assignment; +} + /* ************************************************************************* */ HybridValues HybridBayesTree::optimize() const { DiscreteFactorGraph discrete_fg; @@ -52,8 +72,10 @@ HybridValues HybridBayesTree::optimize() const { // The root should be discrete only, we compute the MPE if (root_conditional->isDiscrete()) { - discrete_fg.push_back(root_conditional->asDiscrete()); - mpe = discrete_fg.optimize(); + auto discrete = std::dynamic_pointer_cast( + root_conditional->asDiscrete()); + discrete_fg.push_back(discrete); + mpe = discreteMaxProduct(discrete_fg); } else { throw std::runtime_error( "HybridBayesTree root is not discrete-only. Please check elimination "