diff --git a/gtsam/base/DSFVector.cpp b/gtsam/base/DSFVector.cpp index 0881465ba..42a3bb96c 100644 --- a/gtsam/base/DSFVector.cpp +++ b/gtsam/base/DSFVector.cpp @@ -18,6 +18,8 @@ * As a result, the size of the forest is prefixed. */ +#include +#include #include using namespace std; @@ -26,23 +28,27 @@ namespace gtsam { /* ************************************************************************* */ DSFVector::DSFVector (const size_t numNodes) { - resize(numNodes); + v_ = boost::make_shared(numNodes); int index = 0; - for(iterator it = begin(); it!=end(); it++, index++) + keys_.reserve(numNodes); + for(V::iterator it = v_->begin(); it!=v_->end(); it++, index++) { *it = index; + keys_.push_back(index); + } } /* ************************************************************************* */ - DSFVector::Label DSFVector::findSet(const size_t& key) const { - size_t parent = at(key); - return parent == key ? key : findSet(parent); + DSFVector::DSFVector(const boost::shared_ptr& v_in, const std::vector& keys) : keys_(keys) { + v_ = v_in; + BOOST_FOREACH(const size_t key, keys) + (*v_)[key] = key; } /* ************************************************************************* */ bool DSFVector::isSingleton(const Label& label) const { bool result = false; - std::vector::const_iterator it = begin(); - for (; it != end(); ++it) { + V::const_iterator it = keys_.begin(); + for (; it != keys_.end(); ++it) { if(findSet(*it) == label) { if (!result) // find the first occurrence result = true; @@ -56,11 +62,10 @@ namespace gtsam { /* ************************************************************************* */ std::set DSFVector::set(const Label& label) const { std::set set; - size_t key = 0; - std::vector::const_iterator it = begin(); - for (; it != end(); it++, key++) { + V::const_iterator it = keys_.begin(); + for (; it != keys_.end(); it++) { if (findSet(*it) == label) - set.insert(key); + set.insert(*it); } return set; } @@ -68,17 +73,26 @@ namespace gtsam { /* ************************************************************************* */ std::map > DSFVector::sets() const { std::map > sets; - size_t key = 0; - std::vector::const_iterator it = begin(); - for (; it != end(); it++, key++) { - sets[findSet(*it)].insert(key); + V::const_iterator it = keys_.begin(); + for (; it != keys_.end(); it++) { + sets[findSet(*it)].insert(*it); } return sets; } + /* ************************************************************************* */ + std::map > DSFVector::arrays() const { + std::map > arrays; + V::const_iterator it = keys_.begin(); + for (; it != keys_.end(); it++) { + arrays[findSet(*it)].push_back(*it); + } + return arrays; + } + /* ************************************************************************* */ void DSFVector::makeUnionInPlace(const size_t& i1, const size_t& i2) { - at(findSet(i2)) = findSet(i1); + (*v_)[findSet(i2)] = findSet(i1); } } // namespace diff --git a/gtsam/base/DSFVector.h b/gtsam/base/DSFVector.h index 435304ba5..c69582633 100644 --- a/gtsam/base/DSFVector.h +++ b/gtsam/base/DSFVector.h @@ -23,23 +23,41 @@ #include #include #include +#include namespace gtsam { /** * A fast impelementation of disjoint set forests that uses vector as underly data structure. */ - class DSFVector : protected std::vector { - private: + class DSFVector { public: + typedef std::vector V; typedef size_t Label; + typedef std::vector::const_iterator const_iterator; + typedef std::vector::iterator iterator; - // constructor - DSFVector(const std::size_t numNodes); + private: + boost::shared_ptr v_; // could use existing memory to improve the efficiency + std::vector keys_; + + public: + // constructor that allocate a new memory + DSFVector(const size_t numNodes); + + // constructor that uses the existing memory + DSFVector(const boost::shared_ptr& v_in, const std::vector& keys); // find the label of the set in which {key} lives - Label findSet(const size_t& key) const; + inline Label findSet(size_t key) const { + size_t parent = (*v_)[key]; + while (parent != key) { + key = parent; + parent = (*v_)[key]; + } + return parent; + } // find whether there is one and only one occurrence for the given {label} bool isSingleton(const Label& label) const; @@ -49,9 +67,10 @@ namespace gtsam { // return all sets, i.e. a partition of all elements std::map > sets() const; + std::map > arrays() const; // the in-place version of makeUnion - void makeUnionInPlace(const std::size_t& i1, const std::size_t& i2); + void makeUnionInPlace(const size_t& i1, const size_t& i2); }; diff --git a/gtsam/base/tests/testDSFVector.cpp b/gtsam/base/tests/testDSFVector.cpp index ce4ea7f21..81e4f3b07 100644 --- a/gtsam/base/tests/testDSFVector.cpp +++ b/gtsam/base/tests/testDSFVector.cpp @@ -18,8 +18,10 @@ */ #include +#include #include #include +#include using namespace boost::assign; #include @@ -41,6 +43,15 @@ TEST(DSFVectorVector, makeUnionInPlace) { CHECK(dsf.findSet(0) == dsf.findSet(2)); } +/* ************************************************************************* */ +TEST(DSFVectorVector, makeUnionInPlace2) { + boost::shared_ptr v = boost::make_shared(5); + std::vector keys; keys += 1, 3; + DSFVector dsf(v, keys); + dsf.makeUnionInPlace(1,3); + CHECK(dsf.findSet(1) == dsf.findSet(3)); +} + /* ************************************************************************* */ TEST(DSFVector, makeUnion2) { DSFVector dsf(3); @@ -67,6 +78,17 @@ TEST(DSFVector, sets) { CHECK(expected == sets[dsf.findSet(0)]); } +/* ************************************************************************* */ +TEST(DSFVector, arrays) { + DSFVector dsf(2); + dsf.makeUnionInPlace(0,1); + map > arrays = dsf.arrays(); + LONGS_EQUAL(1, arrays.size()); + + vector expected; expected += 0, 1; + CHECK(expected == arrays[dsf.findSet(0)]); +} + /* ************************************************************************* */ TEST(DSFVector, sets2) { DSFVector dsf(3); @@ -79,6 +101,18 @@ TEST(DSFVector, sets2) { CHECK(expected == sets[dsf.findSet(0)]); } +/* ************************************************************************* */ +TEST(DSFVector, arrays2) { + DSFVector dsf(3); + dsf.makeUnionInPlace(0,1); + dsf.makeUnionInPlace(1,2); + map > arrays = dsf.arrays(); + LONGS_EQUAL(1, arrays.size()); + + vector expected; expected += 0, 1, 2; + CHECK(expected == arrays[dsf.findSet(0)]); +} + /* ************************************************************************* */ TEST(DSFVector, sets3) { DSFVector dsf(3); @@ -90,6 +124,17 @@ TEST(DSFVector, sets3) { CHECK(expected == sets[dsf.findSet(0)]); } +/* ************************************************************************* */ +TEST(DSFVector, arrays3) { + DSFVector dsf(3); + dsf.makeUnionInPlace(0,1); + map > arrays = dsf.arrays(); + LONGS_EQUAL(2, arrays.size()); + + vector expected; expected += 0, 1; + CHECK(expected == arrays[dsf.findSet(0)]); +} + /* ************************************************************************* */ TEST(DSFVector, set) { DSFVector dsf(3);