Removed debug code, added marginal function
							parent
							
								
									ec6611ae56
								
							
						
					
					
						commit
						4865edb883
					
				|  | @ -6,6 +6,7 @@ | ||||||
| 
 | 
 | ||||||
| #include <boost/foreach.hpp> | #include <boost/foreach.hpp> | ||||||
| #include "BayesTree.h" | #include "BayesTree.h" | ||||||
|  | #include "FactorGraph.h" | ||||||
| 
 | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
|  | @ -27,7 +28,7 @@ namespace gtsam { | ||||||
| 			if (!separator_.empty()) { | 			if (!separator_.empty()) { | ||||||
| 				cout << " :"; | 				cout << " :"; | ||||||
| 				BOOST_FOREACH(string key, separator_) | 				BOOST_FOREACH(string key, separator_) | ||||||
| 				cout << " " << key; | 					cout << " " << key; | ||||||
| 			} | 			} | ||||||
| 			cout << endl; | 			cout << endl; | ||||||
| 		} | 		} | ||||||
|  | @ -48,10 +49,10 @@ namespace gtsam { | ||||||
| 	/* ************************************************************************* */ | 	/* ************************************************************************* */ | ||||||
| 	// TODO: traversal is O(n*log(n)) but could be O(n) with better bayesNet
 | 	// TODO: traversal is O(n*log(n)) but could be O(n) with better bayesNet
 | ||||||
| 	template<class Conditional> | 	template<class Conditional> | ||||||
| 	BayesTree<Conditional>::BayesTree(const BayesNet<Conditional>& bayesNet, bool verbose) { | 	BayesTree<Conditional>::BayesTree(const BayesNet<Conditional>& bayesNet) { | ||||||
| 		typename BayesNet<Conditional>::const_reverse_iterator rit; | 		typename BayesNet<Conditional>::const_reverse_iterator rit; | ||||||
| 		for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit ) | 		for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit ) | ||||||
| 			insert(*rit,verbose); | 			insert(*rit); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	/* ************************************************************************* */ | 	/* ************************************************************************* */ | ||||||
|  | @ -73,22 +74,29 @@ namespace gtsam { | ||||||
| 
 | 
 | ||||||
| 	/* ************************************************************************* */ | 	/* ************************************************************************* */ | ||||||
| 	template<class Conditional> | 	template<class Conditional> | ||||||
| 	void BayesTree<Conditional>::insert(const boost::shared_ptr<Conditional>& conditional, bool verbose) { | 	void BayesTree<Conditional>::addClique | ||||||
|  | 	(const boost::shared_ptr<Conditional>& conditional, node_ptr parent_clique) | ||||||
|  | 	{ | ||||||
|  | 		node_ptr new_clique(new Node(conditional)); | ||||||
|  | 		nodeMap_.insert(make_pair(conditional->key(), nodes_.size())); | ||||||
|  | 		nodes_.push_back(new_clique); | ||||||
|  | 		if (parent_clique==NULL) return; | ||||||
|  | 		new_clique->parent_ = parent_clique; | ||||||
|  | 		parent_clique->children_.push_back(new_clique); | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 		string key =  conditional->key(); | 	/* ************************************************************************* */ | ||||||
| 		if (verbose) cout << "Inserting " << key << "| "; | 	template<class Conditional> | ||||||
| 
 | 	void BayesTree<Conditional>::insert | ||||||
| 		// get parents
 | 	(const boost::shared_ptr<Conditional>& conditional) | ||||||
|  | 	{ | ||||||
|  | 		// get key and parents
 | ||||||
|  | 		string key = conditional->key(); | ||||||
| 		list<string> parents = conditional->parents(); | 		list<string> parents = conditional->parents(); | ||||||
| 		if (verbose) BOOST_FOREACH(string p, parents) cout << p << " "; |  | ||||||
| 		if (verbose) cout << endl; |  | ||||||
| 
 | 
 | ||||||
| 		// if no parents, start a new root clique
 | 		// if no parents, start a new root clique
 | ||||||
| 		if (parents.empty()) { | 		if (parents.empty()) { | ||||||
| 			if (verbose) cout << "Creating root clique" << endl; | 			addClique(conditional); | ||||||
| 			node_ptr root(new Node(conditional)); |  | ||||||
| 			nodes_.push_back(root); |  | ||||||
| 			nodeMap_.insert(make_pair(key, 0)); |  | ||||||
| 			return; | 			return; | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -96,26 +104,35 @@ namespace gtsam { | ||||||
| 		string parent = parents.front(); | 		string parent = parents.front(); | ||||||
| 		NodeMap::const_iterator it = nodeMap_.find(parent); | 		NodeMap::const_iterator it = nodeMap_.find(parent); | ||||||
| 		if (it == nodeMap_.end()) throw(invalid_argument( | 		if (it == nodeMap_.end()) throw(invalid_argument( | ||||||
| 						"BayesTree::insert('"+key+"'): parent '" + parent + "' was not yet inserted")); | 				"BayesTree::insert('"+key+"'): parent '" + parent + "' not yet inserted")); | ||||||
| 		int index = it->second; | 		int parent_index = it->second; | ||||||
| 		node_ptr parent_clique = nodes_[index]; | 		node_ptr parent_clique = nodes_[parent_index]; | ||||||
| 		if (verbose) cout << "Parent clique " << index << " of size " << parent_clique->size() << endl; |  | ||||||
| 
 | 
 | ||||||
| 		// if the parents and parent clique have the same size, add to parent clique
 | 		// if the parents and parent clique have the same size, add to parent clique
 | ||||||
| 		if (parent_clique->size() == parents.size()) { | 		if (parent_clique->size() == parents.size()) { | ||||||
| 			if (verbose) cout << "Adding to clique " << index << endl; | 			nodeMap_.insert(make_pair(key, parent_index)); | ||||||
| 			nodeMap_.insert(make_pair(key, index)); |  | ||||||
| 			parent_clique->push_front(conditional); | 			parent_clique->push_front(conditional); | ||||||
| 			return; | 			return; | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// otherwise, start a new clique and add it to the tree
 | 		// otherwise, start a new clique and add it to the tree
 | ||||||
| 		if (verbose) cout << "Starting new clique" << endl; | 		addClique(conditional,parent_clique); | ||||||
| 		node_ptr new_clique(new Node(conditional)); | 	} | ||||||
| 		new_clique->parent_ = parent_clique; | 
 | ||||||
| 		parent_clique->children_.push_back(new_clique); | 	/* ************************************************************************* */ | ||||||
| 		nodeMap_.insert(make_pair(key, nodes_.size())); | 	template<class Conditional> | ||||||
| 		nodes_.push_back(new_clique); | 	boost::shared_ptr<Conditional> BayesTree<Conditional>::marginal(const string& key) const { | ||||||
|  | 
 | ||||||
|  | 		// find the clique to which key belongs
 | ||||||
|  | 		NodeMap::const_iterator it = nodeMap_.find(key); | ||||||
|  | 		if (it == nodeMap_.end()) throw(invalid_argument( | ||||||
|  | 						"BayesTree::marginal('"+key+"'): key not found")); | ||||||
|  | 
 | ||||||
|  | 		// find all cliques on the path to the root
 | ||||||
|  | 		// FactorGraph
 | ||||||
|  | 
 | ||||||
|  | 		boost::shared_ptr<Conditional> result(new Conditional); | ||||||
|  | 		return result; | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	/* ************************************************************************* */ | 	/* ************************************************************************* */ | ||||||
|  |  | ||||||
|  | @ -30,7 +30,6 @@ namespace gtsam { | ||||||
| 	public: | 	public: | ||||||
| 
 | 
 | ||||||
| 		typedef boost::shared_ptr<Conditional> conditional_ptr; | 		typedef boost::shared_ptr<Conditional> conditional_ptr; | ||||||
| 		typedef std::pair<std::string,conditional_ptr> NamedConditional; |  | ||||||
| 
 | 
 | ||||||
| 	private: | 	private: | ||||||
| 
 | 
 | ||||||
|  | @ -66,13 +65,16 @@ namespace gtsam { | ||||||
| 		typedef std::map<std::string, int> NodeMap; | 		typedef std::map<std::string, int> NodeMap; | ||||||
| 		NodeMap nodeMap_; | 		NodeMap nodeMap_; | ||||||
| 
 | 
 | ||||||
|  | 		/** add a clique */ | ||||||
|  | 		void addClique(const conditional_ptr& conditional, node_ptr parent_clique=node_ptr()); | ||||||
|  | 
 | ||||||
| 	public: | 	public: | ||||||
| 
 | 
 | ||||||
| 		/** Create an empty Bayes Tree */ | 		/** Create an empty Bayes Tree */ | ||||||
| 		BayesTree(); | 		BayesTree(); | ||||||
| 
 | 
 | ||||||
| 		/** Create a Bayes Tree from a Bayes Net */ | 		/** Create a Bayes Tree from a Bayes Net */ | ||||||
| 		BayesTree(const BayesNet<Conditional>& bayesNet, bool verbose=false); | 		BayesTree(const BayesNet<Conditional>& bayesNet); | ||||||
| 
 | 
 | ||||||
| 		/** Destructor */ | 		/** Destructor */ | ||||||
| 		virtual ~BayesTree() {} | 		virtual ~BayesTree() {} | ||||||
|  | @ -84,7 +86,7 @@ namespace gtsam { | ||||||
| 		bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const; | 		bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const; | ||||||
| 
 | 
 | ||||||
| 		/** insert a new conditional */ | 		/** insert a new conditional */ | ||||||
| 		void insert(const boost::shared_ptr<Conditional>& conditional, bool verbose=false); | 		void insert(const boost::shared_ptr<Conditional>& conditional); | ||||||
| 
 | 
 | ||||||
| 			/** number of cliques */ | 			/** number of cliques */ | ||||||
| 		inline size_t size() const { return nodes_.size();} | 		inline size_t size() const { return nodes_.size();} | ||||||
|  | @ -92,6 +94,9 @@ namespace gtsam { | ||||||
| 		/** return root clique */ | 		/** return root clique */ | ||||||
| 		const BayesNet<Conditional>& root() const {return *(nodes_[0]);} | 		const BayesNet<Conditional>& root() const {return *(nodes_[0]);} | ||||||
| 
 | 
 | ||||||
|  | 		/** return marginal on any variable */ | ||||||
|  | 		boost::shared_ptr<Conditional> marginal(const std::string& key) const; | ||||||
|  | 
 | ||||||
| 	}; // BayesTree
 | 	}; // BayesTree
 | ||||||
| 
 | 
 | ||||||
| } /// namespace gtsam
 | } /// namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -27,10 +27,10 @@ SymbolicConditional::shared_ptr B(new SymbolicConditional("B")), L( | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| TEST( BayesTree, Front ) | TEST( BayesTree, Front ) | ||||||
| { | { | ||||||
| 	BayesNet<SymbolicConditional> f1; | 	SymbolicBayesNet f1; | ||||||
| 	f1.push_back(B); | 	f1.push_back(B); | ||||||
| 	f1.push_back(L); | 	f1.push_back(L); | ||||||
| 	BayesNet<SymbolicConditional> f2; | 	SymbolicBayesNet f2; | ||||||
| 	f2.push_back(L); | 	f2.push_back(L); | ||||||
| 	f2.push_back(B); | 	f2.push_back(B); | ||||||
| 	CHECK(f1.equals(f1)); | 	CHECK(f1.equals(f1)); | ||||||
|  | @ -68,9 +68,8 @@ TEST( BayesTree, constructor ) | ||||||
| 	ASIA.push_back(E); | 	ASIA.push_back(E); | ||||||
| 	ASIA.push_back(L); | 	ASIA.push_back(L); | ||||||
| 	ASIA.push_back(B); | 	ASIA.push_back(B); | ||||||
| 	bool verbose = false; | 	BayesTree<SymbolicConditional> bayesTree2(ASIA); | ||||||
| 	BayesTree<SymbolicConditional> bayesTree2(ASIA,verbose); | 	//bayesTree2.print("bayesTree2");
 | ||||||
| 	if (verbose) bayesTree2.print("bayesTree2"); |  | ||||||
| 
 | 
 | ||||||
| 	// Check whether the same
 | 	// Check whether the same
 | ||||||
| 	CHECK(assert_equal(bayesTree,bayesTree2)); | 	CHECK(assert_equal(bayesTree,bayesTree2)); | ||||||
|  | @ -97,7 +96,7 @@ TEST( BayesTree, smoother ) | ||||||
| 	GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); | 	GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); | ||||||
| 
 | 
 | ||||||
| 	// Create the Bayes tree
 | 	// Create the Bayes tree
 | ||||||
| 	BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet,false); | 	BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet); | ||||||
| 	LONGS_EQUAL(6,bayesTree.size()); | 	LONGS_EQUAL(6,bayesTree.size()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -108,7 +107,7 @@ TEST( BayesTree, smoother ) | ||||||
|      x1 : x2 |      x1 : x2 | ||||||
|    x7 : x6 |    x7 : x6 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| TEST( BayesTree, balanced_smoother ) | TEST( BayesTree, balanced_smoother_marginals ) | ||||||
| { | { | ||||||
| 	// Create smoother with 7 nodes
 | 	// Create smoother with 7 nodes
 | ||||||
| 	LinearFactorGraph smoother = createSmoother(7); | 	LinearFactorGraph smoother = createSmoother(7); | ||||||
|  | @ -119,8 +118,21 @@ TEST( BayesTree, balanced_smoother ) | ||||||
| 	GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); | 	GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); | ||||||
| 
 | 
 | ||||||
| 	// Create the Bayes tree
 | 	// Create the Bayes tree
 | ||||||
| 	BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet,false); | 	BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet); | ||||||
| 	LONGS_EQUAL(4,bayesTree.size()); | 	LONGS_EQUAL(4,bayesTree.size()); | ||||||
|  | 
 | ||||||
|  | 	// Check root clique
 | ||||||
|  | 	//BayesNet<ConditionalGaussian> expected_root;
 | ||||||
|  | 	//BayesNet<ConditionalGaussian> actual_root = bayesTree.root();
 | ||||||
|  | 	//CHECK(assert_equal(expected_root,actual_root));
 | ||||||
|  | 
 | ||||||
|  | 	// Check marginal on x1
 | ||||||
|  | 	ConditionalGaussian expected; | ||||||
|  | 	ConditionalGaussian::shared_ptr actual = bayesTree.marginal("x1"); | ||||||
|  | 	CHECK(assert_equal(expected,*actual)); | ||||||
|  | 
 | ||||||
|  | 	// JunctionTree is an undirected tree of cliques
 | ||||||
|  | 	// JunctionTree<ConditionalGaussian> marginals = bayesTree.marginals();
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue