Switch to using HybridValues
							parent
							
								
									b972be0b8f
								
							
						
					
					
						commit
						9cf3e5c26a
					
				| 
						 | 
				
			
			@ -22,6 +22,7 @@
 | 
			
		|||
#include <gtsam/discrete/DiscreteValues.h>
 | 
			
		||||
#include <gtsam/hybrid/GaussianMixture.h>
 | 
			
		||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
 | 
			
		||||
#include <gtsam/hybrid/HybridValues.h>
 | 
			
		||||
#include <gtsam/inference/Conditional-inst.h>
 | 
			
		||||
#include <gtsam/linear/GaussianFactorGraph.h>
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -159,9 +160,9 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
 | 
			
		||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
 | 
			
		||||
  std::set<DiscreteKey> s;
 | 
			
		||||
  s.insert(dkeys.begin(), dkeys.end());
 | 
			
		||||
  s.insert(discreteKeys.begin(), discreteKeys.end());
 | 
			
		||||
  return s;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -186,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
 | 
			
		|||
                    const GaussianConditional::shared_ptr &conditional)
 | 
			
		||||
      -> GaussianConditional::shared_ptr {
 | 
			
		||||
    // typecast so we can use this to get probability value
 | 
			
		||||
    DiscreteValues values(choices);
 | 
			
		||||
    const DiscreteValues values(choices);
 | 
			
		||||
 | 
			
		||||
    // Case where the gaussian mixture has the same
 | 
			
		||||
    // discrete keys as the decision tree.
 | 
			
		||||
| 
						 | 
				
			
			@ -256,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
double GaussianMixture::error(const VectorValues &continuousValues,
 | 
			
		||||
                              const DiscreteValues &discreteValues) const {
 | 
			
		||||
double GaussianMixture::error(const HybridValues &values) const {
 | 
			
		||||
  // Directly index to get the conditional, no need to build the whole tree.
 | 
			
		||||
  auto conditional = conditionals_(discreteValues);
 | 
			
		||||
  return conditional->error(continuousValues);
 | 
			
		||||
  auto conditional = conditionals_(values.discrete());
 | 
			
		||||
  return conditional->error(values.continuous());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,6 +30,7 @@
 | 
			
		|||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
class GaussianMixtureFactor;
 | 
			
		||||
class HybridValues;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief A conditional of gaussian mixtures indexed by discrete variables, as
 | 
			
		||||
| 
						 | 
				
			
			@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
 | 
			
		|||
  /// @name Constructors
 | 
			
		||||
  /// @{
 | 
			
		||||
 | 
			
		||||
  /// Defaut constructor, mainly for serialization.
 | 
			
		||||
  /// Default constructor, mainly for serialization.
 | 
			
		||||
  GaussianMixture() = default;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
| 
						 | 
				
			
			@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
 | 
			
		|||
  /// @name Standard API
 | 
			
		||||
  /// @{
 | 
			
		||||
 | 
			
		||||
  /// @brief Return the conditional Gaussian for the given discrete assignment.
 | 
			
		||||
  GaussianConditional::shared_ptr operator()(
 | 
			
		||||
      const DiscreteValues &discreteValues) const;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
 | 
			
		|||
   * @brief Compute the error of this Gaussian Mixture given the continuous
 | 
			
		||||
   * values and a discrete assignment.
 | 
			
		||||
   *
 | 
			
		||||
   * @param continuousValues Continuous values at which to compute the error.
 | 
			
		||||
   * @param discreteValues The discrete assignment for a specific mode sequence.
 | 
			
		||||
   * @param values Continuous values and discrete assignment.
 | 
			
		||||
   * @return double
 | 
			
		||||
   */
 | 
			
		||||
  double error(const VectorValues &continuousValues,
 | 
			
		||||
               const DiscreteValues &discreteValues) const;
 | 
			
		||||
  double error(const HybridValues &values) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Prune the decision tree of Gaussian factors as per the discrete
 | 
			
		||||
| 
						 | 
				
			
			@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
 | 
			
		|||
};
 | 
			
		||||
 | 
			
		||||
/// Return the DiscreteKey vector as a set.
 | 
			
		||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);
 | 
			
		||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
 | 
			
		||||
 | 
			
		||||
// traits
 | 
			
		||||
template <>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
/* ----------------------------------------------------------------------------
 | 
			
		||||
 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
			
		||||
 * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
 | 
			
		||||
 * Atlanta, Georgia 30332-0415
 | 
			
		||||
 * All Rights Reserved
 | 
			
		||||
 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
			
		||||
| 
						 | 
				
			
			@ -12,6 +12,7 @@
 | 
			
		|||
 * @author Fan Jiang
 | 
			
		||||
 * @author Varun Agrawal
 | 
			
		||||
 * @author Shangjie Xue
 | 
			
		||||
 * @author Frank Dellaert
 | 
			
		||||
 * @date   January 2022
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
double HybridBayesNet::error(const VectorValues &continuousValues,
 | 
			
		||||
                             const DiscreteValues &discreteValues) const {
 | 
			
		||||
  GaussianBayesNet gbn = choose(discreteValues);
 | 
			
		||||
  return gbn.error(continuousValues);
 | 
			
		||||
double HybridBayesNet::error(const HybridValues &values) const {
 | 
			
		||||
  GaussianBayesNet gbn = choose(values.discrete());
 | 
			
		||||
  return gbn.error(values.continuous());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
 | 
			
		|||
   * @brief 0.5 * sum of squared Mahalanobis distances
 | 
			
		||||
   * for a specific discrete assignment.
 | 
			
		||||
   *
 | 
			
		||||
   * @param continuousValues Continuous values at which to compute the error.
 | 
			
		||||
   * @param discreteValues Discrete assignment for a specific mode sequence.
 | 
			
		||||
   * @param values Continuous values and discrete assignment.
 | 
			
		||||
   * @return double
 | 
			
		||||
   */
 | 
			
		||||
  double error(const VectorValues &continuousValues,
 | 
			
		||||
               const DiscreteValues &discreteValues) const;
 | 
			
		||||
  double error(const HybridValues &values) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Compute conditional error for each discrete assignment,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -55,13 +55,14 @@
 | 
			
		|||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
 | 
			
		||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
static GaussianMixtureFactor::Sum &addGaussian(
 | 
			
		||||
    GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
 | 
			
		||||
  using Y = GaussianFactorGraph;
 | 
			
		||||
  // If the decision tree is not intiialized, then intialize it.
 | 
			
		||||
  // If the decision tree is not initialized, then initialize it.
 | 
			
		||||
  if (sum.empty()) {
 | 
			
		||||
    GaussianFactorGraph result;
 | 
			
		||||
    result.push_back(factor);
 | 
			
		||||
| 
						 | 
				
			
			@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals(
 | 
			
		|||
 | 
			
		||||
  for (auto &f : factors) {
 | 
			
		||||
    if (f->isHybrid()) {
 | 
			
		||||
      if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
 | 
			
		||||
        sum = cgmf->add(sum);
 | 
			
		||||
      // TODO(dellaert): just use a virtual method defined in HybridFactor.
 | 
			
		||||
      if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
 | 
			
		||||
        sum = gm->add(sum);
 | 
			
		||||
      }
 | 
			
		||||
      if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
 | 
			
		||||
        sum = gm->asMixture()->add(sum);
 | 
			
		||||
| 
						 | 
				
			
			@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
                  const KeySet &continuousSeparator,
 | 
			
		||||
                  const std::set<DiscreteKey> &discreteSeparatorSet) {
 | 
			
		||||
  // NOTE: since we use the special JunctionTree,
 | 
			
		||||
  // only possiblity is continuous conditioned on discrete.
 | 
			
		||||
  // only possibility is continuous conditioned on discrete.
 | 
			
		||||
  DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
 | 
			
		||||
                                 discreteSeparatorSet.end());
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -251,8 +253,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
 | 
			
		||||
  // Separate out decision tree into conditionals and remaining factors.
 | 
			
		||||
  auto pair = unzip(eliminationResults);
 | 
			
		||||
 | 
			
		||||
  const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
 | 
			
		||||
  const auto &separatorFactors = pair.second;
 | 
			
		||||
 | 
			
		||||
  // Create the GaussianMixture from the conditionals
 | 
			
		||||
  auto conditional = boost::make_shared<GaussianMixture>(
 | 
			
		||||
| 
						 | 
				
			
			@ -460,6 +461,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
 | 
			
		|||
 | 
			
		||||
  // Iterate over each factor.
 | 
			
		||||
  for (size_t idx = 0; idx < size(); idx++) {
 | 
			
		||||
    // TODO(dellaert): just use a virtual method defined in HybridFactor.
 | 
			
		||||
    AlgebraicDecisionTree<Key> factor_error;
 | 
			
		||||
 | 
			
		||||
    if (factors_.at(idx)->isHybrid()) {
 | 
			
		||||
| 
						 | 
				
			
			@ -499,27 +501,26 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
double HybridGaussianFactorGraph::error(
 | 
			
		||||
    const VectorValues &continuousValues,
 | 
			
		||||
    const DiscreteValues &discreteValues) const {
 | 
			
		||||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
 | 
			
		||||
  double error = 0.0;
 | 
			
		||||
  for (size_t idx = 0; idx < size(); idx++) {
 | 
			
		||||
    // TODO(dellaert): just use a virtual method defined in HybridFactor.
 | 
			
		||||
    auto factor = factors_.at(idx);
 | 
			
		||||
 | 
			
		||||
    if (factor->isHybrid()) {
 | 
			
		||||
      if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
 | 
			
		||||
        error += c->asMixture()->error(continuousValues, discreteValues);
 | 
			
		||||
        error += c->asMixture()->error(values);
 | 
			
		||||
      }
 | 
			
		||||
      if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
 | 
			
		||||
        error += f->error(continuousValues, discreteValues);
 | 
			
		||||
        error += f->error(values);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
    } else if (factor->isContinuous()) {
 | 
			
		||||
      if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
 | 
			
		||||
        error += f->inner()->error(continuousValues);
 | 
			
		||||
        error += f->inner()->error(values.continuous());
 | 
			
		||||
      }
 | 
			
		||||
      if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
 | 
			
		||||
        error += cg->asGaussian()->error(continuousValues);
 | 
			
		||||
        error += cg->asGaussian()->error(values.continuous());
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -527,10 +528,8 @@ double HybridGaussianFactorGraph::error(
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
double HybridGaussianFactorGraph::probPrime(
 | 
			
		||||
    const VectorValues &continuousValues,
 | 
			
		||||
    const DiscreteValues &discreteValues) const {
 | 
			
		||||
  double error = this->error(continuousValues, discreteValues);
 | 
			
		||||
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
 | 
			
		||||
  double error = this->error(values);
 | 
			
		||||
  // NOTE: The 0.5 term is handled by each factor
 | 
			
		||||
  return std::exp(-error);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -186,14 +186,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
 | 
			
		|||
   * @brief Compute error given a continuous vector values
 | 
			
		||||
   * and a discrete assignment.
 | 
			
		||||
   *
 | 
			
		||||
   * @param continuousValues The continuous VectorValues
 | 
			
		||||
   * for computing the error.
 | 
			
		||||
   * @param discreteValues The specific discrete assignment
 | 
			
		||||
   * whose error we wish to compute.
 | 
			
		||||
   * @return double
 | 
			
		||||
   */
 | 
			
		||||
  double error(const VectorValues& continuousValues,
 | 
			
		||||
               const DiscreteValues& discreteValues) const;
 | 
			
		||||
  double error(const HybridValues& values) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
 | 
			
		||||
| 
						 | 
				
			
			@ -210,13 +205,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
 | 
			
		|||
   * @brief Compute the unnormalized posterior probability for a continuous
 | 
			
		||||
   * vector values given a specific assignment.
 | 
			
		||||
   *
 | 
			
		||||
   * @param continuousValues The vector values for which to compute the
 | 
			
		||||
   * posterior probability.
 | 
			
		||||
   * @param discreteValues The specific assignment to use for the computation.
 | 
			
		||||
   * @return double
 | 
			
		||||
   */
 | 
			
		||||
  double probPrime(const VectorValues& continuousValues,
 | 
			
		||||
                   const DiscreteValues& discreteValues) const;
 | 
			
		||||
  double probPrime(const HybridValues& values) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Return a Colamd constrained ordering where the discrete keys are
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue