refactor and add more comments

release/4.3a0
Duy-Nguyen Ta 2013-10-12 15:27:37 +00:00
parent 1f293294fd
commit 8be7363a01
1 changed files with 65 additions and 37 deletions

View File

@ -15,44 +15,53 @@
#include <fstream> #include <fstream>
using namespace std; using namespace std;
using namespace boost;
using namespace boost::assign; using namespace boost::assign;
using namespace gtsam; using namespace gtsam;
/**
* Loopy belief solver for graphs with only binary and unary factors
*/
class LoopyBelief { class LoopyBelief {
/** Star graph struct for each node, containing
* - the star graph itself
* - the product of original unary factors so we don't have to recompute it later, and
* - the factor indices of the corrected belief factors of the neighboring nodes
*/
typedef std::map<Key, size_t> CorrectedBeliefIndices; typedef std::map<Key, size_t> CorrectedBeliefIndices;
struct StarGraph { struct StarGraph {
DiscreteFactorGraph::shared_ptr star; DiscreteFactorGraph::shared_ptr star;
DecisionTreeFactor::shared_ptr unary;
CorrectedBeliefIndices correctedBeliefIndices; CorrectedBeliefIndices correctedBeliefIndices;
StarGraph() {
}
StarGraph(const DiscreteFactorGraph::shared_ptr& _star, StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
const DecisionTreeFactor::shared_ptr& _unary,
const CorrectedBeliefIndices& _beliefIndices) : const CorrectedBeliefIndices& _beliefIndices) :
star(_star), correctedBeliefIndices(_beliefIndices) { star(_star), unary(_unary), correctedBeliefIndices(_beliefIndices) {
} }
}; };
const DiscreteFactorGraph& graph_; typedef std::map<Key, StarGraph> StarGraphs;
VariableIndex varIndex_; StarGraphs starGraphs_; ///< star graph at each variable
std::map<Key, StarGraph> starGraphs_;
std::map<Key, DecisionTreeFactor> unary_;
std::map<Key, DecisionTreeFactor> belief_;
public: public:
LoopyBelief(const DiscreteFactorGraph& graph, const std::map<Key, DiscreteKey>& allDiscreteKeys) : /** Constructor
graph_(graph), varIndex_(graph) { * Need all discrete keys to access node's cardinality for creating belief factors
initialize(allDiscreteKeys); * TODO: so troublesome!!
*/
LoopyBelief(const DiscreteFactorGraph& graph,
const std::map<Key, DiscreteKey>& allDiscreteKeys) :
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
} }
void iterate() { /// One step of belief propagation
DiscreteFactorGraph::shared_ptr iterate() {
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
// Eliminate each star graph // Eliminate each star graph
BOOST_FOREACH(const VariableIndex::value_type& keyFactors, varIndex_) { BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
Key key = keyFactors.first; // initialize belief to the unary factor from the original graph
// reset belief to the unary factor in the original graph DecisionTreeFactor beliefAtKey = *starGraphs_.at(key).unary;
belief_[key] = unary_.at(key);
// keep intermediate messages to divide later // keep intermediate messages to divide later
std::map<Key, DiscreteFactor::shared_ptr> messages; std::map<Key, DiscreteFactor::shared_ptr> messages;
@ -60,45 +69,59 @@ public:
// eliminate each neighbor in this star graph one by one // eliminate each neighbor in this star graph one by one
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
DiscreteFactor::shared_ptr factor; DiscreteFactor::shared_ptr factor;
boost::tie(dummyCond, factor) = EliminateDiscrete(*starGraphs_.at(key).star, boost::tie(dummyCond, factor) = EliminateDiscrete(
Ordering(list_of(neighbor))); *starGraphs_.at(key).star, Ordering(list_of(neighbor)));
// store the new factor into messages // store the new factor into messages
messages.insert(make_pair(neighbor, factor)); messages.insert(make_pair(neighbor, factor));
// Belief is the product of all messages and the unary factor // Belief is the product of all messages and the unary factor
// Incorporate new the factor to belief // Incorporate new the factor to belief
belief_.at(key) = belief_.at(key) * (*boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)); beliefAtKey = beliefAtKey
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(factor));
} }
beliefs->push_back(beliefAtKey);
// Update the corrected belief for the neighbor's stargraph // Update the corrected belief for the neighbor's stargraph
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
DecisionTreeFactor correctedBelief = belief_.at(key) DecisionTreeFactor correctedBelief = beliefAtKey
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(messages.at(neighbor))); / (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
messages.at(neighbor)));
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at( size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
key); key);
starGraphs_.at(neighbor).star->replace(beliefIndex, boost::make_shared<DecisionTreeFactor>(correctedBelief)); starGraphs_.at(neighbor).star->replace(beliefIndex,
boost::make_shared<DecisionTreeFactor>(correctedBelief));
} }
} }
return beliefs;
} }
private: private:
void initialize(const std::map<Key, DiscreteKey>& allDiscreteKeys) { /**
BOOST_FOREACH(Key key, varIndex_ | boost::adaptors::map_keys) { * Build star graphs for each node.
*/
StarGraphs buildStarGraphs(const DiscreteFactorGraph& graph,
const std::map<Key, DiscreteKey>& allDiscreteKeys) const {
StarGraphs starGraphs;
VariableIndex varIndex(graph); ///< access to all factors of each node
BOOST_FOREACH(Key key, varIndex | boost::adaptors::map_keys) {
// initialize to multiply with other unary factors later // initialize to multiply with other unary factors later
unary_.insert(make_pair(key, DecisionTreeFactor(allDiscreteKeys.at(key), "1 1"))); DecisionTreeFactor prodOfUnaries(allDiscreteKeys.at(key), "1 1");
// collect all factors involving this key in the original graph // collect all factors involving this key in the original graph
DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph()); DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
BOOST_FOREACH(size_t factorIdx, varIndex_[key]) { BOOST_FOREACH(size_t factorIdx, varIndex[key]) {
star->push_back(graph_.at(factorIdx)); star->push_back(graph.at(factorIdx));
if (graph_.at(factorIdx)->size() == 1) {
unary_.at(key) = unary_.at(key) // accumulate unary factors
if (graph.at(factorIdx)->size() == 1) {
prodOfUnaries = prodOfUnaries
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>( * (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
graph_.at(factorIdx))); graph.at(factorIdx)));
} }
} }
// add the belief factor for each other variable in this star graph // add the belief factor for each neighbor variable to this star graph
// also record the factor index for later modification // also record the factor index for later modification
FastSet<Key> neighbors = star->keys(); FastSet<Key> neighbors = star->keys();
neighbors.erase(key); neighbors.erase(key);
@ -109,9 +132,12 @@ private:
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "1.0 0.0")); DecisionTreeFactor(allDiscreteKeys.at(neighbor), "1.0 0.0"));
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1)); correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
} }
starGraphs_.insert( starGraphs.insert(
make_pair(key, StarGraph(star, correctedBeliefIndices))); make_pair(key,
StarGraph(star, make_shared<DecisionTreeFactor>(prodOfUnaries),
correctedBeliefIndices)));
} }
return starGraphs;
} }
}; };
@ -123,7 +149,7 @@ TEST_UNSAFE(LoopyBelief, construction) {
// Map from key to DiscreteKey for building belief factor. // Map from key to DiscreteKey for building belief factor.
// TODO: this is bad! // TODO: this is bad!
std::map<Key, DiscreteKey> allKeys = map_list_of(0,C)(1,S)(2,R)(3,W); std::map<Key, DiscreteKey> allKeys = map_list_of(0, C)(1, S)(2, R)(3, W);
// Build graph // Build graph
DecisionTreeFactor pC(C, "0.5 0.5"); DecisionTreeFactor pC(C, "0.5 0.5");
@ -142,7 +168,9 @@ TEST_UNSAFE(LoopyBelief, construction) {
// Main loop // Main loop
for (size_t iter = 0; iter < 10; ++iter) { for (size_t iter = 0; iter < 10; ++iter) {
solver.iterate(); DiscreteFactorGraph::shared_ptr beliefs = solver.iterate();
cout << "iteration: " << iter << endl;
beliefs->print();
} }
} }