Merge pull request #1358 from borglab/hybrid/gaussian-mixture-factor
						commit
						d0821a57de
					
				|  | @ -22,6 +22,7 @@ | ||||||
| #include <gtsam/discrete/DiscreteValues.h> | #include <gtsam/discrete/DiscreteValues.h> | ||||||
| #include <gtsam/hybrid/GaussianMixture.h> | #include <gtsam/hybrid/GaussianMixture.h> | ||||||
| #include <gtsam/hybrid/GaussianMixtureFactor.h> | #include <gtsam/hybrid/GaussianMixtureFactor.h> | ||||||
|  | #include <gtsam/hybrid/HybridValues.h> | ||||||
| #include <gtsam/inference/Conditional-inst.h> | #include <gtsam/inference/Conditional-inst.h> | ||||||
| #include <gtsam/linear/GaussianFactorGraph.h> | #include <gtsam/linear/GaussianFactorGraph.h> | ||||||
| 
 | 
 | ||||||
|  | @ -149,17 +150,19 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood( | ||||||
|   const DiscreteKeys discreteParentKeys = discreteKeys(); |   const DiscreteKeys discreteParentKeys = discreteKeys(); | ||||||
|   const KeyVector continuousParentKeys = continuousParents(); |   const KeyVector continuousParentKeys = continuousParents(); | ||||||
|   const GaussianMixtureFactor::Factors likelihoods( |   const GaussianMixtureFactor::Factors likelihoods( | ||||||
|       conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { |       conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { | ||||||
|         return conditional->likelihood(frontals); |         return GaussianMixtureFactor::FactorAndConstant{ | ||||||
|  |             conditional->likelihood(frontals), | ||||||
|  |             conditional->logNormalizationConstant()}; | ||||||
|       }); |       }); | ||||||
|   return boost::make_shared<GaussianMixtureFactor>( |   return boost::make_shared<GaussianMixtureFactor>( | ||||||
|       continuousParentKeys, discreteParentKeys, likelihoods); |       continuousParentKeys, discreteParentKeys, likelihoods); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) { | std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { | ||||||
|   std::set<DiscreteKey> s; |   std::set<DiscreteKey> s; | ||||||
|   s.insert(dkeys.begin(), dkeys.end()); |   s.insert(discreteKeys.begin(), discreteKeys.end()); | ||||||
|   return s; |   return s; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -184,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { | ||||||
|                     const GaussianConditional::shared_ptr &conditional) |                     const GaussianConditional::shared_ptr &conditional) | ||||||
|       -> GaussianConditional::shared_ptr { |       -> GaussianConditional::shared_ptr { | ||||||
|     // typecast so we can use this to get probability value
 |     // typecast so we can use this to get probability value
 | ||||||
|     DiscreteValues values(choices); |     const DiscreteValues values(choices); | ||||||
| 
 | 
 | ||||||
|     // Case where the gaussian mixture has the same
 |     // Case where the gaussian mixture has the same
 | ||||||
|     // discrete keys as the decision tree.
 |     // discrete keys as the decision tree.
 | ||||||
|  | @ -254,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| double GaussianMixture::error(const VectorValues &continuousValues, | double GaussianMixture::error(const HybridValues &values) const { | ||||||
|                               const DiscreteValues &discreteValues) const { |  | ||||||
|   // Directly index to get the conditional, no need to build the whole tree.
 |   // Directly index to get the conditional, no need to build the whole tree.
 | ||||||
|   auto conditional = conditionals_(discreteValues); |   auto conditional = conditionals_(values.discrete()); | ||||||
|   return conditional->error(continuousValues); |   return conditional->error(values.continuous()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -30,6 +30,7 @@ | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
| class GaussianMixtureFactor; | class GaussianMixtureFactor; | ||||||
|  | class HybridValues; | ||||||
| 
 | 
 | ||||||
| /**
 | /**
 | ||||||
|  * @brief A conditional of gaussian mixtures indexed by discrete variables, as |  * @brief A conditional of gaussian mixtures indexed by discrete variables, as | ||||||
|  | @ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture | ||||||
|   /// @name Constructors
 |   /// @name Constructors
 | ||||||
|   /// @{
 |   /// @{
 | ||||||
| 
 | 
 | ||||||
|   /// Defaut constructor, mainly for serialization.
 |   /// Default constructor, mainly for serialization.
 | ||||||
|   GaussianMixture() = default; |   GaussianMixture() = default; | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|  | @ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture | ||||||
|   /// @name Standard API
 |   /// @name Standard API
 | ||||||
|   /// @{
 |   /// @{
 | ||||||
| 
 | 
 | ||||||
|  |   /// @brief Return the conditional Gaussian for the given discrete assignment.
 | ||||||
|   GaussianConditional::shared_ptr operator()( |   GaussianConditional::shared_ptr operator()( | ||||||
|       const DiscreteValues &discreteValues) const; |       const DiscreteValues &discreteValues) const; | ||||||
| 
 | 
 | ||||||
|  | @ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture | ||||||
|    * @brief Compute the error of this Gaussian Mixture given the continuous |    * @brief Compute the error of this Gaussian Mixture given the continuous | ||||||
|    * values and a discrete assignment. |    * values and a discrete assignment. | ||||||
|    * |    * | ||||||
|    * @param continuousValues Continuous values at which to compute the error. |    * @param values Continuous values and discrete assignment. | ||||||
|    * @param discreteValues The discrete assignment for a specific mode sequence. |  | ||||||
|    * @return double |    * @return double | ||||||
|    */ |    */ | ||||||
|   double error(const VectorValues &continuousValues, |   double error(const HybridValues &values) const override; | ||||||
|                const DiscreteValues &discreteValues) const; |  | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Prune the decision tree of Gaussian factors as per the discrete |    * @brief Prune the decision tree of Gaussian factors as per the discrete | ||||||
|  | @ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| /// Return the DiscreteKey vector as a set.
 | /// Return the DiscreteKey vector as a set.
 | ||||||
| std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys); | std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys); | ||||||
| 
 | 
 | ||||||
| // traits
 | // traits
 | ||||||
| template <> | template <> | ||||||
|  |  | ||||||
|  | @ -22,6 +22,8 @@ | ||||||
| #include <gtsam/discrete/DecisionTree-inl.h> | #include <gtsam/discrete/DecisionTree-inl.h> | ||||||
| #include <gtsam/discrete/DecisionTree.h> | #include <gtsam/discrete/DecisionTree.h> | ||||||
| #include <gtsam/hybrid/GaussianMixtureFactor.h> | #include <gtsam/hybrid/GaussianMixtureFactor.h> | ||||||
|  | #include <gtsam/hybrid/HybridValues.h> | ||||||
|  | #include <gtsam/linear/GaussianFactor.h> | ||||||
| #include <gtsam/linear/GaussianFactorGraph.h> | #include <gtsam/linear/GaussianFactorGraph.h> | ||||||
| 
 | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
|  | @ -29,8 +31,11 @@ namespace gtsam { | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, | GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||||
|                                              const DiscreteKeys &discreteKeys, |                                              const DiscreteKeys &discreteKeys, | ||||||
|                                              const Factors &factors) |                                              const Mixture &factors) | ||||||
|     : Base(continuousKeys, discreteKeys), factors_(factors) {} |     : Base(continuousKeys, discreteKeys), | ||||||
|  |       factors_(factors, [](const GaussianFactor::shared_ptr &gf) { | ||||||
|  |         return FactorAndConstant{gf, 0.0}; | ||||||
|  |       }) {} | ||||||
| 
 | 
 | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { | bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { | ||||||
|  | @ -43,11 +48,11 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { | ||||||
| 
 | 
 | ||||||
|   // Check the base and the factors:
 |   // Check the base and the factors:
 | ||||||
|   return Base::equals(*e, tol) && |   return Base::equals(*e, tol) && | ||||||
|          factors_.equals(e->factors_, |          factors_.equals(e->factors_, [tol](const FactorAndConstant &f1, | ||||||
|                          [tol](const GaussianFactor::shared_ptr &f1, |                                             const FactorAndConstant &f2) { | ||||||
|                                const GaussianFactor::shared_ptr &f2) { |            return f1.factor->equals(*(f2.factor), tol) && | ||||||
|                            return f1->equals(*f2, tol); |                   std::abs(f1.constant - f2.constant) < tol; | ||||||
|                          }); |          }); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
|  | @ -60,7 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s, | ||||||
|   } else { |   } else { | ||||||
|     factors_.print( |     factors_.print( | ||||||
|         "", [&](Key k) { return formatter(k); }, |         "", [&](Key k) { return formatter(k); }, | ||||||
|         [&](const GaussianFactor::shared_ptr &gf) -> std::string { |         [&](const FactorAndConstant &gf_z) -> std::string { | ||||||
|  |           auto gf = gf_z.factor; | ||||||
|           RedirectCout rd; |           RedirectCout rd; | ||||||
|           std::cout << ":\n"; |           std::cout << ":\n"; | ||||||
|           if (gf && !gf->empty()) { |           if (gf && !gf->empty()) { | ||||||
|  | @ -75,8 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { | const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { | ||||||
|   return factors_; |   return Mixture(factors_, [](const FactorAndConstant &factor_z) { | ||||||
|  |     return factor_z.factor; | ||||||
|  |   }); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
|  | @ -95,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add( | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() | GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() | ||||||
|     const { |     const { | ||||||
|   auto wrap = [](const GaussianFactor::shared_ptr &factor) { |   auto wrap = [](const FactorAndConstant &factor_z) { | ||||||
|     GaussianFactorGraph result; |     GaussianFactorGraph result; | ||||||
|     result.push_back(factor); |     result.push_back(factor_z.factor); | ||||||
|     return result; |     return result; | ||||||
|   }; |   }; | ||||||
|   return {factors_, wrap}; |   return {factors_, wrap}; | ||||||
|  | @ -107,21 +115,18 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() | ||||||
| AlgebraicDecisionTree<Key> GaussianMixtureFactor::error( | AlgebraicDecisionTree<Key> GaussianMixtureFactor::error( | ||||||
|     const VectorValues &continuousValues) const { |     const VectorValues &continuousValues) const { | ||||||
|   // functor to convert from sharedFactor to double error value.
 |   // functor to convert from sharedFactor to double error value.
 | ||||||
|   auto errorFunc = |   auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) { | ||||||
|       [continuousValues](const GaussianFactor::shared_ptr &factor) { |     return factor_z.error(continuousValues); | ||||||
|         return factor->error(continuousValues); |   }; | ||||||
|       }; |  | ||||||
|   DecisionTree<Key, double> errorTree(factors_, errorFunc); |   DecisionTree<Key, double> errorTree(factors_, errorFunc); | ||||||
|   return errorTree; |   return errorTree; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| double GaussianMixtureFactor::error( | double GaussianMixtureFactor::error(const HybridValues &values) const { | ||||||
|     const VectorValues &continuousValues, |   const FactorAndConstant factor_z = factors_(values.discrete()); | ||||||
|     const DiscreteValues &discreteValues) const { |   return factor_z.error(values.continuous()); | ||||||
|   // Directly index to get the conditional, no need to build the whole tree.
 |  | ||||||
|   auto factor = factors_(discreteValues); |  | ||||||
|   return factor->error(continuousValues); |  | ||||||
| } | } | ||||||
|  | /* *******************************************************************************/ | ||||||
| 
 | 
 | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -23,17 +23,15 @@ | ||||||
| #include <gtsam/discrete/AlgebraicDecisionTree.h> | #include <gtsam/discrete/AlgebraicDecisionTree.h> | ||||||
| #include <gtsam/discrete/DecisionTree.h> | #include <gtsam/discrete/DecisionTree.h> | ||||||
| #include <gtsam/discrete/DiscreteKey.h> | #include <gtsam/discrete/DiscreteKey.h> | ||||||
| #include <gtsam/discrete/DiscreteValues.h> | #include <gtsam/hybrid/HybridFactor.h> | ||||||
| #include <gtsam/hybrid/HybridGaussianFactor.h> |  | ||||||
| #include <gtsam/linear/GaussianFactor.h> | #include <gtsam/linear/GaussianFactor.h> | ||||||
| #include <gtsam/linear/VectorValues.h> |  | ||||||
| 
 | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
| class GaussianFactorGraph; | class GaussianFactorGraph; | ||||||
| 
 | class HybridValues; | ||||||
| // Needed for wrapper.
 | class DiscreteValues; | ||||||
| using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>; | class VectorValues; | ||||||
| 
 | 
 | ||||||
| /**
 | /**
 | ||||||
|  * @brief Implementation of a discrete conditional mixture factor. |  * @brief Implementation of a discrete conditional mixture factor. | ||||||
|  | @ -53,9 +51,29 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | ||||||
|   using shared_ptr = boost::shared_ptr<This>; |   using shared_ptr = boost::shared_ptr<This>; | ||||||
| 
 | 
 | ||||||
|   using Sum = DecisionTree<Key, GaussianFactorGraph>; |   using Sum = DecisionTree<Key, GaussianFactorGraph>; | ||||||
|  |   using sharedFactor = boost::shared_ptr<GaussianFactor>; | ||||||
| 
 | 
 | ||||||
|   /// typedef for Decision Tree of Gaussian Factors
 |   /// Gaussian factor and log of normalizing constant.
 | ||||||
|   using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>; |   struct FactorAndConstant { | ||||||
|  |     sharedFactor factor; | ||||||
|  |     double constant; | ||||||
|  | 
 | ||||||
|  |     // Return error with constant correction.
 | ||||||
|  |     double error(const VectorValues &values) const { | ||||||
|  |       // Note minus sign: constant is log of normalization constant for probabilities.
 | ||||||
|  |       // Errors is the negative log-likelihood, hence we subtract the constant here.
 | ||||||
|  |       return factor->error(values) - constant; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Check pointer equality.
 | ||||||
|  |     bool operator==(const FactorAndConstant &other) const { | ||||||
|  |       return factor == other.factor && constant == other.constant; | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|  |   /// typedef for Decision Tree of Gaussian factors and log-constant.
 | ||||||
|  |   using Factors = DecisionTree<Key, FactorAndConstant>; | ||||||
|  |   using Mixture = DecisionTree<Key, sharedFactor>; | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   /// Decision tree of Gaussian factors indexed by discrete keys.
 |   /// Decision tree of Gaussian factors indexed by discrete keys.
 | ||||||
|  | @ -82,12 +100,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | ||||||
|    * @param continuousKeys A vector of keys representing continuous variables. |    * @param continuousKeys A vector of keys representing continuous variables. | ||||||
|    * @param discreteKeys A vector of keys representing discrete variables and |    * @param discreteKeys A vector of keys representing discrete variables and | ||||||
|    * their cardinalities. |    * their cardinalities. | ||||||
|    * @param factors The decision tree of Gaussian Factors stored as the mixture |    * @param factors The decision tree of Gaussian factors stored as the mixture | ||||||
|    * density. |    * density. | ||||||
|    */ |    */ | ||||||
|   GaussianMixtureFactor(const KeyVector &continuousKeys, |   GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||||
|                         const DiscreteKeys &discreteKeys, |                         const DiscreteKeys &discreteKeys, | ||||||
|                         const Factors &factors); |                         const Mixture &factors); | ||||||
|  | 
 | ||||||
|  |   GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||||
|  |                         const DiscreteKeys &discreteKeys, | ||||||
|  |                         const Factors &factors_and_z) | ||||||
|  |       : Base(continuousKeys, discreteKeys), factors_(factors_and_z) {} | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Construct a new GaussianMixtureFactor object using a vector of |    * @brief Construct a new GaussianMixtureFactor object using a vector of | ||||||
|  | @ -99,9 +122,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | ||||||
|    */ |    */ | ||||||
|   GaussianMixtureFactor(const KeyVector &continuousKeys, |   GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||||
|                         const DiscreteKeys &discreteKeys, |                         const DiscreteKeys &discreteKeys, | ||||||
|                         const std::vector<GaussianFactor::shared_ptr> &factors) |                         const std::vector<sharedFactor> &factors) | ||||||
|       : GaussianMixtureFactor(continuousKeys, discreteKeys, |       : GaussianMixtureFactor(continuousKeys, discreteKeys, | ||||||
|                               Factors(discreteKeys, factors)) {} |                               Mixture(discreteKeys, factors)) {} | ||||||
| 
 | 
 | ||||||
|   /// @}
 |   /// @}
 | ||||||
|   /// @name Testable
 |   /// @name Testable
 | ||||||
|  | @ -113,9 +136,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | ||||||
|       const std::string &s = "GaussianMixtureFactor\n", |       const std::string &s = "GaussianMixtureFactor\n", | ||||||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; |       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||||
|   /// @}
 |   /// @}
 | ||||||
|  |   /// @name Standard API
 | ||||||
|  |   /// @{
 | ||||||
| 
 | 
 | ||||||
|   /// Getter for the underlying Gaussian Factor Decision Tree.
 |   /// Getter for the underlying Gaussian Factor Decision Tree.
 | ||||||
|   const Factors &factors(); |   const Mixture factors() const; | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while |    * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while | ||||||
|  | @ -137,21 +162,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | ||||||
|   AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; |   AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const; | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Compute the error of this Gaussian Mixture given the continuous |    * @brief Compute the log-likelihood, including the log-normalizing constant. | ||||||
|    * values and a discrete assignment. |  | ||||||
|    * |  | ||||||
|    * @param continuousValues Continuous values at which to compute the error. |  | ||||||
|    * @param discreteValues The discrete assignment for a specific mode sequence. |  | ||||||
|    * @return double |    * @return double | ||||||
|    */ |    */ | ||||||
|   double error(const VectorValues &continuousValues, |   double error(const HybridValues &values) const override; | ||||||
|                const DiscreteValues &discreteValues) const; |  | ||||||
| 
 | 
 | ||||||
|   /// Add MixtureFactor to a Sum, syntactic sugar.
 |   /// Add MixtureFactor to a Sum, syntactic sugar.
 | ||||||
|   friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { |   friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { | ||||||
|     sum = factor.add(sum); |     sum = factor.add(sum); | ||||||
|     return sum; |     return sum; | ||||||
|   } |   } | ||||||
|  |   /// @}
 | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // traits
 | // traits
 | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| /* ----------------------------------------------------------------------------
 | /* ----------------------------------------------------------------------------
 | ||||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, |  * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation, | ||||||
|  * Atlanta, Georgia 30332-0415 |  * Atlanta, Georgia 30332-0415 | ||||||
|  * All Rights Reserved |  * All Rights Reserved | ||||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) |  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||||
|  | @ -12,6 +12,7 @@ | ||||||
|  * @author Fan Jiang |  * @author Fan Jiang | ||||||
|  * @author Varun Agrawal |  * @author Varun Agrawal | ||||||
|  * @author Shangjie Xue |  * @author Shangjie Xue | ||||||
|  |  * @author Frank Dellaert | ||||||
|  * @date   January 2022 |  * @date   January 2022 | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
|  | @ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| double HybridBayesNet::error(const VectorValues &continuousValues, | double HybridBayesNet::error(const HybridValues &values) const { | ||||||
|                              const DiscreteValues &discreteValues) const { |   GaussianBayesNet gbn = choose(values.discrete()); | ||||||
|   GaussianBayesNet gbn = choose(discreteValues); |   return gbn.error(values.continuous()); | ||||||
|   return gbn.error(continuousValues); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  |  | ||||||
|  | @ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||||
|    * @brief 0.5 * sum of squared Mahalanobis distances |    * @brief 0.5 * sum of squared Mahalanobis distances | ||||||
|    * for a specific discrete assignment. |    * for a specific discrete assignment. | ||||||
|    * |    * | ||||||
|    * @param continuousValues Continuous values at which to compute the error. |    * @param values Continuous values and discrete assignment. | ||||||
|    * @param discreteValues Discrete assignment for a specific mode sequence. |  | ||||||
|    * @return double |    * @return double | ||||||
|    */ |    */ | ||||||
|   double error(const VectorValues &continuousValues, |   double error(const HybridValues &values) const; | ||||||
|                const DiscreteValues &discreteValues) const; |  | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Compute conditional error for each discrete assignment, |    * @brief Compute conditional error for each discrete assignment, | ||||||
|  |  | ||||||
|  | @ -52,7 +52,7 @@ namespace gtsam { | ||||||
|  * having diamond inheritances, and neutralized the need to change other |  * having diamond inheritances, and neutralized the need to change other | ||||||
|  * components of GTSAM to make hybrid elimination work. |  * components of GTSAM to make hybrid elimination work. | ||||||
|  * |  * | ||||||
|  * A great reference to the type-erasure pattern is Eduaado Madrid's CppCon |  * A great reference to the type-erasure pattern is Eduardo Madrid's CppCon | ||||||
|  * talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
 |  * talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
 | ||||||
|  * |  * | ||||||
|  * @ingroup hybrid |  * @ingroup hybrid | ||||||
|  | @ -129,33 +129,6 @@ class GTSAM_EXPORT HybridConditional | ||||||
|    */ |    */ | ||||||
|   HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture); |   HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture); | ||||||
| 
 | 
 | ||||||
|   /**
 |  | ||||||
|    * @brief Return HybridConditional as a GaussianMixture |  | ||||||
|    * @return nullptr if not a mixture |  | ||||||
|    * @return GaussianMixture::shared_ptr otherwise |  | ||||||
|    */ |  | ||||||
|   GaussianMixture::shared_ptr asMixture() { |  | ||||||
|     return boost::dynamic_pointer_cast<GaussianMixture>(inner_); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /**
 |  | ||||||
|    * @brief Return HybridConditional as a GaussianConditional |  | ||||||
|    * @return nullptr if not a GaussianConditional |  | ||||||
|    * @return GaussianConditional::shared_ptr otherwise |  | ||||||
|    */ |  | ||||||
|   GaussianConditional::shared_ptr asGaussian() { |  | ||||||
|     return boost::dynamic_pointer_cast<GaussianConditional>(inner_); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /**
 |  | ||||||
|    * @brief Return conditional as a DiscreteConditional |  | ||||||
|    * @return nullptr if not a DiscreteConditional |  | ||||||
|    * @return DiscreteConditional::shared_ptr |  | ||||||
|    */ |  | ||||||
|   DiscreteConditional::shared_ptr asDiscrete() { |  | ||||||
|     return boost::dynamic_pointer_cast<DiscreteConditional>(inner_); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// @}
 |   /// @}
 | ||||||
|   /// @name Testable
 |   /// @name Testable
 | ||||||
|   /// @{
 |   /// @{
 | ||||||
|  | @ -169,10 +142,52 @@ class GTSAM_EXPORT HybridConditional | ||||||
|   bool equals(const HybridFactor& other, double tol = 1e-9) const override; |   bool equals(const HybridFactor& other, double tol = 1e-9) const override; | ||||||
| 
 | 
 | ||||||
|   /// @}
 |   /// @}
 | ||||||
|  |   /// @name Standard Interface
 | ||||||
|  |   /// @{
 | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Return HybridConditional as a GaussianMixture | ||||||
|  |    * @return nullptr if not a mixture | ||||||
|  |    * @return GaussianMixture::shared_ptr otherwise | ||||||
|  |    */ | ||||||
|  |   GaussianMixture::shared_ptr asMixture() const { | ||||||
|  |     return boost::dynamic_pointer_cast<GaussianMixture>(inner_); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Return HybridConditional as a GaussianConditional | ||||||
|  |    * @return nullptr if not a GaussianConditional | ||||||
|  |    * @return GaussianConditional::shared_ptr otherwise | ||||||
|  |    */ | ||||||
|  |   GaussianConditional::shared_ptr asGaussian() const { | ||||||
|  |     return boost::dynamic_pointer_cast<GaussianConditional>(inner_); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Return conditional as a DiscreteConditional | ||||||
|  |    * @return nullptr if not a DiscreteConditional | ||||||
|  |    * @return DiscreteConditional::shared_ptr | ||||||
|  |    */ | ||||||
|  |   DiscreteConditional::shared_ptr asDiscrete() const { | ||||||
|  |     return boost::dynamic_pointer_cast<DiscreteConditional>(inner_); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   /// Get the type-erased pointer to the inner type
 |   /// Get the type-erased pointer to the inner type
 | ||||||
|   boost::shared_ptr<Factor> inner() { return inner_; } |   boost::shared_ptr<Factor> inner() { return inner_; } | ||||||
| 
 | 
 | ||||||
|  |   /// Return the error of the underlying conditional.
 | ||||||
|  |   /// Currently only implemented for Gaussian mixture.
 | ||||||
|  |   double error(const HybridValues& values) const override { | ||||||
|  |     if (auto gm = asMixture()) { | ||||||
|  |       return gm->error(values); | ||||||
|  |     } else { | ||||||
|  |       throw std::runtime_error( | ||||||
|  |           "HybridConditional::error: only implemented for Gaussian mixture"); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// @}
 | ||||||
|  | 
 | ||||||
|  private: |  private: | ||||||
|   /** Serialization function */ |   /** Serialization function */ | ||||||
|   friend class boost::serialization::access; |   friend class boost::serialization::access; | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| #include <gtsam/hybrid/HybridDiscreteFactor.h> | #include <gtsam/hybrid/HybridDiscreteFactor.h> | ||||||
|  | #include <gtsam/hybrid/HybridValues.h> | ||||||
| 
 | 
 | ||||||
| #include <boost/make_shared.hpp> | #include <boost/make_shared.hpp> | ||||||
| 
 | 
 | ||||||
|  | @ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s, | ||||||
|   inner_->print("\n", formatter); |   inner_->print("\n", formatter); | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | /* ************************************************************************ */ | ||||||
|  | double HybridDiscreteFactor::error(const HybridValues &values) const { | ||||||
|  |   return -log((*inner_)(values.discrete())); | ||||||
|  | } | ||||||
|  | /* ************************************************************************ */ | ||||||
|  | 
 | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -24,10 +24,12 @@ | ||||||
| 
 | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
|  | class HybridValues; | ||||||
|  | 
 | ||||||
| /**
 | /**
 | ||||||
|  * A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows |  * A HybridDiscreteFactor is a thin container for DiscreteFactor, which | ||||||
|  * us to hide the implementation of DiscreteFactor and thus avoid diamond |  * allows us to hide the implementation of DiscreteFactor and thus avoid | ||||||
|  * inheritance. |  * diamond inheritance. | ||||||
|  * |  * | ||||||
|  * @ingroup hybrid |  * @ingroup hybrid | ||||||
|  */ |  */ | ||||||
|  | @ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { | ||||||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; |       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||||
| 
 | 
 | ||||||
|   /// @}
 |   /// @}
 | ||||||
|  |   /// @name Standard Interface
 | ||||||
|  |   /// @{
 | ||||||
| 
 | 
 | ||||||
|   /// Return pointer to the internal discrete factor
 |   /// Return pointer to the internal discrete factor
 | ||||||
|   DiscreteFactor::shared_ptr inner() const { return inner_; } |   DiscreteFactor::shared_ptr inner() const { return inner_; } | ||||||
|  | 
 | ||||||
|  |   /// Return the error of the underlying Discrete Factor.
 | ||||||
|  |   double error(const HybridValues &values) const override; | ||||||
|  |   /// @}
 | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // traits
 | // traits
 | ||||||
|  |  | ||||||
|  | @ -26,6 +26,8 @@ | ||||||
| #include <string> | #include <string> | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
|  | class HybridValues; | ||||||
|  | 
 | ||||||
| KeyVector CollectKeys(const KeyVector &continuousKeys, | KeyVector CollectKeys(const KeyVector &continuousKeys, | ||||||
|                       const DiscreteKeys &discreteKeys); |                       const DiscreteKeys &discreteKeys); | ||||||
| KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); | KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); | ||||||
|  | @ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor { | ||||||
|   /// @name Standard Interface
 |   /// @name Standard Interface
 | ||||||
|   /// @{
 |   /// @{
 | ||||||
| 
 | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Compute the error of this Gaussian Mixture given the continuous | ||||||
|  |    * values and a discrete assignment. | ||||||
|  |    * | ||||||
|  |    * @param values Continuous values and discrete assignment. | ||||||
|  |    * @return double | ||||||
|  |    */ | ||||||
|  |   virtual double error(const HybridValues &values) const = 0; | ||||||
|  | 
 | ||||||
|   /// True if this is a factor of discrete variables only.
 |   /// True if this is a factor of discrete variables only.
 | ||||||
|   bool isDiscrete() const { return isDiscrete_; } |   bool isDiscrete() const { return isDiscrete_; } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| #include <gtsam/hybrid/HybridGaussianFactor.h> | #include <gtsam/hybrid/HybridGaussianFactor.h> | ||||||
|  | #include <gtsam/hybrid/HybridValues.h> | ||||||
| #include <gtsam/linear/HessianFactor.h> | #include <gtsam/linear/HessianFactor.h> | ||||||
| #include <gtsam/linear/JacobianFactor.h> | #include <gtsam/linear/JacobianFactor.h> | ||||||
| 
 | 
 | ||||||
|  | @ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s, | ||||||
|   inner_->print("\n", formatter); |   inner_->print("\n", formatter); | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | /* ************************************************************************ */ | ||||||
|  | double HybridGaussianFactor::error(const HybridValues &values) const { | ||||||
|  |   return inner_->error(values.continuous()); | ||||||
|  | } | ||||||
|  | /* ************************************************************************ */ | ||||||
|  | 
 | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -25,6 +25,7 @@ namespace gtsam { | ||||||
| // Forward declarations
 | // Forward declarations
 | ||||||
| class JacobianFactor; | class JacobianFactor; | ||||||
| class HessianFactor; | class HessianFactor; | ||||||
|  | class HybridValues; | ||||||
| 
 | 
 | ||||||
| /**
 | /**
 | ||||||
|  * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have |  * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have | ||||||
|  | @ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { | ||||||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; |       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||||
| 
 | 
 | ||||||
|   /// @}
 |   /// @}
 | ||||||
|  |   /// @name Standard Interface
 | ||||||
|  |   /// @{
 | ||||||
| 
 | 
 | ||||||
|  |   /// Return pointer to the internal discrete factor
 | ||||||
|   GaussianFactor::shared_ptr inner() const { return inner_; } |   GaussianFactor::shared_ptr inner() const { return inner_; } | ||||||
|  | 
 | ||||||
|  |   /// Return the error of the underlying Discrete Factor.
 | ||||||
|  |   double error(const HybridValues &values) const override; | ||||||
|  |   /// @}
 | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // traits
 | // traits
 | ||||||
|  |  | ||||||
|  | @ -55,13 +55,14 @@ | ||||||
| 
 | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
|  | /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
 | ||||||
| template class EliminateableFactorGraph<HybridGaussianFactorGraph>; | template class EliminateableFactorGraph<HybridGaussianFactorGraph>; | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************ */ | /* ************************************************************************ */ | ||||||
| static GaussianMixtureFactor::Sum &addGaussian( | static GaussianMixtureFactor::Sum &addGaussian( | ||||||
|     GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { |     GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { | ||||||
|   using Y = GaussianFactorGraph; |   using Y = GaussianFactorGraph; | ||||||
|   // If the decision tree is not intiialized, then intialize it.
 |   // If the decision tree is not initialized, then initialize it.
 | ||||||
|   if (sum.empty()) { |   if (sum.empty()) { | ||||||
|     GaussianFactorGraph result; |     GaussianFactorGraph result; | ||||||
|     result.push_back(factor); |     result.push_back(factor); | ||||||
|  | @ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals( | ||||||
| 
 | 
 | ||||||
|   for (auto &f : factors) { |   for (auto &f : factors) { | ||||||
|     if (f->isHybrid()) { |     if (f->isHybrid()) { | ||||||
|       if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { |       // TODO(dellaert): just use a virtual method defined in HybridFactor.
 | ||||||
|         sum = cgmf->add(sum); |       if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { | ||||||
|  |         sum = gm->add(sum); | ||||||
|       } |       } | ||||||
|       if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { |       if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { | ||||||
|         sum = gm->asMixture()->add(sum); |         sum = gm->asMixture()->add(sum); | ||||||
|  | @ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | ||||||
|                   const KeySet &continuousSeparator, |                   const KeySet &continuousSeparator, | ||||||
|                   const std::set<DiscreteKey> &discreteSeparatorSet) { |                   const std::set<DiscreteKey> &discreteSeparatorSet) { | ||||||
|   // NOTE: since we use the special JunctionTree,
 |   // NOTE: since we use the special JunctionTree,
 | ||||||
|   // only possiblity is continuous conditioned on discrete.
 |   // only possibility is continuous conditioned on discrete.
 | ||||||
|   DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), |   DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), | ||||||
|                                  discreteSeparatorSet.end()); |                                  discreteSeparatorSet.end()); | ||||||
| 
 | 
 | ||||||
|  | @ -204,16 +206,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | ||||||
|   }; |   }; | ||||||
|   sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); |   sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); | ||||||
| 
 | 
 | ||||||
|   using EliminationPair = GaussianFactorGraph::EliminationResult; |   using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>, | ||||||
|  |                                     GaussianMixtureFactor::FactorAndConstant>; | ||||||
| 
 | 
 | ||||||
|   KeyVector keysOfEliminated;  // Not the ordering
 |   KeyVector keysOfEliminated;  // Not the ordering
 | ||||||
|   KeyVector keysOfSeparator;   // TODO(frank): Is this just (keys - ordering)?
 |   KeyVector keysOfSeparator;   // TODO(frank): Is this just (keys - ordering)?
 | ||||||
| 
 | 
 | ||||||
|   // This is the elimination method on the leaf nodes
 |   // This is the elimination method on the leaf nodes
 | ||||||
|   auto eliminate = [&](const GaussianFactorGraph &graph) |   auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { | ||||||
|       -> GaussianFactorGraph::EliminationResult { |  | ||||||
|     if (graph.empty()) { |     if (graph.empty()) { | ||||||
|       return {nullptr, nullptr}; |       return {nullptr, {nullptr, 0.0}}; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| #ifdef HYBRID_TIMING | #ifdef HYBRID_TIMING | ||||||
|  | @ -222,18 +224,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | ||||||
| 
 | 
 | ||||||
|     std::pair<boost::shared_ptr<GaussianConditional>, |     std::pair<boost::shared_ptr<GaussianConditional>, | ||||||
|               boost::shared_ptr<GaussianFactor>> |               boost::shared_ptr<GaussianFactor>> | ||||||
|         result = EliminatePreferCholesky(graph, frontalKeys); |         conditional_factor = EliminatePreferCholesky(graph, frontalKeys); | ||||||
| 
 | 
 | ||||||
|     // Initialize the keysOfEliminated to be the keys of the
 |     // Initialize the keysOfEliminated to be the keys of the
 | ||||||
|     // eliminated GaussianConditional
 |     // eliminated GaussianConditional
 | ||||||
|     keysOfEliminated = result.first->keys(); |     keysOfEliminated = conditional_factor.first->keys(); | ||||||
|     keysOfSeparator = result.second->keys(); |     keysOfSeparator = conditional_factor.second->keys(); | ||||||
| 
 | 
 | ||||||
| #ifdef HYBRID_TIMING | #ifdef HYBRID_TIMING | ||||||
|     gttoc_(hybrid_eliminate); |     gttoc_(hybrid_eliminate); | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|     return result; |     return {conditional_factor.first, {conditional_factor.second, 0.0}}; | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|   // Perform elimination!
 |   // Perform elimination!
 | ||||||
|  | @ -246,8 +248,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | ||||||
| 
 | 
 | ||||||
|   // Separate out decision tree into conditionals and remaining factors.
 |   // Separate out decision tree into conditionals and remaining factors.
 | ||||||
|   auto pair = unzip(eliminationResults); |   auto pair = unzip(eliminationResults); | ||||||
| 
 |   const auto &separatorFactors = pair.second; | ||||||
|   const GaussianMixtureFactor::Factors &separatorFactors = pair.second; |  | ||||||
| 
 | 
 | ||||||
|   // Create the GaussianMixture from the conditionals
 |   // Create the GaussianMixture from the conditionals
 | ||||||
|   auto conditional = boost::make_shared<GaussianMixture>( |   auto conditional = boost::make_shared<GaussianMixture>( | ||||||
|  | @ -257,16 +258,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors, | ||||||
|   // DiscreteFactor, with the error for each discrete choice.
 |   // DiscreteFactor, with the error for each discrete choice.
 | ||||||
|   if (keysOfSeparator.empty()) { |   if (keysOfSeparator.empty()) { | ||||||
|     VectorValues empty_values; |     VectorValues empty_values; | ||||||
|     auto factorProb = [&](const GaussianFactor::shared_ptr &factor) { |     auto factorProb = | ||||||
|       if (!factor) { |         [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { | ||||||
|         return 0.0;  // If nullptr, return 0.0 probability
 |           GaussianFactor::shared_ptr factor = factor_z.factor; | ||||||
|       } else { |           if (!factor) { | ||||||
|         // This is the probability q(μ) at the MLE point.
 |             return 0.0;  // If nullptr, return 0.0 probability
 | ||||||
|         double error = |           } else { | ||||||
|             0.5 * std::abs(factor->augmentedInformation().determinant()); |             // This is the probability q(μ) at the MLE point.
 | ||||||
|         return std::exp(-error); |             double error = | ||||||
|       } |                 0.5 * std::abs(factor->augmentedInformation().determinant()) + | ||||||
|     }; |                 factor_z.constant; | ||||||
|  |             return std::exp(-error); | ||||||
|  |           } | ||||||
|  |         }; | ||||||
|     DecisionTree<Key, double> fdt(separatorFactors, factorProb); |     DecisionTree<Key, double> fdt(separatorFactors, factorProb); | ||||||
| 
 | 
 | ||||||
|     auto discreteFactor = |     auto discreteFactor = | ||||||
|  | @ -452,6 +456,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( | ||||||
| 
 | 
 | ||||||
|   // Iterate over each factor.
 |   // Iterate over each factor.
 | ||||||
|   for (size_t idx = 0; idx < size(); idx++) { |   for (size_t idx = 0; idx < size(); idx++) { | ||||||
|  |     // TODO(dellaert): just use a virtual method defined in HybridFactor.
 | ||||||
|     AlgebraicDecisionTree<Key> factor_error; |     AlgebraicDecisionTree<Key> factor_error; | ||||||
| 
 | 
 | ||||||
|     if (factors_.at(idx)->isHybrid()) { |     if (factors_.at(idx)->isHybrid()) { | ||||||
|  | @ -491,38 +496,17 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************ */ | /* ************************************************************************ */ | ||||||
| double HybridGaussianFactorGraph::error( | double HybridGaussianFactorGraph::error(const HybridValues &values) const { | ||||||
|     const VectorValues &continuousValues, |  | ||||||
|     const DiscreteValues &discreteValues) const { |  | ||||||
|   double error = 0.0; |   double error = 0.0; | ||||||
|   for (size_t idx = 0; idx < size(); idx++) { |   for (auto &factor : factors_) { | ||||||
|     auto factor = factors_.at(idx); |     error += factor->error(values); | ||||||
| 
 |  | ||||||
|     if (factor->isHybrid()) { |  | ||||||
|       if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) { |  | ||||||
|         error += c->asMixture()->error(continuousValues, discreteValues); |  | ||||||
|       } |  | ||||||
|       if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) { |  | ||||||
|         error += f->error(continuousValues, discreteValues); |  | ||||||
|       } |  | ||||||
| 
 |  | ||||||
|     } else if (factor->isContinuous()) { |  | ||||||
|       if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) { |  | ||||||
|         error += f->inner()->error(continuousValues); |  | ||||||
|       } |  | ||||||
|       if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) { |  | ||||||
|         error += cg->asGaussian()->error(continuousValues); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|   return error; |   return error; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************ */ | /* ************************************************************************ */ | ||||||
| double HybridGaussianFactorGraph::probPrime( | double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { | ||||||
|     const VectorValues &continuousValues, |   double error = this->error(values); | ||||||
|     const DiscreteValues &discreteValues) const { |  | ||||||
|   double error = this->error(continuousValues, discreteValues); |  | ||||||
|   // NOTE: The 0.5 term is handled by each factor
 |   // NOTE: The 0.5 term is handled by each factor
 | ||||||
|   return std::exp(-error); |   return std::exp(-error); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -12,7 +12,7 @@ | ||||||
| /**
 | /**
 | ||||||
|  * @file   HybridGaussianFactorGraph.h |  * @file   HybridGaussianFactorGraph.h | ||||||
|  * @brief  Linearized Hybrid factor graph that uses type erasure |  * @brief  Linearized Hybrid factor graph that uses type erasure | ||||||
|  * @author Fan Jiang, Varun Agrawal |  * @author Fan Jiang, Varun Agrawal, Frank Dellaert | ||||||
|  * @date   Mar 11, 2022 |  * @date   Mar 11, 2022 | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
|  | @ -38,6 +38,7 @@ class HybridBayesTree; | ||||||
| class HybridJunctionTree; | class HybridJunctionTree; | ||||||
| class DecisionTreeFactor; | class DecisionTreeFactor; | ||||||
| class JacobianFactor; | class JacobianFactor; | ||||||
|  | class HybridValues; | ||||||
| 
 | 
 | ||||||
| /**
 | /**
 | ||||||
|  * @brief Main elimination function for HybridGaussianFactorGraph. |  * @brief Main elimination function for HybridGaussianFactorGraph. | ||||||
|  | @ -186,14 +187,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | ||||||
|    * @brief Compute error given a continuous vector values |    * @brief Compute error given a continuous vector values | ||||||
|    * and a discrete assignment. |    * and a discrete assignment. | ||||||
|    * |    * | ||||||
|    * @param continuousValues The continuous VectorValues |  | ||||||
|    * for computing the error. |  | ||||||
|    * @param discreteValues The specific discrete assignment |  | ||||||
|    * whose error we wish to compute. |  | ||||||
|    * @return double |    * @return double | ||||||
|    */ |    */ | ||||||
|   double error(const VectorValues& continuousValues, |   double error(const HybridValues& values) const; | ||||||
|                const DiscreteValues& discreteValues) const; |  | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ |    * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ | ||||||
|  | @ -210,13 +206,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | ||||||
|    * @brief Compute the unnormalized posterior probability for a continuous |    * @brief Compute the unnormalized posterior probability for a continuous | ||||||
|    * vector values given a specific assignment. |    * vector values given a specific assignment. | ||||||
|    * |    * | ||||||
|    * @param continuousValues The vector values for which to compute the |  | ||||||
|    * posterior probability. |  | ||||||
|    * @param discreteValues The specific assignment to use for the computation. |  | ||||||
|    * @return double |    * @return double | ||||||
|    */ |    */ | ||||||
|   double probPrime(const VectorValues& continuousValues, |   double probPrime(const HybridValues& values) const; | ||||||
|                    const DiscreteValues& discreteValues) const; |  | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Return a Colamd constrained ordering where the discrete keys are |    * @brief Return a Colamd constrained ordering where the discrete keys are | ||||||
|  |  | ||||||
|  | @ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor { | ||||||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; |       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||||
| 
 | 
 | ||||||
|   /// @}
 |   /// @}
 | ||||||
|  |   /// @name Standard Interface
 | ||||||
|  |   /// @{
 | ||||||
| 
 | 
 | ||||||
|   NonlinearFactor::shared_ptr inner() const { return inner_; } |   NonlinearFactor::shared_ptr inner() const { return inner_; } | ||||||
| 
 | 
 | ||||||
|  |   /// Error for HybridValues is not provided for nonlinear factor.
 | ||||||
|  |   double error(const HybridValues &values) const override { | ||||||
|  |     throw std::runtime_error( | ||||||
|  |         "HybridNonlinearFactor::error(HybridValues) not implemented."); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /// Linearize to a HybridGaussianFactor at the linearization point `c`.
 |   /// Linearize to a HybridGaussianFactor at the linearization point `c`.
 | ||||||
|   boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const { |   boost::shared_ptr<HybridGaussianFactor> linearize(const Values &c) const { | ||||||
|     return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c)); |     return boost::make_shared<HybridGaussianFactor>(inner_->linearize(c)); | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   /// @}
 | ||||||
| }; | }; | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor { | ||||||
|                              factor, continuousValues); |                              factor, continuousValues); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /// Error for HybridValues is not provided for nonlinear hybrid factor.
 | ||||||
|  |   double error(const HybridValues &values) const override { | ||||||
|  |     throw std::runtime_error( | ||||||
|  |         "MixtureFactor::error(HybridValues) not implemented."); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   size_t dim() const { |   size_t dim() const { | ||||||
|     // TODO(Varun)
 |     // TODO(Varun)
 | ||||||
|     throw std::runtime_error("MixtureFactor::dim not implemented."); |     throw std::runtime_error("MixtureFactor::dim not implemented."); | ||||||
|  |  | ||||||
|  | @ -183,10 +183,8 @@ class HybridGaussianFactorGraph { | ||||||
|   bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; |   bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; | ||||||
| 
 | 
 | ||||||
|   // evaluation |   // evaluation | ||||||
|   double error(const gtsam::VectorValues& continuousValues, |   double error(const gtsam::HybridValues& values) const; | ||||||
|                const gtsam::DiscreteValues& discreteValues) const; |   double probPrime(const gtsam::HybridValues& values) const; | ||||||
|   double probPrime(const gtsam::VectorValues& continuousValues, |  | ||||||
|                    const gtsam::DiscreteValues& discreteValues) const; |  | ||||||
| 
 | 
 | ||||||
|   gtsam::HybridBayesNet* eliminateSequential(); |   gtsam::HybridBayesNet* eliminateSequential(); | ||||||
|   gtsam::HybridBayesNet* eliminateSequential( |   gtsam::HybridBayesNet* eliminateSequential( | ||||||
|  |  | ||||||
|  | @ -128,9 +128,9 @@ TEST(GaussianMixture, Error) { | ||||||
|   // Regression for non-tree version.
 |   // Regression for non-tree version.
 | ||||||
|   DiscreteValues assignment; |   DiscreteValues assignment; | ||||||
|   assignment[M(1)] = 0; |   assignment[M(1)] = 0; | ||||||
|   EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); |   EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8); | ||||||
|   assignment[M(1)] = 1; |   assignment[M(1)] = 1; | ||||||
|   EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), |   EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}), | ||||||
|                        1e-8); |                        1e-8); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -179,7 +179,9 @@ TEST(GaussianMixture, Likelihood) { | ||||||
|   const GaussianMixtureFactor::Factors factors( |   const GaussianMixtureFactor::Factors factors( | ||||||
|       gm.conditionals(), |       gm.conditionals(), | ||||||
|       [measurements](const GaussianConditional::shared_ptr& conditional) { |       [measurements](const GaussianConditional::shared_ptr& conditional) { | ||||||
|         return conditional->likelihood(measurements); |         return GaussianMixtureFactor::FactorAndConstant{ | ||||||
|  |             conditional->likelihood(measurements), | ||||||
|  |             conditional->logNormalizationConstant()}; | ||||||
|       }); |       }); | ||||||
|   const GaussianMixtureFactor expected({X(0)}, {mode}, factors); |   const GaussianMixtureFactor expected({X(0)}, {mode}, factors); | ||||||
|   EXPECT(assert_equal(*factor, expected)); |   EXPECT(assert_equal(*factor, expected)); | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ | ||||||
| #include <gtsam/discrete/DiscreteValues.h> | #include <gtsam/discrete/DiscreteValues.h> | ||||||
| #include <gtsam/hybrid/GaussianMixture.h> | #include <gtsam/hybrid/GaussianMixture.h> | ||||||
| #include <gtsam/hybrid/GaussianMixtureFactor.h> | #include <gtsam/hybrid/GaussianMixtureFactor.h> | ||||||
|  | #include <gtsam/hybrid/HybridValues.h> | ||||||
| #include <gtsam/inference/Symbol.h> | #include <gtsam/inference/Symbol.h> | ||||||
| #include <gtsam/linear/GaussianFactorGraph.h> | #include <gtsam/linear/GaussianFactorGraph.h> | ||||||
| 
 | 
 | ||||||
|  | @ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) { | ||||||
|   DiscreteValues discreteValues; |   DiscreteValues discreteValues; | ||||||
|   discreteValues[m1.first] = 1; |   discreteValues[m1.first] = 1; | ||||||
|   EXPECT_DOUBLES_EQUAL( |   EXPECT_DOUBLES_EQUAL( | ||||||
|       4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9); |       4.0, mixtureFactor.error({continuousValues, discreteValues}), | ||||||
|  |       1e-9); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  |  | ||||||
|  | @ -188,14 +188,14 @@ TEST(HybridBayesNet, Optimize) { | ||||||
| 
 | 
 | ||||||
|   HybridValues delta = hybridBayesNet->optimize(); |   HybridValues delta = hybridBayesNet->optimize(); | ||||||
| 
 | 
 | ||||||
|   //TODO(Varun) The expectedAssignment should be 111, not 101
 |   // TODO(Varun) The expectedAssignment should be 111, not 101
 | ||||||
|   DiscreteValues expectedAssignment; |   DiscreteValues expectedAssignment; | ||||||
|   expectedAssignment[M(0)] = 1; |   expectedAssignment[M(0)] = 1; | ||||||
|   expectedAssignment[M(1)] = 0; |   expectedAssignment[M(1)] = 0; | ||||||
|   expectedAssignment[M(2)] = 1; |   expectedAssignment[M(2)] = 1; | ||||||
|   EXPECT(assert_equal(expectedAssignment, delta.discrete())); |   EXPECT(assert_equal(expectedAssignment, delta.discrete())); | ||||||
| 
 | 
 | ||||||
|   //TODO(Varun) This should be all -Vector1::Ones()
 |   // TODO(Varun) This should be all -Vector1::Ones()
 | ||||||
|   VectorValues expectedValues; |   VectorValues expectedValues; | ||||||
|   expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); |   expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); | ||||||
|   expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); |   expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); | ||||||
|  | @ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) { | ||||||
|   double total_error = 0; |   double total_error = 0; | ||||||
|   for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { |   for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { | ||||||
|     if (hybridBayesNet->at(idx)->isHybrid()) { |     if (hybridBayesNet->at(idx)->isHybrid()) { | ||||||
|       double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(), |       double error = hybridBayesNet->atMixture(idx)->error( | ||||||
|                                                            discrete_values); |           {delta.continuous(), discrete_values}); | ||||||
|       total_error += error; |       total_error += error; | ||||||
|     } else if (hybridBayesNet->at(idx)->isContinuous()) { |     } else if (hybridBayesNet->at(idx)->isContinuous()) { | ||||||
|       double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); |       double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); | ||||||
|  | @ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   EXPECT_DOUBLES_EQUAL( |   EXPECT_DOUBLES_EQUAL( | ||||||
|       total_error, hybridBayesNet->error(delta.continuous(), discrete_values), |       total_error, hybridBayesNet->error({delta.continuous(), discrete_values}), | ||||||
|       1e-9); |       1e-9); | ||||||
|   EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); |   EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); | ||||||
|   EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); |   EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); | ||||||
|  |  | ||||||
|  | @ -273,7 +273,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree( | ||||||
|       continue; |       continue; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     double error = graph.error(delta, assignment); |     double error = graph.error({delta, assignment}); | ||||||
|     probPrimes.push_back(exp(-error)); |     probPrimes.push_back(exp(-error)); | ||||||
|   } |   } | ||||||
|   AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes); |   AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes); | ||||||
|  | @ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) { | ||||||
|          const HybridValues& sample) -> double { |          const HybridValues& sample) -> double { | ||||||
|     const DiscreteValues assignment = sample.discrete(); |     const DiscreteValues assignment = sample.discrete(); | ||||||
|     // Compute in log form for numerical stability
 |     // Compute in log form for numerical stability
 | ||||||
|     double log_ratio = bayesNet->error(sample.continuous(), assignment) - |     double log_ratio = bayesNet->error({sample.continuous(), assignment}) - | ||||||
|                        factorGraph->error(sample.continuous(), assignment); |                        factorGraph->error({sample.continuous(), assignment}); | ||||||
|     double ratio = exp(-log_ratio); |     double ratio = exp(-log_ratio); | ||||||
|     return ratio; |     return ratio; | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  | @ -575,18 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { | ||||||
|   HybridBayesNet::shared_ptr hybridBayesNet = |   HybridBayesNet::shared_ptr hybridBayesNet = | ||||||
|       graph.eliminateSequential(hybridOrdering); |       graph.eliminateSequential(hybridOrdering); | ||||||
| 
 | 
 | ||||||
|   HybridValues delta = hybridBayesNet->optimize(); |   const HybridValues delta = hybridBayesNet->optimize(); | ||||||
|   double error = graph.error(delta.continuous(), delta.discrete()); |   const double error = graph.error(delta); | ||||||
| 
 |  | ||||||
|   double expected_error = 0.490243199; |  | ||||||
|   // regression
 |  | ||||||
|   EXPECT(assert_equal(expected_error, error, 1e-9)); |  | ||||||
| 
 |  | ||||||
|   double probs = exp(-error); |  | ||||||
|   double expected_probs = graph.probPrime(delta.continuous(), delta.discrete()); |  | ||||||
| 
 | 
 | ||||||
|   // regression
 |   // regression
 | ||||||
|   EXPECT(assert_equal(expected_probs, probs, 1e-7)); |   EXPECT(assert_equal(1.58886, error, 1e-5)); | ||||||
|  | 
 | ||||||
|  |   // Real test:
 | ||||||
|  |   EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ****************************************************************************/ | /* ****************************************************************************/ | ||||||
|  |  | ||||||
|  | @ -168,26 +168,30 @@ namespace gtsam { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| double GaussianConditional::logDeterminant() const { | double GaussianConditional::logDeterminant() const { | ||||||
|   double logDet; |   if (get_model()) { | ||||||
|   if (this->get_model()) { |     Vector diag = R().diagonal(); | ||||||
|     Vector diag = this->R().diagonal(); |     get_model()->whitenInPlace(diag); | ||||||
|     this->get_model()->whitenInPlace(diag); |     return diag.unaryExpr([](double x) { return log(x); }).sum(); | ||||||
|     logDet = diag.unaryExpr([](double x) { return log(x); }).sum(); |  | ||||||
|   } else { |   } else { | ||||||
|     logDet = |     return R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); | ||||||
|         this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); |  | ||||||
|   } |   } | ||||||
|   return logDet; |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| //  density = exp(-error(x)) / sqrt((2*pi)^n*det(Sigma))
 | //  normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
 | ||||||
| //  log = -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
 | //  log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
 | ||||||
| double GaussianConditional::logDensity(const VectorValues& x) const { | double GaussianConditional::logNormalizationConstant() const { | ||||||
|   constexpr double log2pi = 1.8378770664093454835606594728112; |   constexpr double log2pi = 1.8378770664093454835606594728112; | ||||||
|   size_t n = d().size(); |   size_t n = d().size(); | ||||||
|   // log det(Sigma)) = - 2.0 * logDeterminant()
 |   // log det(Sigma)) = - 2.0 * logDeterminant()
 | ||||||
|   return - error(x) - 0.5 * n * log2pi + logDeterminant(); |   return - 0.5 * n * log2pi + logDeterminant(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | //  density = k exp(-error(x))
 | ||||||
|  | //  log = log(k) -error(x) - 0.5 * n*log(2*pi)
 | ||||||
|  | double GaussianConditional::logDensity(const VectorValues& x) const { | ||||||
|  |   return logNormalizationConstant() - error(x); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  |  | ||||||
|  | @ -169,7 +169,7 @@ namespace gtsam { | ||||||
|      * |      * | ||||||
|      * @return double |      * @return double | ||||||
|      */ |      */ | ||||||
|     double determinant() const { return exp(this->logDeterminant()); } |     inline double determinant() const { return exp(logDeterminant()); } | ||||||
| 
 | 
 | ||||||
|     /**
 |     /**
 | ||||||
|      * @brief Compute the log determinant of the R matrix. |      * @brief Compute the log determinant of the R matrix. | ||||||
|  | @ -184,6 +184,19 @@ namespace gtsam { | ||||||
|      */ |      */ | ||||||
|     double logDeterminant() const; |     double logDeterminant() const; | ||||||
| 
 | 
 | ||||||
|  |     /**
 | ||||||
|  |      * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) | ||||||
|  |      * log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) | ||||||
|  |      */ | ||||||
|  |     double logNormalizationConstant() const; | ||||||
|  | 
 | ||||||
|  |     /**
 | ||||||
|  |      * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) | ||||||
|  |      */ | ||||||
|  |     inline double normalizationConstant() const { | ||||||
|  |       return exp(logNormalizationConstant()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /**
 |     /**
 | ||||||
|     * Solves a conditional Gaussian and writes the solution into the entries of |     * Solves a conditional Gaussian and writes the solution into the entries of | ||||||
|     * \c x for each frontal variable of the conditional.  The parents are |     * \c x for each frontal variable of the conditional.  The parents are | ||||||
|  |  | ||||||
|  | @ -6,7 +6,7 @@ All Rights Reserved | ||||||
| See LICENSE for the license information | See LICENSE for the license information | ||||||
| 
 | 
 | ||||||
| Unit tests for Hybrid Factor Graphs. | Unit tests for Hybrid Factor Graphs. | ||||||
| Author: Fan Jiang | Author: Fan Jiang, Varun Agrawal, Frank Dellaert | ||||||
| """ | """ | ||||||
| # pylint: disable=invalid-name, no-name-in-module, no-member | # pylint: disable=invalid-name, no-name-in-module, no-member | ||||||
| 
 | 
 | ||||||
|  | @ -18,13 +18,14 @@ from gtsam.utils.test_case import GtsamTestCase | ||||||
| 
 | 
 | ||||||
| import gtsam | import gtsam | ||||||
| from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, | from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, | ||||||
|                    GaussianMixture, GaussianMixtureFactor, |                    GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues, | ||||||
|                    HybridGaussianFactorGraph, JacobianFactor, Ordering, |                    HybridGaussianFactorGraph, JacobianFactor, Ordering, | ||||||
|                    noiseModel) |                    noiseModel) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestHybridGaussianFactorGraph(GtsamTestCase): | class TestHybridGaussianFactorGraph(GtsamTestCase): | ||||||
|     """Unit tests for HybridGaussianFactorGraph.""" |     """Unit tests for HybridGaussianFactorGraph.""" | ||||||
|  | 
 | ||||||
|     def test_create(self): |     def test_create(self): | ||||||
|         """Test construction of hybrid factor graph.""" |         """Test construction of hybrid factor graph.""" | ||||||
|         model = noiseModel.Unit.Create(3) |         model = noiseModel.Unit.Create(3) | ||||||
|  | @ -81,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | ||||||
|         self.assertEqual(hv.atDiscrete(C(0)), 1) |         self.assertEqual(hv.atDiscrete(C(0)), 1) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: |     def tiny(num_measurements: int = 1) -> HybridBayesNet: | ||||||
|         """ |         """ | ||||||
|         Create a tiny two variable hybrid model which represents |         Create a tiny two variable hybrid model which represents | ||||||
|         the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). |         the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). | ||||||
|         """ |         """ | ||||||
|         # Create hybrid Bayes net. |         # Create hybrid Bayes net. | ||||||
|         bayesNet = gtsam.HybridBayesNet() |         bayesNet = HybridBayesNet() | ||||||
| 
 | 
 | ||||||
|         # Create mode key: 0 is low-noise, 1 is high-noise. |         # Create mode key: 0 is low-noise, 1 is high-noise. | ||||||
|         mode = (M(0), 2) |         mode = (M(0), 2) | ||||||
|  | @ -113,35 +114,76 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | ||||||
|         bayesNet.addGaussian(prior_on_x0) |         bayesNet.addGaussian(prior_on_x0) | ||||||
| 
 | 
 | ||||||
|         # Add prior on mode. |         # Add prior on mode. | ||||||
|         bayesNet.emplaceDiscrete(mode, "1/1") |         bayesNet.emplaceDiscrete(mode, "4/6") | ||||||
| 
 | 
 | ||||||
|         return bayesNet |         return bayesNet | ||||||
| 
 | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): | ||||||
|  |         """Create a factor graph from the Bayes net with sampled measurements. | ||||||
|  |             The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` | ||||||
|  |             and thus represents the same joint probability as the Bayes net. | ||||||
|  |         """ | ||||||
|  |         fg = HybridGaussianFactorGraph() | ||||||
|  |         num_measurements = bayesNet.size() - 2 | ||||||
|  |         for i in range(num_measurements): | ||||||
|  |             conditional = bayesNet.atMixture(i) | ||||||
|  |             measurement = gtsam.VectorValues() | ||||||
|  |             measurement.insert(Z(i), sample.at(Z(i))) | ||||||
|  |             factor = conditional.likelihood(measurement) | ||||||
|  |             fg.push_back(factor) | ||||||
|  |         fg.push_back(bayesNet.atGaussian(num_measurements)) | ||||||
|  |         fg.push_back(bayesNet.atDiscrete(num_measurements+1)) | ||||||
|  |         return fg | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000): | ||||||
|  |         """Do importance sampling to get an estimate of the discrete marginal P(mode).""" | ||||||
|  |         # Use prior on x0, mode as proposal density. | ||||||
|  |         prior = cls.tiny(num_measurements=0)  # just P(x0)P(mode) | ||||||
|  | 
 | ||||||
|  |         # Allocate space for marginals. | ||||||
|  |         marginals = np.zeros((2,)) | ||||||
|  | 
 | ||||||
|  |         # Do importance sampling. | ||||||
|  |         num_measurements = bayesNet.size() - 2 | ||||||
|  |         for s in range(N): | ||||||
|  |             proposed = prior.sample() | ||||||
|  |             for i in range(num_measurements): | ||||||
|  |                 z_i = sample.at(Z(i)) | ||||||
|  |                 proposed.insert(Z(i), z_i) | ||||||
|  |             weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) | ||||||
|  |             marginals[proposed.atDiscrete(M(0))] += weight | ||||||
|  | 
 | ||||||
|  |         # print marginals: | ||||||
|  |         marginals /= marginals.sum() | ||||||
|  |         return marginals | ||||||
|  | 
 | ||||||
|     def test_tiny(self): |     def test_tiny(self): | ||||||
|         """Test a tiny two variable hybrid model.""" |         """Test a tiny two variable hybrid model.""" | ||||||
|         bayesNet = self.tiny() |         bayesNet = self.tiny() | ||||||
|         sample = bayesNet.sample() |         sample = bayesNet.sample() | ||||||
|         # print(sample) |         # print(sample) | ||||||
| 
 | 
 | ||||||
|         # Create a factor graph from the Bayes net with sampled measurements. |         # Estimate marginals using importance sampling. | ||||||
|         fg = HybridGaussianFactorGraph() |         marginals = self.estimate_marginals(bayesNet, sample) | ||||||
|         conditional = bayesNet.atMixture(0) |         # print(f"True mode: {sample.atDiscrete(M(0))}") | ||||||
|         measurement = gtsam.VectorValues() |         # print(f"P(mode=0; z0) = {marginals[0]}") | ||||||
|         measurement.insert(Z(0), sample.at(Z(0))) |         # print(f"P(mode=1; z0) = {marginals[1]}") | ||||||
|         factor = conditional.likelihood(measurement) |  | ||||||
|         fg.push_back(factor) |  | ||||||
|         fg.push_back(bayesNet.atGaussian(1)) |  | ||||||
|         fg.push_back(bayesNet.atDiscrete(2)) |  | ||||||
| 
 | 
 | ||||||
|  |         # Check that the estimate is close to the true value. | ||||||
|  |         self.assertAlmostEqual(marginals[0], 0.4, delta=0.1) | ||||||
|  |         self.assertAlmostEqual(marginals[1], 0.6, delta=0.1) | ||||||
|  | 
 | ||||||
|  |         fg = self.factor_graph_from_bayes_net(bayesNet, sample) | ||||||
|         self.assertEqual(fg.size(), 3) |         self.assertEqual(fg.size(), 3) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def calculate_ratio(bayesNet, fg, sample): |     def calculate_ratio(bayesNet: HybridBayesNet, | ||||||
|  |                         fg: HybridGaussianFactorGraph, | ||||||
|  |                         sample: HybridValues): | ||||||
|         """Calculate ratio  between Bayes net probability and the factor graph.""" |         """Calculate ratio  between Bayes net probability and the factor graph.""" | ||||||
|         continuous = gtsam.VectorValues() |         return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 | ||||||
|         continuous.insert(X(0), sample.at(X(0))) |  | ||||||
|         return bayesNet.evaluate(sample) / fg.probPrime( |  | ||||||
|             continuous, sample.discrete()) |  | ||||||
| 
 | 
 | ||||||
|     def test_ratio(self): |     def test_ratio(self): | ||||||
|         """ |         """ | ||||||
|  | @ -153,23 +195,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | ||||||
|         # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) |         # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) | ||||||
|         bayesNet = self.tiny(num_measurements=2) |         bayesNet = self.tiny(num_measurements=2) | ||||||
|         # Sample from the Bayes net. |         # Sample from the Bayes net. | ||||||
|         sample: gtsam.HybridValues = bayesNet.sample() |         sample: HybridValues = bayesNet.sample() | ||||||
|         # print(sample) |         # print(sample) | ||||||
| 
 | 
 | ||||||
|         # Create a factor graph from the Bayes net with sampled measurements. |         # Estimate marginals using importance sampling. | ||||||
|         # The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)` |         marginals = self.estimate_marginals(bayesNet, sample) | ||||||
|         # and thus represents the same joint probability as the Bayes net. |         # print(f"True mode: {sample.atDiscrete(M(0))}") | ||||||
|         fg = HybridGaussianFactorGraph() |         # print(f"P(mode=0; z0, z1) = {marginals[0]}") | ||||||
|         for i in range(2): |         # print(f"P(mode=1; z0, z1) = {marginals[1]}") | ||||||
|             conditional = bayesNet.atMixture(i) |  | ||||||
|             measurement = gtsam.VectorValues() |  | ||||||
|             measurement.insert(Z(i), sample.at(Z(i))) |  | ||||||
|             factor = conditional.likelihood(measurement) |  | ||||||
|             fg.push_back(factor) |  | ||||||
|         fg.push_back(bayesNet.atGaussian(2)) |  | ||||||
|         fg.push_back(bayesNet.atDiscrete(3)) |  | ||||||
| 
 | 
 | ||||||
|         # print(fg) |         # Check marginals based on sampled mode. | ||||||
|  |         if sample.atDiscrete(M(0)) == 0: | ||||||
|  |             self.assertGreater(marginals[0], marginals[1]) | ||||||
|  |         else: | ||||||
|  |             self.assertGreater(marginals[1], marginals[0]) | ||||||
|  | 
 | ||||||
|  |         fg = self.factor_graph_from_bayes_net(bayesNet, sample) | ||||||
|         self.assertEqual(fg.size(), 4) |         self.assertEqual(fg.size(), 4) | ||||||
| 
 | 
 | ||||||
|         # Calculate ratio between Bayes net probability and the factor graph: |         # Calculate ratio between Bayes net probability and the factor graph: | ||||||
|  | @ -185,10 +226,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | ||||||
|         for i in range(10): |         for i in range(10): | ||||||
|             other = bayesNet.sample() |             other = bayesNet.sample() | ||||||
|             other.update(measurements) |             other.update(measurements) | ||||||
|             # print(other) |             ratio = self.calculate_ratio(bayesNet, fg, other) | ||||||
|             # ratio = self.calculate_ratio(bayesNet, fg, other) |  | ||||||
|             # print(f"Ratio: {ratio}\n") |             # print(f"Ratio: {ratio}\n") | ||||||
|             # self.assertAlmostEqual(ratio, expected_ratio) |             if (ratio > 0): | ||||||
|  |                 self.assertAlmostEqual(ratio, expected_ratio) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue