Merge pull request #1669 from borglab/discrete-error
						commit
						42b5218662
					
				| 
						 | 
				
			
			@ -62,6 +62,22 @@ namespace gtsam {
 | 
			
		|||
    return error(values.discrete());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************ */
 | 
			
		||||
  AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
 | 
			
		||||
    // Get all possible assignments
 | 
			
		||||
    DiscreteKeys dkeys = discreteKeys();
 | 
			
		||||
    // Reverse to make cartesian product output a more natural ordering.
 | 
			
		||||
    DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
 | 
			
		||||
    const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
 | 
			
		||||
 | 
			
		||||
    // Construct vector with error values
 | 
			
		||||
    std::vector<double> errors;
 | 
			
		||||
    for (const auto& assignment : assignments) {
 | 
			
		||||
      errors.push_back(error(assignment));
 | 
			
		||||
    }
 | 
			
		||||
    return AlgebraicDecisionTree<Key>(dkeys, errors);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************ */
 | 
			
		||||
  double DecisionTreeFactor::safe_div(const double& a, const double& b) {
 | 
			
		||||
    // The use for safe_div is when we divide the product factor by the sum
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -292,6 +292,9 @@ namespace gtsam {
 | 
			
		|||
   */
 | 
			
		||||
  double error(const HybridValues& values) const override;
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree() const override;
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
 | 
			
		||||
   private:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,9 +18,10 @@
 | 
			
		|||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
 | 
			
		||||
#include <gtsam/discrete/DiscreteValues.h>
 | 
			
		||||
#include <gtsam/inference/Factor.h>
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
| 
						 | 
				
			
			@ -35,7 +36,7 @@ class HybridValues;
 | 
			
		|||
 *
 | 
			
		||||
 * @ingroup discrete
 | 
			
		||||
 */
 | 
			
		||||
class GTSAM_EXPORT DiscreteFactor: public Factor {
 | 
			
		||||
class GTSAM_EXPORT DiscreteFactor : public Factor {
 | 
			
		||||
 public:
 | 
			
		||||
  // typedefs needed to play nice with gtsam
 | 
			
		||||
  typedef DiscreteFactor This;  ///< This class
 | 
			
		||||
| 
						 | 
				
			
			@ -103,7 +104,11 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
 | 
			
		|||
   */
 | 
			
		||||
  double error(const HybridValues& c) const override;
 | 
			
		||||
 | 
			
		||||
  /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
 | 
			
		||||
 | 
			
		||||
  /// Multiply in a DecisionTreeFactor and return the result as
 | 
			
		||||
  /// DecisionTreeFactor
 | 
			
		||||
  virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
 | 
			
		||||
 | 
			
		||||
  virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
 | 
			
		||||
| 
						 | 
				
			
			@ -111,7 +116,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
 | 
			
		|||
  /// @}
 | 
			
		||||
  /// @name Wrapper support
 | 
			
		||||
  /// @{
 | 
			
		||||
  
 | 
			
		||||
 | 
			
		||||
  /// Translation table from values to strings.
 | 
			
		||||
  using Names = DiscreteValues::Names;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -175,4 +180,4 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
 | 
			
		|||
std::vector<double> expNormalize(const std::vector<double> &logProbs);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}// namespace gtsam
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -168,6 +168,11 @@ double TableFactor::error(const HybridValues& values) const {
 | 
			
		|||
  return error(values.discrete());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
 | 
			
		||||
  return toDecisionTreeFactor().errorTree();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
 | 
			
		||||
  return toDecisionTreeFactor() * f;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -358,6 +358,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
 | 
			
		|||
   */
 | 
			
		||||
  double error(const HybridValues& values) const override;
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree() const override;
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -67,6 +67,24 @@ TEST( DecisionTreeFactor, constructors)
 | 
			
		|||
  EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DecisionTreeFactor, Error) {
 | 
			
		||||
  // Declare a bunch of keys
 | 
			
		||||
  DiscreteKey X(0,2), Y(1,3), Z(2,2);
 | 
			
		||||
 | 
			
		||||
  // Create factors
 | 
			
		||||
  DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
 | 
			
		||||
 | 
			
		||||
  auto errors = f.errorTree();
 | 
			
		||||
  // regression
 | 
			
		||||
  AlgebraicDecisionTree<Key> expected(
 | 
			
		||||
      {X, Y, Z},
 | 
			
		||||
      vector<double>{-0.69314718, -1.6094379, -1.0986123, -1.7917595,
 | 
			
		||||
                     -1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481,
 | 
			
		||||
                     -4.1743873, -3.8066625, -4.3174881});
 | 
			
		||||
  EXPECT(assert_equal(expected, errors, 1e-6));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DecisionTreeFactor, multiplication) {
 | 
			
		||||
  DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -313,14 +313,14 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
AlgebraicDecisionTree<Key> GaussianMixture::error(
 | 
			
		||||
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
 | 
			
		||||
    const VectorValues &continuousValues) const {
 | 
			
		||||
  auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
 | 
			
		||||
    return conditional->error(continuousValues) +  //
 | 
			
		||||
           logConstant_ - conditional->logNormalizationConstant();
 | 
			
		||||
  };
 | 
			
		||||
  DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
 | 
			
		||||
  return errorTree;
 | 
			
		||||
  DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
 | 
			
		||||
  return error_tree;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -214,7 +214,7 @@ class GTSAM_EXPORT GaussianMixture
 | 
			
		|||
   * @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
 | 
			
		||||
   * only, with the leaf values as the error for each assignment.
 | 
			
		||||
   */
 | 
			
		||||
  AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Compute the logProbability of this Gaussian Mixture.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -102,14 +102,14 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
 | 
			
		||||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::errorTree(
 | 
			
		||||
    const VectorValues &continuousValues) const {
 | 
			
		||||
  // functor to convert from sharedFactor to double error value.
 | 
			
		||||
  auto errorFunc = [&continuousValues](const sharedFactor &gf) {
 | 
			
		||||
    return gf->error(continuousValues);
 | 
			
		||||
  };
 | 
			
		||||
  DecisionTree<Key, double> errorTree(factors_, errorFunc);
 | 
			
		||||
  return errorTree;
 | 
			
		||||
  DecisionTree<Key, double> error_tree(factors_, errorFunc);
 | 
			
		||||
  return error_tree;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -135,7 +135,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
 | 
			
		|||
   * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
 | 
			
		||||
   * as the factors involved, and leaf values as the error.
 | 
			
		||||
   */
 | 
			
		||||
  AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Compute the log-likelihood, including the log-normalizing constant.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -420,7 +420,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
 | 
			
		||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
 | 
			
		||||
    const VectorValues &continuousValues) const {
 | 
			
		||||
  AlgebraicDecisionTree<Key> error_tree(0.0);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -431,7 +431,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
 | 
			
		|||
 | 
			
		||||
    if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
 | 
			
		||||
      // Compute factor error and add it.
 | 
			
		||||
      error_tree = error_tree + gaussianMixture->error(continuousValues);
 | 
			
		||||
      error_tree = error_tree + gaussianMixture->errorTree(continuousValues);
 | 
			
		||||
    } else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
 | 
			
		||||
      // If continuous only, get the (double) error
 | 
			
		||||
      // and add it to the error_tree
 | 
			
		||||
| 
						 | 
				
			
			@ -460,7 +460,7 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
 | 
			
		|||
/* ************************************************************************ */
 | 
			
		||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
 | 
			
		||||
    const VectorValues &continuousValues) const {
 | 
			
		||||
  AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
 | 
			
		||||
  AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
 | 
			
		||||
  AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
 | 
			
		||||
    // NOTE: The 0.5 term is handled by each factor
 | 
			
		||||
    return exp(-error);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -161,7 +161,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
 | 
			
		|||
   * @param continuousValues Continuous values at which to compute the error.
 | 
			
		||||
   * @return AlgebraicDecisionTree<Key>
 | 
			
		||||
   */
 | 
			
		||||
  AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree(
 | 
			
		||||
      const VectorValues& continuousValues) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -131,13 +131,13 @@ class MixtureFactor : public HybridFactor {
 | 
			
		|||
   * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
 | 
			
		||||
   * as the factor, and leaf values as the error.
 | 
			
		||||
   */
 | 
			
		||||
  AlgebraicDecisionTree<Key> error(const Values& continuousValues) const {
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const {
 | 
			
		||||
    // functor to convert from sharedFactor to double error value.
 | 
			
		||||
    auto errorFunc = [continuousValues](const sharedFactor& factor) {
 | 
			
		||||
      return factor->error(continuousValues);
 | 
			
		||||
    };
 | 
			
		||||
    DecisionTree<Key, double> errorTree(factors_, errorFunc);
 | 
			
		||||
    return errorTree;
 | 
			
		||||
    DecisionTree<Key, double> result(factors_, errorFunc);
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,7 +97,7 @@ TEST(GaussianMixture, LogProbability) {
 | 
			
		|||
/// Check error.
 | 
			
		||||
TEST(GaussianMixture, Error) {
 | 
			
		||||
  using namespace equal_constants;
 | 
			
		||||
  auto actual = mixture.error(vv);
 | 
			
		||||
  auto actual = mixture.errorTree(vv);
 | 
			
		||||
 | 
			
		||||
  // Check result.
 | 
			
		||||
  std::vector<DiscreteKey> discrete_keys = {mode};
 | 
			
		||||
| 
						 | 
				
			
			@ -134,7 +134,7 @@ TEST(GaussianMixture, Likelihood) {
 | 
			
		|||
  std::vector<double> leaves = {conditionals[0]->likelihood(vv)->error(vv),
 | 
			
		||||
                                conditionals[1]->likelihood(vv)->error(vv)};
 | 
			
		||||
  AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
 | 
			
		||||
  EXPECT(assert_equal(expected, likelihood->error(vv), 1e-6));
 | 
			
		||||
  EXPECT(assert_equal(expected, likelihood->errorTree(vv), 1e-6));
 | 
			
		||||
 | 
			
		||||
  // Check that the ratio of probPrime to evaluate is the same for all modes.
 | 
			
		||||
  std::vector<double> ratio(2);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -178,7 +178,7 @@ TEST(GaussianMixtureFactor, Error) {
 | 
			
		|||
  continuousValues.insert(X(2), Vector2(1, 1));
 | 
			
		||||
 | 
			
		||||
  // error should return a tree of errors, with nodes for each discrete value.
 | 
			
		||||
  AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
 | 
			
		||||
  AlgebraicDecisionTree<Key> error_tree = mixtureFactor.errorTree(continuousValues);
 | 
			
		||||
 | 
			
		||||
  std::vector<DiscreteKey> discrete_keys = {m1};
 | 
			
		||||
  // Error values for regression test
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -580,7 +580,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
 | 
			
		|||
  HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
 | 
			
		||||
 | 
			
		||||
  HybridValues delta = hybridBayesNet->optimize();
 | 
			
		||||
  auto error_tree = graph.error(delta.continuous());
 | 
			
		||||
  auto error_tree = graph.errorTree(delta.continuous());
 | 
			
		||||
 | 
			
		||||
  std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
 | 
			
		||||
  std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,7 +97,8 @@ TEST(MixtureFactor, Error) {
 | 
			
		|||
  continuousValues.insert<double>(X(1), 0);
 | 
			
		||||
  continuousValues.insert<double>(X(2), 1);
 | 
			
		||||
 | 
			
		||||
  AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
 | 
			
		||||
  AlgebraicDecisionTree<Key> error_tree =
 | 
			
		||||
      mixtureFactor.errorTree(continuousValues);
 | 
			
		||||
 | 
			
		||||
  DiscreteKey m1(1, 2);
 | 
			
		||||
  std::vector<DiscreteKey> discrete_keys = {m1};
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -53,6 +53,11 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
 | 
			
		|||
  /// Multiply into a decisiontree
 | 
			
		||||
  DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree() const override {
 | 
			
		||||
    throw std::runtime_error("AllDiff::error not implemented");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /*
 | 
			
		||||
   * Ensure Arc-consistency by checking every possible value of domain j.
 | 
			
		||||
   * @param j domain to be checked
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -91,6 +91,11 @@ class BinaryAllDiff : public Constraint {
 | 
			
		|||
      const Domains&) const override {
 | 
			
		||||
    throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree() const override {
 | 
			
		||||
    throw std::runtime_error("BinaryAllDiff::error not implemented");
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -69,6 +69,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree() const override {
 | 
			
		||||
    throw std::runtime_error("Domain::error not implemented");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Return concise string representation, mostly to debug arc consistency.
 | 
			
		||||
  // Converts from base 0 to base1.
 | 
			
		||||
  std::string base1Str() const;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,6 +49,11 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Compute error for each assignment and return as a tree
 | 
			
		||||
  AlgebraicDecisionTree<Key> errorTree() const override {
 | 
			
		||||
    throw std::runtime_error("SingleValue::error not implemented");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Calculate value
 | 
			
		||||
  double operator()(const DiscreteValues& values) const override;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue