mlpack
hmm_util_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
13 #define MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
20 
21 namespace mlpack {
22 namespace hmm {
23 
24 // Forward declarations of utility functions.
25 
26 // Set up the archive for deserialization.
27 template<typename ActionType, typename ArchiveType, typename ExtraInfoType>
28 void LoadHMMAndPerformActionHelper(const std::string& modelFile,
29  ExtraInfoType* x = NULL);
30 
31 // Actually deserialize into the correct type.
32 template<typename ActionType,
33  typename ArchiveType,
34  typename HMMType,
35  typename ExtraInfoType>
36 void DeserializeHMMAndPerformAction(ArchiveType& ar, ExtraInfoType* x = NULL);
37 
38 template<typename ActionType, typename ExtraInfoType>
39 void LoadHMMAndPerformAction(const std::string& modelFile,
40  ExtraInfoType* x)
41 {
42  const std::string extension = data::Extension(modelFile);
43  if (extension == "xml")
44  {
45  LoadHMMAndPerformActionHelper<ActionType, cereal::XMLInputArchive>(
46  modelFile, x);
47  }
48  else if (extension == "bin")
49  {
50  LoadHMMAndPerformActionHelper<ActionType, cereal::BinaryInputArchive>(
51  modelFile, x);
52  }
53  else if (extension == "json")
54  {
55  LoadHMMAndPerformActionHelper<ActionType, cereal::JSONInputArchive>(
56  modelFile, x);
57  }
58  else
59  {
60  Log::Fatal << "Unknown extension '" << extension << "' for HMM model file "
61  << "(known: 'xml', 'json', 'bin')." << std::endl;
62  }
63 }
64 
65 template<typename ActionType,
66  typename ArchiveType,
67  typename ExtraInfoType>
68 void LoadHMMAndPerformActionHelper(const std::string& modelFile,
69  ExtraInfoType* x)
70 {
71  std::ifstream ifs(modelFile);
72  if (ifs.fail())
73  Log::Fatal << "Cannot open model file '" << modelFile << "' for loading!"
74  << std::endl;
75  ArchiveType ar(ifs);
76 
77  // Read in the unsigned integer that denotes the type of the model.
78  char type;
79  ar(CEREAL_NVP(type));
80 
81  using namespace mlpack::distribution;
82 
83  switch (type)
84  {
85  case HMMType::DiscreteHMM:
86  DeserializeHMMAndPerformAction<ActionType, ArchiveType,
88  break;
89 
90  case HMMType::GaussianHMM:
91  DeserializeHMMAndPerformAction<ActionType, ArchiveType,
93  break;
94 
95  case HMMType::GaussianMixtureModelHMM:
96  DeserializeHMMAndPerformAction<ActionType, ArchiveType,
97  HMM<gmm::GMM>>(ar, x);
98  break;
99 
100  case HMMType::DiagonalGaussianMixtureModelHMM:
101  DeserializeHMMAndPerformAction<ActionType, ArchiveType,
102  HMM<gmm::DiagonalGMM>>(ar, x);
103 
104  default:
105  Log::Fatal << "Unknown HMM type '" << (unsigned int) type << "'!"
106  << std::endl;
107  }
108 }
109 
110 template<typename ActionType,
111  typename ArchiveType,
112  typename HMMType,
113  typename ExtraInfoType>
114 void DeserializeHMMAndPerformAction(ArchiveType& ar, ExtraInfoType* x)
115 {
116  // Extract the HMM and perform the action.
117  HMMType hmm;
118  ar(CEREAL_NVP(hmm));
119  ActionType::Apply(hmm, x);
120 }
121 
122 // Helper function.
123 template<typename ArchiveType, typename HMMType>
124 void SaveHMMHelper(HMMType& hmm, const std::string& modelFile);
125 
126 template<typename HMMType>
127 char GetHMMType();
128 
129 template<typename HMMType>
130 void SaveHMM(HMMType& hmm, const std::string& modelFile)
131 {
132  const std::string extension = data::Extension(modelFile);
133  if (extension == "xml")
134  SaveHMMHelper<cereal::XMLOutputArchive>(hmm, modelFile);
135  else if (extension == "bin")
136  SaveHMMHelper<cereal::BinaryOutputArchive>(hmm, modelFile);
137  else if (extension == "json")
138  SaveHMMHelper<cereal::JSONOutputArchive>(hmm, modelFile);
139  else
140  Log::Fatal << "Unknown extension '" << extension << "' for HMM model file."
141  << std::endl;
142 }
143 
144 template<typename ArchiveType, typename HMMType>
145 void SaveHMMHelper(HMMType& hmm, const std::string& modelFile)
146 {
147  std::ofstream ofs(modelFile);
148  if (ofs.fail())
149  Log::Fatal << "Cannot open model file '" << modelFile << "' for saving!"
150  << std::endl;
151  ArchiveType ar(ofs);
152 
153  // Write out the unsigned integer that denotes the type of the model.
154  char type = GetHMMType<HMMType>();
155  if (type == char(-1))
156  Log::Fatal << "Unknown HMM type given to SaveHMM()!" << std::endl;
157 
158  ar(CEREAL_NVP(type));
159  ar(CEREAL_NVP(hmm));
160 }
161 
162 // Utility functions to turn a type into something we can store.
163 template<typename HMMType>
164 char GetHMMType() { return char(-1); }
165 
166 template<>
167 char GetHMMType<HMM<distribution::DiscreteDistribution>>()
168 {
169  return HMMType::DiscreteHMM;
170 }
171 
172 template<>
173 char GetHMMType<HMM<distribution::GaussianDistribution>>()
174 {
175  return HMMType::GaussianHMM;
176 }
177 
178 template<>
179 char GetHMMType<HMM<gmm::GMM>>()
180 {
181  return HMMType::GaussianMixtureModelHMM;
182 }
183 
184 template<>
185 char GetHMMType<HMM<gmm::DiagonalGMM>>()
186 {
187  return HMMType::DiagonalGaussianMixtureModelHMM;
188 }
189 
190 } // namespace hmm
191 } // namespace mlpack
192 
193 #endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void LoadHMMAndPerformAction(const std::string &modelFile, ExtraInfoType *x=NULL)
ActionType should implement static void Apply(HMMType&).
Definition: hmm_util_impl.hpp:39
Probability distributions.
Definition: diagonal_gaussian_distribution.hpp:18
A class that represents a Hidden Markov Model with an arbitrary type of emission distribution.
Definition: hmm.hpp:85
void SaveHMM(HMMType &hmm, const std::string &modelFile)
Save an HMM to a file.
Definition: hmm_util_impl.hpp:130