support multivalue variables
parent
9bff152dbe
commit
16af82dc86
|
|
@ -38,8 +38,8 @@ class LoopyBelief {
|
||||||
StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
|
StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
|
||||||
const CorrectedBeliefIndices& _beliefIndices,
|
const CorrectedBeliefIndices& _beliefIndices,
|
||||||
const DecisionTreeFactor::shared_ptr& _unary) :
|
const DecisionTreeFactor::shared_ptr& _unary) :
|
||||||
star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
|
star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
|
||||||
*_star) {
|
*_star) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(const std::string& s = "") const {
|
void print(const std::string& s = "") const {
|
||||||
|
|
@ -64,7 +64,7 @@ public:
|
||||||
*/
|
*/
|
||||||
LoopyBelief(const DiscreteFactorGraph& graph,
|
LoopyBelief(const DiscreteFactorGraph& graph,
|
||||||
const std::map<Key, DiscreteKey>& allDiscreteKeys) :
|
const std::map<Key, DiscreteKey>& allDiscreteKeys) :
|
||||||
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
|
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// print
|
/// print
|
||||||
|
|
@ -78,12 +78,13 @@ public:
|
||||||
/// One step of belief propagation
|
/// One step of belief propagation
|
||||||
DiscreteFactorGraph::shared_ptr iterate(
|
DiscreteFactorGraph::shared_ptr iterate(
|
||||||
const std::map<Key, DiscreteKey>& allDiscreteKeys) {
|
const std::map<Key, DiscreteKey>& allDiscreteKeys) {
|
||||||
|
static const bool debug = false;
|
||||||
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
|
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
|
||||||
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
|
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
|
||||||
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
|
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
|
||||||
// Eliminate each star graph
|
// Eliminate each star graph
|
||||||
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
||||||
cout << "***** Node " << key << "*****" << endl;
|
// cout << "***** Node " << key << "*****" << endl;
|
||||||
// initialize belief to the unary factor from the original graph
|
// initialize belief to the unary factor from the original graph
|
||||||
DecisionTreeFactor::shared_ptr beliefAtKey;
|
DecisionTreeFactor::shared_ptr beliefAtKey;
|
||||||
|
|
||||||
|
|
@ -96,13 +97,13 @@ public:
|
||||||
BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) {
|
BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) {
|
||||||
subGraph.push_back(starGraphs_.at(key).star->at(factor));
|
subGraph.push_back(starGraphs_.at(key).star->at(factor));
|
||||||
}
|
}
|
||||||
subGraph.print("------- Subgraph:");
|
if (debug) subGraph.print("------- Subgraph:");
|
||||||
DiscreteFactor::shared_ptr message;
|
DiscreteFactor::shared_ptr message;
|
||||||
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph,
|
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph,
|
||||||
Ordering(list_of(neighbor)));
|
Ordering(list_of(neighbor)));
|
||||||
// store the new factor into messages
|
// store the new factor into messages
|
||||||
messages.insert(make_pair(neighbor, message));
|
messages.insert(make_pair(neighbor, message));
|
||||||
message->print("------- Message: ");
|
if (debug) message->print("------- Message: ");
|
||||||
|
|
||||||
// Belief is the product of all messages and the unary factor
|
// Belief is the product of all messages and the unary factor
|
||||||
// Incorporate new the factor to belief
|
// Incorporate new the factor to belief
|
||||||
|
|
@ -113,24 +114,27 @@ public:
|
||||||
beliefAtKey =
|
beliefAtKey =
|
||||||
make_shared<DecisionTreeFactor>(
|
make_shared<DecisionTreeFactor>(
|
||||||
(*beliefAtKey)
|
(*beliefAtKey)
|
||||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
message)));
|
message)));
|
||||||
}
|
}
|
||||||
if (starGraphs_.at(key).unary)
|
if (starGraphs_.at(key).unary)
|
||||||
beliefAtKey = make_shared<DecisionTreeFactor>(
|
beliefAtKey = make_shared<DecisionTreeFactor>(
|
||||||
(*beliefAtKey) * (*starGraphs_.at(key).unary));
|
(*beliefAtKey) * (*starGraphs_.at(key).unary));
|
||||||
beliefAtKey->print("New belief at key: ");
|
if (debug) beliefAtKey->print("New belief at key: ");
|
||||||
// normalize belief
|
// normalize belief
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for (size_t v = 0; v<allDiscreteKeys.at(key).second; ++v) {
|
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
|
||||||
DiscreteFactor::Values val;
|
DiscreteFactor::Values val;
|
||||||
val[key] = v;
|
val[key] = v;
|
||||||
sum += (*beliefAtKey)(val);
|
sum += (*beliefAtKey)(val);
|
||||||
}
|
}
|
||||||
DecisionTreeFactor denomFactor(allDiscreteKeys.at(key), (boost::format("%f %f")%sum%sum).str());
|
string sumFactorTable;
|
||||||
denomFactor.print("denomFactor: ");
|
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v)
|
||||||
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey)/denomFactor);
|
sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str();
|
||||||
beliefAtKey->print("New belief at key normalized: ");
|
DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable);
|
||||||
|
if (debug) sumFactor.print("denomFactor: ");
|
||||||
|
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
|
||||||
|
if (debug) beliefAtKey->print("New belief at key normalized: ");
|
||||||
beliefs->push_back(beliefAtKey);
|
beliefs->push_back(beliefAtKey);
|
||||||
allMessages[key] = messages;
|
allMessages[key] = messages;
|
||||||
}
|
}
|
||||||
|
|
@ -140,19 +144,19 @@ public:
|
||||||
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
||||||
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
|
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
|
||||||
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
||||||
DecisionTreeFactor
|
DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast<
|
||||||
correctedBelief = (*boost::dynamic_pointer_cast<DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
|
DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
|
||||||
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
messages.at(neighbor)));
|
messages.at(neighbor)));
|
||||||
correctedBelief.print("correctedBelief: ");
|
if (debug) correctedBelief.print("correctedBelief: ");
|
||||||
size_t beliefIndex =
|
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
|
||||||
starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
|
key);
|
||||||
starGraphs_.at(neighbor).star->replace(beliefIndex,
|
starGraphs_.at(neighbor).star->replace(beliefIndex,
|
||||||
boost::make_shared<DecisionTreeFactor>(correctedBelief));
|
boost::make_shared<DecisionTreeFactor>(correctedBelief));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
print("After update: ");
|
if (debug) print("After update: ");
|
||||||
|
|
||||||
return beliefs;
|
return beliefs;
|
||||||
}
|
}
|
||||||
|
|
@ -182,8 +186,8 @@ private:
|
||||||
else
|
else
|
||||||
prodOfUnaries = make_shared<DecisionTreeFactor>(
|
prodOfUnaries = make_shared<DecisionTreeFactor>(
|
||||||
*prodOfUnaries
|
*prodOfUnaries
|
||||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
graph.at(factorIdx))));
|
graph.at(factorIdx))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -194,10 +198,14 @@ private:
|
||||||
CorrectedBeliefIndices correctedBeliefIndices;
|
CorrectedBeliefIndices correctedBeliefIndices;
|
||||||
BOOST_FOREACH(Key neighbor, neighbors) {
|
BOOST_FOREACH(Key neighbor, neighbors) {
|
||||||
// TODO: default table for keys with more than 2 values?
|
// TODO: default table for keys with more than 2 values?
|
||||||
|
string initialBelief;
|
||||||
|
for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) {
|
||||||
|
initialBelief = initialBelief + "0.0 ";
|
||||||
|
}
|
||||||
|
initialBelief = initialBelief + "1.0";
|
||||||
star->push_back(
|
star->push_back(
|
||||||
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "0.0 1.0"));
|
DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief));
|
||||||
correctedBeliefIndices.insert(
|
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
|
||||||
make_pair(neighbor, star->size() - 1));
|
|
||||||
}
|
}
|
||||||
starGraphs.insert(
|
starGraphs.insert(
|
||||||
make_pair(key,
|
make_pair(key,
|
||||||
|
|
@ -221,7 +229,7 @@ TEST_UNSAFE(LoopyBelief, construction) {
|
||||||
DecisionTreeFactor pC(C, "0.5 0.5");
|
DecisionTreeFactor pC(C, "0.5 0.5");
|
||||||
DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
|
DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
|
||||||
DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
|
DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
|
||||||
DecisionTreeFactor pSR( S & R, "0.0 0.9 0.9 0.99");
|
DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99");
|
||||||
|
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.push_back(pC);
|
graph.push_back(pC);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue