mlpack
pixel_shuffle_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_PIXEL_SHUFFLE_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_PIXEL_SHUFFLE_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "pixel_shuffle.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  PixelShuffle(0, 0, 0, 0)
25 {
26  // Nothing to do here.
27 }
28 
29 template<typename InputDataType, typename OutputDataType>
31  const size_t upscaleFactor,
32  const size_t height,
33  const size_t width,
34  const size_t size) :
35  upscaleFactor(upscaleFactor),
36  height(height),
37  width(width),
38  size(size),
39  batchSize(0),
40  outputHeight(0),
41  outputWidth(0),
42  sizeOut(0),
43  reset(false)
44 {
45  // Nothing to do here.
46 }
47 
48 template<typename InputDataType, typename OutputDataType>
49 template<typename eT>
51  const arma::Mat<eT>& input, arma::Mat<eT>& output)
52 {
53  if (!reset)
54  {
55  batchSize = input.n_cols;
56  sizeOut = size / std::pow(upscaleFactor, 2);
57  outputHeight = height * upscaleFactor;
58  outputWidth = width * upscaleFactor;
59  reset = true;
60  }
61 
62  output.zeros(outputHeight * outputWidth * sizeOut, batchSize);
63  for (size_t n = 0; n < batchSize; n++)
64  {
65  arma::cube inputTemp(const_cast<arma::mat&>(input).memptr(), height,
66  width, size * batchSize, false, false);
67  arma::cube outputTemp(const_cast<arma::mat&>(output).memptr(),
68  outputHeight, outputWidth, sizeOut * batchSize, false, false);
69 
70  for (size_t c = 0; c < sizeOut; c++)
71  {
72  for (size_t h = 0; h < outputHeight; h++)
73  {
74  for (size_t w = 0; w < outputWidth; w++)
75  {
76  size_t height_index = h / upscaleFactor;
77  size_t width_index = w / upscaleFactor;
78  size_t channel_index = (upscaleFactor * (h % upscaleFactor)) +
79  (w % upscaleFactor) + (c * std::pow(upscaleFactor, 2));
80  outputTemp(w, h, c + n * sizeOut) = inputTemp(width_index,
81  height_index, channel_index + n * size);
82  }
83  }
84  }
85  }
86 }
87 
88 template<typename InputDataType, typename OutputDataType>
89 template<typename eT>
91  const arma::Mat<eT>& input, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
92 {
93  g.zeros(arma::size(input));
94  for (size_t n = 0; n < batchSize; n++)
95  {
96  arma::cube gyTemp(const_cast<arma::mat&>(gy).memptr(), outputHeight,
97  outputWidth, sizeOut * batchSize, false, false);
98  arma::cube gTemp(const_cast<arma::mat&>(g).memptr(), height, width,
99  size * batchSize, false, false);
100 
101  for (size_t c = 0; c < sizeOut; c++)
102  {
103  for (size_t h = 0; h < outputHeight; h++)
104  {
105  for (size_t w = 0; w < outputWidth; w++)
106  {
107  size_t height_index = h / upscaleFactor;
108  size_t width_index = w / upscaleFactor;
109  size_t channel_index = (upscaleFactor * (h % upscaleFactor)) +
110  (w % upscaleFactor) + (c * std::pow(upscaleFactor, 2));
111  gTemp(width_index, height_index, channel_index + n * size) =
112  gyTemp(w, h, c + n * sizeOut);
113  }
114  }
115  }
116  }
117 }
118 
119 template<typename InputDataType, typename OutputDataType>
120 template<typename Archive>
122  Archive& ar,
123  const unsigned int /* version */)
124 {
125  ar(CEREAL_NVP(delta));
126  ar(CEREAL_NVP(outputParameter));
127  ar(CEREAL_NVP(upscaleFactor));
128  ar(CEREAL_NVP(height));
129  ar(CEREAL_NVP(width));
130  ar(CEREAL_NVP(size));
131  ar(CEREAL_NVP(batchSize));
132  ar(CEREAL_NVP(outputHeight));
133  ar(CEREAL_NVP(outputHeight));
134  ar(CEREAL_NVP(outputWidth));
135  ar(CEREAL_NVP(sizeOut));
136 }
137 
138 } // namespace ann
139 } // namespace mlpack
140 
141 #endif
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of the PixelShuffle layer.
Definition: pixel_shuffle_impl.hpp:90
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
Definition: pixel_shuffle_impl.hpp:121
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of the PixelShuffle layer.
Definition: pixel_shuffle_impl.hpp:50
Implementation of the PixelShuffle layer.
Definition: pixel_shuffle.hpp:49
PixelShuffle()
Create the PixelShuffle object.
Definition: pixel_shuffle_impl.hpp:23