15 #ifndef MLPACK_TESTS_MAIN_TESTS_HMM_TEST_UTILS_HPP 16 #define MLPACK_TESTS_MAIN_TESTS_HMM_TEST_UTILS_HPP 23 template<
typename HMMType>
24 static void Apply(HMMType& hmm, vector<mat>* trainSeq)
26 const size_t states = 2;
29 Create(hmm, *trainSeq, states);
37 static void Create(HMM<DiscreteDistribution>& hmm,
38 vector<mat>& trainSeq,
40 double tolerance = 1e-05)
44 arma::Col<size_t> maxEmissions(trainSeq[0].n_rows);
46 for (vector<mat>::iterator it = trainSeq.begin(); it != trainSeq.end();
49 arma::Col<size_t> maxSeqs =
50 arma::conv_to<arma::Col<size_t>>::from(arma::max(*it, 1)) + 1;
51 maxEmissions = arma::max(maxEmissions, maxSeqs);
54 hmm = HMM<DiscreteDistribution>(size_t(states),
55 DiscreteDistribution(maxEmissions), tolerance);
58 static void Create(HMM<GaussianDistribution>& hmm,
59 vector<mat>& trainSeq,
61 double tolerance = 1e-05)
64 const size_t dimensionality = trainSeq[0].n_rows;
67 for (
size_t i = 0; i < trainSeq.size(); ++i)
69 if (trainSeq[i].n_rows != dimensionality)
71 Log::Fatal <<
"Observation sequence " << i <<
" dimensionality (" 72 << trainSeq[i].n_rows <<
" is incorrect (should be " 73 << dimensionality <<
")!" << endl;
78 hmm = HMM<GaussianDistribution>(size_t(states),
79 GaussianDistribution(dimensionality), tolerance);
82 static void Create(HMM<GMM>& hmm,
83 vector<mat>& trainSeq,
85 double tolerance = 1e-05)
88 const size_t dimensionality = trainSeq[0].n_rows;
89 const int gaussians = 2;
93 Log::Fatal <<
"Number of gaussians for each GMM must be specified " 94 <<
"when type = 'gmm'!" << endl;
99 Log::Fatal <<
"Invalid number of gaussians (" << gaussians <<
"); must " 100 <<
"be greater than or equal to 1." << endl;
104 hmm = HMM<GMM>(size_t(states), GMM(
size_t(gaussians), dimensionality),
109 static void Create(HMM<DiagonalGMM>& hmm,
110 vector<mat>& trainSeq,
112 double tolerance = 1e-05)
115 const size_t dimensionality = trainSeq[0].n_rows;
116 const int gaussians = 2;
120 Log::Fatal <<
"Number of gaussians for each GMM must be specified " 121 <<
"when type = 'diag_gmm'!" << endl;
126 Log::Fatal <<
"Invalid number of gaussians (" << gaussians <<
"); must " 127 <<
"be greater than or equal to 1." << endl;
131 hmm = HMM<DiagonalGMM>(size_t(states), DiagonalGMM(
size_t(gaussians),
132 dimensionality), tolerance);
138 for (
size_t i = 0; i < e.size(); ++i)
140 e[i].Probabilities().randu();
141 e[i].Probabilities() /= arma::accu(e[i].Probabilities());
147 for (
size_t i = 0; i < e.size(); ++i)
149 const size_t dimensionality = e[i].Mean().n_rows;
152 arma::mat r = arma::randu<arma::mat>(dimensionality, dimensionality);
153 e[i].Covariance(r * r.t());
159 for (
size_t i = 0; i < e.size(); ++i)
162 e[i].Weights().randu();
163 e[i].Weights() /= arma::accu(e[i].Weights());
166 for (
int g = 0; g < 2; ++g)
168 const size_t dimensionality = e[i].Component(g).Mean().n_rows;
169 e[i].Component(g).Mean().randu();
172 arma::mat r = arma::randu<arma::mat>(dimensionality,
174 e[i].Component(g).Covariance(r * r.t());
182 for (
size_t i = 0; i < e.size(); ++i)
185 e[i].Weights().randu();
186 e[i].Weights() /= arma::accu(e[i].Weights());
189 for (
int g = 0; g < 2; ++g)
191 const size_t dimensionality = e[i].Component(g).Mean().n_rows;
192 e[i].Component(g).Mean().randu();
195 arma::vec r = arma::randu<arma::vec>(dimensionality);
196 e[i].Component(g).Covariance(r);
204 template<
typename HMMType>
205 static void Apply(HMMType& hmm, vector<arma::mat>* trainSeq)
208 hmm.Train(*trainSeq);
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void RandomInitialize(vector< DiscreteDistribution > &e)
Helper function for discrete emission distributions.
Definition: hmm_test_utils.hpp:136
static void Create(HMM< DiscreteDistribution > &hmm, vector< mat > &trainSeq, size_t states, double tolerance=1e-05)
Helper function to create discrete HMM.
Definition: hmm_test_utils.hpp:37
static void RandomInitialize(vector< DiagonalGMM > &e)
Helper function for diagonal GMM emission distributions.
Definition: hmm_test_utils.hpp:180
Definition: hmm_test_utils.hpp:21
Definition: hmm_test_utils.hpp:202
static void Create(HMM< DiagonalGMM > &hmm, vector< mat > &trainSeq, size_t states, double tolerance=1e-05)
Helper function to create Diagonal GMM HMM.
Definition: hmm_test_utils.hpp:109