discrete now has all logProbability and evaluate versions needed.
parent
f89ef731a5
commit
11ef99b3f0
|
@ -18,6 +18,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
|
||||||
|
@ -56,6 +57,16 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double DecisionTreeFactor::error(const DiscreteValues& values) const {
|
||||||
|
return -std::log(evaluate(values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double DecisionTreeFactor::error(const HybridValues& values) const {
|
||||||
|
return error(values.discrete());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
|
|
@ -34,6 +34,7 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class DiscreteConditional;
|
class DiscreteConditional;
|
||||||
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A discrete probabilistic factor.
|
* A discrete probabilistic factor.
|
||||||
|
@ -97,11 +98,20 @@ namespace gtsam {
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Value is just look up in AlgebraicDecisionTree.
|
/// Calculate probability for given values `x`,
|
||||||
|
/// is just look up in AlgebraicDecisionTree.
|
||||||
|
double evaluate(const DiscreteValues& values) const {
|
||||||
|
return ADT::operator()(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluate probability density, sugar.
|
||||||
double operator()(const DiscreteValues& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
|
double error(const DiscreteValues& values) const;
|
||||||
|
|
||||||
/// multiply two factors
|
/// multiply two factors
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||||
return apply(f, ADT::Ring::mul);
|
return apply(f, ADT::Ring::mul);
|
||||||
|
@ -230,6 +240,16 @@ namespace gtsam {
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name HybridValues methods.
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate error for HybridValues `x`, is -log(probability)
|
||||||
|
* Simply dispatches to DiscreteValues version.
|
||||||
|
*/
|
||||||
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
|
@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// Log-probability is just -error(x).
|
||||||
|
double logProbability(const DiscreteValues& x) const {
|
||||||
|
return -error(x);
|
||||||
|
}
|
||||||
|
|
||||||
/// print index signature only
|
/// print index signature only
|
||||||
void printSignature(
|
void printSignature(
|
||||||
const std::string& s = "Discrete Conditional: ",
|
const std::string& s = "Discrete Conditional: ",
|
||||||
|
@ -225,6 +230,21 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name HybridValues methods.
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
|
||||||
|
* This is actually just -error(x).
|
||||||
|
*/
|
||||||
|
double logProbability(const HybridValues& x) const override {
|
||||||
|
return -error(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
using DecisionTreeFactor::evaluate;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
|
|
@ -88,6 +88,29 @@ TEST(DiscreteConditional, constructors3) {
|
||||||
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test evaluate for a discrete Prior P(Asia).
|
||||||
|
TEST(DiscreteConditional, PriorProbability) {
|
||||||
|
constexpr Key asiaKey = 0;
|
||||||
|
const DiscreteKey Asia(asiaKey, 2);
|
||||||
|
DiscreteConditional dc(Asia, "4/6");
|
||||||
|
DiscreteValues values{{asiaKey, 0}};
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check that error, logProbability, evaluate all work as expected.
|
||||||
|
TEST(DiscreteConditional, probability) {
|
||||||
|
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check calculation of joint P(A,B)
|
// Check calculation of joint P(A,B)
|
||||||
TEST(DiscreteConditional, Multiply) {
|
TEST(DiscreteConditional, Multiply) {
|
||||||
|
|
Loading…
Reference in New Issue