Make HybridFactorGraph just a FactorGraph<Factor> with extra methods

release/4.3a0
Frank Dellaert 2023-01-06 16:04:54 -08:00
parent ce27a8baa0
commit 1538452d5a
7 changed files with 96 additions and 99 deletions

View File

@ -59,6 +59,8 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionals)) {}
/* *******************************************************************************/
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
// GaussianMixtureFactor, no?
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;

View File

@ -11,8 +11,9 @@
/**
* @file HybridFactorGraph.h
* @brief Hybrid factor graph base class that uses type erasure
* @brief Factor graph with utilities for hybrid factors.
* @author Varun Agrawal
* @author Frank Dellaert
* @date May 28, 2022
*/
@ -31,13 +32,11 @@ using SharedFactor = boost::shared_ptr<Factor>;
/**
* Hybrid Factor Graph
* -----------------------
* This is the base hybrid factor graph.
* Everything inside needs to be hybrid factor or hybrid conditional.
* Factor graph with utilities for hybrid factors.
*/
class HybridFactorGraph : public FactorGraph<HybridFactor> {
class HybridFactorGraph : public FactorGraph<Factor> {
public:
using Base = FactorGraph<HybridFactor>;
using Base = FactorGraph<Factor>;
using This = HybridFactorGraph; ///< this class
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
@ -140,8 +139,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
const KeySet discreteKeys() const {
KeySet discrete_keys;
for (auto& factor : factors_) {
for (const DiscreteKey& k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const DiscreteKey& k : p->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
}
return discrete_keys;
@ -151,8 +152,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
const KeySet continuousKeys() const {
KeySet keys;
for (auto& factor : factors_) {
for (const Key& key : factor->continuousKeys()) {
keys.insert(key);
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const Key& key : p->continuousKeys()) {
keys.insert(key);
}
}
}
return keys;

View File

@ -79,48 +79,47 @@ static GaussianFactorGraphTree addGaussian(
}
/* ************************************************************************ */
// TODO(dellaert): Implementation-wise, it's probably more efficient to first
// collect the discrete keys, and then loop over all assignments to populate a
// vector.
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
using boost::dynamic_pointer_cast;
gttic(assembleGraphTree);
GaussianFactorGraphTree result;
for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (f->isHybrid()) {
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
} else {
// Has to be discrete.
continue;
}
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
result = gm->asMixture()->add(result);
}
} else if (f->isContinuous()) {
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = addGaussian(result, gf->inner());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
result = addGaussian(result, cg->asGaussian());
}
} else if (f->isDiscrete()) {
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = addGaussian(result, gf->inner());
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// Don't do anything for discrete-only factors
// since we want to eliminate continuous values only.
continue;
} else {
} else if (auto orphan = dynamic_pointer_cast<
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f)) {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
auto orphan = boost::dynamic_pointer_cast<
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
if (!orphan) {
auto &fr = *f;
throw std::invalid_argument(
std::string("factor is discrete in continuous elimination ") +
demangle(typeid(fr).name()));
}
throw std::invalid_argument(
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
"yet.");
} else {
auto &fr = *f;
throw std::invalid_argument(
std::string("gtsam::assembleGraphTree: factor type not handled: ") +
demangle(typeid(fr).name()));
}
}
@ -377,8 +376,8 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Build a map from keys to DiscreteKeys
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
for (auto &&factor : factors) {
if (!factor->isContinuous()) {
for (auto &k : factor->discreteKeys()) {
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (auto &k : p->discreteKeys()) {
mapFromKeyToDiscreteKey[k.first] = k;
}
}
@ -451,12 +450,6 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = discreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
@ -466,25 +459,23 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
using boost::dynamic_pointer_cast;
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + gaussianMixture->error(continuousValues);
} else if (factors_.at(idx)->isContinuous()) {
} else if (auto hybridGaussianFactor =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
auto hybridGaussianFactor =
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
// Compute the error of the gaussian factor.
@ -493,9 +484,16 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (factors_.at(idx)->isDiscrete()) {
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
continue;
} else {
auto &fr = *f;
throw std::invalid_argument(
std::string(
"HybridGaussianFactorGraph::error: factor type not handled: ") +
demangle(typeid(fr).name()));
}
}
@ -506,7 +504,9 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (auto &factor : factors_) {
error += factor->error(values);
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
error += p->error(values);
}
}
return error;
}

View File

@ -61,9 +61,11 @@ struct HybridConstructorTraversalData {
parentData.junctionTreeNode->addChild(data.junctionTreeNode);
// Add all the discrete keys in the hybrid factors to the current data
for (HybridFactor::shared_ptr& f : node->factors) {
for (auto& k : f->discreteKeys()) {
data.discreteKeys.insert(k.first);
for (const auto& f : node->factors) {
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(f)) {
for (auto& k : p->discreteKeys()) {
data.discreteKeys.insert(k.first);
}
}
}

View File

@ -50,47 +50,42 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
/* ************************************************************************* */
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const Values& continuousValues) const {
using boost::dynamic_pointer_cast;
// create an empty linear FG
auto linearFG = boost::make_shared<HybridGaussianFactorGraph>();
linearFG->reserve(size());
// linearize all hybrid factors
for (auto&& factor : factors_) {
for (auto& f : factors_) {
// First check if it is a valid factor
if (factor) {
// Check if the factor is a hybrid factor.
// It can be either a nonlinear MixtureFactor or a linear
// GaussianMixtureFactor.
if (factor->isHybrid()) {
// Check if it is a nonlinear mixture factor
if (auto nlmf = boost::dynamic_pointer_cast<MixtureFactor>(factor)) {
linearFG->push_back(nlmf->linearize(continuousValues));
} else {
linearFG->push_back(factor);
}
// Now check if the factor is a continuous only factor.
} else if (factor->isContinuous()) {
// In this case, we check if factor->inner() is nonlinear since
// HybridFactors wrap over continuous factors.
auto nlhf = boost::dynamic_pointer_cast<HybridNonlinearFactor>(factor);
if (auto nlf =
boost::dynamic_pointer_cast<NonlinearFactor>(nlhf->inner())) {
auto hgf = boost::make_shared<HybridGaussianFactor>(
nlf->linearize(continuousValues));
linearFG->push_back(hgf);
} else {
linearFG->push_back(factor);
}
// Finally if nothing else, we are discrete-only which doesn't need
// lineariztion.
} else {
linearFG->push_back(factor);
}
} else {
if (!f) {
// TODO(dellaert): why?
linearFG->push_back(GaussianFactor::shared_ptr());
continue;
}
// Check if it is a nonlinear mixture factor
if (auto nlmf = dynamic_pointer_cast<MixtureFactor>(f)) {
const GaussianMixtureFactor::shared_ptr& gmf =
nlmf->linearize(continuousValues);
linearFG->push_back(gmf);
} else if (auto nlhf = dynamic_pointer_cast<HybridNonlinearFactor>(f)) {
// Nonlinear wrapper case:
const GaussianFactor::shared_ptr& gf =
nlhf->inner()->linearize(continuousValues);
const auto hgf = boost::make_shared<HybridGaussianFactor>(gf);
linearFG->push_back(hgf);
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// If discrete-only: doesn't need linearization.
linearFG->push_back(f);
} else {
auto& fr = *f;
throw std::invalid_argument(
std::string("HybridNonlinearFactorGraph::linearize: factor type "
"not handled: ") +
demangle(typeid(fr).name()));
}
}
return linearFG;

View File

@ -47,11 +47,6 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
using HasDerivedValueType = typename std::enable_if<
std::is_base_of<HybridFactor, typename T::value_type>::value>::type;
/// Check if T has a pointer type derived from FactorType.
template <typename T>
using HasDerivedElementType = typename std::enable_if<std::is_base_of<
HybridFactor, typename T::value_type::element_type>::value>::type;
public:
using Base = HybridFactorGraph;
using This = HybridNonlinearFactorGraph; ///< this class
@ -124,7 +119,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
* copied)
*/
template <typename CONTAINER>
HasDerivedElementType<CONTAINER> push_back(const CONTAINER& container) {
void push_back(const CONTAINER& container) {
Base::push_back(container.begin(), container.end());
}

View File

@ -435,7 +435,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (HybridFactor::shared_ptr& factor : (*remainingFactorGraph_partial)) {
for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor);
discrete_fg.push_back(df->inner());
}