Expression Templates Library (ETL)
conv_multi.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 namespace etl::detail {
16 
20 template <size_t S1, size_t S2, size_t P1, size_t P2>
28  template <typename I, typename K, typename C>
29  static void apply(I&& input, K&& kernel, C&& conv) {
30  constexpr_select auto impl = select_conv_valid_multi_impl<I, K, C>();
31 
32  if
33  constexpr_select(impl == etl::conv_multi_impl::BLAS_VEC) {
34  inc_counter("inc:blas_vec");
35  impl::vec::blas_conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
36  }
37  else if
38  constexpr_select(impl == etl::conv_multi_impl::BLAS_MKL) {
39  inc_counter("inc:blas_mkl");
40  impl::blas::blas_conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
41  }
42  else if
43  constexpr_select(impl == etl::conv_multi_impl::VALID_FFT_MKL) {
44  inc_counter("inc:fft_mkl");
45  impl::blas::fft_conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
46  }
47  else if
48  constexpr_select(impl == etl::conv_multi_impl::CUDNN) {
49  inc_counter("inc:cudnn");
50  impl::cudnn::conv2_valid_multi(smart_forward_gpu(input), smart_forward_gpu(kernel), conv, S1, S2, P1, P2);
51  }
52  else if
53  constexpr_select(impl == etl::conv_multi_impl::VEC) {
54  inc_counter("inc:vec");
55  impl::vec::conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
56  }
57  else if
58  constexpr_select(impl == etl::conv_multi_impl::STD) {
59  inc_counter("inc:std");
60  impl::standard::conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
61  }
62  else {
63  cpp_unreachable("Invalid conv implementation selection");
64  }
65  }
66 };
67 
71 template <size_t S1, size_t S2, size_t P1, size_t P2>
79  template <typename I, typename K, typename C>
80  static void apply(I&& input, K&& kernel, C&& conv) {
81  constexpr_select auto impl = select_conv_valid_multi_impl<I, K, C>();
82 
83  if
84  constexpr_select(impl == etl::conv_multi_impl::BLAS_VEC) {
85  inc_counter("inc:blas_vec");
86  impl::vec::blas_conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
87  }
88  else if
89  constexpr_select(impl == etl::conv_multi_impl::BLAS_MKL) {
90  inc_counter("inc:blas_mkl");
91  impl::blas::blas_conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
92  }
93  else if
94  constexpr_select(impl == etl::conv_multi_impl::VALID_FFT_MKL) {
95  inc_counter("inc:fft_mkl");
96  impl::blas::fft_conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
97  }
98  else if
99  constexpr_select(impl == etl::conv_multi_impl::CUDNN) {
100  inc_counter("inc:cudnn");
101  impl::cudnn::conv2_valid_multi_flipped(smart_forward_gpu(input), smart_forward_gpu(kernel), conv, S1, S2, P1, P2);
102  }
103  else if
104  constexpr_select(impl == etl::conv_multi_impl::VEC) {
105  inc_counter("inc:vec");
106  impl::vec::conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
107  }
108  else if
109  constexpr_select(impl == etl::conv_multi_impl::STD) {
110  inc_counter("inc:std");
111  impl::standard::conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
112  }
113  else {
114  cpp_unreachable("Invalid conv implementation selection");
115  }
116  }
117 };
118 
122 template <size_t S1, size_t S2, size_t P1, size_t P2>
130  template <typename I, typename K, typename C>
131  static void apply(I&& input, K&& kernel, C&& conv) {
132  constexpr_select auto impl = select_conv_valid_multi_multi_impl<I, K, C>();
133 
134  if
135  constexpr_select(impl == etl::conv_multi_impl::BLAS_VEC) {
136  inc_counter("inc:blas_vec");
137  impl::vec::blas_conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
138  }
139  else if
140  constexpr_select(impl == etl::conv_multi_impl::BLAS_MKL) {
141  inc_counter("inc:blas_mkl");
142  impl::blas::blas_conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
143  }
144  else if
145  constexpr_select(impl == etl::conv_multi_impl::VALID_FFT_MKL) {
146  inc_counter("inc:fft_mkl");
147  impl::blas::fft_conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
148  }
149  else if
150  constexpr_select(impl == etl::conv_multi_impl::VEC) {
151  inc_counter("inc:vec");
152  impl::vec::conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
153  }
154  else if
155  constexpr_select(impl == etl::conv_multi_impl::STD) {
156  inc_counter("inc:std");
157  impl::standard::conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
158  }
159  else {
160  cpp_unreachable("Invalid conv implementation selection");
161  }
162  }
163 };
164 
168 template <size_t S1, size_t S2, size_t P1, size_t P2>
176  template <typename I, typename K, typename C>
177  static void apply(I&& input, K&& kernel, C&& conv) {
178  constexpr_select auto impl = select_conv_valid_multi_multi_impl<I, K, C>();
179 
180  if
181  constexpr_select(impl == etl::conv_multi_impl::BLAS_VEC) {
182  inc_counter("inc:blas_vec");
183  impl::vec::blas_conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
184  }
185  else if
186  constexpr_select(impl == etl::conv_multi_impl::BLAS_MKL) {
187  inc_counter("inc:blas_mkl");
188  impl::blas::blas_conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
189  }
190  else if
191  constexpr_select(impl == etl::conv_multi_impl::VALID_FFT_MKL) {
192  inc_counter("inc:fft_mkl");
193  impl::blas::fft_conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
194  }
195  else if
196  constexpr_select(impl == etl::conv_multi_impl::VEC) {
197  inc_counter("inc:vec");
198  impl::vec::conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
199  }
200  else if
201  constexpr_select(impl == etl::conv_multi_impl::STD) {
202  inc_counter("inc:std");
203  impl::standard::conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
204  }
205  else {
206  cpp_unreachable("Invalid conv implementation selection");
207  }
208  }
209 };
210 
221  template <typename I, typename K, typename C>
222  static void apply(I&& input, K&& kernel, C&& conv, size_t s1, size_t s2, size_t p1, size_t p2) {
223  constexpr_select auto impl = select_conv_valid_multi_impl<I, K, C>();
224 
225  if
226  constexpr_select(impl == etl::conv_multi_impl::BLAS_VEC) {
227  inc_counter("inc:blas_vec");
228  impl::vec::blas_conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
229  }
230  else if
231  constexpr_select(impl == etl::conv_multi_impl::BLAS_MKL) {
232  inc_counter("inc:blas_mkl");
233  impl::blas::blas_conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
234  }
235  else if
236  constexpr_select(impl == etl::conv_multi_impl::VALID_FFT_MKL) {
237  inc_counter("inc:fft_mkl");
238  impl::blas::fft_conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
239  }
240  else if
241  constexpr_select(impl == etl::conv_multi_impl::CUDNN) {
242  inc_counter("inc:cudnn");
243  impl::cudnn::conv2_valid_multi(smart_forward_gpu(input), smart_forward_gpu(kernel), conv, s1, s2, p1, p2);
244  }
245  else if
246  constexpr_select(impl == etl::conv_multi_impl::VEC) {
247  inc_counter("inc:vec");
248  impl::vec::conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
249  }
250  else if
251  constexpr_select(impl == etl::conv_multi_impl::STD) {
252  inc_counter("inc:std");
253  impl::standard::conv2_valid_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
254  }
255  else {
256  cpp_unreachable("Invalid conv implementation selection");
257  }
258  }
259 };
260 
271  template <typename I, typename K, typename C>
272  static void apply(I&& input, K&& kernel, C&& conv, size_t s1, size_t s2, size_t p1, size_t p2) {
273  constexpr_select auto impl = select_conv_valid_multi_impl<I, K, C>();
274 
275  if
276  constexpr_select(impl == etl::conv_multi_impl::BLAS_VEC) {
277  inc_counter("inc:blas_vec");
278  impl::vec::blas_conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
279  }
280  else if
281  constexpr_select(impl == etl::conv_multi_impl::BLAS_MKL) {
282  inc_counter("inc:blas_mkl");
283  impl::blas::blas_conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
284  }
285  else if
286  constexpr_select(impl == etl::conv_multi_impl::VALID_FFT_MKL) {
287  inc_counter("inc:fft_mkl");
288  impl::blas::fft_conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
289  }
290  else if
291  constexpr_select(impl == etl::conv_multi_impl::CUDNN) {
292  inc_counter("inc:cudnn");
293  impl::cudnn::conv2_valid_multi_flipped(smart_forward_gpu(input), smart_forward_gpu(kernel), conv, s1, s2, p1, p2);
294  }
295  else if
296  constexpr_select(impl == etl::conv_multi_impl::VEC) {
297  inc_counter("inc:vec");
298  impl::vec::conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
299  }
300  else if
301  constexpr_select(impl == etl::conv_multi_impl::STD) {
302  inc_counter("inc:std");
303  impl::standard::conv2_valid_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
304  }
305  else {
306  cpp_unreachable("Invalid conv implementation selection");
307  }
308  }
309 };
310 
321  template <typename I, typename K, typename C>
322  static void apply(I&& input, K&& kernel, C&& conv, size_t s1, size_t s2, size_t p1, size_t p2) {
323  constexpr_select auto impl = select_conv_valid_multi_multi_impl<I, K, C>();
324 
325  if
326  constexpr_select(impl == etl::conv_multi_impl::BLAS_VEC) {
327  inc_counter("inc:blas_vec");
328  impl::vec::blas_conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
329  }
330  else if
331  constexpr_select(impl == etl::conv_multi_impl::BLAS_MKL) {
332  inc_counter("inc:blas_mkl");
333  impl::blas::blas_conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
334  }
335  else if
336  constexpr_select(impl == etl::conv_multi_impl::VALID_FFT_MKL) {
337  inc_counter("inc:fft_mkl");
338  impl::blas::fft_conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
339  }
340  else if
341  constexpr_select(impl == etl::conv_multi_impl::VEC) {
342  inc_counter("inc:vec");
343  impl::vec::conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
344  }
345  else if
346  constexpr_select(impl == etl::conv_multi_impl::STD) {
347  inc_counter("inc:std");
348  impl::standard::conv2_valid_multi_multi(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
349  }
350  else {
351  cpp_unreachable("Invalid conv implementation selection");
352  }
353  }
354 };
355 
366  template <typename I, typename K, typename C>
367  static void apply(I&& input, K&& kernel, C&& conv, size_t s1, size_t s2, size_t p1, size_t p2) {
368  constexpr_select auto impl = select_conv_valid_multi_multi_impl<I, K, C>();
369 
370  if
371  constexpr_select(impl == etl::conv_multi_impl::BLAS_VEC) {
372  inc_counter("inc:blas_vec");
373  impl::vec::blas_conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
374  }
375  else if
376  constexpr_select(impl == etl::conv_multi_impl::BLAS_MKL) {
377  inc_counter("inc:blas_mkl");
378  impl::blas::blas_conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
379  }
380  else if
381  constexpr_select(impl == etl::conv_multi_impl::VALID_FFT_MKL) {
382  inc_counter("inc:fft_mkl");
383  impl::blas::fft_conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
384  }
385  else if
386  constexpr_select(impl == etl::conv_multi_impl::VEC) {
387  inc_counter("inc:vec");
388  impl::vec::conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
389  }
390  else if
391  constexpr_select(impl == etl::conv_multi_impl::STD) {
392  inc_counter("inc:std");
393  impl::standard::conv2_valid_multi_multi_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
394  }
395  else {
396  cpp_unreachable("Invalid conv implementation selection");
397  }
398  }
399 };
400 
401 } //end of namespace etl::detail
The functor impl for 2D valid conv, with multiple kernels.
Definition: conv_multi.hpp:314
Standard implementation.
static void apply(I &&input, K &&kernel, C &&conv)
Apply the convolution.
Definition: conv_multi.hpp:29
The functor impl for 2D valid conv, with multiple kernels.
Definition: conv_multi.hpp:214
static void apply(I &&input, K &&kernel, C &&conv, size_t s1, size_t s2, size_t p1, size_t p2)
Apply the convolution.
Definition: conv_multi.hpp:272
The functor impl for 2D valid conv, with multiple kernels.
Definition: conv_multi.hpp:72
VEC implementation.
The functor impl for 2D valid conv, with multiple kernels.
Definition: conv_multi.hpp:169
Definition: expression_builder.hpp:699
The functor impl for 2D valid conv, with multiple kernels.
Definition: conv_multi.hpp:264
static void apply(I &&input, K &&kernel, C &&conv)
Apply the convolution.
Definition: conv_multi.hpp:131
static void apply(I &&input, K &&kernel, C &&conv)
Apply the convolution.
Definition: conv_multi.hpp:80
The functor impl for 2D valid conv, with multiple kernels.
Definition: conv_multi.hpp:21
static void apply(I &&input, K &&kernel, C &&conv, size_t s1, size_t s2, size_t p1, size_t p2)
Apply the convolution.
Definition: conv_multi.hpp:367
GPU implementation.
static void apply(I &&input, K &&kernel, C &&conv, size_t s1, size_t s2, size_t p1, size_t p2)
Apply the convolution.
Definition: conv_multi.hpp:222
The functor impl for 2D valid conv, with multiple kernels.
Definition: conv_multi.hpp:123
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
Reductiont to FFT (valid)
The functor impl for 2D valid conv, with multiple kernels.
Definition: conv_multi.hpp:359
decltype(auto) smart_forward(E &expr)
Smart forwarding for a temporary expression.
Definition: helpers.hpp:323
static void apply(I &&input, K &&kernel, C &&conv, size_t s1, size_t s2, size_t p1, size_t p2)
Apply the convolution.
Definition: conv_multi.hpp:322
static void apply(I &&input, K &&kernel, C &&conv)
Apply the convolution.
Definition: conv_multi.hpp:177
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25