CppADCodeGen  HEAD
A C++ Algorithmic Differentiation Package with Source Code Generation
evaluator_solve.hpp
1 #ifndef CPPAD_CG_EVALUATOR_SOLVE_INCLUDED
2 #define CPPAD_CG_EVALUATOR_SOLVE_INCLUDED
3 /* --------------------------------------------------------------------------
4  * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5  * Copyright (C) 2016 Ciengis
6  *
7  * CppADCodeGen is distributed under multiple licenses:
8  *
9  * - Eclipse Public License Version 1.0 (EPL1), and
10  * - GNU General Public License Version 3 (GPL3).
11  *
12  * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
13  * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
14  * ----------------------------------------------------------------------------
15  * Author: Joao Leal
16  */
17 
18 namespace CppAD {
19 namespace cg {
20 
26 template<class Scalar>
27 class EvaluatorCloneSolve : public EvaluatorCG<Scalar, Scalar, EvaluatorCloneSolve<Scalar>> {
34 public:
35  using ActiveOut = CG<Scalar>;
36  using SourceCodePath = typename CodeHandler<Scalar>::SourceCodePath;
37 protected:
39 private:
44  const std::vector<const SourceCodePath*>* paths_;
49  const std::vector<const std::vector<CG<Scalar>*>*>* replaceOnPath_;
54  const BidirGraph<Scalar>* pathGraph_;
58  const std::map<const PathNodeEdges<Scalar>*, CG<Scalar>>* replaceOnGraph_;
62  const std::set<const OperationNode<Scalar>*>* clone_;
66  const std::map<const OperationPathNode<Scalar>, CG<Scalar>>* replaceArgument_;
67 public:
68 
79  const std::vector<const SourceCodePath*>& paths,
80  const std::vector<const std::vector<CG<Scalar>*>*>& replaceOnPath) :
81  Super(handler),
82  paths_(&paths),
83  replaceOnPath_(&replaceOnPath),
84  pathGraph_(nullptr),
85  replaceOnGraph_(nullptr),
86  clone_(nullptr),
87  replaceArgument_(nullptr) {
88  CPPADCG_ASSERT_UNKNOWN(paths_->size() == replaceOnPath_->size());
89 #ifndef NDEBUG
90  for (size_t i = 0; i < paths.size(); ++i) {
91  CPPADCG_ASSERT_UNKNOWN(paths[i]->size() == replaceOnPath[i]->size());
92  }
93 #endif
94  }
95 
104  const BidirGraph<Scalar>& pathGraph,
105  const std::map<const PathNodeEdges<Scalar>*, CG<Scalar> >& replaceOnGraph) :
106  Super(handler),
107  paths_(nullptr),
108  replaceOnPath_(nullptr),
109  pathGraph_(&pathGraph),
110  replaceOnGraph_(&replaceOnGraph),
111  clone_(nullptr),
112  replaceArgument_(nullptr) {
113  }
114 
123  const std::set<const OperationNode<Scalar>*>& clone,
124  const std::map<const OperationPathNode<Scalar>, CG<Scalar>>& replaceArgument) :
125  Super(handler),
126  paths_(nullptr),
127  replaceOnPath_(nullptr),
128  pathGraph_(nullptr),
129  replaceOnGraph_(nullptr),
130  clone_(&clone),
131  replaceArgument_(&replaceArgument) {
132  }
133 
134 protected:
135 
141  CPPADCG_ASSERT_UNKNOWN(this->depth_ > 0);
142 
143  if(paths_ != nullptr) {
144  const auto& paths = *paths_;
145  for (size_t i = 0; i < paths.size(); ++i) {
146  size_t d = this->depth_ - 1;
147  if (isOnPath(*paths[i])) {
148  // in one of the paths
149 
150  auto* r = (*(*replaceOnPath_)[i])[d];
151  if (r != nullptr) {
152  return *r;
153  } else {
154  return Super::evalOperation(node);
155  }
156  }
157  }
158  }
159 
160  if(pathGraph_ != nullptr) {
161  const PathNodeEdges<Scalar>* egdes = pathGraph_->find(node);
162  if (egdes != nullptr) {
163  auto it = replaceOnGraph_->find(egdes);
164  if (it != replaceOnGraph_->end()) {
165  return it->second;
166  } else {
167  return Super::evalOperation(node);
168  }
169  }
170  }
171 
172  if (clone_ != nullptr) {
173  if (clone_->find(&node) != clone_->end()) {
174  return Super::evalOperation(node);
175  }
176  }
177 
178  if (replaceArgument_ != nullptr) {
179  size_t d = this->depth_ - 1;
180  if (d > 0) {
181  auto it = replaceArgument_->find(this->path_[d - 1]);
182  if (it != replaceArgument_->end()) {
183  return it->second;
184  }
185  }
186  }
187 
188  return CG<Scalar>(node); // use original
189  }
190 
191 private:
192  inline bool isOnPath(const SourceCodePath& path) const {
193  size_t d = this->depth_ - 1;
194 
195  if (d >= path.size())
196  return false;
197 
198  if (this->path_[d].node != path[d].node) // compare only the node
199  return false;
200 
201  if (d > 0) {
202  for (size_t j = 0; j < d; ++j) {
203  if (this->path_[j] != path[j]) { // compare node and argument index
204  return false;
205  }
206  }
207  }
208 
209  return true;
210  }
211 
212 };
213 
214 } // END cg namespace
215 } // END CppAD namespace
216 
217 #endif
ActiveOut evalOperation(OperationNode< ScalarIn > &node)
Definition: evaluator.hpp:374
EvaluatorCloneSolve(CodeHandler< Scalar > &handler, const std::set< const OperationNode< Scalar > *> &clone, const std::map< const OperationPathNode< Scalar >, CG< Scalar >> &replaceArgument)
ActiveOut evalOperation(OperationNode< Scalar > &node)
EvaluatorCloneSolve(CodeHandler< Scalar > &handler, const BidirGraph< Scalar > &pathGraph, const std::map< const PathNodeEdges< Scalar > *, CG< Scalar > > &replaceOnGraph)
EvaluatorCloneSolve(CodeHandler< Scalar > &handler, const std::vector< const SourceCodePath *> &paths, const std::vector< const std::vector< CG< Scalar > *> *> &replaceOnPath)