SOT
shepard.h
1 
9 #ifndef SOT_SHEPARD_H
10 #define SOT_SHEPARD_H
11 
12 #include "common.h"
13 #include "utils.h"
14 #include "surrogate.h"
15 
17 namespace sot {
18 
20 
31  class Shepard : public Surrogate {
32  protected:
33  double mp;
34  double mDistTol = 1e-10;
35  int mMaxPoints;
36  int mNumPoints;
37  int mDim;
38  mat mX;
39  mat mfX;
40  public:
42 
47  Shepard(int maxPoints, int dim, double p) {
48  mNumPoints = 0;
49  mMaxPoints = maxPoints;
50  mp = p;
51  mDim = dim;
52  mX.resize(dim, maxPoints);
53  mfX.resize(maxPoints);
54  }
55  int dim() const {
56  return mDim;
57  }
58  int numPoints() const {
59  return mNumPoints;
60  }
61  vec X(int i) const {
62  return mX.col(i);
63  }
64  mat X() const {
65  return mX.cols(0, mNumPoints-1);
66  }
67  double fX(int i) const {
68  return mfX(i);
69  }
70  vec fX() const {
71  return mfX.rows(0, mNumPoints-1);
72  }
73 
75 
81  void addPoint(const vec &point, double funVal) {
82  if(mNumPoints >= mMaxPoints) {
83  throw std::logic_error("Capacity exceeded");
84  }
85  mX.col(mNumPoints) = point;
86  mfX(mNumPoints) = funVal;
87  mNumPoints++;
88  }
89 
91 
97  void addPoints(const mat &points, const vec &funVals) {
98  int n = points.n_cols;
99 
100  if(n < 2) {
101  throw std::logic_error("Use add_point instead");
102  }
103  if(mNumPoints + n > mMaxPoints) {
104  throw std::logic_error("Capacity exceeded");
105  }
106 
107  mX.cols(mNumPoints, mNumPoints + n - 1) = points;
108  mfX.rows(mNumPoints, mNumPoints + n - 1) = funVals;
109  mNumPoints += n;
110  }
111 
112  double eval(const vec &point) const {
113  vec dists = squaredPointSetDistance<mat,vec>(point, X());
114  if (arma::min(dists) < mDistTol) { // Just return the closest point
115  arma::uword closest;
116  double scores = dists.min(closest);
117  return mfX(closest);
118  }
119  else {
120  vec weights = arma::pow(dists, -mp/2.0);
121  return arma::dot(weights, fX())/arma::sum(weights);
122  }
123  }
124 
125  double eval(const vec &point, const vec &dists) const {
126  return eval(point);
127  }
128 
129  vec evals(const mat &points) const {
130  vec vals = arma::zeros<vec>(points.n_cols);
131  for(int i=0; i < points.n_cols; i++) {
132  vals(i) = eval(points.col(i));
133  }
134  return vals;
135  }
136 
137  vec evals(const mat &points, const mat &dists) const {
138  return evals(points);
139  }
140 
142 
145  vec deriv(const vec &point) const {
146  throw std::logic_error("No derivatives for Shepard");
147  }
148  void reset() { mNumPoints = 0; }
150  void fit() { return; }
151  };
152 }
153 
154 #endif
void addPoint(const vec &point, double funVal)
Method for adding a point with a known value.
Definition: shepard.h:81
int mMaxPoints
Definition: shepard.h:35
arma::vec vec
Default (column) vector class.
Definition: common.h:17
int numPoints() const
Method for getting the current number of points.
Definition: shepard.h:58
mat mfX
Definition: shepard.h:39
double mp
Definition: shepard.h:33
Abstract class for a SOT surrogate model.
Definition: surrogate.h:25
Shepard(int maxPoints, int dim, double p)
Constructor.
Definition: shepard.h:47
double eval(const vec &point, const vec &dists) const
Method for evaluating the surrogate at multiple points.
Definition: shepard.h:125
vec X(int i) const
Method for getting current point number i (0 is the first)
Definition: shepard.h:61
vec evals(const mat &points, const mat &dists) const
Method for evaluating the surrogate at multiple points.
Definition: shepard.h:137
Shepard&#39;s method
Definition: shepard.h:31
mat mX
Definition: shepard.h:38
int mNumPoints
Definition: shepard.h:36
vec fX() const
Method for getting the values of the current points.
Definition: shepard.h:70
double fX(int i) const
Method for getting the value of current point number i (0 is the first)
Definition: shepard.h:67
double eval(const vec &point) const
Method for evaluating the surrogate model at a point.
Definition: shepard.h:112
mat X() const
Method for getting the current points.
Definition: shepard.h:64
vec deriv(const vec &point) const
Method for evaluating the kNN derivative at one point (not implemented)
Definition: shepard.h:145
int dim() const
Method for getting the number of dimensions.
Definition: shepard.h:55
void reset()
Method for resetting the surrogate model.
Definition: shepard.h:148
void fit()
Fits the interpolant (does nothing)
Definition: shepard.h:150
double mDistTol
Definition: shepard.h:34
void addPoints(const mat &points, const vec &funVals)
Method for adding multiple points with known values.
Definition: shepard.h:97
SOT namespace.
Definition: sot.h:27
int mDim
Definition: shepard.h:37
arma::mat mat
Default matrix class.
Definition: common.h:16
vec evals(const mat &points) const
Method for evaluating the surrogate at multiple points.
Definition: shepard.h:129