TernaryExpression now without MPL
parent
e1b3a11957
commit
4a8dbd689a
|
@ -648,56 +648,108 @@ public:
|
||||||
/// Ternary Expression
|
/// Ternary Expression
|
||||||
|
|
||||||
template<class T, class A1, class A2, class A3>
|
template<class T, class A1, class A2, class A3>
|
||||||
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type {
|
class TernaryExpression: public ExpressionNode<T> {
|
||||||
|
|
||||||
typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type Base;
|
|
||||||
typedef typename Base::Record Record;
|
|
||||||
|
|
||||||
private:
|
|
||||||
|
|
||||||
typedef typename Expression<T>::template TernaryFunction<A1, A2, A3>::type Function;
|
typedef typename Expression<T>::template TernaryFunction<A1, A2, A3>::type Function;
|
||||||
|
boost::shared_ptr<ExpressionNode<A1> > expression1_;
|
||||||
|
boost::shared_ptr<ExpressionNode<A2> > expression2_;
|
||||||
|
boost::shared_ptr<ExpressionNode<A3> > expression3_;
|
||||||
Function function_;
|
Function function_;
|
||||||
|
|
||||||
/// Constructor with a ternary function f, and three input arguments
|
public:
|
||||||
|
|
||||||
|
/// Constructor with a ternary function f, and two input arguments
|
||||||
TernaryExpression(Function f, const Expression<A1>& e1,
|
TernaryExpression(Function f, const Expression<A1>& e1,
|
||||||
const Expression<A2>& e2, const Expression<A3>& e3) :
|
const Expression<A2>& e2, const Expression<A3>& e3) :
|
||||||
|
expression1_(e1.root()), expression2_(e2.root()), expression3_(e3.root()), //
|
||||||
function_(f) {
|
function_(f) {
|
||||||
this->template reset<A1, 1>(e1.root());
|
ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + //
|
||||||
this->template reset<A2, 2>(e2.root());
|
e1.traceSize() + e2.traceSize() + e3.traceSize();
|
||||||
this->template reset<A3, 3>(e3.root());
|
|
||||||
ExpressionNode<T>::traceSize_ = //
|
|
||||||
upAligned(sizeof(Record)) + e1.traceSize() + e2.traceSize()
|
|
||||||
+ e3.traceSize();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class Expression<T> ;
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
using boost::none;
|
using boost::none;
|
||||||
return function_(this->template expression<A1, 1>()->value(values),
|
return function_(expression1_->value(values), expression2_->value(values),
|
||||||
this->template expression<A2, 2>()->value(values),
|
expression3_->value(values), none, none, none);
|
||||||
this->template expression<A3, 3>()->value(values),
|
|
||||||
none, none, none);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Return keys that play in this expression
|
||||||
|
virtual std::set<Key> keys() const {
|
||||||
|
std::set<Key> keys = expression1_->keys();
|
||||||
|
std::set<Key> myKeys = expression2_->keys();
|
||||||
|
keys.insert(myKeys.begin(), myKeys.end());
|
||||||
|
myKeys = expression3_->keys();
|
||||||
|
keys.insert(myKeys.begin(), myKeys.end());
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return dimensions for each argument
|
||||||
|
virtual void dims(std::map<Key, int>& map) const {
|
||||||
|
expression1_->dims(map);
|
||||||
|
expression2_->dims(map);
|
||||||
|
expression3_->dims(map);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inner Record Class
|
||||||
|
struct Record: public CallRecordImplementor<Record, traits<T>::dimension> {
|
||||||
|
|
||||||
|
A1 value1;
|
||||||
|
ExecutionTrace<A1> trace1;
|
||||||
|
typename Jacobian<T, A1>::type dTdA1;
|
||||||
|
|
||||||
|
A2 value2;
|
||||||
|
ExecutionTrace<A2> trace2;
|
||||||
|
typename Jacobian<T, A2>::type dTdA2;
|
||||||
|
|
||||||
|
A3 value3;
|
||||||
|
ExecutionTrace<A3> trace3;
|
||||||
|
typename Jacobian<T, A3>::type dTdA3;
|
||||||
|
|
||||||
|
/// Print to std::cout
|
||||||
|
void print(const std::string& indent) const {
|
||||||
|
std::cout << indent << "TernaryExpression::Record {" << std::endl;
|
||||||
|
std::cout << indent << dTdA1.format(kMatlabFormat) << std::endl;
|
||||||
|
trace1.print(indent);
|
||||||
|
std::cout << indent << dTdA2.format(kMatlabFormat) << std::endl;
|
||||||
|
trace2.print(indent);
|
||||||
|
std::cout << indent << dTdA3.format(kMatlabFormat) << std::endl;
|
||||||
|
trace3.print(indent);
|
||||||
|
std::cout << indent << "}" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the reverse AD process, see comments in Base
|
||||||
|
void startReverseAD4(JacobianMap& jacobians) const {
|
||||||
|
trace1.reverseAD1(dTdA1, jacobians);
|
||||||
|
trace2.reverseAD1(dTdA2, jacobians);
|
||||||
|
trace3.reverseAD1(dTdA3, jacobians);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||||
|
template<typename SomeMatrix>
|
||||||
|
void reverseAD4(const SomeMatrix & dFdT, JacobianMap& jacobians) const {
|
||||||
|
trace1.reverseAD1(dFdT * dTdA1, jacobians);
|
||||||
|
trace2.reverseAD1(dFdT * dTdA2, jacobians);
|
||||||
|
trace3.reverseAD1(dFdT * dTdA3, jacobians);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Construct an execution trace for reverse AD, see UnaryExpression for explanation
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
ExecutionTraceStorage* traceStorage) const {
|
ExecutionTraceStorage* ptr) const {
|
||||||
|
assert(reinterpret_cast<size_t>(ptr) % TraceAlignment == 0);
|
||||||
Record* record = Base::trace(values, traceStorage);
|
Record* record = new (ptr) Record();
|
||||||
|
ptr += upAligned(sizeof(Record));
|
||||||
|
record->value1 = expression1_->traceExecution(values, record->trace1, ptr);
|
||||||
|
record->value2 = expression2_->traceExecution(values, record->trace2, ptr);
|
||||||
|
record->value3 = expression3_->traceExecution(values, record->trace3, ptr);
|
||||||
|
ptr += expression1_->traceSize() + expression2_->traceSize()
|
||||||
|
+ expression3_->traceSize();
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
return function_(record->value1, record->value2, record->value3,
|
||||||
return function_(
|
record->dTdA1, record->dTdA2, record->dTdA3);
|
||||||
record->template value<A1, 1>(), record->template value<A2, 2>(),
|
|
||||||
record->template value<A3, 3>(), record->template jacobian<A1, 1>(),
|
|
||||||
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
} // namespace internal
|
||||||
// namespace internal
|
} // namespace gtsam
|
||||||
}// namespace gtsam
|
|
||||||
|
|
Loading…
Reference in New Issue