Made CheckInvariants a static method in Conditional.*
							parent
							
								
									b99d464049
								
							
						
					
					
						commit
						cd2d37e724
					
				|  | @ -63,4 +63,22 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const { | ||||||
|   return std::exp(logNormalizationConstant()); |   return std::exp(logNormalizationConstant()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | template <class FACTOR, class DERIVEDCONDITIONAL> | ||||||
|  | template <class VALUES> | ||||||
|  | bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants( | ||||||
|  |     const DERIVEDCONDITIONAL& conditional, const VALUES& values) { | ||||||
|  |   const double probability = conditional.evaluate(values); | ||||||
|  |   if (probability < 0.0 || probability > 1.0) | ||||||
|  |     return false;  // probability is not in [0,1]
 | ||||||
|  |   const double logProb = conditional.logProbability(values); | ||||||
|  |   if (std::abs(probability - std::exp(logProb)) > 1e-9) | ||||||
|  |     return false;  // logProb is not consistent with probability
 | ||||||
|  |   const double expected = | ||||||
|  |       conditional.logNormalizationConstant() - conditional.error(values); | ||||||
|  |   if (std::abs(logProb - expected) > 1e-9) | ||||||
|  |     return false;  // logProb is not consistent with error
 | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -181,6 +181,10 @@ namespace gtsam { | ||||||
|     /** Mutable iterator pointing past the last parent key. */ |     /** Mutable iterator pointing past the last parent key. */ | ||||||
|     typename FACTOR::iterator endParents() { return asFactor().end(); } |     typename FACTOR::iterator endParents() { return asFactor().end(); } | ||||||
| 
 | 
 | ||||||
|  |     template <class VALUES> | ||||||
|  |     static bool CheckInvariants(const DERIVEDCONDITIONAL& conditional, | ||||||
|  |                                 const VALUES& values); | ||||||
|  | 
 | ||||||
|     /// @}
 |     /// @}
 | ||||||
| 
 | 
 | ||||||
|   private: |   private: | ||||||
|  |  | ||||||
|  | @ -135,23 +135,6 @@ static const auto unitPrior = | ||||||
|                       noiseModel::Isotropic::Sigma(1, sigma)); |                       noiseModel::Isotropic::Sigma(1, sigma)); | ||||||
| }  // namespace density
 | }  // namespace density
 | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ |  | ||||||
| template <class VALUES> |  | ||||||
| bool checkInvariants(const GaussianConditional& conditional, |  | ||||||
|                      const VALUES& values) { |  | ||||||
|   const double probability = conditional.evaluate(values); |  | ||||||
|   if (probability < 0.0 || probability > 1.0) |  | ||||||
|     return false;  // probability is not in [0,1]
 |  | ||||||
|   const double logProb = conditional.logProbability(values); |  | ||||||
|   if (std::abs(probability - std::exp(logProb)) > 1e-9) |  | ||||||
|     return false;  // logProb is not consistent with probability
 |  | ||||||
|   const double expected = |  | ||||||
|       conditional.logNormalizationConstant() - conditional.error(values); |  | ||||||
|   if (std::abs(logProb - expected) > 1e-9) |  | ||||||
|     return false;  // logProb is not consistent with error
 |  | ||||||
|   return true; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| // Check that the evaluate function matches direct calculation with R.
 | // Check that the evaluate function matches direct calculation with R.
 | ||||||
| TEST(GaussianConditional, Evaluate1) { | TEST(GaussianConditional, Evaluate1) { | ||||||
|  | @ -174,8 +157,9 @@ TEST(GaussianConditional, Evaluate1) { | ||||||
| 
 | 
 | ||||||
|   // Check Invariants at the mean and a different value
 |   // Check Invariants at the mean and a different value
 | ||||||
|   for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { |   for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { | ||||||
|     EXPECT(checkInvariants(density::unitPrior, vv)); |     EXPECT(GaussianConditional::CheckInvariants(density::unitPrior, vv)); | ||||||
|     EXPECT(checkInvariants(density::unitPrior, HybridValues{vv, {}, {}})); |     EXPECT(GaussianConditional::CheckInvariants(density::unitPrior, | ||||||
|  |                                                 HybridValues{vv, {}, {}})); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // Let's numerically integrate and see that we integrate to 1.0.
 |   // Let's numerically integrate and see that we integrate to 1.0.
 | ||||||
|  | @ -206,8 +190,9 @@ TEST(GaussianConditional, Evaluate2) { | ||||||
| 
 | 
 | ||||||
|   // Check Invariants at the mean and a different value
 |   // Check Invariants at the mean and a different value
 | ||||||
|   for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { |   for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { | ||||||
|     EXPECT(checkInvariants(density::widerPrior, vv)); |     EXPECT(GaussianConditional::CheckInvariants(density::widerPrior, vv)); | ||||||
|     EXPECT(checkInvariants(density::widerPrior, HybridValues{vv, {}, {}})); |     EXPECT(GaussianConditional::CheckInvariants(density::widerPrior, | ||||||
|  |                                                 HybridValues{vv, {}, {}})); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // Let's numerically integrate and see that we integrate to 1.0.
 |   // Let's numerically integrate and see that we integrate to 1.0.
 | ||||||
|  | @ -422,8 +407,9 @@ TEST(GaussianConditional, FromMeanAndStddev) { | ||||||
| 
 | 
 | ||||||
|   // Check Invariants for both conditionals
 |   // Check Invariants for both conditionals
 | ||||||
|   for (auto conditional : {conditional1, conditional2}) { |   for (auto conditional : {conditional1, conditional2}) { | ||||||
|     EXPECT(checkInvariants(conditional, values)); |     EXPECT(GaussianConditional::CheckInvariants(conditional, values)); | ||||||
|     EXPECT(checkInvariants(conditional, HybridValues{values, {}, {}})); |     EXPECT(GaussianConditional::CheckInvariants(conditional, | ||||||
|  |                                                 HybridValues{values, {}, {}})); | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue