refactor and add more comments
parent
1f293294fd
commit
8be7363a01
|
@ -15,44 +15,53 @@
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
using namespace boost;
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loopy belief solver for graphs with only binary and unary factors
|
||||||
|
*/
|
||||||
class LoopyBelief {
|
class LoopyBelief {
|
||||||
|
|
||||||
|
/** Star graph struct for each node, containing
|
||||||
|
* - the star graph itself
|
||||||
|
* - the product of original unary factors so we don't have to recompute it later, and
|
||||||
|
* - the factor indices of the corrected belief factors of the neighboring nodes
|
||||||
|
*/
|
||||||
typedef std::map<Key, size_t> CorrectedBeliefIndices;
|
typedef std::map<Key, size_t> CorrectedBeliefIndices;
|
||||||
|
|
||||||
struct StarGraph {
|
struct StarGraph {
|
||||||
DiscreteFactorGraph::shared_ptr star;
|
DiscreteFactorGraph::shared_ptr star;
|
||||||
|
DecisionTreeFactor::shared_ptr unary;
|
||||||
CorrectedBeliefIndices correctedBeliefIndices;
|
CorrectedBeliefIndices correctedBeliefIndices;
|
||||||
StarGraph() {
|
|
||||||
}
|
|
||||||
StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
|
StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
|
||||||
|
const DecisionTreeFactor::shared_ptr& _unary,
|
||||||
const CorrectedBeliefIndices& _beliefIndices) :
|
const CorrectedBeliefIndices& _beliefIndices) :
|
||||||
star(_star), correctedBeliefIndices(_beliefIndices) {
|
star(_star), unary(_unary), correctedBeliefIndices(_beliefIndices) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const DiscreteFactorGraph& graph_;
|
typedef std::map<Key, StarGraph> StarGraphs;
|
||||||
VariableIndex varIndex_;
|
StarGraphs starGraphs_; ///< star graph at each variable
|
||||||
std::map<Key, StarGraph> starGraphs_;
|
|
||||||
std::map<Key, DecisionTreeFactor> unary_;
|
|
||||||
std::map<Key, DecisionTreeFactor> belief_;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LoopyBelief(const DiscreteFactorGraph& graph, const std::map<Key, DiscreteKey>& allDiscreteKeys) :
|
/** Constructor
|
||||||
graph_(graph), varIndex_(graph) {
|
* Need all discrete keys to access node's cardinality for creating belief factors
|
||||||
initialize(allDiscreteKeys);
|
* TODO: so troublesome!!
|
||||||
|
*/
|
||||||
|
LoopyBelief(const DiscreteFactorGraph& graph,
|
||||||
|
const std::map<Key, DiscreteKey>& allDiscreteKeys) :
|
||||||
|
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void iterate() {
|
/// One step of belief propagation
|
||||||
|
DiscreteFactorGraph::shared_ptr iterate() {
|
||||||
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
|
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
|
||||||
|
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
|
||||||
// Eliminate each star graph
|
// Eliminate each star graph
|
||||||
BOOST_FOREACH(const VariableIndex::value_type& keyFactors, varIndex_) {
|
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
||||||
Key key = keyFactors.first;
|
// initialize belief to the unary factor from the original graph
|
||||||
// reset belief to the unary factor in the original graph
|
DecisionTreeFactor beliefAtKey = *starGraphs_.at(key).unary;
|
||||||
belief_[key] = unary_.at(key);
|
|
||||||
|
|
||||||
// keep intermediate messages to divide later
|
// keep intermediate messages to divide later
|
||||||
std::map<Key, DiscreteFactor::shared_ptr> messages;
|
std::map<Key, DiscreteFactor::shared_ptr> messages;
|
||||||
|
@ -60,45 +69,59 @@ public:
|
||||||
// eliminate each neighbor in this star graph one by one
|
// eliminate each neighbor in this star graph one by one
|
||||||
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
||||||
DiscreteFactor::shared_ptr factor;
|
DiscreteFactor::shared_ptr factor;
|
||||||
boost::tie(dummyCond, factor) = EliminateDiscrete(*starGraphs_.at(key).star,
|
boost::tie(dummyCond, factor) = EliminateDiscrete(
|
||||||
Ordering(list_of(neighbor)));
|
*starGraphs_.at(key).star, Ordering(list_of(neighbor)));
|
||||||
// store the new factor into messages
|
// store the new factor into messages
|
||||||
messages.insert(make_pair(neighbor, factor));
|
messages.insert(make_pair(neighbor, factor));
|
||||||
|
|
||||||
// Belief is the product of all messages and the unary factor
|
// Belief is the product of all messages and the unary factor
|
||||||
// Incorporate new the factor to belief
|
// Incorporate new the factor to belief
|
||||||
belief_.at(key) = belief_.at(key) * (*boost::dynamic_pointer_cast<DecisionTreeFactor>(factor));
|
beliefAtKey = beliefAtKey
|
||||||
|
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(factor));
|
||||||
}
|
}
|
||||||
|
beliefs->push_back(beliefAtKey);
|
||||||
|
|
||||||
// Update the corrected belief for the neighbor's stargraph
|
// Update the corrected belief for the neighbor's stargraph
|
||||||
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
||||||
DecisionTreeFactor correctedBelief = belief_.at(key)
|
DecisionTreeFactor correctedBelief = beliefAtKey
|
||||||
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(messages.at(neighbor)));
|
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
|
messages.at(neighbor)));
|
||||||
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
|
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
|
||||||
key);
|
key);
|
||||||
starGraphs_.at(neighbor).star->replace(beliefIndex, boost::make_shared<DecisionTreeFactor>(correctedBelief));
|
starGraphs_.at(neighbor).star->replace(beliefIndex,
|
||||||
|
boost::make_shared<DecisionTreeFactor>(correctedBelief));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return beliefs;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void initialize(const std::map<Key, DiscreteKey>& allDiscreteKeys) {
|
/**
|
||||||
BOOST_FOREACH(Key key, varIndex_ | boost::adaptors::map_keys) {
|
* Build star graphs for each node.
|
||||||
|
*/
|
||||||
|
StarGraphs buildStarGraphs(const DiscreteFactorGraph& graph,
|
||||||
|
const std::map<Key, DiscreteKey>& allDiscreteKeys) const {
|
||||||
|
StarGraphs starGraphs;
|
||||||
|
VariableIndex varIndex(graph); ///< access to all factors of each node
|
||||||
|
BOOST_FOREACH(Key key, varIndex | boost::adaptors::map_keys) {
|
||||||
// initialize to multiply with other unary factors later
|
// initialize to multiply with other unary factors later
|
||||||
unary_.insert(make_pair(key, DecisionTreeFactor(allDiscreteKeys.at(key), "1 1")));
|
DecisionTreeFactor prodOfUnaries(allDiscreteKeys.at(key), "1 1");
|
||||||
|
|
||||||
// collect all factors involving this key in the original graph
|
// collect all factors involving this key in the original graph
|
||||||
DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
|
DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
|
||||||
BOOST_FOREACH(size_t factorIdx, varIndex_[key]) {
|
BOOST_FOREACH(size_t factorIdx, varIndex[key]) {
|
||||||
star->push_back(graph_.at(factorIdx));
|
star->push_back(graph.at(factorIdx));
|
||||||
if (graph_.at(factorIdx)->size() == 1) {
|
|
||||||
unary_.at(key) = unary_.at(key)
|
// accumulate unary factors
|
||||||
|
if (graph.at(factorIdx)->size() == 1) {
|
||||||
|
prodOfUnaries = prodOfUnaries
|
||||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
graph_.at(factorIdx)));
|
graph.at(factorIdx)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add the belief factor for each other variable in this star graph
|
// add the belief factor for each neighbor variable to this star graph
|
||||||
// also record the factor index for later modification
|
// also record the factor index for later modification
|
||||||
FastSet<Key> neighbors = star->keys();
|
FastSet<Key> neighbors = star->keys();
|
||||||
neighbors.erase(key);
|
neighbors.erase(key);
|
||||||
|
@ -109,9 +132,12 @@ private:
|
||||||
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "1.0 0.0"));
|
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "1.0 0.0"));
|
||||||
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
|
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
|
||||||
}
|
}
|
||||||
starGraphs_.insert(
|
starGraphs.insert(
|
||||||
make_pair(key, StarGraph(star, correctedBeliefIndices)));
|
make_pair(key,
|
||||||
|
StarGraph(star, make_shared<DecisionTreeFactor>(prodOfUnaries),
|
||||||
|
correctedBeliefIndices)));
|
||||||
}
|
}
|
||||||
|
return starGraphs;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -123,7 +149,7 @@ TEST_UNSAFE(LoopyBelief, construction) {
|
||||||
|
|
||||||
// Map from key to DiscreteKey for building belief factor.
|
// Map from key to DiscreteKey for building belief factor.
|
||||||
// TODO: this is bad!
|
// TODO: this is bad!
|
||||||
std::map<Key, DiscreteKey> allKeys = map_list_of(0,C)(1,S)(2,R)(3,W);
|
std::map<Key, DiscreteKey> allKeys = map_list_of(0, C)(1, S)(2, R)(3, W);
|
||||||
|
|
||||||
// Build graph
|
// Build graph
|
||||||
DecisionTreeFactor pC(C, "0.5 0.5");
|
DecisionTreeFactor pC(C, "0.5 0.5");
|
||||||
|
@ -142,7 +168,9 @@ TEST_UNSAFE(LoopyBelief, construction) {
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (size_t iter = 0; iter < 10; ++iter) {
|
for (size_t iter = 0; iter < 10; ++iter) {
|
||||||
solver.iterate();
|
DiscreteFactorGraph::shared_ptr beliefs = solver.iterate();
|
||||||
|
cout << "iteration: " << iter << endl;
|
||||||
|
beliefs->print();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue