Change from pair to small struct

release/4.3a0
Frank Dellaert 2022-12-30 12:09:56 -05:00
parent 38a6154c55
commit b972be0b8f
4 changed files with 62 additions and 58 deletions

View File

@ -149,9 +149,10 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
return std::make_pair(conditional->likelihood(frontals),
0.5 * conditional->logDeterminant());
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(frontals),
0.5 * conditional->logDeterminant()};
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);

View File

@ -22,6 +22,8 @@
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
@ -32,7 +34,7 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const Mixture &factors)
: Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return std::make_pair(gf, 0.0);
return FactorAndConstant{gf, 0.0};
}) {}
/* *******************************************************************************/
@ -46,11 +48,11 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const GaussianMixtureFactor::FactorAndLogZ &f1,
const GaussianMixtureFactor::FactorAndLogZ &f2) {
return f1.first->equals(*(f2.first), tol);
});
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
const FactorAndConstant &f2) {
return f1.factor->equals(*(f2.factor), tol) &&
std::abs(f1.constant - f2.constant) < tol;
});
}
/* *******************************************************************************/
@ -63,8 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string {
auto gf = gf_z.first;
[&](const FactorAndConstant &gf_z) -> std::string {
auto gf = gf_z.factor;
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
@ -79,10 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s,
}
/* *******************************************************************************/
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() {
// Unzip to tree of Gaussian factors and tree of log-constants,
// and return the first tree.
return unzip(factors_).first;
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
return Mixture(factors_, [](const FactorAndConstant &factor_z) {
return factor_z.factor;
});
}
/* *******************************************************************************/
@ -101,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result;
result.push_back(factor_z.first);
result.push_back(factor_z.factor);
return result;
};
return {factors_, wrap};
@ -113,26 +115,17 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
GaussianFactor::shared_ptr factor;
double log_z;
std::tie(factor, log_z) = factor_z;
return factor->error(continuousValues) + log_z;
};
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
return factor_z.error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}
/* *******************************************************************************/
double GaussianMixtureFactor::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
GaussianFactor::shared_ptr factor;
double log_z;
std::tie(factor, log_z) = factors_(discreteValues);
return factor->error(continuousValues) + log_z;
double GaussianMixtureFactor::error(const HybridValues &values) const {
const FactorAndConstant factor_z = factors_(values.discrete());
return factor_z.factor->error(values.continuous()) + factor_z.constant;
}
} // namespace gtsam

View File

@ -23,17 +23,15 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>
namespace gtsam {
class GaussianFactorGraph;
// Needed for wrapper.
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;
class HybridValues;
class DiscreteValues;
class VectorValues;
/**
* @brief Implementation of a discrete conditional mixture factor.
@ -53,12 +51,27 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using shared_ptr = boost::shared_ptr<This>;
using Sum = DecisionTree<Key, GaussianFactorGraph>;
using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// typedef of pair of Gaussian factor and log of normalizing constant.
using FactorAndLogZ = std::pair<GaussianFactor::shared_ptr, double>;
/// typedef for Decision Tree of Gaussian Factors and log-constant.
using Factors = DecisionTree<Key, FactorAndLogZ>;
using Mixture = DecisionTree<Key, GaussianFactor::shared_ptr>;
/// Gaussian factor and log of normalizing constant.
struct FactorAndConstant {
sharedFactor factor;
double constant;
// Return error with constant added.
double error(const VectorValues &values) const {
return factor->error(values) + constant;
}
// Check pointer equality.
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
};
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
@ -85,7 +98,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities.
* @param factors The decision tree of Gaussian Factors stored as the mixture
* @param factors The decision tree of Gaussian factors stored as the mixture
* density.
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
@ -107,7 +120,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors)
const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Mixture(discreteKeys, factors)) {}
@ -121,9 +134,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard API
/// @{
/// Getter for the underlying Gaussian Factor Decision Tree.
const Mixture factors();
const Mixture factors() const;
/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -145,21 +160,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @brief Compute the log-likelihood, including the log-normalizing constant.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const;
/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
return sum;
}
/// @}
};
// traits

View File

@ -263,16 +263,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (keysOfSeparator.empty()) {
VectorValues empty_values;
auto factorProb =
[&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
if (!factor_z.first) {
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
GaussianFactor::shared_ptr factor = factor_z.factor;
if (!factor) {
return 0.0; // If nullptr, return 0.0 probability
} else {
GaussianFactor::shared_ptr factor = factor_z.first;
double log_z = factor_z.second;
// This is the probability q(μ) at the MLE point.
double error =
0.5 * std::abs(factor->augmentedInformation().determinant()) +
log_z;
factor_z.constant;
return std::exp(-error);
}
};