Expression Templates Library (ETL)
max_pooling_derivative.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl::impl {
11 
12 // TODO Optimize max pool derivative like max pooling upsampling was optimized
13 
28  template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, typename A, typename B, typename M>
29  static void pool_derivative_block(const A& in, const B& out, M& m, size_t j, size_t k) {
30  auto max = out(j, k);
31 
32  // Slow path for cells with padding
33  if constexpr (P1 || P2) {
34  if (cpp_unlikely(j < P1 || k < P2 || j >= etl::dim<0>(out) - P1 || k >= etl::dim<1>(out) - P2)) {
35  const size_t base_j = j * S1 - P1;
36  const size_t base_k = k * S2 - P2;
37 
38  for (size_t jj = 0; jj < C1; ++jj) {
39  for (size_t kk = 0; kk < C2; ++kk) {
40  if (base_j + jj < etl::dim<0>(m) && base_k + kk < etl::dim<1>(m)) {
41  if (max == in(base_j + jj, base_k + kk)) {
42  m(base_j + jj, base_k + kk) = 1.0;
43 
44  if constexpr (cudnn_compatible) {
45  return;
46  }
47  } else {
48  m(base_j + jj, base_k + kk) = 0.0;
49  }
50  }
51  }
52  }
53 
54  return;
55  }
56  }
57 
58  for (size_t jj = 0; jj < C1; ++jj) {
59  for (size_t kk = 0; kk < C2; ++kk) {
60  const size_t final_j = j * S1 - P1 + jj;
61  const size_t final_k = k * S2 - P2 + kk;
62 
63  if constexpr (C1 == S1 && C2 == S2) {
64  if (max == in(final_j, final_k)) {
65  m(final_j, final_k) = 1.0;
66 
67  if constexpr (cudnn_compatible) {
68  return;
69  }
70  } else {
71  m(final_j, final_k) = 0.0;
72  }
73  } else {
74  if (max == in(final_j, final_k)) {
75  if constexpr (cudnn_compatible) {
76  m(final_j, final_k) = 1.0;
77 
78  return;
79  } else {
80  m(final_j, final_k) += 1.0;
81  }
82  }
83  }
84  }
85  }
86  }
87 
96  template <size_t C1, size_t C2, size_t C3, size_t S1, size_t S2, size_t S3, size_t P1, size_t P2, size_t P3, etl_2d A, typename B, typename M>
97  static void apply(A&& in, B&& out, M&& m) {
98  in.ensure_cpu_up_to_date();
99  out.ensure_cpu_up_to_date();
100 
101  if constexpr (C1 != S1 || C2 != S2) {
102  m = 0;
103  }
104 
105  for (size_t j = 0; j < etl::dim<0>(out); ++j) {
106  for (size_t k = 0; k < etl::dim<1>(out); ++k) {
107  pool_derivative_block<C1, C2, S1, S2, P1, P2>(in, out, m, j, k);
108  }
109  }
110 
111  m.invalidate_gpu();
112  m.validate_cpu();
113  }
114 
125  template <typename A, typename B, typename M>
126  static void pool_derivative_block(const A& in, const B& out, M& m, size_t j, size_t k, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
127  auto max = out(j, k);
128 
129  // Slow path for cells with padding
130  if (cpp_unlikely(p1 || p2)) {
131  if (cpp_unlikely(j < p1 || k < p2 || j >= etl::dim<0>(out) - p1 || k >= etl::dim<1>(out) - p2)) {
132  const size_t base_j = j * s1 - p1;
133  const size_t base_k = k * s2 - p2;
134 
135  for (size_t jj = 0; jj < c1; ++jj) {
136  for (size_t kk = 0; kk < c2; ++kk) {
137  if (base_j + jj < etl::dim<0>(m) && base_k + kk < etl::dim<1>(m)) {
138 
139  if (max == in(base_j + jj, base_k + kk)) {
140  m(base_j + jj, base_k + kk) = 1.0;
141 
142  if constexpr (cudnn_compatible) {
143  return;
144  }
145  } else {
146  m(base_j + jj, base_k + kk) = 0.0;
147  }
148  }
149  }
150  }
151 
152  return;
153  }
154  }
155 
156  if (c1 == s1 && c2 == s2) {
157  for (size_t jj = 0; jj < c1; ++jj) {
158  for (size_t kk = 0; kk < c2; ++kk) {
159  const size_t final_j = j * s1 - p1 + jj;
160  const size_t final_k = k * s2 - p2 + kk;
161 
162  if (max == in(final_j, final_k)) {
163  m(final_j, final_k) = 1.0;
164 
165  if constexpr (cudnn_compatible) {
166  return;
167  }
168  } else {
169  m(final_j, final_k) = 0.0;
170  }
171  }
172  }
173  } else {
174  for (size_t jj = 0; jj < c1; ++jj) {
175  for (size_t kk = 0; kk < c2; ++kk) {
176  const size_t final_j = j * s1 - p1 + jj;
177  const size_t final_k = k * s2 - p2 + kk;
178 
179  if (max == in(final_j, final_k)) {
180  if constexpr (cudnn_compatible) {
181  m(final_j, final_k) = 1.0;
182 
183  return;
184  } else {
185  m(final_j, final_k) += 1.0;
186  }
187  }
188  }
189  }
190  }
191  }
192 
201  template <etl_2d A, typename B, typename M>
202  static void apply(A&& in,
203  B&& out,
204  M&& m,
205  size_t c1,
206  size_t c2,
207  [[maybe_unused]] size_t c3,
208  size_t s1,
209  size_t s2,
210  [[maybe_unused]] size_t s3,
211  size_t p1,
212  size_t p2,
213  [[maybe_unused]] size_t p3) {
214  in.ensure_cpu_up_to_date();
215  out.ensure_cpu_up_to_date();
216 
217  if (c1 != s1 || c2 != s2) {
218  m = 0;
219  }
220 
221  for (size_t j = 0; j < etl::dim<0>(out); ++j) {
222  for (size_t k = 0; k < etl::dim<1>(out); ++k) {
223  pool_derivative_block(in, out, m, j, k, c1, c2, s1, s2, p1, p2);
224  }
225  }
226 
227  m.invalidate_gpu();
228  m.validate_cpu();
229  }
230 
231  // Deep handling
232 
241  template <size_t C1, size_t C2, size_t C3, size_t S1, size_t S2, size_t S3, size_t P1, size_t P2, size_t P3, deep_mat A, typename B, typename M>
242  static void apply(A&& in, B&& out, M& m) {
243  in.ensure_cpu_up_to_date();
244  out.ensure_cpu_up_to_date();
245 
246  for (size_t i = 0; i < etl::dim<0>(in); ++i) {
247  apply<C1, C2, C3, S1, S2, S3, P1, P2, P3>(in(i), out(i), m(i));
248  }
249 
250  m.invalidate_gpu();
251  m.validate_cpu();
252  }
253 
262  template <deep_mat A, typename B, typename M>
263  static void apply(A&& in, B&& out, M& m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3) {
264  in.ensure_cpu_up_to_date();
265  out.ensure_cpu_up_to_date();
266 
267  for (size_t i = 0; i < etl::dim<0>(in); ++i) {
268  apply(in(i), out(i), m(i), c1, c2, c3, s1, s2, s3, p1, p2, p3);
269  }
270 
271  m.invalidate_gpu();
272  m.validate_cpu();
273  }
274 };
275 
292  template <size_t C1, size_t C2, size_t C3, size_t S1, size_t S2, size_t S3, size_t P1, size_t P2, size_t P3, typename A, typename B, typename M>
293  static void pool_derivative_block(const A& in, const B& out, M& m, size_t i, size_t j, size_t k) {
294  auto max = out(i, j, k);
295 
296  for (size_t ii = 0; ii < C1; ++ii) {
297  for (size_t jj = 0; jj < C2; ++jj) {
298  for (size_t kk = 0; kk < C3; ++kk) {
299  if (max == in(i * C1 + ii, j * C2 + jj, k * C3 + kk)) {
300  m(i * C1 + ii, j * C2 + jj, k * C3 + kk) = 1.0;
301  } else {
302  m(i * C1 + ii, j * C2 + jj, k * C3 + kk) = 0.0;
303  }
304  }
305  }
306  }
307  }
308 
317  template <size_t C1, size_t C2, size_t C3, size_t S1, size_t S2, size_t S3, size_t P1, size_t P2, size_t P3, etl_3d A, typename B, typename M>
318  static void apply(A&& in, B&& out, M&& m) {
319  in.ensure_cpu_up_to_date();
320  out.ensure_cpu_up_to_date();
321 
322  for (size_t i = 0; i < etl::dim<0>(out); ++i) {
323  for (size_t j = 0; j < etl::dim<1>(out); ++j) {
324  for (size_t k = 0; k < etl::dim<2>(out); ++k) {
325  pool_derivative_block<C1, C2, C3, S1, S2, S3, P1, P2, P3>(in, out, m, i, j, k);
326  }
327  }
328  }
329 
330  m.invalidate_gpu();
331  m.validate_cpu();
332  }
333 
346  template <typename A, typename B, typename M>
347  static void pool_derivative_block(const A& in,
348  const B& out,
349  M& m,
350  size_t i,
351  size_t j,
352  size_t k,
353  size_t c1,
354  size_t c2,
355  size_t c3,
356  [[maybe_unused]] size_t s1,
357  [[maybe_unused]] size_t s2,
358  [[maybe_unused]] size_t s3,
359  [[maybe_unused]] size_t p1,
360  [[maybe_unused]] size_t p2,
361  [[maybe_unused]] size_t p3) {
362  auto max = out(i, j, k);
363 
364  for (size_t ii = 0; ii < c1; ++ii) {
365  for (size_t jj = 0; jj < c2; ++jj) {
366  for (size_t kk = 0; kk < c3; ++kk) {
367  if (max == in(i * c1 + ii, j * c2 + jj, k * c3 + kk)) {
368  m(i * c1 + ii, j * c2 + jj, k * c3 + kk) = 1.0;
369  } else {
370  m(i * c1 + ii, j * c2 + jj, k * c3 + kk) = 0.0;
371  }
372  }
373  }
374  }
375  }
376 
385  template <etl_3d A, typename B, typename M>
386  static void apply(A&& in, B&& out, M&& m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3) {
387  in.ensure_cpu_up_to_date();
388  out.ensure_cpu_up_to_date();
389 
390  for (size_t i = 0; i < etl::dim<0>(out); ++i) {
391  for (size_t j = 0; j < etl::dim<1>(out); ++j) {
392  for (size_t k = 0; k < etl::dim<2>(out); ++k) {
393  pool_derivative_block(in, out, m, i, j, k, c1, c2, c3, s1, s2, s3, p1, p2, p3);
394  }
395  }
396  }
397 
398  m.invalidate_gpu();
399  m.validate_cpu();
400  }
401 
402  // Deep handling
403 
412  template <size_t C1, size_t C2, size_t C3, size_t S1, size_t S2, size_t S3, size_t P1, size_t P2, size_t P3, etl_4d_and_plus A, typename B, typename M>
413  static void apply(A&& in, B&& out, M& m) {
414  in.ensure_cpu_up_to_date();
415  out.ensure_cpu_up_to_date();
416 
417  for (size_t i = 0; i < etl::dim<0>(in); ++i) {
418  apply<C1, C2, C3, S1, S2, S3, P1, P2, P3>(in(i), out(i), m(i));
419  }
420 
421  m.invalidate_gpu();
422  m.validate_cpu();
423  }
424 
433  template <etl_4d_and_plus A, typename B, typename M>
434  static void apply(
435  A&& in, B&& out, M& m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3) {
436  in.ensure_cpu_up_to_date();
437  out.ensure_cpu_up_to_date();
438 
439  for (size_t i = 0; i < etl::dim<0>(in); ++i) {
440  apply(in(i), out(i), m(i), c1, c2, c3, s1, s2, s3, p1, p2, p3);
441  }
442 
443  m.invalidate_gpu();
444  m.validate_cpu();
445  }
446 };
447 
448 } //end of namespace etl::impl
static void apply(A &&in, B &&out, M &m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:242
static void pool_derivative_block(const A &in, const B &out, M &m, size_t j, size_t k, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2)
Pool a block of the sub expression.
Definition: max_pooling_derivative.hpp:126
static void apply(A &&in, B &&out, M &m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:263
auto max(L &&lhs, R &&rhs)
Create an expression with the max value of lhs or rhs.
Definition: expression_builder.hpp:65
static void apply(A &&in, B &&out, M &m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:413
static void apply(A &&in, B &&out, M &&m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:97
Functor for the derivative of 3D Max Pooling.
Definition: max_pooling_derivative.hpp:279
static void apply(A &&in, B &&out, M &&m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:386
static void apply(A &&in, B &&out, M &m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:434
static void apply(A &&in, B &&out, M &&m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:318
constexpr bool cudnn_compatible
Indicates if ETL is trying to generate results similar to CUDNN (default).
Definition: config.hpp:163
static void pool_derivative_block(const A &in, const B &out, M &m, size_t i, size_t j, size_t k, size_t c1, size_t c2, size_t c3, [[maybe_unused]] size_t s1, [[maybe_unused]] size_t s2, [[maybe_unused]] size_t s3, [[maybe_unused]] size_t p1, [[maybe_unused]] size_t p2, [[maybe_unused]] size_t p3)
Pool a block of the sub expression.
Definition: max_pooling_derivative.hpp:347
static void apply(A &&in, B &&out, M &&m, size_t c1, size_t c2, [[maybe_unused]] size_t c3, size_t s1, size_t s2, [[maybe_unused]] size_t s3, size_t p1, size_t p2, [[maybe_unused]] size_t p3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:202
Functor for the derivative of 2D Max Pooling.
Definition: max_pooling_derivative.hpp:17
static void pool_derivative_block(const A &in, const B &out, M &m, size_t j, size_t k)
Pool a block of the sub expression.
Definition: max_pooling_derivative.hpp:29
static void pool_derivative_block(const A &in, const B &out, M &m, size_t i, size_t j, size_t k)
Pool a block of the sub expression.
Definition: max_pooling_derivative.hpp:293
Definition: avg_pooling_derivative.hpp:10