mlpack
channel_shuffle_impl.hpp
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_CHANNEL_SHUFFLE_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_CHANNEL_SHUFFLE_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "channel_shuffle.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 
22 template<typename InputDataType, typename OutputDataType>
25  inRowSize(0),
26  inColSize(0),
27  depth(0),
28  groupCount(0),
29  batchSize(0)
30 {
31  // Nothing to do here.
32 }
33 
34 template<typename InputDataType, typename OutputDataType>
37  const size_t inRowSize,
38  const size_t inColSize,
39  const size_t depth,
40  const size_t groupCount):
41  inRowSize(inRowSize),
42  inColSize(inColSize),
43  depth(depth),
44  groupCount(groupCount),
45  batchSize(0)
46 {
47  if (depth % groupCount != 0)
48  {
49  Log::Fatal << "Number of channels must be divisible by groupCount!" << std::endl;
50  }
51 }
52 
53 template<typename InputDataType, typename OutputDataType>
54 template<typename eT>
56  const arma::Mat<eT>& input, arma::Mat<eT>& output)
57 {
58  batchSize = input.n_cols;
59 
60  if (output.is_empty())
61  output.set_size(inRowSize * inColSize * depth, batchSize);
62  else
63  {
64  assert(output.n_rows == inRowSize * inColSize * depth);
65  assert(output.n_cols == batchSize);
66  }
67 
68  arma::cube inputAsCube(const_cast<arma::Mat<eT>&>(input).memptr(),
69  inRowSize, inColSize, depth * batchSize, false, false);
70  arma::cube outputAsCube(output.memptr(), inRowSize, inColSize,
71  depth * batchSize, false, true);
72 
73  const size_t groupSize= depth / groupCount;
74  size_t outChannelIdx = 0;
75  for (size_t k = 0; k < batchSize; ++k)
76  {
77  for (size_t i = 0; i < groupSize; ++i)
78  {
79  for (size_t g = 0; g < groupCount; ++g, ++outChannelIdx)
80  {
81  size_t inChannelIdx = k * batchSize + g * groupSize + i;
82  outputAsCube.slice(outChannelIdx) = inputAsCube.slice(inChannelIdx);
83  }
84  }
85  }
86 }
87 
88 template<typename InputDataType, typename OutputDataType>
89 template<typename eT>
91  const arma::Mat<eT>& /*input*/,
92  const arma::Mat<eT>& gradient,
93  arma::Mat<eT>& output)
94 {
95  if (output.is_empty())
96  output.set_size(inRowSize * inColSize * depth, batchSize);
97  else
98  {
99  assert(output.n_rows == inRowSize * inColSize * depth);
100  assert(output.n_cols == batchSize);
101  }
102 
103  arma::cube gradientAsCube(((arma::Mat<eT>&) gradient).memptr(), inColSize,
104  inColSize, depth * batchSize, false, false);
105  arma::cube outputAsCube(output.memptr(), inRowSize, inColSize,
106  depth * batchSize, false, true);
107 
108  const size_t groupSize= depth / groupCount;
109  size_t gradientChannelIdx = 0;
110  for (size_t k = 0; k < batchSize; ++k)
111  {
112  for (size_t i = 0; i < groupSize; ++i)
113  {
114  for (size_t g = 0; g < groupCount; ++g, ++gradientChannelIdx)
115  {
116  size_t outChannelIdx = k * batchSize + g * groupSize + i;
117  outputAsCube.slice(outChannelIdx) = gradientAsCube.slice(gradientChannelIdx);
118  }
119  }
120  }
121 
122 }
123 
124 template<typename InputDataType, typename OutputDataType>
125 template<typename Archive>
127  Archive& ar, const uint32_t /* version */)
128 {
129  ar(CEREAL_NVP(inRowSize));
130  ar(CEREAL_NVP(inColSize));
131  ar(CEREAL_NVP(depth));
132 }
133 
134 } // namespace ann
135 } // namespace mlpack
136 
137 #endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass through the layer.
Definition: channel_shuffle_impl.hpp:55
ChannelShuffle()
Create the Channel Shuffle object.
Definition: channel_shuffle_impl.hpp:24
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: channel_shuffle_impl.hpp:126
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gradient, arma::Mat< eT > &output)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: channel_shuffle_impl.hpp:90