discrete now has all logProbability and evaluate versions needed.

release/4.3a0
Frank Dellaert 2023-01-10 15:16:55 -08:00
parent f89ef731a5
commit 11ef99b3f0
4 changed files with 77 additions and 3 deletions

View File

@ -18,6 +18,7 @@
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
@ -56,6 +57,16 @@ namespace gtsam {
}
}
/* ************************************************************************ */
double DecisionTreeFactor::error(const DiscreteValues& values) const {
return -std::log(evaluate(values));
}
/* ************************************************************************ */
double DecisionTreeFactor::error(const HybridValues& values) const {
return error(values.discrete());
}
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum

View File

@ -34,6 +34,7 @@
namespace gtsam {
class DiscreteConditional;
class HybridValues;
/**
* A discrete probabilistic factor.
@ -97,11 +98,20 @@ namespace gtsam {
/// @name Standard Interface
/// @{
/// Value is just look up in AlgebraicDecisionTree.
/// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}
/// Evaluate probability density, sugar.
double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values);
}
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;
/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, ADT::Ring::mul);
@ -230,7 +240,17 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
/// @}
/// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate error for HybridValues `x`, is -log(probability)
* Simply dispatches to DiscreteValues version.
*/
double error(const HybridValues& values) const override;
/// @}
private:
/** Serialization function */

View File

@ -18,9 +18,9 @@
#pragma once
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional.h>
#include <boost/make_shared.hpp>
#include <boost/shared_ptr.hpp>
@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Standard Interface
/// @{
/// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const {
return -error(x);
}
/// print index signature only
void printSignature(
const std::string& s = "Discrete Conditional: ",
@ -225,6 +230,21 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;
/// @}
/// @name HybridValues methods.
/// @{
/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* This is actually just -error(x).
*/
double logProbability(const HybridValues& x) const override {
return -error(x);
}
using DecisionTreeFactor::evaluate;
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42

View File

@ -88,6 +88,29 @@ TEST(DiscreteConditional, constructors3) {
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
}
/* ****************************************************************************/
// Test evaluate for a discrete Prior P(Asia).
TEST(DiscreteConditional, PriorProbability) {
constexpr Key asiaKey = 0;
const DiscreteKey Asia(asiaKey, 2);
DiscreteConditional dc(Asia, "4/6");
DiscreteValues values{{asiaKey, 0}};
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
}
/* ************************************************************************* */
// Check that error, logProbability, evaluate all work as expected.
TEST(DiscreteConditional, probability) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
}
/* ************************************************************************* */
// Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) {