13 #ifndef MLPACK_METHODS_AUGMENTED_TASKS_ADD_IMPL_HPP 14 #define MLPACK_METHODS_AUGMENTED_TASKS_ADD_IMPL_HPP 27 std::ostringstream oss;
28 oss <<
"AddTask::AddTask(): binary length (" << bitLen <<
") " 31 throw std::invalid_argument(oss.str());
36 arma::field<arma::mat>& labels,
37 const size_t batchSize,
38 bool fixedLength)
const 40 arma::field<arma::vec> vecInput = arma::field<arma::colvec>(batchSize);
41 arma::field<arma::vec> vecLabels = arma::field<arma::colvec>(batchSize);
42 size_t sizeA = bitLen, sizeB = bitLen;
43 for (
size_t i = 0; i < batchSize; ++i)
47 arma::vec weights(bitLen - 1);
48 weights = arma::exp2(arma::linspace(1, bitLen - 1, bitLen - 1));
54 arma::linspace(1, bitLen - 1, bitLen - 1));
61 vecInput(i) = arma::randi<arma::colvec>(
62 sizeA + sizeB + 1, arma::distr_param(0, 1));
64 vecInput(i).at(sizeA) = 2;
67 for (
size_t k = 0; k < sizeA; ++k)
69 valA +=
static_cast<int>(vecInput(i).at(k)) << k;
73 for (
size_t k = sizeA + 1; k < sizeA + 1 + sizeB; ++k)
75 valB +=
static_cast<int>(vecInput(i).at(k)) << (k - sizeA - 1);
78 int tot = valA + valB;
79 std::vector<int> binarySeq;
82 binarySeq.push_back(tot & 1);
85 if (binarySeq.empty())
89 std::ostringstream oss;
90 oss <<
"AddTask::Generate(): output sequence is empty " 91 <<
"but the target sum is not 0 (=" << valA + valB <<
")" 93 throw std::domain_error(oss.str());
95 binarySeq.push_back(0);
97 vecLabels(i) = arma::colvec(binarySeq.size());
98 for (
size_t j = 0; j < binarySeq.size(); ++j)
100 vecLabels(i).at(j) = binarySeq[j];
103 Binarize(vecInput, input);
104 Binarize(vecLabels, labels);
105 if (input.n_rows != labels.n_rows)
107 std::ostringstream oss;
108 oss <<
"AddTask::Generate(): sequences after application of " 109 <<
"Binarize() are not aligned (" 110 << input.n_rows <<
" and " << labels.n_rows <<
")" 112 throw std::logic_error(oss.str());
114 for (
size_t i = 0; i < input.n_rows; ++i)
116 labels.at(i).reshape(input.at(i).n_elem, 1);
122 const size_t batchSize)
const 124 arma::field<arma::mat> fieldInput, fieldLabels;
125 Generate(fieldInput, fieldLabels, batchSize,
true);
126 input.set_size(fieldInput(0).n_rows, batchSize);
127 labels.set_size(fieldLabels(0).n_rows, batchSize);
128 for (
size_t i = 0; i < batchSize; ++i)
130 input.col(i) = fieldInput.at(i);
131 labels.col(i) = fieldLabels.at(i);
135 void AddTask::Binarize(
const arma::field<arma::vec>& input,
136 arma::field<arma::mat>& output)
const 138 output = arma::field<arma::mat>(input.n_elem);
139 for (
size_t i = 0; i < input.n_elem; ++i)
141 output.at(i) = arma::zeros(3, input.at(i).n_elem);
142 for (
size_t j = 0; j < input.at(i).n_elem; ++j)
144 size_t val = input.at(i).at(j);
145 output.at(i).at(val, j) = 1;
147 output.at(i).reshape(output.at(i).n_elem, 1);
void Generate(arma::field< arma::mat > &input, arma::field< arma::mat > &labels, const size_t batchSize, const bool fixedLength=false) const
Generate dataset of a given size.
Definition: add_impl.hpp:35
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A discrete distribution where the only observations are discrete observations.
Definition: discrete_distribution.hpp:45
arma::vec Random() const
Return a randomly generated observation (one-dimensional vector; one observation) according to the pr...
Definition: discrete_distribution.cpp:22
AddTask(const size_t bitLen)
Creates an instance of the binary addition task.
Definition: add_impl.hpp:23
arma::vec & Probabilities(const size_t dim=0)
Return the vector of probabilities for the given dimension.
Definition: discrete_distribution.hpp:232