likelihood
parent
dbe5c0fa81
commit
457d074858
|
@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
Potentials::ADT DiscreteConditional::choose(
|
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||||
const DiscreteValues& parentsValues) const {
|
const DiscreteValues& parentsValues) {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
// branches based on the value of the parent variables.
|
// branches based on the value of the parent variables.
|
||||||
ADT pFS(*this);
|
DiscreteConditional::ADT adt(conditional);
|
||||||
size_t value;
|
size_t value;
|
||||||
for (Key j : parents()) {
|
for (Key j : conditional.parents()) {
|
||||||
try {
|
try {
|
||||||
value = parentsValues.at(j);
|
value = parentsValues.at(j);
|
||||||
pFS = pFS.choose(j, value); // ADT keeps getting smaller.
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
} catch (exception&) {
|
} catch (exception&) {
|
||||||
cout << "Key: " << j << " Value: " << value << endl;
|
|
||||||
parentsValues.print("parentsValues: ");
|
parentsValues.print("parentsValues: ");
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return pFS;
|
return adt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor(
|
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
|
||||||
const DiscreteValues& parentsValues) const {
|
const DiscreteValues& parentsValues) const {
|
||||||
ADT pFS = choose(parentsValues);
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
|
// branches based on the value of the parent variables.
|
||||||
|
ADT adt(*this);
|
||||||
|
size_t value;
|
||||||
|
for (Key j : parents()) {
|
||||||
|
try {
|
||||||
|
value = parentsValues.at(j);
|
||||||
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
|
} catch (exception&) {
|
||||||
|
parentsValues.print("parentsValues: ");
|
||||||
|
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Convert ADT to factor.
|
// Convert ADT to factor.
|
||||||
if (nrFrontals() != 1) {
|
DiscreteKeys discreteKeys;
|
||||||
throw std::runtime_error("Expected only one frontal variable in choose.");
|
for (Key j : frontals()) {
|
||||||
|
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||||
}
|
}
|
||||||
DiscreteKeys keys;
|
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||||
const Key frontalKey = keys_[0];
|
}
|
||||||
size_t frontalCardinality = this->cardinality(frontalKey);
|
|
||||||
keys.push_back(DiscreteKey(frontalKey, frontalCardinality));
|
/* ******************************************************************************** */
|
||||||
return boost::make_shared<DecisionTreeFactor>(keys, pFS);
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
|
const DiscreteValues& frontalValues) const {
|
||||||
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
|
// branches based on the value of the frontal variables.
|
||||||
|
ADT adt(*this);
|
||||||
|
size_t value;
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
try {
|
||||||
|
value = frontalValues.at(j);
|
||||||
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
|
} catch (exception&) {
|
||||||
|
frontalValues.print("frontalValues: ");
|
||||||
|
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert ADT to factor.
|
||||||
|
DiscreteKeys discreteKeys;
|
||||||
|
for (Key j : parents()) {
|
||||||
|
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||||
|
}
|
||||||
|
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
|
size_t value) const {
|
||||||
|
if (nrFrontals() != 1)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Single value likelihood can only be invoked on single-variable "
|
||||||
|
"conditional");
|
||||||
|
DiscreteValues values;
|
||||||
|
values.emplace(keys_[0], value);
|
||||||
|
return likelihood(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
||||||
ADT pFS = choose(*values); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
|
@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||||
|
|
||||||
// TODO: is this really the fastest way? I think it is.
|
// TODO: is this really the fastest way? I think it is.
|
||||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
// TODO, only works for one key now, seems horribly slow this way
|
// TODO, only works for one key now, seems horribly slow this way
|
||||||
|
@ -203,7 +248,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
static mt19937 rng(2); // random number generator
|
static mt19937 rng(2); // random number generator
|
||||||
|
|
||||||
// Get the correct conditional density
|
// Get the correct conditional density
|
||||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
|
|
|
@ -146,13 +146,17 @@ public:
|
||||||
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
|
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
|
|
||||||
ADT choose(const DiscreteValues& parentsValues) const;
|
|
||||||
|
|
||||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
/** Restrict to given parent values, returns DecisionTreeFactor */
|
||||||
DecisionTreeFactor::shared_ptr chooseAsFactor(
|
DecisionTreeFactor::shared_ptr choose(
|
||||||
const DiscreteValues& parentsValues) const;
|
const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
/** Convert to a likelihood factor by providing value before bar. */
|
||||||
|
DecisionTreeFactor::shared_ptr likelihood(
|
||||||
|
const DiscreteValues& frontalValues) const;
|
||||||
|
|
||||||
|
/** Single variable version of likelihood. */
|
||||||
|
DecisionTreeFactor::shared_ptr likelihood(size_t value) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* solve a conditional
|
* solve a conditional
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
|
|
|
@ -76,8 +76,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
string s = "Discrete Conditional: ",
|
string s = "Discrete Conditional: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
gtsam::DecisionTreeFactor* toFactor() const;
|
gtsam::DecisionTreeFactor* toFactor() const;
|
||||||
gtsam::DecisionTreeFactor* chooseAsFactor(
|
gtsam::DecisionTreeFactor* choose(
|
||||||
const gtsam::DiscreteValues& parentsValues) const;
|
const gtsam::DiscreteValues& parentsValues) const;
|
||||||
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
|
const gtsam::DiscreteValues& frontalValues) const;
|
||||||
|
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
|
|
|
@ -31,24 +31,21 @@ using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteConditional, constructors)
|
TEST(DiscreteConditional, constructors) {
|
||||||
{
|
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
|
||||||
|
DiscreteConditional expected(X | Y = "1/1 2/3 1/4");
|
||||||
|
EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals()));
|
||||||
|
EXPECT_LONGS_EQUAL(2, *(expected.beginParents()));
|
||||||
|
EXPECT(expected.endParents() == expected.end());
|
||||||
|
EXPECT(expected.endFrontals() == expected.beginParents());
|
||||||
|
|
||||||
DiscreteConditional::shared_ptr expected1 = //
|
|
||||||
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
|
||||||
EXPECT(expected1);
|
|
||||||
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
|
||||||
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
|
||||||
EXPECT(expected1->endParents() == expected1->end());
|
|
||||||
EXPECT(expected1->endFrontals() == expected1->beginParents());
|
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional actual1(1, f1);
|
DiscreteConditional actual1(1, f1);
|
||||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
EXPECT(assert_equal(expected, actual1, 1e-9));
|
||||||
|
|
||||||
DecisionTreeFactor f2(X & Y & Z,
|
DecisionTreeFactor f2(
|
||||||
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
||||||
}
|
}
|
||||||
|
@ -108,6 +105,20 @@ TEST(DiscreteConditional, Combine) {
|
||||||
EXPECT(assert_equal(expected, *actual, 1e-5));
|
EXPECT(assert_equal(expected, *actual, 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteConditional, likelihood) {
|
||||||
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
|
||||||
|
|
||||||
|
auto actual0 = conditional.likelihood(0);
|
||||||
|
DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
|
||||||
|
EXPECT(assert_equal(expected0, *actual0, 1e-9));
|
||||||
|
|
||||||
|
auto actual1 = conditional.likelihood(1);
|
||||||
|
DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
|
||||||
|
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation looks as expected, no parents.
|
// Check markdown representation looks as expected, no parents.
|
||||||
TEST(DiscreteConditional, markdown_prior) {
|
TEST(DiscreteConditional, markdown_prior) {
|
||||||
|
|
|
@ -13,12 +13,26 @@ Author: Varun Agrawal
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import DiscreteConditional, DiscreteKeys
|
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteConditional(GtsamTestCase):
|
class TestDiscreteConditional(GtsamTestCase):
|
||||||
"""Tests for Discrete Conditionals."""
|
"""Tests for Discrete Conditionals."""
|
||||||
|
|
||||||
|
def test_likelihood(self):
|
||||||
|
X = (0, 2)
|
||||||
|
Y = (1, 3)
|
||||||
|
conditional = DiscreteConditional(X, "2/8 4/6 5/5", Y)
|
||||||
|
|
||||||
|
actual0 = conditional.likelihood(0)
|
||||||
|
expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
|
||||||
|
self.gtsamAssertEquals(actual0, expected0, 1e-9)
|
||||||
|
|
||||||
|
actual1 = conditional.likelihood(1)
|
||||||
|
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
|
||||||
|
self.gtsamAssertEquals(actual1, expected1, 1e-9)
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test whether the _repr_markdown_ method."""
|
"""Test whether the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
@ -32,7 +46,7 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
conditional = DiscreteConditional(A, parents,
|
conditional = DiscreteConditional(A, parents,
|
||||||
"0/1 1/3 1/1 3/1 0/1 1/0")
|
"0/1 1/3 1/1 3/1 0/1 1/0")
|
||||||
expected = \
|
expected = \
|
||||||
" $P(A|B,C)$:\n" \
|
" *P(A|B,C)*:\n\n" \
|
||||||
"|B|C|0|1|\n" \
|
"|B|C|0|1|\n" \
|
||||||
"|:-:|:-:|:-:|:-:|\n" \
|
"|:-:|:-:|:-:|:-:|\n" \
|
||||||
"|0|0|0|1|\n" \
|
"|0|0|0|1|\n" \
|
||||||
|
|
Loading…
Reference in New Issue