parent
							
								
									159a78581f
								
							
						
					
					
						commit
						bd4230baae
					
				|  | @ -80,7 +80,7 @@ namespace gtsam { | ||||||
|      * @endcode |      * @endcode | ||||||
|      * |      * | ||||||
|      * The values in the table should be laid out so that the first key varies |      * The values in the table should be laid out so that the first key varies | ||||||
|      * the slowest. and the last key the fastest. |      * the slowest, and the last key the fastest. | ||||||
|      */ |      */ | ||||||
|     DecisionTreeFactor(const DiscreteKeys& keys, |     DecisionTreeFactor(const DiscreteKeys& keys, | ||||||
|                        const std::vector<double>& table); |                        const std::vector<double>& table); | ||||||
|  | @ -101,7 +101,7 @@ namespace gtsam { | ||||||
|      * @endcode |      * @endcode | ||||||
|      * |      * | ||||||
|      * The values in the table should be laid out so that the first key varies |      * The values in the table should be laid out so that the first key varies | ||||||
|      * the slowest. and the last key the fastest. |      * the slowest, and the last key the fastest. | ||||||
|      */ |      */ | ||||||
|     DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); |     DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -58,6 +58,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique | ||||||
| 
 | 
 | ||||||
|   //** evaluate conditional probability of subtree for given DiscreteValues */
 |   //** evaluate conditional probability of subtree for given DiscreteValues */
 | ||||||
|   double evaluate(const DiscreteValues& values) const; |   double evaluate(const DiscreteValues& values) const; | ||||||
|  | 
 | ||||||
|  |   //** (Preferred) sugar for the above for given DiscreteValues */
 | ||||||
|  |   double operator()(const DiscreteValues& values) const { | ||||||
|  |     return evaluate(values); | ||||||
|  |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  |  | ||||||
|  | @ -215,6 +215,7 @@ class DiscreteBayesTreeClique { | ||||||
|       const string& s = "Clique: ", |       const string& s = "Clique: ", | ||||||
|       const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; |       const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; | ||||||
|   double evaluate(const gtsam::DiscreteValues& values) const; |   double evaluate(const gtsam::DiscreteValues& values) const; | ||||||
|  |   double operator()(const gtsam::DiscreteValues& values) const; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| class DiscreteBayesTree { | class DiscreteBayesTree { | ||||||
|  | @ -229,6 +230,7 @@ class DiscreteBayesTree { | ||||||
|   const DiscreteBayesTreeClique* operator[](size_t j) const; |   const DiscreteBayesTreeClique* operator[](size_t j) const; | ||||||
| 
 | 
 | ||||||
|   double evaluate(const gtsam::DiscreteValues& values) const; |   double evaluate(const gtsam::DiscreteValues& values) const; | ||||||
|  |   double operator()(const gtsam::DiscreteValues& values) const; | ||||||
| 
 | 
 | ||||||
|   string dot(const gtsam::KeyFormatter& keyFormatter = |   string dot(const gtsam::KeyFormatter& keyFormatter = | ||||||
|                  gtsam::DefaultKeyFormatter) const; |                  gtsam::DefaultKeyFormatter) const; | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| #include <gtsam/base/Vector.h> | #include <gtsam/base/Vector.h> | ||||||
|  | #include <gtsam/inference/Symbol.h> | ||||||
| #include <gtsam/inference/BayesNet.h> | #include <gtsam/inference/BayesNet.h> | ||||||
| #include <gtsam/discrete/DiscreteBayesNet.h> | #include <gtsam/discrete/DiscreteBayesNet.h> | ||||||
| #include <gtsam/discrete/DiscreteBayesTree.h> | #include <gtsam/discrete/DiscreteBayesTree.h> | ||||||
|  | @ -26,7 +27,6 @@ | ||||||
| #include <iostream> | #include <iostream> | ||||||
| #include <vector> | #include <vector> | ||||||
| 
 | 
 | ||||||
| using namespace std; |  | ||||||
| using namespace gtsam; | using namespace gtsam; | ||||||
| static constexpr bool debug = false; | static constexpr bool debug = false; | ||||||
| 
 | 
 | ||||||
|  | @ -108,7 +108,7 @@ TEST(DiscreteBayesTree, ThinTree) { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| // Check calculation of separator marginals
 | // Check calculation of separator marginals
 | ||||||
| TEST(DiscreteBayesTree, separatorMarginal) { | TEST(DiscreteBayesTree, SeparatorMarginals) { | ||||||
|   TestFixture self; |   TestFixture self; | ||||||
| 
 | 
 | ||||||
|   // Calculate some marginals for DiscreteValues==all1
 |   // Calculate some marginals for DiscreteValues==all1
 | ||||||
|  | @ -141,7 +141,7 @@ TEST(DiscreteBayesTree, separatorMarginal) { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| // Check shortcuts in the tree
 | // Check shortcuts in the tree
 | ||||||
| TEST(DiscreteBayesTree, shortcut) { | TEST(DiscreteBayesTree, Shortcuts) { | ||||||
|   TestFixture self; |   TestFixture self; | ||||||
| 
 | 
 | ||||||
|   // Calculate some marginals for DiscreteValues==all1
 |   // Calculate some marginals for DiscreteValues==all1
 | ||||||
|  | @ -199,7 +199,7 @@ TEST(DiscreteBayesTree, shortcut) { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| // Check all marginals
 | // Check all marginals
 | ||||||
| TEST(DiscreteBayesTree, marginalFactor) { | TEST(DiscreteBayesTree, MarginalFactors) { | ||||||
|   TestFixture self; |   TestFixture self; | ||||||
| 
 | 
 | ||||||
|   Vector marginals = Vector::Zero(15); |   Vector marginals = Vector::Zero(15); | ||||||
|  | @ -286,7 +286,7 @@ TEST(DiscreteBayesTree, Joints) { | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| TEST(DiscreteBayesTree, Dot) { | TEST(DiscreteBayesTree, Dot) { | ||||||
|   TestFixture self; |   TestFixture self; | ||||||
|   string actual = self.bayesTree->dot(); |   std::string actual = self.bayesTree->dot(); | ||||||
|   EXPECT(actual == |   EXPECT(actual == | ||||||
|          "digraph G{\n" |          "digraph G{\n" | ||||||
|          "0[label=\"13, 11, 6, 7\"];\n" |          "0[label=\"13, 11, 6, 7\"];\n" | ||||||
|  | @ -313,6 +313,61 @@ TEST(DiscreteBayesTree, Dot) { | ||||||
|          "}"); |          "}"); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | // Check that we can have a multi-frontal lookup table
 | ||||||
|  | TEST(DiscreteBayesTree, Lookup) { | ||||||
|  |   using gtsam::symbol_shorthand::A; | ||||||
|  |   using gtsam::symbol_shorthand::X; | ||||||
|  | 
 | ||||||
|  |   // Make a small planning-like graph: 3 states, 2 actions
 | ||||||
|  |   DiscreteFactorGraph graph; | ||||||
|  |   const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3}; | ||||||
|  |   const DiscreteKey a1{A(1), 2}, a2{A(2), 2}; | ||||||
|  |   const DiscreteKeys keys{x1, x2, x3, a1, a2}; | ||||||
|  |   // Constraint on start and goal
 | ||||||
|  |   graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0}); | ||||||
|  |   graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1}); | ||||||
|  |   // Should I stay or should I go?
 | ||||||
|  |   // "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
 | ||||||
|  |   const double r = 10; | ||||||
|  |   std::vector<double> table{ | ||||||
|  |       r, 0, 0, 0, r, 0,  // x1 = 0
 | ||||||
|  |       0, r, 0, 0, 0, r,  // x1 = 1
 | ||||||
|  |       0, 0, r, 0, 0, r   // x1 = 2
 | ||||||
|  |   }; | ||||||
|  |   graph.add(DiscreteKeys{x1, a1, x2}, table); | ||||||
|  |   graph.add(DiscreteKeys{x2, a2, x3}, table); | ||||||
|  | 
 | ||||||
|  |   // eliminate for MPE (maximum probable explanation).
 | ||||||
|  |   Ordering ordering{A(2), X(3), X(1), A(1), X(2)}; | ||||||
|  |   auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE); | ||||||
|  | 
 | ||||||
|  |   // Check that the lookup table is correct
 | ||||||
|  |   EXPECT_LONGS_EQUAL(2, lookup->size()); | ||||||
|  |   auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional(); | ||||||
|  |   EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size()); | ||||||
|  |   // check that sum is 100
 | ||||||
|  |   DiscreteValues empty; | ||||||
|  |   EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9); | ||||||
|  |   // And that only non-zero reward is for x1 a1 x2 == 0 1 1
 | ||||||
|  |   EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9); | ||||||
|  | 
 | ||||||
|  |   auto lookup_a2_x3 = (*lookup)[X(3)]->conditional(); | ||||||
|  |   // check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
 | ||||||
|  |   auto sum_x2 = lookup_a2_x3->sum(2); | ||||||
|  |   EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9); | ||||||
|  |   EXPECT_DOUBLES_EQUAL(10, (*sum_x2)({{X(2),1}}), 1e-9); | ||||||
|  |   EXPECT_DOUBLES_EQUAL(20, (*sum_x2)({{X(2),2}}), 1e-9); | ||||||
|  |   EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size()); | ||||||
|  |   // And that the non-zero rewards are for 
 | ||||||
|  |   // x2 a2 x3 == 1 1 2
 | ||||||
|  |   EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9); | ||||||
|  |   // x2 a2 x3 == 2 0 2
 | ||||||
|  |   EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9); | ||||||
|  |   // x2 a2 x3 == 2 1 2
 | ||||||
|  |   EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| int main() { | int main() { | ||||||
|   TestResult tr; |   TestResult tr; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue