TernaryExpression now without MPL
parent
e1b3a11957
commit
4a8dbd689a
|
@ -648,56 +648,108 @@ public:
|
|||
/// Ternary Expression
|
||||
|
||||
template<class T, class A1, class A2, class A3>
|
||||
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type {
|
||||
|
||||
typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type Base;
|
||||
typedef typename Base::Record Record;
|
||||
|
||||
private:
|
||||
class TernaryExpression: public ExpressionNode<T> {
|
||||
|
||||
typedef typename Expression<T>::template TernaryFunction<A1, A2, A3>::type Function;
|
||||
boost::shared_ptr<ExpressionNode<A1> > expression1_;
|
||||
boost::shared_ptr<ExpressionNode<A2> > expression2_;
|
||||
boost::shared_ptr<ExpressionNode<A3> > expression3_;
|
||||
Function function_;
|
||||
|
||||
/// Constructor with a ternary function f, and three input arguments
|
||||
public:
|
||||
|
||||
/// Constructor with a ternary function f, and two input arguments
|
||||
TernaryExpression(Function f, const Expression<A1>& e1,
|
||||
const Expression<A2>& e2, const Expression<A3>& e3) :
|
||||
expression1_(e1.root()), expression2_(e2.root()), expression3_(e3.root()), //
|
||||
function_(f) {
|
||||
this->template reset<A1, 1>(e1.root());
|
||||
this->template reset<A2, 2>(e2.root());
|
||||
this->template reset<A3, 3>(e3.root());
|
||||
ExpressionNode<T>::traceSize_ = //
|
||||
upAligned(sizeof(Record)) + e1.traceSize() + e2.traceSize()
|
||||
+ e3.traceSize();
|
||||
ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + //
|
||||
e1.traceSize() + e2.traceSize() + e3.traceSize();
|
||||
}
|
||||
|
||||
friend class Expression<T> ;
|
||||
|
||||
public:
|
||||
|
||||
/// Return value
|
||||
virtual T value(const Values& values) const {
|
||||
using boost::none;
|
||||
return function_(this->template expression<A1, 1>()->value(values),
|
||||
this->template expression<A2, 2>()->value(values),
|
||||
this->template expression<A3, 3>()->value(values),
|
||||
none, none, none);
|
||||
return function_(expression1_->value(values), expression2_->value(values),
|
||||
expression3_->value(values), none, none, none);
|
||||
}
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
/// Return keys that play in this expression
|
||||
virtual std::set<Key> keys() const {
|
||||
std::set<Key> keys = expression1_->keys();
|
||||
std::set<Key> myKeys = expression2_->keys();
|
||||
keys.insert(myKeys.begin(), myKeys.end());
|
||||
myKeys = expression3_->keys();
|
||||
keys.insert(myKeys.begin(), myKeys.end());
|
||||
return keys;
|
||||
}
|
||||
|
||||
/// Return dimensions for each argument
|
||||
virtual void dims(std::map<Key, int>& map) const {
|
||||
expression1_->dims(map);
|
||||
expression2_->dims(map);
|
||||
expression3_->dims(map);
|
||||
}
|
||||
|
||||
// Inner Record Class
|
||||
struct Record: public CallRecordImplementor<Record, traits<T>::dimension> {
|
||||
|
||||
A1 value1;
|
||||
ExecutionTrace<A1> trace1;
|
||||
typename Jacobian<T, A1>::type dTdA1;
|
||||
|
||||
A2 value2;
|
||||
ExecutionTrace<A2> trace2;
|
||||
typename Jacobian<T, A2>::type dTdA2;
|
||||
|
||||
A3 value3;
|
||||
ExecutionTrace<A3> trace3;
|
||||
typename Jacobian<T, A3>::type dTdA3;
|
||||
|
||||
/// Print to std::cout
|
||||
void print(const std::string& indent) const {
|
||||
std::cout << indent << "TernaryExpression::Record {" << std::endl;
|
||||
std::cout << indent << dTdA1.format(kMatlabFormat) << std::endl;
|
||||
trace1.print(indent);
|
||||
std::cout << indent << dTdA2.format(kMatlabFormat) << std::endl;
|
||||
trace2.print(indent);
|
||||
std::cout << indent << dTdA3.format(kMatlabFormat) << std::endl;
|
||||
trace3.print(indent);
|
||||
std::cout << indent << "}" << std::endl;
|
||||
}
|
||||
|
||||
/// Start the reverse AD process, see comments in Base
|
||||
void startReverseAD4(JacobianMap& jacobians) const {
|
||||
trace1.reverseAD1(dTdA1, jacobians);
|
||||
trace2.reverseAD1(dTdA2, jacobians);
|
||||
trace3.reverseAD1(dTdA3, jacobians);
|
||||
}
|
||||
|
||||
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||
template<typename SomeMatrix>
|
||||
void reverseAD4(const SomeMatrix & dFdT, JacobianMap& jacobians) const {
|
||||
trace1.reverseAD1(dFdT * dTdA1, jacobians);
|
||||
trace2.reverseAD1(dFdT * dTdA2, jacobians);
|
||||
trace3.reverseAD1(dFdT * dTdA3, jacobians);
|
||||
}
|
||||
};
|
||||
|
||||
/// Construct an execution trace for reverse AD, see UnaryExpression for explanation
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
ExecutionTraceStorage* traceStorage) const {
|
||||
|
||||
Record* record = Base::trace(values, traceStorage);
|
||||
ExecutionTraceStorage* ptr) const {
|
||||
assert(reinterpret_cast<size_t>(ptr) % TraceAlignment == 0);
|
||||
Record* record = new (ptr) Record();
|
||||
ptr += upAligned(sizeof(Record));
|
||||
record->value1 = expression1_->traceExecution(values, record->trace1, ptr);
|
||||
record->value2 = expression2_->traceExecution(values, record->trace2, ptr);
|
||||
record->value3 = expression3_->traceExecution(values, record->trace3, ptr);
|
||||
ptr += expression1_->traceSize() + expression2_->traceSize()
|
||||
+ expression3_->traceSize();
|
||||
trace.setFunction(record);
|
||||
|
||||
return function_(
|
||||
record->template value<A1, 1>(), record->template value<A2, 2>(),
|
||||
record->template value<A3, 3>(), record->template jacobian<A1, 1>(),
|
||||
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
|
||||
return function_(record->value1, record->value2, record->value3,
|
||||
record->dTdA1, record->dTdA2, record->dTdA3);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
// namespace internal
|
||||
}// namespace gtsam
|
||||
} // namespace internal
|
||||
} // namespace gtsam
|
||||
|
|
Loading…
Reference in New Issue