12 #ifndef MLPACK_METHODS_AUGMENTED_TASKS_COPY_IMPL_HPP 13 #define MLPACK_METHODS_AUGMENTED_TASKS_COPY_IMPL_HPP 24 const size_t nRepeats,
25 const bool addSeparator) :
28 addSeparator(addSeparator)
32 std::ostringstream oss;
33 oss <<
"CopyTask::CopyTask(): maximum sequence length (" 35 <<
"should be at least 2!" 37 throw std::invalid_argument(oss.str());
41 std::ostringstream oss;
42 oss <<
"CopyTask::CopyTask(): repetition count (" << nRepeats <<
") " 45 throw std::invalid_argument(oss.str());
51 arma::field<arma::mat>& labels,
52 const size_t batchSize,
53 bool fixedLength)
const 55 input = arma::field<arma::mat>(batchSize);
56 labels = arma::field<arma::mat>(batchSize);
57 size_t size = maxLength;
58 for (
size_t i = 0; i < batchSize; ++i)
62 arma::vec weights(maxLength - 1);
68 arma::exp2(arma::linspace(1, maxLength - 1, maxLength - 1));
72 arma::colvec vecInput = arma::randi<arma::colvec>(
73 size, arma::distr_param(0, 1));
74 arma::colvec vecLabel = arma::conv_to<arma::colvec>::from(
75 arma::repmat(vecInput, nRepeats, 1));
76 size_t totSize = vecInput.n_elem + addSeparator + vecLabel.n_elem;
77 input(i) = arma::zeros(totSize, 2);
78 input(i).col(0).rows(0, vecInput.n_elem - 1) =
81 input(i).at(vecInput.n_elem, 0) = 0.5;
82 input(i).col(1).rows(addSeparator + vecInput.n_elem, totSize - 1) =
83 arma::ones(totSize-vecInput.n_elem - addSeparator);
84 input(i) = input(i).t();
85 input(i).reshape(input(i).n_elem, 1);
86 labels(i) = arma::zeros(totSize, 1);
87 labels(i).col(0).rows(addSeparator + vecInput.n_elem, totSize - 1) =
94 const size_t batchSize)
const 96 arma::field<arma::mat> fieldInput, fieldLabels;
97 Generate(fieldInput, fieldLabels, batchSize,
true);
98 size_t cols = batchSize;
99 input = arma::zeros(fieldInput(0).n_rows, cols);
100 labels = arma::zeros(fieldLabels(0).n_rows, cols);
101 for (
size_t i = 0; i < cols; ++i)
103 input.col(i) = fieldInput.at(i);
104 labels.col(i) = fieldLabels.at(i);
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A discrete distribution where the only observations are discrete observations.
Definition: discrete_distribution.hpp:45
arma::vec Random() const
Return a randomly generated observation (one-dimensional vector; one observation) according to the pr...
Definition: discrete_distribution.cpp:22
arma::vec & Probabilities(const size_t dim=0)
Return the vector of probabilities for the given dimension.
Definition: discrete_distribution.hpp:232
CopyTask(const size_t maxLength, const size_t nRepeats, const bool addSeparator=false)
Creates an instance of the sequence copy task.
Definition: copy_impl.hpp:23
void Generate(arma::field< arma::mat > &input, arma::field< arma::mat > &labels, const size_t batchSize, bool fixedLength=false) const
Generate dataset of a given size.
Definition: copy_impl.hpp:50