15 #ifdef ETL_EGBLAS_MODE 17 #include "etl/impl/cublas/cuda.hpp" 25 #ifdef EGBLAS_HAS_SBATCH_K_SCALE2 26 static constexpr
bool has_sbatch_k_scale2 =
true;
28 static constexpr
bool has_sbatch_k_scale2 =
false;
39 [[maybe_unused]]
size_t k,
40 [[maybe_unused]]
const float* A,
41 [[maybe_unused]]
const float* gamma,
42 [[maybe_unused]]
float* B) {
43 #ifdef EGBLAS_HAS_SBATCH_K_SCALE2 45 egblas_sbatch_k_scale2(b, k, A, gamma, B);
47 cpp_unreachable(
"Invalid call to egblas::batch_k_scale");
51 #ifdef EGBLAS_HAS_DBATCH_K_SCALE2 52 static constexpr
bool has_dbatch_k_scale2 =
true;
54 static constexpr
bool has_dbatch_k_scale2 =
false;
65 [[maybe_unused]]
size_t k,
66 [[maybe_unused]]
const double* A,
67 [[maybe_unused]]
const double* gamma,
68 [[maybe_unused]]
double* B) {
69 #ifdef EGBLAS_HAS_DBATCH_K_SCALE2 71 egblas_dbatch_k_scale2(b, k, A, gamma, B);
73 cpp_unreachable(
"Invalid call to egblas::batch_k_scale");
77 #ifdef EGBLAS_HAS_SBATCH_K_SCALE4 78 static constexpr
bool has_sbatch_k_scale4 =
true;
80 static constexpr
bool has_sbatch_k_scale4 =
false;
93 [[maybe_unused]]
size_t k,
94 [[maybe_unused]]
size_t m,
95 [[maybe_unused]]
size_t n,
96 [[maybe_unused]]
const float* A,
97 [[maybe_unused]]
const float* gamma,
98 [[maybe_unused]]
float* B) {
99 #ifdef EGBLAS_HAS_SBATCH_K_SCALE4 101 egblas_sbatch_k_scale4(b, k, m, n, A, gamma, B);
103 cpp_unreachable(
"Invalid call to egblas::batch_k_scale");
107 #ifdef EGBLAS_HAS_DBATCH_K_SCALE4 108 static constexpr
bool has_dbatch_k_scale4 =
true;
110 static constexpr
bool has_dbatch_k_scale4 =
false;
123 [[maybe_unused]]
size_t k,
124 [[maybe_unused]]
size_t m,
125 [[maybe_unused]]
size_t n,
126 [[maybe_unused]]
const double* A,
127 [[maybe_unused]]
const double* gamma,
128 [[maybe_unused]]
double* B) {
129 #ifdef EGBLAS_HAS_DBATCH_K_SCALE4 131 egblas_dbatch_k_scale4(b, k, m, n, A, gamma, B);
133 cpp_unreachable(
"Invalid call to egblas::batch_k_scale");
139 #ifdef EGBLAS_HAS_SBATCH_K_SCALE_PLUS2 140 static constexpr
bool has_sbatch_k_scale_plus2 =
true;
142 static constexpr
bool has_sbatch_k_scale_plus2 =
false;
153 [[maybe_unused]]
size_t k,
154 [[maybe_unused]]
const float* A,
155 [[maybe_unused]]
const float* gamma,
156 [[maybe_unused]]
const float* beta,
157 [[maybe_unused]]
float* B) {
158 #ifdef EGBLAS_HAS_SBATCH_K_SCALE_PLUS2 160 egblas_sbatch_k_scale_plus2(b, k, A, gamma, beta, B);
162 cpp_unreachable(
"Invalid call to egblas::batch_k_scale_plus");
166 #ifdef EGBLAS_HAS_DBATCH_K_SCALE_PLUS2 167 static constexpr
bool has_dbatch_k_scale_plus2 =
true;
169 static constexpr
bool has_dbatch_k_scale_plus2 =
false;
180 [[maybe_unused]]
size_t k,
181 [[maybe_unused]]
const double* A,
182 [[maybe_unused]]
const double* gamma,
183 [[maybe_unused]]
const double* beta,
184 [[maybe_unused]]
double* B) {
185 #ifdef EGBLAS_HAS_DBATCH_K_SCALE_PLUS2 187 egblas_dbatch_k_scale_plus2(b, k, A, gamma, beta, B);
189 cpp_unreachable(
"Invalid call to egblas::batch_k_scale_plus");
193 #ifdef EGBLAS_HAS_SBATCH_K_SCALE_PLUS4 194 static constexpr
bool has_sbatch_k_scale_plus4 =
true;
196 static constexpr
bool has_sbatch_k_scale_plus4 =
false;
209 [[maybe_unused]]
size_t k,
210 [[maybe_unused]]
size_t m,
211 [[maybe_unused]]
size_t n,
212 [[maybe_unused]]
const float* A,
213 [[maybe_unused]]
const float* gamma,
214 [[maybe_unused]]
const float* beta,
215 [[maybe_unused]]
float* B) {
216 #ifdef EGBLAS_HAS_SBATCH_K_SCALE_PLUS4 218 egblas_sbatch_k_scale_plus4(b, k, m, n, A, gamma, beta, B);
220 cpp_unreachable(
"Invalid call to egblas::batch_k_scale_plus");
224 #ifdef EGBLAS_HAS_DBATCH_K_SCALE_PLUS4 225 static constexpr
bool has_dbatch_k_scale_plus4 =
true;
227 static constexpr
bool has_dbatch_k_scale_plus4 =
false;
240 [[maybe_unused]]
size_t k,
241 [[maybe_unused]]
size_t m,
242 [[maybe_unused]]
size_t n,
243 [[maybe_unused]]
const double* A,
244 [[maybe_unused]]
const double* gamma,
245 [[maybe_unused]]
const double* beta,
246 [[maybe_unused]]
double* B) {
247 #ifdef EGBLAS_HAS_DBATCH_K_SCALE_PLUS4 249 egblas_dbatch_k_scale_plus4(b, k, m, n, A, gamma, beta, B);
251 cpp_unreachable(
"Invalid call to egblas::batch_k_scale_plus");
255 #ifdef EGBLAS_HAS_SBATCH_K_MINUS_SCALE2 256 static constexpr
bool has_sbatch_k_minus_scale2 =
true;
258 static constexpr
bool has_sbatch_k_minus_scale2 =
false;
269 [[maybe_unused]]
size_t k,
270 [[maybe_unused]]
const float* A,
271 [[maybe_unused]]
const float* gamma,
272 [[maybe_unused]]
const float* beta,
273 [[maybe_unused]]
float* B) {
274 #ifdef EGBLAS_HAS_SBATCH_K_MINUS_SCALE2 276 egblas_sbatch_k_minus_scale2(b, k, A, gamma, beta, B);
278 cpp_unreachable(
"Invalid call to egblas::batch_k_minus_scale");
282 #ifdef EGBLAS_HAS_DBATCH_K_MINUS_SCALE2 283 static constexpr
bool has_dbatch_k_minus_scale2 =
true;
285 static constexpr
bool has_dbatch_k_minus_scale2 =
false;
296 [[maybe_unused]]
size_t k,
297 [[maybe_unused]]
const double* A,
298 [[maybe_unused]]
const double* gamma,
299 [[maybe_unused]]
const double* beta,
300 [[maybe_unused]]
double* B) {
301 #ifdef EGBLAS_HAS_DBATCH_K_MINUS_SCALE2 303 egblas_dbatch_k_minus_scale2(b, k, A, gamma, beta, B);
305 cpp_unreachable(
"Invalid call to egblas::batch_k_minus_scale");
309 #ifdef EGBLAS_HAS_SBATCH_K_MINUS_SCALE4 310 static constexpr
bool has_sbatch_k_minus_scale4 =
true;
312 static constexpr
bool has_sbatch_k_minus_scale4 =
false;
325 [[maybe_unused]]
size_t k,
326 [[maybe_unused]]
size_t m,
327 [[maybe_unused]]
size_t n,
328 [[maybe_unused]]
const float* A,
329 [[maybe_unused]]
const float* gamma,
330 [[maybe_unused]]
const float* beta,
331 [[maybe_unused]]
float* B) {
332 #ifdef EGBLAS_HAS_SBATCH_K_MINUS_SCALE4 334 egblas_sbatch_k_minus_scale4(b, k, m, n, A, gamma, beta, B);
336 cpp_unreachable(
"Invalid call to egblas::batch_k_minus_scale");
340 #ifdef EGBLAS_HAS_DBATCH_K_MINUS_SCALE4 341 static constexpr
bool has_dbatch_k_minus_scale4 =
true;
343 static constexpr
bool has_dbatch_k_minus_scale4 =
false;
356 [[maybe_unused]]
size_t k,
357 [[maybe_unused]]
size_t m,
358 [[maybe_unused]]
size_t n,
359 [[maybe_unused]]
const double* A,
360 [[maybe_unused]]
const double* gamma,
361 [[maybe_unused]]
const double* beta,
362 [[maybe_unused]]
double* B) {
363 #ifdef EGBLAS_HAS_DBATCH_K_MINUS_SCALE4 365 egblas_dbatch_k_minus_scale4(b, k, m, n, A, gamma, beta, B);
367 cpp_unreachable(
"Invalid call to egblas::batch_k_minus_scale");
batch_k_scale_expr< detail::build_type< A >, detail::build_type< B > > batch_k_scale(const A &a, const B &b)
Returns the transpose of the given expression.
Definition: batch_k_scale_expr.hpp:1495
batch_k_scale_plus_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_scale_plus(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_scale_plus_expr.hpp:1575
batch_k_minus_scale_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_minus_scale(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_minus_scale_expr.hpp:1575
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25