support multivalue variables
parent
9bff152dbe
commit
16af82dc86
|
|
@ -38,8 +38,8 @@ class LoopyBelief {
|
|||
StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
|
||||
const CorrectedBeliefIndices& _beliefIndices,
|
||||
const DecisionTreeFactor::shared_ptr& _unary) :
|
||||
star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
|
||||
*_star) {
|
||||
star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
|
||||
*_star) {
|
||||
}
|
||||
|
||||
void print(const std::string& s = "") const {
|
||||
|
|
@ -64,7 +64,7 @@ public:
|
|||
*/
|
||||
LoopyBelief(const DiscreteFactorGraph& graph,
|
||||
const std::map<Key, DiscreteKey>& allDiscreteKeys) :
|
||||
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
|
||||
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
|
||||
}
|
||||
|
||||
/// print
|
||||
|
|
@ -78,12 +78,13 @@ public:
|
|||
/// One step of belief propagation
|
||||
DiscreteFactorGraph::shared_ptr iterate(
|
||||
const std::map<Key, DiscreteKey>& allDiscreteKeys) {
|
||||
static const bool debug = false;
|
||||
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
|
||||
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
|
||||
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
|
||||
// Eliminate each star graph
|
||||
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
||||
cout << "***** Node " << key << "*****" << endl;
|
||||
// cout << "***** Node " << key << "*****" << endl;
|
||||
// initialize belief to the unary factor from the original graph
|
||||
DecisionTreeFactor::shared_ptr beliefAtKey;
|
||||
|
||||
|
|
@ -96,13 +97,13 @@ public:
|
|||
BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) {
|
||||
subGraph.push_back(starGraphs_.at(key).star->at(factor));
|
||||
}
|
||||
subGraph.print("------- Subgraph:");
|
||||
if (debug) subGraph.print("------- Subgraph:");
|
||||
DiscreteFactor::shared_ptr message;
|
||||
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph,
|
||||
Ordering(list_of(neighbor)));
|
||||
// store the new factor into messages
|
||||
messages.insert(make_pair(neighbor, message));
|
||||
message->print("------- Message: ");
|
||||
if (debug) message->print("------- Message: ");
|
||||
|
||||
// Belief is the product of all messages and the unary factor
|
||||
// Incorporate new the factor to belief
|
||||
|
|
@ -113,24 +114,27 @@ public:
|
|||
beliefAtKey =
|
||||
make_shared<DecisionTreeFactor>(
|
||||
(*beliefAtKey)
|
||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
message)));
|
||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
message)));
|
||||
}
|
||||
if (starGraphs_.at(key).unary)
|
||||
beliefAtKey = make_shared<DecisionTreeFactor>(
|
||||
(*beliefAtKey) * (*starGraphs_.at(key).unary));
|
||||
beliefAtKey->print("New belief at key: ");
|
||||
if (debug) beliefAtKey->print("New belief at key: ");
|
||||
// normalize belief
|
||||
double sum = 0.0;
|
||||
for (size_t v = 0; v<allDiscreteKeys.at(key).second; ++v) {
|
||||
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
|
||||
DiscreteFactor::Values val;
|
||||
val[key] = v;
|
||||
sum += (*beliefAtKey)(val);
|
||||
}
|
||||
DecisionTreeFactor denomFactor(allDiscreteKeys.at(key), (boost::format("%f %f")%sum%sum).str());
|
||||
denomFactor.print("denomFactor: ");
|
||||
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey)/denomFactor);
|
||||
beliefAtKey->print("New belief at key normalized: ");
|
||||
string sumFactorTable;
|
||||
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v)
|
||||
sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str();
|
||||
DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable);
|
||||
if (debug) sumFactor.print("denomFactor: ");
|
||||
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
|
||||
if (debug) beliefAtKey->print("New belief at key normalized: ");
|
||||
beliefs->push_back(beliefAtKey);
|
||||
allMessages[key] = messages;
|
||||
}
|
||||
|
|
@ -140,19 +144,19 @@ public:
|
|||
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
||||
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
|
||||
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
||||
DecisionTreeFactor
|
||||
correctedBelief = (*boost::dynamic_pointer_cast<DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
|
||||
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
messages.at(neighbor)));
|
||||
correctedBelief.print("correctedBelief: ");
|
||||
size_t beliefIndex =
|
||||
starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
|
||||
DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast<
|
||||
DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
|
||||
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
messages.at(neighbor)));
|
||||
if (debug) correctedBelief.print("correctedBelief: ");
|
||||
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
|
||||
key);
|
||||
starGraphs_.at(neighbor).star->replace(beliefIndex,
|
||||
boost::make_shared<DecisionTreeFactor>(correctedBelief));
|
||||
}
|
||||
}
|
||||
|
||||
print("After update: ");
|
||||
if (debug) print("After update: ");
|
||||
|
||||
return beliefs;
|
||||
}
|
||||
|
|
@ -182,8 +186,8 @@ private:
|
|||
else
|
||||
prodOfUnaries = make_shared<DecisionTreeFactor>(
|
||||
*prodOfUnaries
|
||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
graph.at(factorIdx))));
|
||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
graph.at(factorIdx))));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -194,10 +198,14 @@ private:
|
|||
CorrectedBeliefIndices correctedBeliefIndices;
|
||||
BOOST_FOREACH(Key neighbor, neighbors) {
|
||||
// TODO: default table for keys with more than 2 values?
|
||||
string initialBelief;
|
||||
for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) {
|
||||
initialBelief = initialBelief + "0.0 ";
|
||||
}
|
||||
initialBelief = initialBelief + "1.0";
|
||||
star->push_back(
|
||||
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "0.0 1.0"));
|
||||
correctedBeliefIndices.insert(
|
||||
make_pair(neighbor, star->size() - 1));
|
||||
DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief));
|
||||
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
|
||||
}
|
||||
starGraphs.insert(
|
||||
make_pair(key,
|
||||
|
|
@ -221,7 +229,7 @@ TEST_UNSAFE(LoopyBelief, construction) {
|
|||
DecisionTreeFactor pC(C, "0.5 0.5");
|
||||
DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
|
||||
DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
|
||||
DecisionTreeFactor pSR( S & R, "0.0 0.9 0.9 0.99");
|
||||
DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99");
|
||||
|
||||
DiscreteFactorGraph graph;
|
||||
graph.push_back(pC);
|
||||
|
|
|
|||
Loading…
Reference in New Issue