Add MixtureFactor for nonlinear factor types
parent
2927d92a52
commit
78ea90bb27
|
|
@ -25,14 +25,9 @@ namespace gtsam {
|
||||||
HybridNonlinearFactor::HybridNonlinearFactor(NonlinearFactor::shared_ptr other)
|
HybridNonlinearFactor::HybridNonlinearFactor(NonlinearFactor::shared_ptr other)
|
||||||
: Base(other->keys()), inner_(other) {}
|
: Base(other->keys()), inner_(other) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
HybridNonlinearFactor::HybridNonlinearFactor(NonlinearFactor &&nf)
|
|
||||||
: Base(nf.keys()),
|
|
||||||
inner_(boost::make_shared<NonlinearFactor>(std::move(nf))) {}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
bool HybridNonlinearFactor::equals(const HybridFactor &lf, double tol) const {
|
bool HybridNonlinearFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
return Base(lf, tol);
|
return Base::equals(lf, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -38,9 +38,6 @@ class HybridNonlinearFactor : public HybridFactor {
|
||||||
// Explicit conversion from a shared ptr of GF
|
// Explicit conversion from a shared ptr of GF
|
||||||
explicit HybridNonlinearFactor(NonlinearFactor::shared_ptr other);
|
explicit HybridNonlinearFactor(NonlinearFactor::shared_ptr other);
|
||||||
|
|
||||||
// Forwarding constructor from concrete NonlinearFactor
|
|
||||||
explicit HybridNonlinearFactor(NonlinearFactor &&jf);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,240 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
* Copyright 2020 The Ambitious Folks of the MRG
|
||||||
|
* See LICENSE for the license information
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file MixtureFactor.h
|
||||||
|
* @brief Nonlinear Mixture factor of continuous and discrete.
|
||||||
|
* @author Kevin Doherty, kdoherty@mit.edu
|
||||||
|
* @author Varun Agrawal
|
||||||
|
* @date December 2021
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
|
#include <gtsam/nonlinear/Symbol.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <boost/format.hpp>
|
||||||
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implementation of a discrete conditional mixture factor. Implements a
|
||||||
|
* joint discrete-continuous factor where the discrete variable serves to
|
||||||
|
* "select" a mixture component corresponding to a NonlinearFactor type
|
||||||
|
* of measurement.
|
||||||
|
*/
|
||||||
|
template <class NonlinearFactorType>
|
||||||
|
class MixtureFactor : public HybridFactor {
|
||||||
|
public:
|
||||||
|
using Base = HybridFactor;
|
||||||
|
using This = MixtureFactor;
|
||||||
|
using shared_ptr = boost::shared_ptr<MixtureFactor>;
|
||||||
|
using sharedFactor = boost::shared_ptr<NonlinearFactorType>;
|
||||||
|
|
||||||
|
/// typedef for DecisionTree which has Keys as node labels and
|
||||||
|
/// NonlinearFactorType as leaf nodes.
|
||||||
|
using Factors = DecisionTree<Key, sharedFactor>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||||
|
Factors factors_;
|
||||||
|
bool normalized_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MixtureFactor() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from Decision tree.
|
||||||
|
*
|
||||||
|
* @param keys Vector of keys for continuous factors.
|
||||||
|
* @param discreteKeys Vector of discrete keys.
|
||||||
|
* @param factors Decision tree with of shared factors.
|
||||||
|
* @param normalized Flag indicating if the factor error is already
|
||||||
|
* normalized.
|
||||||
|
*/
|
||||||
|
MixtureFactor(const KeyVector& keys, const DiscreteKeys& discreteKeys,
|
||||||
|
const Factors& factors, bool normalized = false)
|
||||||
|
: Base(keys, discreteKeys), factors_(factors), normalized_(normalized) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convenience constructor that generates the underlying factor
|
||||||
|
* decision tree for us.
|
||||||
|
*
|
||||||
|
* Here it is important that the vector of factors has the correct number of
|
||||||
|
* elements based on the number of discrete keys and the cardinality of the
|
||||||
|
* keys, so that the decision tree is constructed appropriately.
|
||||||
|
*
|
||||||
|
* @param keys Vector of keys for continuous factors.
|
||||||
|
* @param discreteKeys Vector of discrete keys.
|
||||||
|
* @param factors Vector of shared pointers to factors.
|
||||||
|
* @param normalized Flag indicating if the factor error is already
|
||||||
|
* normalized.
|
||||||
|
*/
|
||||||
|
MixtureFactor(const KeyVector& keys, const DiscreteKeys& discreteKeys,
|
||||||
|
const std::vector<sharedFactor>& factors,
|
||||||
|
bool normalized = false)
|
||||||
|
: MixtureFactor(keys, discreteKeys, Factors(discreteKeys, factors),
|
||||||
|
normalized) {}
|
||||||
|
|
||||||
|
~MixtureFactor() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute error of factor given both continuous and discrete values.
|
||||||
|
*
|
||||||
|
* @param continuousVals The continuous Values.
|
||||||
|
* @param discreteVals The discrete Values.
|
||||||
|
* @return double The error of this factor.
|
||||||
|
*/
|
||||||
|
double error(const Values& continuousVals,
|
||||||
|
const DiscreteValues& discreteVals) const {
|
||||||
|
// Retrieve the factor corresponding to the assignment in discreteVals.
|
||||||
|
auto factor = factors_(discreteVals);
|
||||||
|
// Compute the error for the selected factor
|
||||||
|
const double factorError = factor->error(continuousVals);
|
||||||
|
if (normalized_) return factorError;
|
||||||
|
return factorError +
|
||||||
|
this->nonlinearFactorLogNormalizingConstant(*factor, continuousVals);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t dim() const {
|
||||||
|
// TODO(Varun)
|
||||||
|
throw std::runtime_error("MixtureFactor::dim not implemented.");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// print to stdout
|
||||||
|
void print(
|
||||||
|
const std::string& s = "MixtureFactor",
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
|
||||||
|
std::cout << (s.empty() ? "" : s + " ");
|
||||||
|
std::cout << "(";
|
||||||
|
auto contKeys = keys();
|
||||||
|
auto dKeys = discreteKeys();
|
||||||
|
for (DiscreteKey key : dKeys) {
|
||||||
|
auto it = std::find(contKeys.begin(), contKeys.end(), key.first);
|
||||||
|
contKeys.erase(it);
|
||||||
|
}
|
||||||
|
for (Key key : contKeys) {
|
||||||
|
std::cout << " " << keyFormatter(key);
|
||||||
|
}
|
||||||
|
std::cout << ";";
|
||||||
|
for (DiscreteKey key : dKeys) {
|
||||||
|
std::cout << " " << keyFormatter(key.first);
|
||||||
|
}
|
||||||
|
std::cout << " ) \n";
|
||||||
|
auto valueFormatter = [](const sharedFactor& v) {
|
||||||
|
if (v) {
|
||||||
|
return (boost::format("Nonlinear factor on %d keys") % v->size()).str();
|
||||||
|
} else {
|
||||||
|
return std::string("nullptr");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
factors_.print("", keyFormatter, valueFormatter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check equality
|
||||||
|
bool equals(const HybridFactor& other, double tol = 1e-9) const override {
|
||||||
|
// We attempt a dynamic cast from HybridFactor to MixtureFactor. If it
|
||||||
|
// fails, return false.
|
||||||
|
if (!dynamic_cast<const MixtureFactor*>(&other)) return false;
|
||||||
|
|
||||||
|
// If the cast is successful, we'll properly construct a MixtureFactor
|
||||||
|
// object from `other`
|
||||||
|
const MixtureFactor& f(static_cast<const MixtureFactor&>(other));
|
||||||
|
|
||||||
|
// Ensure that this MixtureFactor and `f` have the same `factors_`.
|
||||||
|
auto compare = [tol](const sharedFactor& a, const sharedFactor& b) {
|
||||||
|
return traits<NonlinearFactorType>::Equals(*a, *b, tol);
|
||||||
|
};
|
||||||
|
if (!factors_.equals(f.factors_, compare)) return false;
|
||||||
|
|
||||||
|
// If everything above passes, and the keys_, discreteKeys_ and normalized_
|
||||||
|
// member variables are identical, return true.
|
||||||
|
return (std::equal(keys_.begin(), keys_.end(), f.keys().begin()) &&
|
||||||
|
(discreteKeys_ == f.discreteKeys_) &&
|
||||||
|
(normalized_ == f.normalized_));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// Linearize specific nonlinear factors based on the assignment in
|
||||||
|
/// discreteValues.
|
||||||
|
GaussianFactor::shared_ptr linearize(
|
||||||
|
const Values& continuousVals, const DiscreteValues& discreteVals) const {
|
||||||
|
auto factor = factors_(discreteVals);
|
||||||
|
return factor->linearize(continuousVals);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Linearize all the continuous factors to get a GaussianMixtureFactor.
|
||||||
|
boost::shared_ptr<GaussianMixtureFactor> linearize(
|
||||||
|
const Values& continuousVals) const {
|
||||||
|
// functional to linearize each factor in the decision tree
|
||||||
|
auto linearizeDT = [continuousVals](const sharedFactor& factor) {
|
||||||
|
return factor->linearize(continuousVals);
|
||||||
|
};
|
||||||
|
|
||||||
|
DecisionTree<Key, GaussianFactor::shared_ptr> linearized_factors(
|
||||||
|
factors_, linearizeDT);
|
||||||
|
|
||||||
|
return boost::make_shared<GaussianMixtureFactor>(keys_, discreteKeys_,
|
||||||
|
linearized_factors);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If the component factors are not already normalized, we want to compute
|
||||||
|
* their normalizing constants so that the resulting joint distribution is
|
||||||
|
* appropriately computed. Remember, this is the _negative_ normalizing
|
||||||
|
* constant for the measurement likelihood (since we are minimizing the
|
||||||
|
* _negative_ log-likelihood).
|
||||||
|
*/
|
||||||
|
double nonlinearFactorLogNormalizingConstant(
|
||||||
|
const NonlinearFactorType& factor, const Values& values) const {
|
||||||
|
// Information matrix (inverse covariance matrix) for the factor.
|
||||||
|
Matrix infoMat;
|
||||||
|
|
||||||
|
// NOTE: This is sloppy (and mallocs!), is there a cleaner way?
|
||||||
|
auto factorPtr = boost::make_shared<NonlinearFactorType>(factor);
|
||||||
|
|
||||||
|
// If this is a NoiseModelFactor, we'll use its noiseModel to
|
||||||
|
// otherwise noiseModelFactor will be nullptr
|
||||||
|
auto noiseModelFactor =
|
||||||
|
boost::dynamic_pointer_cast<NoiseModelFactor>(factorPtr);
|
||||||
|
if (noiseModelFactor) {
|
||||||
|
// If dynamic cast to NoiseModelFactor succeeded, see if the noise model
|
||||||
|
// is Gaussian
|
||||||
|
auto noiseModel = noiseModelFactor->noiseModel();
|
||||||
|
|
||||||
|
auto gaussianNoiseModel =
|
||||||
|
boost::dynamic_pointer_cast<noiseModel::Gaussian>(noiseModel);
|
||||||
|
if (gaussianNoiseModel) {
|
||||||
|
// If the noise model is Gaussian, retrieve the information matrix
|
||||||
|
infoMat = gaussianNoiseModel->information();
|
||||||
|
} else {
|
||||||
|
// If the factor is not a Gaussian factor, we'll linearize it to get
|
||||||
|
// something with a normalized noise model
|
||||||
|
// TODO(kevin): does this make sense to do? I think maybe not in
|
||||||
|
// general? Should we just yell at the user?
|
||||||
|
auto gaussianFactor = factor.linearize(values);
|
||||||
|
infoMat = gaussianFactor->information();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the (negative) log of the normalizing constant
|
||||||
|
return -(factor.dim() * log(2.0 * M_PI) / 2.0) -
|
||||||
|
(log(infoMat.determinant()) / 2.0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
Loading…
Reference in New Issue