mlpack
nearest_interpolation_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_NEAREST_INTERPOLATION_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_NEAREST_INTERPOLATION_IMPL_HPP
14 
15 // In case it hasn't yet been included.
17 
18 namespace mlpack {
19 namespace ann {
20 
21 
22 template<typename InputDataType, typename OutputDataType>
25  inRowSize(0),
26  inColSize(0),
27  outRowSize(0),
28  outColSize(0),
29  depth(0),
30  batchSize(0)
31 {
32  // Nothing to do here.
33 }
34 
35 template<typename InputDataType, typename OutputDataType>
37 NearestInterpolation(const size_t inRowSize,
38  const size_t inColSize,
39  const size_t outRowSize,
40  const size_t outColSize,
41  const size_t depth) :
42  inRowSize(inRowSize),
43  inColSize(inColSize),
44  outRowSize(outRowSize),
45  outColSize(outColSize),
46  depth(depth),
47  batchSize(0)
48 {
49  // Nothing to do here.
50 }
51 
52 template<typename InputDataType, typename OutputDataType>
53 template<typename eT>
55  const arma::Mat<eT>& input, arma::Mat<eT>& output)
56 {
57  batchSize = input.n_cols;
58  if (output.is_empty())
59  output.set_size(outRowSize * outColSize * depth, batchSize);
60  else
61  {
62  assert(output.n_rows == outRowSize * outColSize * depth);
63  assert(output.n_cols == batchSize);
64  }
65 
66  assert(inRowSize >= 2);
67  assert(inColSize >= 2);
68 
69  arma::cube inputAsCube(const_cast<arma::Mat<eT>&>(input).memptr(),
70  inRowSize, inColSize, depth * batchSize, false, false);
71  arma::cube outputAsCube(output.memptr(), outRowSize, outColSize,
72  depth * batchSize, false, true);
73 
74  double scaleRow = (double) inRowSize / (double) outRowSize;
75  double scaleCol = (double) inColSize / (double) outColSize;
76 
77  for (size_t i = 0; i < outRowSize; ++i)
78  {
79  const size_t rOrigin = std::floor(i * scaleRow);
80 
81  for (size_t j = 0; j < outColSize; ++j)
82  {
83  const size_t cOrigin = std::floor(j * scaleCol);
84 
85  for (size_t k = 0; k < depth * batchSize; ++k)
86  {
87  outputAsCube(i, j, k) = inputAsCube.slice(k)(
88  rOrigin, cOrigin);
89  }
90  }
91  }
92 }
93 
94 template<typename InputDataType, typename OutputDataType>
95 template<typename eT>
97  const arma::Mat<eT>& /*input*/,
98  const arma::Mat<eT>& gradient,
99  arma::Mat<eT>& output)
100 {
101  if (output.is_empty())
102  {
103  output.zeros(inRowSize * inColSize * depth, batchSize);
104  }
105  else
106  {
107  assert(output.n_rows == inRowSize * inColSize * depth);
108  assert(output.n_cols == batchSize);
109  }
110 
111  assert(outRowSize >= 2);
112  assert(outColSize >= 2);
113 
114  arma::cube outputAsCube(output.memptr(), inRowSize, inColSize,
115  depth * batchSize, false, true);
116  arma::cube gradientAsCube(((arma::Mat<eT>&) gradient).memptr(), outRowSize,
117  outColSize, depth * batchSize, false, false);
118 
119  double scaleRow = (double)(inRowSize) / outRowSize;
120  double scaleCol = (double)(inColSize) / outColSize;
121 
122  if (gradient.n_elem == output.n_elem)
123  {
124  outputAsCube = gradientAsCube;
125  }
126  else
127  {
128  for (size_t i = 0; i < outRowSize; ++i)
129  {
130  const size_t rOrigin = std::floor(i * scaleRow);
131 
132  for (size_t j = 0; j < outColSize; ++j)
133  {
134  const size_t cOrigin = std::floor(j * scaleCol);
135 
136  for (size_t k = 0; k < depth * batchSize; ++k)
137  {
138  outputAsCube(rOrigin, cOrigin, k) +=
139  gradientAsCube(i, j, k);
140  }
141  }
142  }
143  }
144 }
145 
146 template<typename InputDataType, typename OutputDataType>
147 template<typename Archive>
149  Archive& ar, const uint32_t /* version */)
150 {
151  ar(CEREAL_NVP(inRowSize));
152  ar(CEREAL_NVP(inColSize));
153  ar(CEREAL_NVP(outRowSize));
154  ar(CEREAL_NVP(outColSize));
155  ar(CEREAL_NVP(depth));
156 }
157 
158 } // namespace ann
159 } // namespace mlpack
160 
161 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass through the layer.
Definition: nearest_interpolation_impl.hpp:54
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: nearest_interpolation_impl.hpp:148
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gradient, arma::Mat< eT > &output)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: nearest_interpolation_impl.hpp:96
NearestInterpolation()
Create the NearestInterpolation object.
Definition: nearest_interpolation_impl.hpp:24