mlpack
sort_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AUGMENTED_TASKS_SORT_IMPL_HPP
13 #define MLPACK_METHODS_AUGMENTED_TASKS_SORT_IMPL_HPP
14 
15 #include "sort.hpp"
16 
17 namespace mlpack {
18 namespace ann /* Artificial Neural Network */ {
19 namespace augmented /* Augmented neural network */ {
20 namespace tasks /* Task utilities for augmented */ {
21 
22 SortTask::SortTask(const size_t maxLength,
23  const size_t bitLen,
24  bool addSeparator)
25  : maxLength(maxLength), bitLen(bitLen), addSeparator(addSeparator)
26 {
27  if (maxLength <= 1)
28  {
29  std::ostringstream oss;
30  oss << "SortTask::SortTask(): maximum sequence length ("
31  << maxLength << ") "
32  << "should be at least 2!"
33  << std::endl;
34  throw std::invalid_argument(oss.str());
35  }
36  if (bitLen <= 0)
37  {
38  std::ostringstream oss;
39  oss << "SortTask::SortTask(): binary length (" << bitLen << ") "
40  << "is not positive!"
41  << std::endl;
42  throw std::invalid_argument(oss.str());
43  }
44 }
45 
46 void SortTask::Generate(arma::field<arma::mat>& input,
47  arma::field<arma::mat>& labels,
48  const size_t batchSize,
49  bool fixedLength) const
50 {
51  input = arma::field<arma::mat>(batchSize);
52  labels = arma::field<arma::mat>(batchSize);
53  size_t size = maxLength;
54  for (size_t i = 0; i < batchSize; ++i)
55  {
56  if (!fixedLength)
57  {
58  // Generate random uniform length from [2..maxLength].
59  size = mlpack::math::RandInt(2, maxLength+1);
60  }
61  input(i) = arma::randi<arma::mat>(bitLen, size, arma::distr_param(0, 1));
62  arma::mat itemAns(bitLen, size);
63  arma::colvec vals(size);
64  for (size_t j = 0; j < size; ++j)
65  {
66  int val = 0;
67  for (size_t k = 0; k < bitLen; ++k)
68  {
69  val <<= 1;
70  val += input(i).at(k, j);
71  }
72  vals[j] = val;
73  }
74  arma::uvec indices = arma::sort_index(vals);
75  for (size_t j = 0; j < size; ++j)
76  {
77  itemAns.col(j) = input(i).col(indices.at(j));
78  }
79  labels(i) = itemAns;
80  input(i).reshape(input(i).n_elem, 1);
81  if (addSeparator)
82  {
83  arma::mat sepInput = arma::zeros(input(i).n_elem + size, 1);
84  size_t ptr = 0, origPtr = 0;
85  for (size_t j = 0; j < size; ++j)
86  {
87  sepInput.rows(ptr, ptr + bitLen - 1) =
88  input(i).rows(origPtr, origPtr + bitLen - 1);
89  ptr += bitLen;
90  origPtr += bitLen;
91  sepInput.at(ptr, 0) = 0.5;
92  ++ptr;
93  }
94  input(i) = sepInput;
95  }
96  labels(i).reshape(labels(i).n_elem, 1);
97  }
98 }
99 
100 void SortTask::Generate(arma::mat& input,
101  arma::mat& labels,
102  const size_t batchSize) const
103 {
104  arma::field<arma::mat> fieldInput, fieldLabels;
105  Generate(fieldInput, fieldLabels, batchSize, true);
106  size_t inputRows = fieldInput(0).n_rows;
107  size_t labelRows = fieldLabels(0).n_rows;
108  size_t cols = batchSize;
109  input = arma::zeros(inputRows, cols);
110  labels = arma::zeros(labelRows, cols);
111  for (size_t i = 0; i < cols; ++i)
112  {
113  input.col(i) = fieldInput.at(i);
114  labels.col(i) = fieldLabels.at(i);
115  }
116 }
117 
118 } // namespace tasks
119 } // namespace augmented
120 } // namespace ann
121 } // namespace mlpack
122 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Generate(arma::field< arma::mat > &input, arma::field< arma::mat > &labels, const size_t batchSize, bool fixedLength=false) const
Generate dataset of a given size.
Definition: sort_impl.hpp:46
int RandInt(const int hiExclusive)
Generates a uniform random integer.
Definition: random.hpp:110
SortTask(const size_t maxLength, const size_t bitLen, bool addSeparator=false)
Creates an instance of the sequence sort task.
Definition: sort_impl.hpp:22