commit
						0a1a7510f9
					
				|  | @ -20,13 +20,14 @@ | ||||||
| #include <vector> | #include <vector> | ||||||
| #include <map> | #include <map> | ||||||
| #include <boost/shared_ptr.hpp> | #include <boost/shared_ptr.hpp> | ||||||
|  | #include <gtsam/inference/BayesNet.h> | ||||||
| #include <gtsam/inference/FactorGraph.h> | #include <gtsam/inference/FactorGraph.h> | ||||||
| #include <gtsam/discrete/DiscreteConditional.h> | #include <gtsam/discrete/DiscreteConditional.h> | ||||||
| 
 | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
| /** A Bayes net made from linear-Discrete densities */ | /** A Bayes net made from linear-Discrete densities */ | ||||||
|   class GTSAM_EXPORT DiscreteBayesNet: public FactorGraph<DiscreteConditional> |   class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> | ||||||
|   { |   { | ||||||
|   public: |   public: | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -29,13 +29,32 @@ namespace gtsam { | ||||||
|   template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>; |   template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>; | ||||||
|   template class BayesTree<DiscreteBayesTreeClique>; |   template class BayesTree<DiscreteBayesTreeClique>; | ||||||
| 
 | 
 | ||||||
|  |   /* ************************************************************************* */ | ||||||
|  |   double DiscreteBayesTreeClique::evaluate( | ||||||
|  |       const DiscreteConditional::Values& values) const { | ||||||
|  |     // evaluate all conditionals and multiply
 | ||||||
|  |     double result = (*conditional_)(values); | ||||||
|  |     for (const auto& child : children) { | ||||||
|  |       result *= child->evaluate(values); | ||||||
|  |     } | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   /* ************************************************************************* */ |   /* ************************************************************************* */ | ||||||
|   bool DiscreteBayesTree::equals(const This& other, double tol) const |   bool DiscreteBayesTree::equals(const This& other, double tol) const { | ||||||
|   { |  | ||||||
|     return Base::equals(other, tol); |     return Base::equals(other, tol); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /* ************************************************************************* */ | ||||||
|  |   double DiscreteBayesTree::evaluate( | ||||||
|  |       const DiscreteConditional::Values& values) const { | ||||||
|  |     double result = 1.0; | ||||||
|  |     for (const auto& root : roots_) { | ||||||
|  |       result *= root->evaluate(values); | ||||||
|  |     } | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
| } // \namespace gtsam
 | } // \namespace gtsam
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -11,7 +11,8 @@ | ||||||
| 
 | 
 | ||||||
| /**
 | /**
 | ||||||
|  * @file    DiscreteBayesTree.h |  * @file    DiscreteBayesTree.h | ||||||
|  * @brief   Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree |  * @brief   Discrete Bayes Tree, the result of eliminating a | ||||||
|  |  * DiscreteJunctionTree | ||||||
|  * @brief   DiscreteBayesTree |  * @brief   DiscreteBayesTree | ||||||
|  * @author  Frank Dellaert |  * @author  Frank Dellaert | ||||||
|  * @author  Richard Roberts |  * @author  Richard Roberts | ||||||
|  | @ -22,45 +23,62 @@ | ||||||
| #include <gtsam/discrete/DiscreteBayesNet.h> | #include <gtsam/discrete/DiscreteBayesNet.h> | ||||||
| #include <gtsam/discrete/DiscreteFactorGraph.h> | #include <gtsam/discrete/DiscreteFactorGraph.h> | ||||||
| #include <gtsam/inference/BayesTree.h> | #include <gtsam/inference/BayesTree.h> | ||||||
|  | #include <gtsam/inference/Conditional.h> | ||||||
| #include <gtsam/inference/BayesTreeCliqueBase.h> | #include <gtsam/inference/BayesTreeCliqueBase.h> | ||||||
| 
 | 
 | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
|   // Forward declarations
 | // Forward declarations
 | ||||||
|   class DiscreteConditional; | class DiscreteConditional; | ||||||
|   class VectorValues; | class VectorValues; | ||||||
| 
 | 
 | ||||||
|   /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|   /** A clique in a DiscreteBayesTree */ | /** A clique in a DiscreteBayesTree */ | ||||||
|   class GTSAM_EXPORT DiscreteBayesTreeClique : | class GTSAM_EXPORT DiscreteBayesTreeClique | ||||||
|     public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> |     : public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> { | ||||||
|   { |  public: | ||||||
|   public: |   typedef DiscreteBayesTreeClique This; | ||||||
|     typedef DiscreteBayesTreeClique This; |   typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> | ||||||
|     typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> Base; |       Base; | ||||||
|     typedef boost::shared_ptr<This> shared_ptr; |   typedef boost::shared_ptr<This> shared_ptr; | ||||||
|     typedef boost::weak_ptr<This> weak_ptr; |   typedef boost::weak_ptr<This> weak_ptr; | ||||||
|     DiscreteBayesTreeClique() {} |   DiscreteBayesTreeClique() {} | ||||||
|     DiscreteBayesTreeClique(const boost::shared_ptr<DiscreteConditional>& conditional) : Base(conditional) {} |   DiscreteBayesTreeClique( | ||||||
|   }; |       const boost::shared_ptr<DiscreteConditional>& conditional) | ||||||
|  |       : Base(conditional) {} | ||||||
| 
 | 
 | ||||||
|   /* ************************************************************************* */ |   /// print index signature only
 | ||||||
|   /** A Bayes tree representing a Discrete density */ |   void printSignature( | ||||||
|   class GTSAM_EXPORT DiscreteBayesTree : |       const std::string& s = "Clique: ", | ||||||
|     public BayesTree<DiscreteBayesTreeClique> |       const KeyFormatter& formatter = DefaultKeyFormatter) const { | ||||||
|   { |     conditional_->printSignature(s, formatter); | ||||||
|   private: |   } | ||||||
|     typedef BayesTree<DiscreteBayesTreeClique> Base; |  | ||||||
| 
 | 
 | ||||||
|   public: |   //** evaluate conditional probability of subtree for given Values */
 | ||||||
|     typedef DiscreteBayesTree This; |   double evaluate(const DiscreteConditional::Values& values) const; | ||||||
|     typedef boost::shared_ptr<This> shared_ptr; | }; | ||||||
| 
 | 
 | ||||||
|     /** Default constructor, creates an empty Bayes tree */ | /* ************************************************************************* */ | ||||||
|     DiscreteBayesTree() {} | /** A Bayes tree representing a Discrete density */ | ||||||
|  | class GTSAM_EXPORT DiscreteBayesTree | ||||||
|  |     : public BayesTree<DiscreteBayesTreeClique> { | ||||||
|  |  private: | ||||||
|  |   typedef BayesTree<DiscreteBayesTreeClique> Base; | ||||||
| 
 | 
 | ||||||
|     /** Check equality */ |  public: | ||||||
|     bool equals(const This& other, double tol = 1e-9) const; |   typedef DiscreteBayesTree This; | ||||||
|   }; |   typedef boost::shared_ptr<This> shared_ptr; | ||||||
| 
 | 
 | ||||||
| } |   /** Default constructor, creates an empty Bayes tree */ | ||||||
|  |   DiscreteBayesTree() {} | ||||||
|  | 
 | ||||||
|  |   /** Check equality */ | ||||||
|  |   bool equals(const This& other, double tol = 1e-9) const; | ||||||
|  | 
 | ||||||
|  |   //** evaluate probability for given Values */
 | ||||||
|  |   double evaluate(const DiscreteConditional::Values& values) const; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -24,6 +24,8 @@ | ||||||
| #include <boost/shared_ptr.hpp> | #include <boost/shared_ptr.hpp> | ||||||
| #include <boost/make_shared.hpp> | #include <boost/make_shared.hpp> | ||||||
| 
 | 
 | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
| /**
 | /**
 | ||||||
|  | @ -92,6 +94,13 @@ public: | ||||||
|   /// @name Standard Interface
 |   /// @name Standard Interface
 | ||||||
|   /// @{
 |   /// @{
 | ||||||
| 
 | 
 | ||||||
|  |   /// print index signature only
 | ||||||
|  |   void printSignature( | ||||||
|  |       const std::string& s = "Discrete Conditional: ", | ||||||
|  |       const KeyFormatter& formatter = DefaultKeyFormatter) const { | ||||||
|  |     static_cast<const BaseConditional*>(this)->print(s, formatter); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /// Evaluate, just look up in AlgebraicDecisonTree
 |   /// Evaluate, just look up in AlgebraicDecisonTree
 | ||||||
|   virtual double operator()(const Values& values) const { |   virtual double operator()(const Values& values) const { | ||||||
|     return Potentials::operator()(values); |     return Potentials::operator()(values); | ||||||
|  |  | ||||||
|  | @ -1,261 +1,216 @@ | ||||||
| ///* ----------------------------------------------------------------------------
 | /* ----------------------------------------------------------------------------
 | ||||||
| //
 | 
 | ||||||
| // * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, | ||||||
| // * Atlanta, Georgia 30332-0415
 | * Atlanta, Georgia 30332-0415 | ||||||
| // * All Rights Reserved
 | * All Rights Reserved | ||||||
| // * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||||
| //
 | 
 | ||||||
| // * See LICENSE for the license information
 | * See LICENSE for the license information | ||||||
| //
 | 
 | ||||||
| // * -------------------------------------------------------------------------- */
 | * -------------------------------------------------------------------------- */ | ||||||
| //
 | 
 | ||||||
| ///*
 | /*
 | ||||||
| // * @file testDiscreteBayesTree.cpp
 |  * @file testDiscreteBayesTree.cpp | ||||||
| // * @date sept 15, 2012
 |  * @date sept 15, 2012 | ||||||
| // * @author Frank Dellaert
 |  * @author Frank Dellaert | ||||||
| // */
 |  */ | ||||||
| //
 | 
 | ||||||
| //#include <gtsam/discrete/DiscreteBayesNet.h>
 | #include <gtsam/base/Vector.h> | ||||||
| //#include <gtsam/discrete/DiscreteBayesTree.h>
 | #include <gtsam/discrete/DiscreteBayesNet.h> | ||||||
| //#include <gtsam/discrete/DiscreteFactorGraph.h>
 | #include <gtsam/discrete/DiscreteBayesTree.h> | ||||||
| //
 | #include <gtsam/discrete/DiscreteFactorGraph.h> | ||||||
| //#include <boost/assign/std/vector.hpp>
 | #include <gtsam/inference/BayesNet-inst.h> | ||||||
| //using namespace boost::assign;
 | 
 | ||||||
| //
 | #include <boost/assign/std/vector.hpp> | ||||||
|  | using namespace boost::assign; | ||||||
|  | 
 | ||||||
| #include <CppUnitLite/TestHarness.h> | #include <CppUnitLite/TestHarness.h> | ||||||
| //
 | 
 | ||||||
| //using namespace std;
 | #include <vector> | ||||||
| //using namespace gtsam;
 | 
 | ||||||
| //
 | using namespace std; | ||||||
| //static bool debug = false;
 | using namespace gtsam; | ||||||
| //
 | 
 | ||||||
| ///**
 | static bool debug = false; | ||||||
| // * Custom clique class to debug shortcuts
 | 
 | ||||||
| // */
 | /* ************************************************************************* */ | ||||||
| ////class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> {
 | 
 | ||||||
| ////
 | TEST_UNSAFE(DiscreteBayesTree, ThinTree) { | ||||||
| ////protected:
 |   const int nrNodes = 15; | ||||||
| ////
 |   const size_t nrStates = 2; | ||||||
| ////public:
 | 
 | ||||||
| ////
 |   // define variables
 | ||||||
| ////  typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base;
 |   vector<DiscreteKey> key; | ||||||
| ////  typedef boost::shared_ptr<Clique> shared_ptr;
 |   for (int i = 0; i < nrNodes; i++) { | ||||||
| ////
 |     DiscreteKey key_i(i, nrStates); | ||||||
| ////  // Constructors
 |     key.push_back(key_i); | ||||||
| ////  Clique() {
 |   } | ||||||
| ////  }
 | 
 | ||||||
| ////  Clique(const DiscreteConditional::shared_ptr& conditional) :
 |   // create a thin-tree Bayesnet, a la Jean-Guillaume
 | ||||||
| ////      Base(conditional) {
 |   DiscreteBayesNet bayesNet; | ||||||
| ////  }
 |   bayesNet.add(key[14] % "1/3"); | ||||||
| ////  Clique(
 | 
 | ||||||
| ////      const std::pair<DiscreteConditional::shared_ptr,
 |   bayesNet.add(key[13] | key[14] = "1/3 3/1"); | ||||||
| ////          DiscreteConditional::FactorType::shared_ptr>& result) :
 |   bayesNet.add(key[12] | key[14] = "3/1 3/1"); | ||||||
| ////      Base(result) {
 | 
 | ||||||
| ////  }
 |   bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); | ||||||
| ////
 |   bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); | ||||||
| ////  /// print index signature only
 |   bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); | ||||||
| ////  void printSignature(const std::string& s = "Clique: ",
 |   bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); | ||||||
| ////      const KeyFormatter& indexFormatter = DefaultKeyFormatter) const {
 | 
 | ||||||
| ////    ((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter);
 |   bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); | ||||||
| ////  }
 |   bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); | ||||||
| ////
 |   bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); | ||||||
| ////  /// evaluate value of sub-tree
 |   bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); | ||||||
| ////  double evaluate(const DiscreteConditional::Values & values) {
 | 
 | ||||||
| ////    double result = (*(this->conditional_))(values);
 |   bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); | ||||||
| ////    // evaluate all children and multiply into result
 |   bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); | ||||||
| ////    for(boost::shared_ptr<Clique> c: children_)
 |   bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); | ||||||
| ////      result *= c->evaluate(values);
 |   bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); | ||||||
| ////    return result;
 | 
 | ||||||
| ////  }
 |   if (debug) { | ||||||
| ////
 |     GTSAM_PRINT(bayesNet); | ||||||
| ////};
 |     bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); | ||||||
| //
 |   } | ||||||
| ////typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
 | 
 | ||||||
| ////
 |   // create a BayesTree out of a Bayes net
 | ||||||
| /////* ************************************************************************* */
 |   auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); | ||||||
| ////double evaluate(const DiscreteBayesTree& tree,
 |   if (debug) { | ||||||
| ////    const DiscreteConditional::Values & values) {
 |     GTSAM_PRINT(*bayesTree); | ||||||
| ////  return tree.root()->evaluate(values);
 |     bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); | ||||||
| ////}
 |   } | ||||||
| //
 | 
 | ||||||
| ///* ************************************************************************* */
 |   auto R = bayesTree->roots().front(); | ||||||
| //
 | 
 | ||||||
| //TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
 |   // Check whether BN and BT give the same answer on all configurations
 | ||||||
| //
 |   vector<DiscreteFactor::Values> allPosbValues = cartesianProduct( | ||||||
| //  const int nrNodes = 15;
 |       key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & | ||||||
| //  const size_t nrStates = 2;
 |       key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); | ||||||
| //
 |   for (size_t i = 0; i < allPosbValues.size(); ++i) { | ||||||
| //  // define variables
 |     DiscreteFactor::Values x = allPosbValues[i]; | ||||||
| //  vector<DiscreteKey> key;
 |     double expected = bayesNet.evaluate(x); | ||||||
| //  for (int i = 0; i < nrNodes; i++) {
 |     double actual = bayesTree->evaluate(x); | ||||||
| //    DiscreteKey key_i(i, nrStates);
 |     DOUBLES_EQUAL(expected, actual, 1e-9); | ||||||
| //    key.push_back(key_i);
 |   } | ||||||
| //  }
 | 
 | ||||||
| //
 |   // Calculate all some marginals for Values==all1
 | ||||||
| //  // create a thin-tree Bayesnet, a la Jean-Guillaume
 |   Vector marginals = Vector::Zero(15); | ||||||
| //  DiscreteBayesNet bayesNet;
 |   double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, | ||||||
| //  bayesNet.add(key[14] % "1/3");
 |          joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, | ||||||
| //
 |          joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0, | ||||||
| //  bayesNet.add(key[13] | key[14] = "1/3 3/1");
 |          joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; | ||||||
| //  bayesNet.add(key[12] | key[14] = "3/1 3/1");
 |   for (size_t i = 0; i < allPosbValues.size(); ++i) { | ||||||
| //
 |     DiscreteFactor::Values x = allPosbValues[i]; | ||||||
| //  bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
 |     double px = bayesTree->evaluate(x); | ||||||
| //  bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
 |     for (size_t i = 0; i < 15; i++) | ||||||
| //  bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
 |       if (x[i]) marginals[i] += px; | ||||||
| //  bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
 |     if (x[12] && x[14]) joint_12_14 += px; | ||||||
| //
 |     if (x[9] && x[12] && x[14]) joint_9_12_14 += px; | ||||||
| //  bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
 |     if (x[8] && x[12] && x[14]) joint_8_12_14 += px; | ||||||
| //  bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
 |     if (x[8] && x[12]) joint_8_12 += px; | ||||||
| //  bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
 |     if (x[8] && x[2]) joint82 += px; | ||||||
| //  bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
 |     if (x[1] && x[2]) joint12 += px; | ||||||
| //
 |     if (x[2] && x[4]) joint24 += px; | ||||||
| //  bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
 |     if (x[4] && x[5]) joint45 += px; | ||||||
| //  bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
 |     if (x[4] && x[6]) joint46 += px; | ||||||
| //  bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
 |     if (x[4] && x[11]) joint_4_11 += px; | ||||||
| //  bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
 |     if (x[11] && x[13]) { | ||||||
| //
 |       joint_11_13 += px; | ||||||
| ////  if (debug) {
 |       if (x[8] && x[12]) joint_8_11_12_13 += px; | ||||||
| ////    GTSAM_PRINT(bayesNet);
 |       if (x[9] && x[12]) joint_9_11_12_13 += px; | ||||||
| ////    bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
 |       if (x[14]) { | ||||||
| ////  }
 |         joint_11_13_14 += px; | ||||||
| //
 |         if (x[12]) { | ||||||
| //  // create a BayesTree out of a Bayes net
 |           joint_11_12_13_14 += px; | ||||||
| //  DiscreteBayesTree bayesTree(bayesNet);
 |         } | ||||||
| //  if (debug) {
 |       } | ||||||
| //    GTSAM_PRINT(bayesTree);
 |     } | ||||||
| //    bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
 |   } | ||||||
| //  }
 |   DiscreteFactor::Values all1 = allPosbValues.back(); | ||||||
| //
 | 
 | ||||||
| //  // Check whether BN and BT give the same answer on all configurations
 |   // check separator marginal P(S0)
 | ||||||
| //  // Also calculate all some marginals
 |   auto c = (*bayesTree)[0]; | ||||||
| //  Vector marginals = zero(15);
 |   DiscreteFactorGraph separatorMarginal0 = | ||||||
| //  double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
 |       c->separatorMarginal(EliminateDiscrete); | ||||||
| //      joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
 |   DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); | ||||||
| //      joint_4_11 = 0;
 | 
 | ||||||
| //  vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
 |   // check separator marginal P(S9), should be P(14)
 | ||||||
| //      key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
 |   c = (*bayesTree)[9]; | ||||||
| //          & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
 |   DiscreteFactorGraph separatorMarginal9 = | ||||||
| //  for (size_t i = 0; i < allPosbValues.size(); ++i) {
 |       c->separatorMarginal(EliminateDiscrete); | ||||||
| //    DiscreteFactor::Values x = allPosbValues[i];
 |   DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); | ||||||
| //    double expected = evaluate(bayesNet, x);
 | 
 | ||||||
| //    double actual = evaluate(bayesTree, x);
 |   // check separator marginal of root, should be empty
 | ||||||
| //    DOUBLES_EQUAL(expected, actual, 1e-9);
 |   c = (*bayesTree)[11]; | ||||||
| //    // collect marginals
 |   DiscreteFactorGraph separatorMarginal11 = | ||||||
| //    for (size_t i = 0; i < 15; i++)
 |       c->separatorMarginal(EliminateDiscrete); | ||||||
| //      if (x[i])
 |   LONGS_EQUAL(0, separatorMarginal11.size()); | ||||||
| //        marginals[i] += actual;
 | 
 | ||||||
| //    // calculate shortcut 8 and 0
 |   // check shortcut P(S9||R) to root
 | ||||||
| //    if (x[12] && x[14])
 |   c = (*bayesTree)[9]; | ||||||
| //      joint_12_14 += actual;
 |   DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); | ||||||
| //    if (x[9] && x[12] & x[14])
 |   LONGS_EQUAL(1, shortcut.size()); | ||||||
| //      joint_9_12_14 += actual;
 |   DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); | ||||||
| //    if (x[8] && x[12] & x[14])
 | 
 | ||||||
| //      joint_8_12_14 += actual;
 |   // check shortcut P(S8||R) to root
 | ||||||
| //    if (x[8] && x[12])
 |   c = (*bayesTree)[8]; | ||||||
| //      joint_8_12 += actual;
 |   shortcut = c->shortcut(R, EliminateDiscrete); | ||||||
| //    if (x[8] && x[2])
 |   DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); | ||||||
| //      joint82 += actual;
 | 
 | ||||||
| //    if (x[1] && x[2])
 |   // check shortcut P(S2||R) to root
 | ||||||
| //      joint12 += actual;
 |   c = (*bayesTree)[2]; | ||||||
| //    if (x[2] && x[4])
 |   shortcut = c->shortcut(R, EliminateDiscrete); | ||||||
| //      joint24 += actual;
 |   DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); | ||||||
| //    if (x[4] && x[5])
 | 
 | ||||||
| //      joint45 += actual;
 |   // check shortcut P(S0||R) to root
 | ||||||
| //    if (x[4] && x[6])
 |   c = (*bayesTree)[0]; | ||||||
| //      joint46 += actual;
 |   shortcut = c->shortcut(R, EliminateDiscrete); | ||||||
| //    if (x[4] && x[11])
 |   DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); | ||||||
| //      joint_4_11 += actual;
 | 
 | ||||||
| //  }
 |   // calculate all shortcuts to root
 | ||||||
| //  DiscreteFactor::Values all1 = allPosbValues.back();
 |   DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); | ||||||
| //
 |   for (auto c : cliques) { | ||||||
| //  Clique::shared_ptr R = bayesTree.root();
 |     DiscreteBayesNet shortcut = c.second->shortcut(R, EliminateDiscrete); | ||||||
| //
 |     if (debug) { | ||||||
| //  // check separator marginal P(S0)
 |       c.second->conditional_->printSignature(); | ||||||
| //  Clique::shared_ptr c = bayesTree[0];
 |       shortcut.print("shortcut:"); | ||||||
| //  DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
 |     } | ||||||
| //      EliminateDiscrete);
 |   } | ||||||
| //  EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
 | 
 | ||||||
| //
 |   // Check all marginals
 | ||||||
| //  // check separator marginal P(S9), should be P(14)
 |   DiscreteFactor::shared_ptr marginalFactor; | ||||||
| //  c = bayesTree[9];
 |   for (size_t i = 0; i < 15; i++) { | ||||||
| //  DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
 |     marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); | ||||||
| //      EliminateDiscrete);
 |     double actual = (*marginalFactor)(all1); | ||||||
| //  EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
 |     DOUBLES_EQUAL(marginals[i], actual, 1e-9); | ||||||
| //
 |   } | ||||||
| //  // check separator marginal of root, should be empty
 | 
 | ||||||
| //  c = bayesTree[11];
 |   DiscreteBayesNet::shared_ptr actualJoint; | ||||||
| //  DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
 | 
 | ||||||
| //      EliminateDiscrete);
 |   // Check joint P(8, 2)
 | ||||||
| //  EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
 |   actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); | ||||||
| //
 |   DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9); | ||||||
| //  // check shortcut P(S9||R) to root
 | 
 | ||||||
| //  c = bayesTree[9];
 |   // Check joint P(1, 2)
 | ||||||
| //  DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
 |   actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); | ||||||
| //  EXPECT_LONGS_EQUAL(0, shortcut.size());
 |   DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9); | ||||||
| //
 | 
 | ||||||
| //  // check shortcut P(S8||R) to root
 |   // Check joint P(2, 4)
 | ||||||
| //  c = bayesTree[8];
 |   actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); | ||||||
| //  shortcut = c->shortcut(R, EliminateDiscrete);
 |   DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9); | ||||||
| //  EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
 | 
 | ||||||
| //      1e-9);
 |   // Check joint P(4, 5)
 | ||||||
| //
 |   actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); | ||||||
| //  // check shortcut P(S2||R) to root
 |   DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9); | ||||||
| //  c = bayesTree[2];
 | 
 | ||||||
| //  shortcut = c->shortcut(R, EliminateDiscrete);
 |   // Check joint P(4, 6)
 | ||||||
| //  EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
 |   actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); | ||||||
| //      1e-9);
 |   DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9); | ||||||
| //
 | 
 | ||||||
| //  // check shortcut P(S0||R) to root
 |   // Check joint P(4, 11)
 | ||||||
| //  c = bayesTree[0];
 |   actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); | ||||||
| //  shortcut = c->shortcut(R, EliminateDiscrete);
 |   DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9); | ||||||
| //  EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
 | } | ||||||
| //      1e-9);
 |  | ||||||
| //
 |  | ||||||
| //  // calculate all shortcuts to root
 |  | ||||||
| //  DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
 |  | ||||||
| //  for(Clique::shared_ptr c: cliques) {
 |  | ||||||
| //    DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
 |  | ||||||
| //    if (debug) {
 |  | ||||||
| //      c->printSignature();
 |  | ||||||
| //      shortcut.print("shortcut:");
 |  | ||||||
| //    }
 |  | ||||||
| //  }
 |  | ||||||
| //
 |  | ||||||
| //  // Check all marginals
 |  | ||||||
| //  DiscreteFactor::shared_ptr marginalFactor;
 |  | ||||||
| //  for (size_t i = 0; i < 15; i++) {
 |  | ||||||
| //    marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
 |  | ||||||
| //    double actual = (*marginalFactor)(all1);
 |  | ||||||
| //    EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
 |  | ||||||
| //  }
 |  | ||||||
| //
 |  | ||||||
| //  DiscreteBayesNet::shared_ptr actualJoint;
 |  | ||||||
| //
 |  | ||||||
| //  // Check joint P(8,2) TODO: not disjoint !
 |  | ||||||
| ////  actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
 |  | ||||||
| ////  EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
 |  | ||||||
| //
 |  | ||||||
| //  // Check joint P(1,2) TODO: not disjoint !
 |  | ||||||
| ////  actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
 |  | ||||||
| ////  EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
 |  | ||||||
| //
 |  | ||||||
| //  // Check joint P(2,4)
 |  | ||||||
| //  actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
 |  | ||||||
| //  EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
 |  | ||||||
| //
 |  | ||||||
| //  // Check joint P(4,5) TODO: not disjoint !
 |  | ||||||
| ////  actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
 |  | ||||||
| ////  EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
 |  | ||||||
| //
 |  | ||||||
| //  // Check joint P(4,6) TODO: not disjoint !
 |  | ||||||
| ////  actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
 |  | ||||||
| ////  EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
 |  | ||||||
| //
 |  | ||||||
| //  // Check joint P(4,11)
 |  | ||||||
| //  actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
 |  | ||||||
| //  EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
 |  | ||||||
| //
 |  | ||||||
| //}
 |  | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| int main() { | int main() { | ||||||
|  | @ -263,4 +218,3 @@ int main() { | ||||||
|   return TestRegistry::runAllTests(tr); |   return TestRegistry::runAllTests(tr); | ||||||
| } | } | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| 
 |  | ||||||
|  |  | ||||||
										
											Binary file not shown.
										
									
								
							|  | @ -19,6 +19,7 @@ | ||||||
| #include <gtsam/discrete/DiscreteFactorGraph.h> | #include <gtsam/discrete/DiscreteFactorGraph.h> | ||||||
| #include <gtsam/discrete/DiscreteEliminationTree.h> | #include <gtsam/discrete/DiscreteEliminationTree.h> | ||||||
| #include <gtsam/discrete/DiscreteBayesTree.h> | #include <gtsam/discrete/DiscreteBayesTree.h> | ||||||
|  | #include <gtsam/inference/BayesNet-inst.h> | ||||||
| 
 | 
 | ||||||
| #include <CppUnitLite/TestHarness.h> | #include <CppUnitLite/TestHarness.h> | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -136,57 +136,61 @@ namespace gtsam { | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /* ************************************************************************* */ |   /* *********************************************************************** */ | ||||||
|   // separator marginal, uses separator marginal of parent recursively
 |   // separator marginal, uses separator marginal of parent recursively
 | ||||||
|   // P(C) = P(F|S) P(S)
 |   // P(C) = P(F|S) P(S)
 | ||||||
|   /* ************************************************************************* */ |   /* *********************************************************************** */ | ||||||
|   template<class DERIVED, class FACTORGRAPH> |   template <class DERIVED, class FACTORGRAPH> | ||||||
|   typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType |   typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType | ||||||
|     BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::separatorMarginal(Eliminate function) const |   BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::separatorMarginal( | ||||||
|   { |       Eliminate function) const { | ||||||
|     gttic(BayesTreeCliqueBase_separatorMarginal); |     gttic(BayesTreeCliqueBase_separatorMarginal); | ||||||
|     // Check if the Separator marginal was already calculated
 |     // Check if the Separator marginal was already calculated
 | ||||||
|     if (!cachedSeparatorMarginal_) |     if (!cachedSeparatorMarginal_) { | ||||||
|     { |  | ||||||
|       gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); |       gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); | ||||||
|  | 
 | ||||||
|       // If this is the root, there is no separator
 |       // If this is the root, there is no separator
 | ||||||
|       if (parent_.expired() /*(if we're the root)*/) |       if (parent_.expired() /*(if we're the root)*/) { | ||||||
|       { |  | ||||||
|         // we are root, return empty
 |         // we are root, return empty
 | ||||||
|         FactorGraphType empty; |         FactorGraphType empty; | ||||||
|         cachedSeparatorMarginal_ = empty; |         cachedSeparatorMarginal_ = empty; | ||||||
|       } |       } else { | ||||||
|       else |         // Flatten recursion in timing outline
 | ||||||
|       { |         gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss); | ||||||
|  |         gttoc(BayesTreeCliqueBase_separatorMarginal); | ||||||
|  | 
 | ||||||
|         // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
 |         // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
 | ||||||
|         // initialize P(Cp) with the parent separator marginal
 |         // initialize P(Cp) with the parent separator marginal
 | ||||||
|         derived_ptr parent(parent_.lock()); |         derived_ptr parent(parent_.lock()); | ||||||
|         gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss); // Flatten recursion in timing outline
 |         FactorGraphType p_Cp(parent->separatorMarginal(function));  // P(Sp)
 | ||||||
|         gttoc(BayesTreeCliqueBase_separatorMarginal); | 
 | ||||||
|         FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
 |  | ||||||
|         gttic(BayesTreeCliqueBase_separatorMarginal); |         gttic(BayesTreeCliqueBase_separatorMarginal); | ||||||
|         gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); |         gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); | ||||||
|  | 
 | ||||||
|         // now add the parent conditional
 |         // now add the parent conditional
 | ||||||
|         p_Cp += parent->conditional_; // P(Fp|Sp)
 |         p_Cp += parent->conditional_;  // P(Fp|Sp)
 | ||||||
| 
 | 
 | ||||||
|         // The variables we want to keepSet are exactly the ones in S
 |         // The variables we want to keepSet are exactly the ones in S
 | ||||||
|         KeyVector indicesS(this->conditional()->beginParents(), this->conditional()->endParents()); |         KeyVector indicesS(this->conditional()->beginParents(), | ||||||
|         cachedSeparatorMarginal_ = *p_Cp.marginalMultifrontalBayesNet(Ordering(indicesS), function); |                            this->conditional()->endParents()); | ||||||
|  |         auto separatorMarginal = | ||||||
|  |             p_Cp.marginalMultifrontalBayesNet(Ordering(indicesS), function); | ||||||
|  |         cachedSeparatorMarginal_.reset(*separatorMarginal); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // return the shortcut P(S||B)
 |     // return the shortcut P(S||B)
 | ||||||
|     return *cachedSeparatorMarginal_; // return the cached version
 |     return *cachedSeparatorMarginal_;  // return the cached version
 | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /* ************************************************************************* */ |   /* *********************************************************************** */ | ||||||
|   // marginal2, uses separator marginal of parent recursively
 |   // marginal2, uses separator marginal of parent
 | ||||||
|   // P(C) = P(F|S) P(S)
 |   // P(C) = P(F|S) P(S)
 | ||||||
|   /* ************************************************************************* */ |   /* *********************************************************************** */ | ||||||
|   template<class DERIVED, class FACTORGRAPH> |   template <class DERIVED, class FACTORGRAPH> | ||||||
|   typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType |   typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType | ||||||
|     BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::marginal2(Eliminate function) const |   BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::marginal2( | ||||||
|   { |       Eliminate function) const { | ||||||
|     gttic(BayesTreeCliqueBase_marginal2); |     gttic(BayesTreeCliqueBase_marginal2); | ||||||
|     // initialize with separator marginal P(S)
 |     // initialize with separator marginal P(S)
 | ||||||
|     FactorGraphType p_C = this->separatorMarginal(function); |     FactorGraphType p_C = this->separatorMarginal(function); | ||||||
|  |  | ||||||
|  | @ -65,6 +65,8 @@ namespace gtsam { | ||||||
|     Conditional(size_t nrFrontals) : nrFrontals_(nrFrontals) {} |     Conditional(size_t nrFrontals) : nrFrontals_(nrFrontals) {} | ||||||
| 
 | 
 | ||||||
|     /// @}
 |     /// @}
 | ||||||
|  | 
 | ||||||
|  |   public: | ||||||
|     /// @name Testable
 |     /// @name Testable
 | ||||||
|     /// @{
 |     /// @{
 | ||||||
| 
 | 
 | ||||||
|  | @ -76,7 +78,6 @@ namespace gtsam { | ||||||
| 
 | 
 | ||||||
|     /// @}
 |     /// @}
 | ||||||
| 
 | 
 | ||||||
|   public: |  | ||||||
|     /// @name Standard Interface
 |     /// @name Standard Interface
 | ||||||
|     /// @{
 |     /// @{
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue