CppADCodeGen  HEAD
A C++ Algorithmic Differentiation Package with Source Code Generation
bidir_graph.hpp
1 #ifndef CPPAD_CG_BIDIR_GRAPH_INCLUDED
2 #define CPPAD_CG_BIDIR_GRAPH_INCLUDED
3 /* --------------------------------------------------------------------------
4  * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5  * Copyright (C) 2016 Ciengis
6  *
7  * CppADCodeGen is distributed under multiple licenses:
8  *
9  * - Eclipse Public License Version 1.0 (EPL1), and
10  * - GNU General Public License Version 3 (GPL3).
11  *
12  * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
13  * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
14  * ----------------------------------------------------------------------------
15  * Author: Joao Leal
16  */
17 
18 namespace CppAD {
19 namespace cg {
20 
21 template<class Base>
23 public:
24  using Node = OperationNode<Base>;
26 public:
27  std::vector<size_t> arguments;
28  std::vector<Path> usage; // parent node and argument index in that node
29 };
30 
34 template<class Base>
35 class BidirGraph {
36 public:
37  using Node = OperationNode<Base>;
38  using SourceCodePath = typename CodeHandler<Base>::SourceCodePath;
39 private:
40  std::map<Node*, PathNodeEdges<Base> > graph_;
41 public:
42  inline virtual ~BidirGraph() { }
43 
44  inline bool empty() const {
45  return graph_.empty();
46  }
47 
48  inline void connect(Node& node,
49  size_t argument) {
50  connect(graph_[&node], node, argument);
51  }
52 
53  inline void connect(PathNodeEdges<Base>& nodeInfo,
54  Node& node,
55  size_t argument) {
56  CPPADCG_ASSERT_UNKNOWN(argument < node.getArguments().size());
57  CPPADCG_ASSERT_UNKNOWN(node.getArguments()[argument].getOperation() != nullptr);
58  CPPADCG_ASSERT_UNKNOWN(&graph_[&node] == &nodeInfo);
59 
60  nodeInfo.arguments.push_back(argument);
61 
62  auto* aNode = node.getArguments()[argument].getOperation();
63  graph_[aNode].usage.push_back(OperationPathNode<Base>(&node, argument));
64  }
65 
66  inline bool contains(Node& node) const {
67  auto it = graph_.find(&node);
68  return it != graph_.end();
69  }
70 
71  inline PathNodeEdges<Base>* find(Node& node) {
72  auto it = graph_.find(&node);
73  if (it != graph_.end())
74  return &it->second;
75  else
76  return nullptr;
77  }
78 
79  inline const PathNodeEdges<Base>* find(Node& node) const {
80  auto it = graph_.find(&node);
81  if (it != graph_.end())
82  return &it->second;
83  else
84  return nullptr;
85  }
86 
87  inline bool erase(Node& node) {
88  return graph_.erase(&node) > 0;
89  }
90 
91  inline PathNodeEdges<Base>& operator[](Node& node) {
92  return graph_[&node];
93  }
94 
99  inline std::vector<SourceCodePath> findSingleBifurcation(Node& expression,
100  Node& target,
101  size_t& bifIndex) const {
102 
103  std::vector<SourceCodePath> paths;
104  bifIndex = -1;
105 
106  if (empty()) {
107  return paths;
108  }
109 
110  const PathNodeEdges<Base>* tail = find(target);
111  if (tail == nullptr)
112  return paths;
113 
114  paths.reserve(2);
115  paths.resize(1);
116  paths[0].reserve(20); // path down
117 
118  if (tail->usage.empty()) {
119  // only one path with one element
120  paths[0].push_back(OperationPathNode<Base>(&target, -1));
121  return paths;
122  }
123 
124  paths = findPathUpTo(expression, target);
125  if (paths.size() > 1)
126  bifIndex = 0;
127 
128  if (paths[0][0].node != &expression) {
133  SourceCodePath pathCommon;
134 
135  auto* n = paths[0][0].node;
136  auto* edges = find(*n);
137  CPPADCG_ASSERT_UNKNOWN(edges != nullptr); // must exist
138 
139  while (true) {
140  n = edges->usage.begin()->node; // ignore other usages for now!!!!
141 
142  pathCommon.push_back(*edges->usage.begin());
143  if (n == &expression)
144  break;
145 
146  edges = find(*n);
147  CPPADCG_ASSERT_UNKNOWN(edges != nullptr);
148  CPPADCG_ASSERT_UNKNOWN(!edges->usage.empty());
149  }
150 
151  bifIndex = pathCommon.size();
152 
153  std::reverse(pathCommon.begin(), pathCommon.end());
154  for (auto& p: paths)
155  p.insert(p.begin(), pathCommon.begin(), pathCommon.end());
156  }
157 
158  return paths;
159  }
160 
161 private:
162 
166  std::vector<SourceCodePath> findPathUpTo(Node& node,
167  Node& target) const {
168  auto* n = &node;
169 
170  auto* edges = find(*n);
171  CPPADCG_ASSERT_UNKNOWN(edges != nullptr); // must exist
172 
173  std::vector<SourceCodePath> paths;
174  paths.reserve(2);
175  paths.resize(1);
176 
177  while (!edges->arguments.empty()) {
178  if (edges->arguments.size() > 1) {
179  // found bifurcation: must restart!
180  size_t a1Index = edges->arguments[0];
181  const auto& a1 = n->getArguments()[a1Index];
182  paths = findPathUpTo(*a1.getOperation(), target);
183  if (paths.size() == 2) {
184  return paths;
185  }
186 
187  size_t a2Index = edges->arguments[1];
188  const auto& a2 = n->getArguments()[a2Index];
189  auto paths2 = findPathUpTo(*a2.getOperation(), target);
190  if (paths2.size() == 2) {
191  return paths2;
192  }
193 
194  paths[0].insert(paths[0].begin(), OperationPathNode<Base>(n, a1Index));
195 
196  paths.resize(2);
197  paths[1].reserve(paths2[0].size() + 1);
198  paths[1].insert(paths[1].begin(), OperationPathNode<Base>(n, a2Index));
199  paths[1].insert(paths[1].begin() + 1, paths2[0].begin(), paths2[0].end());
200  return paths;
201  }
202 
203  size_t argIndex1 = *edges->arguments.begin(); // only one argument
204  paths[0].push_back(OperationPathNode<Base>(n, argIndex1));
205 
206  n = n->getArguments()[argIndex1].getOperation();
207  edges = find(*n);
208  CPPADCG_ASSERT_UNKNOWN(edges != nullptr); // must exist
209  }
210 
211  paths[0].push_back(OperationPathNode<Base>(n, -1));
212 
213  return paths;
214  }
215 
216 #if 0
217  void findPathDownThenUp() {
218  for (const auto& arg0: tail->usage) {
219  paths.resize(1);
220  paths[0].clear();
221  paths[0].push_back(OperationPathNode<Base>(&target, -1));
222 
223  Node* n = arg0.node;
224  size_t argIndex = arg0.argIndex;
225 
226  const PathNodeEdges<Base>* edges = find(*n);
227  CPPADCG_ASSERT_UNKNOWN(edges != nullptr);
228 
229  while (true) {
230  paths[0].push_back(OperationPathNode<Base>(n, argIndex));
231 
232  if(edges->arguments.size() != 1)
233  break; // a bifurcation
234 
235  if(edges->usage.empty())
236  break;
237  n = edges->usage.begin()->node; // ignore other usages for now!!!!
238  argIndex = edges->usage.begin()->argIndex;
239 
240  edges = find(*n);
241  CPPADCG_ASSERT_UNKNOWN(edges != nullptr);
242  }
243 
244  CPPADCG_ASSERT_UNKNOWN(!edges->arguments.empty());
245 
246  //if(edges->arguments.size() > 2) {
247  // continue; // should not use this???
248  //}
249 
250  // flip paths[0] so that it starts at bifurcation
251  std::reverse(paths[0].begin(), paths[0].end());
252 
253  if (edges->arguments.size() == 1) {
254  // there is only one path (there are no bifurcations)
255  return paths;
256  }
257 
258  // there is another path up to target
259  paths.resize(2);
260  paths[1].reserve(20); // path up
261 
265  // use the other argument to go up
266  auto* n1 = paths[0][1].node;
267  size_t argIndex1 = n->getArguments()[edges->arguments[0]].getOperation() == n1? edges->arguments[1]: edges->arguments[0];
268  paths[1].push_back(OperationPathNode<Base>(n, argIndex1)); // start at the same location (but different argument index)
269 
270  n = n->getArguments()[argIndex1].getOperation();
271 
272  edges = find(*n);
273  CPPADCG_ASSERT_UNKNOWN(edges != nullptr); // must exist
274 
275  while (!edges->arguments.empty()) {
276  argIndex1 = *edges->arguments.begin(); // ignore other arguments for now!!!!
277  paths[1].push_back(OperationPathNode<Base>(n, argIndex1));
278 
279  n = n->getArguments()[argIndex1].getOperation();
280  edges = find(*n);
281  CPPADCG_ASSERT_UNKNOWN(edges != nullptr); // must exist
282  }
283 
284  paths[1].push_back(OperationPathNode<Base>(n, -1));
285 
286  bifIndex = 0;
287 
288  break;
289  }
290  }
291 #endif
292 
293 };
294 
295 } // END cg namespace
296 } // END CppAD namespace
297 
298 #endif
std::vector< SourceCodePath > findSingleBifurcation(Node &expression, Node &target, size_t &bifIndex) const
Definition: bidir_graph.hpp:99
const std::vector< Argument< Base > > & getArguments() const