likelihood
							parent
							
								
									dbe5c0fa81
								
							
						
					
					
						commit
						457d074858
					
				|  | @ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, | |||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| Potentials::ADT DiscreteConditional::choose( | ||||
|     const DiscreteValues& parentsValues) const { | ||||
| static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, | ||||
|                                        const DiscreteValues& parentsValues) { | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|   // branches based on the value of the parent variables.
 | ||||
|   ADT pFS(*this); | ||||
|   DiscreteConditional::ADT adt(conditional); | ||||
|   size_t value; | ||||
|   for (Key j : parents()) { | ||||
|   for (Key j : conditional.parents()) { | ||||
|     try { | ||||
|       value = parentsValues.at(j); | ||||
|       pFS = pFS.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|       adt = adt.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (exception&) { | ||||
|       cout << "Key: " << j << "  Value: " << value << endl; | ||||
|       parentsValues.print("parentsValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: parent value missing"); | ||||
|     }; | ||||
|   } | ||||
|   return pFS; | ||||
|   return adt; | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::choose( | ||||
|     const DiscreteValues& parentsValues) const { | ||||
|   ADT pFS = choose(parentsValues); | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|   // branches based on the value of the parent variables.
 | ||||
|   ADT adt(*this); | ||||
|   size_t value; | ||||
|   for (Key j : parents()) { | ||||
|     try { | ||||
|       value = parentsValues.at(j); | ||||
|       adt = adt.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (exception&) { | ||||
|       parentsValues.print("parentsValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: parent value missing"); | ||||
|     }; | ||||
|   } | ||||
| 
 | ||||
|   // Convert ADT to factor.
 | ||||
|   if (nrFrontals() != 1) { | ||||
|     throw std::runtime_error("Expected only one frontal variable in choose."); | ||||
|   DiscreteKeys discreteKeys; | ||||
|   for (Key j : frontals()) { | ||||
|     discreteKeys.emplace_back(j, this->cardinality(j)); | ||||
|   } | ||||
|   DiscreteKeys keys; | ||||
|   const Key frontalKey = keys_[0]; | ||||
|   size_t frontalCardinality = this->cardinality(frontalKey); | ||||
|   keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); | ||||
|   return boost::make_shared<DecisionTreeFactor>(keys, pFS); | ||||
|   return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( | ||||
|     const DiscreteValues& frontalValues) const { | ||||
|   // Get the big decision tree with all the levels, and then go down the
 | ||||
|   // branches based on the value of the frontal variables.
 | ||||
|   ADT adt(*this); | ||||
|   size_t value; | ||||
|   for (Key j : frontals()) { | ||||
|     try { | ||||
|       value = frontalValues.at(j); | ||||
|       adt = adt.choose(j, value);  // ADT keeps getting smaller.
 | ||||
|     } catch (exception&) { | ||||
|       frontalValues.print("frontalValues: "); | ||||
|       throw runtime_error("DiscreteConditional::choose: frontal value missing"); | ||||
|     }; | ||||
|   } | ||||
| 
 | ||||
|   // Convert ADT to factor.
 | ||||
|   DiscreteKeys discreteKeys; | ||||
|   for (Key j : parents()) { | ||||
|     discreteKeys.emplace_back(j, this->cardinality(j)); | ||||
|   } | ||||
|   return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( | ||||
|     size_t value) const { | ||||
|   if (nrFrontals() != 1) | ||||
|     throw std::invalid_argument( | ||||
|         "Single value likelihood can only be invoked on single-variable " | ||||
|         "conditional"); | ||||
|   DiscreteValues values; | ||||
|   values.emplace(keys_[0], value); | ||||
|   return likelihood(values); | ||||
| } | ||||
| 
 | ||||
| /* ******************************************************************************** */ | ||||
| void DiscreteConditional::solveInPlace(DiscreteValues* values) const { | ||||
|   // TODO: Abhijit asks: is this really the fastest way? He thinks it is.
 | ||||
|   ADT pFS = choose(*values); // P(F|S=parentsValues)
 | ||||
|   ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Initialize
 | ||||
|   DiscreteValues mpe; | ||||
|  | @ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { | |||
| size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { | ||||
| 
 | ||||
|   // TODO: is this really the fastest way? I think it is.
 | ||||
|   ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
 | ||||
|   ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // Then, find the max over all remaining
 | ||||
|   // TODO, only works for one key now, seems horribly slow this way
 | ||||
|  | @ -203,7 +248,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { | |||
|   static mt19937 rng(2);  // random number generator
 | ||||
| 
 | ||||
|   // Get the correct conditional density
 | ||||
|   ADT pFS = choose(parentsValues);  // P(F|S=parentsValues)
 | ||||
|   ADT pFS = Choose(*this, parentsValues);  // P(F|S=parentsValues)
 | ||||
| 
 | ||||
|   // TODO(Duy): only works for one key now, seems horribly slow this way
 | ||||
|   assert(nrFrontals() == 1); | ||||
|  |  | |||
|  | @ -146,13 +146,17 @@ public: | |||
|     return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); | ||||
|   } | ||||
| 
 | ||||
|   /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ | ||||
|   ADT choose(const DiscreteValues& parentsValues) const; | ||||
| 
 | ||||
|   /** Restrict to given parent values, returns DecisionTreeFactor */ | ||||
|   DecisionTreeFactor::shared_ptr chooseAsFactor( | ||||
|   DecisionTreeFactor::shared_ptr choose( | ||||
|       const DiscreteValues& parentsValues) const; | ||||
| 
 | ||||
|   /** Convert to a likelihood factor by providing value before bar. */ | ||||
|   DecisionTreeFactor::shared_ptr likelihood( | ||||
|       const DiscreteValues& frontalValues) const; | ||||
| 
 | ||||
|   /** Single variable version of likelihood. */ | ||||
|   DecisionTreeFactor::shared_ptr likelihood(size_t value) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * solve a conditional | ||||
|    * @param parentsValues Known values of the parents | ||||
|  |  | |||
|  | @ -76,8 +76,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { | |||
|       string s = "Discrete Conditional: ", | ||||
|       const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; | ||||
|   gtsam::DecisionTreeFactor* toFactor() const; | ||||
|   gtsam::DecisionTreeFactor* chooseAsFactor( | ||||
|   gtsam::DecisionTreeFactor* choose( | ||||
|       const gtsam::DiscreteValues& parentsValues) const; | ||||
|   gtsam::DecisionTreeFactor* likelihood( | ||||
|       const gtsam::DiscreteValues& frontalValues) const; | ||||
|   gtsam::DecisionTreeFactor* likelihood(size_t value) const; | ||||
|   size_t solve(const gtsam::DiscreteValues& parentsValues) const; | ||||
|   size_t sample(const gtsam::DiscreteValues& parentsValues) const; | ||||
|   void solveInPlace(gtsam::DiscreteValues @parentsValues) const; | ||||
|  |  | |||
|  | @ -31,24 +31,21 @@ using namespace std; | |||
| using namespace gtsam; | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST( DiscreteConditional, constructors) | ||||
| { | ||||
|   DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
 | ||||
| TEST(DiscreteConditional, constructors) { | ||||
|   DiscreteKey X(0, 2), Y(2, 3), Z(1, 2);  // watch ordering !
 | ||||
| 
 | ||||
|   DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); | ||||
|   EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); | ||||
|   EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); | ||||
|   EXPECT(expected.endParents() == expected.end()); | ||||
|   EXPECT(expected.endFrontals() == expected.beginParents()); | ||||
| 
 | ||||
|   DiscreteConditional::shared_ptr expected1 = //
 | ||||
|       boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4"); | ||||
|   EXPECT(expected1); | ||||
|   EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); | ||||
|   EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); | ||||
|   EXPECT(expected1->endParents() == expected1->end()); | ||||
|   EXPECT(expected1->endFrontals() == expected1->beginParents()); | ||||
|    | ||||
|   DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); | ||||
|   DiscreteConditional actual1(1, f1); | ||||
|   EXPECT(assert_equal(*expected1, actual1, 1e-9)); | ||||
|   EXPECT(assert_equal(expected, actual1, 1e-9)); | ||||
| 
 | ||||
|   DecisionTreeFactor f2(X & Y & Z, | ||||
|       "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DecisionTreeFactor f2( | ||||
|       X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); | ||||
|   DiscreteConditional actual2(1, f2); | ||||
|   EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); | ||||
| } | ||||
|  | @ -108,6 +105,20 @@ TEST(DiscreteConditional, Combine) { | |||
|   EXPECT(assert_equal(expected, *actual, 1e-5)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DiscreteConditional, likelihood) { | ||||
|   DiscreteKey X(0, 2), Y(1, 3); | ||||
|   DiscreteConditional conditional(X | Y = "2/8 4/6 5/5"); | ||||
| 
 | ||||
|   auto actual0 = conditional.likelihood(0); | ||||
|   DecisionTreeFactor expected0(Y, "0.2 0.4 0.5"); | ||||
|   EXPECT(assert_equal(expected0, *actual0, 1e-9)); | ||||
| 
 | ||||
|   auto actual1 = conditional.likelihood(1); | ||||
|   DecisionTreeFactor expected1(Y, "0.8 0.6 0.5"); | ||||
|   EXPECT(assert_equal(expected1, *actual1, 1e-9)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check markdown representation looks as expected, no parents.
 | ||||
| TEST(DiscreteConditional, markdown_prior) { | ||||
|  |  | |||
|  | @ -13,12 +13,26 @@ Author: Varun Agrawal | |||
| 
 | ||||
| import unittest | ||||
| 
 | ||||
| from gtsam import DiscreteConditional, DiscreteKeys | ||||
| from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| 
 | ||||
| class TestDiscreteConditional(GtsamTestCase): | ||||
|     """Tests for Discrete Conditionals.""" | ||||
| 
 | ||||
|     def test_likelihood(self): | ||||
|         X = (0, 2) | ||||
|         Y = (1, 3) | ||||
|         conditional = DiscreteConditional(X, "2/8 4/6 5/5", Y) | ||||
| 
 | ||||
|         actual0 = conditional.likelihood(0) | ||||
|         expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5") | ||||
|         self.gtsamAssertEquals(actual0, expected0, 1e-9) | ||||
| 
 | ||||
|         actual1 = conditional.likelihood(1) | ||||
|         expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") | ||||
|         self.gtsamAssertEquals(actual1, expected1, 1e-9) | ||||
| 
 | ||||
|     def test_markdown(self): | ||||
|         """Test whether the _repr_markdown_ method.""" | ||||
| 
 | ||||
|  | @ -32,7 +46,7 @@ class TestDiscreteConditional(GtsamTestCase): | |||
|         conditional = DiscreteConditional(A, parents, | ||||
|                                           "0/1 1/3  1/1 3/1  0/1 1/0") | ||||
|         expected = \ | ||||
|             " $P(A|B,C)$:\n" \ | ||||
|             " *P(A|B,C)*:\n\n" \ | ||||
|             "|B|C|0|1|\n" \ | ||||
|             "|:-:|:-:|:-:|:-:|\n" \ | ||||
|             "|0|0|0|1|\n" \ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue