Make HybridFactorGraph just a FactorGraph<Factor> with extra methods
parent
ce27a8baa0
commit
1538452d5a
|
@ -59,6 +59,8 @@ GaussianMixture::GaussianMixture(
|
|||
Conditionals(discreteParents, conditionals)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
|
||||
// GaussianMixtureFactor, no?
|
||||
GaussianFactorGraphTree GaussianMixture::add(
|
||||
const GaussianFactorGraphTree &sum) const {
|
||||
using Y = GraphAndConstant;
|
||||
|
|
|
@ -11,8 +11,9 @@
|
|||
|
||||
/**
|
||||
* @file HybridFactorGraph.h
|
||||
* @brief Hybrid factor graph base class that uses type erasure
|
||||
* @brief Factor graph with utilities for hybrid factors.
|
||||
* @author Varun Agrawal
|
||||
* @author Frank Dellaert
|
||||
* @date May 28, 2022
|
||||
*/
|
||||
|
||||
|
@ -31,13 +32,11 @@ using SharedFactor = boost::shared_ptr<Factor>;
|
|||
|
||||
/**
|
||||
* Hybrid Factor Graph
|
||||
* -----------------------
|
||||
* This is the base hybrid factor graph.
|
||||
* Everything inside needs to be hybrid factor or hybrid conditional.
|
||||
* Factor graph with utilities for hybrid factors.
|
||||
*/
|
||||
class HybridFactorGraph : public FactorGraph<HybridFactor> {
|
||||
class HybridFactorGraph : public FactorGraph<Factor> {
|
||||
public:
|
||||
using Base = FactorGraph<HybridFactor>;
|
||||
using Base = FactorGraph<Factor>;
|
||||
using This = HybridFactorGraph; ///< this class
|
||||
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||
|
||||
|
@ -140,8 +139,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
|
|||
const KeySet discreteKeys() const {
|
||||
KeySet discrete_keys;
|
||||
for (auto& factor : factors_) {
|
||||
for (const DiscreteKey& k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
for (const DiscreteKey& k : p->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
return discrete_keys;
|
||||
|
@ -151,8 +152,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
|
|||
const KeySet continuousKeys() const {
|
||||
KeySet keys;
|
||||
for (auto& factor : factors_) {
|
||||
for (const Key& key : factor->continuousKeys()) {
|
||||
keys.insert(key);
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
for (const Key& key : p->continuousKeys()) {
|
||||
keys.insert(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
return keys;
|
||||
|
|
|
@ -79,48 +79,47 @@ static GaussianFactorGraphTree addGaussian(
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// TODO(dellaert): Implementation-wise, it's probably more efficient to first
|
||||
// collect the discrete keys, and then loop over all assignments to populate a
|
||||
// vector.
|
||||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||
// keys, and then loop over all assignments to populate a vector.
|
||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||
using boost::dynamic_pointer_cast;
|
||||
|
||||
gttic(assembleGraphTree);
|
||||
|
||||
GaussianFactorGraphTree result;
|
||||
|
||||
for (auto &f : factors_) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
if (f->isHybrid()) {
|
||||
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
result = gm->add(result);
|
||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
if (auto gm = hc->asMixture()) {
|
||||
result = gm->add(result);
|
||||
} else if (auto g = hc->asGaussian()) {
|
||||
result = addGaussian(result, g);
|
||||
} else {
|
||||
// Has to be discrete.
|
||||
continue;
|
||||
}
|
||||
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
result = gm->asMixture()->add(result);
|
||||
}
|
||||
|
||||
} else if (f->isContinuous()) {
|
||||
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
result = addGaussian(result, gf->inner());
|
||||
}
|
||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
result = addGaussian(result, cg->asGaussian());
|
||||
}
|
||||
|
||||
} else if (f->isDiscrete()) {
|
||||
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
result = addGaussian(result, gf->inner());
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
||||
// Don't do anything for discrete-only factors
|
||||
// since we want to eliminate continuous values only.
|
||||
continue;
|
||||
|
||||
} else {
|
||||
} else if (auto orphan = dynamic_pointer_cast<
|
||||
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f)) {
|
||||
// We need to handle the case where the object is actually an
|
||||
// BayesTreeOrphanWrapper!
|
||||
auto orphan = boost::dynamic_pointer_cast<
|
||||
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
|
||||
if (!orphan) {
|
||||
auto &fr = *f;
|
||||
throw std::invalid_argument(
|
||||
std::string("factor is discrete in continuous elimination ") +
|
||||
demangle(typeid(fr).name()));
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
|
||||
"yet.");
|
||||
} else {
|
||||
auto &fr = *f;
|
||||
throw std::invalid_argument(
|
||||
std::string("gtsam::assembleGraphTree: factor type not handled: ") +
|
||||
demangle(typeid(fr).name()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -377,8 +376,8 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
// Build a map from keys to DiscreteKeys
|
||||
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
|
||||
for (auto &&factor : factors) {
|
||||
if (!factor->isContinuous()) {
|
||||
for (auto &k : factor->discreteKeys()) {
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
for (auto &k : p->discreteKeys()) {
|
||||
mapFromKeyToDiscreteKey[k.first] = k;
|
||||
}
|
||||
}
|
||||
|
@ -451,12 +450,6 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
|||
/* ************************************************************************ */
|
||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
||||
KeySet discrete_keys = discreteKeys();
|
||||
for (auto &factor : factors_) {
|
||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
|
||||
const VariableIndex index(factors_);
|
||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||
|
@ -466,25 +459,23 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
|||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
using boost::dynamic_pointer_cast;
|
||||
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
// Iterate over each factor.
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
for (auto &f : factors_) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
AlgebraicDecisionTree<Key> factor_error;
|
||||
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixtureFactor::shared_ptr gaussianMixture =
|
||||
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
|
||||
if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
// Compute factor error and add it.
|
||||
error_tree = error_tree + gaussianMixture->error(continuousValues);
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
} else if (auto hybridGaussianFactor =
|
||||
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
auto hybridGaussianFactor =
|
||||
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
|
||||
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
|
||||
|
||||
// Compute the error of the gaussian factor.
|
||||
|
@ -493,9 +484,16 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
} else {
|
||||
auto &fr = *f;
|
||||
throw std::invalid_argument(
|
||||
std::string(
|
||||
"HybridGaussianFactorGraph::error: factor type not handled: ") +
|
||||
demangle(typeid(fr).name()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -506,7 +504,9 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||
double error = 0.0;
|
||||
for (auto &factor : factors_) {
|
||||
error += factor->error(values);
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
error += p->error(values);
|
||||
}
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
|
|
@ -61,9 +61,11 @@ struct HybridConstructorTraversalData {
|
|||
parentData.junctionTreeNode->addChild(data.junctionTreeNode);
|
||||
|
||||
// Add all the discrete keys in the hybrid factors to the current data
|
||||
for (HybridFactor::shared_ptr& f : node->factors) {
|
||||
for (auto& k : f->discreteKeys()) {
|
||||
data.discreteKeys.insert(k.first);
|
||||
for (const auto& f : node->factors) {
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(f)) {
|
||||
for (auto& k : p->discreteKeys()) {
|
||||
data.discreteKeys.insert(k.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -50,47 +50,42 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
|
|||
/* ************************************************************************* */
|
||||
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||
const Values& continuousValues) const {
|
||||
using boost::dynamic_pointer_cast;
|
||||
|
||||
// create an empty linear FG
|
||||
auto linearFG = boost::make_shared<HybridGaussianFactorGraph>();
|
||||
|
||||
linearFG->reserve(size());
|
||||
|
||||
// linearize all hybrid factors
|
||||
for (auto&& factor : factors_) {
|
||||
for (auto& f : factors_) {
|
||||
// First check if it is a valid factor
|
||||
if (factor) {
|
||||
// Check if the factor is a hybrid factor.
|
||||
// It can be either a nonlinear MixtureFactor or a linear
|
||||
// GaussianMixtureFactor.
|
||||
if (factor->isHybrid()) {
|
||||
// Check if it is a nonlinear mixture factor
|
||||
if (auto nlmf = boost::dynamic_pointer_cast<MixtureFactor>(factor)) {
|
||||
linearFG->push_back(nlmf->linearize(continuousValues));
|
||||
} else {
|
||||
linearFG->push_back(factor);
|
||||
}
|
||||
|
||||
// Now check if the factor is a continuous only factor.
|
||||
} else if (factor->isContinuous()) {
|
||||
// In this case, we check if factor->inner() is nonlinear since
|
||||
// HybridFactors wrap over continuous factors.
|
||||
auto nlhf = boost::dynamic_pointer_cast<HybridNonlinearFactor>(factor);
|
||||
if (auto nlf =
|
||||
boost::dynamic_pointer_cast<NonlinearFactor>(nlhf->inner())) {
|
||||
auto hgf = boost::make_shared<HybridGaussianFactor>(
|
||||
nlf->linearize(continuousValues));
|
||||
linearFG->push_back(hgf);
|
||||
} else {
|
||||
linearFG->push_back(factor);
|
||||
}
|
||||
// Finally if nothing else, we are discrete-only which doesn't need
|
||||
// lineariztion.
|
||||
} else {
|
||||
linearFG->push_back(factor);
|
||||
}
|
||||
|
||||
} else {
|
||||
if (!f) {
|
||||
// TODO(dellaert): why?
|
||||
linearFG->push_back(GaussianFactor::shared_ptr());
|
||||
continue;
|
||||
}
|
||||
// Check if it is a nonlinear mixture factor
|
||||
if (auto nlmf = dynamic_pointer_cast<MixtureFactor>(f)) {
|
||||
const GaussianMixtureFactor::shared_ptr& gmf =
|
||||
nlmf->linearize(continuousValues);
|
||||
linearFG->push_back(gmf);
|
||||
} else if (auto nlhf = dynamic_pointer_cast<HybridNonlinearFactor>(f)) {
|
||||
// Nonlinear wrapper case:
|
||||
const GaussianFactor::shared_ptr& gf =
|
||||
nlhf->inner()->linearize(continuousValues);
|
||||
const auto hgf = boost::make_shared<HybridGaussianFactor>(gf);
|
||||
linearFG->push_back(hgf);
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
||||
// If discrete-only: doesn't need linearization.
|
||||
linearFG->push_back(f);
|
||||
} else {
|
||||
auto& fr = *f;
|
||||
throw std::invalid_argument(
|
||||
std::string("HybridNonlinearFactorGraph::linearize: factor type "
|
||||
"not handled: ") +
|
||||
demangle(typeid(fr).name()));
|
||||
}
|
||||
}
|
||||
return linearFG;
|
||||
|
|
|
@ -47,11 +47,6 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
|||
using HasDerivedValueType = typename std::enable_if<
|
||||
std::is_base_of<HybridFactor, typename T::value_type>::value>::type;
|
||||
|
||||
/// Check if T has a pointer type derived from FactorType.
|
||||
template <typename T>
|
||||
using HasDerivedElementType = typename std::enable_if<std::is_base_of<
|
||||
HybridFactor, typename T::value_type::element_type>::value>::type;
|
||||
|
||||
public:
|
||||
using Base = HybridFactorGraph;
|
||||
using This = HybridNonlinearFactorGraph; ///< this class
|
||||
|
@ -124,7 +119,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
|||
* copied)
|
||||
*/
|
||||
template <typename CONTAINER>
|
||||
HasDerivedElementType<CONTAINER> push_back(const CONTAINER& container) {
|
||||
void push_back(const CONTAINER& container) {
|
||||
Base::push_back(container.begin(), container.end());
|
||||
}
|
||||
|
||||
|
|
|
@ -435,7 +435,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
|
|||
|
||||
DiscreteFactorGraph discrete_fg;
|
||||
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
|
||||
for (HybridFactor::shared_ptr& factor : (*remainingFactorGraph_partial)) {
|
||||
for (auto& factor : (*remainingFactorGraph_partial)) {
|
||||
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor);
|
||||
discrete_fg.push_back(df->inner());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue