12 #ifndef MLPACK_METHODS_AUGMENTED_TASKS_SORT_IMPL_HPP 13 #define MLPACK_METHODS_AUGMENTED_TASKS_SORT_IMPL_HPP 25 : maxLength(maxLength), bitLen(bitLen), addSeparator(addSeparator)
29 std::ostringstream oss;
30 oss <<
"SortTask::SortTask(): maximum sequence length (" 32 <<
"should be at least 2!" 34 throw std::invalid_argument(oss.str());
38 std::ostringstream oss;
39 oss <<
"SortTask::SortTask(): binary length (" << bitLen <<
") " 42 throw std::invalid_argument(oss.str());
47 arma::field<arma::mat>& labels,
48 const size_t batchSize,
49 bool fixedLength)
const 51 input = arma::field<arma::mat>(batchSize);
52 labels = arma::field<arma::mat>(batchSize);
53 size_t size = maxLength;
54 for (
size_t i = 0; i < batchSize; ++i)
61 input(i) = arma::randi<arma::mat>(bitLen, size, arma::distr_param(0, 1));
62 arma::mat itemAns(bitLen, size);
63 arma::colvec vals(size);
64 for (
size_t j = 0; j < size; ++j)
67 for (
size_t k = 0; k < bitLen; ++k)
70 val += input(i).at(k, j);
74 arma::uvec indices = arma::sort_index(vals);
75 for (
size_t j = 0; j < size; ++j)
77 itemAns.col(j) = input(i).col(indices.at(j));
80 input(i).reshape(input(i).n_elem, 1);
83 arma::mat sepInput = arma::zeros(input(i).n_elem + size, 1);
84 size_t ptr = 0, origPtr = 0;
85 for (
size_t j = 0; j < size; ++j)
87 sepInput.rows(ptr, ptr + bitLen - 1) =
88 input(i).rows(origPtr, origPtr + bitLen - 1);
91 sepInput.at(ptr, 0) = 0.5;
96 labels(i).reshape(labels(i).n_elem, 1);
102 const size_t batchSize)
const 104 arma::field<arma::mat> fieldInput, fieldLabels;
105 Generate(fieldInput, fieldLabels, batchSize,
true);
106 size_t inputRows = fieldInput(0).n_rows;
107 size_t labelRows = fieldLabels(0).n_rows;
108 size_t cols = batchSize;
109 input = arma::zeros(inputRows, cols);
110 labels = arma::zeros(labelRows, cols);
111 for (
size_t i = 0; i < cols; ++i)
113 input.col(i) = fieldInput.at(i);
114 labels.col(i) = fieldLabels.at(i);
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Generate(arma::field< arma::mat > &input, arma::field< arma::mat > &labels, const size_t batchSize, bool fixedLength=false) const
Generate dataset of a given size.
Definition: sort_impl.hpp:46
int RandInt(const int hiExclusive)
Generates a uniform random integer.
Definition: random.hpp:110
SortTask(const size_t maxLength, const size_t bitLen, bool addSeparator=false)
Creates an instance of the sequence sort task.
Definition: sort_impl.hpp:22