mlpack
Public Member Functions | List of all members
mlpack::ann::ChannelShuffle< InputDataType, OutputDataType > Class Template Reference

Definition and implementation of the Channel Shuffle Layer. More...

#include <channel_shuffle.hpp>

Public Member Functions

 ChannelShuffle ()
 Create the Channel Shuffle object.
 
 ChannelShuffle (const size_t inRowSize, const size_t inColSize, const size_t depth, const size_t groupCount)
 The constructor for the Channel Shuffle. More...
 
template<typename eT >
void Forward (const arma::Mat< eT > &input, arma::Mat< eT > &output)
 Forward pass through the layer. More...
 
template<typename eT >
void Backward (const arma::Mat< eT > &, const arma::Mat< eT > &gradient, arma::Mat< eT > &output)
 Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards through f. More...
 
OutputDataType const & OutputParameter () const
 Get the output parameter.
 
OutputDataType & OutputParameter ()
 Modify the output parameter.
 
OutputDataType const & Delta () const
 Get the delta.
 
OutputDataType & Delta ()
 Modify the delta.
 
size_t const & InRowSize () const
 Get the row size of the input.
 
size_t & InRowSize ()
 Modify the row size of the input.
 
size_t const & InColSize () const
 Get the column size of the input.
 
size_t & InColSize ()
 Modify the column size of the input.
 
size_t const & InDepth () const
 Get the depth of the input.
 
size_t & InDepth ()
 Modify the depth of the input.
 
size_t const & InGroupCount () const
 Get the number of groups the channels is divided into.
 
size_t & InGroupCount ()
 Modify the number of groups the channels is divided into.
 
size_t InputShape () const
 Get the shape of the input.
 
template<typename Archive >
void serialize (Archive &ar, const uint32_t)
 Serialize the layer.
 

Detailed Description

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
class mlpack::ann::ChannelShuffle< InputDataType, OutputDataType >

Definition and implementation of the Channel Shuffle Layer.

Channel Shuffle divides the channels/units in a tensor into groups and rearrange while keeping the original tensor shape.

For more information, refer to the following paper,

@article{zhang2018shufflenet,
author = {Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, Jian Sun and
Megvii Inc},
title = {Shufflenet: An extremely efficient convolutional neural
network for mobile devices},
year = {2018},
url = {https://arxiv.org/pdf/1707.01083},
}
Template Parameters
InputDataTypeType of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
OutputDataTypeType of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).

Constructor & Destructor Documentation

◆ ChannelShuffle()

template<typename InputDataType , typename OutputDataType >
mlpack::ann::ChannelShuffle< InputDataType, OutputDataType >::ChannelShuffle ( const size_t  inRowSize,
const size_t  inColSize,
const size_t  depth,
const size_t  groupCount 
)

The constructor for the Channel Shuffle.

Parameters
inRowSizeNumber of input rows.
inColSizeNumber of input columns.
depthNumber of input slices.
groupNumber of groups for shuffling channels.

Member Function Documentation

◆ Backward()

template<typename InputDataType , typename OutputDataType >
template<typename eT >
void mlpack::ann::ChannelShuffle< InputDataType, OutputDataType >::Backward ( const arma::Mat< eT > &  ,
const arma::Mat< eT > &  gradient,
arma::Mat< eT > &  output 
)

Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards through f.

Using the results from the feed forward pass. Since the layer does not have any learn-able parameters, we just have to down-sample the gradient to make its size compatible with the input size.

Parameters
*(input) The input matrix.
gradientThe computed backward gradient.
outputThe resulting down-sampled output.

◆ Forward()

template<typename InputDataType , typename OutputDataType >
template<typename eT >
void mlpack::ann::ChannelShuffle< InputDataType, OutputDataType >::Forward ( const arma::Mat< eT > &  input,
arma::Mat< eT > &  output 
)

Forward pass through the layer.

Parameters
inputThe input matrix.
outputThe resulting interpolated output matrix.

The documentation for this class was generated from the following files: