12 #ifndef MLPACK_METHODS_PERCEPTRON_PERCEPTRON_IMPL_HPP 13 #define MLPACK_METHODS_PERCEPTRON_PERCEPTRON_IMPL_HPP 18 namespace perceptron {
26 typename WeightInitializationPolicy,
30 const size_t numClasses,
31 const size_t dimensionality,
32 const size_t maxIterations) :
33 maxIterations(maxIterations)
35 WeightInitializationPolicy wip;
36 wip.Initialize(weights, biases, dimensionality, numClasses);
51 typename WeightInitializationPolicy,
56 const arma::Row<size_t>& labels,
57 const size_t numClasses,
58 const size_t maxIterations) :
59 maxIterations(maxIterations)
62 Train(data, labels, numClasses);
78 typename WeightInitializationPolicy,
84 const arma::Row<size_t>& labels,
85 const size_t numClasses,
86 const arma::rowvec& instanceWeights) :
87 maxIterations(other.maxIterations)
89 Train(data, labels, numClasses, instanceWeights);
101 typename LearnPolicy,
102 typename WeightInitializationPolicy,
107 arma::Row<size_t>& predictedLabels)
109 arma::vec tempLabelMat;
110 arma::uword maxIndex = 0;
113 for (
size_t i = 0; i < test.n_cols; ++i)
115 tempLabelMat = weights.t() * test.col(i) + biases;
116 tempLabelMat.max(maxIndex);
117 predictedLabels(0, i) = maxIndex;
131 typename LearnPolicy,
132 typename WeightInitializationPolicy,
137 const arma::Row<size_t>& labels,
138 const size_t numClasses,
139 const arma::rowvec& instanceWeights)
142 if (weights.n_elem != numClasses)
144 WeightInitializationPolicy wip;
145 wip.Initialize(weights, biases, data.n_rows, numClasses);
149 bool converged =
false;
151 arma::uword maxIndexRow = 0, maxIndexCol = 0;
152 arma::mat tempLabelMat;
156 const bool hasWeights = (instanceWeights.n_elem > 0);
158 while ((i < maxIterations) && (!converged))
166 for (j = 0; j < data.n_cols; ++j)
170 tempLabelMat = weights.t() * data.col(j) + biases;
172 tempLabelMat.max(maxIndexRow, maxIndexCol);
175 if (maxIndexRow != labels(0, j))
179 tempLabel = labels(0, j);
185 LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow, tempLabel,
188 LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow,
196 template<
typename LearnPolicy,
197 typename WeightInitializationPolicy,
199 template<
typename Archive>
206 ar(CEREAL_NVP(maxIterations));
207 ar(CEREAL_NVP(weights));
208 ar(CEREAL_NVP(biases));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Perceptron(const size_t numClasses=0, const size_t dimensionality=0, const size_t maxIterations=1000)
Constructor: create the perceptron with the given number of classes and initialize the weight matrix...
Definition: perceptron_impl.hpp:29
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const arma::rowvec &instanceWeights=arma::rowvec())
Train the perceptron on the given data for up to the maximum number of iterations (specified in the c...
Definition: perceptron_impl.hpp:135
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
Definition: perceptron_impl.hpp:105
void serialize(Archive &ar, const uint32_t)
Serialize the perceptron.
Definition: perceptron_impl.hpp:200
This class implements a simple perceptron (i.e., a single layer neural network).
Definition: perceptron.hpp:36