diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index 4ffb9afcd..45a932351 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -30,8 +30,9 @@ #include #include #include -#include -#include +#include +#include + namespace MPL = boost::mpl::placeholders; namespace gtsam { @@ -152,41 +153,52 @@ struct Select<2, A> { } }; +/** + * Record the evaluation of a single argument in a functional expression + */ +template +struct SingleTrace { + typedef Eigen::Matrix JacobianTA; + typename JacobianTrace::Pointer trace; + JacobianTA dTdA; +}; + /** * Recursive Trace Class for Functional Expressions * Abrahams, David; Gurtovoy, Aleksey (2004-12-10). * C++ Template Metaprogramming: Concepts, Tools, and Techniques from Boost - * and Beyond (Kindle Location 1244). Pearson Education. + * and Beyond. Pearson Education. */ template -struct Trace: More { - // define dimensions - static const size_t m = T::dimension; - static const size_t n = A::dimension; +struct Trace: SingleTrace, More { - // define fixed size Jacobian matrix types - typedef Eigen::Matrix JacobianTA; - typedef Eigen::Matrix Jacobian2T; + typename JacobianTrace::Pointer const & myTrace() const { + return static_cast*>(this)->trace; + } - // declare trace that produces value A, and corresponding Jacobian - typename JacobianTrace::Pointer trace; - JacobianTA dTdA; + typedef Eigen::Matrix JacobianTA; + const JacobianTA& myJacobian() const { + return static_cast*>(this)->dTdA; + } /// Start the reverse AD process virtual void startReverseAD(JacobianMap& jacobians) const { More::startReverseAD(jacobians); - Select::reverseAD(trace, dTdA, jacobians); + Select::reverseAD(myTrace(), myJacobian(), jacobians); } + /// Given df/dT, multiply in dT/dA and continue reverse AD process virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { More::reverseAD(dFdT, jacobians); - trace.reverseAD(dFdT * dTdA, jacobians); + myTrace().reverseAD(dFdT * myJacobian(), jacobians); } + /// Version specialized to 2-dimensional output + typedef Eigen::Matrix Jacobian2T; virtual void reverseAD2(const Jacobian2T& dFdT, JacobianMap& jacobians) const { More::reverseAD2(dFdT, jacobians); - trace.reverseAD2(dFdT * dTdA, jacobians); + myTrace().reverseAD2(dFdT * myJacobian(), jacobians); } }; @@ -195,6 +207,7 @@ template struct GenerateTrace { typedef typename boost::mpl::fold, Trace >::type type; + }; //----------------------------------------------------------------------------- @@ -545,39 +558,20 @@ public: } /// Trace structure for reverse AD - struct Trace: public JacobianTrace { - typename JacobianTrace::Pointer trace1; - typename JacobianTrace::Pointer trace2; - JacobianTA1 dTdA1; - JacobianTA2 dTdA2; - - /// Start the reverse AD process - virtual void startReverseAD(JacobianMap& jacobians) const { - Select::reverseAD(trace1, dTdA1, jacobians); - Select::reverseAD(trace2, dTdA2, jacobians); - } - /// Given df/dT, multiply in dT/dA and continue reverse AD process - virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { - trace1.reverseAD(dFdT * dTdA1, jacobians); - trace2.reverseAD(dFdT * dTdA2, jacobians); - } - /// Version specialized to 2-dimensional output - typedef Eigen::Matrix Jacobian2T; - virtual void reverseAD2(const Jacobian2T& dFdT, - JacobianMap& jacobians) const { - trace1.reverseAD2(dFdT * dTdA1, jacobians); - trace2.reverseAD2(dFdT * dTdA2, jacobians); - } - }; + typedef boost::mpl::vector Arguments; + typedef typename GenerateTrace::type Trace; /// Construct an execution trace for reverse AD virtual T traceExecution(const Values& values, typename JacobianTrace::Pointer& p) const { Trace* trace = new Trace(); p.setFunction(trace); - A1 a1 = this->expressionA1_->traceExecution(values, trace->trace1); - A2 a2 = this->expressionA2_->traceExecution(values, trace->trace2); - return function_(a1, a2, trace->dTdA1, trace->dTdA2); + A1 a1 = this->expressionA1_->traceExecution(values, + static_cast*>(trace)->trace); + A2 a2 = this->expressionA2_->traceExecution(values, + static_cast*>(trace)->trace); + return function_(a1, a2, static_cast*>(trace)->dTdA, + static_cast*>(trace)->dTdA); } };