TernaryExpression now without MPL

release/4.3a0
dellaert 2015-05-12 01:25:34 -07:00
parent e1b3a11957
commit 4a8dbd689a
1 changed files with 86 additions and 34 deletions

View File

@ -648,56 +648,108 @@ public:
/// Ternary Expression
template<class T, class A1, class A2, class A3>
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type {
typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type Base;
typedef typename Base::Record Record;
private:
class TernaryExpression: public ExpressionNode<T> {
typedef typename Expression<T>::template TernaryFunction<A1, A2, A3>::type Function;
boost::shared_ptr<ExpressionNode<A1> > expression1_;
boost::shared_ptr<ExpressionNode<A2> > expression2_;
boost::shared_ptr<ExpressionNode<A3> > expression3_;
Function function_;
/// Constructor with a ternary function f, and three input arguments
public:
/// Constructor with a ternary function f, and two input arguments
TernaryExpression(Function f, const Expression<A1>& e1,
const Expression<A2>& e2, const Expression<A3>& e3) :
expression1_(e1.root()), expression2_(e2.root()), expression3_(e3.root()), //
function_(f) {
this->template reset<A1, 1>(e1.root());
this->template reset<A2, 2>(e2.root());
this->template reset<A3, 3>(e3.root());
ExpressionNode<T>::traceSize_ = //
upAligned(sizeof(Record)) + e1.traceSize() + e2.traceSize()
+ e3.traceSize();
ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + //
e1.traceSize() + e2.traceSize() + e3.traceSize();
}
friend class Expression<T> ;
public:
/// Return value
virtual T value(const Values& values) const {
using boost::none;
return function_(this->template expression<A1, 1>()->value(values),
this->template expression<A2, 2>()->value(values),
this->template expression<A3, 3>()->value(values),
none, none, none);
return function_(expression1_->value(values), expression2_->value(values),
expression3_->value(values), none, none, none);
}
/// Construct an execution trace for reverse AD
/// Return keys that play in this expression
virtual std::set<Key> keys() const {
std::set<Key> keys = expression1_->keys();
std::set<Key> myKeys = expression2_->keys();
keys.insert(myKeys.begin(), myKeys.end());
myKeys = expression3_->keys();
keys.insert(myKeys.begin(), myKeys.end());
return keys;
}
/// Return dimensions for each argument
virtual void dims(std::map<Key, int>& map) const {
expression1_->dims(map);
expression2_->dims(map);
expression3_->dims(map);
}
// Inner Record Class
struct Record: public CallRecordImplementor<Record, traits<T>::dimension> {
A1 value1;
ExecutionTrace<A1> trace1;
typename Jacobian<T, A1>::type dTdA1;
A2 value2;
ExecutionTrace<A2> trace2;
typename Jacobian<T, A2>::type dTdA2;
A3 value3;
ExecutionTrace<A3> trace3;
typename Jacobian<T, A3>::type dTdA3;
/// Print to std::cout
void print(const std::string& indent) const {
std::cout << indent << "TernaryExpression::Record {" << std::endl;
std::cout << indent << dTdA1.format(kMatlabFormat) << std::endl;
trace1.print(indent);
std::cout << indent << dTdA2.format(kMatlabFormat) << std::endl;
trace2.print(indent);
std::cout << indent << dTdA3.format(kMatlabFormat) << std::endl;
trace3.print(indent);
std::cout << indent << "}" << std::endl;
}
/// Start the reverse AD process, see comments in Base
void startReverseAD4(JacobianMap& jacobians) const {
trace1.reverseAD1(dTdA1, jacobians);
trace2.reverseAD1(dTdA2, jacobians);
trace3.reverseAD1(dTdA3, jacobians);
}
/// Given df/dT, multiply in dT/dA and continue reverse AD process
template<typename SomeMatrix>
void reverseAD4(const SomeMatrix & dFdT, JacobianMap& jacobians) const {
trace1.reverseAD1(dFdT * dTdA1, jacobians);
trace2.reverseAD1(dFdT * dTdA2, jacobians);
trace3.reverseAD1(dFdT * dTdA3, jacobians);
}
};
/// Construct an execution trace for reverse AD, see UnaryExpression for explanation
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
ExecutionTraceStorage* traceStorage) const {
Record* record = Base::trace(values, traceStorage);
ExecutionTraceStorage* ptr) const {
assert(reinterpret_cast<size_t>(ptr) % TraceAlignment == 0);
Record* record = new (ptr) Record();
ptr += upAligned(sizeof(Record));
record->value1 = expression1_->traceExecution(values, record->trace1, ptr);
record->value2 = expression2_->traceExecution(values, record->trace2, ptr);
record->value3 = expression3_->traceExecution(values, record->trace3, ptr);
ptr += expression1_->traceSize() + expression2_->traceSize()
+ expression3_->traceSize();
trace.setFunction(record);
return function_(
record->template value<A1, 1>(), record->template value<A2, 2>(),
record->template value<A3, 3>(), record->template jacobian<A1, 1>(),
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
return function_(record->value1, record->value2, record->value3,
record->dTdA1, record->dTdA2, record->dTdA3);
}
};
}
// namespace internal
}// namespace gtsam
} // namespace internal
} // namespace gtsam