support multivalue variables

release/4.3a0
Duy-Nguyen Ta 2013-10-15 18:25:05 +00:00
parent 9bff152dbe
commit 16af82dc86
1 changed files with 36 additions and 28 deletions

View File

@ -38,8 +38,8 @@ class LoopyBelief {
StarGraph(const DiscreteFactorGraph::shared_ptr& _star, StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
const CorrectedBeliefIndices& _beliefIndices, const CorrectedBeliefIndices& _beliefIndices,
const DecisionTreeFactor::shared_ptr& _unary) : const DecisionTreeFactor::shared_ptr& _unary) :
star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_( star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
*_star) { *_star) {
} }
void print(const std::string& s = "") const { void print(const std::string& s = "") const {
@ -64,7 +64,7 @@ public:
*/ */
LoopyBelief(const DiscreteFactorGraph& graph, LoopyBelief(const DiscreteFactorGraph& graph,
const std::map<Key, DiscreteKey>& allDiscreteKeys) : const std::map<Key, DiscreteKey>& allDiscreteKeys) :
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) { starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
} }
/// print /// print
@ -78,12 +78,13 @@ public:
/// One step of belief propagation /// One step of belief propagation
DiscreteFactorGraph::shared_ptr iterate( DiscreteFactorGraph::shared_ptr iterate(
const std::map<Key, DiscreteKey>& allDiscreteKeys) { const std::map<Key, DiscreteKey>& allDiscreteKeys) {
static const bool debug = false;
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph()); DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages; std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
// Eliminate each star graph // Eliminate each star graph
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) { BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
cout << "***** Node " << key << "*****" << endl; // cout << "***** Node " << key << "*****" << endl;
// initialize belief to the unary factor from the original graph // initialize belief to the unary factor from the original graph
DecisionTreeFactor::shared_ptr beliefAtKey; DecisionTreeFactor::shared_ptr beliefAtKey;
@ -96,13 +97,13 @@ public:
BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) { BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) {
subGraph.push_back(starGraphs_.at(key).star->at(factor)); subGraph.push_back(starGraphs_.at(key).star->at(factor));
} }
subGraph.print("------- Subgraph:"); if (debug) subGraph.print("------- Subgraph:");
DiscreteFactor::shared_ptr message; DiscreteFactor::shared_ptr message;
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph, boost::tie(dummyCond, message) = EliminateDiscrete(subGraph,
Ordering(list_of(neighbor))); Ordering(list_of(neighbor)));
// store the new factor into messages // store the new factor into messages
messages.insert(make_pair(neighbor, message)); messages.insert(make_pair(neighbor, message));
message->print("------- Message: "); if (debug) message->print("------- Message: ");
// Belief is the product of all messages and the unary factor // Belief is the product of all messages and the unary factor
// Incorporate new the factor to belief // Incorporate new the factor to belief
@ -113,24 +114,27 @@ public:
beliefAtKey = beliefAtKey =
make_shared<DecisionTreeFactor>( make_shared<DecisionTreeFactor>(
(*beliefAtKey) (*beliefAtKey)
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>( * (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
message))); message)));
} }
if (starGraphs_.at(key).unary) if (starGraphs_.at(key).unary)
beliefAtKey = make_shared<DecisionTreeFactor>( beliefAtKey = make_shared<DecisionTreeFactor>(
(*beliefAtKey) * (*starGraphs_.at(key).unary)); (*beliefAtKey) * (*starGraphs_.at(key).unary));
beliefAtKey->print("New belief at key: "); if (debug) beliefAtKey->print("New belief at key: ");
// normalize belief // normalize belief
double sum = 0.0; double sum = 0.0;
for (size_t v = 0; v<allDiscreteKeys.at(key).second; ++v) { for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
DiscreteFactor::Values val; DiscreteFactor::Values val;
val[key] = v; val[key] = v;
sum += (*beliefAtKey)(val); sum += (*beliefAtKey)(val);
} }
DecisionTreeFactor denomFactor(allDiscreteKeys.at(key), (boost::format("%f %f")%sum%sum).str()); string sumFactorTable;
denomFactor.print("denomFactor: "); for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v)
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey)/denomFactor); sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str();
beliefAtKey->print("New belief at key normalized: "); DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable);
if (debug) sumFactor.print("denomFactor: ");
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
if (debug) beliefAtKey->print("New belief at key normalized: ");
beliefs->push_back(beliefAtKey); beliefs->push_back(beliefAtKey);
allMessages[key] = messages; allMessages[key] = messages;
} }
@ -140,19 +144,19 @@ public:
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) { BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key]; std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
DecisionTreeFactor DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast<
correctedBelief = (*boost::dynamic_pointer_cast<DecisionTreeFactor>(beliefs->at(beliefFactors[key].front()))) DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>( / (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
messages.at(neighbor))); messages.at(neighbor)));
correctedBelief.print("correctedBelief: "); if (debug) correctedBelief.print("correctedBelief: ");
size_t beliefIndex = size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
starGraphs_.at(neighbor).correctedBeliefIndices.at(key); key);
starGraphs_.at(neighbor).star->replace(beliefIndex, starGraphs_.at(neighbor).star->replace(beliefIndex,
boost::make_shared<DecisionTreeFactor>(correctedBelief)); boost::make_shared<DecisionTreeFactor>(correctedBelief));
} }
} }
print("After update: "); if (debug) print("After update: ");
return beliefs; return beliefs;
} }
@ -182,8 +186,8 @@ private:
else else
prodOfUnaries = make_shared<DecisionTreeFactor>( prodOfUnaries = make_shared<DecisionTreeFactor>(
*prodOfUnaries *prodOfUnaries
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>( * (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
graph.at(factorIdx)))); graph.at(factorIdx))));
} }
} }
@ -194,10 +198,14 @@ private:
CorrectedBeliefIndices correctedBeliefIndices; CorrectedBeliefIndices correctedBeliefIndices;
BOOST_FOREACH(Key neighbor, neighbors) { BOOST_FOREACH(Key neighbor, neighbors) {
// TODO: default table for keys with more than 2 values? // TODO: default table for keys with more than 2 values?
string initialBelief;
for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) {
initialBelief = initialBelief + "0.0 ";
}
initialBelief = initialBelief + "1.0";
star->push_back( star->push_back(
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "0.0 1.0")); DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief));
correctedBeliefIndices.insert( correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
make_pair(neighbor, star->size() - 1));
} }
starGraphs.insert( starGraphs.insert(
make_pair(key, make_pair(key,
@ -221,7 +229,7 @@ TEST_UNSAFE(LoopyBelief, construction) {
DecisionTreeFactor pC(C, "0.5 0.5"); DecisionTreeFactor pC(C, "0.5 0.5");
DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1"); DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8"); DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
DecisionTreeFactor pSR( S & R, "0.0 0.9 0.9 0.99"); DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99");
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.push_back(pC); graph.push_back(pC);