mlpack
Public Member Functions | List of all members
mlpack::ann::Lookup< InputDataType, OutputDataType > Class Template Reference

The Lookup class stores word embeddings and retrieves them using tokens. More...

#include <lookup.hpp>

Public Member Functions

 Lookup (const size_t vocabSize=0, const size_t embeddingSize=0)
 Create the Lookup object using the specified vocabulary and embedding size. More...
 
template<typename eT >
void Forward (const arma::Mat< eT > &input, arma::Mat< eT > &output)
 Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f. More...
 
template<typename eT >
void Backward (const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
 Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f. More...
 
template<typename eT >
void Gradient (const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &gradient)
 Calculate the gradient using the output delta and the input activation. More...
 
OutputDataType const & Parameters () const
 Get the parameters.
 
OutputDataType & Parameters ()
 Modify the parameters.
 
OutputDataType const & OutputParameter () const
 Get the output parameter.
 
OutputDataType & OutputParameter ()
 Modify the output parameter.
 
OutputDataType const & Delta () const
 Get the delta.
 
OutputDataType & Delta ()
 Modify the delta.
 
OutputDataType const & Gradient () const
 Get the gradient.
 
OutputDataType & Gradient ()
 Modify the gradient.
 
size_t VocabSize () const
 Get the size of the vocabulary.
 
size_t EmbeddingSize () const
 Get the length of each embedding vector.
 
template<typename Archive >
void serialize (Archive &ar, const uint32_t)
 Serialize the layer.
 

Detailed Description

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
class mlpack::ann::Lookup< InputDataType, OutputDataType >

The Lookup class stores word embeddings and retrieves them using tokens.

The Lookup layer is always the first layer of the network. The input to the Lookup class is a matrix of shape (sequenceLength, batchSize). The matrix consists of tokens which are used to lookup the table (i.e. weights) to find the embeddings of those tokens.

The input shape : (sequenceLength, batchSize). The output shape : (embeddingSize, sequenceLength, batchSize).

Template Parameters
InputDataTypeType of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
OutputDataTypeType of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).

Constructor & Destructor Documentation

◆ Lookup()

template<typename InputDataType , typename OutputDataType >
mlpack::ann::Lookup< InputDataType, OutputDataType >::Lookup ( const size_t  vocabSize = 0,
const size_t  embeddingSize = 0 
)

Create the Lookup object using the specified vocabulary and embedding size.

Parameters
vocabSizeThe size of the vocabulary.
embeddingSizeThe length of each embedding vector.

Member Function Documentation

◆ Backward()

template<typename InputDataType , typename OutputDataType >
template<typename eT >
void mlpack::ann::Lookup< InputDataType, OutputDataType >::Backward ( const arma::Mat< eT > &  ,
const arma::Mat< eT > &  gy,
arma::Mat< eT > &  g 
)

Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f.

Using the results from the feed forward pass.

Parameters
*(input) The propagated input activation.
gyThe backpropagated error.
gThe calculated gradient.

◆ Forward()

template<typename InputDataType , typename OutputDataType >
template<typename eT >
void mlpack::ann::Lookup< InputDataType, OutputDataType >::Forward ( const arma::Mat< eT > &  input,
arma::Mat< eT > &  output 
)

Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f.

Parameters
inputInput data used for evaluating the specified function.
outputResulting output activation.

◆ Gradient()

template<typename InputDataType , typename OutputDataType >
template<typename eT >
void mlpack::ann::Lookup< InputDataType, OutputDataType >::Gradient ( const arma::Mat< eT > &  input,
const arma::Mat< eT > &  error,
arma::Mat< eT > &  gradient 
)

Calculate the gradient using the output delta and the input activation.

Parameters
inputThe input parameter used for calculating the gradient.
errorThe calculated error.
gradientThe calculated gradient.

The documentation for this class was generated from the following files: