diff --git a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp index 9d8d3b40c..f6face900 100644 --- a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp +++ b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp @@ -38,8 +38,8 @@ class LoopyBelief { StarGraph(const DiscreteFactorGraph::shared_ptr& _star, const CorrectedBeliefIndices& _beliefIndices, const DecisionTreeFactor::shared_ptr& _unary) : - star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_( - *_star) { + star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_( + *_star) { } void print(const std::string& s = "") const { @@ -64,7 +64,7 @@ public: */ LoopyBelief(const DiscreteFactorGraph& graph, const std::map& allDiscreteKeys) : - starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) { + starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) { } /// print @@ -78,12 +78,13 @@ public: /// One step of belief propagation DiscreteFactorGraph::shared_ptr iterate( const std::map& allDiscreteKeys) { + static const bool debug = false; static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph()); std::map > allMessages; // Eliminate each star graph BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) { - cout << "***** Node " << key << "*****" << endl; +// cout << "***** Node " << key << "*****" << endl; // initialize belief to the unary factor from the original graph DecisionTreeFactor::shared_ptr beliefAtKey; @@ -96,13 +97,13 @@ public: BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) { subGraph.push_back(starGraphs_.at(key).star->at(factor)); } - subGraph.print("------- Subgraph:"); + if (debug) subGraph.print("------- Subgraph:"); DiscreteFactor::shared_ptr message; boost::tie(dummyCond, message) = EliminateDiscrete(subGraph, Ordering(list_of(neighbor))); // store the new factor into messages messages.insert(make_pair(neighbor, message)); - message->print("------- Message: "); + if (debug) message->print("------- Message: "); // Belief is the product of all messages and the unary factor // Incorporate new the factor to belief @@ -113,24 +114,27 @@ public: beliefAtKey = make_shared( (*beliefAtKey) - * (*boost::dynamic_pointer_cast( - message))); + * (*boost::dynamic_pointer_cast( + message))); } if (starGraphs_.at(key).unary) beliefAtKey = make_shared( (*beliefAtKey) * (*starGraphs_.at(key).unary)); - beliefAtKey->print("New belief at key: "); + if (debug) beliefAtKey->print("New belief at key: "); // normalize belief double sum = 0.0; - for (size_t v = 0; v((*beliefAtKey)/denomFactor); - beliefAtKey->print("New belief at key normalized: "); + string sumFactorTable; + for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) + sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str(); + DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable); + if (debug) sumFactor.print("denomFactor: "); + beliefAtKey = make_shared((*beliefAtKey) / sumFactor); + if (debug) beliefAtKey->print("New belief at key normalized: "); beliefs->push_back(beliefAtKey); allMessages[key] = messages; } @@ -140,19 +144,19 @@ public: BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) { std::map messages = allMessages[key]; BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { - DecisionTreeFactor - correctedBelief = (*boost::dynamic_pointer_cast(beliefs->at(beliefFactors[key].front()))) - / (*boost::dynamic_pointer_cast( - messages.at(neighbor))); - correctedBelief.print("correctedBelief: "); - size_t beliefIndex = - starGraphs_.at(neighbor).correctedBeliefIndices.at(key); + DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast< + DecisionTreeFactor>(beliefs->at(beliefFactors[key].front()))) + / (*boost::dynamic_pointer_cast( + messages.at(neighbor))); + if (debug) correctedBelief.print("correctedBelief: "); + size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at( + key); starGraphs_.at(neighbor).star->replace(beliefIndex, boost::make_shared(correctedBelief)); } } - print("After update: "); + if (debug) print("After update: "); return beliefs; } @@ -182,8 +186,8 @@ private: else prodOfUnaries = make_shared( *prodOfUnaries - * (*boost::dynamic_pointer_cast( - graph.at(factorIdx)))); + * (*boost::dynamic_pointer_cast( + graph.at(factorIdx)))); } } @@ -194,10 +198,14 @@ private: CorrectedBeliefIndices correctedBeliefIndices; BOOST_FOREACH(Key neighbor, neighbors) { // TODO: default table for keys with more than 2 values? + string initialBelief; + for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) { + initialBelief = initialBelief + "0.0 "; + } + initialBelief = initialBelief + "1.0"; star->push_back( - DecisionTreeFactor(allDiscreteKeys.at(neighbor), "0.0 1.0")); - correctedBeliefIndices.insert( - make_pair(neighbor, star->size() - 1)); + DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief)); + correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1)); } starGraphs.insert( make_pair(key, @@ -221,7 +229,7 @@ TEST_UNSAFE(LoopyBelief, construction) { DecisionTreeFactor pC(C, "0.5 0.5"); DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1"); DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8"); - DecisionTreeFactor pSR( S & R, "0.0 0.9 0.9 0.99"); + DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99"); DiscreteFactorGraph graph; graph.push_back(pC);