Constructor from PMF
							parent
							
								
									1000825b03
								
							
						
					
					
						commit
						be5aa56df7
					
				|  | @ -48,17 +48,17 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { | ||||||
|   DiscretePrior(const Signature& s) : Base(s) {} |   DiscretePrior(const Signature& s) : Base(s) {} | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * Construct from key and a Signature::Table specifying the |    * Construct from key and a vector of floats specifying the probability mass | ||||||
|    * conditional probability table (CPT). |    * function (PMF). | ||||||
|    * |    * | ||||||
|    * Example: DiscretePrior P(D, table); |    * Example: DiscretePrior P(D, {0.4, 0.6}); | ||||||
|    */ |    */ | ||||||
|   DiscretePrior(const DiscreteKey& key, const Signature::Table& table) |   DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec) | ||||||
|       : Base(Signature(key, {}, table)) {} |       : DiscretePrior(Signature(key, {}, Signature::Table{spec})) {} | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * Construct from key and a string specifying the conditional |    * Construct from key and a string specifying the probability mass function | ||||||
|    * probability table (CPT). |    * (PMF). | ||||||
|    * |    * | ||||||
|    * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); |    * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); | ||||||
|    */ |    */ | ||||||
|  |  | ||||||
|  | @ -120,6 +120,7 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { | ||||||
|   DiscretePrior(); |   DiscretePrior(); | ||||||
|   DiscretePrior(const gtsam::DecisionTreeFactor& f); |   DiscretePrior(const gtsam::DecisionTreeFactor& f); | ||||||
|   DiscretePrior(const gtsam::DiscreteKey& key, string spec); |   DiscretePrior(const gtsam::DiscreteKey& key, string spec); | ||||||
|  |   DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec); | ||||||
|   void print(string s = "Discrete Prior\n", |   void print(string s = "Discrete Prior\n", | ||||||
|              const gtsam::KeyFormatter& keyFormatter = |              const gtsam::KeyFormatter& keyFormatter = | ||||||
|                  gtsam::DefaultKeyFormatter) const; |                  gtsam::DefaultKeyFormatter) const; | ||||||
|  |  | ||||||
|  | @ -27,12 +27,19 @@ static const DiscreteKey X(0, 2); | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| TEST(DiscretePrior, constructors) { | TEST(DiscretePrior, constructors) { | ||||||
|  |   DecisionTreeFactor f(X, "0.4 0.6"); | ||||||
|  |   DiscretePrior expected(f); | ||||||
|  | 
 | ||||||
|   DiscretePrior actual(X % "2/3"); |   DiscretePrior actual(X % "2/3"); | ||||||
|   EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); |   EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); | ||||||
|   EXPECT_LONGS_EQUAL(0, actual.nrParents()); |   EXPECT_LONGS_EQUAL(0, actual.nrParents()); | ||||||
|   DecisionTreeFactor f(X, "0.4 0.6"); |  | ||||||
|   DiscretePrior expected(f); |  | ||||||
|   EXPECT(assert_equal(expected, actual, 1e-9)); |   EXPECT(assert_equal(expected, actual, 1e-9)); | ||||||
|  | 
 | ||||||
|  |   const vector<double> pmf{0.4, 0.6}; | ||||||
|  |   DiscretePrior actual2(X, pmf); | ||||||
|  |   EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); | ||||||
|  |   EXPECT_LONGS_EQUAL(0, actual2.nrParents()); | ||||||
|  |   EXPECT(assert_equal(expected, actual2, 1e-9)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  |  | ||||||
|  | @ -25,13 +25,17 @@ class TestDiscretePrior(GtsamTestCase): | ||||||
| 
 | 
 | ||||||
|     def test_constructor(self): |     def test_constructor(self): | ||||||
|         """Test various constructors.""" |         """Test various constructors.""" | ||||||
|         actual = DiscretePrior(X, "2/3") |  | ||||||
|         keys = DiscreteKeys() |         keys = DiscreteKeys() | ||||||
|         keys.push_back(X) |         keys.push_back(X) | ||||||
|         f = DecisionTreeFactor(keys, "0.4 0.6") |         f = DecisionTreeFactor(keys, "0.4 0.6") | ||||||
|         expected = DiscretePrior(f) |         expected = DiscretePrior(f) | ||||||
|  |          | ||||||
|  |         actual = DiscretePrior(X, "2/3") | ||||||
|         self.gtsamAssertEquals(actual, expected) |         self.gtsamAssertEquals(actual, expected) | ||||||
|          |          | ||||||
|  |         actual2 = DiscretePrior(X, [0.4, 0.6]) | ||||||
|  |         self.gtsamAssertEquals(actual2, expected) | ||||||
|  | 
 | ||||||
|     def test_operator(self): |     def test_operator(self): | ||||||
|         prior = DiscretePrior(X, "2/3") |         prior = DiscretePrior(X, "2/3") | ||||||
|         self.assertAlmostEqual(prior(0), 0.4) |         self.assertAlmostEqual(prior(0), 0.4) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue