Expression Templates Library (ETL)
conv_2d.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 namespace etl::detail {
16 
27  template <typename I, typename K, typename C>
28  static void apply(const I& input, const K& kernel, C& conv) {
29  constexpr_select auto impl = select_conv2_impl_new<conv_type::FULL, I, K, C>();
30 
31  if
32  constexpr_select(impl == etl::conv_impl::VEC) {
33  inc_counter("impl:vec");
34  impl::vec::conv2_full(smart_forward(input), smart_forward(kernel), conv);
35  }
36  else if
37  constexpr_select(impl == etl::conv_impl::CUDNN) {
38  inc_counter("impl:cudnn");
39  impl::cudnn::conv2_full(smart_forward_gpu(input), smart_forward_gpu(kernel), conv);
40  }
41  else if
42  constexpr_select(impl == etl::conv_impl::STD) {
43  inc_counter("impl:std");
44  impl::standard::conv2_full(smart_forward(input), smart_forward(kernel), conv);
45  }
46  else if
47  constexpr_select(impl == etl::conv_impl::FFT_STD) {
48  inc_counter("impl:fft_std");
49  impl::standard::conv2_full_fft(smart_forward(input), smart_forward(kernel), conv);
50  }
51  else if
52  constexpr_select(impl == etl::conv_impl::FFT_MKL) {
53  inc_counter("impl:fft_mkl");
54  impl::blas::conv2_full(smart_forward(input), smart_forward(kernel), conv);
55  }
56  else if
57  constexpr_select(impl == etl::conv_impl::FFT_CUFFT) {
58  inc_counter("impl:fft_cufft");
59  impl::cufft::conv2_full(smart_forward(input), smart_forward(kernel), conv);
60  }
61  else {
62  cpp_unreachable("Invalid conv implementation selection");
63  }
64  }
65 };
66 
77  template <typename I, typename K, typename C>
78  static void apply(const I& input, const K& kernel, C& conv) {
79  constexpr_select auto impl = select_conv2_impl_new<conv_type::FULL, I, K, C>();
80 
81  if
82  constexpr_select(impl == etl::conv_impl::VEC) {
83  inc_counter("impl:vec");
84  impl::vec::conv2_full_flipped(smart_forward(input), smart_forward(kernel), conv);
85  }
86  else if
87  constexpr_select(impl == etl::conv_impl::CUDNN) {
88  inc_counter("impl:cudnn");
89  impl::cudnn::conv2_full_flipped(smart_forward_gpu(input), smart_forward_gpu(kernel), conv);
90  }
91  else if
92  constexpr_select(impl == etl::conv_impl::STD) {
93  inc_counter("impl:std");
94  impl::standard::conv2_full_flipped(smart_forward(input), smart_forward(kernel), conv);
95  }
96  else if
97  constexpr_select(impl == etl::conv_impl::FFT_STD) {
98  inc_counter("impl:fft_std");
99  impl::standard::conv2_full_fft_flipped(smart_forward(input), smart_forward(kernel), conv);
100  }
101  else if
102  constexpr_select(impl == etl::conv_impl::FFT_MKL) {
103  inc_counter("impl:fft_mkl");
104  impl::blas::conv2_full_flipped(smart_forward(input), smart_forward(kernel), conv);
105  }
106  else if
107  constexpr_select(impl == etl::conv_impl::FFT_CUFFT) {
108  inc_counter("impl:fft_cufft");
109  impl::cufft::conv2_full_flipped(smart_forward(input), smart_forward(kernel), conv);
110  }
111  else {
112  cpp_unreachable("Invalid conv implementation selection");
113  }
114  }
115 };
116 
127  template <typename I, typename K, typename C>
128  static void apply(const I& input, const K& kernel, C& conv) {
129  constexpr_select auto impl = select_conv2_impl_new<conv_type::SAME, I, K, C>();
130 
131  if
132  constexpr_select(impl == etl::conv_impl::VEC) {
133  inc_counter("impl:vec");
134  impl::vec::conv2_same(smart_forward(input), smart_forward(kernel), conv);
135  }
136  else if
137  constexpr_select(impl == etl::conv_impl::STD) {
138  inc_counter("impl:std");
139  impl::standard::conv2_same(smart_forward(input), smart_forward(kernel), conv);
140  }
141  else {
142  cpp_unreachable("Invalid conv implementation selection");
143  }
144  }
145 };
146 
157  template <typename I, typename K, typename C>
158  static void apply(const I& input, const K& kernel, C& conv) {
159  constexpr_select auto impl = select_conv2_impl_new<conv_type::SAME, I, K, C>();
160 
161  if
162  constexpr_select(impl == etl::conv_impl::VEC) {
163  inc_counter("impl:vec");
164  impl::vec::conv2_same_flipped(smart_forward(input), smart_forward(kernel), conv);
165  }
166  else if
167  constexpr_select(impl == etl::conv_impl::STD) {
168  inc_counter("impl:std");
169  impl::standard::conv2_same_flipped(smart_forward(input), smart_forward(kernel), conv);
170  }
171  else {
172  cpp_unreachable("Invalid conv implementation selection");
173  }
174  }
175 };
176 
180 template <size_t S1, size_t S2, size_t P1, size_t P2>
188  template <typename I, typename K, typename C>
189  static void apply(const I& input, const K& kernel, C& conv) {
190  constexpr_select auto impl = select_conv_impl<conv_type::VALID, I, K, C>();
191 
192  if /*constepxr_select*/ (impl == etl::conv_impl::VEC) {
193  inc_counter("impl:vec");
194  impl::vec::conv2_valid(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
195  } else if
196  constexpr_select(impl == etl::conv_impl::CUDNN) {
197  inc_counter("impl:cudnn");
198  impl::cudnn::conv2_valid(smart_forward_gpu(input), smart_forward_gpu(kernel), conv, S1, S2, P1, P2);
199  }
200  else if
201  constexpr_select(impl == etl::conv_impl::STD) {
202  inc_counter("impl:std");
203  impl::standard::conv2_valid(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
204  }
205  else {
206  cpp_unreachable("Invalid conv implementation selection");
207  }
208  }
209 };
210 
214 template <size_t S1, size_t S2, size_t P1, size_t P2>
222  template <typename I, typename K, typename C>
223  static void apply(const I& input, const K& kernel, C& conv) {
224  constexpr_select auto impl = select_conv_impl<conv_type::VALID, I, K, C>();
225 
226  if /*constepxr_select*/ (impl == etl::conv_impl::VEC) {
227  inc_counter("impl:vec");
228  impl::vec::conv2_valid_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
229  } else if
230  constexpr_select(impl == etl::conv_impl::CUDNN) {
231  inc_counter("impl:cudnn");
232  impl::cudnn::conv2_valid_flipped(smart_forward_gpu(input), smart_forward_gpu(kernel), conv, S1, S2, P1, P2);
233  }
234  else if
235  constexpr_select(impl == etl::conv_impl::STD) {
236  inc_counter("impl:std");
237  impl::standard::conv2_valid_flipped(smart_forward(input), smart_forward(kernel), conv, S1, S2, P1, P2);
238  }
239  else {
240  cpp_unreachable("Invalid conv implementation selection");
241  }
242  }
243 };
244 
255  template <typename I, typename K, typename C>
256  static void apply(const I& input, const K& kernel, C& conv, size_t s1, size_t s2, size_t p1, size_t p2) {
257  constexpr_select auto impl = select_conv_impl<conv_type::VALID, I, K, C>();
258 
259  if /*constepxr_select*/ (impl == etl::conv_impl::VEC) {
260  inc_counter("impl:vec");
261  impl::vec::conv2_valid(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
262  } else if
263  constexpr_select(impl == etl::conv_impl::CUDNN) {
264  inc_counter("impl:cudnn");
265  impl::cudnn::conv2_valid(smart_forward_gpu(input), smart_forward_gpu(kernel), conv, s1, s2, p1, p2);
266  }
267  else if
268  constexpr_select(impl == etl::conv_impl::STD) {
269  inc_counter("impl:std");
270  impl::standard::conv2_valid(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
271  }
272  else {
273  cpp_unreachable("Invalid conv implementation selection");
274  }
275  }
276 };
277 
288  template <typename I, typename K, typename C>
289  static void apply(const I& input, const K& kernel, C& conv, size_t s1, size_t s2, size_t p1, size_t p2) {
290  constexpr_select auto impl = select_conv_impl<conv_type::VALID, I, K, C>();
291 
292  if /*constepxr_select*/ (impl == etl::conv_impl::VEC) {
293  inc_counter("impl:vec");
294  impl::vec::conv2_valid_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
295  } else if
296  constexpr_select(impl == etl::conv_impl::CUDNN) {
297  inc_counter("impl:cudnn");
298  impl::cudnn::conv2_valid_flipped(smart_forward_gpu(input), smart_forward_gpu(kernel), conv, s1, s2, p1, p2);
299  }
300  else if
301  constexpr_select(impl == etl::conv_impl::STD) {
302  inc_counter("impl:std");
303  impl::standard::conv2_valid_flipped(smart_forward(input), smart_forward(kernel), conv, s1, s2, p1, p2);
304  }
305  else {
306  cpp_unreachable("Invalid conv implementation selection");
307  }
308  }
309 };
310 
311 } //end of namespace etl::detail
FFT reduction (with MKL impl)
The functor impl for 2D same conv.
Definition: conv_2d.hpp:150
Standard implementation.
static void apply(const I &input, const K &kernel, C &conv)
Apply the convolution.
Definition: conv_2d.hpp:223
VEC implementation.
static void apply(const I &input, const K &kernel, C &conv)
Apply the convolution.
Definition: conv_2d.hpp:28
The functor impl for 2D full conv.
Definition: conv_2d.hpp:20
Definition: expression_builder.hpp:699
static void apply(const I &input, const K &kernel, C &conv)
Apply the convolution.
Definition: conv_2d.hpp:158
The functor impl for 2D same conv.
Definition: conv_2d.hpp:120
static void apply(const I &input, const K &kernel, C &conv)
Apply the convolution.
Definition: conv_2d.hpp:78
FFT reduction (with STD impl)
static void apply(const I &input, const K &kernel, C &conv, size_t s1, size_t s2, size_t p1, size_t p2)
Apply the convolution.
Definition: conv_2d.hpp:256
The functor impl for 2D valid conv.
Definition: conv_2d.hpp:181
The functor impl for 2D valid conv.
Definition: conv_2d.hpp:215
GPU implementation.
The functor impl for 2D valid conv.
Definition: conv_2d.hpp:281
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
The functor impl for 2D full conv.
Definition: conv_2d.hpp:70
decltype(auto) smart_forward(E &expr)
Smart forwarding for a temporary expression.
Definition: helpers.hpp:323
static void apply(const I &input, const K &kernel, C &conv)
Apply the convolution.
Definition: conv_2d.hpp:128
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25
static void apply(const I &input, const K &kernel, C &conv)
Apply the convolution.
Definition: conv_2d.hpp:189
The functor impl for 2D valid conv.
Definition: conv_2d.hpp:248
static void apply(const I &input, const K &kernel, C &conv, size_t s1, size_t s2, size_t p1, size_t p2)
Apply the convolution.
Definition: conv_2d.hpp:289
FFT reduction (with CUFFT impl)