diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index bac1285e1..cabfd28b8 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -59,6 +59,8 @@ GaussianMixture::GaussianMixture( Conditionals(discreteParents, conditionals)) {} /* *******************************************************************************/ +// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from +// GaussianMixtureFactor, no? GaussianFactorGraphTree GaussianMixture::add( const GaussianFactorGraphTree &sum) const { using Y = GraphAndConstant; diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index 05a17b000..e2322ee0b 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -11,8 +11,9 @@ /** * @file HybridFactorGraph.h - * @brief Hybrid factor graph base class that uses type erasure + * @brief Factor graph with utilities for hybrid factors. * @author Varun Agrawal + * @author Frank Dellaert * @date May 28, 2022 */ @@ -31,13 +32,11 @@ using SharedFactor = boost::shared_ptr; /** * Hybrid Factor Graph - * ----------------------- - * This is the base hybrid factor graph. - * Everything inside needs to be hybrid factor or hybrid conditional. + * Factor graph with utilities for hybrid factors. */ -class HybridFactorGraph : public FactorGraph { +class HybridFactorGraph : public FactorGraph { public: - using Base = FactorGraph; + using Base = FactorGraph; using This = HybridFactorGraph; ///< this class using shared_ptr = boost::shared_ptr; ///< shared_ptr to This @@ -140,8 +139,10 @@ class HybridFactorGraph : public FactorGraph { const KeySet discreteKeys() const { KeySet discrete_keys; for (auto& factor : factors_) { - for (const DiscreteKey& k : factor->discreteKeys()) { - discrete_keys.insert(k.first); + if (auto p = boost::dynamic_pointer_cast(factor)) { + for (const DiscreteKey& k : p->discreteKeys()) { + discrete_keys.insert(k.first); + } } } return discrete_keys; @@ -151,8 +152,10 @@ class HybridFactorGraph : public FactorGraph { const KeySet continuousKeys() const { KeySet keys; for (auto& factor : factors_) { - for (const Key& key : factor->continuousKeys()) { - keys.insert(key); + if (auto p = boost::dynamic_pointer_cast(factor)) { + for (const Key& key : p->continuousKeys()) { + keys.insert(key); + } } } return keys; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index f6b713a76..a2f420c3f 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -79,48 +79,47 @@ static GaussianFactorGraphTree addGaussian( } /* ************************************************************************ */ -// TODO(dellaert): Implementation-wise, it's probably more efficient to first -// collect the discrete keys, and then loop over all assignments to populate a -// vector. +// TODO(dellaert): it's probably more efficient to first collect the discrete +// keys, and then loop over all assignments to populate a vector. GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { + using boost::dynamic_pointer_cast; + gttic(assembleGraphTree); GaussianFactorGraphTree result; for (auto &f : factors_) { // TODO(dellaert): just use a virtual method defined in HybridFactor. - if (f->isHybrid()) { - if (auto gm = boost::dynamic_pointer_cast(f)) { + if (auto gm = dynamic_pointer_cast(f)) { + result = gm->add(result); + } else if (auto hc = dynamic_pointer_cast(f)) { + if (auto gm = hc->asMixture()) { result = gm->add(result); + } else if (auto g = hc->asGaussian()) { + result = addGaussian(result, g); + } else { + // Has to be discrete. + continue; } - if (auto gm = boost::dynamic_pointer_cast(f)) { - result = gm->asMixture()->add(result); - } - - } else if (f->isContinuous()) { - if (auto gf = boost::dynamic_pointer_cast(f)) { - result = addGaussian(result, gf->inner()); - } - if (auto cg = boost::dynamic_pointer_cast(f)) { - result = addGaussian(result, cg->asGaussian()); - } - - } else if (f->isDiscrete()) { + } else if (auto gf = dynamic_pointer_cast(f)) { + result = addGaussian(result, gf->inner()); + } else if (dynamic_pointer_cast(f) || + dynamic_pointer_cast(f)) { // Don't do anything for discrete-only factors // since we want to eliminate continuous values only. continue; - - } else { + } else if (auto orphan = dynamic_pointer_cast< + BayesTreeOrphanWrapper>(f)) { // We need to handle the case where the object is actually an // BayesTreeOrphanWrapper! - auto orphan = boost::dynamic_pointer_cast< - BayesTreeOrphanWrapper>(f); - if (!orphan) { - auto &fr = *f; - throw std::invalid_argument( - std::string("factor is discrete in continuous elimination ") + - demangle(typeid(fr).name())); - } + throw std::invalid_argument( + "gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented " + "yet."); + } else { + auto &fr = *f; + throw std::invalid_argument( + std::string("gtsam::assembleGraphTree: factor type not handled: ") + + demangle(typeid(fr).name())); } } @@ -377,8 +376,8 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // Build a map from keys to DiscreteKeys std::unordered_map mapFromKeyToDiscreteKey; for (auto &&factor : factors) { - if (!factor->isContinuous()) { - for (auto &k : factor->discreteKeys()) { + if (auto p = boost::dynamic_pointer_cast(factor)) { + for (auto &k : p->discreteKeys()) { mapFromKeyToDiscreteKey[k.first] = k; } } @@ -451,12 +450,6 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { /* ************************************************************************ */ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { KeySet discrete_keys = discreteKeys(); - for (auto &factor : factors_) { - for (const DiscreteKey &k : factor->discreteKeys()) { - discrete_keys.insert(k.first); - } - } - const VariableIndex index(factors_); Ordering ordering = Ordering::ColamdConstrainedLast( index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); @@ -466,25 +459,23 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::error( const VectorValues &continuousValues) const { + using boost::dynamic_pointer_cast; + AlgebraicDecisionTree error_tree(0.0); // Iterate over each factor. - for (size_t idx = 0; idx < size(); idx++) { + for (auto &f : factors_) { // TODO(dellaert): just use a virtual method defined in HybridFactor. AlgebraicDecisionTree factor_error; - if (factors_.at(idx)->isHybrid()) { - // If factor is hybrid, select based on assignment. - GaussianMixtureFactor::shared_ptr gaussianMixture = - boost::static_pointer_cast(factors_.at(idx)); + if (auto gaussianMixture = dynamic_pointer_cast(f)) { // Compute factor error and add it. error_tree = error_tree + gaussianMixture->error(continuousValues); - } else if (factors_.at(idx)->isContinuous()) { + } else if (auto hybridGaussianFactor = + dynamic_pointer_cast(f)) { // If continuous only, get the (double) error // and add it to the error_tree - auto hybridGaussianFactor = - boost::static_pointer_cast(factors_.at(idx)); GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); // Compute the error of the gaussian factor. @@ -493,9 +484,16 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); - } else if (factors_.at(idx)->isDiscrete()) { + } else if (dynamic_pointer_cast(f) || + dynamic_pointer_cast(f)) { // If factor at `idx` is discrete-only, we skip. continue; + } else { + auto &fr = *f; + throw std::invalid_argument( + std::string( + "HybridGaussianFactorGraph::error: factor type not handled: ") + + demangle(typeid(fr).name())); } } @@ -506,7 +504,9 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( double HybridGaussianFactorGraph::error(const HybridValues &values) const { double error = 0.0; for (auto &factor : factors_) { - error += factor->error(values); + if (auto p = boost::dynamic_pointer_cast(factor)) { + error += p->error(values); + } } return error; } diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 422c200a4..573df7eca 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -61,9 +61,11 @@ struct HybridConstructorTraversalData { parentData.junctionTreeNode->addChild(data.junctionTreeNode); // Add all the discrete keys in the hybrid factors to the current data - for (HybridFactor::shared_ptr& f : node->factors) { - for (auto& k : f->discreteKeys()) { - data.discreteKeys.insert(k.first); + for (const auto& f : node->factors) { + if (auto p = boost::dynamic_pointer_cast(f)) { + for (auto& k : p->discreteKeys()) { + data.discreteKeys.insert(k.first); + } } } diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 3a3bf720b..6ab1962d4 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -50,47 +50,42 @@ void HybridNonlinearFactorGraph::print(const std::string& s, /* ************************************************************************* */ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( const Values& continuousValues) const { + using boost::dynamic_pointer_cast; + // create an empty linear FG auto linearFG = boost::make_shared(); linearFG->reserve(size()); // linearize all hybrid factors - for (auto&& factor : factors_) { + for (auto& f : factors_) { // First check if it is a valid factor - if (factor) { - // Check if the factor is a hybrid factor. - // It can be either a nonlinear MixtureFactor or a linear - // GaussianMixtureFactor. - if (factor->isHybrid()) { - // Check if it is a nonlinear mixture factor - if (auto nlmf = boost::dynamic_pointer_cast(factor)) { - linearFG->push_back(nlmf->linearize(continuousValues)); - } else { - linearFG->push_back(factor); - } - - // Now check if the factor is a continuous only factor. - } else if (factor->isContinuous()) { - // In this case, we check if factor->inner() is nonlinear since - // HybridFactors wrap over continuous factors. - auto nlhf = boost::dynamic_pointer_cast(factor); - if (auto nlf = - boost::dynamic_pointer_cast(nlhf->inner())) { - auto hgf = boost::make_shared( - nlf->linearize(continuousValues)); - linearFG->push_back(hgf); - } else { - linearFG->push_back(factor); - } - // Finally if nothing else, we are discrete-only which doesn't need - // lineariztion. - } else { - linearFG->push_back(factor); - } - - } else { + if (!f) { + // TODO(dellaert): why? linearFG->push_back(GaussianFactor::shared_ptr()); + continue; + } + // Check if it is a nonlinear mixture factor + if (auto nlmf = dynamic_pointer_cast(f)) { + const GaussianMixtureFactor::shared_ptr& gmf = + nlmf->linearize(continuousValues); + linearFG->push_back(gmf); + } else if (auto nlhf = dynamic_pointer_cast(f)) { + // Nonlinear wrapper case: + const GaussianFactor::shared_ptr& gf = + nlhf->inner()->linearize(continuousValues); + const auto hgf = boost::make_shared(gf); + linearFG->push_back(hgf); + } else if (dynamic_pointer_cast(f) || + dynamic_pointer_cast(f)) { + // If discrete-only: doesn't need linearization. + linearFG->push_back(f); + } else { + auto& fr = *f; + throw std::invalid_argument( + std::string("HybridNonlinearFactorGraph::linearize: factor type " + "not handled: ") + + demangle(typeid(fr).name())); } } return linearFG; diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index b48e8bb5c..25314cfd3 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -47,11 +47,6 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { using HasDerivedValueType = typename std::enable_if< std::is_base_of::value>::type; - /// Check if T has a pointer type derived from FactorType. - template - using HasDerivedElementType = typename std::enable_if::value>::type; - public: using Base = HybridFactorGraph; using This = HybridNonlinearFactorGraph; ///< this class @@ -124,7 +119,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { * copied) */ template - HasDerivedElementType push_back(const CONTAINER& container) { + void push_back(const CONTAINER& container) { Base::push_back(container.begin(), container.end()); } diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index d84f4b352..83a71a7d7 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -435,7 +435,7 @@ TEST(HybridFactorGraph, Full_Elimination) { DiscreteFactorGraph discrete_fg; // TODO(Varun) Make this a function of HybridGaussianFactorGraph? - for (HybridFactor::shared_ptr& factor : (*remainingFactorGraph_partial)) { + for (auto& factor : (*remainingFactorGraph_partial)) { auto df = dynamic_pointer_cast(factor); discrete_fg.push_back(df->inner()); }