discrete now has all logProbability and evaluate versions needed.

release/4.3a0
Frank Dellaert 2023-01-10 15:16:55 -08:00
parent f89ef731a5
commit 11ef99b3f0
4 changed files with 77 additions and 3 deletions

View File

@ -18,6 +18,7 @@
*/ */
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
@ -56,6 +57,16 @@ namespace gtsam {
} }
} }
/* ************************************************************************ */
double DecisionTreeFactor::error(const DiscreteValues& values) const {
return -std::log(evaluate(values));
}
/* ************************************************************************ */
double DecisionTreeFactor::error(const HybridValues& values) const {
return error(values.discrete());
}
/* ************************************************************************ */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum

View File

@ -34,6 +34,7 @@
namespace gtsam { namespace gtsam {
class DiscreteConditional; class DiscreteConditional;
class HybridValues;
/** /**
* A discrete probabilistic factor. * A discrete probabilistic factor.
@ -97,11 +98,20 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Value is just look up in AlgebraicDecisionTree. /// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}
/// Evaluate probability density, sugar.
double operator()(const DiscreteValues& values) const override { double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values); return ADT::operator()(values);
} }
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
/// multiply two factors /// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, ADT::Ring::mul); return apply(f, ADT::Ring::mul);
@ -230,6 +240,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate error for HybridValues `x`, is -log(probability)
* Simply dispatches to DiscreteValues version.
*/
double error(const HybridValues& values) const override;
/// @} /// @}
private: private:

View File

@ -18,9 +18,9 @@
#pragma once #pragma once
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const {
return -error(x);
}
/// print index signature only /// print index signature only
void printSignature( void printSignature(
const std::string& s = "Discrete Conditional: ", const std::string& s = "Discrete Conditional: ",
@ -225,6 +230,21 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* This is actually just -error(x).
*/
double logProbability(const HybridValues& x) const override {
return -error(x);
}
using DecisionTreeFactor::evaluate;
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42

View File

@ -88,6 +88,29 @@ TEST(DiscreteConditional, constructors3) {
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual))); EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
} }
/* ****************************************************************************/
// Test evaluate for a discrete Prior P(Asia).
TEST(DiscreteConditional, PriorProbability) {
constexpr Key asiaKey = 0;
const DiscreteKey Asia(asiaKey, 2);
DiscreteConditional dc(Asia, "4/6");
DiscreteValues values{{asiaKey, 0}};
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
}
/* ************************************************************************* */
// Check that error, logProbability, evaluate all work as expected.
TEST(DiscreteConditional, probability) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check calculation of joint P(A,B) // Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) { TEST(DiscreteConditional, Multiply) {