| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  * @file    BayesTree.cpp | 
					
						
							|  |  |  |  * @brief   Bayes Tree is a tree of cliques of a Bayes Chain | 
					
						
							|  |  |  |  * @author  Frank Dellaert | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-01 00:57:36 +08:00
										 |  |  | #include <boost/foreach.hpp>
 | 
					
						
							| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | #include "BayesTree.h"
 | 
					
						
							| 
									
										
										
										
											2009-11-05 14:30:50 +08:00
										 |  |  | #include "FactorGraph-inl.h"
 | 
					
						
							| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace gtsam { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 	using namespace std; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-01 00:57:36 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							|  |  |  | 	template<class Conditional> | 
					
						
							| 
									
										
										
										
											2009-11-02 13:17:44 +08:00
										 |  |  | 	BayesTree<Conditional>::Node::Node(const boost::shared_ptr<Conditional>& conditional) { | 
					
						
							|  |  |  | 			separator_ = conditional->parents(); | 
					
						
							|  |  |  | 			this->push_back(conditional); | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2009-11-01 00:57:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	/* ************************************************************************* */ | 
					
						
							|  |  |  | 	template<class Conditional> | 
					
						
							| 
									
										
										
										
											2009-11-02 13:17:44 +08:00
										 |  |  | 	void BayesTree<Conditional>::Node::print(const string& s) const { | 
					
						
							|  |  |  | 			cout << s; | 
					
						
							|  |  |  | 			BOOST_REVERSE_FOREACH(const conditional_ptr& conditional, this->conditionals_) | 
					
						
							|  |  |  | 				cout << " " << conditional->key(); | 
					
						
							|  |  |  | 			if (!separator_.empty()) { | 
					
						
							|  |  |  | 				cout << " :"; | 
					
						
							|  |  |  | 				BOOST_FOREACH(string key, separator_) | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 					cout << " " << key; | 
					
						
							| 
									
										
										
										
											2009-11-02 13:17:44 +08:00
										 |  |  | 			} | 
					
						
							|  |  |  | 			cout << endl; | 
					
						
							| 
									
										
										
										
											2009-11-01 00:57:36 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	/* ************************************************************************* */ | 
					
						
							|  |  |  | 	template<class Conditional> | 
					
						
							| 
									
										
										
										
											2009-11-02 13:17:44 +08:00
										 |  |  | 	void BayesTree<Conditional>::Node::printTree(const string& indent) const { | 
					
						
							|  |  |  | 			print(indent); | 
					
						
							|  |  |  | 			BOOST_FOREACH(shared_ptr child, children_) | 
					
						
							|  |  |  | 				child->printTree(indent+"  "); | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2009-11-01 00:57:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							|  |  |  | 	template<class Conditional> | 
					
						
							|  |  |  | 	BayesTree<Conditional>::BayesTree() { | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	/* ************************************************************************* */ | 
					
						
							| 
									
										
										
										
											2009-11-02 11:50:30 +08:00
										 |  |  | 	// TODO: traversal is O(n*log(n)) but could be O(n) with better bayesNet
 | 
					
						
							| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | 	template<class Conditional> | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 	BayesTree<Conditional>::BayesTree(const BayesNet<Conditional>& bayesNet) { | 
					
						
							| 
									
										
										
										
											2009-11-02 11:50:30 +08:00
										 |  |  | 		typename BayesNet<Conditional>::const_reverse_iterator rit; | 
					
						
							| 
									
										
										
										
											2009-11-03 14:29:56 +08:00
										 |  |  | 		for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit ) | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 			insert(*rit); | 
					
						
							| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | 	template<class Conditional> | 
					
						
							| 
									
										
										
										
											2009-11-01 00:57:36 +08:00
										 |  |  | 	void BayesTree<Conditional>::print(const string& s) const { | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 		cout << s << ": size == " << nodes_.size() << endl; | 
					
						
							|  |  |  | 		if (nodes_.empty()) return; | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 		root_->printTree(""); | 
					
						
							| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | 	template<class Conditional> | 
					
						
							|  |  |  | 	bool BayesTree<Conditional>::equals(const BayesTree<Conditional>& other, | 
					
						
							|  |  |  | 			double tol) const { | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 		return size()==other.size(); | 
					
						
							|  |  |  | 		//&& equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star<Node>(tol));
 | 
					
						
							| 
									
										
										
										
											2009-10-30 21:03:38 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							|  |  |  | 	template<class Conditional> | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 	boost::shared_ptr<typename BayesTree<Conditional>::Node> BayesTree<Conditional>::addClique | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 	(const boost::shared_ptr<Conditional>& conditional, node_ptr parent_clique) | 
					
						
							|  |  |  | 	{ | 
					
						
							|  |  |  | 		node_ptr new_clique(new Node(conditional)); | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 		nodes_.insert(make_pair(conditional->key(), new_clique)); | 
					
						
							|  |  |  | 	if (parent_clique!=NULL) { | 
					
						
							|  |  |  | 			new_clique->parent_ = parent_clique; | 
					
						
							|  |  |  | 			parent_clique->children_.push_back(new_clique); | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return new_clique; | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2009-11-01 00:57:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							|  |  |  | 	template<class Conditional> | 
					
						
							|  |  |  | 	void BayesTree<Conditional>::insert | 
					
						
							|  |  |  | 	(const boost::shared_ptr<Conditional>& conditional) | 
					
						
							|  |  |  | 	{ | 
					
						
							|  |  |  | 		// get key and parents
 | 
					
						
							|  |  |  | 		string key = conditional->key(); | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 		list<string> parents = conditional->parents(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// if no parents, start a new root clique
 | 
					
						
							|  |  |  | 		if (parents.empty()) { | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 			root_ = addClique(conditional); | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 			return; | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// otherwise, find the parent clique
 | 
					
						
							|  |  |  | 		string parent = parents.front(); | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 		typename Nodes::const_iterator it = nodes_.find(parent); | 
					
						
							|  |  |  | 		if (it == nodes_.end()) throw(invalid_argument( | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 				"BayesTree::insert('"+key+"'): parent '" + parent + "' not yet inserted")); | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 		node_ptr parent_clique = it->second; | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 		// if the parents and parent clique have the same size, add to parent clique
 | 
					
						
							|  |  |  | 		if (parent_clique->size() == parents.size()) { | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 			nodes_.insert(make_pair(key, parent_clique)); | 
					
						
							| 
									
										
										
										
											2009-11-03 14:29:56 +08:00
										 |  |  | 			parent_clique->push_front(conditional); | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 			return; | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// otherwise, start a new clique and add it to the tree
 | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 		addClique(conditional,parent_clique); | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-06 13:43:03 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							|  |  |  | 	// Desired: recursive, memoizing version
 | 
					
						
							|  |  |  | 	// Once we know the clique, can we do all with Nodes ?
 | 
					
						
							|  |  |  | 	// Sure, as P(x) = \int P(C|root)
 | 
					
						
							|  |  |  | 	// The natural cache is P(C|root), memoized, of course, in the clique C
 | 
					
						
							|  |  |  | 	// When any marginal is asked for, we calculate P(C|root) = P(C|Pi)P(Pi|root)
 | 
					
						
							|  |  |  | 	// Super-naturally recursive !!!!!
 | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							|  |  |  | 	template<class Conditional> | 
					
						
							| 
									
										
										
										
											2009-11-05 14:30:50 +08:00
										 |  |  | 	template<class Factor> | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 	boost::shared_ptr<Conditional> BayesTree<Conditional>::marginal(const string& key) const { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// find the clique to which key belongs
 | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 		typename Nodes::const_iterator it = nodes_.find(key); | 
					
						
							|  |  |  | 		if (it == nodes_.end()) throw(invalid_argument( | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 						"BayesTree::marginal('"+key+"'): key not found")); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-05 16:06:32 +08:00
										 |  |  | 		// get clique containing key, and remove all factors below key
 | 
					
						
							|  |  |  | 		node_ptr clique = it->second; | 
					
						
							|  |  |  | 		Ordering ordering = clique->ordering(); | 
					
						
							|  |  |  | 		FactorGraph<Factor> graph(*clique); | 
					
						
							|  |  |  | 		while(ordering.front()!=key) { | 
					
						
							|  |  |  | 			graph.findAndRemoveFactors(ordering.front()); | 
					
						
							|  |  |  | 			ordering.pop_front(); | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 		// find all cliques on the path to the root and turn into factor graph
 | 
					
						
							| 
									
										
										
										
											2009-11-05 16:06:32 +08:00
										 |  |  | 		while (clique->parent_!=NULL) { | 
					
						
							|  |  |  | 			// move up the tree
 | 
					
						
							|  |  |  | 			clique = clique->parent_; | 
					
						
							| 
									
										
										
										
											2009-11-05 14:30:50 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			// extend ordering
 | 
					
						
							| 
									
										
										
										
											2009-11-05 16:06:32 +08:00
										 |  |  | 			Ordering cliqueOrdering = clique->ordering(); | 
					
						
							| 
									
										
										
										
											2009-11-05 14:30:50 +08:00
										 |  |  | 			ordering.splice (ordering.end(), cliqueOrdering); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// extend factor graph
 | 
					
						
							| 
									
										
										
										
											2009-11-05 16:06:32 +08:00
										 |  |  | 			FactorGraph<Factor> cliqueGraph(*clique); | 
					
						
							| 
									
										
										
										
											2009-11-05 14:30:50 +08:00
										 |  |  | 			typename FactorGraph<Factor>::const_iterator factor=cliqueGraph.begin(); | 
					
						
							|  |  |  | 			for(; factor!=cliqueGraph.end(); factor++) | 
					
						
							|  |  |  | 				graph.push_back(*factor); | 
					
						
							| 
									
										
										
										
											2009-11-05 13:29:47 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2009-11-04 11:22:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-06 13:43:03 +08:00
										 |  |  | 		// TODO: can we prove reverse ordering is efficient?
 | 
					
						
							| 
									
										
										
										
											2009-11-05 14:30:50 +08:00
										 |  |  | 		ordering.reverse(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// eliminate to get marginal
 | 
					
						
							|  |  |  | 		boost::shared_ptr<BayesNet<Conditional> > bayesNet; | 
					
						
							|  |  |  | 		typename boost::shared_ptr<BayesNet<Conditional> > chordalBayesNet = | 
					
						
							|  |  |  | 				graph.eliminate(bayesNet,ordering); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-06 13:43:03 +08:00
										 |  |  | 		return chordalBayesNet->back(); // the root is the marginal
 | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-01 00:57:36 +08:00
										 |  |  | 	/* ************************************************************************* */ | 
					
						
							| 
									
										
										
										
											2009-10-31 13:12:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-11-02 11:50:30 +08:00
										 |  |  | } | 
					
						
							|  |  |  | /// namespace gtsam
 |