discretePosterior
							parent
							
								
									38ed609614
								
							
						
					
					
						commit
						a898ad3661
					
				| 
						 | 
					@ -23,6 +23,7 @@
 | 
				
			||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
 | 
					#include <gtsam/discrete/DiscreteEliminationTree.h>
 | 
				
			||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
 | 
					#include <gtsam/discrete/DiscreteFactorGraph.h>
 | 
				
			||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
 | 
					#include <gtsam/discrete/DiscreteJunctionTree.h>
 | 
				
			||||||
 | 
					#include <gtsam/discrete/DiscreteKey.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridConditional.h>
 | 
					#include <gtsam/hybrid/HybridConditional.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridEliminationTree.h>
 | 
					#include <gtsam/hybrid/HybridEliminationTree.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridFactor.h>
 | 
					#include <gtsam/hybrid/HybridFactor.h>
 | 
				
			||||||
| 
						 | 
					@ -42,7 +43,6 @@
 | 
				
			||||||
#include <algorithm>
 | 
					#include <algorithm>
 | 
				
			||||||
#include <cstddef>
 | 
					#include <cstddef>
 | 
				
			||||||
#include <iostream>
 | 
					#include <iostream>
 | 
				
			||||||
#include <iterator>
 | 
					 | 
				
			||||||
#include <memory>
 | 
					#include <memory>
 | 
				
			||||||
#include <stdexcept>
 | 
					#include <stdexcept>
 | 
				
			||||||
#include <utility>
 | 
					#include <utility>
 | 
				
			||||||
| 
						 | 
					@ -342,14 +342,20 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
 | 
				
			||||||
  return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
 | 
					  return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* *******************************************************************************/
 | 
				
			||||||
 | 
					/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys.
 | 
				
			||||||
 | 
					static auto GetDiscreteKeys =
 | 
				
			||||||
 | 
					    [](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys {
 | 
				
			||||||
 | 
					  const std::set<DiscreteKey> discreteKeySet = hfg.discreteKeys();
 | 
				
			||||||
 | 
					  return {discreteKeySet.begin(), discreteKeySet.end()};
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* *******************************************************************************/
 | 
					/* *******************************************************************************/
 | 
				
			||||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
 | 
					std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
 | 
				
			||||||
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
 | 
					HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
 | 
				
			||||||
  // Since we eliminate all continuous variables first,
 | 
					  // Since we eliminate all continuous variables first,
 | 
				
			||||||
  // the discrete separator will be *all* the discrete keys.
 | 
					  // the discrete separator will be *all* the discrete keys.
 | 
				
			||||||
  const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
 | 
					  DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
 | 
				
			||||||
  DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
 | 
					 | 
				
			||||||
                                 keysForDiscreteVariables.end());
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Collect all the factors to create a set of Gaussian factor graphs in a
 | 
					  // Collect all the factors to create a set of Gaussian factor graphs in a
 | 
				
			||||||
  // decision tree indexed by all discrete keys involved.
 | 
					  // decision tree indexed by all discrete keys involved.
 | 
				
			||||||
| 
						 | 
					@ -525,14 +531,21 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ************************************************************************ */
 | 
					/* ************************************************************************ */
 | 
				
			||||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
 | 
					DecisionTreeFactor HybridGaussianFactorGraph::probPrime(
 | 
				
			||||||
    const VectorValues &continuousValues) const {
 | 
					    const VectorValues &continuousValues) const {
 | 
				
			||||||
  AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
 | 
					  AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
 | 
				
			||||||
  AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
 | 
					  AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
 | 
				
			||||||
    // NOTE: The 0.5 term is handled by each factor
 | 
					    // NOTE: The 0.5 term is handled by each factor
 | 
				
			||||||
    return exp(-error);
 | 
					    return exp(-error);
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
  return prob_tree;
 | 
					  return {GetDiscreteKeys(*this), prob_tree};
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					DiscreteConditional HybridGaussianFactorGraph::discretePosterior(
 | 
				
			||||||
 | 
					    const VectorValues &continuousValues) const {
 | 
				
			||||||
 | 
					  auto p = probPrime(continuousValues);
 | 
				
			||||||
 | 
					  return {p.size(), p};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ************************************************************************ */
 | 
					/* ************************************************************************ */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#pragma once
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <gtsam/discrete/DiscreteKey.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridFactor.h>
 | 
					#include <gtsam/hybrid/HybridFactor.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
 | 
					#include <gtsam/hybrid/HybridFactorGraph.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
 | 
					#include <gtsam/hybrid/HybridGaussianFactor.h>
 | 
				
			||||||
| 
						 | 
					@ -187,17 +188,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
 | 
				
			||||||
  AlgebraicDecisionTree<Key> errorTree(
 | 
					  AlgebraicDecisionTree<Key> errorTree(
 | 
				
			||||||
      const VectorValues& continuousValues) const;
 | 
					      const VectorValues& continuousValues) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
 | 
					 | 
				
			||||||
   * for each discrete assignment, and return as a tree.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param continuousValues Continuous values at which to compute the
 | 
					 | 
				
			||||||
   * probability.
 | 
					 | 
				
			||||||
   * @return AlgebraicDecisionTree<Key>
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  AlgebraicDecisionTree<Key> probPrime(
 | 
					 | 
				
			||||||
      const VectorValues& continuousValues) const;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * @brief Compute the unnormalized posterior probability for a continuous
 | 
					   * @brief Compute the unnormalized posterior probability for a continuous
 | 
				
			||||||
   * vector values given a specific assignment.
 | 
					   * vector values given a specific assignment.
 | 
				
			||||||
| 
						 | 
					@ -206,6 +196,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  double probPrime(const HybridValues& values) const;
 | 
					  double probPrime(const HybridValues& values) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
 | 
				
			||||||
 | 
					   * for each discrete assignment, and return as a tree.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param continuousValues Continuous values at which to compute probability.
 | 
				
			||||||
 | 
					   * @return DecisionTreeFactor
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  DecisionTreeFactor probPrime(const VectorValues& continuousValues) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * @brief Computer posterior P(M|X=x) when all continuous values X are given.
 | 
				
			||||||
 | 
					   * This is very efficient as this simply probPrime normalized into a
 | 
				
			||||||
 | 
					   * conditional.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param continuousValues Continuous values x to condition on.
 | 
				
			||||||
 | 
					   * @return DecisionTreeFactor
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  DiscreteConditional discretePosterior(
 | 
				
			||||||
 | 
					      const VectorValues& continuousValues) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * @brief Create a decision tree of factor graphs out of this hybrid factor
 | 
					   * @brief Create a decision tree of factor graphs out of this hybrid factor
 | 
				
			||||||
   * graph.
 | 
					   * graph.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -603,29 +603,34 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
 | 
				
			||||||
/* ****************************************************************************/
 | 
					/* ****************************************************************************/
 | 
				
			||||||
// Test hybrid gaussian factor graph error and unnormalized probabilities
 | 
					// Test hybrid gaussian factor graph error and unnormalized probabilities
 | 
				
			||||||
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
 | 
					TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
 | 
				
			||||||
 | 
					  // Create switching network with three continuous variables and two discrete:
 | 
				
			||||||
 | 
					  // ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
 | 
				
			||||||
  Switching s(3);
 | 
					  Switching s(3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
 | 
					  const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
 | 
					  const HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  HybridValues delta = hybridBayesNet->optimize();
 | 
					  const HybridValues delta = hybridBayesNet->optimize();
 | 
				
			||||||
  auto error_tree = graph.errorTree(delta.continuous());
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
 | 
					  // regression test for errorTree
 | 
				
			||||||
  std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
 | 
					  std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
 | 
				
			||||||
  AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
 | 
					  AlgebraicDecisionTree<Key> expectedErrors(s.modes, leaves);
 | 
				
			||||||
 | 
					  const auto error_tree = graph.errorTree(delta.continuous());
 | 
				
			||||||
  // regression
 | 
					  EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
 | 
				
			||||||
  EXPECT(assert_equal(expected_error, error_tree, 1e-7));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // regression test for probPrime
 | 
				
			||||||
 | 
					  const DecisionTreeFactor expectedFactor(
 | 
				
			||||||
 | 
					      s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064});
 | 
				
			||||||
  auto probabilities = graph.probPrime(delta.continuous());
 | 
					  auto probabilities = graph.probPrime(delta.continuous());
 | 
				
			||||||
  std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
 | 
					  EXPECT(assert_equal(expectedFactor, probabilities, 1e-7));
 | 
				
			||||||
                                     0.99029064};
 | 
					 | 
				
			||||||
  AlgebraicDecisionTree<Key> expected_probabilities(discrete_keys, prob_leaves);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // regression
 | 
					  // regression test for discretePosterior
 | 
				
			||||||
  EXPECT(assert_equal(expected_probabilities, probabilities, 1e-7));
 | 
					  const DecisionTreeFactor normalized(
 | 
				
			||||||
 | 
					      s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
 | 
				
			||||||
 | 
					  DiscreteConditional expectedPosterior(2, normalized);
 | 
				
			||||||
 | 
					  auto posterior = graph.discretePosterior(delta.continuous());
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ****************************************************************************/
 | 
					/* ****************************************************************************/
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue