diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 2037dd951..68892b1a4 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -64,10 +64,18 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DiscreteFactorGraph::product() const { - DecisionTreeFactor result; + TableFactor DiscreteFactorGraph::product() const { + TableFactor result; for (const sharedFactor& factor : *this) { - if (factor) result = (*factor) * result; + if (factor) { + if (auto f = std::dynamic_pointer_cast(factor)) { + result = result * (*f); + } + else if (auto dtf = + std::dynamic_pointer_cast(factor)) { + result = TableFactor(result * (*dtf)); + } + } } return result; } @@ -116,15 +124,14 @@ namespace gtsam { * product to prevent underflow. * * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DecisionTreeFactor + * @return TableFactor */ - static DecisionTreeFactor ProductAndNormalize( - const DiscreteFactorGraph& factors) { + static TableFactor ProductAndNormalize(const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors #if GTSAM_HYBRID_TIMING gttic_(DiscreteProduct); #endif - DecisionTreeFactor product = factors.product(); + TableFactor product = factors.product(); #if GTSAM_HYBRID_TIMING gttoc_(DiscreteProduct); #endif @@ -149,11 +156,11 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = ProductAndNormalize(factors); + TableFactor product = ProductAndNormalize(factors); // max out frontals, this is the factor on the separator gttic(max); - DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + TableFactor::shared_ptr max = product.max(frontalKeys); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front @@ -166,8 +173,8 @@ namespace gtsam { // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = - std::make_shared(nrFrontals, orderedKeys, product); + auto lookup = std::make_shared( + nrFrontals, orderedKeys, product.toDecisionTreeFactor()); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -227,13 +234,13 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = ProductAndNormalize(factors); + TableFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteSum); #endif - DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); + TableFactor::shared_ptr sum = product.sum(frontalKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteSum); #endif @@ -246,11 +253,18 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteDivide); +#endif + auto c = product / (*sum); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteDivide); +#endif #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteToDiscreteConditional); #endif - auto conditional = - std::make_shared(product, *sum, orderedKeys); + auto conditional = std::make_shared( + orderedKeys.size(), c.toDecisionTreeFactor()); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteToDiscreteConditional); #endif diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index c57d2258c..f1575cd7e 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteKeys discreteKeys() const; /** return product of all factors as a single factor */ - DecisionTreeFactor product() const; + TableFactor product() const; /** * Evaluates the factor graph given values, returns the joint probability of