mlpack
|
#include <mlpack/prereqs.hpp>
#include <mlpack/core/util/io.hpp>
#include <mlpack/core/util/mlpack_main.hpp>
#include "gmm.hpp"
#include "diagonal_gmm.hpp"
#include "no_constraint.hpp"
#include "diagonal_constraint.hpp"
#include <mlpack/methods/kmeans/refined_start.hpp>
Functions | |
BINDING_NAME ("Gaussian Mixture Model (GMM) Training") | |
BINDING_SHORT_DESC ("An implementation of the EM algorithm for training Gaussian mixture " "models (GMMs). Given a dataset, this can train a GMM for future use " "with other tools.") | |
BINDING_LONG_DESC ("This program takes a parametric estimate of a Gaussian mixture model (GMM)" " using the EM algorithm to find the maximum likelihood estimate. The " "model may be saved and reused by other mlpack GMM tools." "\" "The input data to train on must be specified with the "+PRINT_PARAM_STRING("input")+" parameter, and the number of Gaussians in " "the model must be specified with the "+PRINT_PARAM_STRING("gaussians")+" parameter. Optionally, many trials with different random " "initializations may be run, and the result with highest log-likelihood on " "the training data will be taken. The number of trials to run is specified" " with the "+PRINT_PARAM_STRING("trials")+" parameter. By default, " "only one trial is run." "\" "The tolerance for convergence and maximum number of iterations of the EM " "algorithm are specified with the "+PRINT_PARAM_STRING("tolerance")+" and "+PRINT_PARAM_STRING("max_iterations")+" parameters, " "respectively. The GMM may be initialized for training with another model," " specified with the "+PRINT_PARAM_STRING("input_model")+" parameter." " Otherwise, the model is initialized by running k-means on the data. The " "k-means clustering initialization can be controlled with the "+PRINT_PARAM_STRING("kmeans_max_iterations")+", "+PRINT_PARAM_STRING("refined_start")+", "+PRINT_PARAM_STRING("samplings")+", and "+PRINT_PARAM_STRING("percentage")+" parameters. If "+PRINT_PARAM_STRING("refined_start")+" is specified, then the " "Bradley-Fayyad refined start initialization will be used. This can often " "lead to better clustering results." "\" "The 'diagonal_covariance' flag will cause the learned covariances to be " "diagonal matrices. This significantly simplifies the model itself and " "causes training to be faster, but restricts the ability to fit more " "complex GMMs." "\" "If GMM training fails with an error indicating that a covariance matrix " "could not be inverted, make sure that the "+PRINT_PARAM_STRING("no_force_positive")+" parameter is not " "specified. Alternately, adding a small amount of Gaussian noise (using " "the "+PRINT_PARAM_STRING("noise")+" parameter) to the entire dataset" " may help prevent Gaussians with zero variance in a particular dimension, " "which is usually the cause of non-invertible covariance matrices." "\" "The "+PRINT_PARAM_STRING("no_force_positive")+" parameter, if set, " "will avoid the checks after each iteration of the EM algorithm which " "ensure that the covariance matrices are positive definite. Specifying " "the flag can cause faster runtime, but may also cause non-positive " "definite covariance matrices, which will cause the program to crash.") | |
BINDING_EXAMPLE ("As an example, to train a 6-Gaussian GMM on the data in "+PRINT_DATASET("data")+" with a maximum of 100 iterations of EM and 3 " "trials, saving the trained GMM to "+PRINT_MODEL("gmm")+", the " "following command can be used:" "\"+PRINT_CALL("gmm_train", "input", "data", "gaussians", 6, "trials", 3, "output_model", "gmm")+"\" "To re-train that GMM on another set of data "+PRINT_DATASET("data2")+", the following command may be used: " "\"+PRINT_CALL("gmm_train", "input_model", "gmm", "input", "data2", "gaussians", 6, "output_model", "new_gmm")) | |
BINDING_SEE_ALSO ("@gmm_generate", "#gmm_generate") | |
BINDING_SEE_ALSO ("@gmm_probability", "#gmm_probability") | |
BINDING_SEE_ALSO ("Gaussian Mixture Models on Wikipedia", "https://en.wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model") | |
BINDING_SEE_ALSO ("mlpack::gmm::GMM class documentation", "@doxygen/classmlpack_1_1gmm_1_1GMM.html") | |
PARAM_MATRIX_IN_REQ ("input", "The training data on which the model will be " "fit.", "i") | |
PARAM_INT_IN_REQ ("gaussians", "Number of Gaussians in the GMM.", "g") | |
PARAM_INT_IN ("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0) | |
PARAM_INT_IN ("trials", "Number of trials to perform in training GMM.", "t", 1) | |
PARAM_DOUBLE_IN ("tolerance", "Tolerance for convergence of EM.", "T", 1e-10) | |
PARAM_FLAG ("no_force_positive", "Do not force the covariance matrices to be " "positive definite.", "P") | |
PARAM_INT_IN ("max_iterations", "Maximum number of iterations of EM algorithm " "(passing 0 will run until convergence).", "n", 250) | |
PARAM_FLAG ("diagonal_covariance", "Force the covariance of the Gaussians to " "be diagonal. This can accelerate training time significantly.", "d") | |
PARAM_DOUBLE_IN ("noise", "Variance of zero-mean Gaussian noise to add to data.", "N", 0) | |
PARAM_INT_IN ("kmeans_max_iterations", "Maximum number of iterations for the " "k-means algorithm (used to initialize EM).", "k", 1000) | |
PARAM_FLAG ("refined_start", "During the initialization, use refined initial " "positions for k-means clustering (Bradley and Fayyad, 1998).", "r") | |
PARAM_INT_IN ("samplings", "If using --refined_start, specify the number of " "samplings used for initial points.", "S", 100) | |
PARAM_DOUBLE_IN ("percentage", "If using --refined_start, specify the percentage" " of the dataset used for each sampling (should be between 0.0 and 1.0).", "p", 0.02) | |
PARAM_MODEL_IN (GMM, "input_model", "Initial input GMM model to start training " "with.", "m") | |
PARAM_MODEL_OUT (GMM, "output_model", "Output for trained GMM model.", "M") | |
This program trains a mixture of Gaussians on a given data matrix.
mlpack is free software; you may redistribute it and/or modify it under the terms of the 3-clause BSD license. You should have received a copy of the 3-clause BSD license along with mlpack. If not, see http://www.opensource.org/licenses/BSD-3-Clause for more information.