diff --git a/gtsam/inference/JunctionTree-inst.h b/gtsam/inference/JunctionTree-inst.h index b2012f3bf..296bdb434 100644 --- a/gtsam/inference/JunctionTree-inst.h +++ b/gtsam/inference/JunctionTree-inst.h @@ -58,7 +58,7 @@ namespace gtsam { /* ************************************************************************* */ // Post-order visitor function template - void ConstructorTraversalVisitorPost( + void ConstructorTraversalVisitorPostAlg1( const boost::shared_ptr& ETreeNode, const ConstructorTraversalData& myData) { @@ -114,6 +114,69 @@ namespace gtsam { } myData.myJTNode->problemSize_ = combinedProblemSize; } + + /* ************************************************************************* */ + // Post-order visitor function + template + void ConstructorTraversalVisitorPostAlg2( + const boost::shared_ptr& ETreeNode, + const ConstructorTraversalData& myData) + { + // In this post-order visitor, we combine the symbolic elimination results from the + // elimination tree children and symbolically eliminate the current elimination tree node. We + // then check whether each of our elimination tree child nodes should be merged with us. The + // check for this is that our number of symbolic elimination parents is exactly 1 less than + // our child's symbolic elimination parents - this condition indicates that eliminating the + // current node did not introduce any parents beyond those already in the child. + + // Do symbolic elimination for this node + SymbolicFactorGraph symbolicFactors; + symbolicFactors.reserve(ETreeNode->factors.size() + myData.childSymbolicFactors.size()); + // Add symbolic versions of the ETree node factors + BOOST_FOREACH(const typename GRAPH::sharedFactor& factor, ETreeNode->factors) { + symbolicFactors.push_back(boost::make_shared( + SymbolicFactor::FromKeys(*factor))); + } + // Add symbolic factors passed up from children + symbolicFactors.push_back(myData.childSymbolicFactors.begin(), myData.childSymbolicFactors.end()); + Ordering keyAsOrdering; keyAsOrdering.push_back(ETreeNode->key); + std::pair symbolicElimResult = + EliminateSymbolic(symbolicFactors, keyAsOrdering); + + // Store symbolic elimination results in the parent + myData.parentData->childSymbolicConditionals.push_back(symbolicElimResult.first); + myData.parentData->childSymbolicFactors.push_back(symbolicElimResult.second); + + // Merge our children if they are in our clique - if our conditional has exactly one fewer + // parent than our child's conditional. + size_t myNrFrontals = 1; + const size_t myNrParents = symbolicElimResult.first->nrParents(); + size_t nrMergedChildren = 0; + assert(myData.myJTNode->children.size() == myData.childSymbolicConditionals.size()); + // Loop over children + int combinedProblemSize = (int) (symbolicElimResult.first->size() * symbolicFactors.size()); + for(size_t child = 0; child < myData.childSymbolicConditionals.size(); ++child) { + // Check if we should merge the child + if(myNrParents + myNrFrontals == myData.childSymbolicConditionals[child]->nrParents()) { + // Get a reference to the child, adjusting the index to account for children previously + // merged and removed from the child list. + const typename JunctionTree::Node& childToMerge = + *myData.myJTNode->children[child - nrMergedChildren]; + // Merge keys, factors, and children. + myData.myJTNode->keys.insert(myData.myJTNode->keys.begin(), childToMerge.keys.begin(), childToMerge.keys.end()); + myData.myJTNode->factors.insert(myData.myJTNode->factors.end(), childToMerge.factors.begin(), childToMerge.factors.end()); + myData.myJTNode->children.insert(myData.myJTNode->children.end(), childToMerge.children.begin(), childToMerge.children.end()); + // Increment problem size + combinedProblemSize = std::max(combinedProblemSize, childToMerge.problemSize_); + // Remove child from list. + myData.myJTNode->children.erase(myData.myJTNode->children.begin() + (child - nrMergedChildren)); + // Increment number of merged children + myNrFrontals += childToMerge.keys.size(); + ++ nrMergedChildren; + } + } + myData.myJTNode->problemSize_ = combinedProblemSize; + } } /* ************************************************************************* */ @@ -134,7 +197,7 @@ namespace gtsam { ConstructorTraversalData rootData(0); rootData.myJTNode = boost::make_shared(); // Make a dummy node to gather the junction tree roots treeTraversal::DepthFirstForest(eliminationTree, rootData, - ConstructorTraversalVisitorPre, ConstructorTraversalVisitorPost); + ConstructorTraversalVisitorPre, ConstructorTraversalVisitorPostAlg2); // Assign roots from the dummy node Base::roots_ = rootData.myJTNode->children;