13 #ifndef MLPACK_METHODS_ANN_LAYER_MEAN_POOLING_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_MEAN_POOLING_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
28 template<
typename InputDataType,
typename OutputDataType>
30 const size_t kernelWidth,
31 const size_t kernelHeight,
32 const size_t strideWidth,
33 const size_t strideHeight,
35 kernelWidth(kernelWidth),
36 kernelHeight(kernelHeight),
37 strideWidth(strideWidth),
38 strideHeight(strideHeight),
53 template<
typename InputDataType,
typename OutputDataType>
56 const arma::Mat<eT>& input, arma::Mat<eT>& output)
58 batchSize = input.n_cols;
59 inSize = input.n_elem / (inputWidth * inputHeight * batchSize);
60 inputTemp = arma::cube(
const_cast<arma::Mat<eT>&
>(input).memptr(),
61 inputWidth, inputHeight, batchSize * inSize,
false,
false);
65 outputWidth = std::floor((inputWidth -
66 (
double) kernelWidth) / (
double) strideWidth + 1);
67 outputHeight = std::floor((inputHeight -
68 (
double) kernelHeight) / (
double) strideHeight + 1);
72 outputWidth = std::ceil((inputWidth -
73 (
double) kernelWidth) / (
double) strideWidth + 1);
74 outputHeight = std::ceil((inputHeight -
75 (
double) kernelHeight) / (
double) strideHeight + 1);
78 outputTemp = arma::zeros<arma::Cube<eT> >(outputWidth, outputHeight,
81 for (
size_t s = 0; s < inputTemp.n_slices; s++)
82 Pooling(inputTemp.slice(s), outputTemp.slice(s));
84 output = arma::Mat<eT>(outputTemp.memptr(), outputTemp.n_elem / batchSize,
87 outputWidth = outputTemp.n_rows;
88 outputHeight = outputTemp.n_cols;
89 outSize = batchSize * inSize;
92 template<
typename InputDataType,
typename OutputDataType>
95 const arma::Mat<eT>& ,
96 const arma::Mat<eT>& gy,
99 arma::cube mappedError = arma::cube(((arma::Mat<eT>&) gy).memptr(),
100 outputWidth, outputHeight, outSize,
false,
false);
102 gTemp = arma::zeros<arma::cube>(inputTemp.n_rows,
103 inputTemp.n_cols, inputTemp.n_slices);
105 for (
size_t s = 0; s < mappedError.n_slices; s++)
107 Unpooling(inputTemp.slice(s), mappedError.slice(s), gTemp.slice(s));
110 g = arma::mat(gTemp.memptr(), gTemp.n_elem / batchSize, batchSize);
113 template<
typename InputDataType,
typename OutputDataType>
114 template<
typename Archive>
119 ar(CEREAL_NVP(kernelWidth));
120 ar(CEREAL_NVP(kernelHeight));
121 ar(CEREAL_NVP(strideWidth));
122 ar(CEREAL_NVP(strideHeight));
123 ar(CEREAL_NVP(batchSize));
124 ar(CEREAL_NVP(floor));
125 ar(CEREAL_NVP(inputWidth));
126 ar(CEREAL_NVP(inputHeight));
127 ar(CEREAL_NVP(outputWidth));
128 ar(CEREAL_NVP(outputHeight));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
MeanPooling()
Create the MeanPooling object.
Definition: mean_pooling_impl.hpp:23
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: mean_pooling_impl.hpp:115
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: mean_pooling_impl.hpp:55
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, using 3rd-order tensors as input, calculating the function f(x) by propagating x backwards through f.
Definition: mean_pooling_impl.hpp:94