diff --git a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp index ed1064fed..552db4dca 100644 --- a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp +++ b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp @@ -15,44 +15,53 @@ #include using namespace std; +using namespace boost; using namespace boost::assign; using namespace gtsam; +/** + * Loopy belief solver for graphs with only binary and unary factors + */ class LoopyBelief { + /** Star graph struct for each node, containing + * - the star graph itself + * - the product of original unary factors so we don't have to recompute it later, and + * - the factor indices of the corrected belief factors of the neighboring nodes + */ typedef std::map CorrectedBeliefIndices; - struct StarGraph { DiscreteFactorGraph::shared_ptr star; + DecisionTreeFactor::shared_ptr unary; CorrectedBeliefIndices correctedBeliefIndices; - StarGraph() { - } StarGraph(const DiscreteFactorGraph::shared_ptr& _star, + const DecisionTreeFactor::shared_ptr& _unary, const CorrectedBeliefIndices& _beliefIndices) : - star(_star), correctedBeliefIndices(_beliefIndices) { + star(_star), unary(_unary), correctedBeliefIndices(_beliefIndices) { } }; - const DiscreteFactorGraph& graph_; - VariableIndex varIndex_; - std::map starGraphs_; - std::map unary_; - std::map belief_; + typedef std::map StarGraphs; + StarGraphs starGraphs_; ///< star graph at each variable public: - LoopyBelief(const DiscreteFactorGraph& graph, const std::map& allDiscreteKeys) : - graph_(graph), varIndex_(graph) { - initialize(allDiscreteKeys); + /** Constructor + * Need all discrete keys to access node's cardinality for creating belief factors + * TODO: so troublesome!! + */ + LoopyBelief(const DiscreteFactorGraph& graph, + const std::map& allDiscreteKeys) : + starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) { } - void iterate() { + /// One step of belief propagation + DiscreteFactorGraph::shared_ptr iterate() { static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination - + DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph()); // Eliminate each star graph - BOOST_FOREACH(const VariableIndex::value_type& keyFactors, varIndex_) { - Key key = keyFactors.first; - // reset belief to the unary factor in the original graph - belief_[key] = unary_.at(key); + BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) { + // initialize belief to the unary factor from the original graph + DecisionTreeFactor beliefAtKey = *starGraphs_.at(key).unary; // keep intermediate messages to divide later std::map messages; @@ -60,45 +69,59 @@ public: // eliminate each neighbor in this star graph one by one BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { DiscreteFactor::shared_ptr factor; - boost::tie(dummyCond, factor) = EliminateDiscrete(*starGraphs_.at(key).star, - Ordering(list_of(neighbor))); + boost::tie(dummyCond, factor) = EliminateDiscrete( + *starGraphs_.at(key).star, Ordering(list_of(neighbor))); // store the new factor into messages messages.insert(make_pair(neighbor, factor)); // Belief is the product of all messages and the unary factor // Incorporate new the factor to belief - belief_.at(key) = belief_.at(key) * (*boost::dynamic_pointer_cast(factor)); + beliefAtKey = beliefAtKey + * (*boost::dynamic_pointer_cast(factor)); } + beliefs->push_back(beliefAtKey); // Update the corrected belief for the neighbor's stargraph BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { - DecisionTreeFactor correctedBelief = belief_.at(key) - / (*boost::dynamic_pointer_cast(messages.at(neighbor))); + DecisionTreeFactor correctedBelief = beliefAtKey + / (*boost::dynamic_pointer_cast( + messages.at(neighbor))); size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at( key); - starGraphs_.at(neighbor).star->replace(beliefIndex, boost::make_shared(correctedBelief)); + starGraphs_.at(neighbor).star->replace(beliefIndex, + boost::make_shared(correctedBelief)); } } + + return beliefs; } private: - void initialize(const std::map& allDiscreteKeys) { - BOOST_FOREACH(Key key, varIndex_ | boost::adaptors::map_keys) { + /** + * Build star graphs for each node. + */ + StarGraphs buildStarGraphs(const DiscreteFactorGraph& graph, + const std::map& allDiscreteKeys) const { + StarGraphs starGraphs; + VariableIndex varIndex(graph); ///< access to all factors of each node + BOOST_FOREACH(Key key, varIndex | boost::adaptors::map_keys) { // initialize to multiply with other unary factors later - unary_.insert(make_pair(key, DecisionTreeFactor(allDiscreteKeys.at(key), "1 1"))); + DecisionTreeFactor prodOfUnaries(allDiscreteKeys.at(key), "1 1"); // collect all factors involving this key in the original graph DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph()); - BOOST_FOREACH(size_t factorIdx, varIndex_[key]) { - star->push_back(graph_.at(factorIdx)); - if (graph_.at(factorIdx)->size() == 1) { - unary_.at(key) = unary_.at(key) + BOOST_FOREACH(size_t factorIdx, varIndex[key]) { + star->push_back(graph.at(factorIdx)); + + // accumulate unary factors + if (graph.at(factorIdx)->size() == 1) { + prodOfUnaries = prodOfUnaries * (*boost::dynamic_pointer_cast( - graph_.at(factorIdx))); + graph.at(factorIdx))); } } - // add the belief factor for each other variable in this star graph + // add the belief factor for each neighbor variable to this star graph // also record the factor index for later modification FastSet neighbors = star->keys(); neighbors.erase(key); @@ -109,9 +132,12 @@ private: DecisionTreeFactor(allDiscreteKeys.at(neighbor), "1.0 0.0")); correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1)); } - starGraphs_.insert( - make_pair(key, StarGraph(star, correctedBeliefIndices))); + starGraphs.insert( + make_pair(key, + StarGraph(star, make_shared(prodOfUnaries), + correctedBeliefIndices))); } + return starGraphs; } }; @@ -123,7 +149,7 @@ TEST_UNSAFE(LoopyBelief, construction) { // Map from key to DiscreteKey for building belief factor. // TODO: this is bad! - std::map allKeys = map_list_of(0,C)(1,S)(2,R)(3,W); + std::map allKeys = map_list_of(0, C)(1, S)(2, R)(3, W); // Build graph DecisionTreeFactor pC(C, "0.5 0.5"); @@ -142,7 +168,9 @@ TEST_UNSAFE(LoopyBelief, construction) { // Main loop for (size_t iter = 0; iter < 10; ++iter) { - solver.iterate(); + DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(); + cout << "iteration: " << iter << endl; + beliefs->print(); } }