support multivalue variables

release/4.3a0
Duy-Nguyen Ta 2013-10-15 18:25:05 +00:00
parent 9bff152dbe
commit 16af82dc86
1 changed files with 36 additions and 28 deletions

View File

@ -78,12 +78,13 @@ public:
/// One step of belief propagation
DiscreteFactorGraph::shared_ptr iterate(
const std::map<Key, DiscreteKey>& allDiscreteKeys) {
static const bool debug = false;
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
// Eliminate each star graph
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
cout << "***** Node " << key << "*****" << endl;
// cout << "***** Node " << key << "*****" << endl;
// initialize belief to the unary factor from the original graph
DecisionTreeFactor::shared_ptr beliefAtKey;
@ -96,13 +97,13 @@ public:
BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) {
subGraph.push_back(starGraphs_.at(key).star->at(factor));
}
subGraph.print("------- Subgraph:");
if (debug) subGraph.print("------- Subgraph:");
DiscreteFactor::shared_ptr message;
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph,
Ordering(list_of(neighbor)));
// store the new factor into messages
messages.insert(make_pair(neighbor, message));
message->print("------- Message: ");
if (debug) message->print("------- Message: ");
// Belief is the product of all messages and the unary factor
// Incorporate new the factor to belief
@ -119,18 +120,21 @@ public:
if (starGraphs_.at(key).unary)
beliefAtKey = make_shared<DecisionTreeFactor>(
(*beliefAtKey) * (*starGraphs_.at(key).unary));
beliefAtKey->print("New belief at key: ");
if (debug) beliefAtKey->print("New belief at key: ");
// normalize belief
double sum = 0.0;
for (size_t v = 0; v<allDiscreteKeys.at(key).second; ++v) {
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
DiscreteFactor::Values val;
val[key] = v;
sum += (*beliefAtKey)(val);
}
DecisionTreeFactor denomFactor(allDiscreteKeys.at(key), (boost::format("%f %f")%sum%sum).str());
denomFactor.print("denomFactor: ");
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey)/denomFactor);
beliefAtKey->print("New belief at key normalized: ");
string sumFactorTable;
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v)
sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str();
DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable);
if (debug) sumFactor.print("denomFactor: ");
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
if (debug) beliefAtKey->print("New belief at key normalized: ");
beliefs->push_back(beliefAtKey);
allMessages[key] = messages;
}
@ -140,19 +144,19 @@ public:
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
DecisionTreeFactor
correctedBelief = (*boost::dynamic_pointer_cast<DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast<
DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
messages.at(neighbor)));
correctedBelief.print("correctedBelief: ");
size_t beliefIndex =
starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
if (debug) correctedBelief.print("correctedBelief: ");
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
key);
starGraphs_.at(neighbor).star->replace(beliefIndex,
boost::make_shared<DecisionTreeFactor>(correctedBelief));
}
}
print("After update: ");
if (debug) print("After update: ");
return beliefs;
}
@ -194,10 +198,14 @@ private:
CorrectedBeliefIndices correctedBeliefIndices;
BOOST_FOREACH(Key neighbor, neighbors) {
// TODO: default table for keys with more than 2 values?
string initialBelief;
for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) {
initialBelief = initialBelief + "0.0 ";
}
initialBelief = initialBelief + "1.0";
star->push_back(
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "0.0 1.0"));
correctedBeliefIndices.insert(
make_pair(neighbor, star->size() - 1));
DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief));
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
}
starGraphs.insert(
make_pair(key,
@ -221,7 +229,7 @@ TEST_UNSAFE(LoopyBelief, construction) {
DecisionTreeFactor pC(C, "0.5 0.5");
DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
DecisionTreeFactor pSR( S & R, "0.0 0.9 0.9 0.99");
DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99");
DiscreteFactorGraph graph;
graph.push_back(pC);