diff --git a/gtsam/nonlinear/internal/ExpressionNode.h b/gtsam/nonlinear/internal/ExpressionNode.h index e72445716..65b521d6c 100644 --- a/gtsam/nonlinear/internal/ExpressionNode.h +++ b/gtsam/nonlinear/internal/ExpressionNode.h @@ -648,56 +648,108 @@ public: /// Ternary Expression template -class TernaryExpression: public FunctionalNode >::type { - - typedef typename FunctionalNode >::type Base; - typedef typename Base::Record Record; - -private: +class TernaryExpression: public ExpressionNode { typedef typename Expression::template TernaryFunction::type Function; + boost::shared_ptr > expression1_; + boost::shared_ptr > expression2_; + boost::shared_ptr > expression3_; Function function_; - /// Constructor with a ternary function f, and three input arguments +public: + + /// Constructor with a ternary function f, and two input arguments TernaryExpression(Function f, const Expression& e1, const Expression& e2, const Expression& e3) : + expression1_(e1.root()), expression2_(e2.root()), expression3_(e3.root()), // function_(f) { - this->template reset(e1.root()); - this->template reset(e2.root()); - this->template reset(e3.root()); - ExpressionNode::traceSize_ = // - upAligned(sizeof(Record)) + e1.traceSize() + e2.traceSize() - + e3.traceSize(); + ExpressionNode::traceSize_ = upAligned(sizeof(Record)) + // + e1.traceSize() + e2.traceSize() + e3.traceSize(); } - friend class Expression ; - -public: - /// Return value virtual T value(const Values& values) const { using boost::none; - return function_(this->template expression()->value(values), - this->template expression()->value(values), - this->template expression()->value(values), - none, none, none); + return function_(expression1_->value(values), expression2_->value(values), + expression3_->value(values), none, none, none); } - /// Construct an execution trace for reverse AD + /// Return keys that play in this expression + virtual std::set keys() const { + std::set keys = expression1_->keys(); + std::set myKeys = expression2_->keys(); + keys.insert(myKeys.begin(), myKeys.end()); + myKeys = expression3_->keys(); + keys.insert(myKeys.begin(), myKeys.end()); + return keys; + } + + /// Return dimensions for each argument + virtual void dims(std::map& map) const { + expression1_->dims(map); + expression2_->dims(map); + expression3_->dims(map); + } + + // Inner Record Class + struct Record: public CallRecordImplementor::dimension> { + + A1 value1; + ExecutionTrace trace1; + typename Jacobian::type dTdA1; + + A2 value2; + ExecutionTrace trace2; + typename Jacobian::type dTdA2; + + A3 value3; + ExecutionTrace trace3; + typename Jacobian::type dTdA3; + + /// Print to std::cout + void print(const std::string& indent) const { + std::cout << indent << "TernaryExpression::Record {" << std::endl; + std::cout << indent << dTdA1.format(kMatlabFormat) << std::endl; + trace1.print(indent); + std::cout << indent << dTdA2.format(kMatlabFormat) << std::endl; + trace2.print(indent); + std::cout << indent << dTdA3.format(kMatlabFormat) << std::endl; + trace3.print(indent); + std::cout << indent << "}" << std::endl; + } + + /// Start the reverse AD process, see comments in Base + void startReverseAD4(JacobianMap& jacobians) const { + trace1.reverseAD1(dTdA1, jacobians); + trace2.reverseAD1(dTdA2, jacobians); + trace3.reverseAD1(dTdA3, jacobians); + } + + /// Given df/dT, multiply in dT/dA and continue reverse AD process + template + void reverseAD4(const SomeMatrix & dFdT, JacobianMap& jacobians) const { + trace1.reverseAD1(dFdT * dTdA1, jacobians); + trace2.reverseAD1(dFdT * dTdA2, jacobians); + trace3.reverseAD1(dFdT * dTdA3, jacobians); + } + }; + + /// Construct an execution trace for reverse AD, see UnaryExpression for explanation virtual T traceExecution(const Values& values, ExecutionTrace& trace, - ExecutionTraceStorage* traceStorage) const { - - Record* record = Base::trace(values, traceStorage); + ExecutionTraceStorage* ptr) const { + assert(reinterpret_cast(ptr) % TraceAlignment == 0); + Record* record = new (ptr) Record(); + ptr += upAligned(sizeof(Record)); + record->value1 = expression1_->traceExecution(values, record->trace1, ptr); + record->value2 = expression2_->traceExecution(values, record->trace2, ptr); + record->value3 = expression3_->traceExecution(values, record->trace3, ptr); + ptr += expression1_->traceSize() + expression2_->traceSize() + + expression3_->traceSize(); trace.setFunction(record); - - return function_( - record->template value(), record->template value(), - record->template value(), record->template jacobian(), - record->template jacobian(), record->template jacobian()); + return function_(record->value1, record->value2, record->value3, + record->dTdA1, record->dTdA2, record->dTdA3); } - }; -} - // namespace internal -}// namespace gtsam +} // namespace internal +} // namespace gtsam