TooN
downhill_simplex.h
1 //Copyright (C) Edward Rosten 2009
2 
3 //All rights reserved.
4 //
5 //Redistribution and use in source and binary forms, with or without
6 //modification, are permitted provided that the following conditions
7 //are met:
8 //1. Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 //2. Redistributions in binary form must reproduce the above copyright
11 // notice, this list of conditions and the following disclaimer in the
12 // documentation and/or other materials provided with the distribution.
13 //
14 //THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND OTHER CONTRIBUTORS ``AS IS''
15 //AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 //IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 //ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR OTHER CONTRIBUTORS BE
18 //LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19 //CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20 //SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21 //INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22 //CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23 //ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24 //POSSIBILITY OF SUCH DAMAGE.
25 
26 #ifndef TOON_DOWNHILL_SIMPLEX_H
27 #define TOON_DOWNHILL_SIMPLEX_H
28 #include <TooN/TooN.h>
29 #include <TooN/helpers.h>
30 #include <algorithm>
31 #include <cstdlib>
32 
33 namespace TooN
34 {
35 
102 template<int N=-1, typename Precision=double> class DownhillSimplex
103 {
104  static const int Vertices = (N==-1?-1:N+1);
107 
108  public:
117  template<class Function> DownhillSimplex(const Function& func, const Vector<N>& c, Precision spread=1)
118  :simplex(c.size()+1, c.size()),values(c.size()+1)
119  {
120  alpha = 1.0;
121  rho = 2.0;
122  gamma = 0.5;
123  sigma = 0.5;
124 
125  using std::sqrt;
126  epsilon = sqrt(numeric_limits<Precision>::epsilon());
127  zero_epsilon = 1e-20;
128 
129  restart(func, c, spread);
130  }
131 
138  template<class Function> void restart(const Function& func, const Vector<N>& c, Precision spread)
139  {
140  for(int i=0; i < simplex.num_rows(); i++)
141  simplex[i] = c;
142 
143  for(int i=0; i < simplex.num_cols(); i++)
144  simplex[i][i] += spread;
145 
146  for(int i=0; i < values.size(); i++)
147  values[i] = func(simplex[i]);
148  }
149 
155  bool finished()
156  {
157  Precision span = norm(simplex[get_best()] - simplex[get_worst()]);
158  Precision scale = norm(simplex[get_best()]);
159 
160  if(span/scale < epsilon || span < zero_epsilon)
161  return 1;
162  else
163  return 0;
164  }
165 
170  template<class Function> void restart(const Function& func, Precision spread)
171  {
172  restart(func, simplex[get_best()], spread);
173  }
174 
176  const Simplex& get_simplex() const
177  {
178  return simplex;
179  }
180 
182  const Values& get_values() const
183  {
184  return values;
185  }
186 
188  int get_best() const
189  {
190  return std::min_element(&values[0], &values[0] + values.size()) - &values[0];
191  }
192 
194  int get_worst() const
195  {
196  return std::max_element(&values[0], &values[0] + values.size()) - &values[0];
197  }
198 
201  template<class Function> void find_next_point(const Function& func)
202  {
203  //Find various things:
204  // - The worst point
205  // - The second worst point
206  // - The best point
207  // - The centroid of all the points but the worst
208  int worst = get_worst();
209  Precision second_worst_val=-HUGE_VAL, bestval = HUGE_VAL, worst_val = values[worst];
210  int best=0;
211  Vector<N> x0 = Zeros(simplex.num_cols());
212 
213 
214  for(int i=0; i < simplex.num_rows(); i++)
215  {
216  if(values[i] < bestval)
217  {
218  bestval = values[i];
219  best = i;
220  }
221 
222  if(i != worst)
223  {
224  if(values[i] > second_worst_val)
225  second_worst_val = values[i];
226 
227  //Compute the centroid of the non-worst points;
228  x0 += simplex[i];
229  }
230  }
231  x0 *= 1.0 / simplex.num_cols();
232 
233 
234  //Reflect the worst point about the centroid.
235  Vector<N> xr = (1 + alpha) * x0 - alpha * simplex[worst];
236  Precision fr = func(xr);
237 
238  if(fr < bestval)
239  {
240  //If the new point is better than the smallest, then try expanding the simplex.
241  Vector<N> xe = rho * xr + (1-rho) * x0;
242  Precision fe = func(xe);
243 
244  //Keep whichever is best
245  if(fe < fr)
246  {
247  simplex[worst] = xe;
248  values[worst] = fe;
249  }
250  else
251  {
252  simplex[worst] = xr;
253  values[worst] = fr;
254  }
255 
256  return;
257  }
258 
259  //Otherwise, if the new point lies between the other points
260  //then keep it and move on to the next iteration.
261  if(fr < second_worst_val)
262  {
263  simplex[worst] = xr;
264  values[worst] = fr;
265  return;
266  }
267 
268 
269  //Otherwise, if the new point is a bit better than the worst point,
270  //(ie, it's got just a little bit better) then contract the simplex
271  //a bit.
272  if(fr < worst_val)
273  {
274  Vector<N> xc = (1 + gamma) * x0 - gamma * simplex[worst];
275  Precision fc = func(xc);
276 
277  //If this helped, use it
278  if(fc <= fr)
279  {
280  simplex[worst] = xc;
281  values[worst] = fc;
282  return;
283  }
284  }
285 
286  //Otherwise, fr is worse than the worst point, or the fc was worse
287  //than fr. So shrink the whole simplex around the best point.
288  for(int i=0; i < simplex.num_rows(); i++)
289  if(i != best)
290  {
291  simplex[i] = simplex[best] + sigma * (simplex[i] - simplex[best]);
292  values[i] = func(simplex[i]);
293  }
294  }
295 
299  template<class Function> bool iterate(const Function& func)
300  {
301  find_next_point(func);
302  return !finished();
303  }
304 
305  Precision alpha;
306  Precision rho;
307  Precision gamma;
308  Precision sigma;
309  Precision epsilon;
310  Precision zero_epsilon;
311 
312  private:
313 
314  //Each row is a simplex vertex
315  Simplex simplex;
316 
317  //Function values for each vertex
318  Values values;
319 
320 
321 };
322 }
323 #endif
This is an implementation of the Downhill Simplex (Nelder & Mead, 1965) algorithm.
Definition: downhill_simplex.h:102
DownhillSimplex(const Function &func, const Vector< N > &c, Precision spread=1)
Initialize the DownhillSimplex class.
Definition: downhill_simplex.h:117
Precision norm(const Vector< Size, Precision, Base > &v)
Compute the norm of v.
Definition: helpers.h:97
void restart(const Function &func, Precision spread)
This function resets the simplex around the best current point.
Definition: downhill_simplex.h:170
bool iterate(const Function &func)
Perform one iteration of the downhill Simplex algorithm, and return the result of not DownhillSimplex...
Definition: downhill_simplex.h:299
Precision zero_epsilon
Additive term in tolerance to prevent excessive iterations if . Known as ZEPS in numerical recipies...
Definition: downhill_simplex.h:310
Pretty generic SFINAE introspection generator.
Definition: vec_test.cc:21
int get_worst() const
Get the index of the worst vertex.
Definition: downhill_simplex.h:194
void restart(const Function &func, const Vector< N > &c, Precision spread)
This function sets up the simplex around, with one point at c and the remaining points are made by mo...
Definition: downhill_simplex.h:138
int get_best() const
Get the index of the best vertex.
Definition: downhill_simplex.h:188
Precision alpha
Reflected size. Defaults to 1.
Definition: downhill_simplex.h:305
Precision gamma
Contraction ratio. Defaults to .5.
Definition: downhill_simplex.h:307
Precision epsilon
Tolerance used to determine if the optimization is complete. Defaults to square root of machine preci...
Definition: downhill_simplex.h:309
Precision sigma
Shrink ratio. Defaults to .5.
Definition: downhill_simplex.h:308
Matrix< R, C, P > sqrt(const Matrix< R, C, P, B > &m)
computes a matrix square root of a matrix m by the product form of the Denman and Beavers iteration a...
Definition: helpers.h:350
const Simplex & get_simplex() const
Return the simplex.
Definition: downhill_simplex.h:176
bool finished()
Check to see it iteration should stop.
Definition: downhill_simplex.h:155
Precision rho
Expansion ratio. Defaults to 2.
Definition: downhill_simplex.h:306
const Values & get_values() const
Return the score at the vertices.
Definition: downhill_simplex.h:182
void find_next_point(const Function &func)
Perform one iteration of the downhill Simplex algorithm.
Definition: downhill_simplex.h:201