CppADCodeGen  HEAD
A C++ Algorithmic Differentiation Package with Source Code Generation
operation_path.hpp
1 #ifndef CPPAD_CG_OPERATION_PATH_INCLUDED
2 #define CPPAD_CG_OPERATION_PATH_INCLUDED
3 /* --------------------------------------------------------------------------
4  * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5  * Copyright (C) 2012 Ciengis
6  *
7  * CppADCodeGen is distributed under multiple licenses:
8  *
9  * - Eclipse Public License Version 1.0 (EPL1), and
10  * - GNU General Public License Version 3 (GPL3).
11  *
12  * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
13  * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
14  * ----------------------------------------------------------------------------
15  * Author: Joao Leal
16  */
17 
18 #include <cppad/cg/operation_path_node.hpp>
19 #include <cppad/cg/bidir_graph.hpp>
20 
21 namespace CppAD {
22 namespace cg {
23 
34 template<class Base>
35 inline bool findPathGraph(BidirGraph<Base>& foundGraph,
36  OperationNode<Base>& root,
37  OperationNode<Base>& target,
38  size_t& bifurcations,
39  size_t maxBifurcations = (std::numeric_limits<size_t>::max)()) {
40  if (bifurcations >= maxBifurcations) {
41  return false;
42  }
43 
44  if (&root == &target) {
45  return true;
46  }
47 
48  if(foundGraph.contains(root)) {
49  return true; // been here and it was saved in foundGraph
50  }
51 
52  auto* h = root.getCodeHandler();
53 
54  if(h->isVisited(root)) {
55  return false; // been here but it was not saved in foundGraph
56  }
57 
58  // not visited yet
59  h->markVisited(root); // mark node as visited
60 
61  PathNodeEdges<Base>& info = foundGraph[root];
62 
63  const auto& args = root.getArguments();
64 
65  bool found = false;
66  for(size_t i = 0; i < args.size(); ++i) {
67  const Argument<Base>& a = args[i];
68  if(a.getOperation() != nullptr ) {
69  auto& aNode = *a.getOperation();
70  if(findPathGraph(foundGraph, aNode, target, bifurcations, maxBifurcations)) {
71  foundGraph.connect(info, root, i);
72  if(found) {
73  bifurcations++; // multiple ways to get to target
74  } else {
75  found = true;
76  }
77  }
78  }
79  }
80 
81  if(!found) {
82  foundGraph.erase(root);
83  }
84 
85  return found;
86 }
87 
88 template<class Base>
89 inline BidirGraph<Base> CodeHandler<Base>::findPathGraph(OperationNode<Base>& root,
90  OperationNode<Base>& target) {
91  size_t bifurcations = 0;
92  return findPathGraph(root, target, bifurcations);
93 }
94 
95 template<class Base>
96 inline BidirGraph<Base> CodeHandler<Base>::findPathGraph(OperationNode<Base>& root,
97  OperationNode<Base>& target,
98  size_t& bifurcations,
99  size_t maxBifurcations) {
100  startNewOperationTreeVisit();
101 
102  BidirGraph<Base> foundGraph;
103 
104  if (bifurcations <= maxBifurcations) {
105  if (&root == &target) {
106  foundGraph[root];
107  } else {
108  CppAD::cg::findPathGraph<Base>(foundGraph, root, target, bifurcations, maxBifurcations);
109  }
110  }
111 
112  return foundGraph;
113 }
114 
115 
116 template<class Base>
117 inline std::vector<std::vector<OperationPathNode<Base> > > CodeHandler<Base>::findPaths(OperationNode<Base>& root,
118  OperationNode<Base>& code,
119  size_t max) {
120  std::vector<std::vector<OperationPathNode<Base> > > found;
121 
122  startNewOperationTreeVisit();
123 
124  if (max > 0) {
125  std::vector<OperationPathNode<Base> > path2node;
126  path2node.reserve(30);
127  path2node.push_back(OperationPathNode<Base> (&root, 0));
128 
129  if (&root == &code) {
130  found.push_back(path2node);
131  } else {
132  findPaths(path2node, code, found, max);
133  }
134  }
135 
136  return found;
137 }
138 
139 template<class Base>
140 inline void CodeHandler<Base>::findPaths(SourceCodePath& currPath,
141  OperationNode<Base>& code,
142  std::vector<SourceCodePath>& found,
143  size_t max) {
144 
145  OperationNode<Base>* currNode = currPath.back().node;
146  if (&code == currNode) {
147  found.push_back(currPath);
148  return;
149  }
150 
151  const std::vector<Argument<Base> >& args = currNode->getArguments();
152  if (args.empty())
153  return; // nothing to look in
154 
155  if (isVisited(*currNode)) {
156  // already searched inside this node
157  // any match would have been saved in found
158  std::vector<SourceCodePath> pathsFromNode = findPathsFromNode(found, *currNode);
159  for (const SourceCodePath& pathFromNode : pathsFromNode) {
160  SourceCodePath newPath(currPath.size() + pathFromNode.size());
161  std::copy(currPath.begin(), currPath.end(), newPath.begin());
162  std::copy(pathFromNode.begin(), pathFromNode.end(), newPath.begin() + currPath.size());
163  found.push_back(newPath);
164  }
165 
166  } else {
167  // not visited yet
168  markVisited(*currNode); // mark node as visited
169 
170  size_t size = args.size();
171  for (size_t i = 0; i < size; ++i) {
172  OperationNode<Base>* a = args[i].getOperation();
173  if (a != nullptr) {
174  currPath.push_back(OperationPathNode<Base> (a, i));
175  findPaths(currPath, code, found, max);
176  currPath.pop_back();
177  if (found.size() == max) {
178  return;
179  }
180  }
181  }
182  }
183 }
184 
185 template<class Base>
186 inline std::vector<std::vector<OperationPathNode<Base> > > CodeHandler<Base>::findPathsFromNode(const std::vector<SourceCodePath> nodePaths,
187  OperationNode<Base>& node) {
188 
189  std::vector<SourceCodePath> foundPaths;
190  std::set<size_t> argsFound;
191 
192  for (const SourceCodePath& path : nodePaths) {
193  size_t size = path.size();
194  for (size_t i = 0; i < size - 1; i++) {
195  const OperationPathNode<Base>& pnode = path[i];
196  if (pnode.node == &node) {
197  if (argsFound.find(path[i + 1].argIndex) == argsFound.end()) {
198  foundPaths.push_back(SourceCodePath(path.begin() + i + 1, path.end()));
199  argsFound.insert(path[i + 1].argIndex);
200  }
201  }
202  }
203  }
204 
205  return foundPaths;
206 }
207 
208 } // END cg namespace
209 } // END CppAD namespace
210 
211 #endif
std::vector< SourceCodePath > findPaths(Node &root, Node &target, size_t max)
const std::vector< Argument< Base > > & getArguments() const
OperationNode< Base > * node