diff --git a/gtsam/inference/Conditional-inst.h b/gtsam/inference/Conditional-inst.h index 1b439649e..8445b74bd 100644 --- a/gtsam/inference/Conditional-inst.h +++ b/gtsam/inference/Conditional-inst.h @@ -63,4 +63,22 @@ double Conditional::normalizationConstant() const { return std::exp(logNormalizationConstant()); } +/* ************************************************************************* */ +template +template +bool Conditional::CheckInvariants( + const DERIVEDCONDITIONAL& conditional, const VALUES& values) { + const double probability = conditional.evaluate(values); + if (probability < 0.0 || probability > 1.0) + return false; // probability is not in [0,1] + const double logProb = conditional.logProbability(values); + if (std::abs(probability - std::exp(logProb)) > 1e-9) + return false; // logProb is not consistent with probability + const double expected = + conditional.logNormalizationConstant() - conditional.error(values); + if (std::abs(logProb - expected) > 1e-9) + return false; // logProb is not consistent with error + return true; +} + } // namespace gtsam diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index bba4c7bd5..b4b1080aa 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -181,6 +181,10 @@ namespace gtsam { /** Mutable iterator pointing past the last parent key. */ typename FACTOR::iterator endParents() { return asFactor().end(); } + template + static bool CheckInvariants(const DERIVEDCONDITIONAL& conditional, + const VALUES& values); + /// @} private: diff --git a/gtsam/linear/tests/testGaussianConditional.cpp b/gtsam/linear/tests/testGaussianConditional.cpp index 12c668c25..0479ce9a1 100644 --- a/gtsam/linear/tests/testGaussianConditional.cpp +++ b/gtsam/linear/tests/testGaussianConditional.cpp @@ -135,23 +135,6 @@ static const auto unitPrior = noiseModel::Isotropic::Sigma(1, sigma)); } // namespace density -/* ************************************************************************* */ -template -bool checkInvariants(const GaussianConditional& conditional, - const VALUES& values) { - const double probability = conditional.evaluate(values); - if (probability < 0.0 || probability > 1.0) - return false; // probability is not in [0,1] - const double logProb = conditional.logProbability(values); - if (std::abs(probability - std::exp(logProb)) > 1e-9) - return false; // logProb is not consistent with probability - const double expected = - conditional.logNormalizationConstant() - conditional.error(values); - if (std::abs(logProb - expected) > 1e-9) - return false; // logProb is not consistent with error - return true; -} - /* ************************************************************************* */ // Check that the evaluate function matches direct calculation with R. TEST(GaussianConditional, Evaluate1) { @@ -174,8 +157,9 @@ TEST(GaussianConditional, Evaluate1) { // Check Invariants at the mean and a different value for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { - EXPECT(checkInvariants(density::unitPrior, vv)); - EXPECT(checkInvariants(density::unitPrior, HybridValues{vv, {}, {}})); + EXPECT(GaussianConditional::CheckInvariants(density::unitPrior, vv)); + EXPECT(GaussianConditional::CheckInvariants(density::unitPrior, + HybridValues{vv, {}, {}})); } // Let's numerically integrate and see that we integrate to 1.0. @@ -206,8 +190,9 @@ TEST(GaussianConditional, Evaluate2) { // Check Invariants at the mean and a different value for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { - EXPECT(checkInvariants(density::widerPrior, vv)); - EXPECT(checkInvariants(density::widerPrior, HybridValues{vv, {}, {}})); + EXPECT(GaussianConditional::CheckInvariants(density::widerPrior, vv)); + EXPECT(GaussianConditional::CheckInvariants(density::widerPrior, + HybridValues{vv, {}, {}})); } // Let's numerically integrate and see that we integrate to 1.0. @@ -422,8 +407,9 @@ TEST(GaussianConditional, FromMeanAndStddev) { // Check Invariants for both conditionals for (auto conditional : {conditional1, conditional2}) { - EXPECT(checkInvariants(conditional, values)); - EXPECT(checkInvariants(conditional, HybridValues{values, {}, {}})); + EXPECT(GaussianConditional::CheckInvariants(conditional, values)); + EXPECT(GaussianConditional::CheckInvariants(conditional, + HybridValues{values, {}, {}})); } }