mlpack
max_pooling_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_MAX_POOLING_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_MAX_POOLING_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "max_pooling.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24 {
25  // Nothing to do here.
26 }
27 
28 template<typename InputDataType, typename OutputDataType>
30  const size_t kernelWidth,
31  const size_t kernelHeight,
32  const size_t strideWidth,
33  const size_t strideHeight,
34  const bool floor) :
35  kernelWidth(kernelWidth),
36  kernelHeight(kernelHeight),
37  strideWidth(strideWidth),
38  strideHeight(strideHeight),
39  floor(floor),
40  inSize(0),
41  outSize(0),
42  reset(false),
43  inputWidth(0),
44  inputHeight(0),
45  outputWidth(0),
46  outputHeight(0),
47  deterministic(false),
48  batchSize(0)
49 {
50  // Nothing to do here.
51 }
52 
53 template<typename InputDataType, typename OutputDataType>
54 template<typename eT>
56  const arma::Mat<eT>& input, arma::Mat<eT>& output)
57 {
58  batchSize = input.n_cols;
59  inSize = input.n_elem / (inputWidth * inputHeight * batchSize);
60  inputTemp = arma::cube(const_cast<arma::Mat<eT>&>(input).memptr(),
61  inputWidth, inputHeight, batchSize * inSize, false, false);
62 
63  if (floor)
64  {
65  outputWidth = std::floor((inputWidth -
66  (double) kernelWidth) / (double) strideWidth + 1);
67  outputHeight = std::floor((inputHeight -
68  (double) kernelHeight) / (double) strideHeight + 1);
69  }
70  else
71  {
72  outputWidth = std::ceil((inputWidth -
73  (double) kernelWidth) / (double) strideWidth + 1);
74  outputHeight = std::ceil((inputHeight -
75  (double) kernelHeight) / (double) strideHeight + 1);
76  }
77 
78  outputTemp = arma::zeros<arma::Cube<eT> >(outputWidth, outputHeight,
79  batchSize * inSize);
80 
81  if (!deterministic)
82  {
83  poolingIndices.push_back(outputTemp);
84  }
85 
86  if (!reset)
87  {
88  size_t elements = inputWidth * inputHeight;
89  indicesCol = arma::linspace<arma::Col<size_t> >(0, (elements - 1),
90  elements);
91 
92  indices = arma::Mat<size_t>(indicesCol.memptr(), inputWidth, inputHeight);
93 
94  reset = true;
95  }
96 
97  for (size_t s = 0; s < inputTemp.n_slices; s++)
98  {
99  if (!deterministic)
100  {
101  PoolingOperation(inputTemp.slice(s), outputTemp.slice(s),
102  poolingIndices.back().slice(s));
103  }
104  else
105  {
106  PoolingOperation(inputTemp.slice(s), outputTemp.slice(s),
107  inputTemp.slice(s));
108  }
109  }
110 
111  output = arma::Mat<eT>(outputTemp.memptr(), outputTemp.n_elem / batchSize,
112  batchSize);
113 
114  outputWidth = outputTemp.n_rows;
115  outputHeight = outputTemp.n_cols;
116  outSize = batchSize * inSize;
117 }
118 
119 template<typename InputDataType, typename OutputDataType>
120 template<typename eT>
122  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
123 {
124  arma::cube mappedError = arma::cube(((arma::Mat<eT>&) gy).memptr(),
125  outputWidth, outputHeight, outSize, false, false);
126 
127  gTemp = arma::zeros<arma::cube>(inputTemp.n_rows,
128  inputTemp.n_cols, inputTemp.n_slices);
129 
130  for (size_t s = 0; s < mappedError.n_slices; s++)
131  {
132  Unpooling(mappedError.slice(s), gTemp.slice(s),
133  poolingIndices.back().slice(s));
134  }
135 
136  poolingIndices.pop_back();
137 
138  g = arma::mat(gTemp.memptr(), gTemp.n_elem / batchSize, batchSize);
139 }
140 
141 template<typename InputDataType, typename OutputDataType>
142 template<typename Archive>
144  Archive& ar,
145  const uint32_t /* version */)
146 {
147  ar(CEREAL_NVP(kernelWidth));
148  ar(CEREAL_NVP(kernelHeight));
149  ar(CEREAL_NVP(strideWidth));
150  ar(CEREAL_NVP(strideHeight));
151  ar(CEREAL_NVP(batchSize));
152  ar(CEREAL_NVP(floor));
153  ar(CEREAL_NVP(inputWidth));
154  ar(CEREAL_NVP(inputHeight));
155  ar(CEREAL_NVP(outputWidth));
156  ar(CEREAL_NVP(outputHeight));
157 }
158 
159 } // namespace ann
160 } // namespace mlpack
161 
162 #endif
MaxPooling()
Create the MaxPooling object.
Definition: max_pooling_impl.hpp:23
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: max_pooling_impl.hpp:143
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, using 3rd-order tensors as input, calculating the function f(x) by propagating x backwards through f.
Definition: max_pooling_impl.hpp:121
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: max_pooling_impl.hpp:55