28 template <
size_t C1,
size_t C2,
size_t S1,
size_t S2,
size_t P1,
size_t P2,
typename A,
typename B,
typename M>
33 if constexpr (P1 || P2) {
34 if (cpp_unlikely(j < P1 || k < P2 || j >= etl::dim<0>(out) - P1 || k >= etl::dim<1>(out) - P2)) {
35 const size_t base_j = j * S1 - P1;
36 const size_t base_k = k * S2 - P2;
38 for (
size_t jj = 0; jj < C1; ++jj) {
39 for (
size_t kk = 0; kk < C2; ++kk) {
40 if (base_j + jj < etl::dim<0>(m) && base_k + kk < etl::dim<1>(m)) {
41 if (
max == in(base_j + jj, base_k + kk)) {
42 m(base_j + jj, base_k + kk) = 1.0;
48 m(base_j + jj, base_k + kk) = 0.0;
58 for (
size_t jj = 0; jj < C1; ++jj) {
59 for (
size_t kk = 0; kk < C2; ++kk) {
60 const size_t final_j = j * S1 - P1 + jj;
61 const size_t final_k = k * S2 - P2 + kk;
63 if constexpr (C1 == S1 && C2 == S2) {
64 if (
max == in(final_j, final_k)) {
65 m(final_j, final_k) = 1.0;
71 m(final_j, final_k) = 0.0;
74 if (
max == in(final_j, final_k)) {
76 m(final_j, final_k) = 1.0;
80 m(final_j, final_k) += 1.0;
96 template <
size_t C1,
size_t C2,
size_t C3,
size_t S1,
size_t S2,
size_t S3,
size_t P1,
size_t P2,
size_t P3, etl_2d A,
typename B,
typename M>
97 static void apply(A&& in, B&& out, M&& m) {
98 in.ensure_cpu_up_to_date();
99 out.ensure_cpu_up_to_date();
101 if constexpr (C1 != S1 || C2 != S2) {
105 for (
size_t j = 0; j < etl::dim<0>(out); ++j) {
106 for (
size_t k = 0; k < etl::dim<1>(out); ++k) {
107 pool_derivative_block<C1, C2, S1, S2, P1, P2>(in, out, m, j, k);
125 template <
typename A,
typename B,
typename M>
126 static void pool_derivative_block(
const A& in,
const B& out, M& m,
size_t j,
size_t k,
size_t c1,
size_t c2,
size_t s1,
size_t s2,
size_t p1,
size_t p2) {
127 auto max = out(j, k);
130 if (cpp_unlikely(p1 || p2)) {
131 if (cpp_unlikely(j < p1 || k < p2 || j >= etl::dim<0>(out) - p1 || k >= etl::dim<1>(out) - p2)) {
132 const size_t base_j = j * s1 - p1;
133 const size_t base_k = k * s2 - p2;
135 for (
size_t jj = 0; jj < c1; ++jj) {
136 for (
size_t kk = 0; kk < c2; ++kk) {
137 if (base_j + jj < etl::dim<0>(m) && base_k + kk < etl::dim<1>(m)) {
139 if (
max == in(base_j + jj, base_k + kk)) {
140 m(base_j + jj, base_k + kk) = 1.0;
146 m(base_j + jj, base_k + kk) = 0.0;
156 if (c1 == s1 && c2 == s2) {
157 for (
size_t jj = 0; jj < c1; ++jj) {
158 for (
size_t kk = 0; kk < c2; ++kk) {
159 const size_t final_j = j * s1 - p1 + jj;
160 const size_t final_k = k * s2 - p2 + kk;
162 if (
max == in(final_j, final_k)) {
163 m(final_j, final_k) = 1.0;
169 m(final_j, final_k) = 0.0;
174 for (
size_t jj = 0; jj < c1; ++jj) {
175 for (
size_t kk = 0; kk < c2; ++kk) {
176 const size_t final_j = j * s1 - p1 + jj;
177 const size_t final_k = k * s2 - p2 + kk;
179 if (
max == in(final_j, final_k)) {
181 m(final_j, final_k) = 1.0;
185 m(final_j, final_k) += 1.0;
201 template <etl_2d A,
typename B,
typename M>
207 [[maybe_unused]]
size_t c3,
210 [[maybe_unused]]
size_t s3,
213 [[maybe_unused]]
size_t p3) {
214 in.ensure_cpu_up_to_date();
215 out.ensure_cpu_up_to_date();
217 if (c1 != s1 || c2 != s2) {
221 for (
size_t j = 0; j < etl::dim<0>(out); ++j) {
222 for (
size_t k = 0; k < etl::dim<1>(out); ++k) {
223 pool_derivative_block(in, out, m, j, k, c1, c2, s1, s2, p1, p2);
241 template <
size_t C1,
size_t C2,
size_t C3,
size_t S1,
size_t S2,
size_t S3,
size_t P1,
size_t P2,
size_t P3, deep_mat A,
typename B,
typename M>
242 static void apply(A&& in, B&& out, M& m) {
243 in.ensure_cpu_up_to_date();
244 out.ensure_cpu_up_to_date();
246 for (
size_t i = 0; i < etl::dim<0>(in); ++i) {
247 apply<C1, C2, C3, S1, S2, S3, P1, P2, P3>(in(i), out(i), m(i));
262 template <deep_mat A,
typename B,
typename M>
263 static void apply(A&& in, B&& out, M& m,
size_t c1,
size_t c2,
size_t c3,
size_t s1,
size_t s2,
size_t s3,
size_t p1,
size_t p2,
size_t p3) {
264 in.ensure_cpu_up_to_date();
265 out.ensure_cpu_up_to_date();
267 for (
size_t i = 0; i < etl::dim<0>(in); ++i) {
268 apply(in(i), out(i), m(i), c1, c2, c3, s1, s2, s3, p1, p2, p3);
292 template <
size_t C1,
size_t C2,
size_t C3,
size_t S1,
size_t S2,
size_t S3,
size_t P1,
size_t P2,
size_t P3,
typename A,
typename B,
typename M>
294 auto max = out(i, j, k);
296 for (
size_t ii = 0; ii < C1; ++ii) {
297 for (
size_t jj = 0; jj < C2; ++jj) {
298 for (
size_t kk = 0; kk < C3; ++kk) {
299 if (
max == in(i * C1 + ii, j * C2 + jj, k * C3 + kk)) {
300 m(i * C1 + ii, j * C2 + jj, k * C3 + kk) = 1.0;
302 m(i * C1 + ii, j * C2 + jj, k * C3 + kk) = 0.0;
317 template <
size_t C1,
size_t C2,
size_t C3,
size_t S1,
size_t S2,
size_t S3,
size_t P1,
size_t P2,
size_t P3, etl_3d A,
typename B,
typename M>
318 static void apply(A&& in, B&& out, M&& m) {
319 in.ensure_cpu_up_to_date();
320 out.ensure_cpu_up_to_date();
322 for (
size_t i = 0; i < etl::dim<0>(out); ++i) {
323 for (
size_t j = 0; j < etl::dim<1>(out); ++j) {
324 for (
size_t k = 0; k < etl::dim<2>(out); ++k) {
325 pool_derivative_block<C1, C2, C3, S1, S2, S3, P1, P2, P3>(in, out, m, i, j, k);
346 template <
typename A,
typename B,
typename M>
356 [[maybe_unused]]
size_t s1,
357 [[maybe_unused]]
size_t s2,
358 [[maybe_unused]]
size_t s3,
359 [[maybe_unused]]
size_t p1,
360 [[maybe_unused]]
size_t p2,
361 [[maybe_unused]]
size_t p3) {
362 auto max = out(i, j, k);
364 for (
size_t ii = 0; ii < c1; ++ii) {
365 for (
size_t jj = 0; jj < c2; ++jj) {
366 for (
size_t kk = 0; kk < c3; ++kk) {
367 if (
max == in(i * c1 + ii, j * c2 + jj, k * c3 + kk)) {
368 m(i * c1 + ii, j * c2 + jj, k * c3 + kk) = 1.0;
370 m(i * c1 + ii, j * c2 + jj, k * c3 + kk) = 0.0;
385 template <etl_3d A,
typename B,
typename M>
386 static void apply(A&& in, B&& out, M&& m,
size_t c1,
size_t c2,
size_t c3,
size_t s1,
size_t s2,
size_t s3,
size_t p1,
size_t p2,
size_t p3) {
387 in.ensure_cpu_up_to_date();
388 out.ensure_cpu_up_to_date();
390 for (
size_t i = 0; i < etl::dim<0>(out); ++i) {
391 for (
size_t j = 0; j < etl::dim<1>(out); ++j) {
392 for (
size_t k = 0; k < etl::dim<2>(out); ++k) {
393 pool_derivative_block(in, out, m, i, j, k, c1, c2, c3, s1, s2, s3, p1, p2, p3);
412 template <
size_t C1,
size_t C2,
size_t C3,
size_t S1,
size_t S2,
size_t S3,
size_t P1,
size_t P2,
size_t P3, etl_4d_and_plus A,
typename B,
typename M>
413 static void apply(A&& in, B&& out, M& m) {
414 in.ensure_cpu_up_to_date();
415 out.ensure_cpu_up_to_date();
417 for (
size_t i = 0; i < etl::dim<0>(in); ++i) {
418 apply<C1, C2, C3, S1, S2, S3, P1, P2, P3>(in(i), out(i), m(i));
433 template <etl_4d_and_plus A,
typename B,
typename M>
435 A&& in, B&& out, M& m,
size_t c1,
size_t c2,
size_t c3,
size_t s1,
size_t s2,
size_t s3,
size_t p1,
size_t p2,
size_t p3) {
436 in.ensure_cpu_up_to_date();
437 out.ensure_cpu_up_to_date();
439 for (
size_t i = 0; i < etl::dim<0>(in); ++i) {
440 apply(in(i), out(i), m(i), c1, c2, c3, s1, s2, s3, p1, p2, p3);
static void apply(A &&in, B &&out, M &m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:242
static void pool_derivative_block(const A &in, const B &out, M &m, size_t j, size_t k, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2)
Pool a block of the sub expression.
Definition: max_pooling_derivative.hpp:126
static void apply(A &&in, B &&out, M &m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:263
auto max(L &&lhs, R &&rhs)
Create an expression with the max value of lhs or rhs.
Definition: expression_builder.hpp:65
static void apply(A &&in, B &&out, M &m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:413
static void apply(A &&in, B &&out, M &&m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:97
Functor for the derivative of 3D Max Pooling.
Definition: max_pooling_derivative.hpp:279
static void apply(A &&in, B &&out, M &&m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:386
static void apply(A &&in, B &&out, M &m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:434
static void apply(A &&in, B &&out, M &&m)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:318
constexpr bool cudnn_compatible
Indicates if ETL is trying to generate results similar to CUDNN (default).
Definition: config.hpp:163
static void pool_derivative_block(const A &in, const B &out, M &m, size_t i, size_t j, size_t k, size_t c1, size_t c2, size_t c3, [[maybe_unused]] size_t s1, [[maybe_unused]] size_t s2, [[maybe_unused]] size_t s3, [[maybe_unused]] size_t p1, [[maybe_unused]] size_t p2, [[maybe_unused]] size_t p3)
Pool a block of the sub expression.
Definition: max_pooling_derivative.hpp:347
static void apply(A &&in, B &&out, M &&m, size_t c1, size_t c2, [[maybe_unused]] size_t c3, size_t s1, size_t s2, [[maybe_unused]] size_t s3, size_t p1, size_t p2, [[maybe_unused]] size_t p3)
Apply the functor on sub and store the result in m.
Definition: max_pooling_derivative.hpp:202
Functor for the derivative of 2D Max Pooling.
Definition: max_pooling_derivative.hpp:17
static void pool_derivative_block(const A &in, const B &out, M &m, size_t j, size_t k)
Pool a block of the sub expression.
Definition: max_pooling_derivative.hpp:29
static void pool_derivative_block(const A &in, const B &out, M &m, size_t i, size_t j, size_t k)
Pool a block of the sub expression.
Definition: max_pooling_derivative.hpp:293
Definition: avg_pooling_derivative.hpp:10