Unit tests for DepthFirstForest and CloneForest

release/4.3a0
Richard Roberts 2013-06-06 15:36:53 +00:00
parent ec2df2df3c
commit 33443b3a13
2 changed files with 81 additions and 18 deletions

View File

@ -16,6 +16,8 @@
*/ */
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/base/treeTraversal-inst.h>
#include <vector> #include <vector>
#include <list> #include <list>
@ -25,16 +27,19 @@
using boost::assign::operator+=; using boost::assign::operator+=;
using namespace std; using namespace std;
using namespace gtsam;
struct TestNode { struct TestNode {
typedef boost::shared_ptr<TestNode> shared_ptr; typedef boost::shared_ptr<TestNode> shared_ptr;
int data; int data;
vector<shared_ptr> children; vector<shared_ptr> children;
TestNode() : data(-1) {}
TestNode(int data) : data(data) {} TestNode(int data) : data(data) {}
}; };
struct TestForest { struct TestForest {
typedef TestNode::shared_ptr sharedNode; typedef TestNode Node;
typedef Node::shared_ptr sharedNode;
vector<sharedNode> roots_; vector<sharedNode> roots_;
const vector<sharedNode>& roots() const { return roots_; } const vector<sharedNode>& roots() const { return roots_; }
}; };
@ -70,8 +75,9 @@ struct PreOrderVisitor {
node->data == 1 ? -1 : node->data == 1 ? -1 :
node->data == 2 ? 0 : node->data == 2 ? 0 :
node->data == 3 ? 0 : node->data == 3 ? 0 :
node->data == 4 ? 0 : node->data == 4 ? 3 :
(throw std::runtime_error("Unexpected node index"), -1); node->data == 10 ? 0 :
(parentsMatched = false, -1);
if(expectedParentIndex != parentData) if(expectedParentIndex != parentData)
parentsMatched = false; parentsMatched = false;
return node->data; return node->data;
@ -87,17 +93,65 @@ struct PostOrderVisitor {
} }
}; };
/* ************************************************************************* */
std::list<int> getPreorder(const TestForest& forest) {
std::list<int> result;
PreOrderVisitor preVisitor;
int rootData = -1;
treeTraversal::DepthFirstForest(forest, rootData, preVisitor);
result = preVisitor.visited;
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(treeTraversal, DepthFirst) TEST(treeTraversal, DepthFirst)
{ {
// Get test forest // Get test forest
TestForest testForest = makeTestForest(); TestForest testForest = makeTestForest();
// Expected pre-order // Expected visit order
std::list<int> preOrderExpected; std::list<int> preOrderExpected;
preOrderExpected += 0, 2, 3, 4, 1; preOrderExpected += 0, 2, 3, 4, 1;
std::list<int> postOrderExpected; std::list<int> postOrderExpected;
postOrderExpected += 2, 4, 3, 0, 1; postOrderExpected += 2, 4, 3, 0, 1;
// Actual visit order
PreOrderVisitor preVisitor;
PostOrderVisitor postVisitor;
int rootData = -1;
treeTraversal::DepthFirstForest(testForest, rootData, preVisitor, postVisitor);
EXPECT(preVisitor.parentsMatched);
EXPECT(assert_container_equality(preOrderExpected, preVisitor.visited));
EXPECT(assert_container_equality(postOrderExpected, postVisitor.visited));
}
/* ************************************************************************* */
TEST(treeTraversal, CloneForest)
{
// Get test forest
TestForest testForest1 = makeTestForest();
TestForest testForest2;
testForest2.roots_ = treeTraversal::CloneForest(testForest1);
// Check that the original and clone both are expected
std::list<int> preOrder1Expected;
preOrder1Expected += 0, 2, 3, 4, 1;
std::list<int> preOrder1Actual = getPreorder(testForest1);
std::list<int> preOrder2Actual = getPreorder(testForest2);
EXPECT(assert_container_equality(preOrder1Expected, preOrder1Actual));
EXPECT(assert_container_equality(preOrder1Expected, preOrder2Actual));
// Modify clone - should not modify original
testForest2.roots_[0]->children[1]->data = 10;
std::list<int> preOrderModifiedExpected;
preOrderModifiedExpected += 0, 2, 10, 4, 1;
// Check that original is the same and only the clone is modified
std::list<int> preOrder1ModActual = getPreorder(testForest1);
std::list<int> preOrder2ModActual = getPreorder(testForest2);
EXPECT(assert_container_equality(preOrder1Expected, preOrder1ModActual));
EXPECT(assert_container_equality(preOrderModifiedExpected, preOrder2ModActual));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -34,9 +34,10 @@ namespace gtsam {
struct TraversalNode { struct TraversalNode {
bool expanded; bool expanded;
const boost::shared_ptr<NODE>& treeNode; const boost::shared_ptr<NODE>& treeNode;
DATA data; DATA& parentData;
TraversalNode(const boost::shared_ptr<NODE>& _treeNode, const DATA& _data) : typename FastList<DATA>::iterator dataPointer;
expanded(false), treeNode(_treeNode), data(_data) {} TraversalNode(const boost::shared_ptr<NODE>& _treeNode, DATA& _parentData) :
expanded(false), treeNode(_treeNode), parentData(_parentData) {}
}; };
/// Do nothing - default argument for post-visitor for tree traversal /// Do nothing - default argument for post-visitor for tree traversal
@ -67,33 +68,40 @@ namespace gtsam {
// Depth first traversal stack // Depth first traversal stack
typedef TraversalNode<typename FOREST::Node, DATA> TraversalNode; typedef TraversalNode<typename FOREST::Node, DATA> TraversalNode;
typedef std::stack<TraversalNode, FastList<TraversalNode> > Stack; typedef FastList<TraversalNode> Stack;
Stack stack; Stack stack;
FastList<DATA> dataList; // List to store node data as it is returned from the pre-order visitor
// Add roots to stack (use reverse iterators so children are processed in the order they // Add roots to stack (insert such that they are visited and processed in order
// appear) {
BOOST_REVERSE_FOREACH(const sharedNode& root, forest.roots()) Stack::iterator insertLocation = stack.begin();
stack.push(TraversalNode(root, visitorPre(root, rootData))); BOOST_FOREACH(const sharedNode& root, forest.roots())
stack.insert(insertLocation, TraversalNode(root, rootData));
}
// Traverse // Traverse
while(!stack.empty()) while(!stack.empty())
{ {
// Get next node // Get next node
TraversalNode& node = stack.top(); TraversalNode& node = stack.front();
if(node.expanded) { if(node.expanded) {
// If already expanded, then the data stored in the node is no longer needed, so visit // If already expanded, then the data stored in the node is no longer needed, so visit
// then delete it. // then delete it.
(void) visitorPost(node.treeNode, node.data); (void) visitorPost(node.treeNode, *node.dataPointer);
stack.pop(); dataList.erase(node.dataPointer);
stack.pop_front();
} else { } else {
// If not already visited, visit the node and add its children (use reverse iterators so // If not already visited, visit the node and add its children (use reverse iterators so
// children are processed in the order they appear) // children are processed in the order they appear)
BOOST_REVERSE_FOREACH(const sharedNode& child, node.treeNode->children) node.dataPointer = dataList.insert(dataList.end(), visitorPre(node.treeNode, node.parentData));
stack.push(TraversalNode(child, visitorPre(child, node.data))); Stack::iterator insertLocation = stack.begin();
BOOST_FOREACH(const sharedNode& child, node.treeNode->children)
stack.insert(insertLocation, TraversalNode(child, *node.dataPointer));
node.expanded = true; node.expanded = true;
} }
} }
assert(dataList.empty());
} }
/** Traverse a forest depth-first, with a pre-order visit but no post-order visit. /** Traverse a forest depth-first, with a pre-order visit but no post-order visit.
@ -123,6 +131,7 @@ namespace gtsam {
{ {
// Clone the current node and add it to its cloned parent // Clone the current node and add it to its cloned parent
boost::shared_ptr<NODE> clone = boost::make_shared<NODE>(*node); boost::shared_ptr<NODE> clone = boost::make_shared<NODE>(*node);
clone->children.clear();
parentPointer->children.push_back(clone); parentPointer->children.push_back(clone);
return clone; return clone;
} }
@ -160,7 +169,7 @@ namespace gtsam {
/** Print a tree, prefixing each line with \c str, and formatting keys using \c keyFormatter. /** Print a tree, prefixing each line with \c str, and formatting keys using \c keyFormatter.
* To print each node, this function calls the \c print function of the tree nodes. */ * To print each node, this function calls the \c print function of the tree nodes. */
template<class FOREST> template<class FOREST>
void PrintForest(const FOREST& forest, const std::string& str, const KeyFormatter& keyFormatter) { void PrintForest(const FOREST& forest, std::string str, const KeyFormatter& keyFormatter) {
typedef typename FOREST::Node Node; typedef typename FOREST::Node Node;
DepthFirstForest(forest, str, boost::bind(PrintForestVisitorPre<Node>, _1, _2, keyFormatter)); DepthFirstForest(forest, str, boost::bind(PrintForestVisitorPre<Node>, _1, _2, keyFormatter));
} }