discrete now has all logProbability and evaluate versions needed.
parent
f89ef731a5
commit
11ef99b3f0
|
@ -18,6 +18,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/FastSet.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
||||
|
@ -56,6 +57,16 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double DecisionTreeFactor::error(const DiscreteValues& values) const {
|
||||
return -std::log(evaluate(values));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double DecisionTreeFactor::error(const HybridValues& values) const {
|
||||
return error(values.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||
// The use for safe_div is when we divide the product factor by the sum
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
namespace gtsam {
|
||||
|
||||
class DiscreteConditional;
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* A discrete probabilistic factor.
|
||||
|
@ -97,11 +98,20 @@ namespace gtsam {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Value is just look up in AlgebraicDecisionTree.
|
||||
/// Calculate probability for given values `x`,
|
||||
/// is just look up in AlgebraicDecisionTree.
|
||||
double evaluate(const DiscreteValues& values) const {
|
||||
return ADT::operator()(values);
|
||||
}
|
||||
|
||||
/// Evaluate probability density, sugar.
|
||||
double operator()(const DiscreteValues& values) const override {
|
||||
return ADT::operator()(values);
|
||||
}
|
||||
|
||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||
double error(const DiscreteValues& values) const;
|
||||
|
||||
/// multiply two factors
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||
return apply(f, ADT::Ring::mul);
|
||||
|
@ -230,7 +240,17 @@ namespace gtsam {
|
|||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
/// @}
|
||||
/// @name HybridValues methods.
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* Calculate error for HybridValues `x`, is -log(probability)
|
||||
* Simply dispatches to DiscreteValues version.
|
||||
*/
|
||||
double error(const HybridValues& values) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Log-probability is just -error(x).
|
||||
double logProbability(const DiscreteValues& x) const {
|
||||
return -error(x);
|
||||
}
|
||||
|
||||
/// print index signature only
|
||||
void printSignature(
|
||||
const std::string& s = "Discrete Conditional: ",
|
||||
|
@ -225,6 +230,21 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
|
||||
/// @}
|
||||
/// @name HybridValues methods.
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
|
||||
* This is actually just -error(x).
|
||||
*/
|
||||
double logProbability(const HybridValues& x) const override {
|
||||
return -error(x);
|
||||
}
|
||||
|
||||
using DecisionTreeFactor::evaluate;
|
||||
|
||||
/// @}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
|
|
|
@ -88,6 +88,29 @@ TEST(DiscreteConditional, constructors3) {
|
|||
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test evaluate for a discrete Prior P(Asia).
|
||||
TEST(DiscreteConditional, PriorProbability) {
|
||||
constexpr Key asiaKey = 0;
|
||||
const DiscreteKey Asia(asiaKey, 2);
|
||||
DiscreteConditional dc(Asia, "4/6");
|
||||
DiscreteValues values{{asiaKey, 0}};
|
||||
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check that error, logProbability, evaluate all work as expected.
|
||||
TEST(DiscreteConditional, probability) {
|
||||
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||
|
||||
DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
|
||||
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of joint P(A,B)
|
||||
TEST(DiscreteConditional, Multiply) {
|
||||
|
|
Loading…
Reference in New Issue