mlpack
lookup_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_LOOKUP_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_LOOKUP_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "lookup.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template <typename InputDataType, typename OutputDataType>
24  const size_t vocabSize,
25  const size_t embeddingSize) :
26  vocabSize(vocabSize),
27  embeddingSize(embeddingSize)
28 {
29  weights.set_size(embeddingSize, vocabSize);
30 }
31 
32 template<typename InputDataType, typename OutputDataType>
33 template<typename eT>
35  const arma::Mat<eT>& input, arma::Mat<eT>& output)
36 {
37  const size_t seqLength = input.n_rows;
38  const size_t batchSize = input.n_cols;
39 
40  output.set_size(embeddingSize * seqLength, batchSize);
41 
42  for (size_t i = 0; i < batchSize; ++i)
43  {
44  // ith column of output is a vectorized form of a matrix of shape
45  // (embeddingSize, seqLength) selected as a combination of columns from the
46  // weights.
47  output.col(i) = arma::vectorise(weights.cols(
48  arma::conv_to<arma::uvec>::from(input.col(i)) - 1));
49  }
50 }
51 
52 template<typename InputDataType, typename OutputDataType>
53 template<typename eT>
55  const arma::Mat<eT>& /* input */,
56  const arma::Mat<eT>& /* gy */,
57  arma::Mat<eT>& /* g */)
58 {
59  Log::Fatal << "Lookup cannot be used as an intermediate layer." << std::endl;
60 }
61 
62 template<typename InputDataType, typename OutputDataType>
63 template<typename eT>
65  const arma::Mat<eT>& input,
66  const arma::Mat<eT>& error,
67  arma::Mat<eT>& gradient)
68 {
69  const size_t seqLength = input.n_rows;
70  const size_t batchSize = input.n_cols;
71 
72  arma::Cube<eT> errorTemp(const_cast<arma::Mat<eT>&>(error).memptr(),
73  embeddingSize, seqLength, batchSize, false, false);
74 
75  gradient.set_size(arma::size(weights));
76  gradient.zeros();
77 
78  for (size_t i = 0; i < batchSize; ++i)
79  {
80  gradient.cols(arma::conv_to<arma::uvec>::from(input.col(i)) - 1)
81  += errorTemp.slice(i);
82  }
83 }
84 
85 template<typename InputDataType, typename OutputDataType>
86 template<typename Archive>
88  Archive& ar, const uint32_t /* version */)
89 {
90  ar(CEREAL_NVP(vocabSize));
91  ar(CEREAL_NVP(embeddingSize));
92 
93  // This is inefficient, but we have to allocate this memory so that
94  // WeightSetVisitor gets the right size.
95  if (cereal::is_loading<Archive>())
96  weights.set_size(embeddingSize, vocabSize);
97 }
98 
99 } // namespace ann
100 } // namespace mlpack
101 
102 #endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: lookup_impl.hpp:87
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: lookup_impl.hpp:34
Lookup(const size_t vocabSize=0, const size_t embeddingSize=0)
Create the Lookup object using the specified vocabulary and embedding size.
Definition: lookup_impl.hpp:23
OutputDataType const & Gradient() const
Get the gradient.
Definition: lookup.hpp:104
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: lookup_impl.hpp:54