gtsam/gtsam/discrete/DiscreteFactorGraph.cpp

290 lines
10 KiB
C++
Raw Normal View History

2014-01-10 05:38:47 +08:00
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteFactorGraph.cpp
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
* @author Varun Agrawal
2014-01-10 05:38:47 +08:00
*/
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteConditional.h>
2014-01-10 05:38:47 +08:00
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
2014-01-10 05:38:47 +08:00
#include <gtsam/discrete/DiscreteJunctionTree.h>
2022-01-22 02:18:28 +08:00
#include <gtsam/discrete/DiscreteLookupDAG.h>
2014-01-10 05:38:47 +08:00
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <gtsam/inference/FactorGraph-inst.h>
using std::vector;
using std::string;
using std::map;
2014-01-10 05:38:47 +08:00
namespace gtsam {
// Instantiate base classes
template class FactorGraph<DiscreteFactor>;
template class EliminateableFactorGraph<DiscreteFactorGraph>;
2024-12-11 00:05:50 +08:00
/* ************************************************************************ */
bool DiscreteFactorGraph::equals(const This& fg, double tol) const {
2014-01-10 05:38:47 +08:00
return Base::equals(fg, tol);
}
2024-12-11 00:05:50 +08:00
/* ************************************************************************ */
KeySet DiscreteFactorGraph::keys() const {
KeySet keys;
2022-01-23 00:06:06 +08:00
for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end());
}
2014-01-10 05:38:47 +08:00
return keys;
}
2024-12-11 00:05:50 +08:00
/* ************************************************************************ */
2022-01-23 00:06:06 +08:00
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
2024-12-10 23:41:57 +08:00
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
2022-01-23 00:06:06 +08:00
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}
return result;
}
2024-12-11 00:05:50 +08:00
/* ************************************************************************ */
2014-01-10 05:38:47 +08:00
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
2014-01-10 05:38:47 +08:00
if (factor) result = (*factor) * result;
}
2014-01-10 05:38:47 +08:00
return result;
}
2024-12-11 00:05:50 +08:00
/* ************************************************************************ */
double DiscreteFactorGraph::operator()(const DiscreteValues& values) const {
2014-01-10 05:38:47 +08:00
double product = 1.0;
for (const sharedFactor& factor : factors_) {
2024-12-11 03:10:33 +08:00
if (factor) product *= (*factor)(values);
}
2014-01-10 05:38:47 +08:00
return product;
}
2024-12-11 00:05:50 +08:00
/* ************************************************************************ */
void DiscreteFactorGraph::print(const string& s,
2024-12-11 00:05:50 +08:00
const KeyFormatter& formatter) const {
2014-01-10 05:38:47 +08:00
std::cout << s << std::endl;
std::cout << "size: " << size() << std::endl;
for (size_t i = 0; i < factors_.size(); i++) {
std::stringstream ss;
ss << "factor " << i << ": ";
2020-04-07 05:31:05 +08:00
if (factors_[i] != nullptr) factors_[i]->print(ss.str(), formatter);
2014-01-10 05:38:47 +08:00
}
}
// /* ************************************************************************* */
// void DiscreteFactorGraph::permuteWithInverse(
// const Permutation& inversePermutation) {
// for(const sharedFactor& factor: factors_) {
2014-01-10 05:38:47 +08:00
// if(factor)
// factor->permuteWithInverse(inversePermutation);
// }
// }
//
// /* ************************************************************************* */
// void DiscreteFactorGraph::reduceWithInverse(
// const internal::Reduction& inverseReduction) {
// for(const sharedFactor& factor: factors_) {
2014-01-10 05:38:47 +08:00
// if(factor)
// factor->reduceWithInverse(inverseReduction);
// }
// }
/**
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
*/
2024-12-11 00:05:50 +08:00
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
2022-01-21 23:12:31 +08:00
// PRODUCT: multiply all factors
2024-12-28 01:02:21 +08:00
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product();
2024-12-28 01:02:21 +08:00
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
2022-01-21 23:12:31 +08:00
// Max over all the potentials by pretending all keys are frontal:
2024-12-26 02:13:21 +08:00
auto normalizer = product.max(product.size());
2024-12-28 01:02:21 +08:00
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
2024-12-26 02:13:21 +08:00
product = product / (*normalizer);
2024-12-28 01:02:21 +08:00
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
2024-12-11 00:10:25 +08:00
return product;
}
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);
2022-01-21 23:12:31 +08:00
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
2024-12-11 00:05:50 +08:00
auto lookup =
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
2022-01-21 23:12:31 +08:00
gttoc(lookup);
2023-02-05 01:08:34 +08:00
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
2022-01-21 23:12:31 +08:00
}
2022-01-26 06:15:52 +08:00
/* ************************************************************************ */
// sumProduct is just an alias for regular eliminateSequential.
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_sumProduct);
2022-01-26 12:47:53 +08:00
auto bayesNet = eliminateSequential(orderingType);
2022-01-26 06:15:52 +08:00
return *bayesNet;
}
2022-01-26 12:47:53 +08:00
DiscreteBayesNet DiscreteFactorGraph::sumProduct(
2022-01-26 06:15:52 +08:00
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_sumProduct);
2022-01-26 12:47:53 +08:00
auto bayesNet = eliminateSequential(ordering);
return *bayesNet;
2022-01-26 06:15:52 +08:00
}
2022-01-21 23:12:31 +08:00
/* ************************************************************************ */
// The max-product solution below is a bit clunky: the elimination machinery
// does not allow for differently *typed* versions of elimination, so we
// eliminate into a Bayes Net using the special eliminate function above, and
// then create the DiscreteLookupDAG after the fact, in linear time.
2022-01-22 02:18:28 +08:00
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
2022-01-21 23:12:31 +08:00
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
2022-01-26 12:47:53 +08:00
auto bayesNet = eliminateSequential(orderingType, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
2022-01-22 02:18:28 +08:00
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_maxProduct);
2022-01-26 12:47:53 +08:00
auto bayesNet = eliminateSequential(ordering, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
2022-01-21 23:12:31 +08:00
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
2014-01-10 05:38:47 +08:00
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(orderingType);
return dag.argmax();
2014-01-10 05:38:47 +08:00
}
2024-12-26 02:13:21 +08:00
DiscreteValues DiscreteFactorGraph::optimize(const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
2022-01-22 02:18:28 +08:00
return dag.argmax();
2014-01-10 05:38:47 +08:00
}
2022-01-21 23:12:31 +08:00
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
2022-01-21 23:12:31 +08:00
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);
2014-01-10 05:38:47 +08:00
// sum out frontals, this is the factor on the separator
2024-12-28 01:02:21 +08:00
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
2014-01-10 05:38:47 +08:00
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
2024-12-28 01:02:21 +08:00
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
2014-01-10 05:38:47 +08:00
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
2022-01-21 23:12:31 +08:00
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
2014-01-10 05:38:47 +08:00
// now divide product/sum to get conditional
2024-12-28 01:02:21 +08:00
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
2022-01-21 23:12:31 +08:00
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
2024-12-28 01:02:21 +08:00
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif
2014-01-10 05:38:47 +08:00
2023-02-05 01:08:34 +08:00
return {conditional, sum};
2014-01-10 05:38:47 +08:00
}
2022-01-10 00:42:56 +08:00
/* ************************************************************************ */
string DiscreteFactorGraph::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
2021-12-25 02:27:02 +08:00
using std::endl;
std::stringstream ss;
ss << "`DiscreteFactorGraph` of size " << size() << endl << endl;
for (size_t i = 0; i < factors_.size(); i++) {
ss << "factor " << i << ":\n";
ss << factors_[i]->markdown(keyFormatter, names) << endl;
2021-12-25 02:27:02 +08:00
}
return ss.str();
}
2014-01-10 05:38:47 +08:00
2022-01-10 00:42:56 +08:00
/* ************************************************************************ */
string DiscreteFactorGraph::html(const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
ss << "<div><p><tt>DiscreteFactorGraph</tt> of size " << size() << "</p>";
for (size_t i = 0; i < factors_.size(); i++) {
ss << "<p>factor " << i << ":</p>";
ss << factors_[i]->html(keyFormatter, names) << endl;
}
return ss.str();
}
/* ************************************************************************ */
2021-12-25 02:27:02 +08:00
} // namespace gtsam