likelihood

release/4.3a0
Frank Dellaert 2021-12-27 13:01:29 -05:00
parent dbe5c0fa81
commit 457d074858
5 changed files with 117 additions and 40 deletions

View File

@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose( static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& parentsValues) const { const DiscreteValues& parentsValues) {
// Get the big decision tree with all the levels, and then go down the // Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables. // branches based on the value of the parent variables.
ADT pFS(*this); DiscreteConditional::ADT adt(conditional);
size_t value; size_t value;
for (Key j : parents()) { for (Key j : conditional.parents()) {
try { try {
value = parentsValues.at(j); value = parentsValues.at(j);
pFS = pFS.choose(j, value); // ADT keeps getting smaller. adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) { } catch (exception&) {
cout << "Key: " << j << " Value: " << value << endl;
parentsValues.print("parentsValues: "); parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing"); throw runtime_error("DiscreteConditional::choose: parent value missing");
}; };
} }
return pFS; return adt;
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
const DiscreteValues& parentsValues) const { const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues); // Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
value = parentsValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
// Convert ADT to factor. // Convert ADT to factor.
if (nrFrontals() != 1) { DiscreteKeys discreteKeys;
throw std::runtime_error("Expected only one frontal variable in choose."); for (Key j : frontals()) {
discreteKeys.emplace_back(j, this->cardinality(j));
} }
DiscreteKeys keys; return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
const Key frontalKey = keys_[0]; }
size_t frontalCardinality = this->cardinality(frontalKey);
keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); /* ******************************************************************************** */
return boost::make_shared<DecisionTreeFactor>(keys, pFS); DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the frontal variables.
ADT adt(*this);
size_t value;
for (Key j : frontals()) {
try {
value = frontalValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing");
};
}
// Convert ADT to factor.
DiscreteKeys discreteKeys;
for (Key j : parents()) {
discreteKeys.emplace_back(j, this->cardinality(j));
}
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}
/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value likelihood can only be invoked on single-variable "
"conditional");
DiscreteValues values;
values.emplace(keys_[0], value);
return likelihood(values);
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
void DiscreteConditional::solveInPlace(DiscreteValues* values) const { void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is. // TODO: Abhijit asks: is this really the fastest way? He thinks it is.
ADT pFS = choose(*values); // P(F|S=parentsValues) ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
// Initialize // Initialize
DiscreteValues mpe; DiscreteValues mpe;
@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
// TODO: is this really the fastest way? I think it is. // TODO: is this really the fastest way? I think it is.
ADT pFS = choose(parentsValues); // P(F|S=parentsValues) ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
// Then, find the max over all remaining // Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way // TODO, only works for one key now, seems horribly slow this way
@ -203,7 +248,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator static mt19937 rng(2); // random number generator
// Get the correct conditional density // Get the correct conditional density
ADT pFS = choose(parentsValues); // P(F|S=parentsValues) ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way // TODO(Duy): only works for one key now, seems horribly slow this way
assert(nrFrontals() == 1); assert(nrFrontals() == 1);

View File

@ -146,13 +146,17 @@ public:
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
} }
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const DiscreteValues& parentsValues) const;
/** Restrict to given parent values, returns DecisionTreeFactor */ /** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr chooseAsFactor( DecisionTreeFactor::shared_ptr choose(
const DiscreteValues& parentsValues) const; const DiscreteValues& parentsValues) const;
/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t value) const;
/** /**
* solve a conditional * solve a conditional
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents

View File

@ -76,8 +76,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
string s = "Discrete Conditional: ", string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* toFactor() const; gtsam::DecisionTreeFactor* toFactor() const;
gtsam::DecisionTreeFactor* chooseAsFactor( gtsam::DecisionTreeFactor* choose(
const gtsam::DiscreteValues& parentsValues) const; const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const; void solveInPlace(gtsam::DiscreteValues @parentsValues) const;

View File

@ -31,24 +31,21 @@ using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteConditional, constructors) TEST(DiscreteConditional, constructors) {
{ DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
DiscreteConditional expected(X | Y = "1/1 2/3 1/4");
EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(expected.beginParents()));
EXPECT(expected.endParents() == expected.end());
EXPECT(expected.endFrontals() == expected.beginParents());
DiscreteConditional::shared_ptr expected1 = //
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
EXPECT(expected1);
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
EXPECT(expected1->endParents() == expected1->end());
EXPECT(expected1->endFrontals() == expected1->beginParents());
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional actual1(1, f1); DiscreteConditional actual1(1, f1);
EXPECT(assert_equal(*expected1, actual1, 1e-9)); EXPECT(assert_equal(expected, actual1, 1e-9));
DecisionTreeFactor f2(X & Y & Z, DecisionTreeFactor f2(
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
} }
@ -108,6 +105,20 @@ TEST(DiscreteConditional, Combine) {
EXPECT(assert_equal(expected, *actual, 1e-5)); EXPECT(assert_equal(expected, *actual, 1e-5));
} }
/* ************************************************************************* */
TEST(DiscreteConditional, likelihood) {
DiscreteKey X(0, 2), Y(1, 3);
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
auto actual0 = conditional.likelihood(0);
DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
EXPECT(assert_equal(expected0, *actual0, 1e-9));
auto actual1 = conditional.likelihood(1);
DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
EXPECT(assert_equal(expected1, *actual1, 1e-9));
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation looks as expected, no parents. // Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) { TEST(DiscreteConditional, markdown_prior) {

View File

@ -13,12 +13,26 @@ Author: Varun Agrawal
import unittest import unittest
from gtsam import DiscreteConditional, DiscreteKeys from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
class TestDiscreteConditional(GtsamTestCase): class TestDiscreteConditional(GtsamTestCase):
"""Tests for Discrete Conditionals.""" """Tests for Discrete Conditionals."""
def test_likelihood(self):
X = (0, 2)
Y = (1, 3)
conditional = DiscreteConditional(X, "2/8 4/6 5/5", Y)
actual0 = conditional.likelihood(0)
expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
self.gtsamAssertEquals(actual0, expected0, 1e-9)
actual1 = conditional.likelihood(1)
expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
self.gtsamAssertEquals(actual1, expected1, 1e-9)
def test_markdown(self): def test_markdown(self):
"""Test whether the _repr_markdown_ method.""" """Test whether the _repr_markdown_ method."""
@ -32,7 +46,7 @@ class TestDiscreteConditional(GtsamTestCase):
conditional = DiscreteConditional(A, parents, conditional = DiscreteConditional(A, parents,
"0/1 1/3 1/1 3/1 0/1 1/0") "0/1 1/3 1/1 3/1 0/1 1/0")
expected = \ expected = \
" $P(A|B,C)$:\n" \ " *P(A|B,C)*:\n\n" \
"|B|C|0|1|\n" \ "|B|C|0|1|\n" \
"|:-:|:-:|:-:|:-:|\n" \ "|:-:|:-:|:-:|:-:|\n" \
"|0|0|0|1|\n" \ "|0|0|0|1|\n" \