13 #ifndef MLPACK_METHODS_GMM_EM_FIT_IMPL_HPP 14 #define MLPACK_METHODS_GMM_EM_FIT_IMPL_HPP 25 template<
typename InitialClusteringType,
26 typename CovarianceConstraintPolicy,
27 typename Distribution>
29 const size_t maxIterations,
30 const double tolerance,
31 InitialClusteringType clusterer,
32 CovarianceConstraintPolicy constraint) :
33 maxIterations(maxIterations),
36 constraint(constraint)
39 template<
typename InitialClusteringType,
40 typename CovarianceConstraintPolicy,
41 typename Distribution>
44 std::vector<Distribution>& dists,
46 const bool useInitialModel)
48 if (std::is_same<Distribution,
52 Log::Warn <<
"Cannot use arma::gmm_diag on Visual Studio due to OpenMP" 53 <<
" compilation issues! Using slower EMFit::Estimate() instead..." 56 ArmadilloGMMWrapper(observations, dists, weights, useInitialModel);
60 else if (std::is_same<CovarianceConstraintPolicy, DiagonalConstraint>::value
61 && std::is_same<Distribution, distribution::GaussianDistribution>::value)
65 Log::Warn <<
"EMFit::Estimate() using DiagonalConstraint with " 66 <<
"GaussianDistribution makes use of slower implementation, so " 67 <<
"DiagonalGMM is recommended for faster training." << std::endl;
72 InitialClustering(observations, dists, weights);
74 double l = LogLikelihood(observations, dists, weights);
76 Log::Debug <<
"EMFit::Estimate(): initial clustering log-likelihood: " 79 double lOld = -DBL_MAX;
80 arma::mat condLogProb(observations.n_cols, dists.size());
84 while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
86 Log::Info <<
"EMFit::Estimate(): iteration " << iteration <<
", " 87 <<
"log-likelihood " << l <<
"." << std::endl;
91 for (
size_t i = 0; i < dists.size(); ++i)
95 arma::vec condLogProbAlias = condLogProb.unsafe_col(i);
96 dists[i].LogProbability(observations, condLogProbAlias);
97 condLogProbAlias += log(weights[i]);
101 for (
size_t i = 0; i < condLogProb.n_rows; ++i)
106 if (probSum != -std::numeric_limits<double>::infinity())
107 condLogProb.row(i) -= probSum;
111 arma::vec probRowSums(dists.size());
112 for (
size_t i = 0; i < dists.size(); ++i)
119 for (
size_t i = 0; i < dists.size(); ++i)
122 if (probRowSums[i] != -std::numeric_limits<double>::infinity())
123 dists[i].Mean() = observations * arma::exp(condLogProb.col(i) -
130 arma::mat tmp = observations.each_col() - dists[i].Mean();
134 if (std::is_same<Distribution,
137 arma::vec covariance = arma::sum((tmp % tmp) %
138 (arma::ones<arma::vec>(observations.n_rows) *
139 trans(arma::exp(condLogProb.col(i) - probRowSums[i]))), 1);
142 constraint.ApplyConstraint(covariance);
143 dists[i].Covariance(std::move(covariance));
147 arma::mat tmpB = tmp.each_row() % trans(arma::exp(condLogProb.col(i) -
149 arma::mat covariance = tmp * trans(tmpB);
152 constraint.ApplyConstraint(covariance);
153 dists[i].Covariance(std::move(covariance));
159 weights = arma::exp(probRowSums - std::log(observations.n_cols));
163 l = LogLikelihood(observations, dists, weights);
169 template<
typename InitialClusteringType,
170 typename CovarianceConstraintPolicy,
171 typename Distribution>
174 const arma::vec& probabilities,
175 std::vector<Distribution>& dists,
177 const bool useInitialModel)
179 if (!useInitialModel)
180 InitialClustering(observations, dists, weights);
182 double l = LogLikelihood(observations, dists, weights);
184 Log::Debug <<
"EMFit::Estimate(): initial clustering log-likelihood: " 187 double lOld = -DBL_MAX;
188 arma::mat condLogProb(observations.n_cols, dists.size());
191 size_t iteration = 1;
192 while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
196 for (
size_t i = 0; i < dists.size(); ++i)
200 arma::vec condLogProbAlias = condLogProb.unsafe_col(i);
201 dists[i].LogProbability(observations, condLogProbAlias);
202 condLogProbAlias += log(weights[i]);
206 for (
size_t i = 0; i < condLogProb.n_rows; ++i)
211 if (probSum != -std::numeric_limits<double>::infinity())
212 condLogProb.row(i) -= probSum;
217 arma::vec probRowSums(dists.size());
221 arma::vec logProbabilities = arma::log(probabilities);
222 for (
size_t i = 0; i < dists.size(); ++i)
228 arma::vec tmpProb = condLogProb.col(i) + logProbabilities;
232 if (probRowSums[i] != -std::numeric_limits<double>::infinity())
234 dists[i].Mean() = observations *
235 arma::exp(condLogProb.col(i) + logProbabilities - probRowSums[i]);
242 arma::mat tmp = observations.each_col() - dists[i].Mean();
246 if (std::is_same<Distribution,
249 arma::vec cov = arma::sum((tmp % tmp) %
250 (arma::ones<arma::vec>(observations.n_rows) *
251 trans(arma::exp(condLogProb.col(i) +
252 logProbabilities - probRowSums[i]))), 1);
255 constraint.ApplyConstraint(cov);
256 dists[i].Covariance(std::move(cov));
260 arma::mat tmpB = tmp.each_row() % trans(arma::exp(condLogProb.col(i) +
261 logProbabilities - probRowSums[i]));
262 arma::mat cov = (tmp * trans(tmpB));
265 constraint.ApplyConstraint(cov);
266 dists[i].Covariance(std::move(cov));
276 l = LogLikelihood(observations, dists, weights);
282 template<
typename InitialClusteringType,
283 typename CovarianceConstraintPolicy,
284 typename Distribution>
287 std::vector<Distribution>& dists,
291 arma::Row<size_t> assignments;
294 clusterer.Cluster(observations, dists.size(), assignments);
299 const bool isDiagGaussDist = std::is_same<Distribution,
302 std::vector<arma::vec> means(dists.size());
305 std::vector<
typename std::conditional<isDiagGaussDist,
306 arma::vec, arma::mat>::type> covs(dists.size());
310 for (
size_t i = 0; i < dists.size(); ++i)
312 means[i].zeros(dists[i].Mean().n_elem);
315 covs[i].zeros(dists[i].Covariance().n_elem);
319 covs[i].zeros(dists[i].Covariance().n_rows,
320 dists[i].Covariance().n_cols);
325 for (
size_t i = 0; i < observations.n_cols; ++i)
327 const size_t cluster = assignments[i];
330 means[cluster] += observations.col(i);
334 covs[cluster] += observations.col(i) % observations.col(i);
336 covs[cluster] += observations.col(i) * trans(observations.col(i));
343 for (
size_t i = 0; i < dists.size(); ++i)
345 means[i] /= (weights[i] > 1) ? weights[i] : 1;
348 for (
size_t i = 0; i < observations.n_cols; ++i)
350 const size_t cluster = assignments[i];
351 const arma::vec normObs = observations.col(i) - means[cluster];
353 covs[cluster] += normObs % normObs;
355 covs[cluster] += normObs * normObs.t();
358 for (
size_t i = 0; i < dists.size(); ++i)
360 covs[i] /= (weights[i] > 1) ? weights[i] : 1;
364 covs[i] = arma::clamp(covs[i], 1e-10, DBL_MAX);
366 constraint.ApplyConstraint(covs[i]);
368 std::swap(dists[i].Mean(), means[i]);
369 dists[i].Covariance(std::move(covs[i]));
373 weights /= accu(weights);
376 template<
typename InitialClusteringType,
377 typename CovarianceConstraintPolicy,
378 typename Distribution>
381 const std::vector<Distribution>& dists,
382 const arma::vec& weights)
const 384 double logLikelihood = 0;
387 arma::mat logLikelihoods(dists.size(), observations.n_cols);
390 for (
size_t i = 0; i < dists.size(); ++i)
392 dists[i].LogProbability(observations, logPhis);
393 logLikelihoods.row(i) = log(weights(i)) + trans(logPhis);
396 for (
size_t j = 0; j < observations.n_cols; ++j)
399 -std::numeric_limits<double>::infinity())
401 Log::Info <<
"Likelihood of point " << j <<
" is 0! It is probably an " 402 <<
"outlier." << std::endl;
407 return logLikelihood;
410 template<
typename InitialClusteringType,
411 typename CovarianceConstraintPolicy,
412 typename Distribution>
413 template<
typename Archive>
417 ar(CEREAL_NVP(maxIterations));
418 ar(CEREAL_NVP(tolerance));
419 ar(CEREAL_NVP(clusterer));
420 ar(CEREAL_NVP(constraint));
423 template<
typename InitialClusteringType,
424 typename CovarianceConstraintPolicy,
425 typename Distribution>
428 std::vector<Distribution>& dists,
430 const bool useInitialModel)
437 Log::Warn <<
"GMM::Train(): tolerance ignored when training GMMs with " 438 <<
"DiagonalConstraint." << std::endl;
448 if (!useInitialModel)
449 InitialClustering(observations, dists, weights);
452 arma::mat means(observations.n_rows, dists.size());
453 arma::mat covs(observations.n_rows, dists.size());
454 for (
size_t i = 0; i < dists.size(); ++i)
456 means.col(i) = dists[i].Mean();
459 covs.col(i) = dists[i].Covariance();
462 g.reset(observations.n_rows, dists.size());
463 g.set_params(std::move(means), std::move(covs), weights.t());
465 g.learn(observations, dists.size(), arma::eucl_dist, arma::keep_existing, 0,
466 maxIterations, 1e-10,
false );
472 g.learn(observations, dists.size(), arma::eucl_dist, arma::static_subset,
473 1000, maxIterations, 1e-10,
false );
477 weights = g.hefts.t();
478 for (
size_t i = 0; i < dists.size(); ++i)
480 dists[i].Mean() = g.means.col(i);
483 arma::vec covsAlias = g.dcovs.unsafe_col(i);
484 constraint.ApplyConstraint(covsAlias);
487 dists[i].Covariance(g.dcovs.col(i));
This class contains methods which can fit a GMM to observations using the EM algorithm.
Definition: em_fit.hpp:45
static MLPACK_EXPORT util::NullOutStream Debug
MLPACK_EXPORT is required for global variables, so that they are properly exported by the Windows com...
Definition: log.hpp:79
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
EMFit(const size_t maxIterations=300, const double tolerance=1e-10, InitialClusteringType clusterer=InitialClusteringType(), CovarianceConstraintPolicy constraint=CovarianceConstraintPolicy())
Construct the EMFit object, optionally passing an InitialClusteringType object (just in case it needs...
Definition: em_fit_impl.hpp:28
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
void serialize(Archive &ar, const uint32_t version)
Serialize the fitter.
Definition: em_fit_impl.hpp:415
double Tolerance() const
Get the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:126
T::elem_type AccuLog(const T &x)
Log-sum a vector of log values.
Definition: log_add_impl.hpp:63
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
void Estimate(const arma::mat &observations, std::vector< Distribution > &dists, arma::vec &weights, const bool useInitialModel=false)
Fit the observations to a Gaussian mixture model (GMM) using the EM algorithm.
Definition: em_fit_impl.hpp:43
This class implements K-Means clustering, using a variety of possible implementations of Lloyd's algo...
Definition: kmeans.hpp:73
A single multivariate Gaussian distribution with diagonal covariance.
Definition: diagonal_gaussian_distribution.hpp:21