mlpack
perceptron_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_PERCEPTRON_PERCEPTRON_IMPL_HPP
13 #define MLPACK_METHODS_PERCEPTRON_PERCEPTRON_IMPL_HPP
14 
15 #include "perceptron.hpp"
16 
17 namespace mlpack {
18 namespace perceptron {
19 
24 template<
25  typename LearnPolicy,
26  typename WeightInitializationPolicy,
27  typename MatType
28 >
30  const size_t numClasses,
31  const size_t dimensionality,
32  const size_t maxIterations) :
33  maxIterations(maxIterations)
34 {
35  WeightInitializationPolicy wip;
36  wip.Initialize(weights, biases, dimensionality, numClasses);
37 }
38 
49 template<
50  typename LearnPolicy,
51  typename WeightInitializationPolicy,
52  typename MatType
53 >
55  const MatType& data,
56  const arma::Row<size_t>& labels,
57  const size_t numClasses,
58  const size_t maxIterations) :
59  maxIterations(maxIterations)
60 {
61  // Start training.
62  Train(data, labels, numClasses);
63 }
64 
76 template<
77  typename LearnPolicy,
78  typename WeightInitializationPolicy,
79  typename MatType
80 >
82  const Perceptron& other,
83  const MatType& data,
84  const arma::Row<size_t>& labels,
85  const size_t numClasses,
86  const arma::rowvec& instanceWeights) :
87  maxIterations(other.maxIterations)
88 {
89  Train(data, labels, numClasses, instanceWeights);
90 }
91 
100 template<
101  typename LearnPolicy,
102  typename WeightInitializationPolicy,
103  typename MatType
104 >
106  const MatType& test,
107  arma::Row<size_t>& predictedLabels)
108 {
109  arma::vec tempLabelMat;
110  arma::uword maxIndex = 0;
111 
112  // Could probably be faster if done in batch.
113  for (size_t i = 0; i < test.n_cols; ++i)
114  {
115  tempLabelMat = weights.t() * test.col(i) + biases;
116  tempLabelMat.max(maxIndex);
117  predictedLabels(0, i) = maxIndex;
118  }
119 }
120 
130 template<
131  typename LearnPolicy,
132  typename WeightInitializationPolicy,
133  typename MatType
134 >
136  const MatType& data,
137  const arma::Row<size_t>& labels,
138  const size_t numClasses,
139  const arma::rowvec& instanceWeights)
140 {
141  // Do we need to resize the weights?
142  if (weights.n_elem != numClasses)
143  {
144  WeightInitializationPolicy wip;
145  wip.Initialize(weights, biases, data.n_rows, numClasses);
146  }
147 
148  size_t j, i = 0;
149  bool converged = false;
150  size_t tempLabel;
151  arma::uword maxIndexRow = 0, maxIndexCol = 0;
152  arma::mat tempLabelMat;
153 
154  LearnPolicy LP;
155 
156  const bool hasWeights = (instanceWeights.n_elem > 0);
157 
158  while ((i < maxIterations) && (!converged))
159  {
160  // This outer loop is for each iteration, and we use the 'converged'
161  // variable for noting whether or not convergence has been reached.
162  ++i;
163  converged = true;
164 
165  // Now this inner loop is for going through the dataset in each iteration.
166  for (j = 0; j < data.n_cols; ++j)
167  {
168  // Multiply for each variable and check whether the current weight vector
169  // correctly classifies this.
170  tempLabelMat = weights.t() * data.col(j) + biases;
171 
172  tempLabelMat.max(maxIndexRow, maxIndexCol);
173 
174  // Check whether prediction is correct.
175  if (maxIndexRow != labels(0, j))
176  {
177  // Due to incorrect prediction, convergence set to false.
178  converged = false;
179  tempLabel = labels(0, j);
180 
181  // Send maxIndexRow for knowing which weight to update, send j to know
182  // the value of the vector to update it with. Send tempLabel to know
183  // the correct class.
184  if (hasWeights)
185  LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow, tempLabel,
186  instanceWeights(j));
187  else
188  LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow,
189  tempLabel);
190  }
191  }
192  }
193 }
194 
196 template<typename LearnPolicy,
197  typename WeightInitializationPolicy,
198  typename MatType>
199 template<typename Archive>
201  Archive& ar,
202  const uint32_t /* version */)
203 {
204  // We just need to serialize the maximum number of iterations, the weights,
205  // and the biases.
206  ar(CEREAL_NVP(maxIterations));
207  ar(CEREAL_NVP(weights));
208  ar(CEREAL_NVP(biases));
209 }
210 
211 } // namespace perceptron
212 } // namespace mlpack
213 
214 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Perceptron(const size_t numClasses=0, const size_t dimensionality=0, const size_t maxIterations=1000)
Constructor: create the perceptron with the given number of classes and initialize the weight matrix...
Definition: perceptron_impl.hpp:29
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const arma::rowvec &instanceWeights=arma::rowvec())
Train the perceptron on the given data for up to the maximum number of iterations (specified in the c...
Definition: perceptron_impl.hpp:135
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
Definition: perceptron_impl.hpp:105
void serialize(Archive &ar, const uint32_t)
Serialize the perceptron.
Definition: perceptron_impl.hpp:200
This class implements a simple perceptron (i.e., a single layer neural network).
Definition: perceptron.hpp:36