Make HybridFactorGraph just a FactorGraph<Factor> with extra methods

release/4.3a0
Frank Dellaert 2023-01-06 16:04:54 -08:00
parent ce27a8baa0
commit 1538452d5a
7 changed files with 96 additions and 99 deletions

View File

@ -59,6 +59,8 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionals)) {} Conditionals(discreteParents, conditionals)) {}
/* *******************************************************************************/ /* *******************************************************************************/
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
// GaussianMixtureFactor, no?
GaussianFactorGraphTree GaussianMixture::add( GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const { const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant; using Y = GraphAndConstant;

View File

@ -11,8 +11,9 @@
/** /**
* @file HybridFactorGraph.h * @file HybridFactorGraph.h
* @brief Hybrid factor graph base class that uses type erasure * @brief Factor graph with utilities for hybrid factors.
* @author Varun Agrawal * @author Varun Agrawal
* @author Frank Dellaert
* @date May 28, 2022 * @date May 28, 2022
*/ */
@ -31,13 +32,11 @@ using SharedFactor = boost::shared_ptr<Factor>;
/** /**
* Hybrid Factor Graph * Hybrid Factor Graph
* ----------------------- * Factor graph with utilities for hybrid factors.
* This is the base hybrid factor graph.
* Everything inside needs to be hybrid factor or hybrid conditional.
*/ */
class HybridFactorGraph : public FactorGraph<HybridFactor> { class HybridFactorGraph : public FactorGraph<Factor> {
public: public:
using Base = FactorGraph<HybridFactor>; using Base = FactorGraph<Factor>;
using This = HybridFactorGraph; ///< this class using This = HybridFactorGraph; ///< this class
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
@ -140,8 +139,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
const KeySet discreteKeys() const { const KeySet discreteKeys() const {
KeySet discrete_keys; KeySet discrete_keys;
for (auto& factor : factors_) { for (auto& factor : factors_) {
for (const DiscreteKey& k : factor->discreteKeys()) { if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
discrete_keys.insert(k.first); for (const DiscreteKey& k : p->discreteKeys()) {
discrete_keys.insert(k.first);
}
} }
} }
return discrete_keys; return discrete_keys;
@ -151,8 +152,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
const KeySet continuousKeys() const { const KeySet continuousKeys() const {
KeySet keys; KeySet keys;
for (auto& factor : factors_) { for (auto& factor : factors_) {
for (const Key& key : factor->continuousKeys()) { if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
keys.insert(key); for (const Key& key : p->continuousKeys()) {
keys.insert(key);
}
} }
} }
return keys; return keys;

View File

@ -79,48 +79,47 @@ static GaussianFactorGraphTree addGaussian(
} }
/* ************************************************************************ */ /* ************************************************************************ */
// TODO(dellaert): Implementation-wise, it's probably more efficient to first // TODO(dellaert): it's probably more efficient to first collect the discrete
// collect the discrete keys, and then loop over all assignments to populate a // keys, and then loop over all assignments to populate a vector.
// vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
using boost::dynamic_pointer_cast;
gttic(assembleGraphTree); gttic(assembleGraphTree);
GaussianFactorGraphTree result; GaussianFactorGraphTree result;
for (auto &f : factors_) { for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): just use a virtual method defined in HybridFactor.
if (f->isHybrid()) { if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
result = gm->add(result); result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
} else {
// Has to be discrete.
continue;
} }
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = gm->asMixture()->add(result); result = addGaussian(result, gf->inner());
} } else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
} else if (f->isContinuous()) {
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = addGaussian(result, gf->inner());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
result = addGaussian(result, cg->asGaussian());
}
} else if (f->isDiscrete()) {
// Don't do anything for discrete-only factors // Don't do anything for discrete-only factors
// since we want to eliminate continuous values only. // since we want to eliminate continuous values only.
continue; continue;
} else if (auto orphan = dynamic_pointer_cast<
} else { BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f)) {
// We need to handle the case where the object is actually an // We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper! // BayesTreeOrphanWrapper!
auto orphan = boost::dynamic_pointer_cast< throw std::invalid_argument(
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f); "gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
if (!orphan) { "yet.");
auto &fr = *f; } else {
throw std::invalid_argument( auto &fr = *f;
std::string("factor is discrete in continuous elimination ") + throw std::invalid_argument(
demangle(typeid(fr).name())); std::string("gtsam::assembleGraphTree: factor type not handled: ") +
} demangle(typeid(fr).name()));
} }
} }
@ -377,8 +376,8 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Build a map from keys to DiscreteKeys // Build a map from keys to DiscreteKeys
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey; std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
for (auto &&factor : factors) { for (auto &&factor : factors) {
if (!factor->isContinuous()) { if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (auto &k : factor->discreteKeys()) { for (auto &k : p->discreteKeys()) {
mapFromKeyToDiscreteKey[k.first] = k; mapFromKeyToDiscreteKey[k.first] = k;
} }
} }
@ -451,12 +450,6 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
/* ************************************************************************ */ /* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = discreteKeys(); KeySet discrete_keys = discreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
const VariableIndex index(factors_); const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast( Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
@ -466,25 +459,23 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
/* ************************************************************************ */ /* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
using boost::dynamic_pointer_cast;
AlgebraicDecisionTree<Key> error_tree(0.0); AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor. // Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) { for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error; AlgebraicDecisionTree<Key> factor_error;
if (factors_.at(idx)->isHybrid()) { if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
// Compute factor error and add it. // Compute factor error and add it.
error_tree = error_tree + gaussianMixture->error(continuousValues); error_tree = error_tree + gaussianMixture->error(continuousValues);
} else if (factors_.at(idx)->isContinuous()) { } else if (auto hybridGaussianFactor =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// If continuous only, get the (double) error // If continuous only, get the (double) error
// and add it to the error_tree // and add it to the error_tree
auto hybridGaussianFactor =
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
// Compute the error of the gaussian factor. // Compute the error of the gaussian factor.
@ -493,9 +484,16 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
error_tree = error_tree.apply( error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; }); [error](double leaf_value) { return leaf_value + error; });
} else if (factors_.at(idx)->isDiscrete()) { } else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip. // If factor at `idx` is discrete-only, we skip.
continue; continue;
} else {
auto &fr = *f;
throw std::invalid_argument(
std::string(
"HybridGaussianFactorGraph::error: factor type not handled: ") +
demangle(typeid(fr).name()));
} }
} }
@ -506,7 +504,9 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
double HybridGaussianFactorGraph::error(const HybridValues &values) const { double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0; double error = 0.0;
for (auto &factor : factors_) { for (auto &factor : factors_) {
error += factor->error(values); if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
error += p->error(values);
}
} }
return error; return error;
} }

View File

@ -61,9 +61,11 @@ struct HybridConstructorTraversalData {
parentData.junctionTreeNode->addChild(data.junctionTreeNode); parentData.junctionTreeNode->addChild(data.junctionTreeNode);
// Add all the discrete keys in the hybrid factors to the current data // Add all the discrete keys in the hybrid factors to the current data
for (HybridFactor::shared_ptr& f : node->factors) { for (const auto& f : node->factors) {
for (auto& k : f->discreteKeys()) { if (auto p = boost::dynamic_pointer_cast<HybridFactor>(f)) {
data.discreteKeys.insert(k.first); for (auto& k : p->discreteKeys()) {
data.discreteKeys.insert(k.first);
}
} }
} }

View File

@ -50,47 +50,42 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
/* ************************************************************************* */ /* ************************************************************************* */
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const Values& continuousValues) const { const Values& continuousValues) const {
using boost::dynamic_pointer_cast;
// create an empty linear FG // create an empty linear FG
auto linearFG = boost::make_shared<HybridGaussianFactorGraph>(); auto linearFG = boost::make_shared<HybridGaussianFactorGraph>();
linearFG->reserve(size()); linearFG->reserve(size());
// linearize all hybrid factors // linearize all hybrid factors
for (auto&& factor : factors_) { for (auto& f : factors_) {
// First check if it is a valid factor // First check if it is a valid factor
if (factor) { if (!f) {
// Check if the factor is a hybrid factor. // TODO(dellaert): why?
// It can be either a nonlinear MixtureFactor or a linear
// GaussianMixtureFactor.
if (factor->isHybrid()) {
// Check if it is a nonlinear mixture factor
if (auto nlmf = boost::dynamic_pointer_cast<MixtureFactor>(factor)) {
linearFG->push_back(nlmf->linearize(continuousValues));
} else {
linearFG->push_back(factor);
}
// Now check if the factor is a continuous only factor.
} else if (factor->isContinuous()) {
// In this case, we check if factor->inner() is nonlinear since
// HybridFactors wrap over continuous factors.
auto nlhf = boost::dynamic_pointer_cast<HybridNonlinearFactor>(factor);
if (auto nlf =
boost::dynamic_pointer_cast<NonlinearFactor>(nlhf->inner())) {
auto hgf = boost::make_shared<HybridGaussianFactor>(
nlf->linearize(continuousValues));
linearFG->push_back(hgf);
} else {
linearFG->push_back(factor);
}
// Finally if nothing else, we are discrete-only which doesn't need
// lineariztion.
} else {
linearFG->push_back(factor);
}
} else {
linearFG->push_back(GaussianFactor::shared_ptr()); linearFG->push_back(GaussianFactor::shared_ptr());
continue;
}
// Check if it is a nonlinear mixture factor
if (auto nlmf = dynamic_pointer_cast<MixtureFactor>(f)) {
const GaussianMixtureFactor::shared_ptr& gmf =
nlmf->linearize(continuousValues);
linearFG->push_back(gmf);
} else if (auto nlhf = dynamic_pointer_cast<HybridNonlinearFactor>(f)) {
// Nonlinear wrapper case:
const GaussianFactor::shared_ptr& gf =
nlhf->inner()->linearize(continuousValues);
const auto hgf = boost::make_shared<HybridGaussianFactor>(gf);
linearFG->push_back(hgf);
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// If discrete-only: doesn't need linearization.
linearFG->push_back(f);
} else {
auto& fr = *f;
throw std::invalid_argument(
std::string("HybridNonlinearFactorGraph::linearize: factor type "
"not handled: ") +
demangle(typeid(fr).name()));
} }
} }
return linearFG; return linearFG;

View File

@ -47,11 +47,6 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
using HasDerivedValueType = typename std::enable_if< using HasDerivedValueType = typename std::enable_if<
std::is_base_of<HybridFactor, typename T::value_type>::value>::type; std::is_base_of<HybridFactor, typename T::value_type>::value>::type;
/// Check if T has a pointer type derived from FactorType.
template <typename T>
using HasDerivedElementType = typename std::enable_if<std::is_base_of<
HybridFactor, typename T::value_type::element_type>::value>::type;
public: public:
using Base = HybridFactorGraph; using Base = HybridFactorGraph;
using This = HybridNonlinearFactorGraph; ///< this class using This = HybridNonlinearFactorGraph; ///< this class
@ -124,7 +119,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
* copied) * copied)
*/ */
template <typename CONTAINER> template <typename CONTAINER>
HasDerivedElementType<CONTAINER> push_back(const CONTAINER& container) { void push_back(const CONTAINER& container) {
Base::push_back(container.begin(), container.end()); Base::push_back(container.begin(), container.end());
} }

View File

@ -435,7 +435,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
DiscreteFactorGraph discrete_fg; DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph? // TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (HybridFactor::shared_ptr& factor : (*remainingFactorGraph_partial)) { for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor); auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor);
discrete_fg.push_back(df->inner()); discrete_fg.push_back(df->inner());
} }