From b972be0b8f6f2a9f0f21fbfd42d89b57d93b5588 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 12:09:56 -0500 Subject: [PATCH] Change from pair to small struct --- gtsam/hybrid/GaussianMixture.cpp | 7 +-- gtsam/hybrid/GaussianMixtureFactor.cpp | 51 +++++++++------------ gtsam/hybrid/GaussianMixtureFactor.h | 53 +++++++++++++--------- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 9 ++-- 4 files changed, 62 insertions(+), 58 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index ddcfaf0e8..10521244f 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -149,9 +149,10 @@ boost::shared_ptr GaussianMixture::likelihood( const DiscreteKeys discreteParentKeys = discreteKeys(); const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( - conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { - return std::make_pair(conditional->likelihood(frontals), - 0.5 * conditional->logDeterminant()); + conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { + return GaussianMixtureFactor::FactorAndConstant{ + conditional->likelihood(frontals), + 0.5 * conditional->logDeterminant()}; }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 0759cf3be..e07b300fa 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include namespace gtsam { @@ -32,7 +34,7 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, const Mixture &factors) : Base(continuousKeys, discreteKeys), factors_(factors, [](const GaussianFactor::shared_ptr &gf) { - return std::make_pair(gf, 0.0); + return FactorAndConstant{gf, 0.0}; }) {} /* *******************************************************************************/ @@ -46,11 +48,11 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { // Check the base and the factors: return Base::equals(*e, tol) && - factors_.equals(e->factors_, - [tol](const GaussianMixtureFactor::FactorAndLogZ &f1, - const GaussianMixtureFactor::FactorAndLogZ &f2) { - return f1.first->equals(*(f2.first), tol); - }); + factors_.equals(e->factors_, [tol](const FactorAndConstant &f1, + const FactorAndConstant &f2) { + return f1.factor->equals(*(f2.factor), tol) && + std::abs(f1.constant - f2.constant) < tol; + }); } /* *******************************************************************************/ @@ -63,8 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s, } else { factors_.print( "", [&](Key k) { return formatter(k); }, - [&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string { - auto gf = gf_z.first; + [&](const FactorAndConstant &gf_z) -> std::string { + auto gf = gf_z.factor; RedirectCout rd; std::cout << ":\n"; if (gf && !gf->empty()) { @@ -79,10 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s, } /* *******************************************************************************/ -const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() { - // Unzip to tree of Gaussian factors and tree of log-constants, - // and return the first tree. - return unzip(factors_).first; +const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { + return Mixture(factors_, [](const FactorAndConstant &factor_z) { + return factor_z.factor; + }); } /* *******************************************************************************/ @@ -101,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add( /* *******************************************************************************/ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() const { - auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { + auto wrap = [](const FactorAndConstant &factor_z) { GaussianFactorGraph result; - result.push_back(factor_z.first); + result.push_back(factor_z.factor); return result; }; return {factors_, wrap}; @@ -113,26 +115,17 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() AlgebraicDecisionTree GaussianMixtureFactor::error( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = - [continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { - GaussianFactor::shared_ptr factor; - double log_z; - std::tie(factor, log_z) = factor_z; - return factor->error(continuousValues) + log_z; - }; + auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) { + return factor_z.error(continuousValues); + }; DecisionTree errorTree(factors_, errorFunc); return errorTree; } /* *******************************************************************************/ -double GaussianMixtureFactor::error( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - // Directly index to get the conditional, no need to build the whole tree. - GaussianFactor::shared_ptr factor; - double log_z; - std::tie(factor, log_z) = factors_(discreteValues); - return factor->error(continuousValues) + log_z; +double GaussianMixtureFactor::error(const HybridValues &values) const { + const FactorAndConstant factor_z = factors_(values.discrete()); + return factor_z.factor->error(values.continuous()) + factor_z.constant; } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index b3e603bc3..aca4f365b 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -23,17 +23,15 @@ #include #include #include -#include -#include +#include #include -#include namespace gtsam { class GaussianFactorGraph; - -// Needed for wrapper. -using GaussianFactorVector = std::vector; +class HybridValues; +class DiscreteValues; +class VectorValues; /** * @brief Implementation of a discrete conditional mixture factor. @@ -53,12 +51,27 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { using shared_ptr = boost::shared_ptr; using Sum = DecisionTree; + using sharedFactor = boost::shared_ptr; - /// typedef of pair of Gaussian factor and log of normalizing constant. - using FactorAndLogZ = std::pair; - /// typedef for Decision Tree of Gaussian Factors and log-constant. - using Factors = DecisionTree; - using Mixture = DecisionTree; + /// Gaussian factor and log of normalizing constant. + struct FactorAndConstant { + sharedFactor factor; + double constant; + + // Return error with constant added. + double error(const VectorValues &values) const { + return factor->error(values) + constant; + } + + // Check pointer equality. + bool operator==(const FactorAndConstant &other) const { + return factor == other.factor && constant == other.constant; + } + }; + + /// typedef for Decision Tree of Gaussian factors and log-constant. + using Factors = DecisionTree; + using Mixture = DecisionTree; private: /// Decision tree of Gaussian factors indexed by discrete keys. @@ -85,7 +98,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @param continuousKeys A vector of keys representing continuous variables. * @param discreteKeys A vector of keys representing discrete variables and * their cardinalities. - * @param factors The decision tree of Gaussian Factors stored as the mixture + * @param factors The decision tree of Gaussian factors stored as the mixture * density. */ GaussianMixtureFactor(const KeyVector &continuousKeys, @@ -107,7 +120,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector &factors) + const std::vector &factors) : GaussianMixtureFactor(continuousKeys, discreteKeys, Mixture(discreteKeys, factors)) {} @@ -121,9 +134,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { const std::string &s = "GaussianMixtureFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard API + /// @{ /// Getter for the underlying Gaussian Factor Decision Tree. - const Mixture factors(); + const Mixture factors() const; /** * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while @@ -145,21 +160,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { AlgebraicDecisionTree error(const VectorValues &continuousValues) const; /** - * @brief Compute the error of this Gaussian Mixture given the continuous - * values and a discrete assignment. - * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues The discrete assignment for a specific mode sequence. + * @brief Compute the log-likelihood, including the log-normalizing constant. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const; /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { sum = factor.add(sum); return sum; } + /// @} }; // traits diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 15a84b27a..5c1c2daf3 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -263,16 +263,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors, if (keysOfSeparator.empty()) { VectorValues empty_values; auto factorProb = - [&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { - if (!factor_z.first) { + [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { + GaussianFactor::shared_ptr factor = factor_z.factor; + if (!factor) { return 0.0; // If nullptr, return 0.0 probability } else { - GaussianFactor::shared_ptr factor = factor_z.first; - double log_z = factor_z.second; // This is the probability q(μ) at the MLE point. double error = 0.5 * std::abs(factor->augmentedInformation().determinant()) + - log_z; + factor_z.constant; return std::exp(-error); } };