| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  * @file    DiscreteConditional.h | 
					
						
							|  |  |  |  * @brief   Discrete Conditional node for use in Bayes nets | 
					
						
							|  |  |  |  * @author  Manohar Paluri | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // \callgraph
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <list>
 | 
					
						
							|  |  |  | #include <string>
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							|  |  |  | #include <boost/shared_ptr.hpp>
 | 
					
						
							|  |  |  | #include <boost/foreach.hpp> // TODO: make cpp file
 | 
					
						
							|  |  |  | #include <boost/serialization/list.hpp>
 | 
					
						
							| 
									
										
										
										
											2009-12-07 07:28:46 +08:00
										 |  |  | #include <boost/serialization/vector.hpp>
 | 
					
						
							| 
									
										
										
										
											2010-08-20 01:23:19 +08:00
										 |  |  | #include <gtsam/inference/Conditional.h>
 | 
					
						
							|  |  |  | #include <gtsam/inference/Key.h>
 | 
					
						
							|  |  |  | #include <gtsam/inference/SymbolMap.h>
 | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace gtsam { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	/**
 | 
					
						
							|  |  |  | 	 * Conditional node for use in a Bayes net | 
					
						
							|  |  |  | 	 */ | 
					
						
							|  |  |  | 	class BinaryConditional: public Conditional { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	private: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2010-01-18 03:34:57 +08:00
										 |  |  | 		std::list<Symbol> parents_; | 
					
						
							| 
									
										
										
										
											2009-12-07 07:28:46 +08:00
										 |  |  | 		std::vector<double> cpt_; | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		/** convenience typename for a shared pointer to this class */ | 
					
						
							|  |  |  | 		typedef boost::shared_ptr<BinaryConditional> shared_ptr; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		/**
 | 
					
						
							|  |  |  | 		 * Empty Constructor to make serialization possible | 
					
						
							|  |  |  | 		 */ | 
					
						
							|  |  |  | 		BinaryConditional(){} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		/**
 | 
					
						
							|  |  |  | 		 * No parents | 
					
						
							|  |  |  | 		 */ | 
					
						
							| 
									
										
										
										
											2010-01-18 03:34:57 +08:00
										 |  |  | 		BinaryConditional(const Symbol& key, double p) : | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 			Conditional(key) { | 
					
						
							| 
									
										
										
										
											2009-12-07 07:28:46 +08:00
										 |  |  | 			cpt_.push_back(1-p); | 
					
						
							|  |  |  | 			cpt_.push_back(p); | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		/**
 | 
					
						
							|  |  |  | 		 * Single parent | 
					
						
							|  |  |  | 		 */ | 
					
						
							| 
									
										
										
										
											2010-01-18 03:34:57 +08:00
										 |  |  | 		BinaryConditional(const Symbol& key, const Symbol& parent, const std::vector<double>& cpt) : | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 			Conditional(key) { | 
					
						
							|  |  |  | 			parents_.push_back(parent); | 
					
						
							| 
									
										
										
										
											2010-05-22 01:59:26 +08:00
										 |  |  | 			for( size_t i = 0 ; i < cpt.size() ; i++ ) | 
					
						
							| 
									
										
										
										
											2009-12-07 11:25:25 +08:00
										 |  |  | 				cpt_.push_back(1-cpt[i]); // p(!x|parents)
 | 
					
						
							|  |  |  | 			cpt_.insert(cpt_.end(),cpt.begin(),cpt.end()); // p(x|parents)
 | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2010-01-22 12:41:40 +08:00
										 |  |  | 		double probability( SymbolMap<bool> config) { | 
					
						
							| 
									
										
										
										
											2009-12-07 11:25:25 +08:00
										 |  |  | 			int index = 0, count = 1; | 
					
						
							| 
									
										
										
										
											2010-01-18 03:34:57 +08:00
										 |  |  | 			BOOST_FOREACH(const Symbol& parent, parents_){ | 
					
						
							| 
									
										
										
										
											2009-12-07 11:25:25 +08:00
										 |  |  | 				index += count*(int)(config[parent]); | 
					
						
							|  |  |  | 				count = count << 1; | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2010-01-22 12:41:40 +08:00
										 |  |  | 			if( config.at(key_) ) | 
					
						
							| 
									
										
										
										
											2009-12-07 11:25:25 +08:00
										 |  |  | 				index += count; | 
					
						
							| 
									
										
										
										
											2009-12-07 08:49:13 +08:00
										 |  |  | 			return cpt_[index]; | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 		/** print */ | 
					
						
							|  |  |  | 		void print(const std::string& s = "BinaryConditional") const { | 
					
						
							| 
									
										
										
										
											2010-01-18 03:34:57 +08:00
										 |  |  | 			std::cout << s << " P(" << (std::string)key_; | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 			if (parents_.size()>0) std::cout << " |"; | 
					
						
							|  |  |  | 			BOOST_FOREACH(std::string parent, parents_) std::cout << " " << parent; | 
					
						
							|  |  |  | 			std::cout << ")" << std::endl; | 
					
						
							| 
									
										
										
										
											2009-12-07 07:28:46 +08:00
										 |  |  | 			std::cout << "Conditional Probability Table::" << std::endl; | 
					
						
							|  |  |  | 			BOOST_FOREACH(double p, cpt_) std::cout << p << "\t"; | 
					
						
							|  |  |  | 			std::cout<< std::endl; | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		/** check equality */ | 
					
						
							|  |  |  | 		bool equals(const Conditional& c, double tol = 1e-9) const { | 
					
						
							|  |  |  | 			if (!Conditional::equals(c)) return false; | 
					
						
							|  |  |  | 			const BinaryConditional* p = dynamic_cast<const BinaryConditional*> (&c); | 
					
						
							|  |  |  | 			if (p == NULL) return false; | 
					
						
							| 
									
										
										
										
											2009-12-07 07:28:46 +08:00
										 |  |  | 			return (parents_ == p->parents_ && cpt_ == p->cpt_); | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		/** return parents */ | 
					
						
							| 
									
										
										
										
											2010-01-18 03:34:57 +08:00
										 |  |  | 		std::list<Symbol> parents() const { return parents_;} | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-12-07 07:28:46 +08:00
										 |  |  | 		/** return Conditional probability table*/ | 
					
						
							|  |  |  | 		std::vector<double> cpt() const { return cpt_;} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 		/** find the number of parents */ | 
					
						
							|  |  |  | 		size_t nrParents() const { | 
					
						
							|  |  |  | 			return parents_.size(); | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	private: | 
					
						
							|  |  |  | 		/** Serialization function */ | 
					
						
							|  |  |  | 		friend class boost::serialization::access; | 
					
						
							|  |  |  | 		template<class Archive> | 
					
						
							|  |  |  | 		void serialize(Archive & ar, const unsigned int version) { | 
					
						
							|  |  |  | 			ar & boost::serialization::make_nvp("Conditional", boost::serialization::base_object<Conditional>(*this)); | 
					
						
							|  |  |  | 			ar & BOOST_SERIALIZATION_NVP(parents_); | 
					
						
							| 
									
										
										
										
											2009-12-07 07:28:46 +08:00
										 |  |  | 			ar & BOOST_SERIALIZATION_NVP(cpt_); | 
					
						
							| 
									
										
										
										
											2009-12-07 05:46:46 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	}; | 
					
						
							|  |  |  | } /// namespace gtsam
 |