SOT
kNN.h
1 
9 #ifndef SOT_kNN_H
10 #define SOT_kNN_H
11 
12 #include "common.h"
13 #include "utils.h"
14 #include "surrogate.h"
15 
17 namespace sot {
18 
20 
28  class kNN : public Surrogate {
29  protected:
30  int mDim;
31  int mMaxPoints;
32  int mNumPoints;
33  int mk;
34  mat mX;
35  vec mfX;
36  public:
38 
43  kNN(int maxPoints, int dim, int k) {
44  mDim = dim;
45  mNumPoints = 0;
46  mMaxPoints = maxPoints;
47  mk = k;
48  mX.resize(dim, maxPoints);
49  mfX.resize(maxPoints);
50  }
51 
52  int dim() const {
53  return mDim;
54  }
55  int numPoints() const {
56  return mNumPoints;
57  }
58  mat X() const {
59  return mX.cols(0, mNumPoints-1);
60  }
61  vec X(int i) const {
62  return mX.col(i);
63  }
64  vec fX() const {
65  return mfX.rows(0, mNumPoints-1);
66  }
67  double fX(int i) const {
68  return mfX(i);
69  }
70 
72 
78  void addPoint(const vec &point, double funVal) {
79  if(mNumPoints >= mMaxPoints) {
80  throw std::logic_error("Capacity exceeded");
81  }
82  mX.col(mNumPoints) = point;
83  mfX(mNumPoints) = funVal;
84  mNumPoints++;
85  }
86 
88 
94  void addPoints(const mat &points, const vec &funVals) {
95  int n = points.n_cols;
96 
97  if(n < 2) {
98  throw std::logic_error("Use add_point instead");
99  }
100  if(mNumPoints + n > mMaxPoints) {
101  throw std::logic_error("Capacity exceeded");
102  }
103 
104  mX.cols(mNumPoints, mNumPoints + n - 1) = points;
105  mfX.rows(mNumPoints, mNumPoints + n - 1) = funVals;
106  mNumPoints += n;
107  }
108 
109  double eval(const vec &point) const {
110  vec dists = squaredPointSetDistance(point, X());
111  uvec indices = sort_index(dists);
112  return arma::mean(mfX(indices.rows(0, mk - 1)));
113  }
114 
115  double eval(const vec &point, const vec &dists) const {
116  return eval(point);
117  }
118  vec evals(const mat &points) const {
119  vec vals = arma::zeros<vec>(points.n_cols);
120  for(int i=0; i < points.n_cols; i++) {
121  vals(i) = eval(points.col(i));
122  }
123  return vals;
124  }
125  vec evals(const mat &points, const mat &dists) const {
126  return evals(points);
127  }
129 
132  vec deriv(const vec& point) const {
133  throw std::logic_error("No derivatives for kNN");
134  }
135  void reset() {
136  mNumPoints = 0;
137  }
139  void fit() {
140  return;
141  }
142  };
143 }
144 
145 #endif
double eval(const vec &point, const vec &dists) const
Method for evaluating the surrogate at multiple points.
Definition: kNN.h:115
void addPoint(const vec &point, double funVal)
Method for adding a point with a known value.
Definition: kNN.h:78
int numPoints() const
Method for getting the current number of points.
Definition: kNN.h:55
arma::vec vec
Default (column) vector class.
Definition: common.h:17
vec fX() const
Method for getting the values of the current points.
Definition: kNN.h:64
arma::uvec uvec
Default unsigned (column) vector class.
Definition: common.h:22
void addPoints(const mat &points, const vec &funVals)
Method for adding multiple points with known values.
Definition: kNN.h:94
vec evals(const mat &points, const mat &dists) const
Method for evaluating the surrogate at multiple points.
Definition: kNN.h:125
vec mfX
Definition: kNN.h:35
int mMaxPoints
Definition: kNN.h:31
mat X() const
Method for getting the current points.
Definition: kNN.h:58
Abstract class for a SOT surrogate model.
Definition: surrogate.h:25
k-nearest neighbors
Definition: kNN.h:28
double eval(const vec &point) const
Method for evaluating the surrogate model at a point.
Definition: kNN.h:109
int mDim
Definition: kNN.h:30
void fit()
Fits kNN (does nothing)
Definition: kNN.h:139
double fX(int i) const
Method for getting the value of current point number i (0 is the first)
Definition: kNN.h:67
int mk
Definition: kNN.h:33
kNN(int maxPoints, int dim, int k)
Constructor.
Definition: kNN.h:43
int dim() const
Method for getting the number of dimensions.
Definition: kNN.h:52
void reset()
Method for resetting the surrogate model.
Definition: kNN.h:135
VecType squaredPointSetDistance(const VecType &x, const MatType &Y)
Fast level-2 distance computation between one point and a set of points.
Definition: utils.h:32
vec evals(const mat &points) const
Method for evaluating the surrogate at multiple points.
Definition: kNN.h:118
vec X(int i) const
Method for getting current point number i (0 is the first)
Definition: kNN.h:61
mat mX
Definition: kNN.h:34
int mNumPoints
Definition: kNN.h:32
SOT namespace.
Definition: sot.h:27
arma::mat mat
Default matrix class.
Definition: common.h:16
vec deriv(const vec &point) const
Method for evaluating the kNN derivative at one point (not implemented)
Definition: kNN.h:132