14 #ifndef MLPACK_METHODS_HMM_HMM_IMPL_HPP 15 #define MLPACK_METHODS_HMM_HMM_IMPL_HPP 28 template<
typename Distribution>
30 const Distribution emissions,
31 const double tolerance) :
32 emission(states, emissions),
33 transitionProxy(
arma::randu<
arma::mat>(states, states)),
34 initialProxy(
arma::randu<
arma::vec>(states) / (double) states),
35 dimensionality(emissions.Dimensionality()),
37 recalculateInitial(false),
38 recalculateTransition(false)
41 initialProxy /= arma::accu(initialProxy);
42 for (
size_t i = 0; i < transitionProxy.n_cols; ++i)
43 transitionProxy.col(i) /= arma::accu(transitionProxy.col(i));
45 logTransition = log(transitionProxy);
46 logInitial = log(initialProxy);
53 template<
typename Distribution>
55 const arma::mat& transition,
56 const std::vector<Distribution>& emission,
57 const double tolerance) :
59 transitionProxy(transition),
60 logTransition(log(transition)),
61 initialProxy(initial),
62 logInitial(log(initial)),
64 recalculateInitial(false),
65 recalculateTransition(false)
68 if (emission.size() > 0)
69 dimensionality = emission[0].Dimensionality();
72 Log::Warn <<
"HMM::HMM(): no emission distributions given; assuming a " 73 <<
"dimensionality of 0 and hoping it gets set right later." 94 template<
typename Distribution>
102 size_t iterations = 1000;
105 size_t totalLength = 0;
106 for (
size_t seq = 0; seq < dataSeq.size(); seq++)
108 totalLength += dataSeq[seq].n_cols;
110 if (dataSeq[seq].n_rows != dimensionality)
111 Log::Fatal <<
"HMM::Train(): data sequence " << seq <<
" has " 112 <<
"dimensionality " << dataSeq[seq].n_rows <<
" (expected " 113 << dimensionality <<
" dimensions)." << std::endl;
118 std::vector<arma::vec> emissionProb(logTransition.n_cols,
119 arma::vec(totalLength));
120 arma::mat emissionList(dimensionality, totalLength);
125 for (
size_t iter = 0; iter < iterations; iter++)
128 arma::vec newLogInitial(logTransition.n_rows);
129 newLogInitial.fill(-std::numeric_limits<double>::infinity());
130 arma::mat newLogTransition(logTransition.n_rows, logTransition.n_cols);
131 newLogTransition.fill(-std::numeric_limits<double>::infinity());
140 for (
size_t seq = 0; seq < dataSeq.size(); seq++)
142 arma::mat stateLogProb;
143 arma::mat forwardLog;
144 arma::mat backwardLog;
148 loglik += LogEstimate(dataSeq[seq], stateLogProb, forwardLog,
149 backwardLog, logScales);
152 math::LogSumExp<arma::vec, true>(stateLogProb.unsafe_col(0),
156 arma::mat logProbs(dataSeq[seq].n_cols, logTransition.n_rows);
158 for (
size_t i = 0; i < logTransition.n_rows; i++)
161 arma::vec alias(logProbs.colptr(i), logProbs.n_rows,
false,
true);
163 emission[i].LogProbability(dataSeq[seq], alias);
172 for (
size_t t = 0; t < dataSeq[seq].n_cols; ++t)
175 if (t < dataSeq[seq].n_cols - 1)
179 const arma::vec tmp = backwardLog.col(t + 1) +
180 logProbs.row(t + 1).t() - logScales[t + 1];
184 for (
size_t j = 0; j < logTransition.n_cols; ++j)
189 arma::vec tmp2 = output + forwardLog(j, t);
190 arma::vec alias = newLogTransition.unsafe_col(j);
191 math::LogSumExp<arma::vec, true>(tmp2, alias);
196 for (
size_t j = 0; j < logTransition.n_cols; ++j)
197 emissionProb[j][sumTime] = exp(stateLogProb(j, t));
198 emissionList.col(sumTime) = dataSeq[seq].col(t);
203 if (std::abs(oldLoglik - loglik) < tolerance)
205 Log::Debug <<
"Converged after " << iter <<
" iterations." << std::endl;
212 if (dataSeq.size() > 1)
213 logInitial = newLogInitial - std::log(dataSeq.size());
215 logInitial = newLogInitial;
221 logTransition += newLogTransition;
224 for (
size_t i = 0; i < logTransition.n_cols; i++)
227 if (std::isfinite(sum))
228 logTransition.col(i) -= sum;
230 logTransition.col(i).fill(-log((
double) logTransition.n_rows));
233 initialProxy = exp(logInitial);
234 transitionProxy = exp(logTransition);
236 for (
size_t state = 0; state < logTransition.n_cols; state++)
237 emission[state].
Train(emissionList, emissionProb[state]);
239 Log::Debug <<
"Iteration " << iter <<
": log-likelihood " << loglik
249 template<
typename Distribution>
251 const std::vector<arma::Row<size_t> >& stateSeq)
254 if (dataSeq.size() != stateSeq.size())
256 Log::Fatal <<
"HMM::Train(): number of data sequences (" << dataSeq.size()
257 <<
") not equal to number of state sequences (" << stateSeq.size()
258 <<
")." << std::endl;
261 arma::mat initial = arma::zeros(logInitial.n_elem);
262 arma::mat transition = arma::zeros(logTransition.n_rows,
263 logTransition.n_cols);
268 std::vector<std::vector<std::pair<size_t, size_t> > >
269 emissionList(transition.n_cols);
270 for (
size_t seq = 0; seq < dataSeq.size(); seq++)
273 if (dataSeq[seq].n_cols != stateSeq[seq].n_elem)
275 Log::Fatal <<
"HMM::Train(): number of observations (" 276 << dataSeq[seq].n_cols <<
") in sequence " << seq
277 <<
" not equal to number of states (" << stateSeq[seq].n_cols
278 <<
") in sequence " << seq <<
"." << std::endl;
281 if (dataSeq[seq].n_rows != dimensionality)
283 Log::Fatal <<
"HMM::Train(): data sequence " << seq <<
" has " 284 <<
"dimensionality " << dataSeq[seq].n_rows <<
" (expected " 285 << dimensionality <<
" dimensions)." << std::endl;
290 initial[stateSeq[seq][0]]++;
291 for (
size_t t = 0; t < dataSeq[seq].n_cols - 1; t++)
293 transition(stateSeq[seq][t + 1], stateSeq[seq][t])++;
294 emissionList[stateSeq[seq][t]].push_back(std::make_pair(seq, t));
298 emissionList[stateSeq[seq][stateSeq[seq].n_elem - 1]].push_back(
299 std::make_pair(seq, stateSeq[seq].n_elem - 1));
303 initial /= accu(initial);
306 for (
size_t col = 0; col < transition.n_cols; col++)
311 double sum = accu(transition.col(col));
313 transition.col(col) /= sum;
316 initialProxy = initial;
317 transitionProxy = transition;
318 logTransition = log(transition);
319 logInitial = log(initial);
322 for (
size_t state = 0; state < transition.n_cols; state++)
326 if (emissionList[state].size() > 0)
328 arma::mat emissions(dimensionality, emissionList[state].size());
329 for (
size_t i = 0; i < emissions.n_cols; i++)
331 emissions.col(i) = dataSeq[emissionList[state][i].first].col(
332 emissionList[state][i].second);
335 emission[state].Train(emissions);
339 Log::Warn <<
"There are no observations in training data with hidden " 340 <<
"state " << state <<
"! The corresponding emission distribution " 341 <<
"is likely to be meaningless." << std::endl;
350 template<
typename Distribution>
352 arma::mat& stateLogProb,
353 arma::mat& forwardLogProb,
354 arma::mat& backwardLogProb,
355 arma::vec& logScales)
const 357 arma::mat logProbs(dataSeq.n_cols, logTransition.n_rows);
360 for (
size_t i = 0; i < logTransition.n_rows; i++)
363 arma::vec alias(logProbs.colptr(i), logProbs.n_rows,
false,
true);
365 emission[i].LogProbability(dataSeq, alias);
369 Forward(dataSeq, logScales, forwardLogProb, logProbs);
370 Backward(dataSeq, logScales, backwardLogProb, logProbs);
374 stateLogProb = forwardLogProb + backwardLogProb;
377 return accu(logScales);
384 template<
typename Distribution>
386 arma::mat& stateProb,
387 arma::mat& forwardProb,
388 arma::mat& backwardProb,
389 arma::vec& scales)
const 391 arma::mat stateLogProb;
392 arma::mat forwardLogProb;
393 arma::mat backwardLogProb;
396 const double loglikelihood = LogEstimate(dataSeq, stateLogProb,
397 forwardLogProb, backwardLogProb, logScales);
399 stateProb = exp(stateLogProb);
400 forwardProb = exp(forwardLogProb);
401 backwardProb = exp(backwardLogProb);
402 scales = exp(logScales);
404 return loglikelihood;
411 template<
typename Distribution>
413 arma::mat& stateProb)
const 416 arma::mat stateLogProb;
417 arma::mat forwardLogProb;
418 arma::mat backwardLogProb;
421 const double loglikelihood = LogEstimate(dataSeq, stateLogProb,
422 forwardLogProb, backwardLogProb, logScales);
424 stateProb = exp(stateLogProb);
426 return loglikelihood;
434 template<
typename Distribution>
436 arma::mat& dataSequence,
437 arma::Row<size_t>& stateSequence,
438 const size_t startState)
const 441 stateSequence.set_size(length);
442 dataSequence.set_size(dimensionality, length);
445 stateSequence[0] = startState;
452 dataSequence.col(0) = emission[startState].Random();
457 for (
size_t t = 1; t < length; t++)
465 for (
size_t st = 0; st < logTransition.n_rows; st++)
467 probSum += exp(logTransition(st, stateSequence[t - 1]));
468 if (randValue <= probSum)
470 stateSequence[t] = st;
476 dataSequence.col(t) = emission[stateSequence[t]].Random();
485 template<
typename Distribution>
487 arma::Row<size_t>& stateSeq)
const 493 stateSeq.set_size(dataSeq.n_cols);
494 arma::mat logStateProb(logTransition.n_rows, dataSeq.n_cols);
495 arma::mat stateSeqBack(logTransition.n_rows, dataSeq.n_cols);
502 logStateProb.col(0).zeros();
503 for (
size_t state = 0; state < logTransition.n_rows; state++)
505 logStateProb(state, 0) = logInitial[state] +
506 emission[state].LogProbability(dataSeq.unsafe_col(0));
507 stateSeqBack(state, 0) = state;
514 arma::mat logProbs(dataSeq.n_cols, logTransition.n_rows);
517 for (
size_t i = 0; i < logTransition.n_rows; i++)
520 arma::vec alias(logProbs.colptr(i), logProbs.n_rows,
false,
true);
522 emission[i].LogProbability(dataSeq, alias);
525 for (
size_t t = 1; t < dataSeq.n_cols; t++)
530 for (
size_t j = 0; j < logTransition.n_rows; j++)
532 arma::vec prob = logStateProb.col(t - 1) + logTransition.row(j).t();
533 logStateProb(j, t) = prob.max(index) + logProbs(t, j);
534 stateSeqBack(j, t) = index;
539 logStateProb.unsafe_col(dataSeq.n_cols - 1).max(index);
540 stateSeq[dataSeq.n_cols - 1] = index;
541 for (
size_t t = 2; t <= dataSeq.n_cols; t++)
543 stateSeq[dataSeq.n_cols - t] =
544 stateSeqBack(stateSeq[dataSeq.n_cols - t + 1], dataSeq.n_cols - t + 1);
547 return logStateProb(stateSeq(dataSeq.n_cols - 1), dataSeq.n_cols - 1);
553 template<
typename Distribution>
556 arma::mat forwardLog;
560 arma::mat logProbs(dataSeq.n_cols, logTransition.n_rows);
563 for (
size_t i = 0; i < logTransition.n_rows; i++)
566 arma::vec alias(logProbs.colptr(i), logProbs.n_rows,
false,
true);
568 emission[i].LogProbability(dataSeq, alias);
571 Forward(dataSeq, logScales, forwardLog, logProbs);
574 return accu(logScales);
582 template<
typename Distribution>
584 const arma::vec& emissionLogProb,
585 arma::vec& forwardLogProb)
const 588 if (forwardLogProb.empty())
591 forwardLogProb = ForwardAtT0(emissionLogProb, curLogScale);
595 forwardLogProb = ForwardAtTn(emissionLogProb, curLogScale,
605 template<
typename Distribution>
607 const arma::vec& emissionLogProb,
608 double& logLikelihood,
609 arma::vec& forwardLogProb)
const 611 bool isStartOfSeq = forwardLogProb.empty();
612 double curLogScale = EmissionLogScaleFactor(emissionLogProb, forwardLogProb);
613 logLikelihood = isStartOfSeq ? curLogScale : curLogScale + logLikelihood;
614 return logLikelihood;
622 template<
typename Distribution>
624 arma::vec& forwardLogProb)
const 626 arma::vec emissionLogProb(logTransition.n_rows);
628 for (
size_t state = 0; state < logTransition.n_rows; state++)
630 emissionLogProb(state) = emission[state].LogProbability(data);
633 return EmissionLogScaleFactor(emissionLogProb, forwardLogProb);
639 template<
typename Distribution>
641 double& logLikelihood,
642 arma::vec& forwardLogProb)
const 644 bool isStartOfSeq = forwardLogProb.empty();
645 double curLogScale = LogScaleFactor(data, forwardLogProb);
646 logLikelihood = isStartOfSeq ? curLogScale : curLogScale + logLikelihood;
647 return logLikelihood;
653 template<
typename Distribution>
655 arma::mat& filterSeq,
659 arma::mat forwardLogProb;
662 arma::mat logProbs(dataSeq.n_cols, logTransition.n_rows);
665 for (
size_t i = 0; i < logTransition.n_rows; i++)
668 arma::vec alias(logProbs.colptr(i), logProbs.n_rows,
false,
true);
670 emission[i].LogProbability(dataSeq, alias);
673 Forward(dataSeq, logScales, forwardLogProb, logProbs);
677 forwardLogProb += ahead * logTransition;
679 arma::mat forwardProb = exp(forwardLogProb);
683 filterSeq.zeros(dimensionality, dataSeq.n_cols);
684 for (
size_t i = 0; i < emission.size(); i++)
685 filterSeq += emission[i].Mean() * forwardProb.row(i);
691 template<
typename Distribution>
693 arma::mat& smoothSeq)
const 696 arma::mat stateLogProb;
697 arma::mat forwardLogProb;
698 arma::mat backwardLogProb;
700 LogEstimate(dataSeq, stateLogProb, forwardLogProb, backwardLogProb,
705 smoothSeq.zeros(dimensionality, dataSeq.n_cols);
706 for (
size_t i = 0; i < emission.size(); i++)
707 smoothSeq += emission[i].Mean() * exp(stateLogProb.row(i));
713 template<
typename Distribution>
715 double& logScales)
const 726 arma::vec forwardLogProb = logInitial + emissionLogProb;
730 if (std::isfinite(logScales))
731 forwardLogProb -= logScales;
733 return forwardLogProb;
739 template<
typename Distribution>
742 const arma::vec& prevForwardLogProb)
752 arma::vec forwardLogProb;
753 arma::mat tmp = logTransition + repmat(prevForwardLogProb.t(),
754 logTransition.n_rows, 1);
756 forwardLogProb += emissionLogProb;
760 if (std::isfinite(logScales))
761 forwardLogProb -= logScales;
763 return forwardLogProb;
769 template<
typename Distribution>
771 arma::vec& logScales,
772 arma::mat& forwardLogProb,
773 arma::mat& logProbs)
const 777 forwardLogProb.resize(logTransition.n_rows, dataSeq.n_cols);
778 forwardLogProb.fill(-std::numeric_limits<double>::infinity());
779 logScales.resize(dataSeq.n_cols);
780 logScales.fill(-std::numeric_limits<double>::infinity());
788 forwardLogProb.col(0) = ForwardAtT0(logProbs.row(0).t(), logScales(0));
791 for (
size_t t = 1; t < dataSeq.n_cols; t++)
793 forwardLogProb.col(t) = ForwardAtTn(logProbs.row(t).t(), logScales(t),
794 forwardLogProb.col(t - 1));
798 template<
typename Distribution>
800 const arma::vec& logScales,
801 arma::mat& backwardLogProb,
802 arma::mat& logProbs)
const 806 backwardLogProb.resize(logTransition.n_rows, dataSeq.n_cols);
807 backwardLogProb.fill(-std::numeric_limits<double>::infinity());
810 backwardLogProb.col(dataSeq.n_cols - 1).fill(0);
813 for (
size_t t = dataSeq.n_cols - 2; t + 1 > 0; t--)
820 const arma::mat tmp = logTransition +
821 repmat(backwardLogProb.col(t + 1), 1, logTransition.n_cols) +
822 repmat(logProbs.row(t + 1).t(), 1, logTransition.n_cols);
823 arma::vec alias = backwardLogProb.unsafe_col(t);
824 math::LogSumExpT<arma::mat, true>(tmp, alias);
827 if (std::isfinite(logScales[t + 1]))
828 backwardLogProb.col(t) -= logScales[t + 1];
836 template<
typename Distribution>
839 if (recalculateInitial)
841 logInitial = log(initialProxy);
842 recalculateInitial =
false;
845 if (recalculateTransition)
847 logTransition = log(transitionProxy);
848 recalculateTransition =
false;
853 template<
typename Distribution>
854 template<
typename Archive>
857 arma::mat transition;
859 ar(CEREAL_NVP(dimensionality));
860 ar(CEREAL_NVP(tolerance));
861 ar(CEREAL_NVP(transition));
862 ar(CEREAL_NVP(initial));
866 emission.resize(transition.n_rows);
868 ar(CEREAL_NVP(emission));
870 logTransition = log(transition);
871 logInitial = log(initial);
872 initialProxy = std::move(initial);
873 transitionProxy = std::move(transition);
877 template<
typename Distribution>
878 template<
typename Archive>
880 const uint32_t )
const 882 arma::mat transition = exp(logTransition);
883 arma::vec initial = exp(logInitial);
884 ar(CEREAL_NVP(dimensionality));
885 ar(CEREAL_NVP(tolerance));
886 ar(CEREAL_NVP(transition));
887 ar(CEREAL_NVP(initial));
888 ar(CEREAL_NVP(emission));
void LogSumExp(const T &x, arma::Col< typename T::elem_type > &y)
Compute the sum of exponentials of each element in each column, then compute the log of that...
Definition: log_add_impl.hpp:78
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: hmm_train_main.cpp:300
A class that represents a Hidden Markov Model with an arbitrary type of emission distribution.
Definition: hmm.hpp:85
T::elem_type AccuLog(const T &x)
Log-sum a vector of log values.
Definition: log_add_impl.hpp:63
double Random()
Generates a uniform random number between 0 and 1.
Definition: random.hpp:83
HMM(const size_t states=0, const Distribution emissions=Distribution(), const double tolerance=1e-5)
Create the Hidden Markov Model with the given number of hidden states and the given default distribut...
Definition: hmm_impl.hpp:29