diff --git a/cython/gtsam/tests/test_dsf_map.py b/cython/gtsam/tests/test_dsf_map.py new file mode 100644 index 000000000..6eb551034 --- /dev/null +++ b/cython/gtsam/tests/test_dsf_map.py @@ -0,0 +1,38 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Disjoint Set Forest. +Author: Frank Dellaert & Varun Agrawal +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +from gtsam.utils.test_case import GtsamTestCase + + +class TestDSFMap(GtsamTestCase): + """Tests for DSFMap.""" + + def test_all(self): + """Test everything in DFSMap.""" + def key(index_pair): + return index_pair.i(), index_pair.j() + + dsf = gtsam.DSFMapIndexPair() + pair1 = gtsam.IndexPair(1, 18) + self.assertEqual(key(dsf.find(pair1)), key(pair1)) + pair2 = gtsam.IndexPair(2, 2) + dsf.merge(pair1, pair2) + self.assertTrue(dsf.find(pair1), dsf.find(pair1)) + + +if __name__ == '__main__': + unittest.main() diff --git a/gtsam.h b/gtsam.h index 73684c96a..aefe4e028 100644 --- a/gtsam.h +++ b/gtsam.h @@ -241,11 +241,28 @@ class FactorIndices { size_t back() const; void push_back(size_t factorIndex) const; }; + //************************************************************************* // base //************************************************************************* /** gtsam namespace functions */ + +#include +class IndexPair { + IndexPair(); + IndexPair(size_t i, size_t j); + size_t i() const; + size_t j() const; +}; + +template +class DSFMap { + DSFMap(); + KEY find(const KEY& key) const; + void merge(const KEY& x, const KEY& y); +}; + #include bool linear_independent(Matrix A, Matrix B, double tol); diff --git a/gtsam/base/DSFMap.h b/gtsam/base/DSFMap.h index 6ddb74cfd..dfce185dc 100644 --- a/gtsam/base/DSFMap.h +++ b/gtsam/base/DSFMap.h @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -18,9 +18,9 @@ #pragma once +#include // Provides size_t #include #include -#include // Provides size_t namespace gtsam { @@ -29,11 +29,9 @@ namespace gtsam { * Uses rank compression and union by rank, iterator version * @addtogroup base */ -template +template class DSFMap { - -protected: - + protected: /// We store the forest in an STL map, but parents are done with pointers struct Entry { typename std::map::iterator parent_; @@ -41,7 +39,7 @@ protected: Entry() {} }; - typedef typename std::map Map; + typedef typename std::map Map; typedef typename Map::iterator iterator; mutable Map entries_; @@ -62,8 +60,7 @@ protected: iterator find_(const iterator& it) const { // follow parent pointers until we reach set representative iterator& parent = it->second.parent_; - if (parent != it) - parent = find_(parent); // not yet, recurse! + if (parent != it) parent = find_(parent); // not yet, recurse! return parent; } @@ -73,13 +70,11 @@ protected: return find_(initial); } -public: - + public: typedef std::set Set; /// constructor - DSFMap() { - } + DSFMap() {} /// Given key, find the representative key for the set in which it lives inline KEY find(const KEY& key) const { @@ -89,12 +84,10 @@ public: /// Merge two sets void merge(const KEY& x, const KEY& y) { - // straight from http://en.wikipedia.org/wiki/Disjoint-set_data_structure iterator xRoot = find_(x); iterator yRoot = find_(y); - if (xRoot == yRoot) - return; + if (xRoot == yRoot) return; // Merge sets if (xRoot->second.rank_ < yRoot->second.rank_) @@ -117,7 +110,14 @@ public: } return sets; } - }; -} +/// Small utility class for representing a wrappable pairs of ints. +class IndexPair : public std::pair { + public: + IndexPair(): std::pair(0,0) {} + IndexPair(size_t i, size_t j) : std::pair(i,j) {} + inline size_t i() const { return first; }; + inline size_t j() const { return second; }; +}; +} // namespace gtsam diff --git a/gtsam/base/tests/testDSFMap.cpp b/gtsam/base/tests/testDSFMap.cpp index 5ffa0d12a..8bc4cd7ca 100644 --- a/gtsam/base/tests/testDSFMap.cpp +++ b/gtsam/base/tests/testDSFMap.cpp @@ -139,6 +139,12 @@ TEST(DSFMap, sets){ EXPECT(s2 == sets[4]); } +/* ************************************************************************* */ +TEST(DSFMap, findIndexPair) { + DSFMap dsf; + EXPECT(dsf.find(IndexPair(1,2))==IndexPair(1,2)); + EXPECT(dsf.find(IndexPair(1,2)) != dsf.find(IndexPair(1,3))); +} /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */