12 #ifndef MLPACK_METHODS_ANN_LAYER_CHANNEL_SHUFFLE_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LAYER_CHANNEL_SHUFFLE_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
34 template<
typename InputDataType,
typename OutputDataType>
37 const size_t inRowSize,
38 const size_t inColSize,
40 const size_t groupCount):
44 groupCount(groupCount),
47 if (depth % groupCount != 0)
49 Log::Fatal <<
"Number of channels must be divisible by groupCount!" << std::endl;
53 template<
typename InputDataType,
typename OutputDataType>
56 const arma::Mat<eT>& input, arma::Mat<eT>& output)
58 batchSize = input.n_cols;
60 if (output.is_empty())
61 output.set_size(inRowSize * inColSize * depth, batchSize);
64 assert(output.n_rows == inRowSize * inColSize * depth);
65 assert(output.n_cols == batchSize);
68 arma::cube inputAsCube(
const_cast<arma::Mat<eT>&
>(input).memptr(),
69 inRowSize, inColSize, depth * batchSize,
false,
false);
70 arma::cube outputAsCube(output.memptr(), inRowSize, inColSize,
71 depth * batchSize,
false,
true);
73 const size_t groupSize= depth / groupCount;
74 size_t outChannelIdx = 0;
75 for (
size_t k = 0; k < batchSize; ++k)
77 for (
size_t i = 0; i < groupSize; ++i)
79 for (
size_t g = 0; g < groupCount; ++g, ++outChannelIdx)
81 size_t inChannelIdx = k * batchSize + g * groupSize + i;
82 outputAsCube.slice(outChannelIdx) = inputAsCube.slice(inChannelIdx);
88 template<
typename InputDataType,
typename OutputDataType>
91 const arma::Mat<eT>& ,
92 const arma::Mat<eT>& gradient,
93 arma::Mat<eT>& output)
95 if (output.is_empty())
96 output.set_size(inRowSize * inColSize * depth, batchSize);
99 assert(output.n_rows == inRowSize * inColSize * depth);
100 assert(output.n_cols == batchSize);
103 arma::cube gradientAsCube(((arma::Mat<eT>&) gradient).memptr(), inColSize,
104 inColSize, depth * batchSize,
false,
false);
105 arma::cube outputAsCube(output.memptr(), inRowSize, inColSize,
106 depth * batchSize,
false,
true);
108 const size_t groupSize= depth / groupCount;
109 size_t gradientChannelIdx = 0;
110 for (
size_t k = 0; k < batchSize; ++k)
112 for (
size_t i = 0; i < groupSize; ++i)
114 for (
size_t g = 0; g < groupCount; ++g, ++gradientChannelIdx)
116 size_t outChannelIdx = k * batchSize + g * groupSize + i;
117 outputAsCube.slice(outChannelIdx) = gradientAsCube.slice(gradientChannelIdx);
124 template<
typename InputDataType,
typename OutputDataType>
125 template<
typename Archive>
127 Archive& ar,
const uint32_t )
129 ar(CEREAL_NVP(inRowSize));
130 ar(CEREAL_NVP(inColSize));
131 ar(CEREAL_NVP(depth));
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass through the layer.
Definition: channel_shuffle_impl.hpp:55
ChannelShuffle()
Create the Channel Shuffle object.
Definition: channel_shuffle_impl.hpp:24
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: channel_shuffle_impl.hpp:126
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gradient, arma::Mat< eT > &output)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: channel_shuffle_impl.hpp:90