make evaluate a common base method
							parent
							
								
									3add91d399
								
							
						
					
					
						commit
						88b36da602
					
				|  | @ -130,14 +130,9 @@ namespace gtsam { | ||||||
|     /// @name Standard Interface
 |     /// @name Standard Interface
 | ||||||
|     /// @{
 |     /// @{
 | ||||||
| 
 | 
 | ||||||
|     /// Calculate probability for given values `x`, 
 |     /// Calculate probability for given values, 
 | ||||||
|     /// is just look up in AlgebraicDecisionTree.
 |     /// is just look up in AlgebraicDecisionTree.
 | ||||||
|     double evaluate(const Assignment<Key>& values) const  { |     double operator()(const Assignment<Key>& values) const override { | ||||||
|       return ADT::operator()(values); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /// Evaluate probability distribution, sugar.
 |  | ||||||
|     double operator()(const DiscreteValues& values) const override { |  | ||||||
|       return ADT::operator()(values); |       return ADT::operator()(values); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -92,8 +92,21 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { | ||||||
| 
 | 
 | ||||||
|   size_t cardinality(Key j) const { return cardinalities_.at(j); } |   size_t cardinality(Key j) const { return cardinalities_.at(j); } | ||||||
| 
 | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Calculate probability for given values. | ||||||
|  |    * Calls specialized evaluation under the hood. | ||||||
|  |    * | ||||||
|  |    * Note: Uses Assignment<Key> as it is the base class of DiscreteValues. | ||||||
|  |    * | ||||||
|  |    * @param values Discrete assignment. | ||||||
|  |    * @return double | ||||||
|  |    */ | ||||||
|  |   double evaluate(const Assignment<Key>& values) const { | ||||||
|  |     return operator()(values); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /// Find value for given assignment of values to variables
 |   /// Find value for given assignment of values to variables
 | ||||||
|   virtual double operator()(const DiscreteValues&) const = 0; |   virtual double operator()(const Assignment<Key>& values) const = 0; | ||||||
| 
 | 
 | ||||||
|   /// Error is just -log(value)
 |   /// Error is just -log(value)
 | ||||||
|   virtual double error(const DiscreteValues& values) const; |   virtual double error(const DiscreteValues& values) const; | ||||||
|  |  | ||||||
|  | @ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************ */ | /* ************************************************************************ */ | ||||||
| double TableFactor::operator()(const DiscreteValues& values) const { | double TableFactor::operator()(const Assignment<Key>& values) const { | ||||||
|   // a b c d => D * (C * (B * (a) + b) + c) + d
 |   // a b c d => D * (C * (B * (a) + b) + c) + d
 | ||||||
|   uint64_t idx = 0, card = 1; |   uint64_t idx = 0, card = 1; | ||||||
|   for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { |   for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { | ||||||
|  | @ -180,6 +180,7 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | ||||||
|   for (auto i = 0; i < sparse_table_.size(); i++) { |   for (auto i = 0; i < sparse_table_.size(); i++) { | ||||||
|     table.push_back(sparse_table_.coeff(i)); |     table.push_back(sparse_table_.coeff(i)); | ||||||
|   } |   } | ||||||
|  |   // NOTE(Varun): This constructor is really expensive!!
 | ||||||
|   DecisionTreeFactor f(dkeys, table); |   DecisionTreeFactor f(dkeys, table); | ||||||
|   return f; |   return f; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -155,14 +155,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { | ||||||
|   // /// @name Standard Interface
 |   // /// @name Standard Interface
 | ||||||
|   // /// @{
 |   // /// @{
 | ||||||
| 
 | 
 | ||||||
|   /// Calculate probability for given values `x`,
 |   /// Evaluate probability distribution, is just look up in TableFactor.
 | ||||||
|   /// is just look up in TableFactor.
 |   double operator()(const Assignment<Key>& values) const override; | ||||||
|   double evaluate(const DiscreteValues& values) const { |  | ||||||
|     return operator()(values); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Evaluate probability distribution, sugar.
 |  | ||||||
|   double operator()(const DiscreteValues& values) const override; |  | ||||||
| 
 | 
 | ||||||
|   /// Calculate error for DiscreteValues `x`, is -log(probability).
 |   /// Calculate error for DiscreteValues `x`, is -log(probability).
 | ||||||
|   double error(const DiscreteValues& values) const override; |   double error(const DiscreteValues& values) const override; | ||||||
|  |  | ||||||
|  | @ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| double AllDiff::operator()(const DiscreteValues& values) const { | double AllDiff::operator()(const Assignment<Key>& values) const { | ||||||
|   std::set<size_t> taken;  // record values taken by keys
 |   std::set<size_t> taken;  // record values taken by keys
 | ||||||
|   for (Key dkey : keys_) { |   for (Key dkey : keys_) { | ||||||
|     size_t value = values.at(dkey);      // get the value for that key
 |     size_t value = values.at(dkey);      // get the value for that key
 | ||||||
|  |  | ||||||
|  | @ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /// Calculate value = expensive !
 |   /// Calculate value = expensive !
 | ||||||
|   double operator()(const DiscreteValues& values) const override; |   double operator()(const Assignment<Key>& values) const override; | ||||||
| 
 | 
 | ||||||
|   /// Convert into a decisiontree, can be *very* expensive !
 |   /// Convert into a decisiontree, can be *very* expensive !
 | ||||||
|   DecisionTreeFactor toDecisionTreeFactor() const override; |   DecisionTreeFactor toDecisionTreeFactor() const override; | ||||||
|  |  | ||||||
|  | @ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /// Calculate value
 |   /// Calculate value
 | ||||||
|   double operator()(const DiscreteValues& values) const override { |   double operator()(const Assignment<Key>& values) const override { | ||||||
|     return (double)(values.at(keys_[0]) != values.at(keys_[1])); |     return (double)(values.at(keys_[0]) != values.at(keys_[1])); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -30,7 +30,7 @@ string Domain::base1Str() const { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| double Domain::operator()(const DiscreteValues& values) const { | double Domain::operator()(const Assignment<Key>& values) const { | ||||||
|   return contains(values.at(key())); |   return contains(values.at(key())); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { | ||||||
|   bool contains(size_t value) const { return values_.count(value) > 0; } |   bool contains(size_t value) const { return values_.count(value) > 0; } | ||||||
| 
 | 
 | ||||||
|   /// Calculate value
 |   /// Calculate value
 | ||||||
|   double operator()(const DiscreteValues& values) const override; |   double operator()(const Assignment<Key>& values) const override; | ||||||
| 
 | 
 | ||||||
|   /// Convert into a decisiontree
 |   /// Convert into a decisiontree
 | ||||||
|   DecisionTreeFactor toDecisionTreeFactor() const override; |   DecisionTreeFactor toDecisionTreeFactor() const override; | ||||||
|  |  | ||||||
|  | @ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| double SingleValue::operator()(const DiscreteValues& values) const { | double SingleValue::operator()(const Assignment<Key>& values) const { | ||||||
|   return (double)(values.at(keys_[0]) == value_); |   return (double)(values.at(keys_[0]) == value_); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /// Calculate value
 |   /// Calculate value
 | ||||||
|   double operator()(const DiscreteValues& values) const override; |   double operator()(const Assignment<Key>& values) const override; | ||||||
| 
 | 
 | ||||||
|   /// Convert into a decisiontree
 |   /// Convert into a decisiontree
 | ||||||
|   DecisionTreeFactor toDecisionTreeFactor() const override; |   DecisionTreeFactor toDecisionTreeFactor() const override; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue