Expression Templates Library (ETL)
batch_k_scale.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_EGBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
18 
19 #include <egblas.hpp>
20 
21 #endif
22 
23 namespace etl::impl::egblas {
24 
25 #ifdef EGBLAS_HAS_SBATCH_K_SCALE2
26 static constexpr bool has_sbatch_k_scale2 = true;
27 #else
28 static constexpr bool has_sbatch_k_scale2 = false;
29 #endif
30 
38 inline void batch_k_scale([[maybe_unused]] size_t b,
39  [[maybe_unused]] size_t k,
40  [[maybe_unused]] const float* A,
41  [[maybe_unused]] const float* gamma,
42  [[maybe_unused]] float* B) {
43 #ifdef EGBLAS_HAS_SBATCH_K_SCALE2
44  inc_counter("egblas");
45  egblas_sbatch_k_scale2(b, k, A, gamma, B);
46 #else
47  cpp_unreachable("Invalid call to egblas::batch_k_scale");
48 #endif
49 }
50 
51 #ifdef EGBLAS_HAS_DBATCH_K_SCALE2
52 static constexpr bool has_dbatch_k_scale2 = true;
53 #else
54 static constexpr bool has_dbatch_k_scale2 = false;
55 #endif
56 
64 inline void batch_k_scale([[maybe_unused]] size_t b,
65  [[maybe_unused]] size_t k,
66  [[maybe_unused]] const double* A,
67  [[maybe_unused]] const double* gamma,
68  [[maybe_unused]] double* B) {
69 #ifdef EGBLAS_HAS_DBATCH_K_SCALE2
70  inc_counter("egblas");
71  egblas_dbatch_k_scale2(b, k, A, gamma, B);
72 #else
73  cpp_unreachable("Invalid call to egblas::batch_k_scale");
74 #endif
75 }
76 
77 #ifdef EGBLAS_HAS_SBATCH_K_SCALE4
78 static constexpr bool has_sbatch_k_scale4 = true;
79 #else
80 static constexpr bool has_sbatch_k_scale4 = false;
81 #endif
82 
92 inline void batch_k_scale([[maybe_unused]] size_t b,
93  [[maybe_unused]] size_t k,
94  [[maybe_unused]] size_t m,
95  [[maybe_unused]] size_t n,
96  [[maybe_unused]] const float* A,
97  [[maybe_unused]] const float* gamma,
98  [[maybe_unused]] float* B) {
99 #ifdef EGBLAS_HAS_SBATCH_K_SCALE4
100  inc_counter("egblas");
101  egblas_sbatch_k_scale4(b, k, m, n, A, gamma, B);
102 #else
103  cpp_unreachable("Invalid call to egblas::batch_k_scale");
104 #endif
105 }
106 
107 #ifdef EGBLAS_HAS_DBATCH_K_SCALE4
108 static constexpr bool has_dbatch_k_scale4 = true;
109 #else
110 static constexpr bool has_dbatch_k_scale4 = false;
111 #endif
112 
122 inline void batch_k_scale([[maybe_unused]] size_t b,
123  [[maybe_unused]] size_t k,
124  [[maybe_unused]] size_t m,
125  [[maybe_unused]] size_t n,
126  [[maybe_unused]] const double* A,
127  [[maybe_unused]] const double* gamma,
128  [[maybe_unused]] double* B) {
129 #ifdef EGBLAS_HAS_DBATCH_K_SCALE4
130  inc_counter("egblas");
131  egblas_dbatch_k_scale4(b, k, m, n, A, gamma, B);
132 #else
133  cpp_unreachable("Invalid call to egblas::batch_k_scale");
134 #endif
135 }
136 
137 // batch_k_scale_plus
138 
139 #ifdef EGBLAS_HAS_SBATCH_K_SCALE_PLUS2
140 static constexpr bool has_sbatch_k_scale_plus2 = true;
141 #else
142 static constexpr bool has_sbatch_k_scale_plus2 = false;
143 #endif
144 
152 inline void batch_k_scale_plus([[maybe_unused]] size_t b,
153  [[maybe_unused]] size_t k,
154  [[maybe_unused]] const float* A,
155  [[maybe_unused]] const float* gamma,
156  [[maybe_unused]] const float* beta,
157  [[maybe_unused]] float* B) {
158 #ifdef EGBLAS_HAS_SBATCH_K_SCALE_PLUS2
159  inc_counter("egblas");
160  egblas_sbatch_k_scale_plus2(b, k, A, gamma, beta, B);
161 #else
162  cpp_unreachable("Invalid call to egblas::batch_k_scale_plus");
163 #endif
164 }
165 
166 #ifdef EGBLAS_HAS_DBATCH_K_SCALE_PLUS2
167 static constexpr bool has_dbatch_k_scale_plus2 = true;
168 #else
169 static constexpr bool has_dbatch_k_scale_plus2 = false;
170 #endif
171 
179 inline void batch_k_scale_plus([[maybe_unused]] size_t b,
180  [[maybe_unused]] size_t k,
181  [[maybe_unused]] const double* A,
182  [[maybe_unused]] const double* gamma,
183  [[maybe_unused]] const double* beta,
184  [[maybe_unused]] double* B) {
185 #ifdef EGBLAS_HAS_DBATCH_K_SCALE_PLUS2
186  inc_counter("egblas");
187  egblas_dbatch_k_scale_plus2(b, k, A, gamma, beta, B);
188 #else
189  cpp_unreachable("Invalid call to egblas::batch_k_scale_plus");
190 #endif
191 }
192 
193 #ifdef EGBLAS_HAS_SBATCH_K_SCALE_PLUS4
194 static constexpr bool has_sbatch_k_scale_plus4 = true;
195 #else
196 static constexpr bool has_sbatch_k_scale_plus4 = false;
197 #endif
198 
208 inline void batch_k_scale_plus([[maybe_unused]] size_t b,
209  [[maybe_unused]] size_t k,
210  [[maybe_unused]] size_t m,
211  [[maybe_unused]] size_t n,
212  [[maybe_unused]] const float* A,
213  [[maybe_unused]] const float* gamma,
214  [[maybe_unused]] const float* beta,
215  [[maybe_unused]] float* B) {
216 #ifdef EGBLAS_HAS_SBATCH_K_SCALE_PLUS4
217  inc_counter("egblas");
218  egblas_sbatch_k_scale_plus4(b, k, m, n, A, gamma, beta, B);
219 #else
220  cpp_unreachable("Invalid call to egblas::batch_k_scale_plus");
221 #endif
222 }
223 
224 #ifdef EGBLAS_HAS_DBATCH_K_SCALE_PLUS4
225 static constexpr bool has_dbatch_k_scale_plus4 = true;
226 #else
227 static constexpr bool has_dbatch_k_scale_plus4 = false;
228 #endif
229 
239 inline void batch_k_scale_plus([[maybe_unused]] size_t b,
240  [[maybe_unused]] size_t k,
241  [[maybe_unused]] size_t m,
242  [[maybe_unused]] size_t n,
243  [[maybe_unused]] const double* A,
244  [[maybe_unused]] const double* gamma,
245  [[maybe_unused]] const double* beta,
246  [[maybe_unused]] double* B) {
247 #ifdef EGBLAS_HAS_DBATCH_K_SCALE_PLUS4
248  inc_counter("egblas");
249  egblas_dbatch_k_scale_plus4(b, k, m, n, A, gamma, beta, B);
250 #else
251  cpp_unreachable("Invalid call to egblas::batch_k_scale_plus");
252 #endif
253 }
254 
255 #ifdef EGBLAS_HAS_SBATCH_K_MINUS_SCALE2
256 static constexpr bool has_sbatch_k_minus_scale2 = true;
257 #else
258 static constexpr bool has_sbatch_k_minus_scale2 = false;
259 #endif
260 
268 inline void batch_k_minus_scale([[maybe_unused]] size_t b,
269  [[maybe_unused]] size_t k,
270  [[maybe_unused]] const float* A,
271  [[maybe_unused]] const float* gamma,
272  [[maybe_unused]] const float* beta,
273  [[maybe_unused]] float* B) {
274 #ifdef EGBLAS_HAS_SBATCH_K_MINUS_SCALE2
275  inc_counter("egblas");
276  egblas_sbatch_k_minus_scale2(b, k, A, gamma, beta, B);
277 #else
278  cpp_unreachable("Invalid call to egblas::batch_k_minus_scale");
279 #endif
280 }
281 
282 #ifdef EGBLAS_HAS_DBATCH_K_MINUS_SCALE2
283 static constexpr bool has_dbatch_k_minus_scale2 = true;
284 #else
285 static constexpr bool has_dbatch_k_minus_scale2 = false;
286 #endif
287 
295 inline void batch_k_minus_scale([[maybe_unused]] size_t b,
296  [[maybe_unused]] size_t k,
297  [[maybe_unused]] const double* A,
298  [[maybe_unused]] const double* gamma,
299  [[maybe_unused]] const double* beta,
300  [[maybe_unused]] double* B) {
301 #ifdef EGBLAS_HAS_DBATCH_K_MINUS_SCALE2
302  inc_counter("egblas");
303  egblas_dbatch_k_minus_scale2(b, k, A, gamma, beta, B);
304 #else
305  cpp_unreachable("Invalid call to egblas::batch_k_minus_scale");
306 #endif
307 }
308 
309 #ifdef EGBLAS_HAS_SBATCH_K_MINUS_SCALE4
310 static constexpr bool has_sbatch_k_minus_scale4 = true;
311 #else
312 static constexpr bool has_sbatch_k_minus_scale4 = false;
313 #endif
314 
324 inline void batch_k_minus_scale([[maybe_unused]] size_t b,
325  [[maybe_unused]] size_t k,
326  [[maybe_unused]] size_t m,
327  [[maybe_unused]] size_t n,
328  [[maybe_unused]] const float* A,
329  [[maybe_unused]] const float* gamma,
330  [[maybe_unused]] const float* beta,
331  [[maybe_unused]] float* B) {
332 #ifdef EGBLAS_HAS_SBATCH_K_MINUS_SCALE4
333  inc_counter("egblas");
334  egblas_sbatch_k_minus_scale4(b, k, m, n, A, gamma, beta, B);
335 #else
336  cpp_unreachable("Invalid call to egblas::batch_k_minus_scale");
337 #endif
338 }
339 
340 #ifdef EGBLAS_HAS_DBATCH_K_MINUS_SCALE4
341 static constexpr bool has_dbatch_k_minus_scale4 = true;
342 #else
343 static constexpr bool has_dbatch_k_minus_scale4 = false;
344 #endif
345 
355 inline void batch_k_minus_scale([[maybe_unused]] size_t b,
356  [[maybe_unused]] size_t k,
357  [[maybe_unused]] size_t m,
358  [[maybe_unused]] size_t n,
359  [[maybe_unused]] const double* A,
360  [[maybe_unused]] const double* gamma,
361  [[maybe_unused]] const double* beta,
362  [[maybe_unused]] double* B) {
363 #ifdef EGBLAS_HAS_DBATCH_K_MINUS_SCALE4
364  inc_counter("egblas");
365  egblas_dbatch_k_minus_scale4(b, k, m, n, A, gamma, beta, B);
366 #else
367  cpp_unreachable("Invalid call to egblas::batch_k_minus_scale");
368 #endif
369 }
370 
371 } //end of namespace etl::impl::egblas
batch_k_scale_expr< detail::build_type< A >, detail::build_type< B > > batch_k_scale(const A &a, const B &b)
Returns the transpose of the given expression.
Definition: batch_k_scale_expr.hpp:1495
batch_k_scale_plus_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_scale_plus(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_scale_plus_expr.hpp:1575
batch_k_minus_scale_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_minus_scale(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_minus_scale_expr.hpp:1575
Definition: abs.hpp:23
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25