Merge pull request #1358 from borglab/hybrid/gaussian-mixture-factor
						commit
						d0821a57de
					
				|  | @ -22,6 +22,7 @@ | |||
| #include <gtsam/discrete/DiscreteValues.h> | ||||
| #include <gtsam/hybrid/GaussianMixture.h> | ||||
| #include <gtsam/hybrid/GaussianMixtureFactor.h> | ||||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| #include <gtsam/inference/Conditional-inst.h> | ||||
| #include <gtsam/linear/GaussianFactorGraph.h> | ||||
| 
 | ||||
|  | @ -149,17 +150,19 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood( | |||
|   const DiscreteKeys discreteParentKeys = discreteKeys(); | ||||
|   const KeyVector continuousParentKeys = continuousParents(); | ||||
|   const GaussianMixtureFactor::Factors likelihoods( | ||||
|       conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { | ||||
|         return conditional->likelihood(frontals); | ||||
|       conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { | ||||
|         return GaussianMixtureFactor::FactorAndConstant{ | ||||
|             conditional->likelihood(frontals), | ||||
|             conditional->logNormalizationConstant()}; | ||||
|       }); | ||||
|   return boost::make_shared<GaussianMixtureFactor>( | ||||
|       continuousParentKeys, discreteParentKeys, likelihoods); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) { | ||||
| std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { | ||||
|   std::set<DiscreteKey> s; | ||||
|   s.insert(dkeys.begin(), dkeys.end()); | ||||
|   s.insert(discreteKeys.begin(), discreteKeys.end()); | ||||
|   return s; | ||||
| } | ||||
| 
 | ||||
|  | @ -184,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { | |||
|                     const GaussianConditional::shared_ptr &conditional) | ||||
|       -> GaussianConditional::shared_ptr { | ||||
|     // typecast so we can use this to get probability value
 | ||||
|     DiscreteValues values(choices); | ||||
|     const DiscreteValues values(choices); | ||||
| 
 | ||||
|     // Case where the gaussian mixture has the same
 | ||||
|     // discrete keys as the decision tree.
 | ||||
|  | @ -254,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error( | |||
| } | ||||
| 
 | ||||
| /* *******************************************************************************/ | ||||
| double GaussianMixture::error(const VectorValues &continuousValues, | ||||
|                               const DiscreteValues &discreteValues) const { | ||||
| double GaussianMixture::error(const HybridValues &values) const { | ||||
|   // Directly index to get the conditional, no need to build the whole tree.
 | ||||
|   auto conditional = conditionals_(discreteValues); | ||||
|   return conditional->error(continuousValues); | ||||
|   auto conditional = conditionals_(values.discrete()); | ||||
|   return conditional->error(values.continuous()); | ||||
| } | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -30,6 +30,7 @@ | |||
| namespace gtsam { | ||||
| 
 | ||||
| class GaussianMixtureFactor; | ||||
| class HybridValues; | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief A conditional of gaussian mixtures indexed by discrete variables, as | ||||
|  | @ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture | |||
|   /// @name Constructors
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// Defaut constructor, mainly for serialization.
 | ||||
|   /// Default constructor, mainly for serialization.
 | ||||
|   GaussianMixture() = default; | ||||
| 
 | ||||
|   /**
 | ||||
|  | @ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture | |||
|   /// @name Standard API
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// @brief Return the conditional Gaussian for the given discrete assignment.
 | ||||
|   GaussianConditional::shared_ptr operator()( | ||||
|       const DiscreteValues &discreteValues) const; | ||||
| 
 | ||||
|  | @ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture | |||
|    * @brief Compute the error of this Gaussian Mixture given the continuous | ||||
|    * values and a discrete assignment. | ||||
|    * | ||||
|    * @param continuousValues Continuous values at which to compute the error. | ||||
|    * @param discreteValues The discrete assignment for a specific mode sequence. | ||||
|    * @param values Continuous values and discrete assignment. | ||||
|    * @return double | ||||
|    */ | ||||
|   double error(const VectorValues &continuousValues, | ||||
|                const DiscreteValues &discreteValues) const; | ||||
|   double error(const HybridValues &values) const override; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Prune the decision tree of Gaussian factors as per the discrete | ||||
|  | @ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture | |||
| }; | ||||
| 
 | ||||
| /// Return the DiscreteKey vector as a set.
 | ||||
| std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys); | ||||
| std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys); | ||||
| 
 | ||||
| // traits
 | ||||
| template <> | ||||
|  |  | |||
|  | @ -22,6 +22,8 @@ | |||
| #include <gtsam/discrete/DecisionTree-inl.h> | ||||
| #include <gtsam/discrete/DecisionTree.h> | ||||
| #include <gtsam/hybrid/GaussianMixtureFactor.h> | ||||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| #include <gtsam/linear/GaussianFactor.h> | ||||
| #include <gtsam/linear/GaussianFactorGraph.h> | ||||
| 
 | ||||
| namespace gtsam { | ||||
|  | @ -29,8 +31,11 @@ namespace gtsam { | |||
| /* *******************************************************************************/ | ||||
| GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||
|                                              const DiscreteKeys &discreteKeys, | ||||
|                                              const Factors &factors) | ||||
|     : Base(continuousKeys, discreteKeys), factors_(factors) {} | ||||
|                                              const Mixture &factors) | ||||
|     : Base(continuousKeys, discreteKeys), | ||||
|       factors_(factors, [](const GaussianFactor::shared_ptr &gf) { | ||||
|         return FactorAndConstant{gf, 0.0}; | ||||
|       }) {} | ||||
| 
 | ||||
| /* *******************************************************************************/ | ||||
| bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { | ||||
|  | @ -43,10 +48,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { | |||
| 
 | ||||
|   // Check the base and the factors:
 | ||||
|   return Base::equals(*e, tol) && | ||||
|          factors_.equals(e->factors_, | ||||
|                          [tol](const GaussianFactor::shared_ptr &f1, | ||||
|                                const GaussianFactor::shared_ptr &f2) { | ||||
|                            return f1->equals(*f2, tol); | ||||
|          factors_.equals(e->factors_, [tol](const FactorAndConstant &f1, | ||||
|                                             const FactorAndConstant &f2) { | ||||
|            return f1.factor->equals(*(f2.factor), tol) && | ||||
|                   std::abs(f1.constant - f2.constant) < tol; | ||||
|          }); | ||||
| } | ||||
| 
 | ||||
|  | @ -60,7 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s, | |||
|   } else { | ||||
|     factors_.print( | ||||
|         "", [&](Key k) { return formatter(k); }, | ||||
|         [&](const GaussianFactor::shared_ptr &gf) -> std::string { | ||||
|         [&](const FactorAndConstant &gf_z) -> std::string { | ||||
|           auto gf = gf_z.factor; | ||||
|           RedirectCout rd; | ||||
|           std::cout << ":\n"; | ||||
|           if (gf && !gf->empty()) { | ||||
|  | @ -75,8 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s, | |||
| } | ||||
| 
 | ||||
| /* *******************************************************************************/ | ||||
| const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { | ||||
|   return factors_; | ||||
| const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { | ||||
|   return Mixture(factors_, [](const FactorAndConstant &factor_z) { | ||||
|     return factor_z.factor; | ||||
|   }); | ||||
| } | ||||
| 
 | ||||
| /* *******************************************************************************/ | ||||
|  | @ -95,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add( | |||
| /* *******************************************************************************/ | ||||
| GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() | ||||
|     const { | ||||
|   auto wrap = [](const GaussianFactor::shared_ptr &factor) { | ||||
|   auto wrap = [](const FactorAndConstant &factor_z) { | ||||
|     GaussianFactorGraph result; | ||||
|     result.push_back(factor); | ||||
|     result.push_back(factor_z.factor); | ||||
|     return result; | ||||
|   }; | ||||
|   return {factors_, wrap}; | ||||
|  | @ -107,21 +115,18 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() | |||
| AlgebraicDecisionTree<Key> GaussianMixtureFactor::error( | ||||
|     const VectorValues &continuousValues) const { | ||||
|   // functor to convert from sharedFactor to double error value.
 | ||||
|   auto errorFunc = | ||||
|       [continuousValues](const GaussianFactor::shared_ptr &factor) { | ||||
|         return factor->error(continuousValues); | ||||
|   auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) { | ||||
|     return factor_z.error(continuousValues); | ||||
|   }; | ||||
|   DecisionTree<Key, double> errorTree(factors_, errorFunc); | ||||
|   return errorTree; | ||||
| } | ||||
| 
 | ||||
| /* *******************************************************************************/ | ||||
| double GaussianMixtureFactor::error( | ||||
|     const VectorValues &continuousValues, | ||||
|     const DiscreteValues &discreteValues) const { | ||||
|   // Directly index to get the conditional, no need to build the whole tree.
 | ||||
|   auto factor = factors_(discreteValues); | ||||
|   return factor->error(continuousValues); | ||||
| double GaussianMixtureFactor::error(const HybridValues &values) const { | ||||
|   const FactorAndConstant factor_z = factors_(values.discrete()); | ||||
|   return factor_z.error(values.continuous()); | ||||
| } | ||||
| /* *******************************************************************************/ | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -23,17 +23,15 @@ | |||
| #include <gtsam/discrete/AlgebraicDecisionTree.h> | ||||
| #include <gtsam/discrete/DecisionTree.h> | ||||
| #include <gtsam/discrete/DiscreteKey.h> | ||||
| #include <gtsam/discrete/DiscreteValues.h> | ||||
| #include <gtsam/hybrid/HybridGaussianFactor.h> | ||||
| #include <gtsam/hybrid/HybridFactor.h> | ||||
| #include <gtsam/linear/GaussianFactor.h> | ||||
| #include <gtsam/linear/VectorValues.h> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| class GaussianFactorGraph; | ||||
| 
 | ||||
| // Needed for wrapper.
 | ||||
| using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>; | ||||
| class HybridValues; | ||||
| class DiscreteValues; | ||||
| class VectorValues; | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Implementation of a discrete conditional mixture factor. | ||||
|  | @ -53,9 +51,29 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | |||
|   using shared_ptr = boost::shared_ptr<This>; | ||||
| 
 | ||||
|   using Sum = DecisionTree<Key, GaussianFactorGraph>; | ||||
|   using sharedFactor = boost::shared_ptr<GaussianFactor>; | ||||
| 
 | ||||
|   /// typedef for Decision Tree of Gaussian Factors
 | ||||
|   using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>; | ||||
|   /// Gaussian factor and log of normalizing constant.
 | ||||
|   struct FactorAndConstant { | ||||
|     sharedFactor factor; | ||||
|     double constant; | ||||
| 
 | ||||
|     // Return error with constant correction.
 | ||||
|     double error(const VectorValues &values) const { | ||||
|       // Note minus sign: constant is log of normalization constant for probabilities.
 | ||||
|       // Errors is the negative log-likelihood, hence we subtract the constant here.
 | ||||
|       return factor->error(values) - constant; | ||||
|     } | ||||
| 
 | ||||
|     // Check pointer equality.
 | ||||
|     bool operator==(const FactorAndConstant &other) const { | ||||
|       return factor == other.factor && constant == other.constant; | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   /// typedef for Decision Tree of Gaussian factors and log-constant.
 | ||||
|   using Factors = DecisionTree<Key, FactorAndConstant>; | ||||
|   using Mixture = DecisionTree<Key, sharedFactor>; | ||||
| 
 | ||||
|  private: | ||||
|   /// Decision tree of Gaussian factors indexed by discrete keys.
 | ||||
|  | @ -82,12 +100,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | |||
|    * @param continuousKeys A vector of keys representing continuous variables. | ||||
|    * @param discreteKeys A vector of keys representing discrete variables and | ||||
|    * their cardinalities. | ||||
|    * @param factors The decision tree of Gaussian Factors stored as the mixture | ||||
|    * @param factors The decision tree of Gaussian factors stored as the mixture | ||||
|    * density. | ||||
|    */ | ||||
|   GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||
|                         const DiscreteKeys &discreteKeys, | ||||
|                         const Factors &factors); | ||||
|                         const Mixture &factors); | ||||
| 
 | ||||
|   GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||
|                         const DiscreteKeys &discreteKeys, | ||||
|                         const Factors &factors_and_z) | ||||
|       : Base(continuousKeys, discreteKeys), factors_(factors_and_z) {} | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Construct a new GaussianMixtureFactor object using a vector of | ||||
|  | @ -99,9 +122,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | |||
|    */ | ||||
|   GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||
|                         const DiscreteKeys &discreteKeys, | ||||
|                         const std::vector<GaussianFactor::shared_ptr> &factors) | ||||
|                         const std::vector<sharedFactor> &factors) | ||||
|       : GaussianMixtureFactor(continuousKeys, discreteKeys, | ||||
|                               Factors(discreteKeys, factors)) {} | ||||
|                               Mixture(discreteKeys, factors)) {} | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Testable
 | ||||
|  | @ -113,9 +136,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | |||
|       const std::string &s = "GaussianMixtureFactor\n", | ||||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||
|   /// @}
 | ||||
|   /// @name Standard API
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// Getter for the underlying Gaussian Factor Decision Tree.
 | ||||
|   const Factors &factors(); | ||||
|   const Mixture factors() const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while | ||||
|  | @ -137,21 +162,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | |||
|   AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Compute the error of this Gaussian Mixture given the continuous | ||||
|    * values and a discrete assignment. | ||||
|    * | ||||
|    * @param continuousValues Continuous values at which to compute the error. | ||||
|    * @param discreteValues The discrete assignment for a specific mode sequence. | ||||
|    * @brief Compute the log-likelihood, including the log-normalizing constant. | ||||
|    * @return double | ||||
|    */ | ||||
|   double error(const VectorValues &continuousValues, | ||||
|                const DiscreteValues &discreteValues) const; | ||||
|   double error(const HybridValues &values) const override; | ||||
| 
 | ||||
|   /// Add MixtureFactor to a Sum, syntactic sugar.
 | ||||
|   friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { | ||||
|     sum = factor.add(sum); | ||||
|     return sum; | ||||
|   } | ||||
|   /// @}
 | ||||
| }; | ||||
| 
 | ||||
| // traits
 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
|  | @ -12,6 +12,7 @@ | |||
|  * @author Fan Jiang | ||||
|  * @author Varun Agrawal | ||||
|  * @author Shangjie Xue | ||||
|  * @author Frank Dellaert | ||||
|  * @date   January 2022 | ||||
|  */ | ||||
| 
 | ||||
|  | @ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const { | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| double HybridBayesNet::error(const VectorValues &continuousValues, | ||||
|                              const DiscreteValues &discreteValues) const { | ||||
|   GaussianBayesNet gbn = choose(discreteValues); | ||||
|   return gbn.error(continuousValues); | ||||
| double HybridBayesNet::error(const HybridValues &values) const { | ||||
|   GaussianBayesNet gbn = choose(values.discrete()); | ||||
|   return gbn.error(values.continuous()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | |||
|    * @brief 0.5 * sum of squared Mahalanobis distances | ||||
|    * for a specific discrete assignment. | ||||
|    * | ||||
|    * @param continuousValues Continuous values at which to compute the error. | ||||
|    * @param discreteValues Discrete assignment for a specific mode sequence. | ||||
|    * @param values Continuous values and discrete assignment. | ||||
|    * @return double | ||||
|    */ | ||||
|   double error(const VectorValues &continuousValues, | ||||
|                const DiscreteValues &discreteValues) const; | ||||
|   double error(const HybridValues &values) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Compute conditional error for each discrete assignment, | ||||
|  |  | |||
|  | @ -52,7 +52,7 @@ namespace gtsam { | |||
|  * having diamond inheritances, and neutralized the need to change other | ||||
|  * components of GTSAM to make hybrid elimination work. | ||||
|  * | ||||
|  * A great reference to the type-erasure pattern is Eduaado Madrid's CppCon | ||||
|  * A great reference to the type-erasure pattern is Eduardo Madrid's CppCon | ||||
|  * talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
 | ||||
|  * | ||||
|  * @ingroup hybrid | ||||
|  | @ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional | |||
|    */ | ||||
|   HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture); | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return HybridConditional as a GaussianMixture | ||||
|    * @return nullptr if not a mixture | ||||
|    * @return GaussianMixture::shared_ptr otherwise | ||||
|    */ | ||||
|   GaussianMixture::shared_ptr asMixture() { | ||||
|     return boost::dynamic_pointer_cast<GaussianMixture>(inner_); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return HybridConditional as a GaussianConditional | ||||
|    * @return nullptr if not a GaussianConditional | ||||
|    * @return GaussianConditional::shared_ptr otherwise | ||||
|    */ | ||||
|   GaussianConditional::shared_ptr asGaussian() { | ||||
|     return boost::dynamic_pointer_cast<GaussianConditional>(inner_); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return conditional as a DiscreteConditional | ||||
|    * @return nullptr if not a DiscreteConditional | ||||
|    * @return DiscreteConditional::shared_ptr | ||||
|    */ | ||||
|   DiscreteConditional::shared_ptr asDiscrete() { | ||||
|     return boost::dynamic_pointer_cast<DiscreteConditional>(inner_); | ||||
|   } | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Testable
 | ||||
|   /// @{
 | ||||
|  | @ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional | |||
|   bool equals(const HybridFactor& other, double tol = 1e-9) const override; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Standard Interface
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return HybridConditional as a GaussianMixture | ||||
|    * @return nullptr if not a mixture | ||||
|    * @return GaussianMixture::shared_ptr otherwise | ||||
|    */ | ||||
|   GaussianMixture::shared_ptr asMixture() const { | ||||
|     return boost::dynamic_pointer_cast<GaussianMixture>(inner_); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return HybridConditional as a GaussianConditional | ||||
|    * @return nullptr if not a GaussianConditional | ||||
|    * @return GaussianConditional::shared_ptr otherwise | ||||
|    */ | ||||
|   GaussianConditional::shared_ptr asGaussian() const { | ||||
|     return boost::dynamic_pointer_cast<GaussianConditional>(inner_); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return conditional as a DiscreteConditional | ||||
|    * @return nullptr if not a DiscreteConditional | ||||
|    * @return DiscreteConditional::shared_ptr | ||||
|    */ | ||||
|   DiscreteConditional::shared_ptr asDiscrete() const { | ||||
|     return boost::dynamic_pointer_cast<DiscreteConditional>(inner_); | ||||
|   } | ||||
| 
 | ||||
|   /// Get the type-erased pointer to the inner type
 | ||||
|   boost::shared_ptr<Factor> inner() { return inner_; } | ||||
| 
 | ||||
|   /// Return the error of the underlying conditional.
 | ||||
|   /// Currently only implemented for Gaussian mixture.
 | ||||
|   double error(const HybridValues& values) const override { | ||||
|     if (auto gm = asMixture()) { | ||||
|       return gm->error(values); | ||||
|     } else { | ||||
|       throw std::runtime_error( | ||||
|           "HybridConditional::error: only implemented for Gaussian mixture"); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /// @}
 | ||||
| 
 | ||||
|  private: | ||||
|   /** Serialization function */ | ||||
|   friend class boost::serialization::access; | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ | |||
|  */ | ||||
| 
 | ||||
| #include <gtsam/hybrid/HybridDiscreteFactor.h> | ||||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| 
 | ||||
| #include <boost/make_shared.hpp> | ||||
| 
 | ||||
|  | @ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s, | |||
|   inner_->print("\n", formatter); | ||||
| }; | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double HybridDiscreteFactor::error(const HybridValues &values) const { | ||||
|   return -log((*inner_)(values.discrete())); | ||||
| } | ||||
| /* ************************************************************************ */ | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -24,10 +24,12 @@ | |||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| class HybridValues; | ||||
| 
 | ||||
| /**
 | ||||
|  * A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows | ||||
|  * us to hide the implementation of DiscreteFactor and thus avoid diamond | ||||
|  * inheritance. | ||||
|  * A HybridDiscreteFactor is a thin container for DiscreteFactor, which | ||||
|  * allows us to hide the implementation of DiscreteFactor and thus avoid | ||||
|  * diamond inheritance. | ||||
|  * | ||||
|  * @ingroup hybrid | ||||
|  */ | ||||
|  | @ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { | |||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Standard Interface
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// Return pointer to the internal discrete factor
 | ||||
|   DiscreteFactor::shared_ptr inner() const { return inner_; } | ||||
| 
 | ||||
|   /// Return the error of the underlying Discrete Factor.
 | ||||
|   double error(const HybridValues &values) const override; | ||||
|   /// @}
 | ||||
| }; | ||||
| 
 | ||||
| // traits
 | ||||
|  |  | |||
|  | @ -26,6 +26,8 @@ | |||
| #include <string> | ||||
| namespace gtsam { | ||||
| 
 | ||||
| class HybridValues; | ||||
| 
 | ||||
| KeyVector CollectKeys(const KeyVector &continuousKeys, | ||||
|                       const DiscreteKeys &discreteKeys); | ||||
| KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); | ||||
|  | @ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor { | |||
|   /// @name Standard Interface
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Compute the error of this Gaussian Mixture given the continuous | ||||
|    * values and a discrete assignment. | ||||
|    * | ||||
|    * @param values Continuous values and discrete assignment. | ||||
|    * @return double | ||||
|    */ | ||||
|   virtual double error(const HybridValues &values) const = 0; | ||||
| 
 | ||||
|   /// True if this is a factor of discrete variables only.
 | ||||
|   bool isDiscrete() const { return isDiscrete_; } | ||||
| 
 | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ | |||
|  */ | ||||
| 
 | ||||
| #include <gtsam/hybrid/HybridGaussianFactor.h> | ||||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| #include <gtsam/linear/HessianFactor.h> | ||||
| #include <gtsam/linear/JacobianFactor.h> | ||||
| 
 | ||||
|  | @ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s, | |||
|   inner_->print("\n", formatter); | ||||
| }; | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double HybridGaussianFactor::error(const HybridValues &values) const { | ||||
|   return inner_->error(values.continuous()); | ||||
| } | ||||
| /* ************************************************************************ */ | ||||
| 
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -25,6 +25,7 @@ namespace gtsam { | |||
| // Forward declarations
 | ||||
| class JacobianFactor; | ||||
| class HessianFactor; | ||||
| class HybridValues; | ||||
| 
 | ||||
| /**
 | ||||
|  * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have | ||||
|  | @ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { | |||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Standard Interface
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// Return pointer to the internal discrete factor
 | ||||
|   GaussianFactor::shared_ptr inner() const { return inner_; } | ||||
| 
 | ||||
|   /// Return the error of the underlying Discrete Factor.
 | ||||
|   double error(const HybridValues &values) const override; | ||||
|   /// @}
 | ||||
| }; | ||||
| 
 | ||||
| // traits
 | ||||
|  |  | |||
|  | @ -55,13 +55,14 @@ | |||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
 | ||||
| template class EliminateableFactorGraph<HybridGaussianFactorGraph>; | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| static GaussianMixtureFactor::Sum &addGaussian( | ||||
|     GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { | ||||
|   using Y = GaussianFactorGraph; | ||||
|   // If the decision tree is not intiialized, then intialize it.
 | ||||
|   // If the decision tree is not initialized, then initialize it.
 | ||||
|   if (sum.empty()) { | ||||
|     GaussianFactorGraph result; | ||||
|     result.push_back(factor); | ||||
|  | @ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals( | |||
| 
 | ||||
|   for (auto &f : factors) { | ||||
|     if (f->isHybrid()) { | ||||
|       if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { | ||||
|         sum = cgmf->add(sum); | ||||
|       // TODO(dellaert): just use a virtual method defined in HybridFactor.
 | ||||
|       if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { | ||||
|         sum = gm->add(sum); | ||||
|       } | ||||
|       if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { | ||||
|         sum = gm->asMixture()->add(sum); | ||||
|  | @ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | |||
|                   const KeySet &continuousSeparator, | ||||
|                   const std::set<DiscreteKey> &discreteSeparatorSet) { | ||||
|   // NOTE: since we use the special JunctionTree,
 | ||||
|   // only possiblity is continuous conditioned on discrete.
 | ||||
|   // only possibility is continuous conditioned on discrete.
 | ||||
|   DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), | ||||
|                                  discreteSeparatorSet.end()); | ||||
| 
 | ||||
|  | @ -204,16 +206,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | |||
|   }; | ||||
|   sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); | ||||
| 
 | ||||
|   using EliminationPair = GaussianFactorGraph::EliminationResult; | ||||
|   using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>, | ||||
|                                     GaussianMixtureFactor::FactorAndConstant>; | ||||
| 
 | ||||
|   KeyVector keysOfEliminated;  // Not the ordering
 | ||||
|   KeyVector keysOfSeparator;   // TODO(frank): Is this just (keys - ordering)?
 | ||||
| 
 | ||||
|   // This is the elimination method on the leaf nodes
 | ||||
|   auto eliminate = [&](const GaussianFactorGraph &graph) | ||||
|       -> GaussianFactorGraph::EliminationResult { | ||||
|   auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { | ||||
|     if (graph.empty()) { | ||||
|       return {nullptr, nullptr}; | ||||
|       return {nullptr, {nullptr, 0.0}}; | ||||
|     } | ||||
| 
 | ||||
| #ifdef HYBRID_TIMING | ||||
|  | @ -222,18 +224,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | |||
| 
 | ||||
|     std::pair<boost::shared_ptr<GaussianConditional>, | ||||
|               boost::shared_ptr<GaussianFactor>> | ||||
|         result = EliminatePreferCholesky(graph, frontalKeys); | ||||
|         conditional_factor = EliminatePreferCholesky(graph, frontalKeys); | ||||
| 
 | ||||
|     // Initialize the keysOfEliminated to be the keys of the
 | ||||
|     // eliminated GaussianConditional
 | ||||
|     keysOfEliminated = result.first->keys(); | ||||
|     keysOfSeparator = result.second->keys(); | ||||
|     keysOfEliminated = conditional_factor.first->keys(); | ||||
|     keysOfSeparator = conditional_factor.second->keys(); | ||||
| 
 | ||||
| #ifdef HYBRID_TIMING | ||||
|     gttoc_(hybrid_eliminate); | ||||
| #endif | ||||
| 
 | ||||
|     return result; | ||||
|     return {conditional_factor.first, {conditional_factor.second, 0.0}}; | ||||
|   }; | ||||
| 
 | ||||
|   // Perform elimination!
 | ||||
|  | @ -246,8 +248,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | |||
| 
 | ||||
|   // Separate out decision tree into conditionals and remaining factors.
 | ||||
|   auto pair = unzip(eliminationResults); | ||||
| 
 | ||||
|   const GaussianMixtureFactor::Factors &separatorFactors = pair.second; | ||||
|   const auto &separatorFactors = pair.second; | ||||
| 
 | ||||
|   // Create the GaussianMixture from the conditionals
 | ||||
|   auto conditional = boost::make_shared<GaussianMixture>( | ||||
|  | @ -257,13 +258,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | |||
|   // DiscreteFactor, with the error for each discrete choice.
 | ||||
|   if (keysOfSeparator.empty()) { | ||||
|     VectorValues empty_values; | ||||
|     auto factorProb = [&](const GaussianFactor::shared_ptr &factor) { | ||||
|     auto factorProb = | ||||
|         [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { | ||||
|           GaussianFactor::shared_ptr factor = factor_z.factor; | ||||
|           if (!factor) { | ||||
|             return 0.0;  // If nullptr, return 0.0 probability
 | ||||
|           } else { | ||||
|             // This is the probability q(μ) at the MLE point.
 | ||||
|             double error = | ||||
|             0.5 * std::abs(factor->augmentedInformation().determinant()); | ||||
|                 0.5 * std::abs(factor->augmentedInformation().determinant()) + | ||||
|                 factor_z.constant; | ||||
|             return std::exp(-error); | ||||
|           } | ||||
|         }; | ||||
|  | @ -452,6 +456,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( | |||
| 
 | ||||
|   // Iterate over each factor.
 | ||||
|   for (size_t idx = 0; idx < size(); idx++) { | ||||
|     // TODO(dellaert): just use a virtual method defined in HybridFactor.
 | ||||
|     AlgebraicDecisionTree<Key> factor_error; | ||||
| 
 | ||||
|     if (factors_.at(idx)->isHybrid()) { | ||||
|  | @ -491,38 +496,17 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double HybridGaussianFactorGraph::error( | ||||
|     const VectorValues &continuousValues, | ||||
|     const DiscreteValues &discreteValues) const { | ||||
| double HybridGaussianFactorGraph::error(const HybridValues &values) const { | ||||
|   double error = 0.0; | ||||
|   for (size_t idx = 0; idx < size(); idx++) { | ||||
|     auto factor = factors_.at(idx); | ||||
| 
 | ||||
|     if (factor->isHybrid()) { | ||||
|       if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) { | ||||
|         error += c->asMixture()->error(continuousValues, discreteValues); | ||||
|       } | ||||
|       if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) { | ||||
|         error += f->error(continuousValues, discreteValues); | ||||
|       } | ||||
| 
 | ||||
|     } else if (factor->isContinuous()) { | ||||
|       if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) { | ||||
|         error += f->inner()->error(continuousValues); | ||||
|       } | ||||
|       if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) { | ||||
|         error += cg->asGaussian()->error(continuousValues); | ||||
|       } | ||||
|     } | ||||
|   for (auto &factor : factors_) { | ||||
|     error += factor->error(values); | ||||
|   } | ||||
|   return error; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double HybridGaussianFactorGraph::probPrime( | ||||
|     const VectorValues &continuousValues, | ||||
|     const DiscreteValues &discreteValues) const { | ||||
|   double error = this->error(continuousValues, discreteValues); | ||||
| double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { | ||||
|   double error = this->error(values); | ||||
|   // NOTE: The 0.5 term is handled by each factor
 | ||||
|   return std::exp(-error); | ||||
| } | ||||
|  |  | |||
|  | @ -12,7 +12,7 @@ | |||
| /**
 | ||||
|  * @file   HybridGaussianFactorGraph.h | ||||
|  * @brief  Linearized Hybrid factor graph that uses type erasure | ||||
|  * @author Fan Jiang, Varun Agrawal | ||||
|  * @author Fan Jiang, Varun Agrawal, Frank Dellaert | ||||
|  * @date   Mar 11, 2022 | ||||
|  */ | ||||
| 
 | ||||
|  | @ -38,6 +38,7 @@ class HybridBayesTree; | |||
| class HybridJunctionTree; | ||||
| class DecisionTreeFactor; | ||||
| class JacobianFactor; | ||||
| class HybridValues; | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Main elimination function for HybridGaussianFactorGraph. | ||||
|  | @ -186,14 +187,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | |||
|    * @brief Compute error given a continuous vector values | ||||
|    * and a discrete assignment. | ||||
|    * | ||||
|    * @param continuousValues The continuous VectorValues | ||||
|    * for computing the error. | ||||
|    * @param discreteValues The specific discrete assignment | ||||
|    * whose error we wish to compute. | ||||
|    * @return double | ||||
|    */ | ||||
|   double error(const VectorValues& continuousValues, | ||||
|                const DiscreteValues& discreteValues) const; | ||||
|   double error(const HybridValues& values) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ | ||||
|  | @ -210,13 +206,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | |||
|    * @brief Compute the unnormalized posterior probability for a continuous | ||||
|    * vector values given a specific assignment. | ||||
|    * | ||||
|    * @param continuousValues The vector values for which to compute the | ||||
|    * posterior probability. | ||||
|    * @param discreteValues The specific assignment to use for the computation. | ||||
|    * @return double | ||||
|    */ | ||||
|   double probPrime(const VectorValues& continuousValues, | ||||
|                    const DiscreteValues& discreteValues) const; | ||||
|   double probPrime(const HybridValues& values) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return a Colamd constrained ordering where the discrete keys are | ||||
|  |  | |||
|  | @ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor { | |||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Standard Interface
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   NonlinearFactor::shared_ptr inner() const { return inner_; } | ||||
| 
 | ||||
|   /// Error for HybridValues is not provided for nonlinear factor.
 | ||||
|   double error(const HybridValues &values) const override { | ||||
|     throw std::runtime_error( | ||||
|         "HybridNonlinearFactor::error(HybridValues) not implemented."); | ||||
|   } | ||||
| 
 | ||||
|   /// Linearize to a HybridGaussianFactor at the linearization point `c`.
 | ||||
|   boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const { | ||||
|     return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c)); | ||||
|   } | ||||
| 
 | ||||
|   /// @}
 | ||||
| }; | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor { | |||
|                              factor, continuousValues); | ||||
|   } | ||||
| 
 | ||||
|   /// Error for HybridValues is not provided for nonlinear hybrid factor.
 | ||||
|   double error(const HybridValues &values) const override { | ||||
|     throw std::runtime_error( | ||||
|         "MixtureFactor::error(HybridValues) not implemented."); | ||||
|   } | ||||
| 
 | ||||
|   size_t dim() const { | ||||
|     // TODO(Varun)
 | ||||
|     throw std::runtime_error("MixtureFactor::dim not implemented."); | ||||
|  |  | |||
|  | @ -183,10 +183,8 @@ class HybridGaussianFactorGraph { | |||
|   bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; | ||||
| 
 | ||||
|   // evaluation | ||||
|   double error(const gtsam::VectorValues& continuousValues, | ||||
|                const gtsam::DiscreteValues& discreteValues) const; | ||||
|   double probPrime(const gtsam::VectorValues& continuousValues, | ||||
|                    const gtsam::DiscreteValues& discreteValues) const; | ||||
|   double error(const gtsam::HybridValues& values) const; | ||||
|   double probPrime(const gtsam::HybridValues& values) const; | ||||
| 
 | ||||
|   gtsam::HybridBayesNet* eliminateSequential(); | ||||
|   gtsam::HybridBayesNet* eliminateSequential( | ||||
|  |  | |||
|  | @ -128,9 +128,9 @@ TEST(GaussianMixture, Error) { | |||
|   // Regression for non-tree version.
 | ||||
|   DiscreteValues assignment; | ||||
|   assignment[M(1)] = 0; | ||||
|   EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); | ||||
|   EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8); | ||||
|   assignment[M(1)] = 1; | ||||
|   EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), | ||||
|   EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}), | ||||
|                        1e-8); | ||||
| } | ||||
| 
 | ||||
|  | @ -179,7 +179,9 @@ TEST(GaussianMixture, Likelihood) { | |||
|   const GaussianMixtureFactor::Factors factors( | ||||
|       gm.conditionals(), | ||||
|       [measurements](const GaussianConditional::shared_ptr& conditional) { | ||||
|         return conditional->likelihood(measurements); | ||||
|         return GaussianMixtureFactor::FactorAndConstant{ | ||||
|             conditional->likelihood(measurements), | ||||
|             conditional->logNormalizationConstant()}; | ||||
|       }); | ||||
|   const GaussianMixtureFactor expected({X(0)}, {mode}, factors); | ||||
|   EXPECT(assert_equal(*factor, expected)); | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ | |||
| #include <gtsam/discrete/DiscreteValues.h> | ||||
| #include <gtsam/hybrid/GaussianMixture.h> | ||||
| #include <gtsam/hybrid/GaussianMixtureFactor.h> | ||||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| #include <gtsam/inference/Symbol.h> | ||||
| #include <gtsam/linear/GaussianFactorGraph.h> | ||||
| 
 | ||||
|  | @ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) { | |||
|   DiscreteValues discreteValues; | ||||
|   discreteValues[m1.first] = 1; | ||||
|   EXPECT_DOUBLES_EQUAL( | ||||
|       4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9); | ||||
|       4.0, mixtureFactor.error({continuousValues, discreteValues}), | ||||
|       1e-9); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) { | |||
|   double total_error = 0; | ||||
|   for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { | ||||
|     if (hybridBayesNet->at(idx)->isHybrid()) { | ||||
|       double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(), | ||||
|                                                            discrete_values); | ||||
|       double error = hybridBayesNet->atMixture(idx)->error( | ||||
|           {delta.continuous(), discrete_values}); | ||||
|       total_error += error; | ||||
|     } else if (hybridBayesNet->at(idx)->isContinuous()) { | ||||
|       double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); | ||||
|  | @ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) { | |||
|   } | ||||
| 
 | ||||
|   EXPECT_DOUBLES_EQUAL( | ||||
|       total_error, hybridBayesNet->error(delta.continuous(), discrete_values), | ||||
|       total_error, hybridBayesNet->error({delta.continuous(), discrete_values}), | ||||
|       1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); | ||||
|  |  | |||
|  | @ -273,7 +273,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree( | |||
|       continue; | ||||
|     } | ||||
| 
 | ||||
|     double error = graph.error(delta, assignment); | ||||
|     double error = graph.error({delta, assignment}); | ||||
|     probPrimes.push_back(exp(-error)); | ||||
|   } | ||||
|   AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes); | ||||
|  | @ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) { | |||
|          const HybridValues& sample) -> double { | ||||
|     const DiscreteValues assignment = sample.discrete(); | ||||
|     // Compute in log form for numerical stability
 | ||||
|     double log_ratio = bayesNet->error(sample.continuous(), assignment) - | ||||
|                        factorGraph->error(sample.continuous(), assignment); | ||||
|     double log_ratio = bayesNet->error({sample.continuous(), assignment}) - | ||||
|                        factorGraph->error({sample.continuous(), assignment}); | ||||
|     double ratio = exp(-log_ratio); | ||||
|     return ratio; | ||||
|   }; | ||||
|  |  | |||
|  | @ -575,18 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { | |||
|   HybridBayesNet::shared_ptr hybridBayesNet = | ||||
|       graph.eliminateSequential(hybridOrdering); | ||||
| 
 | ||||
|   HybridValues delta = hybridBayesNet->optimize(); | ||||
|   double error = graph.error(delta.continuous(), delta.discrete()); | ||||
| 
 | ||||
|   double expected_error = 0.490243199; | ||||
|   // regression
 | ||||
|   EXPECT(assert_equal(expected_error, error, 1e-9)); | ||||
| 
 | ||||
|   double probs = exp(-error); | ||||
|   double expected_probs = graph.probPrime(delta.continuous(), delta.discrete()); | ||||
|   const HybridValues delta = hybridBayesNet->optimize(); | ||||
|   const double error = graph.error(delta); | ||||
| 
 | ||||
|   // regression
 | ||||
|   EXPECT(assert_equal(expected_probs, probs, 1e-7)); | ||||
|   EXPECT(assert_equal(1.58886, error, 1e-5)); | ||||
| 
 | ||||
|   // Real test:
 | ||||
|   EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7)); | ||||
| } | ||||
| 
 | ||||
| /* ****************************************************************************/ | ||||
|  |  | |||
|  | @ -168,26 +168,30 @@ namespace gtsam { | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| double GaussianConditional::logDeterminant() const { | ||||
|   double logDet; | ||||
|   if (this->get_model()) { | ||||
|     Vector diag = this->R().diagonal(); | ||||
|     this->get_model()->whitenInPlace(diag); | ||||
|     logDet = diag.unaryExpr([](double x) { return log(x); }).sum(); | ||||
|   if (get_model()) { | ||||
|     Vector diag = R().diagonal(); | ||||
|     get_model()->whitenInPlace(diag); | ||||
|     return diag.unaryExpr([](double x) { return log(x); }).sum(); | ||||
|   } else { | ||||
|     logDet = | ||||
|         this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); | ||||
|     return R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); | ||||
|   } | ||||
|   return logDet; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| //  density = exp(-error(x)) / sqrt((2*pi)^n*det(Sigma))
 | ||||
| //  log = -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
 | ||||
| double GaussianConditional::logDensity(const VectorValues& x) const { | ||||
| //  normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
 | ||||
| //  log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
 | ||||
| double GaussianConditional::logNormalizationConstant() const { | ||||
|   constexpr double log2pi = 1.8378770664093454835606594728112; | ||||
|   size_t n = d().size(); | ||||
|   // log det(Sigma)) = - 2.0 * logDeterminant()
 | ||||
|   return - error(x) - 0.5 * n * log2pi + logDeterminant(); | ||||
|   return - 0.5 * n * log2pi + logDeterminant(); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| //  density = k exp(-error(x))
 | ||||
| //  log = log(k) -error(x) - 0.5 * n*log(2*pi)
 | ||||
| double GaussianConditional::logDensity(const VectorValues& x) const { | ||||
|   return logNormalizationConstant() - error(x); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -169,7 +169,7 @@ namespace gtsam { | |||
|      * | ||||
|      * @return double | ||||
|      */ | ||||
|     double determinant() const { return exp(this->logDeterminant()); } | ||||
|     inline double determinant() const { return exp(logDeterminant()); } | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Compute the log determinant of the R matrix. | ||||
|  | @ -184,6 +184,19 @@ namespace gtsam { | |||
|      */ | ||||
|     double logDeterminant() const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) | ||||
|      * log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) | ||||
|      */ | ||||
|     double logNormalizationConstant() const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) | ||||
|      */ | ||||
|     inline double normalizationConstant() const { | ||||
|       return exp(logNormalizationConstant()); | ||||
|     } | ||||
| 
 | ||||
|     /**
 | ||||
|     * Solves a conditional Gaussian and writes the solution into the entries of | ||||
|     * \c x for each frontal variable of the conditional.  The parents are | ||||
|  |  | |||
|  | @ -6,7 +6,7 @@ All Rights Reserved | |||
| See LICENSE for the license information | ||||
| 
 | ||||
| Unit tests for Hybrid Factor Graphs. | ||||
| Author: Fan Jiang | ||||
| Author: Fan Jiang, Varun Agrawal, Frank Dellaert | ||||
| """ | ||||
| # pylint: disable=invalid-name, no-name-in-module, no-member | ||||
| 
 | ||||
|  | @ -18,13 +18,14 @@ from gtsam.utils.test_case import GtsamTestCase | |||
| 
 | ||||
| import gtsam | ||||
| from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, | ||||
|                    GaussianMixture, GaussianMixtureFactor, | ||||
|                    GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues, | ||||
|                    HybridGaussianFactorGraph, JacobianFactor, Ordering, | ||||
|                    noiseModel) | ||||
| 
 | ||||
| 
 | ||||
| class TestHybridGaussianFactorGraph(GtsamTestCase): | ||||
|     """Unit tests for HybridGaussianFactorGraph.""" | ||||
| 
 | ||||
|     def test_create(self): | ||||
|         """Test construction of hybrid factor graph.""" | ||||
|         model = noiseModel.Unit.Create(3) | ||||
|  | @ -81,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         self.assertEqual(hv.atDiscrete(C(0)), 1) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: | ||||
|     def tiny(num_measurements: int = 1) -> HybridBayesNet: | ||||
|         """ | ||||
|         Create a tiny two variable hybrid model which represents | ||||
|         the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). | ||||
|         """ | ||||
|         # Create hybrid Bayes net. | ||||
|         bayesNet = gtsam.HybridBayesNet() | ||||
|         bayesNet = HybridBayesNet() | ||||
| 
 | ||||
|         # Create mode key: 0 is low-noise, 1 is high-noise. | ||||
|         mode = (M(0), 2) | ||||
|  | @ -113,35 +114,76 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         bayesNet.addGaussian(prior_on_x0) | ||||
| 
 | ||||
|         # Add prior on mode. | ||||
|         bayesNet.emplaceDiscrete(mode, "1/1") | ||||
|         bayesNet.emplaceDiscrete(mode, "4/6") | ||||
| 
 | ||||
|         return bayesNet | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): | ||||
|         """Create a factor graph from the Bayes net with sampled measurements. | ||||
|             The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` | ||||
|             and thus represents the same joint probability as the Bayes net. | ||||
|         """ | ||||
|         fg = HybridGaussianFactorGraph() | ||||
|         num_measurements = bayesNet.size() - 2 | ||||
|         for i in range(num_measurements): | ||||
|             conditional = bayesNet.atMixture(i) | ||||
|             measurement = gtsam.VectorValues() | ||||
|             measurement.insert(Z(i), sample.at(Z(i))) | ||||
|             factor = conditional.likelihood(measurement) | ||||
|             fg.push_back(factor) | ||||
|         fg.push_back(bayesNet.atGaussian(num_measurements)) | ||||
|         fg.push_back(bayesNet.atDiscrete(num_measurements+1)) | ||||
|         return fg | ||||
| 
 | ||||
|     @classmethod | ||||
|     def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000): | ||||
|         """Do importance sampling to get an estimate of the discrete marginal P(mode).""" | ||||
|         # Use prior on x0, mode as proposal density. | ||||
|         prior = cls.tiny(num_measurements=0)  # just P(x0)P(mode) | ||||
| 
 | ||||
|         # Allocate space for marginals. | ||||
|         marginals = np.zeros((2,)) | ||||
| 
 | ||||
|         # Do importance sampling. | ||||
|         num_measurements = bayesNet.size() - 2 | ||||
|         for s in range(N): | ||||
|             proposed = prior.sample() | ||||
|             for i in range(num_measurements): | ||||
|                 z_i = sample.at(Z(i)) | ||||
|                 proposed.insert(Z(i), z_i) | ||||
|             weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) | ||||
|             marginals[proposed.atDiscrete(M(0))] += weight | ||||
| 
 | ||||
|         # print marginals: | ||||
|         marginals /= marginals.sum() | ||||
|         return marginals | ||||
| 
 | ||||
|     def test_tiny(self): | ||||
|         """Test a tiny two variable hybrid model.""" | ||||
|         bayesNet = self.tiny() | ||||
|         sample = bayesNet.sample() | ||||
|         # print(sample) | ||||
| 
 | ||||
|         # Create a factor graph from the Bayes net with sampled measurements. | ||||
|         fg = HybridGaussianFactorGraph() | ||||
|         conditional = bayesNet.atMixture(0) | ||||
|         measurement = gtsam.VectorValues() | ||||
|         measurement.insert(Z(0), sample.at(Z(0))) | ||||
|         factor = conditional.likelihood(measurement) | ||||
|         fg.push_back(factor) | ||||
|         fg.push_back(bayesNet.atGaussian(1)) | ||||
|         fg.push_back(bayesNet.atDiscrete(2)) | ||||
|         # Estimate marginals using importance sampling. | ||||
|         marginals = self.estimate_marginals(bayesNet, sample) | ||||
|         # print(f"True mode: {sample.atDiscrete(M(0))}") | ||||
|         # print(f"P(mode=0; z0) = {marginals[0]}") | ||||
|         # print(f"P(mode=1; z0) = {marginals[1]}") | ||||
| 
 | ||||
|         # Check that the estimate is close to the true value. | ||||
|         self.assertAlmostEqual(marginals[0], 0.4, delta=0.1) | ||||
|         self.assertAlmostEqual(marginals[1], 0.6, delta=0.1) | ||||
| 
 | ||||
|         fg = self.factor_graph_from_bayes_net(bayesNet, sample) | ||||
|         self.assertEqual(fg.size(), 3) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def calculate_ratio(bayesNet, fg, sample): | ||||
|     def calculate_ratio(bayesNet: HybridBayesNet, | ||||
|                         fg: HybridGaussianFactorGraph, | ||||
|                         sample: HybridValues): | ||||
|         """Calculate ratio  between Bayes net probability and the factor graph.""" | ||||
|         continuous = gtsam.VectorValues() | ||||
|         continuous.insert(X(0), sample.at(X(0))) | ||||
|         return bayesNet.evaluate(sample) / fg.probPrime( | ||||
|             continuous, sample.discrete()) | ||||
|         return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 | ||||
| 
 | ||||
|     def test_ratio(self): | ||||
|         """ | ||||
|  | @ -153,23 +195,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) | ||||
|         bayesNet = self.tiny(num_measurements=2) | ||||
|         # Sample from the Bayes net. | ||||
|         sample: gtsam.HybridValues = bayesNet.sample() | ||||
|         sample: HybridValues = bayesNet.sample() | ||||
|         # print(sample) | ||||
| 
 | ||||
|         # Create a factor graph from the Bayes net with sampled measurements. | ||||
|         # The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)` | ||||
|         # and thus represents the same joint probability as the Bayes net. | ||||
|         fg = HybridGaussianFactorGraph() | ||||
|         for i in range(2): | ||||
|             conditional = bayesNet.atMixture(i) | ||||
|             measurement = gtsam.VectorValues() | ||||
|             measurement.insert(Z(i), sample.at(Z(i))) | ||||
|             factor = conditional.likelihood(measurement) | ||||
|             fg.push_back(factor) | ||||
|         fg.push_back(bayesNet.atGaussian(2)) | ||||
|         fg.push_back(bayesNet.atDiscrete(3)) | ||||
|         # Estimate marginals using importance sampling. | ||||
|         marginals = self.estimate_marginals(bayesNet, sample) | ||||
|         # print(f"True mode: {sample.atDiscrete(M(0))}") | ||||
|         # print(f"P(mode=0; z0, z1) = {marginals[0]}") | ||||
|         # print(f"P(mode=1; z0, z1) = {marginals[1]}") | ||||
| 
 | ||||
|         # print(fg) | ||||
|         # Check marginals based on sampled mode. | ||||
|         if sample.atDiscrete(M(0)) == 0: | ||||
|             self.assertGreater(marginals[0], marginals[1]) | ||||
|         else: | ||||
|             self.assertGreater(marginals[1], marginals[0]) | ||||
| 
 | ||||
|         fg = self.factor_graph_from_bayes_net(bayesNet, sample) | ||||
|         self.assertEqual(fg.size(), 4) | ||||
| 
 | ||||
|         # Calculate ratio between Bayes net probability and the factor graph: | ||||
|  | @ -185,10 +226,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         for i in range(10): | ||||
|             other = bayesNet.sample() | ||||
|             other.update(measurements) | ||||
|             # print(other) | ||||
|             # ratio = self.calculate_ratio(bayesNet, fg, other) | ||||
|             ratio = self.calculate_ratio(bayesNet, fg, other) | ||||
|             # print(f"Ratio: {ratio}\n") | ||||
|             # self.assertAlmostEqual(ratio, expected_ratio) | ||||
|             if (ratio > 0): | ||||
|                 self.assertAlmostEqual(ratio, expected_ratio) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue