custom max-product for HybridBayesTree

release/4.3a0
Varun Agrawal 2025-01-01 20:10:01 -05:00
parent cafac6317e
commit 35502f3f32
1 changed files with 24 additions and 2 deletions

View File

@ -20,6 +20,7 @@
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
@ -41,6 +42,25 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}
/* ************************************************************************* */
DiscreteValues HybridBayesTree::discreteMaxProduct(
const DiscreteFactorGraph& dfg) const {
TableFactor product = TableProduct(dfg);
uint64_t maxIdx = 0;
double maxValue = 0.0;
Eigen::SparseVector<double> sparseTable = product.sparseTable();
for (TableFactor::SparseIt it(sparseTable); it; ++it) {
if (it.value() > maxValue) {
maxIdx = it.index();
maxValue = it.value();
}
}
DiscreteValues assignment = product.findAssignments(maxIdx);
return assignment;
}
/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
DiscreteFactorGraph discrete_fg;
@ -52,8 +72,10 @@ HybridValues HybridBayesTree::optimize() const {
// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
discrete_fg.push_back(root_conditional->asDiscrete());
mpe = discrete_fg.optimize();
auto discrete = std::dynamic_pointer_cast<DiscreteTableConditional>(
root_conditional->asDiscrete());
discrete_fg.push_back(discrete);
mpe = discreteMaxProduct(discrete_fg);
} else {
throw std::runtime_error(
"HybridBayesTree root is not discrete-only. Please check elimination "