TooN
conjugate_gradient.h
1 //Copyright (C) Edward Rosten 2009, 2010, 2012
2 
3 //All rights reserved.
4 //
5 //Redistribution and use in source and binary forms, with or without
6 //modification, are permitted provided that the following conditions
7 //are met:
8 //1. Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 //2. Redistributions in binary form must reproduce the above copyright
11 // notice, this list of conditions and the following disclaimer in the
12 // documentation and/or other materials provided with the distribution.
13 //
14 //THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND OTHER CONTRIBUTORS ``AS IS''
15 //AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 //IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 //ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR OTHER CONTRIBUTORS BE
18 //LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19 //CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20 //SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21 //INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22 //CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23 //ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24 //POSSIBILITY OF SUCH DAMAGE.
25 
26 #include <TooN/optimization/brent.h>
27 #include <utility>
28 #include <cmath>
29 #include <cassert>
30 #include <cstdlib>
31 
32 namespace TooN{
33  namespace Internal{
34 
35 
42  template<int Size, typename Precision, typename Func> struct LineSearch
43  {
46 
47  const Func& f;
48 
53  LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func)
54  :start(s),direction(d),f(func)
55  {}
56 
59  Precision operator()(Precision x) const
60  {
61  return f(start + x * direction);
62  }
63  };
64 
76  template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps)
77  {
78  //Get a, b, c to bracket a minimum along a line
79  Precision a, b, c, b_val, c_val;
80 
81  a=0;
82 
83  //Search forward in steps of lambda
84  Precision lambda=initial_lambda;
85  b = lambda;
86  b_val = func(b);
87 
88  while(std::isnan(b_val))
89  {
90  //We've probably gone in to an invalid region. This can happen even
91  //if following the gradient would never get us there.
92  //try backing off lambda
93  lambda*=.5;
94  b = lambda;
95  b_val = func(b);
96 
97  }
98 
99 
100  if(b_val < a_val) //We've gone downhill, so keep searching until we go back up
101  {
102  double last_good_lambda = lambda;
103 
104  for(;;)
105  {
106  lambda *= 2;
107  c = lambda;
108  c_val = func(c);
109 
110  if(std::isnan(c_val))
111  break;
112  last_good_lambda = lambda;
113  if(c_val > b_val) // we have a bracket
114  break;
115  else
116  {
117  a = b;
118  a_val = b_val;
119  b=c;
120  b_val=c_val;
121 
122  }
123  }
124 
125  //We took a step too far.
126  //Back up: this will not attempt to ensure a bracket
127  if(std::isnan(c_val))
128  {
129  double bad_lambda=lambda;
130  double l=1;
131 
132  for(;;)
133  {
134  l*=.5;
135  c = last_good_lambda + (bad_lambda - last_good_lambda)*l;
136  c_val = func(c);
137 
138  if(!std::isnan(c_val))
139  break;
140  }
141 
142 
143  }
144 
145  }
146  else //We've overshot the minimum, so back up
147  {
148  c = b;
149  c_val = b_val;
150  //Here, c_val > a_val
151 
152  for(;;)
153  {
154  lambda *= .5;
155  b = lambda;
156  b_val = func(b);
157 
158  if(b_val < a_val)// we have a bracket
159  break;
160  else if(lambda < zeps)
161  return Zeros;
162  else //Contract the bracket
163  {
164  c = b;
165  c_val = b_val;
166  }
167  }
168  }
169 
170  Matrix<3,2> ret;
171  ret[0] = makeVector(a, a_val);
172  ret[1] = makeVector(b, b_val);
173  ret[2] = makeVector(c, c_val);
174 
175  return ret;
176  }
177 
178 }
179 
180 
225 template<int Size=Dynamic, class Precision=double> struct ConjugateGradient
226 {
227  const int size;
235  Precision y;
236  Precision old_y;
237 
238  Precision tolerance;
239  Precision epsilon;
241 
244  Precision linesearch_epsilon;
246 
247  Precision bracket_epsilon;
248 
250 
255  template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv)
256  : size(start.size()),
257  g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
258  {
259  init(start, func(start), deriv(start));
260  }
261 
266  template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv)
267  : size(start.size()),
268  g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
269  {
270  init(start, func(start), deriv);
271  }
272 
277  void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv)
278  {
279 
280  using std::numeric_limits;
281  using std::sqrt;
282  x = start;
283 
284  //Start with the conjugate direction aligned with
285  //the gradient
286  g = deriv;
287  h = g;
288  minus_h=-h;
289 
290  y = func;
291  old_y = y;
292 
293  tolerance = sqrt(numeric_limits<Precision>::epsilon());
294  epsilon = 1e-20;
295  max_iterations = size * 100;
296 
297  bracket_initial_lambda = 1;
298 
299  linesearch_tolerance = sqrt(numeric_limits<Precision>::epsilon());
300  linesearch_epsilon = 1e-20;
301  linesearch_max_iterations=100;
302 
303  bracket_epsilon=1e-20;
304 
305  iterations=0;
306  }
307 
308 
322  template<class Func> void find_next_point(const Func& func)
323  {
324  Internal::LineSearch<Size, Precision, Func> line(x, minus_h, func);
325 
326  //Always search in the conjugate direction (h)
327  //First bracket a minimum.
328  Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon);
329 
330  double a = bracket[0][0];
331  double b = bracket[1][0];
332  double c = bracket[2][0];
333 
334  double a_val = bracket[0][1];
335  double b_val = bracket[1][1];
336  double c_val = bracket[2][1];
337 
338  old_y = y;
339  old_x = x;
340  iterations++;
341 
342  //Local maximum achieved!
343  if(a==0 && b== 0 && c == 0)
344  return;
345 
346  //We should have a bracket here
347 
348  if(c < b)
349  {
350  //Failed to bracket due to NaN, so c is the best known point.
351  //Simply go there.
352  x-=h * c;
353  y=c_val;
354 
355  }
356  else
357  {
358  assert(a < b && b < c);
359  assert(a_val > b_val && b_val < c_val);
360 
361  //Find the real minimum
362  Vector<2, Precision> m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon);
363 
364  assert(m[0] >= a && m[0] <= c);
365  assert(m[1] <= b_val);
366 
367  //Update the current position and value
368  x -= m[0] * h;
369  y = m[1];
370  }
371  }
372 
375  bool finished()
376  {
377  using std::abs;
378  return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon);
379  }
380 
389  void update_vectors_PR(const Vector<Size>& grad)
390  {
391  //Update the position, gradient and conjugate directions
392  old_g = g;
393  old_h = h;
394 
395  g = grad;
396  //Precision gamma = (g * g - oldg*g)/(oldg * oldg);
397  Precision gamma = (g * g - old_g*g)/(old_g * old_g);
398  h = g + gamma * old_h;
399  minus_h=-h;
400  }
401 
419  template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv)
420  {
421  find_next_point(func);
422 
423  if(!finished())
424  {
425  update_vectors_PR(deriv(x));
426  return 1;
427  }
428  else
429  return 0;
430  }
431 };
432 
433 }
Precision bracket_epsilon
Minimum size for initial minima bracketing. Below this, it is assumed that the system has converged...
Definition: conjugate_gradient.h:247
Precision old_y
Function at old_x.
Definition: conjugate_gradient.h:236
This class provides a nonlinear conjugate-gradient optimizer.
Definition: conjugate_gradient.h:225
Precision linesearch_epsilon
Additive term in tolerance to prevent excessive iterations if . Known as ZEPS in numerical recipies...
Definition: conjugate_gradient.h:244
Pretty generic SFINAE introspection generator.
Definition: vec_test.cc:21
int max_iterations
Maximum number of iterations. Defaults to size .
Definition: conjugate_gradient.h:240
Precision bracket_initial_lambda
Initial stepsize used in bracketing the minimum for the line search. Defaults to 1.
Definition: conjugate_gradient.h:242
Vector< Size > h
Conjugate vector to be searched along in the next call to iterate()
Definition: conjugate_gradient.h:229
A matrix.
Definition: matrix.hh:105
void init(const Vector< Size > &start, const Precision &func, const Vector< Size > &deriv)
Initialize the ConjugateGradient class with sensible values.
Definition: conjugate_gradient.h:277
int iterations
Number of iterations performed.
Definition: conjugate_gradient.h:249
const Vector< Size, Precision > & start
Definition: conjugate_gradient.h:44
Vector< Size > old_g
Gradient vector used to compute $h$ in the last call to iterate()
Definition: conjugate_gradient.h:231
bool isnan(const Vector< S, P, B > &v)
Returns true if any element is NaN.
Definition: helpers.h:396
bool finished()
Check to see it iteration should stop.
Definition: conjugate_gradient.h:375
Vector< Size > minus_h
negative of h as this is required to be passed into a function which uses references (so can&#39;t be tem...
Definition: conjugate_gradient.h:230
ConjugateGradient(const Vector< Size > &start, const Func &func, const Deriv &deriv)
Initialize the ConjugateGradient class with sensible values.
Definition: conjugate_gradient.h:255
Precision linesearch_tolerance
Tolerance used to determine if the linesearch is complete. Defaults to square root of machine precisi...
Definition: conjugate_gradient.h:243
Precision y
Function at .
Definition: conjugate_gradient.h:235
Precision epsilon
Additive term in tolerance to prevent excessive iterations if . Known as ZEPS in numerical recipies...
Definition: conjugate_gradient.h:239
void update_vectors_PR(const Vector< Size > &grad)
After an iteration, update the gradient and conjugate using the Polak-Ribiere equations.
Definition: conjugate_gradient.h:389
Vector< 2, Precision > brent_line_search(Precision a, Precision x, Precision b, Precision fx, const Functor &func, int maxiterations, Precision tolerance=sqrt(numeric_limits< Precision >::epsilon()), Precision epsilon=numeric_limits< Precision >::epsilon())
brent_line_search performs Brent&#39;s golden section/quadratic interpolation search on the functor provi...
Definition: brent.h:55
const Vector< Size, Precision > & direction
Definition: conjugate_gradient.h:45
Vector< Size > g
Gradient vector used by the next call to iterate()
Definition: conjugate_gradient.h:228
ConjugateGradient(const Vector< Size > &start, const Func &func, const Vector< Size > &deriv)
Initialize the ConjugateGradient class with sensible values.
Definition: conjugate_gradient.h:266
Vector< Size > old_h
Conjugate vector searched along in the last call to iterate()
Definition: conjugate_gradient.h:232
bool iterate(const Func &func, const Deriv &deriv)
Use this function to iterate over the optimization.
Definition: conjugate_gradient.h:419
Vector< Size > old_x
Previous best known point (not set at construction)
Definition: conjugate_gradient.h:234
Precision operator()(Precision x) const
Definition: conjugate_gradient.h:59
Matrix< R, C, P > sqrt(const Matrix< R, C, P, B > &m)
computes a matrix square root of a matrix m by the product form of the Denman and Beavers iteration a...
Definition: helpers.h:350
LineSearch(const Vector< Size, Precision > &s, const Vector< Size, Precision > &d, const Func &func)
Set up the line search class.
Definition: conjugate_gradient.h:53
Vector< Size > x
Current position (best known point)
Definition: conjugate_gradient.h:233
Turn a multidimensional function in to a 1D function by specifying a point and direction.
Definition: conjugate_gradient.h:42
const int size
Dimensionality of the space.
Definition: conjugate_gradient.h:227
void find_next_point(const Func &func)
Perform a linesearch from the current point (x) along the current conjugate vector (h)...
Definition: conjugate_gradient.h:322
Matrix< 3, 2, Precision > bracket_minimum_forward(Precision a_val, const Func &func, Precision initial_lambda, Precision zeps)
Bracket a 1D function by searching forward from zero.
Definition: conjugate_gradient.h:76
const Func & f
Definition: conjugate_gradient.h:47
int linesearch_max_iterations
Maximum number of iterations in the linesearch. Defaults to 100.
Definition: conjugate_gradient.h:245
Precision tolerance
Tolerance used to determine if the optimization is complete. Defaults to square root of machine preci...
Definition: conjugate_gradient.h:238