mlpack
save_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_DATA_SAVE_IMPL_HPP
13 #define MLPACK_CORE_DATA_SAVE_IMPL_HPP
14 
15 // In case it hasn't already been included.
16 #include "save.hpp"
17 #include "extension.hpp"
18 #include "detect_file_type.hpp"
19 
20 #include <cereal/archives/xml.hpp>
21 #include <cereal/archives/json.hpp>
22 #include <cereal/archives/binary.hpp>
23 
24 namespace mlpack {
25 namespace data {
26 
27 template<typename eT>
28 bool Save(const std::string& filename,
29  const arma::Col<eT>& vec,
30  const bool fatal,
31  arma::file_type inputSaveType)
32 {
33  // Don't transpose: one observation per line (for CSVs at least).
34  return Save(filename, vec, fatal, false, inputSaveType);
35 }
36 
37 template<typename eT>
38 bool Save(const std::string& filename,
39  const arma::Row<eT>& rowvec,
40  const bool fatal,
41  arma::file_type inputSaveType)
42 {
43  return Save(filename, rowvec, fatal, true, inputSaveType);
44 }
45 
46 template<typename eT>
47 bool Save(const std::string& filename,
48  const arma::Mat<eT>& matrix,
49  const bool fatal,
50  bool transpose,
51  arma::file_type inputSaveType)
52 {
53  Timer::Start("saving_data");
54 
55  arma::file_type saveType = inputSaveType;
56  std::string stringType = "";
57 
58  if (inputSaveType == arma::auto_detect)
59  {
60  // Detect the file type using only the extension.
61  saveType = DetectFromExtension(filename);
62  if (saveType == arma::file_type_unknown)
63  {
64  if (fatal)
65  Log::Fatal << "Could not detect type of file '" << filename << "' for "
66  << "writing. Save failed." << std::endl;
67  else
68  Log::Warn << "Could not detect type of file '" << filename << "' for "
69  << "writing. Save failed." << std::endl;
70 
71  return false;
72  }
73  }
74 
75  stringType = GetStringType(saveType);
76 
77  // Catch errors opening the file.
78  std::fstream stream;
79 #ifdef _WIN32 // Always open in binary mode on Windows.
80  stream.open(filename.c_str(), std::fstream::out | std::fstream::binary);
81 #else
82  stream.open(filename.c_str(), std::fstream::out);
83 #endif
84  if (!stream.is_open())
85  {
86  Timer::Stop("saving_data");
87  if (fatal)
88  Log::Fatal << "Cannot open file '" << filename << "' for writing. "
89  << "Save failed." << std::endl;
90  else
91  Log::Warn << "Cannot open file '" << filename << "' for writing; save "
92  << "failed." << std::endl;
93 
94  return false;
95  }
96 
97  // Try to save the file.
98  Log::Info << "Saving " << stringType << " to '" << filename << "'."
99  << std::endl;
100 
101  // Transpose the matrix.
102  if (transpose)
103  {
104  arma::Mat<eT> tmp = trans(matrix);
105 
106 #ifdef ARMA_USE_HDF5
107  // We can't save with streams for HDF5.
108  const bool success = (saveType == arma::hdf5_binary) ?
109  tmp.quiet_save(filename, saveType) :
110  tmp.quiet_save(stream, saveType);
111 #else
112  const bool success = tmp.quiet_save(stream, saveType);
113 #endif
114  if (!success)
115  {
116  Timer::Stop("saving_data");
117  if (fatal)
118  Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
119  else
120  Log::Warn << "Save to '" << filename << "' failed." << std::endl;
121 
122  return false;
123  }
124  }
125  else
126  {
127 #ifdef ARMA_USE_HDF5
128  // We can't save with streams for HDF5.
129  const bool success = (saveType == arma::hdf5_binary) ?
130  matrix.quiet_save(filename, saveType) :
131  matrix.quiet_save(stream, saveType);
132 #else
133  const bool success = matrix.quiet_save(stream, saveType);
134 #endif
135  if (!success)
136  {
137  Timer::Stop("saving_data");
138  if (fatal)
139  Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
140  else
141  Log::Warn << "Save to '" << filename << "' failed." << std::endl;
142 
143  return false;
144  }
145  }
146 
147  Timer::Stop("saving_data");
148 
149  // Finally return success.
150  return true;
151 }
152 
153 // Save a Sparse Matrix
154 template<typename eT>
155 bool Save(const std::string& filename,
156  const arma::SpMat<eT>& matrix,
157  const bool fatal,
158  bool transpose)
159 {
160  Timer::Start("saving_data");
161 
162  // First we will try to discriminate by file extension.
163  std::string extension = Extension(filename);
164  if (extension == "")
165  {
166  Timer::Stop("saving_data");
167  if (fatal)
168  Log::Fatal << "No extension given with filename '" << filename << "'; "
169  << "type unknown. Save failed." << std::endl;
170  else
171  Log::Warn << "No extension given with filename '" << filename << "'; "
172  << "type unknown. Save failed." << std::endl;
173 
174  return false;
175  }
176 
177  // Catch errors opening the file.
178  std::fstream stream;
179 #ifdef _WIN32 // Always open in binary mode on Windows.
180  stream.open(filename.c_str(), std::fstream::out | std::fstream::binary);
181 #else
182  stream.open(filename.c_str(), std::fstream::out);
183 #endif
184  if (!stream.is_open())
185  {
186  Timer::Stop("saving_data");
187  if (fatal)
188  Log::Fatal << "Cannot open file '" << filename << "' for writing. "
189  << "Save failed." << std::endl;
190  else
191  Log::Warn << "Cannot open file '" << filename << "' for writing; save "
192  << "failed." << std::endl;
193 
194  return false;
195  }
196 
197  bool unknownType = false;
198  arma::file_type saveType;
199  std::string stringType;
200 
201  if (extension == "txt" || extension == "tsv")
202  {
203  saveType = arma::coord_ascii;
204  stringType = "raw ASCII formatted data";
205  }
206  else if (extension == "bin")
207  {
208  saveType = arma::arma_binary;
209  stringType = "Armadillo binary formatted data";
210  }
211  else
212  {
213  unknownType = true;
214  saveType = arma::raw_binary; // Won't be used; prevent a warning.
215  stringType = "";
216  }
217 
218  // Provide error if we don't know the type.
219  if (unknownType)
220  {
221  Timer::Stop("saving_data");
222  if (fatal)
223  Log::Fatal << "Unable to determine format to save to from filename '"
224  << filename << "'. Save failed." << std::endl;
225  else
226  Log::Warn << "Unable to determine format to save to from filename '"
227  << filename << "'. Save failed." << std::endl;
228 
229  return false;
230  }
231 
232  // Try to save the file.
233  Log::Info << "Saving " << stringType << " to '" << filename << "'."
234  << std::endl;
235 
236  arma::SpMat<eT> tmp = matrix;
237 
238  // Transpose the matrix.
239  if (transpose)
240  {
241  tmp = trans(matrix);
242  }
243 
244  const bool success = tmp.quiet_save(stream, saveType);
245  if (!success)
246  {
247  Timer::Stop("saving_data");
248  if (fatal)
249  Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
250  else
251  Log::Warn << "Save to '" << filename << "' failed." << std::endl;
252 
253  return false;
254  }
255 
256  Timer::Stop("saving_data");
257 
258  // Finally return success.
259  return true;
260 }
261 
263 template<typename T>
264 bool Save(const std::string& filename,
265  const std::string& name,
266  T& t,
267  const bool fatal,
268  format f)
269 {
270  if (f == format::autodetect)
271  {
272  std::string extension = Extension(filename);
273 
274  if (extension == "xml")
275  f = format::xml;
276  else if (extension == "bin")
277  f = format::binary;
278  else if (extension == "json")
279  f = format::json;
280  else
281  {
282  if (fatal)
283  Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
284  << " extension? (allowed: xml/bin/json)" << std::endl;
285  else
286  Log::Warn << "Unable to detect type of '" << filename << "'; save "
287  << "failed. Incorrect extension? (allowed: xml/bin/json)"
288  << std::endl;
289 
290  return false;
291  }
292  }
293 
294  // Open the file to save to.
295  std::ofstream ofs;
296 #ifdef _WIN32
297  if (f == format::binary) // Open non-text types in binary mode on Windows.
298  ofs.open(filename, std::ofstream::out | std::ofstream::binary);
299  else
300  ofs.open(filename, std::ofstream::out);
301 #else
302  ofs.open(filename, std::ofstream::out);
303 #endif
304 
305  if (!ofs.is_open())
306  {
307  if (fatal)
308  Log::Fatal << "Unable to open file '" << filename << "' to save object '"
309  << name << "'." << std::endl;
310  else
311  Log::Warn << "Unable to open file '" << filename << "' to save object '"
312  << name << "'." << std::endl;
313 
314  return false;
315  }
316 
317  try
318  {
319  if (f == format::xml)
320  {
321  cereal::XMLOutputArchive ar(ofs);
322  ar(cereal::make_nvp(name.c_str(), t));
323  }
324  else if (f == format::json)
325  {
326  cereal::JSONOutputArchive ar(ofs);
327  ar(cereal::make_nvp(name.c_str(), t));
328  }
329  else if (f == format::binary)
330  {
331  cereal::BinaryOutputArchive ar(ofs);
332  ar(cereal::make_nvp(name.c_str(), t));
333  }
334 
335  return true;
336  }
337  catch (cereal::Exception& e)
338  {
339  if (fatal)
340  Log::Fatal << e.what() << std::endl;
341  else
342  Log::Warn << e.what() << std::endl;
343 
344  return false;
345  }
346 }
347 
356 template<typename eT>
357 bool Save(const std::string& filename,
358  arma::Mat<eT>& matrix,
359  ImageInfo& info,
360  const bool fatal)
361 {
362  arma::Mat<unsigned char> tmpMatrix =
363  arma::conv_to<arma::Mat<unsigned char>>::from(matrix);
364 
365  // Call out to .cpp implementation.
366  return SaveImage(filename, tmpMatrix, info, fatal);
367 }
368 
369 // Image saving API for multiple files.
370 template<typename eT>
371 bool Save(const std::vector<std::string>& files,
372  arma::Mat<eT>& matrix,
373  ImageInfo& info,
374  const bool fatal)
375 {
376  if (files.size() == 0)
377  {
378  if (fatal)
379  {
380  Log::Fatal << "Save(): vector of image files is empty; nothing to save."
381  << std::endl;
382  }
383  else
384  {
385  Log::Warn << "Save(): vector of image files is empty; nothing to save."
386  << std::endl;
387  }
388 
389  return false;
390  }
391 
392  arma::Mat<unsigned char> img;
393  bool status = true;
394 
395  for (size_t i = 0; i < files.size() ; ++i)
396  {
397  arma::Mat<eT> colImg(matrix.colptr(i), matrix.n_rows, 1,
398  false, true);
399  status &= Save(files[i], colImg, info, fatal);
400  }
401 
402  return status;
403 }
404 
405 } // namespace data
406 } // namespace mlpack
407 
408 #endif
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Implements meta-data of images required by data::Load and data::Save for loading and saving images in...
Definition: image_info.hpp:36
bool SaveImage(const std::string &filename, arma::Mat< unsigned char > &image, ImageInfo &info, const bool fatal=false)
Helper function to save files.
Definition: save_image.cpp:130
std::string GetStringType(const arma::file_type &type)
Given a file type, return a logical name corresponding to that file type.
Definition: detect_file_type.cpp:30
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
format
Define the formats we can read through cereal.
Definition: format.hpp:20
bool Save(const std::string &filename, const arma::Mat< eT > &matrix, const bool fatal=false, bool transpose=true, arma::file_type inputSaveType=arma::auto_detect)
Saves a matrix to file, guessing the filetype from the extension.
Definition: save_impl.hpp:47
arma::file_type DetectFromExtension(const std::string &filename)
Return the type based only on the extension.
Definition: detect_file_type.cpp:310