2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								/* ----------------------------------------------------------------------------
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * Atlanta, Georgia 30332-0415
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * All Rights Reserved
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * See LICENSE for the license information
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * -------------------------------------------------------------------------- */
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								/**
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * @file small.cpp
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * @brief UGM (undirected graphical model) examples: chain
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * @author Frank Dellaert
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-09 07:28:22 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								 * @author Abhijit Kundu
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 *
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * See http://www.di.ens.fr/~mschmidt/Software/UGM/chain.html
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 * for more explanation. This code demos the same example using GTSAM.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								 */
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <gtsam/discrete/DiscreteFactorGraph.h>
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <gtsam/discrete/DiscreteSequentialSolver.h>
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-08 07:20:40 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								#include <gtsam/discrete/DiscreteMarginals.h>
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-08 08:18:32 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								#include <gtsam/base/timing.h>
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <iomanip>
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								using namespace std;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								using namespace gtsam;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								int main(int argc, char** argv) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    // Set Number of Nodes in the Graph
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-08 08:18:32 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    const int nrNodes = 60;
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  // Each node takes 1 of 7 possible states denoted by 0-6 in following order:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  // ["VideoGames"  "Industry"  "GradSchool"  "VideoGames(with PhD)"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  // "Industry(with PhD)"  "Academia"  "Deceased"] 
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  const size_t nrStates = 7;
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  // define variables
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    vector<DiscreteKey> nodes;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    for (int i = 0; i < nrNodes; i++){
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        DiscreteKey dk(i, nrStates);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        nodes.push_back(dk);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  // create graph
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  DiscreteFactorGraph graph;
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  // add node potentials
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  graph.add(nodes[0], ".3 .6 .1 0 0 0 0");
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    for (int i = 1; i < nrNodes; i++)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        graph.add(nodes[i], "1 1 1 1 1 1 1");
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const std::string edgePotential =   ".08 .9 .01 0 0 0 .01 "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                                        ".03 .95 .01 0 0 0 .01 "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                                        ".06 .06 .75 .05 .05 .02 .01 "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                                        "0 0 0 .3 .6 .09 .01 "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                                        "0 0 0 .02 .95 .02 .01 "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                                        "0 0 0 .01 .01 .97 .01 "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                                        "0 0 0 0 0 0 1";
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  // add edge potentials
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  for (int i = 0; i < nrNodes - 1; i++)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    graph.add(nodes[i] & nodes[i + 1], edgePotential);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  cout << "Created Factor Graph with " << nrNodes << " variable nodes and "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      << graph.size() << " factors (Unary+Edge).";
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  // "Decoding", i.e., configuration with largest value
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  // We use sequential variable elimination
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  DiscreteSequentialSolver solver(graph);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  DiscreteFactor::sharedValues optimalDecoding = solver.optimize();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n");
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  // "Inference" Computing marginals for each node
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  cout << "\nComputing Node Marginals ..(Sequential Elimination)" << endl;
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-03 04:18:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  gttic_(Sequential);
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  for (vector<DiscreteKey>::iterator itr = nodes.begin(); itr != nodes.end();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      ++itr) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    //Compute the marginal
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    Vector margProbs = solver.marginalProbabilities(*itr);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    //Print the marginals
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    cout << "Node#" << setw(4) << itr->first << " :  ";
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    print(margProbs);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  }
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-03 04:18:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  gttoc_(Sequential);
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  // Here we'll make use of DiscreteMarginals class, which makes use of
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  // bayes-tree based shortcut evaluation of marginals
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  DiscreteMarginals marginals(graph);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  cout << "\nComputing Node Marginals ..(BayesTree based)" << endl;
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-03 04:18:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  gttic_(Multifrontal);
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  for (vector<DiscreteKey>::iterator itr = nodes.begin(); itr != nodes.end();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      ++itr) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    //Compute the marginal
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    Vector margProbs = marginals.marginalProbabilities(*itr);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    //Print the marginals
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    cout << "Node#" << setw(4) << itr->first << " :  ";
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    print(margProbs);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  }
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-03 04:18:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								  gttoc_(Multifrontal);
							 | 
						
					
						
							
								
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  tictoc_print_();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  return 0;
							 | 
						
					
						
							
								
									
										
										
										
											2012-06-06 11:25:56 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 |