115 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			C++
		
	
	
			
		
		
	
	
			115 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			C++
		
	
	
/* ----------------------------------------------------------------------------
 | 
						|
 | 
						|
 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
						|
 * Atlanta, Georgia 30332-0415
 | 
						|
 * All Rights Reserved
 | 
						|
 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
						|
 | 
						|
 * See LICENSE for the license information
 | 
						|
 | 
						|
 * -------------------------------------------------------------------------- */
 | 
						|
 | 
						|
/**
 | 
						|
 * @file small.cpp
 | 
						|
 * @brief UGM (undirected graphical model) examples: chain
 | 
						|
 * @author Frank Dellaert
 | 
						|
 * @author Abhijit Kundu
 | 
						|
 *
 | 
						|
 * See http://www.di.ens.fr/~mschmidt/Software/UGM/chain.html
 | 
						|
 * for more explanation. This code demos the same example using GTSAM.
 | 
						|
 */
 | 
						|
 | 
						|
#include <gtsam/discrete/DiscreteFactorGraph.h>
 | 
						|
#include <gtsam/discrete/DiscreteSequentialSolver.h>
 | 
						|
#include <gtsam/discrete/DiscreteMarginals.h>
 | 
						|
#include <gtsam/base/timing.h>
 | 
						|
 | 
						|
#include <iomanip>
 | 
						|
 | 
						|
using namespace std;
 | 
						|
using namespace gtsam;
 | 
						|
 | 
						|
int main(int argc, char** argv) {
 | 
						|
 | 
						|
    // Set Number of Nodes in the Graph
 | 
						|
    const int nrNodes = 60;
 | 
						|
 | 
						|
  // Each node takes 1 of 7 possible states denoted by 0-6 in following order:
 | 
						|
  // ["VideoGames"  "Industry"  "GradSchool"  "VideoGames(with PhD)"
 | 
						|
  // "Industry(with PhD)"  "Academia"  "Deceased"]
 | 
						|
  const size_t nrStates = 7;
 | 
						|
 | 
						|
  // define variables
 | 
						|
    vector<DiscreteKey> nodes;
 | 
						|
    for (int i = 0; i < nrNodes; i++){
 | 
						|
        DiscreteKey dk(i, nrStates);
 | 
						|
        nodes.push_back(dk);
 | 
						|
    }
 | 
						|
 | 
						|
  // create graph
 | 
						|
  DiscreteFactorGraph graph;
 | 
						|
 | 
						|
  // add node potentials
 | 
						|
  graph.add(nodes[0], ".3 .6 .1 0 0 0 0");
 | 
						|
    for (int i = 1; i < nrNodes; i++)
 | 
						|
        graph.add(nodes[i], "1 1 1 1 1 1 1");
 | 
						|
 | 
						|
    const std::string edgePotential =   ".08 .9 .01 0 0 0 .01 "
 | 
						|
                                        ".03 .95 .01 0 0 0 .01 "
 | 
						|
                                        ".06 .06 .75 .05 .05 .02 .01 "
 | 
						|
                                        "0 0 0 .3 .6 .09 .01 "
 | 
						|
                                        "0 0 0 .02 .95 .02 .01 "
 | 
						|
                                        "0 0 0 .01 .01 .97 .01 "
 | 
						|
                                        "0 0 0 0 0 0 1";
 | 
						|
 | 
						|
  // add edge potentials
 | 
						|
  for (int i = 0; i < nrNodes - 1; i++)
 | 
						|
    graph.add(nodes[i] & nodes[i + 1], edgePotential);
 | 
						|
 | 
						|
  cout << "Created Factor Graph with " << nrNodes << " variable nodes and "
 | 
						|
      << graph.size() << " factors (Unary+Edge).";
 | 
						|
 | 
						|
  // "Decoding", i.e., configuration with largest value
 | 
						|
  // We use sequential variable elimination
 | 
						|
  DiscreteSequentialSolver solver(graph);
 | 
						|
  DiscreteFactor::sharedValues optimalDecoding = solver.optimize();
 | 
						|
  optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n");
 | 
						|
 | 
						|
  // "Inference" Computing marginals for each node
 | 
						|
 | 
						|
 | 
						|
  cout << "\nComputing Node Marginals ..(Sequential Elimination)" << endl;
 | 
						|
  gttic_(Sequential);
 | 
						|
  for (vector<DiscreteKey>::iterator itr = nodes.begin(); itr != nodes.end();
 | 
						|
      ++itr) {
 | 
						|
    //Compute the marginal
 | 
						|
    Vector margProbs = solver.marginalProbabilities(*itr);
 | 
						|
 | 
						|
    //Print the marginals
 | 
						|
    cout << "Node#" << setw(4) << itr->first << " :  ";
 | 
						|
    print(margProbs);
 | 
						|
  }
 | 
						|
  gttoc_(Sequential);
 | 
						|
 | 
						|
  // Here we'll make use of DiscreteMarginals class, which makes use of
 | 
						|
  // bayes-tree based shortcut evaluation of marginals
 | 
						|
  DiscreteMarginals marginals(graph);
 | 
						|
 | 
						|
  cout << "\nComputing Node Marginals ..(BayesTree based)" << endl;
 | 
						|
  gttic_(Multifrontal);
 | 
						|
  for (vector<DiscreteKey>::iterator itr = nodes.begin(); itr != nodes.end();
 | 
						|
      ++itr) {
 | 
						|
    //Compute the marginal
 | 
						|
    Vector margProbs = marginals.marginalProbabilities(*itr);
 | 
						|
 | 
						|
    //Print the marginals
 | 
						|
    cout << "Node#" << setw(4) << itr->first << " :  ";
 | 
						|
    print(margProbs);
 | 
						|
  }
 | 
						|
  gttoc_(Multifrontal);
 | 
						|
 | 
						|
  tictoc_print_();
 | 
						|
  return 0;
 | 
						|
}
 | 
						|
 |