CppADCodeGen  HEAD
A C++ Algorithmic Differentiation Package with Source Code Generation
solver.hpp
1 #ifndef CPPAD_CG_SOLVER_INCLUDED
2 #define CPPAD_CG_SOLVER_INCLUDED
3 /* --------------------------------------------------------------------------
4  * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5  * Copyright (C) 2012 Ciengis
6  *
7  * CppADCodeGen is distributed under multiple licenses:
8  *
9  * - Eclipse Public License Version 1.0 (EPL1), and
10  * - GNU General Public License Version 3 (GPL3).
11  *
12  * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
13  * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
14  * ----------------------------------------------------------------------------
15  * Author: Joao Leal
16  */
17 
18 #include <cppad/cg/evaluator/evaluator_solve.hpp>
19 #include <cppad/cg/lang/dot/dot.hpp>
20 
21 namespace CppAD {
22 namespace cg {
23 
24 template<class Base>
26  OperationNode<Base>& var) {
27  using std::vector;
28 
29  // find code in expression
30  if (&expression == &var)
31  return CG<Base>(var);
32 
33  size_t bifurcations = (std::numeric_limits<size_t>::max)(); // so that it is possible to enter the loop
34 
35  std::vector<SourceCodePath> paths;
36  BidirGraph<Base> foundGraph;
37  OperationNode<Base> *root = &expression;
38 
39  while (bifurcations > 0) {
40  CPPADCG_ASSERT_UNKNOWN(root != nullptr);
41 
42  // find possible paths from expression to var
43  size_t oldBif = bifurcations;
44  bifurcations = 0;
45  foundGraph = findPathGraph(*root, var, bifurcations, 50000);
46  CPPADCG_ASSERT_UNKNOWN(oldBif > bifurcations);
47 
48  if (!foundGraph.contains(var)) {
49  std::cerr << "Missing variable " << var << std::endl;
50  printExpression(expression, std::cerr);
51  throw CGException("The provided variable ", var.getName() != nullptr ? ("(" + *var.getName() + ")") : "", " is not present in the expression");
52  }
53 
54  // find a bifurcation which does not contain any other bifurcations
55  size_t bifPos = 0;
56  paths = foundGraph.findSingleBifurcation(*root, var, bifPos);
57  if (paths.empty()) {
58  throw CGException("The provided variable is not present in the expression");
59 
60  } else if (paths.size() == 1) {
61  CPPADCG_ASSERT_UNKNOWN(paths[0][0].node == root);
62  CPPADCG_ASSERT_UNKNOWN(paths[0].back().node == &var);
63 
64  return solveFor(paths[0]);
65 
66  } else {
67  CPPADCG_ASSERT_UNKNOWN(paths.size() >= 1);
68  CPPADCG_ASSERT_UNKNOWN(paths[0].back().node == &var);
69 
70  CG<Base> expression2 = collectVariable(*root, paths[0], paths[1], bifPos);
71  root = expression2.getOperationNode();
72  if (root == nullptr) {
73  throw CGException("It is not possible to solve the expression for the requested variable: the variable disappears after symbolic manipulations (e.g., y=x-x).");
74  }
75  }
76  }
77 
78  CPPADCG_ASSERT_UNKNOWN(paths.size() == 1);
79  return solveFor(paths[0]);
80 }
81 
82 template<class Base>
83 inline CG<Base> CodeHandler<Base>::solveFor(const SourceCodePath& path) {
84 
85  CG<Base> rightHs(0.0);
86 
87  for (size_t n = 0; n < path.size() - 1; ++n) {
88  const OperationPathNode<Base>& pnodeOp = path[n];
89  size_t argIndex = path[n].argIndex;
90  const std::vector<Argument<Base> >& args = pnodeOp.node->getArguments();
91 
92  CGOpCode op = pnodeOp.node->getOperationType();
93  switch (op) {
94  case CGOpCode::Mul:
95  {
96  const Argument<Base>& other = args[argIndex == 0 ? 1 : 0];
97  rightHs /= CG<Base>(other);
98  break;
99  }
100  case CGOpCode::Div:
101  if (argIndex == 0) {
102  const Argument<Base>& other = args[1];
103  rightHs *= CG<Base>(other);
104  } else {
105  const Argument<Base>& other = args[0];
106  rightHs = CG<Base>(other) / rightHs;
107  }
108  break;
109 
110  case CGOpCode::UnMinus:
111  rightHs *= Base(-1.0);
112  break;
113  case CGOpCode::Add:
114  {
115  const Argument<Base>& other = args[argIndex == 0 ? 1 : 0];
116  rightHs -= CG<Base>(other);
117  break;
118  }
119  case CGOpCode::Alias:
120  // do nothing
121  break;
122  case CGOpCode::Sub:
123  {
124  if (argIndex == 0) {
125  rightHs += CG<Base>(args[1]);
126  } else {
127  rightHs = CG<Base>(args[0]) - rightHs;
128  }
129  break;
130  }
131  case CGOpCode::Exp:
132  rightHs = log(rightHs);
133  break;
134  case CGOpCode::Log:
135  rightHs = exp(rightHs);
136  break;
137  case CGOpCode::Pow:
138  {
139  if (argIndex == 0) {
140  // base
141  const Argument<Base>& exponent = args[1];
142  if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(0.0)) {
143  throw CGException("Invalid zero exponent");
144  } else if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(1.0)) {
145  continue; // do nothing
146  } else {
147  throw CGException("Unable to invert operation '", op, "'");
148  /*
149  if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(2.0)) {
150  rightHs = sqrt(rightHs); // TODO: should -sqrt(rightHs) somehow be considered???
151  } else {
152  rightHs = pow(rightHs, Base(1.0) / CG<Base>(exponent));
153  }
154  */
155  }
156  } else {
157  //
158  const Argument<Base>& base = args[0];
159  rightHs = log(rightHs) / log(CG<Base>(base));
160  }
161  break;
162  }
163  case CGOpCode::Sqrt:
164  rightHs *= rightHs;
165  break;
166  //case CGAcosOp: // asin(variable)
167  //case CGAsinOp: // asin(variable)
168  //case Atan: // atan(variable)
169  case CGOpCode::Cosh: // cosh(variable)
170  {
171  rightHs = log(rightHs + sqrt(rightHs * rightHs - Base(1.0))); // asinh
172  break;
173  //case Cos: // cos(variable)
174  }
175  case CGOpCode::Sinh: // sinh(variable)
176  rightHs = log(rightHs + sqrt(rightHs * rightHs + Base(1.0))); // asinh
177  break;
178  //case CGSinOp: // sin(variable)
179  case CGOpCode::Tanh: // tanh(variable)
180  rightHs = Base(0.5) * (log(Base(1.0) + rightHs) - log(Base(1.0) - rightHs)); // atanh
181  break;
182  //case CGTanOp: // tan(variable)
183  default:
184  throw CGException("Unable to invert operation '", op, "'");
185  };
186  }
187 
188  return rightHs;
189 }
190 
191 template<class Base>
192 inline bool CodeHandler<Base>::isSolvable(OperationNode<Base>& expression,
193  OperationNode<Base>& var) {
194  size_t bifurcations = 0;
195  BidirGraph<Base> g = findPathGraph(expression, var, bifurcations);
196 
197  if(bifurcations == 0) {
198  size_t bifIndex = 0;
199  auto paths = g.findSingleBifurcation(expression, var, bifIndex);
200  if (paths.empty() || paths[0].empty())
201  return false;
202 
203  return isSolvable(paths[0]);
204  } else {
205  // TODO: improve this
206  //bool v = isCollectableVariableAddSub();
207  try {
208  solveFor(expression, var);
209  return true;
210  } catch(const CGException& e) {
211  return false;
212  }
213  }
214 }
215 
216 template<class Base>
217 inline bool CodeHandler<Base>::isSolvable(const SourceCodePath& path) const {
218  for (size_t n = 0; n < path.size() - 1; ++n) {
219  const OperationPathNode<Base>& pnodeOp = path[n];
220  size_t argIndex = path[n].argIndex;
221  const std::vector<Argument<Base> >& args = pnodeOp.node->getArguments();
222 
223  CGOpCode op = pnodeOp.node->getOperationType();
224  switch (op) {
225  case CGOpCode::Mul:
226  case CGOpCode::Div:
227  case CGOpCode::UnMinus:
228  case CGOpCode::Add:
229  case CGOpCode::Alias:
230  case CGOpCode::Sub:
231  case CGOpCode::Exp:
232  case CGOpCode::Log:
233  case CGOpCode::Sqrt:
234  case CGOpCode::Cosh: // cosh(variable)
235  case CGOpCode::Sinh: // sinh(variable)
236  case CGOpCode::Tanh: // tanh(variable)
237  break;
238  case CGOpCode::Pow:
239  {
240  if (argIndex == 0) {
241  // base
242  const Argument<Base>& exponent = args[1];
243  if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(0.0)) {
244  return false;
245  } else if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(1.0)) {
246  break;
247  } else {
248  return false;
249  }
250  } else {
251  break;
252  }
253  break;
254  }
255 
256  default:
257  return false;
258  };
259  }
260  return true;
261 }
262 
263 } // END cg namespace
264 } // END CppAD namespace
265 
266 #endif
std::vector< SourceCodePath > findSingleBifurcation(Node &expression, Node &target, size_t &bifIndex) const
Definition: bidir_graph.hpp:99
const std::string * getName() const
CGB solveFor(Node &expression, Node &var)
Definition: solver.hpp:25
OperationNode< Base > * node