mlpack
add_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_AUGMENTED_TASKS_ADD_IMPL_HPP
14 #define MLPACK_METHODS_AUGMENTED_TASKS_ADD_IMPL_HPP
15 
16 #include "add.hpp"
17 
18 namespace mlpack {
19 namespace ann /* Artificial Neural Network */ {
20 namespace augmented /* Augmented neural network */ {
21 namespace tasks /* Task utilities for augmented */ {
22 
23 AddTask::AddTask(const size_t bitLen) : bitLen(bitLen)
24 {
25  if (bitLen <= 0)
26  {
27  std::ostringstream oss;
28  oss << "AddTask::AddTask(): binary length (" << bitLen << ") "
29  << "is not positive!"
30  << std::endl;
31  throw std::invalid_argument(oss.str());
32  }
33 }
34 
35 void AddTask::Generate(arma::field<arma::mat>& input,
36  arma::field<arma::mat>& labels,
37  const size_t batchSize,
38  bool fixedLength) const
39 {
40  arma::field<arma::vec> vecInput = arma::field<arma::colvec>(batchSize);
41  arma::field<arma::vec> vecLabels = arma::field<arma::colvec>(batchSize);
42  size_t sizeA = bitLen, sizeB = bitLen;
43  for (size_t i = 0; i < batchSize; ++i)
44  {
45  if (!fixedLength)
46  {
47  arma::vec weights(bitLen - 1);
48  weights = arma::exp2(arma::linspace(1, bitLen - 1, bitLen - 1));
49 
51  // We have two binary numbers with exactly two digits (10 and 11).
52  // Increasing length by 1 double the number of valid numbers.
53  d.Probabilities(0) = arma::exp2(
54  arma::linspace(1, bitLen - 1, bitLen - 1));
55 
56  sizeA = 2 + d.Random()(0);
57  sizeB = 2 + d.Random()(0);
58  }
59  // Construct sequence of the form
60  // (binary number with sizeA bits) + '+' + (binary number with sizeB bits).
61  vecInput(i) = arma::randi<arma::colvec>(
62  sizeA + sizeB + 1, arma::distr_param(0, 1));
63  // Insert special value for '+' delimiter.
64  vecInput(i).at(sizeA) = 2;
65 
66  int valA = 0;
67  for (size_t k = 0; k < sizeA; ++k)
68  {
69  valA += static_cast<int>(vecInput(i).at(k)) << k;
70  }
71 
72  int valB = 0;
73  for (size_t k = sizeA + 1; k < sizeA + 1 + sizeB; ++k)
74  {
75  valB += static_cast<int>(vecInput(i).at(k)) << (k - sizeA - 1);
76  }
77 
78  int tot = valA + valB;
79  std::vector<int> binarySeq;
80  while (tot > 0)
81  {
82  binarySeq.push_back(tot & 1);
83  tot >>= 1;
84  }
85  if (binarySeq.empty())
86  {
87  if (valA + valB != 0)
88  {
89  std::ostringstream oss;
90  oss << "AddTask::Generate(): output sequence is empty "
91  << "but the target sum is not 0 (=" << valA + valB << ")"
92  << std::endl;
93  throw std::domain_error(oss.str());
94  }
95  binarySeq.push_back(0);
96  }
97  vecLabels(i) = arma::colvec(binarySeq.size());
98  for (size_t j = 0; j < binarySeq.size(); ++j)
99  {
100  vecLabels(i).at(j) = binarySeq[j];
101  }
102  }
103  Binarize(vecInput, input);
104  Binarize(vecLabels, labels);
105  if (input.n_rows != labels.n_rows)
106  {
107  std::ostringstream oss;
108  oss << "AddTask::Generate(): sequences after application of "
109  << "Binarize() are not aligned ("
110  << input.n_rows << " and " << labels.n_rows << ")"
111  << std::endl;
112  throw std::logic_error(oss.str());
113  }
114  for (size_t i = 0; i < input.n_rows; ++i)
115  {
116  labels.at(i).reshape(input.at(i).n_elem, 1);
117  }
118 }
119 
120 void AddTask::Generate(arma::mat& input,
121  arma::mat& labels,
122  const size_t batchSize) const
123 {
124  arma::field<arma::mat> fieldInput, fieldLabels;
125  Generate(fieldInput, fieldLabels, batchSize, true);
126  input.set_size(fieldInput(0).n_rows, batchSize);
127  labels.set_size(fieldLabels(0).n_rows, batchSize);
128  for (size_t i = 0; i < batchSize; ++i)
129  {
130  input.col(i) = fieldInput.at(i);
131  labels.col(i) = fieldLabels.at(i);
132  }
133 }
134 
135 void AddTask::Binarize(const arma::field<arma::vec>& input,
136  arma::field<arma::mat>& output) const
137 {
138  output = arma::field<arma::mat>(input.n_elem);
139  for (size_t i = 0; i < input.n_elem; ++i)
140  {
141  output.at(i) = arma::zeros(3, input.at(i).n_elem);
142  for (size_t j = 0; j < input.at(i).n_elem; ++j)
143  {
144  size_t val = input.at(i).at(j);
145  output.at(i).at(val, j) = 1;
146  }
147  output.at(i).reshape(output.at(i).n_elem, 1);
148  }
149 }
150 
151 
152 } // namespace tasks
153 } // namespace augmented
154 } // namespace ann
155 } // namespace mlpack
156 #endif
void Generate(arma::field< arma::mat > &input, arma::field< arma::mat > &labels, const size_t batchSize, const bool fixedLength=false) const
Generate dataset of a given size.
Definition: add_impl.hpp:35
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A discrete distribution where the only observations are discrete observations.
Definition: discrete_distribution.hpp:45
arma::vec Random() const
Return a randomly generated observation (one-dimensional vector; one observation) according to the pr...
Definition: discrete_distribution.cpp:22
AddTask(const size_t bitLen)
Creates an instance of the binary addition task.
Definition: add_impl.hpp:23
arma::vec & Probabilities(const size_t dim=0)
Return the vector of probabilities for the given dimension.
Definition: discrete_distribution.hpp:232