13 #ifndef MLPACK_METHODS_ANN_LAYER_LOOKUP_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_LOOKUP_IMPL_HPP 22 template <
typename InputDataType,
typename OutputDataType>
24 const size_t vocabSize,
25 const size_t embeddingSize) :
27 embeddingSize(embeddingSize)
29 weights.set_size(embeddingSize, vocabSize);
32 template<
typename InputDataType,
typename OutputDataType>
35 const arma::Mat<eT>& input, arma::Mat<eT>& output)
37 const size_t seqLength = input.n_rows;
38 const size_t batchSize = input.n_cols;
40 output.set_size(embeddingSize * seqLength, batchSize);
42 for (
size_t i = 0; i < batchSize; ++i)
47 output.col(i) = arma::vectorise(weights.cols(
48 arma::conv_to<arma::uvec>::from(input.col(i)) - 1));
52 template<
typename InputDataType,
typename OutputDataType>
55 const arma::Mat<eT>& ,
56 const arma::Mat<eT>& ,
59 Log::Fatal <<
"Lookup cannot be used as an intermediate layer." << std::endl;
62 template<
typename InputDataType,
typename OutputDataType>
65 const arma::Mat<eT>& input,
66 const arma::Mat<eT>& error,
67 arma::Mat<eT>& gradient)
69 const size_t seqLength = input.n_rows;
70 const size_t batchSize = input.n_cols;
72 arma::Cube<eT> errorTemp(
const_cast<arma::Mat<eT>&
>(error).memptr(),
73 embeddingSize, seqLength, batchSize,
false,
false);
75 gradient.set_size(arma::size(weights));
78 for (
size_t i = 0; i < batchSize; ++i)
80 gradient.cols(arma::conv_to<arma::uvec>::from(input.col(i)) - 1)
81 += errorTemp.slice(i);
85 template<
typename InputDataType,
typename OutputDataType>
86 template<
typename Archive>
88 Archive& ar,
const uint32_t )
90 ar(CEREAL_NVP(vocabSize));
91 ar(CEREAL_NVP(embeddingSize));
95 if (cereal::is_loading<Archive>())
96 weights.set_size(embeddingSize, vocabSize);
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: lookup_impl.hpp:87
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: lookup_impl.hpp:34
Lookup(const size_t vocabSize=0, const size_t embeddingSize=0)
Create the Lookup object using the specified vocabulary and embedding size.
Definition: lookup_impl.hpp:23
OutputDataType const & Gradient() const
Get the gradient.
Definition: lookup.hpp:104
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: lookup_impl.hpp:54