Expression Templates Library (ETL)
cache.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
17 #pragma once
18 
19 #ifdef ETL_CUDNN_MODE
20 
21 #include "etl/impl/cublas/cuda.hpp"
22 #include "etl/impl/cudnn/cudnn.hpp"
23 
24 #endif
25 
26 namespace etl::impl::cudnn {
27 
28 #ifdef ETL_CUDNN_MODE
29 
33 template <etl_1d M>
34 bool fast_compare(M& lhs, M& rhs) {
35  return etl::dim<0>(lhs) == etl::dim<0>(rhs);
36 }
37 
41 template <etl_2d M>
42 bool fast_compare(M& lhs, M& rhs) {
43  return etl::dim<0>(lhs) == etl::dim<0>(rhs) && etl::dim<1>(lhs) == etl::dim<1>(rhs);
44 }
45 
49 template <etl_3d M>
50 bool fast_compare(M& lhs, M& rhs) {
51  return etl::dim<0>(lhs) == etl::dim<0>(rhs) && etl::dim<1>(lhs) == etl::dim<1>(rhs) && etl::dim<2>(lhs) == etl::dim<2>(rhs);
52 }
53 
57 template <etl_4d M>
58 bool fast_compare(M& lhs, M& rhs) {
59  return etl::dim<0>(lhs) == etl::dim<0>(rhs) && etl::dim<1>(lhs) == etl::dim<1>(rhs) && etl::dim<2>(lhs) == etl::dim<2>(rhs)
60  && etl::dim<3>(lhs) == etl::dim<3>(rhs);
61 }
62 
66 template <typename M, bool F, size_t D>
67 struct mat_cache_key_impl;
68 
72 template <typename M>
73 struct mat_cache_key_impl<M, false, 1> {
74  size_t a;
75 
79  mat_cache_key_impl() {
80  // Nothing else to init
81  }
82 
87  explicit mat_cache_key_impl(M& mat) : a(etl::dim<0>(mat)) {
88  // Nothing else to init
89  }
90 
96  bool operator==(M& rhs) {
97  return a == etl::dim<0>(rhs);
98  }
99 };
100 
104 template <typename M>
105 struct mat_cache_key_impl<M, false, 2> {
106  size_t a;
107  size_t b;
108 
112  mat_cache_key_impl() {
113  // Nothing else to init
114  }
115 
120  explicit mat_cache_key_impl(M& mat) : a(etl::dim<0>(mat)), b(etl::dim<1>(mat)) {
121  // Nothing else to init
122  }
123 
129  bool operator==(M& rhs) {
130  return a == etl::dim<0>(a) && b == etl::dim<1>(rhs);
131  }
132 };
133 
137 template <typename M>
138 struct mat_cache_key_impl<M, false, 3> {
139  size_t a;
140  size_t b;
141  size_t c;
142 
146  mat_cache_key_impl() {
147  // Nothing else to init
148  }
149 
154  explicit mat_cache_key_impl(M& mat) : a(etl::dim<0>(mat)), b(etl::dim<1>(mat)), c(etl::dim<2>(mat)) {
155  // Nothing else to init
156  }
157 
163  bool operator==(M& rhs) {
164  return a == etl::dim<0>(rhs) && b == etl::dim<1>(rhs) && c == etl::dim<2>(rhs);
165  }
166 };
167 
171 template <typename M>
172 struct mat_cache_key_impl<M, false, 4> {
173  size_t a;
174  size_t b;
175  size_t c;
176  size_t d;
177 
181  mat_cache_key_impl() {
182  // Nothing else to init
183  }
184 
189  explicit mat_cache_key_impl(M& mat) : a(etl::dim<0>(mat)), b(etl::dim<1>(mat)), c(etl::dim<2>(mat)), d(etl::dim<3>(mat)) {
190  // Nothing else to init
191  }
192 
198  bool operator==(M& rhs) {
199  return a == etl::dim<0>(rhs) && b == etl::dim<1>(rhs) && c == etl::dim<2>(rhs) && d == etl::dim<3>(rhs);
200  }
201 };
202 
206 template <typename M>
207 using mat_cache_key = mat_cache_key_impl<M, is_fast<M>, decay_traits<M>::dimensions()>;
208 
212 template <typename A, typename B, typename C>
213 struct ternary_cache_key {
214  mat_cache_key<A> key_a;
215  mat_cache_key<B> key_b;
216  mat_cache_key<C> key_c;
217 
221  ternary_cache_key() {
222  // Nothing else to init
223  }
224 
231  ternary_cache_key(A& a, B& b, C& c) : key_a(a), key_b(b), key_c(c) {
232  // Nothing else to init
233  }
234 
243  bool equals(A& a, B& b, C& c) {
244  return key_a == a && key_b == b && key_c == c;
245  }
246 };
247 
251 template <typename K, typename V, size_t L = 16>
252 struct ternary_static_cache {
253  std::array<K, L> keys;
254  std::array<V, L> values;
255 
256  size_t size = 0;
257 
258  static constexpr size_t last = L;
259 
267  template <typename A, typename B, typename C>
268  size_t find(A& a, B& b, C& c) {
269  for (size_t i = 0; i < size; ++i) {
270  if (keys[i].equals(a, b, c)) {
271  return i;
272  }
273  }
274 
275  return last;
276  }
277 
286  template <typename A, typename B, typename C>
287  size_t insert(A& a, B& b, C& c) {
288  if (size == last - 1) {
289  return last;
290  }
291 
292  ++size;
293 
294  new (&keys[size - 1]) K(a, b, c);
295 
296  return size - 1;
297  }
298 
304  auto& operator[](size_t i) {
305  return values[i];
306  }
307 };
308 
312 struct conv4_descriptor {
313  cudnnTensorDescriptor_t input_tensor;
314  cudnnTensorDescriptor_t output_tensor;
315  cudnnFilterDescriptor_t filter;
316  cudnnConvolutionDescriptor_t convolution;
317  cudnnConvolutionFwdAlgo_t conv_algo;
318 
319  size_t workspace_size = 0;
320 };
321 
322 #endif
323 
324 } //end of namespace etl::impl::cudnn
Definition: bias_add.hpp:24
values_t< V... > values(V... v)
Create a list of values for initializing a dyn_matrix.
Definition: dyn_base.hpp:67
Root namespace for the ETL library.
Definition: adapter.hpp:15
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
auto dim(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, D >>
Return a view representing the ith Dth dimension.
Definition: view_expression_builder.hpp:25
Utility functions for cudnn.
const_return_type operator[](size_t j) const
Returns the element at the given index.
Definition: dyn_matrix_view.hpp:71
bool operator==(const complex< T > &lhs, const complex< T > &rhs)
Test two complex numbers for equality.
Definition: complex.hpp:168