likelihood
parent
dbe5c0fa81
commit
457d074858
|
@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
|||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
Potentials::ADT DiscreteConditional::choose(
|
||||
const DiscreteValues& parentsValues) const {
|
||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||
const DiscreteValues& parentsValues) {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
ADT pFS(*this);
|
||||
DiscreteConditional::ADT adt(conditional);
|
||||
size_t value;
|
||||
for (Key j : parents()) {
|
||||
for (Key j : conditional.parents()) {
|
||||
try {
|
||||
value = parentsValues.at(j);
|
||||
pFS = pFS.choose(j, value); // ADT keeps getting smaller.
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
cout << "Key: " << j << " Value: " << value << endl;
|
||||
parentsValues.print("parentsValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
}
|
||||
return pFS;
|
||||
return adt;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor(
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues);
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
value = parentsValues.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
parentsValues.print("parentsValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
}
|
||||
|
||||
// Convert ADT to factor.
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::runtime_error("Expected only one frontal variable in choose.");
|
||||
DiscreteKeys discreteKeys;
|
||||
for (Key j : frontals()) {
|
||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
DiscreteKeys keys;
|
||||
const Key frontalKey = keys_[0];
|
||||
size_t frontalCardinality = this->cardinality(frontalKey);
|
||||
keys.push_back(DiscreteKey(frontalKey, frontalCardinality));
|
||||
return boost::make_shared<DecisionTreeFactor>(keys, pFS);
|
||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
const DiscreteValues& frontalValues) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the frontal variables.
|
||||
ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : frontals()) {
|
||||
try {
|
||||
value = frontalValues.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
frontalValues.print("frontalValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||
};
|
||||
}
|
||||
|
||||
// Convert ADT to factor.
|
||||
DiscreteKeys discreteKeys;
|
||||
for (Key j : parents()) {
|
||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
size_t value) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value likelihood can only be invoked on single-variable "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_[0], value);
|
||||
return likelihood(values);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
||||
ADT pFS = choose(*values); // P(F|S=parentsValues)
|
||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
|
@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
|||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||
|
||||
// TODO: is this really the fastest way? I think it is.
|
||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO, only works for one key now, seems horribly slow this way
|
||||
|
@ -203,7 +248,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
static mt19937 rng(2); // random number generator
|
||||
|
||||
// Get the correct conditional density
|
||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
assert(nrFrontals() == 1);
|
||||
|
|
|
@ -146,13 +146,17 @@ public:
|
|||
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
|
||||
}
|
||||
|
||||
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
|
||||
ADT choose(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
||||
DecisionTreeFactor::shared_ptr chooseAsFactor(
|
||||
DecisionTreeFactor::shared_ptr choose(
|
||||
const DiscreteValues& parentsValues) const;
|
||||
|
||||
/** Convert to a likelihood factor by providing value before bar. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(
|
||||
const DiscreteValues& frontalValues) const;
|
||||
|
||||
/** Single variable version of likelihood. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t value) const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @param parentsValues Known values of the parents
|
||||
|
|
|
@ -76,8 +76,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
string s = "Discrete Conditional: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DecisionTreeFactor* toFactor() const;
|
||||
gtsam::DecisionTreeFactor* chooseAsFactor(
|
||||
gtsam::DecisionTreeFactor* choose(
|
||||
const gtsam::DiscreteValues& parentsValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(
|
||||
const gtsam::DiscreteValues& frontalValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
|
|
|
@ -31,24 +31,21 @@ using namespace std;
|
|||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteConditional, constructors)
|
||||
{
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
TEST(DiscreteConditional, constructors) {
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
|
||||
DiscreteConditional expected(X | Y = "1/1 2/3 1/4");
|
||||
EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals()));
|
||||
EXPECT_LONGS_EQUAL(2, *(expected.beginParents()));
|
||||
EXPECT(expected.endParents() == expected.end());
|
||||
EXPECT(expected.endFrontals() == expected.beginParents());
|
||||
|
||||
DiscreteConditional::shared_ptr expected1 = //
|
||||
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
||||
EXPECT(expected1);
|
||||
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
||||
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
||||
EXPECT(expected1->endParents() == expected1->end());
|
||||
EXPECT(expected1->endFrontals() == expected1->beginParents());
|
||||
|
||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||
DiscreteConditional actual1(1, f1);
|
||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
||||
EXPECT(assert_equal(expected, actual1, 1e-9));
|
||||
|
||||
DecisionTreeFactor f2(X & Y & Z,
|
||||
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DiscreteConditional actual2(1, f2);
|
||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
||||
}
|
||||
|
@ -108,6 +105,20 @@ TEST(DiscreteConditional, Combine) {
|
|||
EXPECT(assert_equal(expected, *actual, 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, likelihood) {
|
||||
DiscreteKey X(0, 2), Y(1, 3);
|
||||
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
|
||||
|
||||
auto actual0 = conditional.likelihood(0);
|
||||
DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
|
||||
EXPECT(assert_equal(expected0, *actual0, 1e-9));
|
||||
|
||||
auto actual1 = conditional.likelihood(1);
|
||||
DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
|
||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, no parents.
|
||||
TEST(DiscreteConditional, markdown_prior) {
|
||||
|
|
|
@ -13,12 +13,26 @@ Author: Varun Agrawal
|
|||
|
||||
import unittest
|
||||
|
||||
from gtsam import DiscreteConditional, DiscreteKeys
|
||||
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestDiscreteConditional(GtsamTestCase):
|
||||
"""Tests for Discrete Conditionals."""
|
||||
|
||||
def test_likelihood(self):
|
||||
X = (0, 2)
|
||||
Y = (1, 3)
|
||||
conditional = DiscreteConditional(X, "2/8 4/6 5/5", Y)
|
||||
|
||||
actual0 = conditional.likelihood(0)
|
||||
expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
|
||||
self.gtsamAssertEquals(actual0, expected0, 1e-9)
|
||||
|
||||
actual1 = conditional.likelihood(1)
|
||||
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
|
||||
self.gtsamAssertEquals(actual1, expected1, 1e-9)
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test whether the _repr_markdown_ method."""
|
||||
|
||||
|
@ -32,7 +46,7 @@ class TestDiscreteConditional(GtsamTestCase):
|
|||
conditional = DiscreteConditional(A, parents,
|
||||
"0/1 1/3 1/1 3/1 0/1 1/0")
|
||||
expected = \
|
||||
" $P(A|B,C)$:\n" \
|
||||
" *P(A|B,C)*:\n\n" \
|
||||
"|B|C|0|1|\n" \
|
||||
"|:-:|:-:|:-:|:-:|\n" \
|
||||
"|0|0|0|1|\n" \
|
||||
|
|
Loading…
Reference in New Issue