likelihood

release/4.3a0
Frank Dellaert 2021-12-27 13:01:29 -05:00
parent dbe5c0fa81
commit 457d074858
5 changed files with 117 additions and 40 deletions

View File

@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
}
/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(
const DiscreteValues& parentsValues) const {
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& parentsValues) {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT pFS(*this);
DiscreteConditional::ADT adt(conditional);
size_t value;
for (Key j : parents()) {
for (Key j : conditional.parents()) {
try {
value = parentsValues.at(j);
pFS = pFS.choose(j, value); // ADT keeps getting smaller.
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
cout << "Key: " << j << " Value: " << value << endl;
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
return pFS;
return adt;
}
/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor(
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues);
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
value = parentsValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
// Convert ADT to factor.
if (nrFrontals() != 1) {
throw std::runtime_error("Expected only one frontal variable in choose.");
DiscreteKeys discreteKeys;
for (Key j : frontals()) {
discreteKeys.emplace_back(j, this->cardinality(j));
}
DiscreteKeys keys;
const Key frontalKey = keys_[0];
size_t frontalCardinality = this->cardinality(frontalKey);
keys.push_back(DiscreteKey(frontalKey, frontalCardinality));
return boost::make_shared<DecisionTreeFactor>(keys, pFS);
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}
/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the frontal variables.
ADT adt(*this);
size_t value;
for (Key j : frontals()) {
try {
value = frontalValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing");
};
}
// Convert ADT to factor.
DiscreteKeys discreteKeys;
for (Key j : parents()) {
discreteKeys.emplace_back(j, this->cardinality(j));
}
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}
/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value likelihood can only be invoked on single-variable "
"conditional");
DiscreteValues values;
values.emplace(keys_[0], value);
return likelihood(values);
}
/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
ADT pFS = choose(*values); // P(F|S=parentsValues)
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
// TODO: is this really the fastest way? I think it is.
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
@ -203,7 +248,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
// Get the correct conditional density
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way
assert(nrFrontals() == 1);

View File

@ -146,13 +146,17 @@ public:
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
}
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const DiscreteValues& parentsValues) const;
/** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr chooseAsFactor(
DecisionTreeFactor::shared_ptr choose(
const DiscreteValues& parentsValues) const;
/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t value) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents

View File

@ -76,8 +76,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* toFactor() const;
gtsam::DecisionTreeFactor* chooseAsFactor(
gtsam::DecisionTreeFactor* choose(
const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;

View File

@ -31,24 +31,21 @@ using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST( DiscreteConditional, constructors)
{
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
TEST(DiscreteConditional, constructors) {
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
DiscreteConditional expected(X | Y = "1/1 2/3 1/4");
EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(expected.beginParents()));
EXPECT(expected.endParents() == expected.end());
EXPECT(expected.endFrontals() == expected.beginParents());
DiscreteConditional::shared_ptr expected1 = //
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
EXPECT(expected1);
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
EXPECT(expected1->endParents() == expected1->end());
EXPECT(expected1->endFrontals() == expected1->beginParents());
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional actual1(1, f1);
EXPECT(assert_equal(*expected1, actual1, 1e-9));
EXPECT(assert_equal(expected, actual1, 1e-9));
DecisionTreeFactor f2(X & Y & Z,
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
}
@ -108,6 +105,20 @@ TEST(DiscreteConditional, Combine) {
EXPECT(assert_equal(expected, *actual, 1e-5));
}
/* ************************************************************************* */
TEST(DiscreteConditional, likelihood) {
DiscreteKey X(0, 2), Y(1, 3);
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
auto actual0 = conditional.likelihood(0);
DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
EXPECT(assert_equal(expected0, *actual0, 1e-9));
auto actual1 = conditional.likelihood(1);
DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
EXPECT(assert_equal(expected1, *actual1, 1e-9));
}
/* ************************************************************************* */
// Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) {

View File

@ -13,12 +13,26 @@ Author: Varun Agrawal
import unittest
from gtsam import DiscreteConditional, DiscreteKeys
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
from gtsam.utils.test_case import GtsamTestCase
class TestDiscreteConditional(GtsamTestCase):
"""Tests for Discrete Conditionals."""
def test_likelihood(self):
X = (0, 2)
Y = (1, 3)
conditional = DiscreteConditional(X, "2/8 4/6 5/5", Y)
actual0 = conditional.likelihood(0)
expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
self.gtsamAssertEquals(actual0, expected0, 1e-9)
actual1 = conditional.likelihood(1)
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
self.gtsamAssertEquals(actual1, expected1, 1e-9)
def test_markdown(self):
"""Test whether the _repr_markdown_ method."""
@ -32,7 +46,7 @@ class TestDiscreteConditional(GtsamTestCase):
conditional = DiscreteConditional(A, parents,
"0/1 1/3 1/1 3/1 0/1 1/0")
expected = \
" $P(A|B,C)$:\n" \
" *P(A|B,C)*:\n\n" \
"|B|C|0|1|\n" \
"|:-:|:-:|:-:|:-:|\n" \
"|0|0|0|1|\n" \