diff --git a/gtsam/base/SpecialCommaInitializer.h b/gtsam/base/SpecialCommaInitializer.h new file mode 100644 index 000000000..fbcbd09cc --- /dev/null +++ b/gtsam/base/SpecialCommaInitializer.h @@ -0,0 +1,218 @@ +/** + * @file SpecialCommaInitializer.h + * @brief A special comma initializer for Eigen that is implicitly convertible to Vector and Matrix. + * @author Richard Roberts + * @created Oct 10, 2013 + */ + +#pragma once + +#include + +namespace Eigen { + namespace internal { + // Row-vectors not tested + //template + //inline void resizeHelper(XprType& xpr, DenseIndex sizeIncrement, + // typename boost::enable_if_c< + // XprType::ColsAtCompileTime == Dynamic && XprType::RowsAtCompileTime == 1>::type* = 0) + //{ + // xpr.conservativeResize(xpr.cols() + sizeIncrement); + //} + + template + inline void resizeHelper(XprType& xpr, typename XprType::Index sizeIncrement, + typename boost::enable_if_c< + XprType::RowsAtCompileTime == Dynamic && XprType::ColsAtCompileTime == 1>::type* = 0) + { + xpr.conservativeResize(xpr.rows() + sizeIncrement); + } + + template + inline void resizeHelper(XprType& xpr, typename XprType::Index sizeIncrement, + typename boost::enable_if_c< + XprType::ColsAtCompileTime == Dynamic>::type* = 0) + { + assert(false); + } + } + + /// A special comma initializer for Eigen that is implicitly convertible to Vector and Matrix. + template + class SpecialCommaInitializer : + public CommaInitializer, + public MatrixBase > + { + private: + bool dynamic_; + + public: + typedef MatrixBase > Base; + typedef CommaInitializer CommaBase; + + EIGEN_DENSE_PUBLIC_INTERFACE(SpecialCommaInitializer) + typedef typename internal::conditional::ret, + XprType, const XprType&>::type ExpressionTypeNested; + typedef typename XprType::InnerIterator InnerIterator; + + // Forward to base class + inline SpecialCommaInitializer(XprType& xpr, const typename XprType::Scalar& s, bool dynamic) : + CommaBase(xpr, s), dynamic_(dynamic) {} + + // Forward to base class + template + inline SpecialCommaInitializer(XprType& xpr, const DenseBase& other, bool dynamic) : + CommaBase(xpr, other), dynamic_(dynamic) {} + + inline Index rows() const { return CommaBase::m_xpr.rows(); } + inline Index cols() const { return CommaBase::m_xpr.cols(); } + inline Index outerStride() const { return CommaBase::m_xpr.outerStride(); } + inline Index innerStride() const { return CommaBase::m_xpr.innerStride(); } + + inline CoeffReturnType coeff(Index row, Index col) const + { + return CommaBase::m_xpr.coeff(row, col); + } + + inline CoeffReturnType coeff(Index index) const + { + return CommaBase::m_xpr.coeff(index); + } + + inline const Scalar& coeffRef(Index row, Index col) const + { + return CommaBase::m_xpr.const_cast_derived().coeffRef(row, col); + } + + inline const Scalar& coeffRef(Index index) const + { + return CommaBase::m_xpr.const_cast_derived().coeffRef(index); + } + + inline Scalar& coeffRef(Index row, Index col) + { + return CommaBase::m_xpr.const_cast_derived().coeffRef(row, col); + } + + inline Scalar& coeffRef(Index index) + { + return CommaBase::m_xpr.const_cast_derived().coeffRef(index); + } + + template + inline const PacketScalar packet(Index row, Index col) const + { + return CommaBase::m_xpr.template packet(row, col); + } + + template + inline void writePacket(Index row, Index col, const PacketScalar& x) + { + CommaBase::m_xpr.const_cast_derived().template writePacket(row, col, x); + } + + template + inline const PacketScalar packet(Index index) const + { + return CommaBase::m_xpr.template packet(index); + } + + template + inline void writePacket(Index index, const PacketScalar& x) + { + CommaBase::m_xpr.const_cast_derived().template writePacket(index, x); + } + + const XprType& _expression() const { return CommaBase::m_xpr; } + + /// Override base class comma operators to return this class instead of the base class. + SpecialCommaInitializer& operator,(const typename XprType::Scalar& s) + { + // If dynamic, resize the underlying object + if(dynamic_) + { + // Dynamic expansion currently only tested for column-vectors + assert(XprType::RowsAtCompileTime == Dynamic); + // Current col should be zero and row should be at the end + assert(CommaBase::m_col == 1); + assert(CommaBase::m_row == CommaBase::m_xpr.rows() - CommaBase::m_currentBlockRows); + resizeHelper(CommaBase::m_xpr, 1); + } + (void) CommaBase::operator,(s); + return *this; + } + + /// Override base class comma operators to return this class instead of the base class. + template + SpecialCommaInitializer& operator,(const DenseBase& other) + { + // If dynamic, resize the underlying object + if(dynamic_) + { + // Dynamic expansion currently only tested for column-vectors + assert(XprType::RowsAtCompileTime == Dynamic); + // Current col should be zero and row should be at the end + assert(CommaBase::m_col == 1); + assert(CommaBase::m_row == CommaBase::m_xpr.rows() - CommaBase::m_currentBlockRows); + resizeHelper(CommaBase::m_xpr, other.size()); + } + (void) CommaBase::operator,(other); + return *this; + } + }; + + namespace internal { + template + struct traits > : traits + { + }; + } + +} + +namespace gtsam { + class Vec + { + Eigen::VectorXd vector_; + bool dynamic_; + + public: + Vec(Eigen::VectorXd::Index size) : vector_(size), dynamic_(false) {} + + Vec() : dynamic_(true) {} + + Eigen::SpecialCommaInitializer operator<< (double s) + { + if(dynamic_) + vector_.resize(1); + return Eigen::SpecialCommaInitializer(vector_, s, dynamic_); + } + + template + Eigen::SpecialCommaInitializer operator<<(const Eigen::DenseBase& other) + { + if(dynamic_) + vector_.resize(other.size()); + return Eigen::SpecialCommaInitializer(vector_, other, dynamic_); + } + }; + + class Mat + { + Eigen::MatrixXd matrix_; + + public: + Mat(Eigen::MatrixXd::Index rows, Eigen::MatrixXd::Index cols) : matrix_(rows, cols) {} + + Eigen::SpecialCommaInitializer operator<< (double s) + { + return Eigen::SpecialCommaInitializer(matrix_, s, false); + } + + template + Eigen::SpecialCommaInitializer operator<<(const Eigen::DenseBase& other) + { + return Eigen::SpecialCommaInitializer(matrix_, other, false); + } + }; +} diff --git a/gtsam/base/Vector.h b/gtsam/base/Vector.h index 2c6066042..b56ba7a7c 100644 --- a/gtsam/base/Vector.h +++ b/gtsam/base/Vector.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace gtsam { diff --git a/gtsam/base/tests/testMatrix.cpp b/gtsam/base/tests/testMatrix.cpp index c5fae2585..fd156d978 100644 --- a/gtsam/base/tests/testMatrix.cpp +++ b/gtsam/base/tests/testMatrix.cpp @@ -69,6 +69,51 @@ TEST( matrix, Matrix_ ) } +namespace { + /* ************************************************************************* */ + template + Matrix testFcn1(const Eigen::DenseBase& in) + { + return in; + } + + /* ************************************************************************* */ + template + Matrix testFcn2(const Eigen::MatrixBase& in) + { + return in; + } +} + +/* ************************************************************************* */ +TEST( matrix, special_comma_initializer) +{ + Matrix expected(2,2); + expected(0,0) = 1; + expected(0,1) = 2; + expected(1,0) = 3; + expected(1,1) = 4; + + Matrix actual1 = (Mat(2,2) << 1, 2, 3, 4); + Matrix actual2((Mat(2,2) << 1, 2, 3, 4)); + + Matrix submat1 = (Mat(1,2) << 3, 4); + Matrix actual3 = (Mat(2,2) << 1, 2, submat1); + + Matrix submat2 = (Mat(1,2) << 1, 2); + Matrix actual4 = (Mat(2,2) << submat2, 3, 4); + + Matrix actual5 = testFcn1((Mat(2,2) << 1, 2, 3, 4)); + Matrix actual6 = testFcn2((Mat(2,2) << 1, 2, 3, 4)); + + EXPECT(assert_equal(expected, actual1)); + EXPECT(assert_equal(expected, actual2)); + EXPECT(assert_equal(expected, actual3)); + EXPECT(assert_equal(expected, actual4)); + EXPECT(assert_equal(expected, actual5)); + EXPECT(assert_equal(expected, actual6)); +} + /* ************************************************************************* */ TEST( matrix, col_major ) { diff --git a/gtsam/base/tests/testVector.cpp b/gtsam/base/tests/testVector.cpp index fd9c1f92c..232d3bac7 100644 --- a/gtsam/base/tests/testVector.cpp +++ b/gtsam/base/tests/testVector.cpp @@ -23,112 +23,6 @@ using namespace std; using namespace gtsam; -#include - -// Row-vectors not tested -//template -//inline void resizeHelper(XprType& xpr, DenseIndex sizeIncrement, -// typename boost::enable_if_c< -// XprType::ColsAtCompileTime == Eigen::Dynamic && XprType::RowsAtCompileTime == 1>::type* = 0) -//{ -// xpr.conservativeResize(xpr.cols() + sizeIncrement); -//} - -template -inline void resizeHelper(XprType& xpr, DenseIndex sizeIncrement, - typename boost::enable_if_c< - XprType::RowsAtCompileTime == Eigen::Dynamic && XprType::ColsAtCompileTime == 1>::type* = 0) -{ - xpr.conservativeResize(xpr.rows() + sizeIncrement); -} - -/// A special comma initializer for Eigen that is implicitly convertible to Vector and Matrix. -template -class SpecialCommaInitializer : public Eigen::CommaInitializer -{ -private: - bool dynamic_; - -public: - typedef Eigen::CommaInitializer Base; - - // Forward to base class - inline SpecialCommaInitializer(XprType& xpr, const typename XprType::Scalar& s, bool dynamic) : - Base(xpr, s), dynamic_(dynamic) {} - - // Forward to base class - template - inline SpecialCommaInitializer(XprType& xpr, const Eigen::DenseBase& other, bool dynamic) : - Base(xpr, other), dynamic_(dynamic) {} - - /// Implicit conversion to expression type, e.g. Vector or Matrix - inline operator XprType () - { - return this->finished(); - } - - /// Override base class comma operators to return this class instead of the base class. - SpecialCommaInitializer& operator,(const typename XprType::Scalar& s) - { - // If dynamic, resize the underlying object - if(dynamic_) - { - // Dynamic expansion currently only tested for column-vectors - assert(XprType::RowsAtCompileTime == Eigen::Dynamic); - // Current col should be zero and row should be at the end - assert(Base::m_col == 1); - assert(Base::m_row == Base::m_xpr.rows() - Base::m_currentBlockRows); - resizeHelper(Base::m_xpr, 1); - } - (void) Base::operator,(s); - return *this; - } - - /// Override base class comma operators to return this class instead of the base class. - template - SpecialCommaInitializer& operator,(const Eigen::DenseBase& other) - { - // If dynamic, resize the underlying object - if(dynamic_) - { - // Dynamic expansion currently only tested for column-vectors - assert(XprType::RowsAtCompileTime == Eigen::Dynamic); - // Current col should be zero and row should be at the end - assert(Base::m_col == 1); - assert(Base::m_row == Base::m_xpr.rows() - Base::m_currentBlockRows); - resizeHelper(Base::m_xpr, other.size()); - } - (void) Base::operator,(other); - return *this; - } -}; - -class Vec -{ - Vector vector_; - bool dynamic_; - -public: - Vec(DenseIndex size) : vector_(size), dynamic_(false) {} - - Vec() : dynamic_(true) {} - - SpecialCommaInitializer operator<< (double s) - { - if(dynamic_) - vector_.resize(1); - return SpecialCommaInitializer(vector_, s, dynamic_); - } - - template - SpecialCommaInitializer operator<<(const Eigen::DenseBase& other) - { - if(dynamic_) - vector_.resize(other.size()); - return SpecialCommaInitializer(vector_, other, dynamic_); - } -}; - /* ************************************************************************* */ TEST( TestVector, Vector_variants ) { @@ -138,6 +32,22 @@ TEST( TestVector, Vector_variants ) EXPECT(assert_equal(a, b)); } +namespace { + /* ************************************************************************* */ + template + Vector testFcn1(const Eigen::DenseBase& in) + { + return in; + } + + /* ************************************************************************* */ + template + Vector testFcn2(const Eigen::MatrixBase& in) + { + return in; + } +} + /* ************************************************************************* */ TEST( TestVector, special_comma_initializer) { @@ -156,11 +66,16 @@ TEST( TestVector, special_comma_initializer) Vector subvec2 = (Vec() << 1, 2); Vector actual5 = (Vec() << subvec2, 3); + Vector actual6 = testFcn1((Vec() << 1, 2, 3)); + Vector actual7 = testFcn2((Vec() << 1, 2, 3)); + EXPECT(assert_equal(expected, actual1)); EXPECT(assert_equal(expected, actual2)); EXPECT(assert_equal(expected, actual3)); EXPECT(assert_equal(expected, actual4)); EXPECT(assert_equal(expected, actual5)); + EXPECT(assert_equal(expected, actual6)); + EXPECT(assert_equal(expected, actual7)); } /* ************************************************************************* */