From 11ef99b3f098343941b840fc5ce92f2aa241cb4b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 10 Jan 2023 15:16:55 -0800 Subject: [PATCH] discrete now has all logProbability and evaluate versions needed. --- gtsam/discrete/DecisionTreeFactor.cpp | 11 +++++++++ gtsam/discrete/DecisionTreeFactor.h | 24 +++++++++++++++++-- gtsam/discrete/DiscreteConditional.h | 22 ++++++++++++++++- .../tests/testDiscreteConditional.cpp | 23 ++++++++++++++++++ 4 files changed, 77 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 7f604086c..14a24b6e6 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -18,6 +18,7 @@ */ #include +#include #include #include @@ -56,6 +57,16 @@ namespace gtsam { } } + /* ************************************************************************ */ + double DecisionTreeFactor::error(const DiscreteValues& values) const { + return -std::log(evaluate(values)); + } + + /* ************************************************************************ */ + double DecisionTreeFactor::error(const HybridValues& values) const { + return error(values.discrete()); + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f759c10f3..dd292cae8 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -34,6 +34,7 @@ namespace gtsam { class DiscreteConditional; + class HybridValues; /** * A discrete probabilistic factor. @@ -97,11 +98,20 @@ namespace gtsam { /// @name Standard Interface /// @{ - /// Value is just look up in AlgebraicDecisionTree. + /// Calculate probability for given values `x`, + /// is just look up in AlgebraicDecisionTree. + double evaluate(const DiscreteValues& values) const { + return ADT::operator()(values); + } + + /// Evaluate probability density, sugar. double operator()(const DiscreteValues& values) const override { return ADT::operator()(values); } + /// Calculate error for DiscreteValues `x`, is -log(probability). + double error(const DiscreteValues& values) const; + /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, ADT::Ring::mul); @@ -230,7 +240,17 @@ namespace gtsam { std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Names& names = {}) const override; - /// @} + /// @} + /// @name HybridValues methods. + /// @{ + + /** + * Calculate error for HybridValues `x`, is -log(probability) + * Simply dispatches to DiscreteValues version. + */ + double error(const HybridValues& values) const override; + + /// @} private: /** Serialization function */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index b68953eb5..94451d407 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -18,9 +18,9 @@ #pragma once +#include #include #include -#include #include #include @@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional /// @name Standard Interface /// @{ + /// Log-probability is just -error(x). + double logProbability(const DiscreteValues& x) const { + return -error(x); + } + /// print index signature only void printSignature( const std::string& s = "Discrete Conditional: ", @@ -225,6 +230,21 @@ class GTSAM_EXPORT DiscreteConditional std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Names& names = {}) const override; + + /// @} + /// @name HybridValues methods. + /// @{ + + /** + * Calculate log-probability log(evaluate(x)) for HybridValues `x`. + * This is actually just -error(x). + */ + double logProbability(const HybridValues& x) const override { + return -error(x); + } + + using DecisionTreeFactor::evaluate; + /// @} #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index f4a2e30ea..fdfe4a145 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -88,6 +88,29 @@ TEST(DiscreteConditional, constructors3) { EXPECT(assert_equal(expected, static_cast(actual))); } +/* ****************************************************************************/ +// Test evaluate for a discrete Prior P(Asia). +TEST(DiscreteConditional, PriorProbability) { + constexpr Key asiaKey = 0; + const DiscreteKey Asia(asiaKey, 2); + DiscreteConditional dc(Asia, "4/6"); + DiscreteValues values{{asiaKey, 0}}; + EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9); +} + +/* ************************************************************************* */ +// Check that error, logProbability, evaluate all work as expected. +TEST(DiscreteConditional, probability) { + DiscreteKey C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}}; + EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9); + EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9); + EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9); + EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9); +} + /* ************************************************************************* */ // Check calculation of joint P(A,B) TEST(DiscreteConditional, Multiply) {