17 #ifndef MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP 18 #define MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP 26 namespace naive_bayes {
28 template<
typename ModelMatType>
29 template<
typename MatType>
32 const arma::Row<size_t>& labels,
33 const size_t numClasses,
34 const bool incremental,
35 const double epsilon) :
39 static_assert(std::is_same<ElemType, typename MatType::elem_type>::value,
40 "NaiveBayesClassifier: element type of given data must match the element " 41 "type of the model!");
48 probabilities.zeros(numClasses);
49 means.zeros(data.n_rows, numClasses);
50 variances.zeros(data.n_rows, numClasses);
54 probabilities.set_size(numClasses);
55 means.set_size(data.n_rows, numClasses);
56 variances.set_size(data.n_rows, numClasses);
59 Train(data, labels, numClasses, incremental);
62 template<
typename ModelMatType>
64 const size_t dimensionality,
65 const size_t numClasses,
66 const double epsilon) :
71 probabilities.zeros(numClasses);
72 means.zeros(dimensionality, numClasses);
73 variances.zeros(dimensionality, numClasses);
76 template<
typename ModelMatType>
77 template<
typename MatType>
80 const arma::Row<size_t>& labels,
81 const size_t numClasses,
82 const bool incremental)
84 static_assert(std::is_same<ElemType, typename MatType::elem_type>::value,
85 "NaiveBayesClassifier: element type of given data must match the element " 86 "type of the model!");
89 if (probabilities.n_elem != numClasses)
96 probabilities.zeros(numClasses);
97 means.zeros(data.n_rows, numClasses);
98 variances.zeros(data.n_rows, numClasses);
102 probabilities.set_size(numClasses);
103 means.set_size(data.n_rows, numClasses);
104 variances.set_size(data.n_rows, numClasses);
114 probabilities *= trainingPoints;
116 for (
size_t j = 0; j < data.n_cols; ++j)
118 const size_t label = labels[j];
119 ++probabilities[label];
121 arma::vec delta = data.col(j) - means.col(label);
122 means.col(label) += delta / probabilities[label];
123 variances.col(label) += delta % (data.col(j) - means.col(label));
126 for (
size_t i = 0; i < probabilities.n_elem; ++i)
128 if (probabilities[i] > 2)
129 variances.col(i) /= (probabilities[i] - 1);
135 probabilities.zeros();
146 for (
size_t j = 0; j < data.n_cols; ++j)
148 const size_t label = labels[j];
149 ++probabilities[label];
150 means.col(label) += data.col(j);
154 for (
size_t i = 0; i < probabilities.n_elem; ++i)
155 if (probabilities[i] != 0.0)
156 means.col(i) /= probabilities[i];
159 for (
size_t j = 0; j < data.n_cols; ++j)
161 const size_t label = labels[j];
162 variances.col(label) += square(data.col(j) - means.col(label));
166 for (
size_t i = 0; i < probabilities.n_elem; ++i)
167 if (probabilities[i] > 1)
168 variances.col(i) /= (probabilities[i] - 1);
172 variances += epsilon;
174 probabilities /= data.n_cols;
175 trainingPoints += data.n_cols;
178 template<
typename ModelMatType>
179 template<
typename VecType>
183 static_assert(std::is_same<ElemType, typename VecType::elem_type>::value,
184 "NaiveBayesClassifier: element type of given data must match the element " 185 "type of the model!");
188 probabilities *= trainingPoints;
189 probabilities[label]++;
191 arma::vec delta = point - means.col(label);
192 means.col(label) += delta / probabilities[label];
193 if (probabilities[label] > 2)
194 variances.col(label) *= (probabilities[label] - 2);
195 variances.col(label) += (delta % (point - means.col(label)));
196 if (probabilities[label] > 1)
197 variances.col(label) /= probabilities[label] - 1;
200 probabilities /= trainingPoints;
203 template<
typename ModelMatType>
204 template<
typename MatType>
207 ModelMatType& logLikelihoods)
const 209 static_assert(std::is_same<ElemType, typename MatType::elem_type>::value,
210 "NaiveBayesClassifier: element type of given data must match the element " 211 "type of the model!");
213 logLikelihoods = arma::log(arma::repmat(probabilities, 1, data.n_cols));
214 ModelMatType invVar = 1.0 / variances;
220 for (
size_t i = 0; i < means.n_cols; ++i)
224 ModelMatType diffs = data - arma::repmat(means.col(i), 1, data.n_cols);
225 ModelMatType rhs = -0.5 * arma::diagmat(invVar.col(i)) * diffs;
226 arma::Mat<ElemType> exponents = arma::sum(diffs % rhs, 0);
228 logLikelihoods.row(i) += (data.n_rows / -2.0 * log(2 * M_PI) - 0.5 *
229 arma::accu(arma::log(variances.col(i))) + exponents);
233 template<
typename ModelMatType>
234 template<
typename VecType>
237 static_assert(std::is_same<ElemType, typename VecType::elem_type>::value,
238 "NaiveBayesClassifier: element type of given data must match the element " 239 "type of the model!");
242 ModelMatType logLikelihoods;
243 LogLikelihood(point, logLikelihoods);
245 arma::uword maxIndex = 0;
246 logLikelihoods.max(maxIndex);
250 template<
typename ModelMatType>
251 template<
typename VecType,
typename ProbabilitiesVecType>
253 const VecType& point,
255 ProbabilitiesVecType& probabilities)
const 257 static_assert(std::is_same<ElemType, typename VecType::elem_type>::value,
258 "NaiveBayesClassifier: element type of given data must match the element " 259 "type of the model!");
260 static_assert(std::is_same<ElemType,
261 typename ProbabilitiesVecType::elem_type>::value,
262 "NaiveBayesClassifier: element type of given data must match the element " 263 "type of the model!");
269 ModelMatType logLikelihoods;
270 LogLikelihood(point, logLikelihoods);
274 const double maxValue = arma::max(logLikelihoods);
275 const double logProbX = log(arma::accu(exp(logLikelihoods - maxValue))) +
277 probabilities = exp(logLikelihoods - logProbX);
279 arma::uword maxIndex = 0;
280 logLikelihoods.max(maxIndex);
281 prediction = (size_t) maxIndex;
284 template<
typename ModelMatType>
285 template<
typename MatType>
288 arma::Row<size_t>& predictions)
const 290 static_assert(std::is_same<ElemType, typename MatType::elem_type>::value,
291 "NaiveBayesClassifier: element type of given data must match the element " 292 "type of the model!");
294 predictions.set_size(data.n_cols);
296 ModelMatType logLikelihoods;
297 LogLikelihood(data, logLikelihoods);
299 for (
size_t i = 0; i < data.n_cols; ++i)
301 arma::uword maxIndex = 0;
302 logLikelihoods.unsafe_col(i).max(maxIndex);
303 predictions[i] = maxIndex;
307 template<
typename ModelMatType>
308 template<
typename MatType,
typename ProbabilitiesMatType>
311 arma::Row<size_t>& predictions,
312 ProbabilitiesMatType& predictionProbs)
const 314 static_assert(std::is_same<ElemType, typename MatType::elem_type>::value,
315 "NaiveBayesClassifier: element type of given data must match the element " 316 "type of the model!");
317 static_assert(std::is_same<ElemType,
318 typename ProbabilitiesMatType::elem_type>::value,
319 "NaiveBayesClassifier: element type of given data must match the element " 320 "type of the model!");
322 predictions.set_size(data.n_cols);
324 ModelMatType logLikelihoods;
325 LogLikelihood(data, logLikelihoods);
327 predictionProbs.set_size(arma::size(logLikelihoods));
328 double maxValue, logProbX;
329 for (
size_t j = 0; j < data.n_cols; ++j)
335 maxValue = arma::max(logLikelihoods.col(j));
336 logProbX = log(arma::accu(exp(logLikelihoods.col(j) -
337 maxValue))) + maxValue;
338 predictionProbs.col(j) = arma::exp(logLikelihoods.col(j) - logProbX);
342 for (
size_t i = 0; i < data.n_cols; ++i)
344 arma::uword maxIndex = 0;
345 logLikelihoods.unsafe_col(i).max(maxIndex);
346 predictions[i] = maxIndex;
350 template<
typename ModelMatType>
351 template<
typename Archive>
356 ar(CEREAL_NVP(means));
357 ar(CEREAL_NVP(variances));
358 ar(CEREAL_NVP(probabilities));
void serialize(Archive &ar, const uint32_t)
Serialize the classifier.
Definition: naive_bayes_classifier_impl.hpp:352
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incremental=true)
Train the Naive Bayes classifier on the given dataset.
Definition: naive_bayes_classifier_impl.hpp:78
NaiveBayesClassifier(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incrementalVariance=false, const double epsilon=1e-10)
Initializes the classifier as per the input and then trains it by calculating the sample mean and var...
Definition: naive_bayes_classifier_impl.hpp:30
The core includes that mlpack expects; standard C++ includes and Armadillo.
The simple Naive Bayes classifier.
Definition: naive_bayes_classifier.hpp:58
size_t Classify(const VecType &point) const
Classify the given point, using the trained NaiveBayesClassifier model.
Definition: naive_bayes_classifier_impl.hpp:235