mlpack
mean_pooling_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_MEAN_POOLING_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_MEAN_POOLING_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "mean_pooling.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24 {
25  // Nothing to do here.
26 }
27 
28 template<typename InputDataType, typename OutputDataType>
30  const size_t kernelWidth,
31  const size_t kernelHeight,
32  const size_t strideWidth,
33  const size_t strideHeight,
34  const bool floor) :
35  kernelWidth(kernelWidth),
36  kernelHeight(kernelHeight),
37  strideWidth(strideWidth),
38  strideHeight(strideHeight),
39  floor(floor),
40  inSize(0),
41  outSize(0),
42  inputWidth(0),
43  inputHeight(0),
44  outputWidth(0),
45  outputHeight(0),
46  reset(false),
47  deterministic(false),
48  batchSize(0)
49 {
50  // Nothing to do here.
51 }
52 
53 template<typename InputDataType, typename OutputDataType>
54 template<typename eT>
56  const arma::Mat<eT>& input, arma::Mat<eT>& output)
57 {
58  batchSize = input.n_cols;
59  inSize = input.n_elem / (inputWidth * inputHeight * batchSize);
60  inputTemp = arma::cube(const_cast<arma::Mat<eT>&>(input).memptr(),
61  inputWidth, inputHeight, batchSize * inSize, false, false);
62 
63  if (floor)
64  {
65  outputWidth = std::floor((inputWidth -
66  (double) kernelWidth) / (double) strideWidth + 1);
67  outputHeight = std::floor((inputHeight -
68  (double) kernelHeight) / (double) strideHeight + 1);
69  }
70  else
71  {
72  outputWidth = std::ceil((inputWidth -
73  (double) kernelWidth) / (double) strideWidth + 1);
74  outputHeight = std::ceil((inputHeight -
75  (double) kernelHeight) / (double) strideHeight + 1);
76  }
77 
78  outputTemp = arma::zeros<arma::Cube<eT> >(outputWidth, outputHeight,
79  batchSize * inSize);
80 
81  for (size_t s = 0; s < inputTemp.n_slices; s++)
82  Pooling(inputTemp.slice(s), outputTemp.slice(s));
83 
84  output = arma::Mat<eT>(outputTemp.memptr(), outputTemp.n_elem / batchSize,
85  batchSize);
86 
87  outputWidth = outputTemp.n_rows;
88  outputHeight = outputTemp.n_cols;
89  outSize = batchSize * inSize;
90 }
91 
92 template<typename InputDataType, typename OutputDataType>
93 template<typename eT>
95  const arma::Mat<eT>& /* input */,
96  const arma::Mat<eT>& gy,
97  arma::Mat<eT>& g)
98 {
99  arma::cube mappedError = arma::cube(((arma::Mat<eT>&) gy).memptr(),
100  outputWidth, outputHeight, outSize, false, false);
101 
102  gTemp = arma::zeros<arma::cube>(inputTemp.n_rows,
103  inputTemp.n_cols, inputTemp.n_slices);
104 
105  for (size_t s = 0; s < mappedError.n_slices; s++)
106  {
107  Unpooling(inputTemp.slice(s), mappedError.slice(s), gTemp.slice(s));
108  }
109 
110  g = arma::mat(gTemp.memptr(), gTemp.n_elem / batchSize, batchSize);
111 }
112 
113 template<typename InputDataType, typename OutputDataType>
114 template<typename Archive>
116  Archive& ar,
117  const uint32_t /* version */)
118 {
119  ar(CEREAL_NVP(kernelWidth));
120  ar(CEREAL_NVP(kernelHeight));
121  ar(CEREAL_NVP(strideWidth));
122  ar(CEREAL_NVP(strideHeight));
123  ar(CEREAL_NVP(batchSize));
124  ar(CEREAL_NVP(floor));
125  ar(CEREAL_NVP(inputWidth));
126  ar(CEREAL_NVP(inputHeight));
127  ar(CEREAL_NVP(outputWidth));
128  ar(CEREAL_NVP(outputHeight));
129 }
130 
131 } // namespace ann
132 } // namespace mlpack
133 
134 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
MeanPooling()
Create the MeanPooling object.
Definition: mean_pooling_impl.hpp:23
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: mean_pooling_impl.hpp:115
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: mean_pooling_impl.hpp:55
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, using 3rd-order tensors as input, calculating the function f(x) by propagating x backwards through f.
Definition: mean_pooling_impl.hpp:94