Expression Templates Library (ETL)
outer.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_BLAS_MODE
16 #include "cblas.h"
17 #endif
18 
19 namespace etl::impl::blas {
20 
21 #ifdef ETL_BLAS_MODE
22 
29 template <typename A, typename B, typename C>
30 void outer(const A& a, const B& b, C&& c) {
31  c = 0;
32 
33  a.ensure_cpu_up_to_date();
34  b.ensure_cpu_up_to_date();
35  c.ensure_cpu_up_to_date();
36 
37  if constexpr (all_single_precision<A, B, C>) {
38  cblas_sger(CblasRowMajor, etl::dim<0>(a), etl::dim<0>(b), 1.0, a.memory_start(), 1, b.memory_start(), 1, c.memory_start(), etl::dim<0>(b));
39  } else {
40  cblas_dger(CblasRowMajor, etl::dim<0>(a), etl::dim<0>(b), 1.0, a.memory_start(), 1, b.memory_start(), 1, c.memory_start(), etl::dim<0>(b));
41  }
42 
43  c.invalidate_gpu();
44 }
45 
52 template <typename A, typename B, typename C>
53 void batch_outer(const A& a, const B& b, C&& c) {
54  const size_t m = etl::rows(c);
55  const size_t n = etl::columns(c);
56  const size_t k = etl::rows(a);
57 
58  a.ensure_cpu_up_to_date();
59  b.ensure_cpu_up_to_date();
60 
61  if constexpr (all_single_precision<A, B, C>) {
62  cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k, 1.0f, a.memory_start(), m, b.memory_start(), n, 0.0f, c.memory_start(), n);
63  } else {
64  cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k, 1.0, a.memory_start(), m, b.memory_start(), n, 0.0, c.memory_start(), n);
65  }
66 
67  c.invalidate_gpu();
68 }
69 
70 #else
71 
75 template <typename A, typename B, typename C>
76 void outer(const A& /*a*/, const B& /*b*/, C&& /*c*/) {
77  cpp_unreachable("BLAS not enabled/available");
78 }
79 
83 template <typename A, typename B, typename C>
84 void batch_outer(const A& /*a*/, const B& /*b*/, C&& /*c*/) {
85  cpp_unreachable("BLAS not enabled/available");
86 }
87 
88 #endif
89 
90 } //end of namespace etl::impl::blas
batch_outer_product_expr< detail::build_type< A >, detail::build_type< B > > batch_outer(A &&a, B &&b)
Batch Outer product multiplication of two matrices.
Definition: batch_outer_product_expr.hpp:333
size_t columns(const E &expr)
Returns the number of columns of the given ETL expression.
Definition: helpers.hpp:78
Definition: dot.hpp:19
size_t rows(const E &expr)
Returns the number of rows of the given ETL expression.
Definition: helpers.hpp:58
outer_product_expr< detail::build_type< A >, detail::build_type< B > > outer(A &&a, B &&b)
Outer product multiplication of two matrices.
Definition: outer_product_expr.hpp:293