| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | /* ----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |  * GTSAM Copyright 2010, Georgia Tech Research Corporation,  | 
					
						
							|  |  |  |  * Atlanta, Georgia 30332-0415 | 
					
						
							|  |  |  |  * All Rights Reserved | 
					
						
							|  |  |  |  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |  * See LICENSE for the license information | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |  * -------------------------------------------------------------------------- */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /**
 | 
					
						
							|  |  |  |  * @file Expression-inl.h | 
					
						
							|  |  |  |  * @date September 18, 2014 | 
					
						
							|  |  |  |  * @author Frank Dellaert | 
					
						
							|  |  |  |  * @author Paul Furgale | 
					
						
							|  |  |  |  * @brief Internals for Expression.h, not for general consumption | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:29:57 +08:00
										 |  |  | #include <gtsam/nonlinear/Values.h>
 | 
					
						
							| 
									
										
										
										
											2014-10-01 17:25:49 +08:00
										 |  |  | #include <gtsam/base/Matrix.h>
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | #include <boost/foreach.hpp>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace gtsam { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template<typename T> | 
					
						
							|  |  |  | class Expression; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  | class JacobianMap: public std::map<Key, Matrix> { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |   void add(Key key, const Matrix& H) { | 
					
						
							|  |  |  |     iterator it = find(key); | 
					
						
							|  |  |  |     if (it != end()) | 
					
						
							|  |  |  |       it->second += H; | 
					
						
							|  |  |  |     else | 
					
						
							|  |  |  |       insert(std::make_pair(key, H)); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | /**
 | 
					
						
							|  |  |  |  * Value and Jacobians | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | class Augmented { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   T value_; | 
					
						
							|  |  |  |   JacobianMap jacobians_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   typedef std::pair<Key, Matrix> Pair; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 21:00:10 +08:00
										 |  |  |   /// Insert terms into jacobians_, premultiplying by H, adding if already exists
 | 
					
						
							|  |  |  |   void add(const JacobianMap& terms) { | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |     BOOST_FOREACH(const Pair& term, terms) | 
					
						
							|  |  |  |       jacobians_.add(term.first, term.second); | 
					
						
							| 
									
										
										
										
											2014-10-05 21:00:10 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Insert terms into jacobians_, premultiplying by H, adding if already exists
 | 
					
						
							|  |  |  |   void add(const Matrix& H, const JacobianMap& terms) { | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |     BOOST_FOREACH(const Pair& term, terms) | 
					
						
							|  |  |  |       jacobians_.add(term.first, H * term.second); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Construct value that does not depend on anything
 | 
					
						
							|  |  |  |   Augmented(const T& t) : | 
					
						
							|  |  |  |       value_(t) { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Construct value dependent on a single key
 | 
					
						
							|  |  |  |   Augmented(const T& t, Key key) : | 
					
						
							|  |  |  |       value_(t) { | 
					
						
							|  |  |  |     size_t n = t.dim(); | 
					
						
							|  |  |  |     jacobians_[key] = Eigen::MatrixXd::Identity(n, n); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 19:37:51 +08:00
										 |  |  |   /// Construct value dependent on a single key, with Jacobain H
 | 
					
						
							|  |  |  |   Augmented(const T& t, Key key, const Matrix& H) : | 
					
						
							|  |  |  |       value_(t) { | 
					
						
							|  |  |  |     jacobians_[key] = H; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Construct value, pre-multiply jacobians by H
 | 
					
						
							|  |  |  |   Augmented(const T& t, const Matrix& H, const JacobianMap& jacobians) : | 
					
						
							|  |  |  |       value_(t) { | 
					
						
							|  |  |  |     add(H, jacobians); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Construct value, pre-multiply jacobians by H
 | 
					
						
							|  |  |  |   Augmented(const T& t, const Matrix& H1, const JacobianMap& jacobians1, | 
					
						
							|  |  |  |       const Matrix& H2, const JacobianMap& jacobians2) : | 
					
						
							|  |  |  |       value_(t) { | 
					
						
							|  |  |  |     add(H1, jacobians1); | 
					
						
							|  |  |  |     add(H2, jacobians2); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return value
 | 
					
						
							|  |  |  |   const T& value() const { | 
					
						
							|  |  |  |     return value_; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return jacobians
 | 
					
						
							|  |  |  |   const JacobianMap& jacobians() const { | 
					
						
							|  |  |  |     return jacobians_; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |   /// Return jacobians
 | 
					
						
							|  |  |  |   JacobianMap& jacobians() { | 
					
						
							|  |  |  |     return jacobians_; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Not dependent on any key
 | 
					
						
							|  |  |  |   bool constant() const { | 
					
						
							|  |  |  |     return jacobians_.empty(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// debugging
 | 
					
						
							|  |  |  |   void print(const KeyFormatter& keyFormatter = DefaultKeyFormatter) { | 
					
						
							|  |  |  |     BOOST_FOREACH(const Pair& term, jacobians_) | 
					
						
							|  |  |  |       std::cout << "(" << keyFormatter(term.first) << ", " << term.second.rows() | 
					
						
							|  |  |  |           << "x" << term.second.cols() << ") "; | 
					
						
							|  |  |  |     std::cout << std::endl; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | struct JacobianTrace { | 
					
						
							|  |  |  |   T t; | 
					
						
							|  |  |  |   T value() const { | 
					
						
							|  |  |  |     return t; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-06 01:27:52 +08:00
										 |  |  |   virtual void reverseAD(JacobianMap& jacobians) const = 0; | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |   virtual void reverseAD(const Matrix& H, JacobianMap& jacobians) const = 0; | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  * Expression node. The superclass for objects that do the heavy lifting | 
					
						
							|  |  |  |  * An Expression<T> has a pointer to an ExpressionNode<T> underneath | 
					
						
							|  |  |  |  * allowing Expressions to have polymorphic behaviour even though they | 
					
						
							|  |  |  |  * are passed by value. This is the same way boost::function works. | 
					
						
							|  |  |  |  * http://loki-lib.sourceforge.net/html/a00652.html
 | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | class ExpressionNode { | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | protected: | 
					
						
							|  |  |  |   ExpressionNode() { | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | public: | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Destructor
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   virtual ~ExpressionNode() { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return keys that play in this expression as a set
 | 
					
						
							|  |  |  |   virtual std::set<Key> keys() const = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const = 0; | 
					
						
							| 
									
										
										
										
											2014-10-03 05:28:19 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value and derivatives
 | 
					
						
							| 
									
										
										
										
											2014-10-04 03:13:34 +08:00
										 |  |  |   virtual Augmented<T> forward(const Values& values) const = 0; | 
					
						
							| 
									
										
										
										
											2014-10-03 05:28:19 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 17:22:14 +08:00
										 |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |   virtual boost::shared_ptr<JacobianTrace<T> > traceExecution( | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |       const Values& values) const = 0; | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | /// Constant Expression
 | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | class ConstantExpression: public ExpressionNode<T> { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// The constant value
 | 
					
						
							|  |  |  |   T constant_; | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor with a value, yielding a constant
 | 
					
						
							|  |  |  |   ConstantExpression(const T& value) : | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |       constant_(value) { | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   friend class Expression<T> ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  |   /// Destructor
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   virtual ~ConstantExpression() { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return keys that play in this expression, i.e., the empty set
 | 
					
						
							|  |  |  |   virtual std::set<Key> keys() const { | 
					
						
							|  |  |  |     std::set<Key> keys; | 
					
						
							|  |  |  |     return keys; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							|  |  |  |     return constant_; | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Return value and derivatives
 | 
					
						
							| 
									
										
										
										
											2014-10-04 03:13:34 +08:00
										 |  |  |   virtual Augmented<T> forward(const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-05 17:22:14 +08:00
										 |  |  |     return Augmented<T>(constant_); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 19:41:20 +08:00
										 |  |  |   /// Trace structure for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |   struct Trace: public JacobianTrace<T> { | 
					
						
							| 
									
										
										
										
											2014-10-06 01:27:52 +08:00
										 |  |  |     /// If the expression is just a constant, we do nothing
 | 
					
						
							|  |  |  |     virtual void reverseAD(JacobianMap& jacobians) const { | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     /// Base case: we simply ignore the given df/dT
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |     virtual void reverseAD(const Matrix& H, JacobianMap& jacobians) const { | 
					
						
							| 
									
										
										
										
											2014-10-05 19:41:20 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 17:22:14 +08:00
										 |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |   virtual boost::shared_ptr<JacobianTrace<T> > traceExecution( | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |       const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-05 19:41:20 +08:00
										 |  |  |     boost::shared_ptr<Trace> trace = boost::make_shared<Trace>(); | 
					
						
							|  |  |  |     trace->t = constant_; | 
					
						
							|  |  |  |     return trace; | 
					
						
							| 
									
										
										
										
											2014-10-05 17:22:14 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | /// Leaf Expression
 | 
					
						
							|  |  |  | template<class T> | 
					
						
							|  |  |  | class LeafExpression: public ExpressionNode<T> { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// The key into values
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   Key key_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor with a single key
 | 
					
						
							|  |  |  |   LeafExpression(Key key) : | 
					
						
							|  |  |  |       key_(key) { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   friend class Expression<T> ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 19:00:37 +08:00
										 |  |  |   /// Destructor
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   virtual ~LeafExpression() { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return keys that play in this expression
 | 
					
						
							|  |  |  |   virtual std::set<Key> keys() const { | 
					
						
							|  |  |  |     std::set<Key> keys; | 
					
						
							|  |  |  |     keys.insert(key_); | 
					
						
							|  |  |  |     return keys; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							|  |  |  |     return values.at<T>(key_); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return value and derivatives
 | 
					
						
							| 
									
										
										
										
											2014-10-04 03:13:34 +08:00
										 |  |  |   virtual Augmented<T> forward(const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |     T t = value(values); | 
					
						
							|  |  |  |     return Augmented<T>(t, key_); | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 19:37:51 +08:00
										 |  |  |   /// Trace structure for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |   struct Trace: public JacobianTrace<T> { | 
					
						
							| 
									
										
										
										
											2014-10-05 19:37:51 +08:00
										 |  |  |     Key key; | 
					
						
							| 
									
										
										
										
											2014-10-06 01:27:52 +08:00
										 |  |  |     /// If the expression is just a leaf, we just insert an identity matrix
 | 
					
						
							|  |  |  |     virtual void reverseAD(JacobianMap& jacobians) const { | 
					
						
							|  |  |  |       size_t n = T::Dim(); | 
					
						
							|  |  |  |       jacobians.add(key, Eigen::MatrixXd::Identity(n, n)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     /// Base case: given df/dT, add it jacobians with correct key and we are done
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |     virtual void reverseAD(const Matrix& H, JacobianMap& jacobians) const { | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |       jacobians.add(key, H); | 
					
						
							| 
									
										
										
										
											2014-10-05 19:37:51 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |   virtual boost::shared_ptr<JacobianTrace<T> > traceExecution( | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |       const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-05 19:37:51 +08:00
										 |  |  |     boost::shared_ptr<Trace> trace = boost::make_shared<Trace>(); | 
					
						
							|  |  |  |     trace->t = value(values); | 
					
						
							|  |  |  |     trace->key = key_; | 
					
						
							|  |  |  |     return trace; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							|  |  |  | /// Unary Function Expression
 | 
					
						
							|  |  |  | template<class T, class A> | 
					
						
							| 
									
										
										
										
											2014-10-03 19:18:25 +08:00
										 |  |  | class UnaryExpression: public ExpressionNode<T> { | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:52:35 +08:00
										 |  |  |   typedef boost::function<T(const A&, boost::optional<Matrix&>)> Function; | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:52:35 +08:00
										 |  |  |   Function function_; | 
					
						
							| 
									
										
										
										
											2014-10-03 19:18:25 +08:00
										 |  |  |   boost::shared_ptr<ExpressionNode<A> > expressionA_; | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor with a unary function f, and input argument e
 | 
					
						
							| 
									
										
										
										
											2014-10-03 19:18:25 +08:00
										 |  |  |   UnaryExpression(Function f, const Expression<A>& e) : | 
					
						
							|  |  |  |       function_(f), expressionA_(e.root()) { | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   friend class Expression<T> ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Destructor
 | 
					
						
							| 
									
										
										
										
											2014-10-03 19:18:25 +08:00
										 |  |  |   virtual ~UnaryExpression() { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return keys that play in this expression
 | 
					
						
							|  |  |  |   virtual std::set<Key> keys() const { | 
					
						
							|  |  |  |     return expressionA_->keys(); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  |     return function_(this->expressionA_->value(values), boost::none); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return value and derivatives
 | 
					
						
							| 
									
										
										
										
											2014-10-04 03:13:34 +08:00
										 |  |  |   virtual Augmented<T> forward(const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |     using boost::none; | 
					
						
							| 
									
										
										
										
											2014-10-04 03:13:34 +08:00
										 |  |  |     Augmented<A> argument = this->expressionA_->forward(values); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |     Matrix H; | 
					
						
							| 
									
										
										
										
											2014-10-03 18:40:26 +08:00
										 |  |  |     T t = function_(argument.value(), | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |         argument.constant() ? none : boost::optional<Matrix&>(H)); | 
					
						
							|  |  |  |     return Augmented<T>(t, H, argument.jacobians()); | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 19:33:23 +08:00
										 |  |  |   /// Trace structure for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |   struct Trace: public JacobianTrace<T> { | 
					
						
							|  |  |  |     boost::shared_ptr<JacobianTrace<A> > trace1; | 
					
						
							| 
									
										
										
										
											2014-10-05 19:33:23 +08:00
										 |  |  |     Matrix H1; | 
					
						
							| 
									
										
										
										
											2014-10-06 01:27:52 +08:00
										 |  |  |     /// Start the reverse AD process
 | 
					
						
							|  |  |  |     virtual void reverseAD(JacobianMap& jacobians) const { | 
					
						
							|  |  |  |       trace1->reverseAD(H1, jacobians); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     /// Given df/dT, multiply in dT/dA and continue reverse AD process
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |     virtual void reverseAD(const Matrix& H, JacobianMap& jacobians) const { | 
					
						
							|  |  |  |       trace1->reverseAD(H * H1, jacobians); | 
					
						
							| 
									
										
										
										
											2014-10-05 19:33:23 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |   virtual boost::shared_ptr<JacobianTrace<T> > traceExecution( | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |       const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-05 19:33:23 +08:00
										 |  |  |     boost::shared_ptr<Trace> trace = boost::make_shared<Trace>(); | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |     trace->trace1 = this->expressionA_->traceExecution(values); | 
					
						
							| 
									
										
										
										
											2014-10-05 19:33:23 +08:00
										 |  |  |     trace->t = function_(trace->trace1->value(), trace->H1); | 
					
						
							|  |  |  |     return trace; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  | /// Binary Expression
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  | template<class T, class A1, class A2> | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | class BinaryExpression: public ExpressionNode<T> { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  |   typedef boost::function< | 
					
						
							|  |  |  |       T(const A1&, const A2&, boost::optional<Matrix&>, | 
					
						
							| 
									
										
										
										
											2014-10-03 18:52:35 +08:00
										 |  |  |           boost::optional<Matrix&>)> Function; | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:52:35 +08:00
										 |  |  |   Function function_; | 
					
						
							| 
									
										
										
										
											2014-10-03 19:18:25 +08:00
										 |  |  |   boost::shared_ptr<ExpressionNode<A1> > expressionA1_; | 
					
						
							|  |  |  |   boost::shared_ptr<ExpressionNode<A2> > expressionA2_; | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor with a binary function f, and two input arguments
 | 
					
						
							| 
									
										
										
										
											2014-10-03 19:18:25 +08:00
										 |  |  |   BinaryExpression(Function f, //
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  |       const Expression<A1>& e1, const Expression<A2>& e2) : | 
					
						
							| 
									
										
										
										
											2014-10-03 19:18:25 +08:00
										 |  |  |       function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) { | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   friend class Expression<T> ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Destructor
 | 
					
						
							| 
									
										
										
										
											2014-10-03 19:18:25 +08:00
										 |  |  |   virtual ~BinaryExpression() { | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return keys that play in this expression
 | 
					
						
							|  |  |  |   virtual std::set<Key> keys() const { | 
					
						
							|  |  |  |     std::set<Key> keys1 = expressionA1_->keys(); | 
					
						
							|  |  |  |     std::set<Key> keys2 = expressionA2_->keys(); | 
					
						
							|  |  |  |     keys1.insert(keys2.begin(), keys2.end()); | 
					
						
							|  |  |  |     return keys1; | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   /// Return value
 | 
					
						
							|  |  |  |   virtual T value(const Values& values) const { | 
					
						
							|  |  |  |     using boost::none; | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  |     return function_(this->expressionA1_->value(values), | 
					
						
							|  |  |  |         this->expressionA2_->value(values), none, none); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Return value and derivatives
 | 
					
						
							| 
									
										
										
										
											2014-10-04 03:13:34 +08:00
										 |  |  |   virtual Augmented<T> forward(const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |     using boost::none; | 
					
						
							| 
									
										
										
										
											2014-10-04 03:13:34 +08:00
										 |  |  |     Augmented<A1> argument1 = this->expressionA1_->forward(values); | 
					
						
							|  |  |  |     Augmented<A2> argument2 = this->expressionA2_->forward(values); | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |     Matrix H1, H2; | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  |     T t = function_(argument1.value(), argument2.value(), | 
					
						
							| 
									
										
										
										
											2014-10-03 16:25:02 +08:00
										 |  |  |         argument1.constant() ? none : boost::optional<Matrix&>(H1), | 
					
						
							|  |  |  |         argument2.constant() ? none : boost::optional<Matrix&>(H2)); | 
					
						
							|  |  |  |     return Augmented<T>(t, H1, argument1.jacobians(), H2, argument2.jacobians()); | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-05 19:27:41 +08:00
										 |  |  |   /// Trace structure for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |   struct Trace: public JacobianTrace<T> { | 
					
						
							|  |  |  |     boost::shared_ptr<JacobianTrace<A1> > trace1; | 
					
						
							|  |  |  |     boost::shared_ptr<JacobianTrace<A2> > trace2; | 
					
						
							| 
									
										
										
										
											2014-10-05 19:27:41 +08:00
										 |  |  |     Matrix H1, H2; | 
					
						
							| 
									
										
										
										
											2014-10-06 01:27:52 +08:00
										 |  |  |     /// Start the reverse AD process
 | 
					
						
							|  |  |  |     virtual void reverseAD(JacobianMap& jacobians) const { | 
					
						
							|  |  |  |       trace1->reverseAD(H1, jacobians); | 
					
						
							|  |  |  |       trace2->reverseAD(H2, jacobians); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     /// Given df/dT, multiply in dT/dA and continue reverse AD process
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |     virtual void reverseAD(const Matrix& H, JacobianMap& jacobians) const { | 
					
						
							|  |  |  |       trace1->reverseAD(H * H1, jacobians); | 
					
						
							|  |  |  |       trace2->reverseAD(H * H2, jacobians); | 
					
						
							| 
									
										
										
										
											2014-10-05 19:27:41 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Construct an execution trace for reverse AD
 | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |   virtual boost::shared_ptr<JacobianTrace<T> > traceExecution( | 
					
						
							| 
									
										
										
										
											2014-10-05 23:12:38 +08:00
										 |  |  |       const Values& values) const { | 
					
						
							| 
									
										
										
										
											2014-10-05 19:27:41 +08:00
										 |  |  |     boost::shared_ptr<Trace> trace = boost::make_shared<Trace>(); | 
					
						
							| 
									
										
										
										
											2014-10-05 23:20:55 +08:00
										 |  |  |     trace->trace1 = this->expressionA1_->traceExecution(values); | 
					
						
							|  |  |  |     trace->trace2 = this->expressionA2_->traceExecution(values); | 
					
						
							| 
									
										
										
										
											2014-10-05 19:27:41 +08:00
										 |  |  |     trace->t = function_(trace->trace1->value(), trace->trace2->value(), | 
					
						
							|  |  |  |         trace->H1, trace->H2); | 
					
						
							|  |  |  |     return trace; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-06 01:09:16 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-03 18:48:28 +08:00
										 |  |  | //-----------------------------------------------------------------------------
 | 
					
						
							| 
									
										
										
										
											2014-09-30 19:19:44 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-09-30 18:20:02 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 |