mlpack
copy_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AUGMENTED_TASKS_COPY_IMPL_HPP
13 #define MLPACK_METHODS_AUGMENTED_TASKS_COPY_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "copy.hpp"
17 
18 namespace mlpack {
19 namespace ann /* Artificial Neural Network */ {
20 namespace augmented /* Augmented neural network */ {
21 namespace tasks /* Task utilities for augmented */ {
22 
23 CopyTask::CopyTask(const size_t maxLength,
24  const size_t nRepeats,
25  const bool addSeparator) :
26  maxLength(maxLength),
27  nRepeats(nRepeats),
28  addSeparator(addSeparator)
29 {
30  if (maxLength <= 1)
31  {
32  std::ostringstream oss;
33  oss << "CopyTask::CopyTask(): maximum sequence length ("
34  << maxLength << ") "
35  << "should be at least 2!"
36  << std::endl;
37  throw std::invalid_argument(oss.str());
38  }
39  if (nRepeats <= 0)
40  {
41  std::ostringstream oss;
42  oss << "CopyTask::CopyTask(): repetition count (" << nRepeats << ") "
43  << "is not positive!"
44  << std::endl;
45  throw std::invalid_argument(oss.str());
46  }
47  // Just storing task-specific parameters.
48 }
49 
50 void CopyTask::Generate(arma::field<arma::mat>& input,
51  arma::field<arma::mat>& labels,
52  const size_t batchSize,
53  bool fixedLength) const
54 {
55  input = arma::field<arma::mat>(batchSize);
56  labels = arma::field<arma::mat>(batchSize);
57  size_t size = maxLength;
58  for (size_t i = 0; i < batchSize; ++i)
59  {
60  if (!fixedLength)
61  {
62  arma::vec weights(maxLength - 1);
63 
65  // We have two binary numbers with exactly two digits (10 and 11).
66  // Increasing length by 1 double the number of valid numbers.
67  d.Probabilities(0) =
68  arma::exp2(arma::linspace(1, maxLength - 1, maxLength - 1));
69 
70  size = 2 + d.Random()(0);
71  }
72  arma::colvec vecInput = arma::randi<arma::colvec>(
73  size, arma::distr_param(0, 1));
74  arma::colvec vecLabel = arma::conv_to<arma::colvec>::from(
75  arma::repmat(vecInput, nRepeats, 1));
76  size_t totSize = vecInput.n_elem + addSeparator + vecLabel.n_elem;
77  input(i) = arma::zeros(totSize, 2);
78  input(i).col(0).rows(0, vecInput.n_elem - 1) =
79  vecInput;
80  if (addSeparator)
81  input(i).at(vecInput.n_elem, 0) = 0.5;
82  input(i).col(1).rows(addSeparator + vecInput.n_elem, totSize - 1) =
83  arma::ones(totSize-vecInput.n_elem - addSeparator);
84  input(i) = input(i).t();
85  input(i).reshape(input(i).n_elem, 1);
86  labels(i) = arma::zeros(totSize, 1);
87  labels(i).col(0).rows(addSeparator + vecInput.n_elem, totSize - 1) =
88  vecLabel;
89  }
90 }
91 
92 void CopyTask::Generate(arma::mat& input,
93  arma::mat& labels,
94  const size_t batchSize) const
95 {
96  arma::field<arma::mat> fieldInput, fieldLabels;
97  Generate(fieldInput, fieldLabels, batchSize, true);
98  size_t cols = batchSize;
99  input = arma::zeros(fieldInput(0).n_rows, cols);
100  labels = arma::zeros(fieldLabels(0).n_rows, cols);
101  for (size_t i = 0; i < cols; ++i)
102  {
103  input.col(i) = fieldInput.at(i);
104  labels.col(i) = fieldLabels.at(i);
105  }
106 }
107 
108 
109 } // namespace tasks
110 } // namespace augmented
111 } // namespace ann
112 } // namespace mlpack
113 
114 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A discrete distribution where the only observations are discrete observations.
Definition: discrete_distribution.hpp:45
arma::vec Random() const
Return a randomly generated observation (one-dimensional vector; one observation) according to the pr...
Definition: discrete_distribution.cpp:22
arma::vec & Probabilities(const size_t dim=0)
Return the vector of probabilities for the given dimension.
Definition: discrete_distribution.hpp:232
CopyTask(const size_t maxLength, const size_t nRepeats, const bool addSeparator=false)
Creates an instance of the sequence copy task.
Definition: copy_impl.hpp:23
void Generate(arma::field< arma::mat > &input, arma::field< arma::mat > &labels, const size_t batchSize, bool fixedLength=false) const
Generate dataset of a given size.
Definition: copy_impl.hpp:50