Expression Templates Library (ETL)
max_pooling_upsample.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl::impl::standard {
11 
26  template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, typename A, typename B, typename C, typename M>
27  static void pool_block_2d(const A& in, const B& out, const C& errors, M& m, size_t i, size_t j) {
28  auto max = out(i, j);
29  auto error = errors(i, j);
30 
31  // Slow path for cells with padding
32  if constexpr (P1 || P2) {
33  if (cpp_unlikely(i < P1 || j < P2 || i >= etl::dim<0>(out) - P1 || j >= etl::dim<1>(out) - P2)) {
34  const size_t base_i = i * S1 - P1;
35  const size_t base_j = j * S2 - P2;
36 
37  for (size_t ii = 0; ii < C1; ++ii) {
38  for (size_t jj = 0; jj < C2; ++jj) {
39  if (base_i + ii < etl::dim<0>(m) && base_j + jj < etl::dim<1>(m)) {
40  if constexpr (S1 == C1 && S2 == C2) {
41  if (max == in(base_i + ii, base_j + jj)) {
42  m(base_i + ii, base_j + jj) = error;
43  } else {
44  m(base_i + ii, base_j + jj) = 0.0;
45  }
46  } else {
47  if (max == in(base_i + ii, base_j + jj)) {
48  m(base_i + ii, base_j + jj) += error;
49  }
50  }
51  }
52  }
53  }
54 
55  return;
56  }
57  }
58 
59  if constexpr (S1 == C1 && S2 == C2) {
60  for (size_t ii = 0; ii < C1; ++ii) {
61  for (size_t jj = 0; jj < C2; ++jj) {
62  if (max == in(i * S1 - P1 + ii, j * S2 - P2 + jj)) {
63  m(i * S1 - P1 + ii, j * S2 - P2 + jj) = error;
64  } else {
65  m(i * S1 - P1 + ii, j * S2 - P2 + jj) = 0.0;
66  }
67  }
68  }
69  } else {
70  for (size_t ii = 0; ii < C1; ++ii) {
71  for (size_t jj = 0; jj < C2; ++jj) {
72  if (max == in(i * S1 - P1 + ii, j * S2 - P2 + jj)) {
73  m(i * S1 - P1 + ii, j * S2 - P2 + jj) += error;
74  }
75  }
76  }
77  }
78  }
79 
90  template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, typename A, typename B, typename C, typename M>
91  static void pool_block_3d(const A& in, const B& out, const C& errors, M& m, size_t q, size_t i, size_t j) {
92  auto max = out(q, i, j);
93  auto error = errors(q, i, j);
94 
95  // Slow path for cells with padding
96  if constexpr (P1 || P2) {
97  if (cpp_unlikely(i < P1 || j < P2 || i >= etl::dim<1>(out) - P1 || j >= etl::dim<2>(out) - P2)) {
98  const size_t base_i = i * S1 - P1;
99  const size_t base_j = j * S2 - P2;
100 
101  for (size_t ii = 0; ii < C1; ++ii) {
102  for (size_t jj = 0; jj < C2; ++jj) {
103  if (base_i + ii < etl::dim<1>(m) && base_j + jj < etl::dim<2>(m)) {
104  if constexpr (S1 == C1 && S2 == C2) {
105  if (max == in(q, base_i + ii, base_j + jj)) {
106  m(q, base_i + ii, base_j + jj) = error;
107  } else {
108  m(q, base_i + ii, base_j + jj) = 0.0;
109  }
110  } else {
111  if (max == in(q, base_i + ii, base_j + jj)) {
112  m(q, base_i + ii, base_j + jj) += error;
113  }
114  }
115  }
116  }
117  }
118 
119  return;
120  }
121  }
122 
123  if constexpr (S1 == C1 && S2 == C2) {
124  for (size_t ii = 0; ii < C1; ++ii) {
125  for (size_t jj = 0; jj < C2; ++jj) {
126  if (max == in(q, i * S1 - P1 + ii, j * S2 - P2 + jj)) {
127  m(q, i * S1 - P1 + ii, j * S2 - P2 + jj) = error;
128  } else {
129  m(q, i * S1 - P1 + ii, j * S2 - P2 + jj) = 0.0;
130  }
131  }
132  }
133  } else {
134  for (size_t ii = 0; ii < C1; ++ii) {
135  for (size_t jj = 0; jj < C2; ++jj) {
136  if (max == in(q, i * S1 -P1 + ii, j * S2 -P2 + jj)) {
137  m(q, i * S1 - P1 + ii, j * S2 - P2 + jj) += error;
138  }
139  }
140  }
141  }
142  }
143 
154  template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, typename A, typename B, typename C, typename M>
155  static void pool_block_4d(const A& in, const B& out, const C& errors, M& m, size_t p, size_t q, size_t i, size_t j) {
156  auto max = out(p, q, i, j);
157  auto error = errors(p, q, i, j);
158 
159  // Slow path for cells with padding
160  if constexpr (P1 || P2) {
161  if (cpp_unlikely(i < P1 || j < P2 || i >= etl::dim<2>(out) - P1 || j >= etl::dim<3>(out) - P2)) {
162  const size_t base_i = i * S1 - P1;
163  const size_t base_j = j * S2 - P2;
164 
165  for (size_t ii = 0; ii < C1; ++ii) {
166  for (size_t jj = 0; jj < C2; ++jj) {
167  if (base_i + ii < etl::dim<2>(m) && base_j + jj < etl::dim<3>(m)) {
168  if constexpr (S1 == C1 && S2 == C2) {
169  if (max == in(p, q, base_i + ii, base_j + jj)) {
170  m(p, q, base_i + ii, base_j + jj) = error;
171  } else {
172  m(p, q, base_i + ii, base_j + jj) = 0.0;
173  }
174  } else {
175  if (max == in(p, q, base_i + ii, base_j + jj)) {
176  m(p, q, base_i + ii, base_j + jj) += error;
177  }
178  }
179  }
180  }
181  }
182 
183  return;
184  }
185  }
186 
187  if constexpr (S1 == C1 && S2 == C2) {
188  for (size_t ii = 0; ii < C1; ++ii) {
189  for (size_t jj = 0; jj < C2; ++jj) {
190  if (max == in(p, q, i * S1 -P1+ ii, j * S2 -P2 + jj)) {
191  m(p, q, i * S1 - P1 + ii, j * S2 - P2 + jj) = error;
192  } else {
193  m(p, q, i * S1 - P1 + ii, j * S2 - P2 + jj) = 0.0;
194  }
195  }
196  }
197  } else {
198  for (size_t ii = 0; ii < C1; ++ii) {
199  for (size_t jj = 0; jj < C2; ++jj) {
200  if (max == in(p, q, i * S1 - P1 + ii, j * S2 - P2 + jj)) {
201  m(p, q, i * S1 - P1 + ii, j * S2 - P2 + jj) += error;
202  }
203  }
204  }
205  }
206  }
207 
218  template <typename A, typename B, typename C, typename M>
219  static void pool_block_2d(const A& in, const B& out, const C& errors, M& m, size_t i, size_t j, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
220  auto max = out(i, j);
221  auto error = errors(i, j);
222 
223  // Slow path for cells with padding
224  if (cpp_unlikely(p1 || p2)) {
225  if (cpp_unlikely(i < p1 || j < p2 || i >= etl::dim<0>(out) - p1 || j >= etl::dim<1>(out) - p2)) {
226  const size_t base_i = i * s1 - p1;
227  const size_t base_j = j * s2 - p2;
228 
229  for (size_t ii = 0; ii < c1; ++ii) {
230  for (size_t jj = 0; jj < c2; ++jj) {
231  if (base_i + ii < etl::dim<0>(m) && base_j + jj < etl::dim<1>(m)) {
232  if (s1 == c1 && s2 == c2) {
233  if (max == in(base_i + ii, base_j + jj)) {
234  m(base_i + ii, base_j + jj) = error;
235  } else {
236  m(base_i + ii, base_j + jj) = 0.0;
237  }
238  } else {
239  if (max == in(base_i + ii, base_j + jj)) {
240  m(base_i + ii, base_j + jj) += error;
241  }
242  }
243  }
244  }
245  }
246 
247  return;
248  }
249  }
250 
251  if (s1 == c1 && s2 == c2) {
252  for (size_t ii = 0; ii < c1; ++ii) {
253  for (size_t jj = 0; jj < c2; ++jj) {
254  if (max == in(i * s1 - p1 + ii, j * s2 - p2 + jj)) {
255  m(i * s1 - p1 + ii, j * s2 - p2 + jj) = error;
256  } else {
257  m(i * s1 - p1 + ii, j * s2 - p2 + jj) = 0.0;
258  }
259  }
260  }
261  } else {
262  for (size_t ii = 0; ii < c1; ++ii) {
263  for (size_t jj = 0; jj < c2; ++jj) {
264  if (max == in(i * s1 - p1 + ii, j * s2 - p2 + jj)) {
265  m(i * s1 - p1 + ii, j * s2 - p2 + jj) += error;
266  }
267  }
268  }
269  }
270  }
271 
282  template <typename A, typename B, typename C, typename M>
283  static void pool_block_3d(const A& in, const B& out, const C& errors, M& m, size_t q, size_t i, size_t j, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
284  auto max = out(q, i, j);
285  auto error = errors(q, i, j);
286 
287  // Slow path for cells with padding
288  if (cpp_unlikely(p1 || p2)) {
289  if (cpp_unlikely(i < p1 || j < p2 || i >= etl::dim<1>(out) - p1 || j >= etl::dim<2>(out) - p2)) {
290  const size_t base_i = i * s1 - p1;
291  const size_t base_j = j * s2 - p2;
292 
293  for (size_t ii = 0; ii < c1; ++ii) {
294  for (size_t jj = 0; jj < c2; ++jj) {
295  if (base_i + ii < etl::dim<1>(m) && base_j + jj < etl::dim<2>(m)) {
296  if (s1 == c1 && s2 == c2) {
297  if (max == in(q, base_i + ii, base_j + jj)) {
298  m(q, base_i + ii, base_j + jj) = error;
299  } else {
300  m(q, base_i + ii, base_j + jj) = 0.0;
301  }
302  } else {
303  if (max == in(q, base_i + ii, base_j + jj)) {
304  m(q, base_i + ii, base_j + jj) += error;
305  }
306  }
307  }
308  }
309  }
310 
311  return;
312  }
313  }
314 
315  if (s1 == c1 && s2 == c2) {
316  for (size_t ii = 0; ii < c1; ++ii) {
317  for (size_t jj = 0; jj < c2; ++jj) {
318  if (max == in(q, i * s1 - p1 + ii, j * s2 - p2 + jj)) {
319  m(q, i * s1 - p1 + ii, j * s2 - p2 + jj) = error;
320  } else {
321  m(q, i * s1 - p1 + ii, j * s2 - p2 + jj) = 0.0;
322  }
323  }
324  }
325  } else {
326  for (size_t ii = 0; ii < c1; ++ii) {
327  for (size_t jj = 0; jj < c2; ++jj) {
328  if (max == in(q, i * s1 - p1 + ii, j * s2 - p2 + jj)) {
329  m(q, i * s1 - p1 + ii, j * s2 - p2 + jj) += error;
330  }
331  }
332  }
333  }
334  }
335 
346  template <typename A, typename B, typename C, typename M>
347  static void pool_block_4d(const A& in, const B& out, const C& errors, M& m, size_t p, size_t q, size_t i, size_t j, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
348  auto max = out(p, q, i, j);
349  auto error = errors(p, q, i, j);
350 
351  // Slow path for cells with padding
352  if (cpp_unlikely(p1 || p2)) {
353  if (cpp_unlikely(i < p1 || j < p2 || i >= etl::dim<2>(out) - p1 || j >= etl::dim<3>(out) - p2)) {
354  const size_t base_i = i * s1 - p1;
355  const size_t base_j = j * s2 - p2;
356 
357  for (size_t ii = 0; ii < c1; ++ii) {
358  for (size_t jj = 0; jj < c2; ++jj) {
359  if (base_i + ii < etl::dim<2>(m) && base_j + jj < etl::dim<3>(m)) {
360  if (s1 == c1 && s2 == c2) {
361  if (max == in(p, q, base_i + ii, base_j + jj)) {
362  m(p, q, base_i + ii, base_j + jj) = error;
363  } else {
364  m(p, q, base_i + ii, base_j + jj) = 0.0;
365  }
366  } else {
367  if (max == in(p, q, base_i + ii, base_j + jj)) {
368  m(p, q, base_i + ii, base_j + jj) += error;
369  }
370  }
371  }
372  }
373  }
374 
375  return;
376  }
377  }
378 
379  if (s1 == c1 && s2 == c2) {
380  for (size_t ii = 0; ii < c1; ++ii) {
381  for (size_t jj = 0; jj < c2; ++jj) {
382  if (max == in(p, q, i * s1 - p1 + ii, j * s2 - p2 + jj)) {
383  m(p, q, i * s1 - p1 + ii, j * s2 - p2 + jj) = error;
384  } else {
385  m(p, q, i * s1 - p1 + ii, j * s2 - p2 + jj) = 0.0;
386  }
387  }
388  }
389  } else {
390  for (size_t ii = 0; ii < c1; ++ii) {
391  for (size_t jj = 0; jj < c2; ++jj) {
392  if (max == in(p, q, i * s1 - p1 + ii, j * s2 - p2 + jj)) {
393  m(p, q, i * s1 - p1 + ii, j * s2 - p2 + jj) += error;
394  }
395  }
396  }
397  }
398  }
399 
400  // 2D Handling
401 
409  template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, etl_2d A, typename B, typename C, typename M>
410  static void apply(A&& in, B&& out, C&& errors, M&& m) {
411  if constexpr (S1 != C1 || S2 != C2) {
412  m = 0;
413  }
414 
415  for (size_t i = 0; i < etl::dim<0>(out); ++i) {
416  for (size_t j = 0; j < etl::dim<1>(out); ++j) {
417  pool_block_2d<C1, C2, S1, S2, P1, P2>(in, out, errors, m, i, j);
418  }
419  }
420  }
421 
429  template <etl_2d A, typename B, typename C, typename M>
430  static void apply(A&& in, B&& out, C&& errors, M&& m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
431  if (s1 != c1 || s2 != c2) {
432  m = 0;
433  }
434 
435  for (size_t i = 0; i < etl::dim<0>(out); ++i) {
436  for (size_t j = 0; j < etl::dim<1>(out); ++j) {
437  pool_block_2d(in, out, errors, m, i, j, c1, c2, s1, s2, p1, p2);
438  }
439  }
440  }
441 
442  // 3D Handling
443 
451  template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, etl_3d A, typename B, typename C, typename M>
452  static void apply(A&& in, B&& out, C&& errors, M&& m) {
453  if (S1 != C1 || S2 != C2) {
454  m = 0;
455  }
456 
457  auto batch_fun = [&](const size_t first, const size_t last) {
458  for (size_t q = first; q < last; ++q) {
459  for (size_t i = 0; i < etl::dim<1>(out); ++i) {
460  for (size_t j = 0; j < etl::dim<2>(out); ++j) {
461  pool_block_3d<C1, C2, S1, S2, P1, P2>(in, out, errors, m, q, i, j);
462  }
463  }
464  }
465  };
466 
467  const size_t N = etl::dim<0>(out);
468 
469  engine_dispatch_1d_serial(batch_fun, 0, N, 2UL);
470  }
471 
479  template <etl_3d A, typename B, typename C, typename M>
480  static void apply(A&& in, B&& out, C&& errors, M&& m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
481  if (s1 != c1 || s2 != c2) {
482  m = 0;
483  }
484 
485  auto batch_fun = [&](const size_t first, const size_t last) {
486  for (size_t q = first; q < last; ++q) {
487  for (size_t i = 0; i < etl::dim<1>(out); ++i) {
488  for (size_t j = 0; j < etl::dim<2>(out); ++j) {
489  pool_block_3d(in, out, errors, m, q, i, j, c1, c2, s1, s2, p1, p2);
490  }
491  }
492  }
493  };
494 
495  const size_t N = etl::dim<0>(out);
496 
497  engine_dispatch_1d_serial(batch_fun, 0, N, 2UL);
498  }
499 
500  // 4D Handling
501 
509  template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, etl_4d A, typename B, typename C, typename M>
510  static void apply(A&& in, B&& out, C&& errors, M&& m) {
511  if (S1 != C1 || S2 != C2) {
512  m = 0;
513  }
514 
515  auto batch_fun = [&](const size_t first, const size_t last) {
516  for (size_t p = first; p < last; ++p) {
517  for (size_t q = 0; q < etl::dim<1>(out); ++q) {
518  for (size_t i = 0; i < etl::dim<2>(out); ++i) {
519  for (size_t j = 0; j < etl::dim<3>(out); ++j) {
520  pool_block_4d<C1, C2, S1, S2, P1, P2>(in, out, errors, m, p, q, i, j);
521  }
522  }
523  }
524  }
525  };
526 
527  const size_t N = etl::dim<0>(out);
528 
529  engine_dispatch_1d_serial(batch_fun, 0, N, 2UL);
530  }
531 
539  template <etl_4d A, typename B, typename C, typename M>
540  static void apply(A&& in, B&& out, C&& errors, M&& m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
541  if (s1 != c1 || s2 != c2) {
542  m = 0;
543  }
544 
545  auto batch_fun = [&](const size_t first, const size_t last) {
546  for (size_t p = first; p < last; ++p) {
547  for (size_t q = 0; q < etl::dim<1>(out); ++q) {
548  for (size_t i = 0; i < etl::dim<2>(out); ++i) {
549  for (size_t j = 0; j < etl::dim<3>(out); ++j) {
550  pool_block_4d(in, out, errors, m, p, q, i, j, c1, c2, s1, s2, p1, p2);
551  }
552  }
553  }
554  }
555  };
556 
557  const size_t N = etl::dim<0>(out);
558 
559  engine_dispatch_1d_serial(batch_fun, 0, N, 2UL);
560  }
561 
562  // Deep handling
563 
571  template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, etl_5d_and_plus A, typename B, typename C, typename M>
572  static void apply(A&& in, B&& out, C&& errors, M& m) {
573  for (size_t i = 0; i < etl::dim<0>(in); ++i) {
574  apply<C1, C2, S1, S2, P1, P2>(in(i), out(i), errors(i), m(i));
575  }
576  }
577 
585  template <etl_5d_and_plus A, typename B, typename C, typename M>
586  static void apply(A&& in, B&& out, C&& errors, M& m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
587  for (size_t i = 0; i < etl::dim<0>(in); ++i) {
588  apply(in(i), out(i), errors(i), m(i), c1, c2, s1, s2, p1, p2);
589  }
590  }
591 };
592 
609  template <size_t C1, size_t C2, size_t C3, typename A, typename B, typename C, typename M>
610  static void pool_block_3d(const A& in, const B& out, const C& errors, M& m, size_t i, size_t j, size_t k) {
611  auto max = out(i, j, k);
612  auto error = errors(i, j, k);
613 
614  for (size_t ii = 0; ii < C1; ++ii) {
615  for (size_t jj = 0; jj < C2; ++jj) {
616  for (size_t kk = 0; kk < C3; ++kk) {
617  if (max == in(i * C1 + ii, j * C2 + jj, k * C3 + kk)) {
618  m(i * C1 + ii, j * C2 + jj, k * C3 + kk) = error;
619  } else {
620  m(i * C1 + ii, j * C2 + jj, k * C3 + kk) = 0.0;
621  }
622  }
623  }
624  }
625  }
626 
639  template <size_t C1, size_t C2, size_t C3, typename A, typename B, typename C, typename M>
640  static void pool_block_4d(const A& in, const B& out, const C& errors, M& m, size_t n, size_t i, size_t j, size_t k) {
641  auto max = out(n, i, j, k);
642  auto error = errors(n, i, j, k);
643 
644  for (size_t ii = 0; ii < C1; ++ii) {
645  for (size_t jj = 0; jj < C2; ++jj) {
646  for (size_t kk = 0; kk < C3; ++kk) {
647  if (max == in(n, i * C1 + ii, j * C2 + jj, k * C3 + kk)) {
648  m(n, i * C1 + ii, j * C2 + jj, k * C3 + kk) = error;
649  } else {
650  m(n, i * C1 + ii, j * C2 + jj, k * C3 + kk) = 0.0;
651  }
652  }
653  }
654  }
655  }
656 
665  template <size_t C1, size_t C2, size_t C3, etl_3d A, typename B, typename C, typename M>
666  static void apply(A&& in, B&& out, C&& errors, M&& m) {
667  for (size_t i = 0; i < etl::dim<0>(out); ++i) {
668  for (size_t j = 0; j < etl::dim<1>(out); ++j) {
669  for (size_t k = 0; k < etl::dim<2>(out); ++k) {
670  pool_block_3d<C1, C2, C3>(in, out, errors, m, i, j, k);
671  }
672  }
673  }
674  }
675 
688  template <typename A, typename B, typename C, typename M>
689  static void pool_block_3d(const A& in, const B& out, const C& errors, M& m, size_t i, size_t j, size_t k, size_t c1, size_t c2, size_t c3) {
690  auto max = out(i, j, k);
691  auto error = errors(i, j, k);
692 
693  for (size_t ii = 0; ii < c1; ++ii) {
694  for (size_t jj = 0; jj < c2; ++jj) {
695  for (size_t kk = 0; kk < c3; ++kk) {
696  if (max == in(i * c1 + ii, j * c2 + jj, k * c3 + kk)) {
697  m(i * c1 + ii, j * c2 + jj, k * c3 + kk) = error;
698  } else {
699  m(i * c1 + ii, j * c2 + jj, k * c3 + kk) = 0.0;
700  }
701  }
702  }
703  }
704  }
705 
718  template <typename A, typename B, typename C, typename M>
719  static void pool_block_4d(const A& in, const B& out, const C& errors, M& m, size_t n, size_t i, size_t j, size_t k, size_t c1, size_t c2, size_t c3) {
720  auto max = out(n, i, j, k);
721  auto error = errors(n, i, j, k);
722 
723  for (size_t ii = 0; ii < c1; ++ii) {
724  for (size_t jj = 0; jj < c2; ++jj) {
725  for (size_t kk = 0; kk < c3; ++kk) {
726  if (max == in(n, i * c1 + ii, j * c2 + jj, k * c3 + kk)) {
727  m(n, i * c1 + ii, j * c2 + jj, k * c3 + kk) = error;
728  } else {
729  m(n, i * c1 + ii, j * c2 + jj, k * c3 + kk) = 0.0;
730  }
731  }
732  }
733  }
734  }
735 
744  template <etl_3d A, typename B, typename C, typename M>
745  static void apply(A&& in, B&& out, C&& errors, M&& m, size_t c1, size_t c2, size_t c3) {
746  for (size_t i = 0; i < etl::dim<0>(out); ++i) {
747  for (size_t j = 0; j < etl::dim<1>(out); ++j) {
748  for (size_t k = 0; k < etl::dim<2>(out); ++k) {
749  pool_block_3d(in, out, errors, m, i, j, k, c1, c2, c3);
750  }
751  }
752  }
753  }
754 
755  /*
756  * 4D handling
757  *
758  * This is especially optimized because this is the most common
759  * case in machine learning. Moreover, this is also easy to
760  * parallelize and optimize
761  */
762 
771  template <size_t C1, size_t C2, size_t C3, etl_4d A, typename B, typename C, typename M>
772  static void apply(A&& in, B&& out, C&& errors, M& m) {
773  auto batch_fun_n = [&](const size_t first, const size_t last) {
774  for (size_t n = first; n < last; ++n) {
775  for (size_t i = 0; i < etl::dim<1>(out); ++i) {
776  for (size_t j = 0; j < etl::dim<2>(out); ++j) {
777  for (size_t k = 0; k < etl::dim<3>(out); ++k) {
778  max_pool_upsample_3d::pool_block_4d<C1, C2, C3>(in, out, errors, m, n, i, j, k);
779  }
780  }
781  }
782  }
783  };
784 
785  const size_t N = etl::dim<0>(out);
786 
787  engine_dispatch_1d_serial(batch_fun_n, 0, N, 2UL);
788  }
789 
798  template <etl_4d A, typename B, typename C, typename M>
799  static void apply(A&& in, B&& out, C&& errors, M& m, size_t c1, size_t c2, size_t c3) {
800  auto batch_fun_n = [&](const size_t first, const size_t last) {
801  for (size_t n = first; n < last; ++n) {
802  for (size_t i = 0; i < etl::dim<1>(out); ++i) {
803  for (size_t j = 0; j < etl::dim<2>(out); ++j) {
804  for (size_t k = 0; k < etl::dim<3>(out); ++k) {
805  max_pool_upsample_3d::pool_block_4d(in, out, errors, m, n, i, j, k, c1, c2, c3);
806  }
807  }
808  }
809  }
810  };
811 
812  const size_t N = etl::dim<0>(out);
813 
814  engine_dispatch_1d_serial(batch_fun_n, 0, N, 2UL);
815  }
816 
817  // Deep handling
818 
827  template <size_t C1, size_t C2, size_t C3, etl_5d_and_plus A, typename B, typename C, typename M>
828  static void apply(A&& in, B&& out, C&& errors, M& m) {
829  for (size_t i = 0; i < etl::dim<0>(in); ++i) {
830  apply<C1, C2, C3>(in(i), out(i), errors(i), m(i));
831  }
832  }
833 
842  template <etl_5d_and_plus A, typename B, typename C, typename M>
843  static void apply(A&& in, B&& out, C&& errors, M& m, size_t c1, size_t c2, size_t c3) {
844  for (size_t i = 0; i < etl::dim<0>(in); ++i) {
845  apply(in(i), out(i), errors(i), m(i), c1, c2, c3);
846  }
847  }
848 };
849 
850 } //end of namespace etl::impl::standard
static void pool_block_3d(const A &in, const B &out, const C &errors, M &m, size_t i, size_t j, size_t k)
Pool a 3D block of the sub expression.
Definition: max_pooling_upsample.hpp:610
auto max(L &&lhs, R &&rhs)
Create an expression with the max value of lhs or rhs.
Definition: expression_builder.hpp:65
static void apply(A &&in, B &&out, C &&errors, M &&m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_upsample.hpp:666
void engine_dispatch_1d_serial(Functor &&functor, size_t first, size_t last, size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:734
static void apply(A &&in, B &&out, C &&errors, M &&m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_upsample.hpp:410
Functor for the derivative of 2D Max Pooling.
Definition: max_pooling_upsample.hpp:15
static void pool_block_3d(const A &in, const B &out, const C &errors, M &m, size_t q, size_t i, size_t j)
Pool a block of the sub expression.
Definition: max_pooling_upsample.hpp:91
static void pool_block_4d(const A &in, const B &out, const C &errors, M &m, size_t p, size_t q, size_t i, size_t j)
Pool a block of the sub expression.
Definition: max_pooling_upsample.hpp:155
static void apply(A &&in, B &&out, C &&errors, M &m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_upsample.hpp:572
Definition: prob_pooling.hpp:10
static void pool_block_2d(const A &in, const B &out, const C &errors, M &m, size_t i, size_t j, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2)
Pool a block of the sub expression.
Definition: max_pooling_upsample.hpp:219
static void pool_block_3d(const A &in, const B &out, const C &errors, M &m, size_t i, size_t j, size_t k, size_t c1, size_t c2, size_t c3)
Pool a 3D block of the sub expression.
Definition: max_pooling_upsample.hpp:689
static void apply(A &&in, B &&out, C &&errors, M &m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2)
Apply the functor on sub and store the result in m.
Definition: max_pooling_upsample.hpp:586
static void pool_block_3d(const A &in, const B &out, const C &errors, M &m, size_t q, size_t i, size_t j, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2)
Pool a block of the sub expression.
Definition: max_pooling_upsample.hpp:283
static void pool_block_2d(const A &in, const B &out, const C &errors, M &m, size_t i, size_t j)
Pool a block of the sub expression.
Definition: max_pooling_upsample.hpp:27
static void pool_block_4d(const A &in, const B &out, const C &errors, M &m, size_t n, size_t i, size_t j, size_t k, size_t c1, size_t c2, size_t c3)
Pool a 4D block of the sub expression.
Definition: max_pooling_upsample.hpp:719
Functor for the derivative of 3D Max Pooling.
Definition: max_pooling_upsample.hpp:596
static void apply(A &&in, B &&out, C &&errors, M &&m, size_t c1, size_t c2, size_t c3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_upsample.hpp:745
static void pool_block_4d(const A &in, const B &out, const C &errors, M &m, size_t n, size_t i, size_t j, size_t k)
Pool a 4D block of the sub expression.
Definition: max_pooling_upsample.hpp:640
static void pool_block_4d(const A &in, const B &out, const C &errors, M &m, size_t p, size_t q, size_t i, size_t j, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2)
Pool a block of the sub expression.
Definition: max_pooling_upsample.hpp:347
static void apply(A &&in, B &&out, C &&errors, M &m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_upsample.hpp:772
static void apply(A &&in, B &&out, C &&errors, M &m, size_t c1, size_t c2, size_t c3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_upsample.hpp:799
static void apply(A &&in, B &&out, C &&errors, M &&m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2)
Apply the functor on sub and store the result in m.
Definition: max_pooling_upsample.hpp:430