use TableFactor instead of DecisionTreeFactor in discrete elimination

release/4.3a0
Varun Agrawal 2024-12-27 14:29:16 -05:00
parent 02d461e359
commit 34fba6823a
2 changed files with 31 additions and 16 deletions

View File

@ -64,10 +64,18 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const { TableFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result; TableFactor result;
for (const sharedFactor& factor : *this) { for (const sharedFactor& factor : *this) {
if (factor) result = (*factor) * result; if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
result = result * (*f);
}
else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
result = TableFactor(result * (*dtf));
}
}
} }
return result; return result;
} }
@ -116,15 +124,14 @@ namespace gtsam {
* product to prevent underflow. * product to prevent underflow.
* *
* @param factors The factors to multiply as a DiscreteFactorGraph. * @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor * @return TableFactor
*/ */
static DecisionTreeFactor ProductAndNormalize( static TableFactor ProductAndNormalize(const DiscreteFactorGraph& factors) {
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct); gttic_(DiscreteProduct);
#endif #endif
DecisionTreeFactor product = factors.product(); TableFactor product = factors.product();
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct); gttoc_(DiscreteProduct);
#endif #endif
@ -149,11 +156,11 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors, EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors); TableFactor product = ProductAndNormalize(factors);
// max out frontals, this is the factor on the separator // max out frontals, this is the factor on the separator
gttic(max); gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); TableFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max); gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
@ -166,8 +173,8 @@ namespace gtsam {
// Make lookup with product // Make lookup with product
gttic(lookup); gttic(lookup);
size_t nrFrontals = frontalKeys.size(); size_t nrFrontals = frontalKeys.size();
auto lookup = auto lookup = std::make_shared<DiscreteLookupTable>(
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product); nrFrontals, orderedKeys, product.toDecisionTreeFactor());
gttoc(lookup); gttoc(lookup);
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max}; return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
@ -227,13 +234,13 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors); TableFactor product = ProductAndNormalize(factors);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum); gttic_(EliminateDiscreteSum);
#endif #endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); TableFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum); gttoc_(EliminateDiscreteSum);
#endif #endif
@ -246,11 +253,18 @@ namespace gtsam {
sum->keys().end()); sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteDivide);
#endif
auto c = product / (*sum);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteDivide);
#endif
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional); gttic_(EliminateDiscreteToDiscreteConditional);
#endif #endif
auto conditional = auto conditional = std::make_shared<DiscreteConditional>(
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); orderedKeys.size(), c.toDecisionTreeFactor());
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional); gttoc_(EliminateDiscreteToDiscreteConditional);
#endif #endif

View File

@ -25,6 +25,7 @@
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/discrete/TableFactor.h>
#include <string> #include <string>
#include <utility> #include <utility>
@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const; DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */ /** return product of all factors as a single factor */
DecisionTreeFactor product() const; TableFactor product() const;
/** /**
* Evaluates the factor graph given values, returns the joint probability of * Evaluates the factor graph given values, returns the joint probability of