21 #include "etl/impl/cublas/cuda.hpp" 34 bool fast_compare(M& lhs, M& rhs) {
35 return etl::dim<0>(lhs) == etl::dim<0>(rhs);
42 bool fast_compare(M& lhs, M& rhs) {
43 return etl::dim<0>(lhs) == etl::dim<0>(rhs) && etl::dim<1>(lhs) == etl::dim<1>(rhs);
50 bool fast_compare(M& lhs, M& rhs) {
51 return etl::dim<0>(lhs) == etl::dim<0>(rhs) && etl::dim<1>(lhs) == etl::dim<1>(rhs) && etl::dim<2>(lhs) == etl::dim<2>(rhs);
58 bool fast_compare(M& lhs, M& rhs) {
59 return etl::dim<0>(lhs) == etl::dim<0>(rhs) && etl::dim<1>(lhs) == etl::dim<1>(rhs) && etl::dim<2>(lhs) == etl::dim<2>(rhs)
60 && etl::dim<3>(lhs) == etl::dim<3>(rhs);
66 template <
typename M,
bool F,
size_t D>
67 struct mat_cache_key_impl;
73 struct mat_cache_key_impl<M, false, 1> {
79 mat_cache_key_impl() {
87 explicit mat_cache_key_impl(M& mat) : a(
etl::
dim<0>(mat)) {
97 return a == etl::dim<0>(rhs);
104 template <
typename M>
105 struct mat_cache_key_impl<M, false, 2> {
112 mat_cache_key_impl() {
120 explicit mat_cache_key_impl(M& mat) : a(
etl::
dim<0>(mat)), b(
etl::
dim<1>(mat)) {
130 return a == etl::dim<0>(a) && b == etl::dim<1>(rhs);
137 template <
typename M>
138 struct mat_cache_key_impl<M, false, 3> {
146 mat_cache_key_impl() {
154 explicit mat_cache_key_impl(M& mat) : a(
etl::
dim<0>(mat)), b(
etl::
dim<1>(mat)), c(
etl::
dim<2>(mat)) {
164 return a == etl::dim<0>(rhs) && b == etl::dim<1>(rhs) && c == etl::dim<2>(rhs);
171 template <
typename M>
172 struct mat_cache_key_impl<M, false, 4> {
181 mat_cache_key_impl() {
189 explicit mat_cache_key_impl(M& mat) : a(
etl::
dim<0>(mat)), b(
etl::
dim<1>(mat)), c(
etl::
dim<2>(mat)), d(
etl::
dim<3>(mat)) {
199 return a == etl::dim<0>(rhs) && b == etl::dim<1>(rhs) && c == etl::dim<2>(rhs) && d == etl::dim<3>(rhs);
206 template <
typename M>
212 template <
typename A,
typename B,
typename C>
213 struct ternary_cache_key {
214 mat_cache_key<A> key_a;
215 mat_cache_key<B> key_b;
216 mat_cache_key<C> key_c;
221 ternary_cache_key() {
231 ternary_cache_key(A& a, B& b, C& c) : key_a(a), key_b(b), key_c(c) {
243 bool equals(A& a, B& b, C& c) {
244 return key_a == a && key_b == b && key_c == c;
251 template <
typename K,
typename V,
size_t L = 16>
252 struct ternary_static_cache {
253 std::array<K, L> keys;
258 static constexpr
size_t last = L;
267 template <
typename A,
typename B,
typename C>
268 size_t find(A& a, B& b, C& c) {
269 for (
size_t i = 0; i < size; ++i) {
270 if (keys[i].equals(a, b, c)) {
286 template <
typename A,
typename B,
typename C>
287 size_t insert(A& a, B& b, C& c) {
288 if (size == last - 1) {
294 new (&keys[size - 1]) K(a, b, c);
312 struct conv4_descriptor {
313 cudnnTensorDescriptor_t input_tensor;
314 cudnnTensorDescriptor_t output_tensor;
315 cudnnFilterDescriptor_t filter;
316 cudnnConvolutionDescriptor_t convolution;
317 cudnnConvolutionFwdAlgo_t conv_algo;
319 size_t workspace_size = 0;
Definition: bias_add.hpp:24
values_t< V... > values(V... v)
Create a list of values for initializing a dyn_matrix.
Definition: dyn_base.hpp:67
Root namespace for the ETL library.
Definition: adapter.hpp:15
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
auto dim(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, D >>
Return a view representing the ith Dth dimension.
Definition: view_expression_builder.hpp:25
Utility functions for cudnn.
const_return_type operator[](size_t j) const
Returns the element at the given index.
Definition: dyn_matrix_view.hpp:71
bool operator==(const complex< T > &lhs, const complex< T > &rhs)
Test two complex numbers for equality.
Definition: complex.hpp:168