use TableFactor instead of DecisionTreeFactor in discrete elimination
parent
02d461e359
commit
34fba6823a
|
@ -64,10 +64,18 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DecisionTreeFactor result;
|
||||
TableFactor DiscreteFactorGraph::product() const {
|
||||
TableFactor result;
|
||||
for (const sharedFactor& factor : *this) {
|
||||
if (factor) result = (*factor) * result;
|
||||
if (factor) {
|
||||
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||
result = result * (*f);
|
||||
}
|
||||
else if (auto dtf =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
result = TableFactor(result * (*dtf));
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -116,15 +124,14 @@ namespace gtsam {
|
|||
* product to prevent underflow.
|
||||
*
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return DecisionTreeFactor
|
||||
* @return TableFactor
|
||||
*/
|
||||
static DecisionTreeFactor ProductAndNormalize(
|
||||
const DiscreteFactorGraph& factors) {
|
||||
static TableFactor ProductAndNormalize(const DiscreteFactorGraph& factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteProduct);
|
||||
#endif
|
||||
DecisionTreeFactor product = factors.product();
|
||||
TableFactor product = factors.product();
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteProduct);
|
||||
#endif
|
||||
|
@ -149,11 +156,11 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
DecisionTreeFactor product = ProductAndNormalize(factors);
|
||||
TableFactor product = ProductAndNormalize(factors);
|
||||
|
||||
// max out frontals, this is the factor on the separator
|
||||
gttic(max);
|
||||
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||
TableFactor::shared_ptr max = product.max(frontalKeys);
|
||||
gttoc(max);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
|
@ -166,8 +173,8 @@ namespace gtsam {
|
|||
// Make lookup with product
|
||||
gttic(lookup);
|
||||
size_t nrFrontals = frontalKeys.size();
|
||||
auto lookup =
|
||||
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
|
||||
auto lookup = std::make_shared<DiscreteLookupTable>(
|
||||
nrFrontals, orderedKeys, product.toDecisionTreeFactor());
|
||||
gttoc(lookup);
|
||||
|
||||
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
||||
|
@ -227,13 +234,13 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
DecisionTreeFactor product = ProductAndNormalize(factors);
|
||||
TableFactor product = ProductAndNormalize(factors);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteSum);
|
||||
#endif
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
TableFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteSum);
|
||||
#endif
|
||||
|
@ -246,11 +253,18 @@ namespace gtsam {
|
|||
sum->keys().end());
|
||||
|
||||
// now divide product/sum to get conditional
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteDivide);
|
||||
#endif
|
||||
auto c = product / (*sum);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteDivide);
|
||||
#endif
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteToDiscreteConditional);
|
||||
#endif
|
||||
auto conditional =
|
||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||
auto conditional = std::make_shared<DiscreteConditional>(
|
||||
orderedKeys.size(), c.toDecisionTreeFactor());
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteToDiscreteConditional);
|
||||
#endif
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/base/FastSet.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/** return product of all factors as a single factor */
|
||||
DecisionTreeFactor product() const;
|
||||
TableFactor product() const;
|
||||
|
||||
/**
|
||||
* Evaluates the factor graph given values, returns the joint probability of
|
||||
|
|
Loading…
Reference in New Issue