Switch to pointers - nice improvement
parent
982dc29d2f
commit
3c1c9c6d12
|
@ -136,13 +136,15 @@ public:
|
|||
|
||||
//-----------------------------------------------------------------------------
|
||||
struct JacobianTrace {
|
||||
virtual ~JacobianTrace() {
|
||||
}
|
||||
virtual void reverseAD(JacobianMap& jacobians) const = 0;
|
||||
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const = 0;
|
||||
// template<class JacobianFT>
|
||||
// void reverseAD(const JacobianFT& dFdT, JacobianMap& jacobians) const {
|
||||
};
|
||||
|
||||
typedef boost::shared_ptr<JacobianTrace> TracePtr;
|
||||
typedef JacobianTrace* TracePtr;
|
||||
|
||||
//template <class Derived>
|
||||
//struct TypedTrace {
|
||||
|
@ -235,7 +237,7 @@ public:
|
|||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
boost::shared_ptr<Trace> trace = boost::make_shared<Trace>();
|
||||
Trace* trace = new Trace();
|
||||
return std::make_pair(constant_, trace);
|
||||
}
|
||||
};
|
||||
|
@ -298,7 +300,7 @@ public:
|
|||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
boost::shared_ptr<Trace> trace = boost::make_shared<Trace>();
|
||||
Trace* trace = new Trace();
|
||||
trace->key = key_;
|
||||
return std::make_pair(values.at<T>(key_), trace);
|
||||
}
|
||||
|
@ -357,6 +359,9 @@ public:
|
|||
struct Trace: public JacobianTrace {
|
||||
TracePtr trace;
|
||||
JacobianTA dTdA;
|
||||
virtual ~Trace() {
|
||||
delete trace;
|
||||
}
|
||||
/// Start the reverse AD process
|
||||
virtual void reverseAD(JacobianMap& jacobians) const {
|
||||
trace->reverseAD(dTdA, jacobians);
|
||||
|
@ -370,9 +375,9 @@ public:
|
|||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
A a;
|
||||
boost::shared_ptr<Trace> trace = boost::make_shared<Trace>();
|
||||
Trace* trace = new Trace();
|
||||
boost::tie(a, trace->trace) = this->expressionA_->traceExecution(values);
|
||||
return std::make_pair(function_(a, trace->dTdA),trace);
|
||||
return std::make_pair(function_(a, trace->dTdA), trace);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -443,6 +448,10 @@ public:
|
|||
TracePtr trace1, trace2;
|
||||
JacobianTA1 dTdA1;
|
||||
JacobianTA2 dTdA2;
|
||||
virtual ~Trace() {
|
||||
delete trace1;
|
||||
delete trace2;
|
||||
}
|
||||
/// Start the reverse AD process
|
||||
virtual void reverseAD(JacobianMap& jacobians) const {
|
||||
trace1->reverseAD(dTdA1, jacobians);
|
||||
|
@ -459,7 +468,7 @@ public:
|
|||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
A1 a1;
|
||||
A2 a2;
|
||||
boost::shared_ptr<Trace> trace = boost::make_shared<Trace>();
|
||||
Trace* trace = new Trace();
|
||||
boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values);
|
||||
boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values);
|
||||
return std::make_pair(function_(a1, a2, trace->dTdA1, trace->dTdA2), trace);
|
||||
|
@ -543,12 +552,15 @@ public:
|
|||
|
||||
/// Trace structure for reverse AD
|
||||
struct Trace: public JacobianTrace {
|
||||
TracePtr trace1;
|
||||
TracePtr trace2;
|
||||
TracePtr trace3;
|
||||
TracePtr trace1, trace2, trace3;
|
||||
JacobianTA1 dTdA1;
|
||||
JacobianTA2 dTdA2;
|
||||
JacobianTA3 dTdA3;
|
||||
virtual ~Trace() {
|
||||
delete trace1;
|
||||
delete trace2;
|
||||
delete trace3;
|
||||
}
|
||||
/// Start the reverse AD process
|
||||
virtual void reverseAD(JacobianMap& jacobians) const {
|
||||
trace1->reverseAD(dTdA1, jacobians);
|
||||
|
@ -568,7 +580,7 @@ public:
|
|||
A1 a1;
|
||||
A2 a2;
|
||||
A3 a3;
|
||||
boost::shared_ptr<Trace> trace = boost::make_shared<Trace>();
|
||||
Trace* trace = new Trace();
|
||||
boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values);
|
||||
boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values);
|
||||
boost::tie(a3, trace->trace3) = this->expressionA3_->traceExecution(values);
|
||||
|
|
|
@ -122,6 +122,7 @@ public:
|
|||
boost::tie(value,trace) = root_->traceExecution(values);
|
||||
Augmented<T> augmented(value);
|
||||
trace->reverseAD(augmented.jacobians());
|
||||
delete trace;
|
||||
return augmented;
|
||||
#else
|
||||
return root_->forward(values);
|
||||
|
|
Loading…
Reference in New Issue