Change from pair to small struct
parent
38a6154c55
commit
b972be0b8f
|
@ -149,9 +149,10 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
|||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const GaussianMixtureFactor::Factors likelihoods(
|
||||
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return std::make_pair(conditional->likelihood(frontals),
|
||||
0.5 * conditional->logDeterminant());
|
||||
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return GaussianMixtureFactor::FactorAndConstant{
|
||||
conditional->likelihood(frontals),
|
||||
0.5 * conditional->logDeterminant()};
|
||||
});
|
||||
return boost::make_shared<GaussianMixtureFactor>(
|
||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -32,7 +34,7 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
|
|||
const Mixture &factors)
|
||||
: Base(continuousKeys, discreteKeys),
|
||||
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
|
||||
return std::make_pair(gf, 0.0);
|
||||
return FactorAndConstant{gf, 0.0};
|
||||
}) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
@ -46,11 +48,11 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
|||
|
||||
// Check the base and the factors:
|
||||
return Base::equals(*e, tol) &&
|
||||
factors_.equals(e->factors_,
|
||||
[tol](const GaussianMixtureFactor::FactorAndLogZ &f1,
|
||||
const GaussianMixtureFactor::FactorAndLogZ &f2) {
|
||||
return f1.first->equals(*(f2.first), tol);
|
||||
});
|
||||
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
|
||||
const FactorAndConstant &f2) {
|
||||
return f1.factor->equals(*(f2.factor), tol) &&
|
||||
std::abs(f1.constant - f2.constant) < tol;
|
||||
});
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
@ -63,8 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s,
|
|||
} else {
|
||||
factors_.print(
|
||||
"", [&](Key k) { return formatter(k); },
|
||||
[&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string {
|
||||
auto gf = gf_z.first;
|
||||
[&](const FactorAndConstant &gf_z) -> std::string {
|
||||
auto gf = gf_z.factor;
|
||||
RedirectCout rd;
|
||||
std::cout << ":\n";
|
||||
if (gf && !gf->empty()) {
|
||||
|
@ -79,10 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s,
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() {
|
||||
// Unzip to tree of Gaussian factors and tree of log-constants,
|
||||
// and return the first tree.
|
||||
return unzip(factors_).first;
|
||||
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
|
||||
return Mixture(factors_, [](const FactorAndConstant &factor_z) {
|
||||
return factor_z.factor;
|
||||
});
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
@ -101,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
|||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||
const {
|
||||
auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
|
||||
auto wrap = [](const FactorAndConstant &factor_z) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor_z.first);
|
||||
result.push_back(factor_z.factor);
|
||||
return result;
|
||||
};
|
||||
return {factors_, wrap};
|
||||
|
@ -113,26 +115,17 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
|||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc =
|
||||
[continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
|
||||
GaussianFactor::shared_ptr factor;
|
||||
double log_z;
|
||||
std::tie(factor, log_z) = factor_z;
|
||||
return factor->error(continuousValues) + log_z;
|
||||
};
|
||||
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
|
||||
return factor_z.error(continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixtureFactor::error(
|
||||
const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
GaussianFactor::shared_ptr factor;
|
||||
double log_z;
|
||||
std::tie(factor, log_z) = factors_(discreteValues);
|
||||
return factor->error(continuousValues) + log_z;
|
||||
double GaussianMixtureFactor::error(const HybridValues &values) const {
|
||||
const FactorAndConstant factor_z = factors_(values.discrete());
|
||||
return factor_z.factor->error(values.continuous()) + factor_z.constant;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -23,17 +23,15 @@
|
|||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class GaussianFactorGraph;
|
||||
|
||||
// Needed for wrapper.
|
||||
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
|
||||
class HybridValues;
|
||||
class DiscreteValues;
|
||||
class VectorValues;
|
||||
|
||||
/**
|
||||
* @brief Implementation of a discrete conditional mixture factor.
|
||||
|
@ -53,12 +51,27 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
||||
|
||||
/// typedef of pair of Gaussian factor and log of normalizing constant.
|
||||
using FactorAndLogZ = std::pair<GaussianFactor::shared_ptr, double>;
|
||||
/// typedef for Decision Tree of Gaussian Factors and log-constant.
|
||||
using Factors = DecisionTree<Key, FactorAndLogZ>;
|
||||
using Mixture = DecisionTree<Key, GaussianFactor::shared_ptr>;
|
||||
/// Gaussian factor and log of normalizing constant.
|
||||
struct FactorAndConstant {
|
||||
sharedFactor factor;
|
||||
double constant;
|
||||
|
||||
// Return error with constant added.
|
||||
double error(const VectorValues &values) const {
|
||||
return factor->error(values) + constant;
|
||||
}
|
||||
|
||||
// Check pointer equality.
|
||||
bool operator==(const FactorAndConstant &other) const {
|
||||
return factor == other.factor && constant == other.constant;
|
||||
}
|
||||
};
|
||||
|
||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||
using Factors = DecisionTree<Key, FactorAndConstant>;
|
||||
using Mixture = DecisionTree<Key, sharedFactor>;
|
||||
|
||||
private:
|
||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||
|
@ -85,7 +98,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
* @param continuousKeys A vector of keys representing continuous variables.
|
||||
* @param discreteKeys A vector of keys representing discrete variables and
|
||||
* their cardinalities.
|
||||
* @param factors The decision tree of Gaussian Factors stored as the mixture
|
||||
* @param factors The decision tree of Gaussian factors stored as the mixture
|
||||
* density.
|
||||
*/
|
||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
|
@ -107,7 +120,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
*/
|
||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
||||
const std::vector<sharedFactor> &factors)
|
||||
: GaussianMixtureFactor(continuousKeys, discreteKeys,
|
||||
Mixture(discreteKeys, factors)) {}
|
||||
|
||||
|
@ -121,9 +134,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
const std::string &s = "GaussianMixtureFactor\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
/// @}
|
||||
/// @name Standard API
|
||||
/// @{
|
||||
|
||||
/// Getter for the underlying Gaussian Factor Decision Tree.
|
||||
const Mixture factors();
|
||||
const Mixture factors() const;
|
||||
|
||||
/**
|
||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
||||
|
@ -145,21 +160,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the error of this Gaussian Mixture given the continuous
|
||||
* values and a discrete assignment.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @param discreteValues The discrete assignment for a specific mode sequence.
|
||||
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||
* @return double
|
||||
*/
|
||||
double error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const;
|
||||
double error(const HybridValues &values) const;
|
||||
|
||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
||||
sum = factor.add(sum);
|
||||
return sum;
|
||||
}
|
||||
/// @}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
@ -263,16 +263,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
if (keysOfSeparator.empty()) {
|
||||
VectorValues empty_values;
|
||||
auto factorProb =
|
||||
[&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
|
||||
if (!factor_z.first) {
|
||||
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
|
||||
GaussianFactor::shared_ptr factor = factor_z.factor;
|
||||
if (!factor) {
|
||||
return 0.0; // If nullptr, return 0.0 probability
|
||||
} else {
|
||||
GaussianFactor::shared_ptr factor = factor_z.first;
|
||||
double log_z = factor_z.second;
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
double error =
|
||||
0.5 * std::abs(factor->augmentedInformation().determinant()) +
|
||||
log_z;
|
||||
factor_z.constant;
|
||||
return std::exp(-error);
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue