gtsam/examples/UGM_chain.cpp

96 lines
2.8 KiB
C++
Raw Permalink Normal View History

2012-06-06 11:25:56 +08:00
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
2020-07-10 08:46:12 +08:00
* @file UGM_chain.cpp
2012-06-06 11:25:56 +08:00
* @brief UGM (undirected graphical model) examples: chain
* @author Frank Dellaert
2012-06-09 07:28:22 +08:00
* @author Abhijit Kundu
2012-06-06 11:25:56 +08:00
*
* See http://www.di.ens.fr/~mschmidt/Software/UGM/chain.html
* for more explanation. This code demos the same example using GTSAM.
*/
2020-07-10 08:46:12 +08:00
#include <gtsam/base/timing.h>
2012-06-06 11:25:56 +08:00
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h>
2012-06-06 11:25:56 +08:00
#include <iomanip>
using namespace std;
using namespace gtsam;
int main(int argc, char** argv) {
2020-07-10 08:46:12 +08:00
// Set Number of Nodes in the Graph
const int nrNodes = 60;
2012-06-06 11:25:56 +08:00
// Each node takes 1 of 7 possible states denoted by 0-6 in following order:
// ["VideoGames" "Industry" "GradSchool" "VideoGames(with PhD)"
2019-02-11 22:39:48 +08:00
// "Industry(with PhD)" "Academia" "Deceased"]
const size_t nrStates = 7;
2012-06-06 11:25:56 +08:00
// define variables
2020-07-10 08:46:12 +08:00
vector<DiscreteKey> nodes;
for (int i = 0; i < nrNodes; i++) {
DiscreteKey dk(i, nrStates);
nodes.push_back(dk);
}
2012-06-06 11:25:56 +08:00
// create graph
DiscreteFactorGraph graph;
2012-06-06 11:25:56 +08:00
// add node potentials
graph.add(nodes[0], ".3 .6 .1 0 0 0 0");
2020-07-10 08:46:12 +08:00
for (int i = 1; i < nrNodes; i++) graph.add(nodes[i], "1 1 1 1 1 1 1");
2012-06-06 11:25:56 +08:00
2020-07-10 08:46:12 +08:00
const std::string edgePotential =
".08 .9 .01 0 0 0 .01 "
".03 .95 .01 0 0 0 .01 "
".06 .06 .75 .05 .05 .02 .01 "
"0 0 0 .3 .6 .09 .01 "
"0 0 0 .02 .95 .02 .01 "
"0 0 0 .01 .01 .97 .01 "
"0 0 0 0 0 0 1";
2012-06-06 11:25:56 +08:00
// add edge potentials
for (int i = 0; i < nrNodes - 1; i++)
graph.add(nodes[i] & nodes[i + 1], edgePotential);
cout << "Created Factor Graph with " << nrNodes << " variable nodes and "
2020-07-10 08:46:12 +08:00
<< graph.size() << " factors (Unary+Edge).";
// "Decoding", i.e., configuration with largest value
2022-01-22 07:12:38 +08:00
// Uses max-product.
auto optimalDecoding = graph.optimize();
2021-11-21 05:34:53 +08:00
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
// "Inference" Computing marginals for each node
// Here we'll make use of DiscreteMarginals class, which makes use of
// bayes-tree based shortcut evaluation of marginals
DiscreteMarginals marginals(graph);
cout << "\nComputing Node Marginals ..(BayesTree based)" << endl;
gttic_(Multifrontal);
2020-07-10 08:46:12 +08:00
for (vector<DiscreteKey>::iterator it = nodes.begin(); it != nodes.end();
++it) {
// Compute the marginal
Vector margProbs = marginals.marginalProbabilities(*it);
2020-07-10 08:46:12 +08:00
// Print the marginals
cout << "Node#" << setw(4) << it->first << " : ";
print(margProbs);
}
gttoc_(Multifrontal);
tictoc_print_();
return 0;
2012-06-06 11:25:56 +08:00
}