1 #ifndef CPPAD_CG_EVALUATOR_AD_INCLUDED 2 #define CPPAD_CG_EVALUATOR_AD_INCLUDED 26 template<
class ScalarIn,
class ScalarOut,
class FinalEvaluatorType>
40 using Super::handler_;
41 using Super::evalArrayCreationOperation;
43 std::set<NodeIn*> evalsAtomic_;
44 std::map<size_t, CppAD::atomic_base<ScalarOut>* > atomicFunctions_;
54 printOutPriOperations_(true) {
64 printOutPriOperations_ = print;
84 bool exists = atomicFunctions_.find(
id) != atomicFunctions_.end();
85 atomicFunctions_[id] = &atomic;
89 virtual void addAtomicFunctions(
const std::map<
size_t, atomic_base<ScalarOut>* >& atomics) {
90 for (
const auto& it : atomics) {
91 atomic_base<ScalarOut>* atomic = it.second;
92 if (atomic !=
nullptr) {
93 atomicFunctions_[it.first] = atomic;
105 return evalsAtomic_.size();
120 Super::prepareNewEvaluation();
122 evalsAtomic_.clear();
133 if (evalsAtomic_.find(&node) != evalsAtomic_.end()) {
138 throw CGException(
"Evaluator can only handle zero forward mode for atomic functions");
141 const std::vector<size_t>& info = node.
getInfo();
142 const std::vector<Argument<ScalarIn> >& args = node.
getArguments();
143 CPPADCG_ASSERT_KNOWN(args.size() == 2,
"Invalid number of arguments for atomic forward mode")
144 CPPADCG_ASSERT_KNOWN(info.size() == 3,
"Invalid number of information data for atomic forward mode")
148 typename std::map<size_t, atomic_base<ScalarOut>* >::const_iterator itaf = atomicFunctions_.find(
id);
149 atomic_base<ScalarOut>* atomicFunction =
nullptr;
150 if (itaf != atomicFunctions_.end()) {
151 atomicFunction = itaf->second;
154 if (atomicFunction ==
nullptr) {
155 std::stringstream ss;
156 ss <<
"No atomic function defined in the evaluator for ";
158 if (!atomName.empty()) {
159 ss <<
"'" << atomName <<
"'";
161 ss <<
"id '" <<
id <<
"'";
167 throw CGException(
"Evaluator can only handle zero forward mode for atomic functions");
169 const std::vector<ActiveOut>& ax = evalArrayCreationOperation(*args[0].getOperation());
170 std::vector<ActiveOut>& ay = evalArrayCreationOperation(*args[1].getOperation());
172 (*atomicFunction)(ax, ay);
174 evalsAtomic_.insert(&node);
183 CPPADCG_ASSERT_KNOWN(args.size() == 1,
"Invalid number of arguments for print()")
187 if (printOutPriOperations_) {
188 std::cout << nodePri.getBeforeString() << out << nodePri.getAfterString();
191 CppAD::PrintFor(
ActiveOut(0), nodePri.getBeforeString().c_str(), out, nodePri.getAfterString().c_str());
201 template<
class ScalarIn,
class ScalarOut>
202 class Evaluator<ScalarIn, ScalarOut,
CppAD::
AD<ScalarOut> > :
public EvaluatorAD<ScalarIn, ScalarOut, Evaluator<ScalarIn, ScalarOut, CppAD::AD<ScalarOut> > > {
const std::vector< Argument< Base > > & getArguments() const
void setPrintOutPrintOperations(bool print)
void evalAtomicOperation(NodeIn &node)
bool printOutPriOperations_
CGOpCode getOperationType() const
void prepareNewEvaluation()
std::string getAtomicFunctionName(size_t id) const
size_t getNumberOfEvaluatedAtomics() const
bool isPrintOutPrintOperations() const
virtual bool addAtomicFunction(size_t id, atomic_base< ScalarOut > &atomic)
ActiveOut evalPrint(const NodeIn &node)
const std::vector< size_t > & getInfo() const