| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | /* ----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |  * GTSAM Copyright 2010, Georgia Tech Research Corporation,  | 
					
						
							|  |  |  |  * Atlanta, Georgia 30332-0415 | 
					
						
							|  |  |  |  * All Rights Reserved | 
					
						
							|  |  |  |  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |  * See LICENSE for the license information | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |  * -------------------------------------------------------------------------- */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /**
 | 
					
						
							|  |  |  |  * @file Expression-inl.h | 
					
						
							|  |  |  |  * @date September 18, 2014 | 
					
						
							|  |  |  |  * @author Frank Dellaert | 
					
						
							|  |  |  |  * @author Paul Furgale | 
					
						
							|  |  |  |  * @brief Internals for Expression.h, not for general consumption | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:29:57 +08:00
										 |  |  | #include <gtsam/nonlinear/Values.h>
 | 
					
						
							| 
									
										
										
										
											2014-10-01 17:25:49 +08:00
										 |  |  | #include <gtsam/base/Matrix.h>
 | 
					
						
							| 
									
										
										
										
											2014-10-11 16:27:30 +08:00
										 |  |  | #include <gtsam/base/Testable.h>
 | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  | #include <gtsam/base/Manifold.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | #include <boost/foreach.hpp>
 | 
					
						
							| 
									
										
										
										
											2014-10-07 22:11:55 +08:00
										 |  |  | #include <boost/tuple/tuple.hpp>
 | 
					
						
							| 
									
										
										
										
											2014-10-11 14:17:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-10 17:41:01 +08:00
										 |  |  | // template meta-programming headers
 | 
					
						
							|  |  |  | #include <boost/mpl/vector.hpp>
 | 
					
						
							|  |  |  | #include <boost/mpl/plus.hpp>
 | 
					
						
							|  |  |  | #include <boost/mpl/front.hpp>
 | 
					
						
							|  |  |  | #include <boost/mpl/pop_front.hpp>
 | 
					
						
							|  |  |  | #include <boost/mpl/fold.hpp>
 | 
					
						
							| 
									
										
										
										
											2014-10-10 19:29:56 +08:00
										 |  |  | #include <boost/mpl/empty_base.hpp>
 | 
					
						
							|  |  |  | #include <boost/mpl/placeholders.hpp>
 | 
					
						
							| 
									
										
										
										
											2014-10-14 00:32:58 +08:00
										 |  |  | #include <boost/mpl/transform.hpp>
 | 
					
						
							|  |  |  | #include <boost/mpl/at.hpp>
 | 
					
						
							| 
									
										
										
										
											2014-10-10 18:38:26 +08:00
										 |  |  | namespace MPL = boost::mpl::placeholders; | 
					
						
							| 
									
										
										
										
											2014-10-10 17:41:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-14 23:46:57 +08:00
										 |  |  | #include <new> // for placement new
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  | class ExpressionFactorBinaryTest; | 
					
						
							|  |  |  | // Forward declare for testing
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | namespace gtsam { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template<typename T> | 
					
						
							|  |  |  | class Expression; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-14 21:43:41 +08:00
										 |  |  | typedef std::map<Key, Eigen::Block<Matrix> > JacobianMap; | 
					
						
							| 
									
										
										
										
											2014-10-13 02:16:08 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-08 19:58:15 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  * The CallRecord class stores the Jacobians of applying a function | 
					
						
							|  |  |  |  * with respect to each of its arguments. It also stores an executation trace | 
					
						
							|  |  |  |  * (defined below) for each of its arguments. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * It is sub-classed in the function-style ExpressionNode sub-classes below. | 
					
						
							|  |  |  |  */ | 
					
						
							| 
									
										
										
										
											2014-10-11 15:00:03 +08:00
										 |  |  | template<int COLS> | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  | struct CallRecord { | 
					
						
							| 
									
										
										
										
											2014-10-13 04:17:21 +08:00
										 |  |  |   static size_t const N = 0; | 
					
						
							| 
									
										
										
										
											2014-10-11 21:20:12 +08:00
										 |  |  |   virtual void print(const std::string& indent) const { | 
					
						
							| 
									
										
										
										
											2014-10-08 19:58:15 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-10 18:31:40 +08:00
										 |  |  |   virtual void startReverseAD(JacobianMap& jacobians) const { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-11 15:00:03 +08:00
										 |  |  |   typedef Eigen::Matrix<double, 2, COLS> Jacobian2T; | 
					
						
							| 
									
										
										
										
											2014-10-09 05:50:17 +08:00
										 |  |  |   virtual void reverseAD2(const Jacobian2T& dFdT, | 
					
						
							| 
									
										
										
										
											2014-10-10 18:31:40 +08:00
										 |  |  |       JacobianMap& jacobians) const { | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-09 05:50:17 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-15 05:40:21 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | /// Handle Leaf Case: reverseAD ends here, by writing a matrix into Jacobians
 | 
					
						
							|  |  |  | template<int ROWS, int COLS> | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  | void handleLeafCase(const Eigen::Matrix<double, ROWS, COLS>& dTdA, | 
					
						
							| 
									
										
										
										
											2014-10-15 05:40:21 +08:00
										 |  |  |     JacobianMap& jacobians, Key key) { | 
					
						
							|  |  |  |   JacobianMap::iterator it = jacobians.find(key); | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  |   it->second.block<ROWS, COLS>(0, 0) += dTdA; // block makes HUGE difference
 | 
					
						
							| 
									
										
										
										
											2014-10-15 05:40:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | /// Handle Leaf Case for Dynamic Matrix type (slower)
 | 
					
						
							|  |  |  | template<> | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  | void handleLeafCase( | 
					
						
							|  |  |  |     const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>& dTdA, | 
					
						
							| 
									
										
										
										
											2014-10-15 05:40:21 +08:00
										 |  |  |     JacobianMap& jacobians, Key key) { | 
					
						
							|  |  |  |   JacobianMap::iterator it = jacobians.find(key); | 
					
						
							|  |  |  |   it->second += dTdA; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | /**
 | 
					
						
							| 
									
										
										
										
											2014-10-19 17:19:09 +08:00
										 |  |  |  * The ExecutionTrace class records a tree-structured expression's execution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * The class looks a bit complicated but it is so for performance. | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |  * It is a tagged union that obviates the need to create | 
					
						
							|  |  |  |  * a ExecutionTrace subclass for Constants and Leaf Expressions. Instead | 
					
						
							|  |  |  |  * the key for the leaf is stored in the space normally used to store a | 
					
						
							|  |  |  |  * CallRecord*. Nothing is stored for a Constant. | 
					
						
							| 
									
										
										
										
											2014-10-19 17:19:09 +08:00
										 |  |  |  * | 
					
						
							|  |  |  |  * A full execution trace of a Binary(Unary(Binary(Leaf,Constant)),Leaf) would be: | 
					
						
							|  |  |  |  * Trace(Function) -> | 
					
						
							|  |  |  |  *   BinaryRecord with two traces in it | 
					
						
							|  |  |  |  *     trace1(Function) -> | 
					
						
							|  |  |  |  *       UnaryRecord with one trace in it | 
					
						
							|  |  |  |  *         trace1(Function) -> | 
					
						
							|  |  |  |  *           BinaryRecord with two traces in it | 
					
						
							|  |  |  |  *             trace1(Leaf) | 
					
						
							|  |  |  |  *             trace2(Constant) | 
					
						
							|  |  |  |  *     trace2(Leaf) | 
					
						
							|  |  |  |  * Hence, there are three Record structs, written to memory by traceExecution | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |  */ | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | class ExecutionTrace { | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |   static const int Dim = traits::dimension<T>::value; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |   enum { | 
					
						
							|  |  |  |     Constant, Leaf, Function | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |   } kind; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |   union { | 
					
						
							|  |  |  |     Key key; | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  |     CallRecord<Dim>* ptr; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |   } content; | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |   /// Pointer always starts out as a Constant
 | 
					
						
							|  |  |  |   ExecutionTrace() : | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |       kind(Constant) { | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |   } | 
					
						
							|  |  |  |   /// Change pointer to a Leaf Record
 | 
					
						
							|  |  |  |   void setLeaf(Key key) { | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |     kind = Leaf; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |     content.key = key; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   /// Take ownership of pointer to a Function Record
 | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  |   void setFunction(CallRecord<Dim>* record) { | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |     kind = Function; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |     content.ptr = record; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-11 16:27:30 +08:00
										 |  |  |   /// Print
 | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |   void print(const std::string& indent = "") const { | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |     if (kind == Constant) | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |       std::cout << indent << "Constant" << std::endl; | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |     else if (kind == Leaf) | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |       std::cout << indent << "Leaf, key = " << content.key << std::endl; | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |     else if (kind == Function) { | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |       std::cout << indent << "Function" << std::endl; | 
					
						
							|  |  |  |       content.ptr->print(indent + "  "); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2014-10-11 16:27:30 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-11 14:52:24 +08:00
										 |  |  |   /// Return record pointer, quite unsafe, used only for testing
 | 
					
						
							|  |  |  |   template<class Record> | 
					
						
							|  |  |  |   boost::optional<Record*> record() { | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |     if (kind != Function) | 
					
						
							| 
									
										
										
										
											2014-10-11 14:52:24 +08:00
										 |  |  |       return boost::none; | 
					
						
							|  |  |  |     else { | 
					
						
							|  |  |  |       Record* p = dynamic_cast<Record*>(content.ptr); | 
					
						
							|  |  |  |       return p ? boost::optional<Record*>(p) : boost::none; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2014-10-11 14:41:39 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-14 21:43:41 +08:00
										 |  |  |   /**
 | 
					
						
							|  |  |  |    *  *** This is the main entry point for reverseAD, called from Expression *** | 
					
						
							|  |  |  |    * Called only once, either inserts I into Jacobians (Leaf) or starts AD (Function) | 
					
						
							|  |  |  |    */ | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  |   typedef Eigen::Matrix<double, Dim, Dim> JacobianTT; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |   void startReverseAD(JacobianMap& jacobians) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |     if (kind == Leaf) { | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |       // This branch will only be called on trivial Leaf expressions, i.e. Priors
 | 
					
						
							| 
									
										
										
										
											2014-10-15 05:40:21 +08:00
										 |  |  |       static const JacobianTT I = JacobianTT::Identity(); | 
					
						
							|  |  |  |       handleLeafCase(I, jacobians, content.key); | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  |     } else if (kind == Function) | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |       // This is the more typical entry point, starting the AD pipeline
 | 
					
						
							| 
									
										
										
										
											2014-10-14 21:43:41 +08:00
										 |  |  |       // Inside the startReverseAD that the correctly dimensioned pipeline is chosen.
 | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |       content.ptr->startReverseAD(jacobians); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   // Either add to Jacobians (Leaf) or propagate (Function)
 | 
					
						
							|  |  |  |   void reverseAD(const Matrix& dTdA, JacobianMap& jacobians) const { | 
					
						
							| 
									
										
										
										
											2014-10-14 21:43:41 +08:00
										 |  |  |     if (kind == Leaf) | 
					
						
							|  |  |  |       handleLeafCase(dTdA, jacobians, content.key); | 
					
						
							|  |  |  |     else if (kind == Function) | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |       content.ptr->reverseAD(dTdA, jacobians); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   // Either add to Jacobians (Leaf) or propagate (Function)
 | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  |   typedef Eigen::Matrix<double, 2, Dim> Jacobian2T; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |   void reverseAD2(const Jacobian2T& dTdA, JacobianMap& jacobians) const { | 
					
						
							| 
									
										
										
										
											2014-10-14 21:43:41 +08:00
										 |  |  |     if (kind == Leaf) | 
					
						
							|  |  |  |       handleLeafCase(dTdA, jacobians, content.key); | 
					
						
							|  |  |  |     else if (kind == Function) | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |       content.ptr->reverseAD2(dTdA, jacobians); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-13 19:04:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Define type so we can apply it as a meta-function
 | 
					
						
							|  |  |  |   typedef ExecutionTrace<T> type; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-09 06:21:10 +08:00
										 |  |  | /// Primary template calls the generic Matrix reverseAD pipeline
 | 
					
						
							| 
									
										
										
										
											2014-10-13 15:25:06 +08:00
										 |  |  | template<size_t ROWS, class A> | 
					
						
							| 
									
										
										
										
											2014-10-09 05:50:17 +08:00
										 |  |  | struct Select { | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |   typedef Eigen::Matrix<double, ROWS, traits::dimension<A>::value> Jacobian; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |   static void reverseAD(const ExecutionTrace<A>& trace, const Jacobian& dTdA, | 
					
						
							|  |  |  |       JacobianMap& jacobians) { | 
					
						
							| 
									
										
										
										
											2014-10-09 05:50:17 +08:00
										 |  |  |     trace.reverseAD(dTdA, jacobians); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-08 19:58:15 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-09 06:21:10 +08:00
										 |  |  | /// Partially specialized template calls the 2-dimensional output version
 | 
					
						
							| 
									
										
										
										
											2014-10-09 05:50:17 +08:00
										 |  |  | template<class A> | 
					
						
							|  |  |  | struct Select<2, A> { | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |   typedef Eigen::Matrix<double, 2, traits::dimension<A>::value> Jacobian; | 
					
						
							| 
									
										
										
										
											2014-10-10 23:45:39 +08:00
										 |  |  |   static void reverseAD(const ExecutionTrace<A>& trace, const Jacobian& dTdA, | 
					
						
							|  |  |  |       JacobianMap& jacobians) { | 
					
						
							| 
									
										
										
										
											2014-10-09 05:50:17 +08:00
										 |  |  |     trace.reverseAD2(dTdA, jacobians); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2014-10-08 19:58:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  * Expression node. The superclass for objects that do the heavy lifting | 
					
						
							|  |  |  |  * An Expression<T> has a pointer to an ExpressionNode<T> underneath | 
					
						
							|  |  |  |  * allowing Expressions to have polymorphic behaviour even though they | 
					
						
							|  |  |  |  * are passed by value. This is the same way boost::function works. | 
					
						
							|  |  |  |  * http://loki-lib.sourceforge.net/html/a00652.html
 | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | class ExpressionNode { | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | protected: | 
					
						
							| 
									
										
										
										
											2014-10-11 18:11:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 00:52:12 +08:00
										 |  |  |   size_t traceSize_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor, traceSize is size of the execution trace of expression rooted here
 | 
					
						
							|  |  |  |   ExpressionNode(size_t traceSize = 0) : | 
					
						
							|  |  |  |       traceSize_(traceSize) { | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | public: | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Destructor
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   virtual ~ExpressionNode() { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return keys that play in this expression as a set
 | 
					
						
							| 
									
										
										
										
											2014-10-13 06:37:46 +08:00
										 |  |  |   virtual std::set<Key> keys() const { | 
					
						
							|  |  |  |     std::set<Key> keys; | 
					
						
							|  |  |  |     return keys; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-14 23:46:57 +08:00
										 |  |  |   /// Return dimensions for each argument, as a map
 | 
					
						
							| 
									
										
										
										
											2014-10-16 18:01:20 +08:00
										 |  |  |   virtual void dims(std::map<Key, size_t>& map) const { | 
					
						
							| 
									
										
										
										
											2014-10-14 23:46:57 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 00:52:12 +08:00
										 |  |  |   // Return size needed for memory buffer in traceExecution
 | 
					
						
							|  |  |  |   size_t traceSize() const { | 
					
						
							|  |  |  |     return traceSize_; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const = 0; | 
					
						
							| 
									
										
										
										
											2014-10-03 05:28:19 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 17:22:14 +08:00
										 |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |   virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, | 
					
						
							| 
									
										
										
										
											2014-10-11 19:07:58 +08:00
										 |  |  |       char* raw) const = 0; | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | /// Constant Expression
 | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | class ConstantExpression: public ExpressionNode<T> { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// The constant value
 | 
					
						
							|  |  |  |   T constant_; | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor with a value, yielding a constant
 | 
					
						
							|  |  |  |   ConstantExpression(const T& value) : | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |       constant_(value) { | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   friend class Expression<T> ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							|  |  |  |     return constant_; | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 17:22:14 +08:00
										 |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |   virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, | 
					
						
							| 
									
										
										
										
											2014-10-11 19:07:58 +08:00
										 |  |  |       char* raw) const { | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |     return constant_; | 
					
						
							| 
									
										
										
										
											2014-10-05 17:22:14 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | /// Leaf Expression
 | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | class LeafExpression: public ExpressionNode<T> { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// The key into values
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   Key key_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor with a single key
 | 
					
						
							|  |  |  |   LeafExpression(Key key) : | 
					
						
							|  |  |  |       key_(key) { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   friend class Expression<T> ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return keys that play in this expression
 | 
					
						
							|  |  |  |   virtual std::set<Key> keys() const { | 
					
						
							|  |  |  |     std::set<Key> keys; | 
					
						
							|  |  |  |     keys.insert(key_); | 
					
						
							|  |  |  |     return keys; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-14 15:53:47 +08:00
										 |  |  |   /// Return dimensions for each argument
 | 
					
						
							| 
									
										
										
										
											2014-10-16 18:01:20 +08:00
										 |  |  |   virtual void dims(std::map<Key, size_t>& map) const { | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |     map[key_] = traits::dimension<T>::value; | 
					
						
							| 
									
										
										
										
											2014-10-14 15:53:47 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							|  |  |  |     return values.at<T>(key_); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 19:37:51 +08:00
										 |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |   virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, | 
					
						
							| 
									
										
										
										
											2014-10-11 19:07:58 +08:00
										 |  |  |       char* raw) const { | 
					
						
							| 
									
										
										
										
											2014-10-11 16:27:30 +08:00
										 |  |  |     trace.setLeaf(key_); | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |     return values.at<T>(key_); | 
					
						
							| 
									
										
										
										
											2014-10-05 19:37:51 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  | // Below we use the "Class Composition" technique described in the book
 | 
					
						
							|  |  |  | //   C++ Template Metaprogramming: Concepts, Tools, and Techniques from Boost
 | 
					
						
							|  |  |  | //   and Beyond. Abrahams, David; Gurtovoy, Aleksey. Pearson Education.
 | 
					
						
							|  |  |  | // to recursively generate a class, that will be the base for function nodes.
 | 
					
						
							| 
									
										
										
										
											2014-10-19 05:48:51 +08:00
										 |  |  | // The class generated, for three arguments A1, A2, and A3 will be
 | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  | //
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | // struct Base1 : Argument<T,A1,1>, FunctionalBase<T> {
 | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  | //   ... storage related to A1 ...
 | 
					
						
							|  |  |  | //   ... methods that work on A1 ...
 | 
					
						
							|  |  |  | // };
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // struct Base2 : Argument<T,A2,2>, Base1 {
 | 
					
						
							|  |  |  | //   ... storage related to A2 ...
 | 
					
						
							| 
									
										
										
										
											2014-10-19 05:48:51 +08:00
										 |  |  | //   ... methods that work on A2 and (recursively) on A1 ...
 | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  | // };
 | 
					
						
							|  |  |  | //
 | 
					
						
							| 
									
										
										
										
											2014-10-19 05:48:51 +08:00
										 |  |  | // struct Base3 : Argument<T,A3,3>, Base2 {
 | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  | //   ... storage related to A3 ...
 | 
					
						
							| 
									
										
										
										
											2014-10-19 05:48:51 +08:00
										 |  |  | //   ... methods that work on A3 and (recursively) on A2 and A1 ...
 | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  | // };
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // struct FunctionalNode : Base3 {
 | 
					
						
							|  |  |  | //   Provides convenience access to storage in hierarchy by using
 | 
					
						
							|  |  |  | //   static_cast<Argument<T, A, N> &>(*this)
 | 
					
						
							|  |  |  | // }
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // All this magic happens when  we generate the Base3 base class of FunctionalNode
 | 
					
						
							|  |  |  | // by invoking boost::mpl::fold over the meta-function GenerateFunctionalNode
 | 
					
						
							|  |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  | /// meta-function to generate fixed-size JacobianTA type
 | 
					
						
							|  |  |  | template<class T, class A> | 
					
						
							|  |  |  | struct Jacobian { | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |   typedef Eigen::Matrix<double, traits::dimension<T>::value, | 
					
						
							|  |  |  |       traits::dimension<A>::value> type; | 
					
						
							| 
									
										
										
										
											2014-10-14 00:32:58 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// meta-function to generate JacobianTA optional reference
 | 
					
						
							|  |  |  | template<class T, class A> | 
					
						
							|  |  |  | struct Optional { | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |   typedef Eigen::Matrix<double, traits::dimension<T>::value, | 
					
						
							|  |  |  |       traits::dimension<A>::value> Jacobian; | 
					
						
							| 
									
										
										
										
											2014-10-14 00:32:58 +08:00
										 |  |  |   typedef boost::optional<Jacobian&> type; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | /**
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |  * Base case for recursive FunctionalNode class | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | struct FunctionalBase: ExpressionNode<T> { | 
					
						
							|  |  |  |   static size_t const N = 0; // number of arguments
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |   typedef CallRecord<traits::dimension<T>::value> Record; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |   void trace(const Values& values, Record* record, char*& raw) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /**
 | 
					
						
							|  |  |  |  * Building block for recursive FunctionalNode class | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  |  * The integer argument N is to guarantee a unique type signature, | 
					
						
							|  |  |  |  * so we are guaranteed to be able to extract their values by static cast. | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  |  */ | 
					
						
							|  |  |  | template<class T, class A, size_t N> | 
					
						
							|  |  |  | struct Argument { | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  |   /// Expression that will generate value/derivatives for argument
 | 
					
						
							| 
									
										
										
										
											2014-10-13 05:31:48 +08:00
										 |  |  |   boost::shared_ptr<ExpressionNode<A> > expression; | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  * Building block for Recursive Record Class | 
					
						
							|  |  |  |  * Records the evaluation of a single argument in a functional expression | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | template<class T, class A, size_t N> | 
					
						
							|  |  |  | struct JacobianTrace { | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |   A value; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |   ExecutionTrace<A> trace; | 
					
						
							|  |  |  |   typename Jacobian<T, A>::type dTdA; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  * Recursive Definition of Functional ExpressionNode | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | template<class T, class A, class Base> | 
					
						
							| 
									
										
										
										
											2014-10-13 05:31:48 +08:00
										 |  |  | struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base { | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  |   static size_t const N = Base::N + 1; ///< Number of arguments in hierarchy
 | 
					
						
							|  |  |  |   typedef Argument<T, A, N> This; ///< The storage we have direct access to
 | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 06:37:46 +08:00
										 |  |  |   /// Return keys that play in this expression
 | 
					
						
							|  |  |  |   virtual std::set<Key> keys() const { | 
					
						
							|  |  |  |     std::set<Key> keys = Base::keys(); | 
					
						
							|  |  |  |     std::set<Key> myKeys = This::expression->keys(); | 
					
						
							|  |  |  |     keys.insert(myKeys.begin(), myKeys.end()); | 
					
						
							|  |  |  |     return keys; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-14 15:53:47 +08:00
										 |  |  |   /// Return dimensions for each argument
 | 
					
						
							| 
									
										
										
										
											2014-10-16 18:01:20 +08:00
										 |  |  |   virtual void dims(std::map<Key, size_t>& map) const { | 
					
						
							|  |  |  |     Base::dims(map); | 
					
						
							|  |  |  |     This::expression->dims(map); | 
					
						
							| 
									
										
										
										
											2014-10-14 15:53:47 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |   /// Recursive Record Class for Functional Expressions
 | 
					
						
							|  |  |  |   struct Record: JacobianTrace<T, A, N>, Base::Record { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     typedef T return_type; | 
					
						
							|  |  |  |     typedef JacobianTrace<T, A, N> This; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Print to std::cout
 | 
					
						
							|  |  |  |     virtual void print(const std::string& indent) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |       Base::Record::print(indent); | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |       static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]"); | 
					
						
							|  |  |  |       std::cout << This::dTdA.format(matlab) << std::endl; | 
					
						
							|  |  |  |       This::trace.print(indent); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Start the reverse AD process
 | 
					
						
							|  |  |  |     virtual void startReverseAD(JacobianMap& jacobians) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |       Base::Record::startReverseAD(jacobians); | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |       Select<traits::dimension<T>::value, A>::reverseAD(This::trace, This::dTdA, | 
					
						
							| 
									
										
										
										
											2014-10-19 06:35:25 +08:00
										 |  |  |           jacobians); | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Given df/dT, multiply in dT/dA and continue reverse AD process
 | 
					
						
							|  |  |  |     virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |       Base::Record::reverseAD(dFdT, jacobians); | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |       This::trace.reverseAD(dFdT * This::dTdA, jacobians); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Version specialized to 2-dimensional output
 | 
					
						
							| 
									
										
										
										
											2014-10-21 07:26:17 +08:00
										 |  |  |     typedef Eigen::Matrix<double, 2, traits::dimension<T>::value> Jacobian2T; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |     virtual void reverseAD2(const Jacobian2T& dFdT, | 
					
						
							|  |  |  |         JacobianMap& jacobians) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |       Base::Record::reverseAD2(dFdT, jacobians); | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |       This::trace.reverseAD2(dFdT * This::dTdA, jacobians); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |   void trace(const Values& values, Record* record, char*& raw) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 17:55:16 +08:00
										 |  |  |     Base::trace(values, record, raw); // recurse
 | 
					
						
							|  |  |  |     // Write an Expression<A> execution trace in record->trace
 | 
					
						
							|  |  |  |     // Iff Constant or Leaf, this will not write to raw, only to trace.
 | 
					
						
							|  |  |  |     // Iff the expression is functional, write all Records in raw buffer
 | 
					
						
							|  |  |  |     // Return value of type T is recorded in record->value
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |     record->Record::This::value = This::expression->traceExecution(values, | 
					
						
							|  |  |  |         record->Record::This::trace, raw); | 
					
						
							| 
									
										
										
										
											2014-10-13 17:55:16 +08:00
										 |  |  |     // raw is never modified by traceExecution, but if traceExecution has
 | 
					
						
							|  |  |  |     // written in the buffer, the next caller expects we advance the pointer
 | 
					
						
							|  |  |  |     raw += This::expression->traceSize(); | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 14:49:12 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  *  Recursive GenerateFunctionalNode class Generator | 
					
						
							|  |  |  |  */ | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | template<class T, class TYPES> | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | struct FunctionalNode { | 
					
						
							| 
									
										
										
										
											2014-10-14 00:32:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |   typedef typename boost::mpl::fold<TYPES, FunctionalBase<T>, | 
					
						
							|  |  |  |       GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type Base; | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |   struct type: public Base { | 
					
						
							| 
									
										
										
										
											2014-10-13 05:03:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-14 00:32:58 +08:00
										 |  |  |     // Argument types and derived, note these are base 0 !
 | 
					
						
							|  |  |  |     typedef TYPES Arguments; | 
					
						
							|  |  |  |     typedef typename boost::mpl::transform<TYPES, Jacobian<T, MPL::_1> >::type Jacobians; | 
					
						
							|  |  |  |     typedef typename boost::mpl::transform<TYPES, Optional<T, MPL::_1> >::type Optionals; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |     /// Reset expression shared pointer
 | 
					
						
							|  |  |  |     template<class A, size_t N> | 
					
						
							|  |  |  |     void reset(const boost::shared_ptr<ExpressionNode<A> >& ptr) { | 
					
						
							|  |  |  |       static_cast<Argument<T, A, N> &>(*this).expression = ptr; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2014-10-13 06:31:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |     /// Access Expression shared pointer
 | 
					
						
							|  |  |  |     template<class A, size_t N> | 
					
						
							|  |  |  |     boost::shared_ptr<ExpressionNode<A> > expression() const { | 
					
						
							|  |  |  |       return static_cast<Argument<T, A, N> const &>(*this).expression; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |     /// Provide convenience access to Record storage
 | 
					
						
							|  |  |  |     struct Record: public Base::Record { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |       /// Access Value
 | 
					
						
							|  |  |  |       template<class A, size_t N> | 
					
						
							|  |  |  |       const A& value() const { | 
					
						
							|  |  |  |         return static_cast<JacobianTrace<T, A, N> const &>(*this).value; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |       /// Access Jacobian
 | 
					
						
							|  |  |  |       template<class A, size_t N> | 
					
						
							|  |  |  |       typename Jacobian<T, A>::type& jacobian() { | 
					
						
							|  |  |  |         return static_cast<JacobianTrace<T, A, N>&>(*this).dTdA; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |     /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |     Record* trace(const Values& values, char* raw) const { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Create the record and advance the pointer
 | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |       Record* record = new (raw) Record(); | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |       raw = (char*) (record + 1); | 
					
						
							| 
									
										
										
										
											2014-10-13 05:31:48 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |       // Record the traces for all arguments
 | 
					
						
							|  |  |  |       // After this, the raw pointer is set to after what was written
 | 
					
						
							|  |  |  |       Base::trace(values, record, raw); | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |       // Return the record for this function evaluation
 | 
					
						
							|  |  |  |       return record; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-10-13 05:31:48 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | /// Unary Function Expression
 | 
					
						
							| 
									
										
										
										
											2014-10-09 05:50:17 +08:00
										 |  |  | template<class T, class A1> | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> >::type { | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-14 00:32:58 +08:00
										 |  |  |   typedef boost::function<T(const A1&, typename Optional<T, A1>::type)> Function; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |   typedef typename FunctionalNode<T, boost::mpl::vector<A1> >::type Base; | 
					
						
							|  |  |  |   typedef typename Base::Record Record; | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:52:35 +08:00
										 |  |  |   Function function_; | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor with a unary function f, and input argument e
 | 
					
						
							| 
									
										
										
										
											2014-10-13 00:52:12 +08:00
										 |  |  |   UnaryExpression(Function f, const Expression<A1>& e1) : | 
					
						
							| 
									
										
										
										
											2014-10-13 06:31:03 +08:00
										 |  |  |       function_(f) { | 
					
						
							| 
									
										
										
										
											2014-10-13 06:57:11 +08:00
										 |  |  |     this->template reset<A1, 1>(e1.root()); | 
					
						
							| 
									
										
										
										
											2014-10-13 05:31:48 +08:00
										 |  |  |     ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize(); | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   friend class Expression<T> ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 06:31:03 +08:00
										 |  |  |     return function_(this->template expression<A1, 1>()->value(values), boost::none); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 19:33:23 +08:00
										 |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |   virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, | 
					
						
							| 
									
										
										
										
											2014-10-11 19:07:58 +08:00
										 |  |  |       char* raw) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |     Record* record = Base::trace(values, raw); | 
					
						
							|  |  |  |     trace.setFunction(record); | 
					
						
							| 
									
										
										
										
											2014-10-11 21:20:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |     return function_(record->template value<A1, 1>(), | 
					
						
							|  |  |  |         record->template jacobian<A1, 1>()); | 
					
						
							| 
									
										
										
										
											2014-10-05 19:33:23 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  | /// Binary Expression
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | template<class T, class A1, class A2> | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> >::type { | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  |   typedef boost::function< | 
					
						
							| 
									
										
										
										
											2014-10-14 00:32:58 +08:00
										 |  |  |       T(const A1&, const A2&, typename Optional<T, A1>::type, | 
					
						
							|  |  |  |           typename Optional<T, A2>::type)> Function; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |   typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2> >::type Base; | 
					
						
							|  |  |  |   typedef typename Base::Record Record; | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:52:35 +08:00
										 |  |  |   Function function_; | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 06:31:03 +08:00
										 |  |  |   /// Constructor with a ternary function f, and three input arguments
 | 
					
						
							|  |  |  |   BinaryExpression(Function f, const Expression<A1>& e1, | 
					
						
							|  |  |  |       const Expression<A2>& e2) : | 
					
						
							|  |  |  |       function_(f) { | 
					
						
							| 
									
										
										
										
											2014-10-13 06:57:11 +08:00
										 |  |  |     this->template reset<A1, 1>(e1.root()); | 
					
						
							|  |  |  |     this->template reset<A2, 2>(e2.root()); | 
					
						
							| 
									
										
										
										
											2014-10-13 05:31:48 +08:00
										 |  |  |     ExpressionNode<T>::traceSize_ = //
 | 
					
						
							|  |  |  |         sizeof(Record) + e1.traceSize() + e2.traceSize(); | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-11 14:41:39 +08:00
										 |  |  |   friend class Expression<T> ; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |   friend class ::ExpressionFactorBinaryTest; | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							|  |  |  |     using boost::none; | 
					
						
							| 
									
										
										
										
											2014-10-13 06:31:03 +08:00
										 |  |  |     return function_(this->template expression<A1, 1>()->value(values), | 
					
						
							| 
									
										
										
										
											2014-10-13 06:57:11 +08:00
										 |  |  |     this->template expression<A2, 2>()->value(values), | 
					
						
							|  |  |  |     none, none); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 19:27:41 +08:00
										 |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |   virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, | 
					
						
							| 
									
										
										
										
											2014-10-11 19:07:58 +08:00
										 |  |  |       char* raw) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |     Record* record = Base::trace(values, raw); | 
					
						
							|  |  |  |     trace.setFunction(record); | 
					
						
							| 
									
										
										
										
											2014-10-11 21:20:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |     return function_(record->template value<A1, 1>(), | 
					
						
							|  |  |  |         record->template value<A2,2>(), record->template jacobian<A1, 1>(), | 
					
						
							| 
									
										
										
										
											2014-10-13 05:57:08 +08:00
										 |  |  |         record->template jacobian<A2, 2>()); | 
					
						
							| 
									
										
										
										
											2014-10-05 19:27:41 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | }; | 
					
						
							| 
									
										
										
										
											2014-10-06 01:09:16 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-10-06 04:09:24 +08:00
										 |  |  | /// Ternary Expression
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template<class T, class A1, class A2, class A3> | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type { | 
					
						
							| 
									
										
										
										
											2014-10-06 04:09:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   typedef boost::function< | 
					
						
							| 
									
										
										
										
											2014-10-14 00:32:58 +08:00
										 |  |  |       T(const A1&, const A2&, const A3&, typename Optional<T, A1>::type, | 
					
						
							|  |  |  |           typename Optional<T, A2>::type, typename Optional<T, A3>::type)> Function; | 
					
						
							| 
									
										
										
										
											2014-10-13 16:50:05 +08:00
										 |  |  |   typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type Base; | 
					
						
							|  |  |  |   typedef typename Base::Record Record; | 
					
						
							| 
									
										
										
										
											2014-10-06 04:09:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Function function_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor with a ternary function f, and three input arguments
 | 
					
						
							| 
									
										
										
										
											2014-10-13 05:31:48 +08:00
										 |  |  |   TernaryExpression(Function f, const Expression<A1>& e1, | 
					
						
							|  |  |  |       const Expression<A2>& e2, const Expression<A3>& e3) : | 
					
						
							| 
									
										
										
										
											2014-10-13 06:31:03 +08:00
										 |  |  |       function_(f) { | 
					
						
							| 
									
										
										
										
											2014-10-13 06:57:11 +08:00
										 |  |  |     this->template reset<A1, 1>(e1.root()); | 
					
						
							|  |  |  |     this->template reset<A2, 2>(e2.root()); | 
					
						
							|  |  |  |     this->template reset<A3, 3>(e3.root()); | 
					
						
							| 
									
										
										
										
											2014-10-13 05:31:48 +08:00
										 |  |  |     ExpressionNode<T>::traceSize_ = //
 | 
					
						
							|  |  |  |         sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize(); | 
					
						
							| 
									
										
										
										
											2014-10-06 04:09:24 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   friend class Expression<T> ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							|  |  |  |     using boost::none; | 
					
						
							| 
									
										
										
										
											2014-10-13 06:31:03 +08:00
										 |  |  |     return function_(this->template expression<A1, 1>()->value(values), | 
					
						
							| 
									
										
										
										
											2014-10-13 06:57:11 +08:00
										 |  |  |     this->template expression<A2, 2>()->value(values), | 
					
						
							|  |  |  |     this->template expression<A3, 3>()->value(values), | 
					
						
							|  |  |  |     none, none, none); | 
					
						
							| 
									
										
										
										
											2014-10-06 04:09:24 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-06 05:40:11 +08:00
										 |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-11 17:03:35 +08:00
										 |  |  |   virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, | 
					
						
							| 
									
										
										
										
											2014-10-11 19:07:58 +08:00
										 |  |  |       char* raw) const { | 
					
						
							| 
									
										
										
										
											2014-10-13 16:10:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |     Record* record = Base::trace(values, raw); | 
					
						
							|  |  |  |     trace.setFunction(record); | 
					
						
							| 
									
										
										
										
											2014-10-11 21:20:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-13 17:37:47 +08:00
										 |  |  |     return function_( | 
					
						
							|  |  |  |         record->template value<A1, 1>(), record->template value<A2, 2>(), | 
					
						
							|  |  |  |         record->template value<A3, 3>(), record->template jacobian<A1, 1>(), | 
					
						
							| 
									
										
										
										
											2014-10-13 05:57:08 +08:00
										 |  |  |         record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>()); | 
					
						
							| 
									
										
										
										
											2014-10-06 05:40:11 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-06 04:09:24 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-10-14 21:43:41 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | 
 |