14 #ifndef MLPACK_METHODS_GMM_GMM_IMPL_HPP 15 #define MLPACK_METHODS_GMM_GMM_IMPL_HPP 26 template<
typename FittingType>
29 const bool useExistingModel,
32 double bestLikelihood;
39 fitter.Estimate(observations, dists, weights, useExistingModel);
40 bestLikelihood = LogLikelihood(observations, dists, weights);
48 std::vector<distribution::GaussianDistribution> distsOrig;
49 arma::vec weightsOrig;
53 weightsOrig = weights;
58 fitter.Estimate(observations, dists, weights, useExistingModel);
60 bestLikelihood = LogLikelihood(observations, dists, weights);
62 Log::Info <<
"GMM::Train(): Log-likelihood of trial 0 is " 63 << bestLikelihood <<
"." << std::endl;
66 std::vector<distribution::GaussianDistribution> distsTrial(gaussians,
68 arma::vec weightsTrial(gaussians);
70 for (
size_t trial = 1; trial < trials; ++trial)
74 distsTrial = distsOrig;
75 weightsTrial = weightsOrig;
78 fitter.Estimate(observations, distsTrial, weightsTrial, useExistingModel);
81 double newLikelihood = LogLikelihood(observations, distsTrial,
84 Log::Info <<
"GMM::Train(): Log-likelihood of trial " << trial <<
" is " 85 << newLikelihood <<
"." << std::endl;
87 if (newLikelihood > bestLikelihood)
90 bestLikelihood = newLikelihood;
93 weights = weightsTrial;
99 Log::Info <<
"GMM::Train(): log-likelihood of trained GMM is " 100 << bestLikelihood <<
"." << std::endl;
101 return bestLikelihood;
108 template<
typename FittingType>
110 const arma::vec& probabilities,
112 const bool useExistingModel,
115 double bestLikelihood;
122 fitter.Estimate(observations, probabilities, dists, weights,
124 bestLikelihood = LogLikelihood(observations, dists, weights);
132 std::vector<distribution::GaussianDistribution> distsOrig;
133 arma::vec weightsOrig;
134 if (useExistingModel)
137 weightsOrig = weights;
142 fitter.Estimate(observations, probabilities, dists, weights,
145 bestLikelihood = LogLikelihood(observations, dists, weights);
147 Log::Debug <<
"GMM::Train(): Log-likelihood of trial 0 is " 148 << bestLikelihood <<
"." << std::endl;
151 std::vector<distribution::GaussianDistribution> distsTrial(gaussians,
153 arma::vec weightsTrial(gaussians);
155 for (
size_t trial = 1; trial < trials; ++trial)
157 if (useExistingModel)
159 distsTrial = distsOrig;
160 weightsTrial = weightsOrig;
163 fitter.Estimate(observations, probabilities, distsTrial, weightsTrial,
167 double newLikelihood = LogLikelihood(observations, distsTrial,
170 Log::Debug <<
"GMM::Train(): Log-likelihood of trial " << trial <<
" is " 171 << newLikelihood <<
"." << std::endl;
173 if (newLikelihood > bestLikelihood)
176 bestLikelihood = newLikelihood;
179 weights = weightsTrial;
185 Log::Info <<
"GMM::Train(): log-likelihood of trained GMM is " 186 << bestLikelihood <<
"." << std::endl;
187 return bestLikelihood;
193 template<
typename Archive>
196 ar(CEREAL_NVP(gaussians));
197 ar(CEREAL_NVP(dimensionality));
202 if (cereal::is_loading<Archive>())
203 dists.resize(gaussians);
205 ar(CEREAL_NVP(dists));
207 ar(CEREAL_NVP(weights));
A single multivariate Gaussian distribution.
Definition: gaussian_distribution.hpp:24
static MLPACK_EXPORT util::NullOutStream Debug
MLPACK_EXPORT is required for global variables, so that they are properly exported by the Windows com...
Definition: log.hpp:79
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double Train(const arma::mat &observations, const size_t trials=1, const bool useExistingModel=false, FittingType fitter=FittingType())
Estimate the probability distribution directly from the given observations, using the given algorithm...
Definition: gmm_impl.hpp:27
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
void serialize(Archive &ar, const uint32_t)
Serialize the GMM.
Definition: gmm_impl.hpp:194