13 #ifndef MLPACK_METHODS_ANN_LAYER_PIXEL_SHUFFLE_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_PIXEL_SHUFFLE_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
29 template<
typename InputDataType,
typename OutputDataType>
31 const size_t upscaleFactor,
35 upscaleFactor(upscaleFactor),
48 template<
typename InputDataType,
typename OutputDataType>
51 const arma::Mat<eT>& input, arma::Mat<eT>& output)
55 batchSize = input.n_cols;
56 sizeOut = size / std::pow(upscaleFactor, 2);
57 outputHeight = height * upscaleFactor;
58 outputWidth = width * upscaleFactor;
62 output.zeros(outputHeight * outputWidth * sizeOut, batchSize);
63 for (
size_t n = 0; n < batchSize; n++)
65 arma::cube inputTemp(const_cast<arma::mat&>(input).memptr(), height,
66 width, size * batchSize,
false,
false);
67 arma::cube outputTemp(const_cast<arma::mat&>(output).memptr(),
68 outputHeight, outputWidth, sizeOut * batchSize,
false,
false);
70 for (
size_t c = 0; c < sizeOut; c++)
72 for (
size_t h = 0; h < outputHeight; h++)
74 for (
size_t w = 0; w < outputWidth; w++)
76 size_t height_index = h / upscaleFactor;
77 size_t width_index = w / upscaleFactor;
78 size_t channel_index = (upscaleFactor * (h % upscaleFactor)) +
79 (w % upscaleFactor) + (c * std::pow(upscaleFactor, 2));
80 outputTemp(w, h, c + n * sizeOut) = inputTemp(width_index,
81 height_index, channel_index + n * size);
88 template<
typename InputDataType,
typename OutputDataType>
91 const arma::Mat<eT>& input,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
93 g.zeros(arma::size(input));
94 for (
size_t n = 0; n < batchSize; n++)
96 arma::cube gyTemp(const_cast<arma::mat&>(gy).memptr(), outputHeight,
97 outputWidth, sizeOut * batchSize,
false,
false);
98 arma::cube gTemp(const_cast<arma::mat&>(g).memptr(), height, width,
99 size * batchSize,
false,
false);
101 for (
size_t c = 0; c < sizeOut; c++)
103 for (
size_t h = 0; h < outputHeight; h++)
105 for (
size_t w = 0; w < outputWidth; w++)
107 size_t height_index = h / upscaleFactor;
108 size_t width_index = w / upscaleFactor;
109 size_t channel_index = (upscaleFactor * (h % upscaleFactor)) +
110 (w % upscaleFactor) + (c * std::pow(upscaleFactor, 2));
111 gTemp(width_index, height_index, channel_index + n * size) =
112 gyTemp(w, h, c + n * sizeOut);
119 template<
typename InputDataType,
typename OutputDataType>
120 template<
typename Archive>
125 ar(CEREAL_NVP(delta));
126 ar(CEREAL_NVP(outputParameter));
127 ar(CEREAL_NVP(upscaleFactor));
128 ar(CEREAL_NVP(height));
129 ar(CEREAL_NVP(width));
130 ar(CEREAL_NVP(size));
131 ar(CEREAL_NVP(batchSize));
132 ar(CEREAL_NVP(outputHeight));
133 ar(CEREAL_NVP(outputHeight));
134 ar(CEREAL_NVP(outputWidth));
135 ar(CEREAL_NVP(sizeOut));
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of the PixelShuffle layer.
Definition: pixel_shuffle_impl.hpp:90
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
Definition: pixel_shuffle_impl.hpp:121
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of the PixelShuffle layer.
Definition: pixel_shuffle_impl.hpp:50
Implementation of the PixelShuffle layer.
Definition: pixel_shuffle.hpp:49
PixelShuffle()
Create the PixelShuffle object.
Definition: pixel_shuffle_impl.hpp:23