Change from pair to small struct

release/4.3a0
Frank Dellaert 2022-12-30 12:09:56 -05:00
parent 38a6154c55
commit b972be0b8f
4 changed files with 62 additions and 58 deletions

View File

@ -149,9 +149,10 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods( const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return std::make_pair(conditional->likelihood(frontals), return GaussianMixtureFactor::FactorAndConstant{
0.5 * conditional->logDeterminant()); conditional->likelihood(frontals),
0.5 * conditional->logDeterminant()};
}); });
return boost::make_shared<GaussianMixtureFactor>( return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods); continuousParentKeys, discreteParentKeys, likelihoods);

View File

@ -22,6 +22,8 @@
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
@ -32,7 +34,7 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const Mixture &factors) const Mixture &factors)
: Base(continuousKeys, discreteKeys), : Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) { factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return std::make_pair(gf, 0.0); return FactorAndConstant{gf, 0.0};
}) {} }) {}
/* *******************************************************************************/ /* *******************************************************************************/
@ -46,10 +48,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors: // Check the base and the factors:
return Base::equals(*e, tol) && return Base::equals(*e, tol) &&
factors_.equals(e->factors_, factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
[tol](const GaussianMixtureFactor::FactorAndLogZ &f1, const FactorAndConstant &f2) {
const GaussianMixtureFactor::FactorAndLogZ &f2) { return f1.factor->equals(*(f2.factor), tol) &&
return f1.first->equals(*(f2.first), tol); std::abs(f1.constant - f2.constant) < tol;
}); });
} }
@ -63,8 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s,
} else { } else {
factors_.print( factors_.print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string { [&](const FactorAndConstant &gf_z) -> std::string {
auto gf = gf_z.first; auto gf = gf_z.factor;
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (gf && !gf->empty()) { if (gf && !gf->empty()) {
@ -79,10 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s,
} }
/* *******************************************************************************/ /* *******************************************************************************/
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() { const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
// Unzip to tree of Gaussian factors and tree of log-constants, return Mixture(factors_, [](const FactorAndConstant &factor_z) {
// and return the first tree. return factor_z.factor;
return unzip(factors_).first; });
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -101,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor_z.first); result.push_back(factor_z.factor);
return result; return result;
}; };
return {factors_, wrap}; return {factors_, wrap};
@ -113,26 +115,17 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error( AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
[continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { return factor_z.error(continuousValues);
GaussianFactor::shared_ptr factor;
double log_z;
std::tie(factor, log_z) = factor_z;
return factor->error(continuousValues) + log_z;
}; };
DecisionTree<Key, double> errorTree(factors_, errorFunc); DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree; return errorTree;
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixtureFactor::error( double GaussianMixtureFactor::error(const HybridValues &values) const {
const VectorValues &continuousValues, const FactorAndConstant factor_z = factors_(values.discrete());
const DiscreteValues &discreteValues) const { return factor_z.factor->error(values.continuous()) + factor_z.constant;
// Directly index to get the conditional, no need to build the whole tree.
GaussianFactor::shared_ptr factor;
double log_z;
std::tie(factor, log_z) = factors_(discreteValues);
return factor->error(continuousValues) + log_z;
} }
} // namespace gtsam } // namespace gtsam

View File

@ -23,17 +23,15 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>
namespace gtsam { namespace gtsam {
class GaussianFactorGraph; class GaussianFactorGraph;
class HybridValues;
// Needed for wrapper. class DiscreteValues;
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>; class VectorValues;
/** /**
* @brief Implementation of a discrete conditional mixture factor. * @brief Implementation of a discrete conditional mixture factor.
@ -53,12 +51,27 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
using Sum = DecisionTree<Key, GaussianFactorGraph>; using Sum = DecisionTree<Key, GaussianFactorGraph>;
using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// typedef of pair of Gaussian factor and log of normalizing constant. /// Gaussian factor and log of normalizing constant.
using FactorAndLogZ = std::pair<GaussianFactor::shared_ptr, double>; struct FactorAndConstant {
/// typedef for Decision Tree of Gaussian Factors and log-constant. sharedFactor factor;
using Factors = DecisionTree<Key, FactorAndLogZ>; double constant;
using Mixture = DecisionTree<Key, GaussianFactor::shared_ptr>;
// Return error with constant added.
double error(const VectorValues &values) const {
return factor->error(values) + constant;
}
// Check pointer equality.
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
};
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys. /// Decision tree of Gaussian factors indexed by discrete keys.
@ -85,7 +98,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys A vector of keys representing continuous variables. * @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and * @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities. * their cardinalities.
* @param factors The decision tree of Gaussian Factors stored as the mixture * @param factors The decision tree of Gaussian factors stored as the mixture
* density. * density.
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
@ -107,7 +120,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors) const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys, : GaussianMixtureFactor(continuousKeys, discreteKeys,
Mixture(discreteKeys, factors)) {} Mixture(discreteKeys, factors)) {}
@ -121,9 +134,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const std::string &s = "GaussianMixtureFactor\n", const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API
/// @{
/// Getter for the underlying Gaussian Factor Decision Tree. /// Getter for the underlying Gaussian Factor Decision Tree.
const Mixture factors(); const Mixture factors() const;
/** /**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -145,21 +160,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/** /**
* @brief Compute the error of this Gaussian Mixture given the continuous * @brief Compute the log-likelihood, including the log-normalizing constant.
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double * @return double
*/ */
double error(const VectorValues &continuousValues, double error(const HybridValues &values) const;
const DiscreteValues &discreteValues) const;
/// Add MixtureFactor to a Sum, syntactic sugar. /// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum); sum = factor.add(sum);
return sum; return sum;
} }
/// @}
}; };
// traits // traits

View File

@ -263,16 +263,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (keysOfSeparator.empty()) { if (keysOfSeparator.empty()) {
VectorValues empty_values; VectorValues empty_values;
auto factorProb = auto factorProb =
[&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
if (!factor_z.first) { GaussianFactor::shared_ptr factor = factor_z.factor;
if (!factor) {
return 0.0; // If nullptr, return 0.0 probability return 0.0; // If nullptr, return 0.0 probability
} else { } else {
GaussianFactor::shared_ptr factor = factor_z.first;
double log_z = factor_z.second;
// This is the probability q(μ) at the MLE point. // This is the probability q(μ) at the MLE point.
double error = double error =
0.5 * std::abs(factor->augmentedInformation().determinant()) + 0.5 * std::abs(factor->augmentedInformation().determinant()) +
log_z; factor_z.constant;
return std::exp(-error); return std::exp(-error);
} }
}; };