CppADCodeGen  HEAD
A C++ Algorithmic Differentiation Package with Source Code Generation
evaluator_ad.hpp
1 #ifndef CPPAD_CG_EVALUATOR_AD_INCLUDED
2 #define CPPAD_CG_EVALUATOR_AD_INCLUDED
3 /* --------------------------------------------------------------------------
4  * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5  * Copyright (C) 2016 Ciengis
6  * Copyright (C) 2020 Joao Leal
7  *
8  * CppADCodeGen is distributed under multiple licenses:
9  *
10  * - Eclipse Public License Version 1.0 (EPL1), and
11  * - GNU General Public License Version 3 (GPL3).
12  *
13  * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
14  * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
15  * ----------------------------------------------------------------------------
16  * Author: Joao Leal
17  */
18 
19 namespace CppAD {
20 namespace cg {
21 
26 template<class ScalarIn, class ScalarOut, class FinalEvaluatorType>
27 class EvaluatorAD : public EvaluatorOperations<ScalarIn, ScalarOut, CppAD::AD<ScalarOut>, FinalEvaluatorType> {
34 public:
37  using ArgIn = Argument<ScalarIn>;
39 protected:
40  using Super::handler_;
41  using Super::evalArrayCreationOperation;
42 protected:
43  std::set<NodeIn*> evalsAtomic_;
44  std::map<size_t, CppAD::atomic_base<ScalarOut>* > atomicFunctions_;
50 public:
51 
52  inline EvaluatorAD(CodeHandler<ScalarIn>& handler) :
53  Super(handler),
54  printOutPriOperations_(true) {
55  }
56 
57  inline virtual ~EvaluatorAD() = default;
58 
63  inline void setPrintOutPrintOperations(bool print) {
64  printOutPriOperations_ = print;
65  }
66 
71  inline bool isPrintOutPrintOperations() const {
73  }
74 
83  virtual bool addAtomicFunction(size_t id, atomic_base<ScalarOut>& atomic) {
84  bool exists = atomicFunctions_.find(id) != atomicFunctions_.end();
85  atomicFunctions_[id] = &atomic;
86  return exists;
87  }
88 
89  virtual void addAtomicFunctions(const std::map<size_t, atomic_base<ScalarOut>* >& atomics) {
90  for (const auto& it : atomics) {
91  atomic_base<ScalarOut>* atomic = it.second;
92  if (atomic != nullptr) {
93  atomicFunctions_[it.first] = atomic;
94  }
95  }
96  }
97 
105  return evalsAtomic_.size();
106  }
107 
108 protected:
109 
114  inline void prepareNewEvaluation() {
120  Super::prepareNewEvaluation();
121 
122  evalsAtomic_.clear();
123  }
124 
131  inline void evalAtomicOperation(NodeIn& node) {
132 
133  if (evalsAtomic_.find(&node) != evalsAtomic_.end()) {
134  return;
135  }
136 
137  if (node.getOperationType() != CGOpCode::AtomicForward) {
138  throw CGException("Evaluator can only handle zero forward mode for atomic functions");
139  }
140 
141  const std::vector<size_t>& info = node.getInfo();
142  const std::vector<Argument<ScalarIn> >& args = node.getArguments();
143  CPPADCG_ASSERT_KNOWN(args.size() == 2, "Invalid number of arguments for atomic forward mode")
144  CPPADCG_ASSERT_KNOWN(info.size() == 3, "Invalid number of information data for atomic forward mode")
145 
146  // find the atomic function
147  size_t id = info[0];
148  typename std::map<size_t, atomic_base<ScalarOut>* >::const_iterator itaf = atomicFunctions_.find(id);
149  atomic_base<ScalarOut>* atomicFunction = nullptr;
150  if (itaf != atomicFunctions_.end()) {
151  atomicFunction = itaf->second;
152  }
153 
154  if (atomicFunction == nullptr) {
155  std::stringstream ss;
156  ss << "No atomic function defined in the evaluator for ";
157  const std::string & atomName = handler_.getAtomicFunctionName(id);
158  if (!atomName.empty()) {
159  ss << "'" << atomName << "'";
160  } else
161  ss << "id '" << id << "'";
162  throw CGException(ss.str());
163  }
164 
165  size_t p = info[2];
166  if (p != 0) {
167  throw CGException("Evaluator can only handle zero forward mode for atomic functions");
168  }
169  const std::vector<ActiveOut>& ax = evalArrayCreationOperation(*args[0].getOperation());
170  std::vector<ActiveOut>& ay = evalArrayCreationOperation(*args[1].getOperation());
171 
172  (*atomicFunction)(ax, ay);
173 
174  evalsAtomic_.insert(&node);
175  }
176 
181  inline ActiveOut evalPrint(const NodeIn& node) {
182  const std::vector<ArgIn>& args = node.getArguments();
183  CPPADCG_ASSERT_KNOWN(args.size() == 1, "Invalid number of arguments for print()")
184  ActiveOut out(this->evalArg(args, 0));
185 
186  const auto& nodePri = static_cast<const PrintOperationNode<ScalarIn>&>(node);
187  if (printOutPriOperations_) {
188  std::cout << nodePri.getBeforeString() << out << nodePri.getAfterString();
189  }
190 
191  CppAD::PrintFor(ActiveOut(0), nodePri.getBeforeString().c_str(), out, nodePri.getAfterString().c_str());
192 
193  return out;
194  }
195 
196 };
197 
201 template<class ScalarIn, class ScalarOut>
202 class Evaluator<ScalarIn, ScalarOut, CppAD::AD<ScalarOut> > : public EvaluatorAD<ScalarIn, ScalarOut, Evaluator<ScalarIn, ScalarOut, CppAD::AD<ScalarOut> > > {
203 public:
206 public:
207 
208  inline Evaluator(CodeHandler<ScalarIn>& handler) :
209  Super(handler) {
210  }
211 
212 };
213 
214 } // END cg namespace
215 } // END CppAD namespace
216 
217 #endif
const std::vector< Argument< Base > > & getArguments() const
void setPrintOutPrintOperations(bool print)
void evalAtomicOperation(NodeIn &node)
CGOpCode getOperationType() const
std::string getAtomicFunctionName(size_t id) const
size_t getNumberOfEvaluatedAtomics() const
bool isPrintOutPrintOperations() const
virtual bool addAtomicFunction(size_t id, atomic_base< ScalarOut > &atomic)
ActiveOut evalPrint(const NodeIn &node)
const std::vector< size_t > & getInfo() const