diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 2d3a5ddf8..328af1ca3 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose( - const DiscreteValues& parentsValues) const { +static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, + const DiscreteValues& parentsValues) { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the parent variables. - ADT pFS(*this); + DiscreteConditional::ADT adt(conditional); size_t value; - for (Key j : parents()) { + for (Key j : conditional.parents()) { try { value = parentsValues.at(j); - pFS = pFS.choose(j, value); // ADT keeps getting smaller. + adt = adt.choose(j, value); // ADT keeps getting smaller. } catch (exception&) { - cout << "Key: " << j << " Value: " << value << endl; parentsValues.print("parentsValues: "); throw runtime_error("DiscreteConditional::choose: parent value missing"); }; } - return pFS; + return adt; } /* ******************************************************************************** */ -DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( +DecisionTreeFactor::shared_ptr DiscreteConditional::choose( const DiscreteValues& parentsValues) const { - ADT pFS = choose(parentsValues); + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + ADT adt(*this); + size_t value; + for (Key j : parents()) { + try { + value = parentsValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (exception&) { + parentsValues.print("parentsValues: "); + throw runtime_error("DiscreteConditional::choose: parent value missing"); + }; + } // Convert ADT to factor. - if (nrFrontals() != 1) { - throw std::runtime_error("Expected only one frontal variable in choose."); + DiscreteKeys discreteKeys; + for (Key j : frontals()) { + discreteKeys.emplace_back(j, this->cardinality(j)); } - DiscreteKeys keys; - const Key frontalKey = keys_[0]; - size_t frontalCardinality = this->cardinality(frontalKey); - keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); - return boost::make_shared(keys, pFS); + return boost::make_shared(discreteKeys, adt); +} + +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + const DiscreteValues& frontalValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the frontal variables. + ADT adt(*this); + size_t value; + for (Key j : frontals()) { + try { + value = frontalValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (exception&) { + frontalValues.print("frontalValues: "); + throw runtime_error("DiscreteConditional::choose: frontal value missing"); + }; + } + + // Convert ADT to factor. + DiscreteKeys discreteKeys; + for (Key j : parents()) { + discreteKeys.emplace_back(j, this->cardinality(j)); + } + return boost::make_shared(discreteKeys, adt); +} + +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + size_t value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value likelihood can only be invoked on single-variable " + "conditional"); + DiscreteValues values; + values.emplace(keys_[0], value); + return likelihood(values); } /* ******************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // TODO: Abhijit asks: is this really the fastest way? He thinks it is. - ADT pFS = choose(*values); // P(F|S=parentsValues) + ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; @@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { // TODO: is this really the fastest way? I think it is. - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // Then, find the max over all remaining // TODO, only works for one key now, seems horribly slow this way @@ -203,7 +248,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way assert(nrFrontals() == 1); diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index e95dfb515..c72f076d8 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -146,13 +146,17 @@ public: return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); } - /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ - ADT choose(const DiscreteValues& parentsValues) const; - /** Restrict to given parent values, returns DecisionTreeFactor */ - DecisionTreeFactor::shared_ptr chooseAsFactor( + DecisionTreeFactor::shared_ptr choose( const DiscreteValues& parentsValues) const; + /** Convert to a likelihood factor by providing value before bar. */ + DecisionTreeFactor::shared_ptr likelihood( + const DiscreteValues& frontalValues) const; + + /** Single variable version of likelihood. */ + DecisionTreeFactor::shared_ptr likelihood(size_t value) const; + /** * solve a conditional * @param parentsValues Known values of the parents diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 971250ba1..daacee371 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -76,8 +76,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; gtsam::DecisionTreeFactor* toFactor() const; - gtsam::DecisionTreeFactor* chooseAsFactor( + gtsam::DecisionTreeFactor* choose( const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* likelihood( + const gtsam::DiscreteValues& frontalValues) const; + gtsam::DecisionTreeFactor* likelihood(size_t value) const; size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; void solveInPlace(gtsam::DiscreteValues @parentsValues) const; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 604b0ce71..00ae1acd0 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -31,24 +31,21 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST( DiscreteConditional, constructors) -{ - DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! +TEST(DiscreteConditional, constructors) { + DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! + + DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); + EXPECT(expected.endParents() == expected.end()); + EXPECT(expected.endFrontals() == expected.beginParents()); - DiscreteConditional::shared_ptr expected1 = // - boost::make_shared(X | Y = "1/1 2/3 1/4"); - EXPECT(expected1); - EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); - EXPECT(expected1->endParents() == expected1->end()); - EXPECT(expected1->endFrontals() == expected1->beginParents()); - DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(*expected1, actual1, 1e-9)); + EXPECT(assert_equal(expected, actual1, 1e-9)); - DecisionTreeFactor f2(X & Y & Z, - "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); + DecisionTreeFactor f2( + X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); } @@ -108,6 +105,20 @@ TEST(DiscreteConditional, Combine) { EXPECT(assert_equal(expected, *actual, 1e-5)); } +/* ************************************************************************* */ +TEST(DiscreteConditional, likelihood) { + DiscreteKey X(0, 2), Y(1, 3); + DiscreteConditional conditional(X | Y = "2/8 4/6 5/5"); + + auto actual0 = conditional.likelihood(0); + DecisionTreeFactor expected0(Y, "0.2 0.4 0.5"); + EXPECT(assert_equal(expected0, *actual0, 1e-9)); + + auto actual1 = conditional.likelihood(1); + DecisionTreeFactor expected1(Y, "0.8 0.6 0.5"); + EXPECT(assert_equal(expected1, *actual1, 1e-9)); +} + /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. TEST(DiscreteConditional, markdown_prior) { diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 5e24dc40b..0cd02ce6a 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -13,12 +13,26 @@ Author: Varun Agrawal import unittest -from gtsam import DiscreteConditional, DiscreteKeys +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase class TestDiscreteConditional(GtsamTestCase): """Tests for Discrete Conditionals.""" + + def test_likelihood(self): + X = (0, 2) + Y = (1, 3) + conditional = DiscreteConditional(X, "2/8 4/6 5/5", Y) + + actual0 = conditional.likelihood(0) + expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5") + self.gtsamAssertEquals(actual0, expected0, 1e-9) + + actual1 = conditional.likelihood(1) + expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") + self.gtsamAssertEquals(actual1, expected1, 1e-9) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" @@ -32,7 +46,7 @@ class TestDiscreteConditional(GtsamTestCase): conditional = DiscreteConditional(A, parents, "0/1 1/3 1/1 3/1 0/1 1/0") expected = \ - " $P(A|B,C)$:\n" \ + " *P(A|B,C)*:\n\n" \ "|B|C|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \