Bug fix in BayesTree marginal, re-enabled joint and unit tests

release/4.3a0
Richard Roberts 2010-10-22 22:11:23 +00:00
parent c47893f105
commit 8ff5bf5c7c
7 changed files with 147 additions and 163 deletions

View File

@ -31,7 +31,7 @@ namespace gtsam {
* percent. * percent.
*/ */
template<typename VALUE> template<typename VALUE>
class FastSet: public std::set<VALUE, std::less<VALUE>, boost::fast_pool_allocator<VALUE> > { class FastSet : public std::set<VALUE, std::less<VALUE>, boost::fast_pool_allocator<VALUE> > {
public: public:

View File

@ -108,14 +108,7 @@ namespace gtsam {
cout << "\n"; cout << "\n";
BOOST_FOREACH(const sharedConditional& conditional, this->conditionals_) { BOOST_FOREACH(const sharedConditional& conditional, this->conditionals_) {
conditional->print(" " + s + "conditional"); conditional->print(" " + s + "conditional");
// cout << " " << conditional->key();
} }
// if (!separator_.empty()) {
// cout << " :";
// BOOST_FOREACH(Index key, separator_)
// cout << " " << key;
// }
// cout << endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -345,11 +338,11 @@ namespace gtsam {
// for(Index j=0; j<integrands.size(); ++j) // for(Index j=0; j<integrands.size(); ++j)
// p_S_R.pop_front(); // p_S_R.pop_front();
// Undo the permutation on the shortcut // Undo the permutation
p_S_R.permuteWithInverse(toBack); p_S_R.permuteWithInverse(toBack);
// return the parent shortcut P(Sp|R) // return the parent shortcut P(Sp|R)
return p_S_R; return *GenericSequentialSolver<typename CONDITIONAL::Factor>(p_S_R).eliminate();
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -374,31 +367,33 @@ namespace gtsam {
return *GenericSequentialSolver<typename CONDITIONAL::Factor>(p_FSR).joint(keys()); return *GenericSequentialSolver<typename CONDITIONAL::Factor>(p_FSR).joint(keys());
} }
// /* ************************************************************************* */ /* ************************************************************************* */
// // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R) // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
// /* ************************************************************************* */ /* ************************************************************************* */
// template<class CONDITIONAL> template<class CONDITIONAL>
// template<class Factor> FactorGraph<typename CONDITIONAL::Factor> BayesTree<CONDITIONAL>::Clique::joint(shared_ptr C2, shared_ptr R) {
// pair<FactorGraph<Factor>, Ordering> // For now, assume neither is the root
// BayesTree<CONDITIONAL>::Clique::joint(shared_ptr C2, shared_ptr R) {
// // For now, assume neither is the root // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R)
// FactorGraph<typename CONDITIONAL::Factor> joint;
// // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R) if (!isRoot()) joint.push_back(*this); // P(F1|S1)
// sharedBayesNet bn(new BayesNet<CONDITIONAL>); if (!isRoot()) joint.push_back(shortcut(R)); // P(S1|R)
// if (!isRoot()) bn->push_back(*this); // P(F1|S1) if (!C2->isRoot()) joint.push_back(*C2); // P(F2|S2)
// if (!isRoot()) bn->push_back(shortcut<Factor>(R)); // P(S1|R) if (!C2->isRoot()) joint.push_back(C2->shortcut(R)); // P(S2|R)
// if (!C2->isRoot()) bn->push_back(*C2); // P(F2|S2) joint.push_back(*R); // P(R)
// if (!C2->isRoot()) bn->push_back(C2->shortcut<Factor>(R)); // P(S2|R)
// bn->push_back(*R); // P(R) // Find the keys of both C1 and C2
// vector<Index> keys1(keys());
// // Find the keys of both C1 and C2 vector<Index> keys2(C2->keys());
// Ordering keys12 = keys(); FastSet<Index> keys12;
// BOOST_FOREACH(Index key,C2->keys()) keys12.push_back(key); keys12.insert(keys1.begin(), keys1.end());
// keys12.unique(); keys12.insert(keys2.begin(), keys2.end());
//
// // Calculate the marginal // Calculate the marginal
// return make_pair(marginalize<Factor,CONDITIONAL>(*bn,keys12), keys12); vector<Index> keys12vector; keys12vector.reserve(keys12.size());
// } keys12vector.insert(keys12vector.begin(), keys12.begin(), keys12.end());
return *GenericSequentialSolver<typename CONDITIONAL::Factor>(joint).joint(keys12vector);
}
/* ************************************************************************* */ /* ************************************************************************* */
template<class CONDITIONAL> template<class CONDITIONAL>
@ -693,7 +688,7 @@ namespace gtsam {
} }
// Now fill in the nodes index // Now fill in the nodes index
if(subtree->back()->key() > (nodes_.size() - 1)) if(nodes_.size() == 0 || subtree->back()->key() > (nodes_.size() - 1))
nodes_.resize(subtree->back()->key() + 1); nodes_.resize(subtree->back()->key() + 1);
fillNodesIndex(subtree); fillNodesIndex(subtree);
} }
@ -712,79 +707,45 @@ namespace gtsam {
FactorGraph<typename CONDITIONAL::Factor> cliqueMarginal = clique->marginal(root_); FactorGraph<typename CONDITIONAL::Factor> cliqueMarginal = clique->marginal(root_);
return GenericSequentialSolver<typename CONDITIONAL::Factor>(cliqueMarginal).marginal(key); return GenericSequentialSolver<typename CONDITIONAL::Factor>(cliqueMarginal).marginal(key);
// // Reorder so that only the requested key is not eliminated
// typename FACTORGRAPH::variableindex_type varIndex(cliqueMarginal);
// vector<Index> keyAsVector(1); keyAsVector[0] = key;
// Permutation toBack(Permutation::PushToBack(keyAsVector, varIndex.size()));
// Permutation::shared_ptr toBackInverse(toBack.inverse());
// varIndex.permute(toBack);
// BOOST_FOREACH(const typename FACTORGRAPH::sharedFactor& factor, cliqueMarginal) {
// factor->permuteWithInverse(*toBackInverse);
// }
//
// // partially eliminate, remaining factor graph is requested marginal
// SymbolicSequentialSolver::EliminateUntil(cliqueMarginal, varIndex.size()-1, varIndex);
// BOOST_FOREACH(const typename FACTORGRAPH::sharedFactor& factor, cliqueMarginal) {
// if(factor)
// factor->permuteWithInverse(toBack);
// }
// return cliqueMarginal;
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class CONDITIONAL> template<class CONDITIONAL>
typename BayesNet<CONDITIONAL>::shared_ptr BayesTree<CONDITIONAL>::marginalBayesNet(Index key) const { typename BayesNet<CONDITIONAL>::shared_ptr BayesTree<CONDITIONAL>::marginalBayesNet(Index key) const {
// calculate marginal as a factor graph // calculate marginal as a factor graph
typename FactorGraph<typename CONDITIONAL::Factor>::shared_ptr fg( FactorGraph<typename CONDITIONAL::Factor> fg;
new FactorGraph<typename CONDITIONAL::Factor>()); fg.push_back(this->marginal(key));
fg->push_back(this->marginal(key));
// eliminate further to Bayes net // eliminate factor graph marginal to a Bayes net
return GenericSequentialSolver<typename CONDITIONAL::Factor>(*fg).eliminate(); return GenericSequentialSolver<typename CONDITIONAL::Factor>(fg).eliminate();
} }
// /* ************************************************************************* */ /* ************************************************************************* */
// // Find two cliques, their joint, then marginalizes // Find two cliques, their joint, then marginalizes
// /* ************************************************************************* */ /* ************************************************************************* */
// template<class CONDITIONAL> template<class CONDITIONAL>
// template<class Factor> typename FactorGraph<typename CONDITIONAL::Factor>::shared_ptr
// FactorGraph<Factor> BayesTree<CONDITIONAL>::joint(Index key1, Index key2) const {
// BayesTree<CONDITIONAL>::joint(Index key1, Index key2) const {
//
// // get clique C1 and C2
// sharedClique C1 = (*this)[key1], C2 = (*this)[key2];
//
// // calculate joint
// Ordering ord;
// FactorGraph<Factor> p_C1C2;
// boost::tie(p_C1C2,ord) = C1->joint<Factor>(C2,root_);
//
// // create an ordering where both requested keys are not eliminated
// ord.remove(key1);
// ord.remove(key2);
//
// // partially eliminate, remaining factor graph is requested joint
// // TODO, make eliminate functional
// eliminate<Factor,CONDITIONAL>(p_C1C2,ord);
// return p_C1C2;
// }
// /* ************************************************************************* */ // get clique C1 and C2
// template<class CONDITIONAL> sharedClique C1 = (*this)[key1], C2 = (*this)[key2];
// template<class Factor>
// BayesNet<CONDITIONAL> // calculate joint
// BayesTree<CONDITIONAL>::jointBayesNet(Index key1, Index key2) const { FactorGraph<typename CONDITIONAL::Factor> p_C1C2(C1->joint(C2, root_));
//
// // calculate marginal as a factor graph // eliminate remaining factor graph to get requested joint
// FactorGraph<Factor> fg = this->joint<Factor>(key1,key2); vector<Index> key12(2); key12[0] = key1; key12[1] = key2;
// return GenericSequentialSolver<typename CONDITIONAL::Factor>(p_C1C2).joint(key12);
// // eliminate further to Bayes net }
// Ordering ordering;
// ordering += key1, key2; /* ************************************************************************* */
// return eliminate<Factor,CONDITIONAL>(fg,ordering); template<class CONDITIONAL>
// } typename BayesNet<CONDITIONAL>::shared_ptr BayesTree<CONDITIONAL>::jointBayesNet(Index key1, Index key2) const {
// eliminate factor graph marginal to a Bayes net
return GenericSequentialSolver<typename CONDITIONAL::Factor>(*this->joint(key1, key2)).eliminate();
}
/* ************************************************************************* */ /* ************************************************************************* */
template<class CONDITIONAL> template<class CONDITIONAL>

View File

@ -113,9 +113,8 @@ namespace gtsam {
/** return the marginal P(C) of the clique */ /** return the marginal P(C) of the clique */
FactorGraph<typename CONDITIONAL::Factor> marginal(shared_ptr root); FactorGraph<typename CONDITIONAL::Factor> marginal(shared_ptr root);
// /** return the joint P(C1,C2), where C1==this. TODO: not a method? */ /** return the joint P(C1,C2), where C1==this. TODO: not a method? */
// template<class Factor> FactorGraph<typename CONDITIONAL::Factor> joint(shared_ptr C2, shared_ptr root);
// std::pair<FactorGraph<Factor>,Ordering> joint(shared_ptr C2, shared_ptr root);
}; };
// typedef for shared pointers to cliques // typedef for shared pointers to cliques
@ -261,13 +260,11 @@ namespace gtsam {
/** return marginal on any variable, as a Bayes Net */ /** return marginal on any variable, as a Bayes Net */
typename BayesNet<CONDITIONAL>::shared_ptr marginalBayesNet(Index key) const; typename BayesNet<CONDITIONAL>::shared_ptr marginalBayesNet(Index key) const;
// /** return joint on two variables */ /** return joint on two variables */
// template<class Factor> typename FactorGraph<typename CONDITIONAL::Factor>::shared_ptr joint(Index key1, Index key2) const;
// FactorGraph<Factor> joint(Index key1, Index key2) const;
// /** return joint on two variables as a BayesNet */
// /** return joint on two variables as a BayesNet */ typename BayesNet<CONDITIONAL>::shared_ptr jointBayesNet(Index key1, Index key2) const;
// template<class Factor>
// BayesNet<CONDITIONAL> jointBayesNet(Index key1, Index key2) const;
/** /**
* Read only with side effects * Read only with side effects

View File

@ -26,6 +26,8 @@
#include <boost/foreach.hpp> #include <boost/foreach.hpp>
using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
@ -42,6 +44,23 @@ GenericMultifrontalSolver<FACTOR, JUNCTIONTREE>::eliminate() const {
return bayesTree; return bayesTree;
} }
/* ************************************************************************* */
template<class FACTOR, class JUNCTIONTREE>
typename FactorGraph<FACTOR>::shared_ptr
GenericMultifrontalSolver<FACTOR, JUNCTIONTREE>::joint(const std::vector<Index>& js) const {
// We currently have code written only for computing the
if(js.size() != 2)
throw domain_error(
"*MultifrontalSolver::joint(js) currently can only compute joint marginals\n"
"for exactly two variables. You can call marginal to compute the\n"
"marginal for one variable. *SequentialSolver::joint(js) can compute the\n"
"joint marginal over any number of variables, so use that if necessary.\n");
return eliminate()->joint(js[0], js[1]);
}
/* ************************************************************************* */ /* ************************************************************************* */
template<class FACTOR, class JUNCTIONTREE> template<class FACTOR, class JUNCTIONTREE>
typename FACTOR::shared_ptr GenericMultifrontalSolver<FACTOR, JUNCTIONTREE>::marginal(Index j) const { typename FACTOR::shared_ptr GenericMultifrontalSolver<FACTOR, JUNCTIONTREE>::marginal(Index j) const {

View File

@ -48,6 +48,13 @@ public:
*/ */
typename JUNCTIONTREE::BayesTree::shared_ptr eliminate() const; typename JUNCTIONTREE::BayesTree::shared_ptr eliminate() const;
/**
* Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. This function returns the result as a factor
* graph.
*/
typename FactorGraph<FACTOR>::shared_ptr joint(const std::vector<Index>& js) const;
/** /**
* Compute the marginal Gaussian density over a variable, by integrating out * Compute the marginal Gaussian density over a variable, by integrating out
* all of the other variables. This function returns the result as a factor. * all of the other variables. This function returns the result as a factor.

View File

@ -54,12 +54,6 @@ public:
*/ */
typename BayesNet<typename FACTOR::Conditional>::shared_ptr eliminate() const; typename BayesNet<typename FACTOR::Conditional>::shared_ptr eliminate() const;
/**
* Compute the marginal Gaussian density over a variable, by integrating out
* all of the other variables. This function returns the result as a factor.
*/
typename FACTOR::shared_ptr marginal(Index j) const;
/** /**
* Compute the marginal joint over a set of variables, by integrating out * Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. This function returns the result as a factor * all of the other variables. This function returns the result as a factor
@ -67,6 +61,12 @@ public:
*/ */
typename FactorGraph<FACTOR>::shared_ptr joint(const std::vector<Index>& js) const; typename FactorGraph<FACTOR>::shared_ptr joint(const std::vector<Index>& js) const;
/**
* Compute the marginal Gaussian density over a variable, by integrating out
* all of the other variables. This function returns the result as a factor.
*/
typename FACTOR::shared_ptr marginal(Index j) const;
}; };
} }

View File

@ -278,66 +278,66 @@ TEST( BayesTree, balanced_smoother_shortcuts )
//} //}
/* ************************************************************************* */ /* ************************************************************************* */
// SL-FIX TEST( BayesTree, balanced_smoother_joint ) TEST( BayesTree, balanced_smoother_joint )
//{ {
// // Create smoother with 7 nodes // Create smoother with 7 nodes
// GaussianFactorGraph smoother = createSmoother(7); Ordering ordering;
// Ordering ordering; ordering += "x1","x3","x5","x7","x2","x6","x4";
// ordering += "x1","x3","x5","x7","x2","x6","x4"; GaussianFactorGraph smoother = createSmoother(7, ordering).first;
//
// // Create the Bayes tree, expected to look like: // Create the Bayes tree, expected to look like:
// // x5 x6 x4 // x5 x6 x4
// // x3 x2 : x4 // x3 x2 : x4
// // x1 : x2 // x1 : x2
// // x7 : x6 // x7 : x6
// GaussianBayesNet chordalBayesNet = smoother.eliminate(ordering); GaussianBayesNet chordalBayesNet = *GaussianSequentialSolver(smoother).eliminate();
// GaussianISAM bayesTree(chordalBayesNet); GaussianISAM bayesTree(chordalBayesNet);
//
// // Conditional density elements reused by both tests // Conditional density elements reused by both tests
// const Vector sigma = ones(2); const Vector sigma = ones(2);
// const Matrix I = eye(2), A = -0.00429185*I; const Matrix I = eye(2), A = -0.00429185*I;
//
// // Check the joint density P(x1,x7) factored as P(x1|x7)P(x7) // Check the joint density P(x1,x7) factored as P(x1|x7)P(x7)
// GaussianBayesNet expected1; GaussianBayesNet expected1;
// // Why does the sign get flipped on the prior? // Why does the sign get flipped on the prior?
// GaussianConditional::shared_ptr GaussianConditional::shared_ptr
// parent1(new GaussianConditional("x7", zero(2), -1*I/sigmax7, ones(2))); parent1(new GaussianConditional(ordering["x7"], zero(2), -1*I/sigmax7, ones(2)));
// expected1.push_front(parent1); expected1.push_front(parent1);
// push_front(expected1,"x1", zero(2), I/sigmax7, "x7", A/sigmax7, sigma); push_front(expected1,ordering["x1"], zero(2), I/sigmax7, ordering["x7"], A/sigmax7, sigma);
// GaussianBayesNet actual1 = bayesTree.jointBayesNet<GaussianFactor>("x1","x7"); GaussianBayesNet actual1 = *bayesTree.jointBayesNet(ordering["x1"],ordering["x7"]);
// CHECK(assert_equal(expected1,actual1,tol)); CHECK(assert_equal(expected1,actual1,tol));
//
// // Check the joint density P(x7,x1) factored as P(x7|x1)P(x1) // // Check the joint density P(x7,x1) factored as P(x7|x1)P(x1)
// GaussianBayesNet expected2; // GaussianBayesNet expected2;
// GaussianConditional::shared_ptr // GaussianConditional::shared_ptr
// parent2(new GaussianConditional("x1", zero(2), -1*I/sigmax1, ones(2))); // parent2(new GaussianConditional(ordering["x1"], zero(2), -1*I/sigmax1, ones(2)));
// expected2.push_front(parent2); // expected2.push_front(parent2);
// push_front(expected2,"x7", zero(2), I/sigmax1, "x1", A/sigmax1, sigma); // push_front(expected2,ordering["x7"], zero(2), I/sigmax1, ordering["x1"], A/sigmax1, sigma);
// GaussianBayesNet actual2 = bayesTree.jointBayesNet<GaussianFactor>("x7","x1"); // GaussianBayesNet actual2 = *bayesTree.jointBayesNet(ordering["x7"],ordering["x1"]);
// CHECK(assert_equal(expected2,actual2,tol)); // CHECK(assert_equal(expected2,actual2,tol));
//
// // Check the joint density P(x1,x4), i.e. with a root variable // Check the joint density P(x1,x4), i.e. with a root variable
// GaussianBayesNet expected3; GaussianBayesNet expected3;
// GaussianConditional::shared_ptr GaussianConditional::shared_ptr
// parent3(new GaussianConditional("x4", zero(2), I/sigmax4, ones(2))); parent3(new GaussianConditional(ordering["x4"], zero(2), I/sigmax4, ones(2)));
// expected3.push_front(parent3); expected3.push_front(parent3);
// double sig14 = 0.784465; double sig14 = 0.784465;
// Matrix A14 = -0.0769231*I; Matrix A14 = -0.0769231*I;
// push_front(expected3,"x1", zero(2), I/sig14, "x4", A14/sig14, sigma); push_front(expected3,ordering["x1"], zero(2), I/sig14, ordering["x4"], A14/sig14, sigma);
// GaussianBayesNet actual3 = bayesTree.jointBayesNet<GaussianFactor>("x1","x4"); GaussianBayesNet actual3 = *bayesTree.jointBayesNet(ordering["x1"],ordering["x4"]);
// CHECK(assert_equal(expected3,actual3,tol)); CHECK(assert_equal(expected3,actual3,tol));
//
// // Check the joint density P(x4,x1), i.e. with a root variable, factored the other way // // Check the joint density P(x4,x1), i.e. with a root variable, factored the other way
// GaussianBayesNet expected4; // GaussianBayesNet expected4;
// GaussianConditional::shared_ptr // GaussianConditional::shared_ptr
// parent4(new GaussianConditional("x1", zero(2), -1.0*I/sigmax1, ones(2))); // parent4(new GaussianConditional(ordering["x1"], zero(2), -1.0*I/sigmax1, ones(2)));
// expected4.push_front(parent4); // expected4.push_front(parent4);
// double sig41 = 0.668096; // double sig41 = 0.668096;
// Matrix A41 = -0.055794*I; // Matrix A41 = -0.055794*I;
// push_front(expected4,"x4", zero(2), I/sig41, "x1", A41/sig41, sigma); // push_front(expected4,ordering["x4"], zero(2), I/sig41, ordering["x1"], A41/sig41, sigma);
// GaussianBayesNet actual4 = bayesTree.jointBayesNet<GaussianFactor>("x4","x1"); // GaussianBayesNet actual4 = *bayesTree.jointBayesNet(ordering["x4"],ordering["x1"]);
// CHECK(assert_equal(expected4,actual4,tol)); // CHECK(assert_equal(expected4,actual4,tol));
//} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);} int main() { TestResult tr; return TestRegistry::runAllTests(tr);}