Expression Templates Library (ETL)
avx512_vectorization.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 //TODO Implementation of AVX-512 complex multiplication and division
14 
15 #pragma once
16 
17 #ifdef __AVX512F__
18 
19 #include <immintrin.h>
20 
21 #include "etl/inline.hpp"
22 
23 #ifdef VECT_DEBUG
24 #include <iostream>
25 #endif
26 
27 #define ETL_INLINE_VEC_VOID ETL_STATIC_INLINE(void)
28 #define ETL_INLINE_VEC_512 ETL_STATIC_INLINE(__m512)
29 #define ETL_INLINE_VEC_512D ETL_STATIC_INLINE(__m512d)
30 #define ETL_OUT_VEC_2512ETL_OUT_INLINE(__m512)
31 #define ETL_OUT_VEC_512D ETL_OUT_INLINE(__m512d)
32 
33 namespace etl {
34 
38 using avx_512_simd_float = simd_pack<vector_mode_t::AVX512, float, __m512>;
39 
43 using avx_512_simd_double = simd_pack<vector_mode_t::AVX512, double, __m512d>;
44 
48 template <typename T>
49 using avx_512_simd_complex_float = simd_pack<vector_mode_t::AVX512, T, __m512>;
50 
54 template <typename T>
55 using avx_512_simd_complex_double = simd_pack<vector_mode_t::AVX512, T, __m512d>;
56 
60 using avx_512_simd_byte = simd_pack<vector_mode_t::AVX512, int8_t, __m512i>;
61 
65 using avx_512_simd_short = simd_pack<vector_mode_t::AVX512, int16_t, __m512i>;
66 
70 using avx_512_simd_int = simd_pack<vector_mode_t::AVX512, int32_t, __m512i>;
71 
75 using avx_512_simd_long = simd_pack<vector_mode_t::AVX512, int64_t, __m512i>;
76 
80 template <typename T>
81 struct avx512_intrinsic_traits {
82  static constexpr bool vectorizable = false;
83  static constexpr size_t size = 1;
84  static constexpr size_t alignment = alignof(T);
85 
86  using intrinsic_type = T;
87 };
88 
92 template <>
93 struct avx512_intrinsic_traits<float> {
94  static constexpr bool vectorizable = true;
95  static constexpr size_t size = 16;
96  static constexpr size_t alignment = 64;
97 
98  using intrinsic_type = avx_512_simd_float;
99 };
100 
104 template <>
105 struct avx512_intrinsic_traits<double> {
106  static constexpr bool vectorizable = true;
107  static constexpr size_t size = 8;
108  static constexpr size_t alignment = 64;
109 
110  using intrinsic_type = avx_512_simd_double;
111 };
112 
116 template <>
117 struct avx512_intrinsic_traits<std::complex<float>> {
118  static constexpr bool vectorizable = true;
119  static constexpr size_t size = 8;
120  static constexpr size_t alignment = 64;
121 
122  using intrinsic_type = avx_512_simd_complex_float<std::complex<float>>;
123 };
124 
128 template <>
129 struct avx512_intrinsic_traits<std::complex<double>> {
130  static constexpr bool vectorizable = true;
131  static constexpr size_t size = 4;
132  static constexpr size_t alignment = 64;
133 
134  using intrinsic_type = avx_512_simd_complex_double<std::complex<double>>;
135 };
136 
140 template <>
141 struct avx512_intrinsic_traits<etl::complex<float>> {
142  static constexpr bool vectorizable = true;
143  static constexpr size_t size = 8;
144  static constexpr size_t alignment = 64;
145 
146  using intrinsic_type = avx_512_simd_complex_float<etl::complex<float>>;
147 };
148 
152 template <>
153 struct avx512_intrinsic_traits<etl::complex<double>> {
154  static constexpr bool vectorizable = true;
155  static constexpr size_t size = 4;
156  static constexpr size_t alignment = 64;
157 
158  using intrinsic_type = avx_512_simd_complex_double<etl::complex<double>>;
159 };
160 
164 template <>
165 struct avx512_intrinsic_traits<int8_t> {
166  static constexpr bool vectorizable = true;
167  static constexpr size_t size = 64;
168  static constexpr size_t alignment = 64;
169 
170  using intrinsic_type = avx_512_simd_byte;
171 };
172 
176 template <>
177 struct avx512_intrinsic_traits<int16_t> {
178  static constexpr bool vectorizable = true;
179  static constexpr size_t size = 32;
180  static constexpr size_t alignment = 64;
181 
182  using intrinsic_type = avx_512_simd_short;
183 };
184 
188 template <>
189 struct avx512_intrinsic_traits<int32_t> {
190  static constexpr bool vectorizable = true;
191  static constexpr size_t size = 16;
192  static constexpr size_t alignment = 64;
193 
194  using intrinsic_type = avx_512_simd_int;
195 };
196 
200 template <>
201 struct avx512_intrinsic_traits<int64_t> {
202  static constexpr bool vectorizable = true;
203  static constexpr size_t size = 8;
204  static constexpr size_t alignment = 64;
205 
206  using intrinsic_type = avx_512_simd_long;
207 };
208 
212 struct avx512_vec {
216  template <typename T>
217  using traits = avx512_intrinsic_traits<T>;
218 
222  template <typename T>
223  using vec_type = typename traits<T>::intrinsic_type;
224 
225 #ifdef VEC_DEBUG
226 
230  template <typename T>
231  static std::string debug_d(T value) {
232  union test {
233  __m512d vec;
234  double array[8];
235  test(__m512d vec) : vec(vec) {}
236  };
237 
238  test u_value = value;
239  std::cout << "[" << u_value.array[0] << "," << u_value.array[1] << "," << u_value.array[2] << "," << u_value.array[3] << "," << u_value.array[4] << ","
240  << u_value.array[5] << "," << u_value.array[6] << "," << u_value.array[7] << "]" << std::endl;
241  }
242 
246  template <typename T>
247  static std::string debug_s(T value) {
248  union test {
249  __m512 vec;
250  float array[16];
251  test(__m512 vec) : vec(vec) {}
252  };
253 
254  test u_value = value;
255  std::cout << "[" << u_value.array[0] << "," << u_value.array[1] << "," << u_value.array[2] << "," << u_value.array[3] << "," << u_value.array[4] << ","
256  << u_value.array[5] << "," << u_value.array[6] << "," << u_value.array[7] << "," << u_value.array[8] << "," << u_value.array[9] << ","
257  << u_value.array[10] << "," << u_value.array[11] << "," << u_value.array[12] << "," << u_value.array[13] << "," << u_value.array[14] << ","
258  << u_value.array[15] << "]" << std::endl;
259  }
260 
261 #else
262 
266  template <typename T>
267  static std::string debug_d(T) {
268  return "";
269  }
270 
274  template <typename T>
275  static std::string debug_s(T) {
276  return "";
277  }
278 
279 #endif
280 
285  ETL_INLINE_VEC_VOID storeu(float* memory, avx_512_simd_float value) {
286  _mm512_storeu_ps(memory, value.value);
287  }
288 
293  ETL_INLINE_VEC_VOID storeu(double* memory, avx_512_simd_double value) {
294  _mm512_storeu_pd(memory, value.value);
295  }
296 
301  ETL_INLINE_VEC_VOID storeu(std::complex<float>* memory, avx_512_simd_complex_float<std::complex<float>> value) {
302  _mm512_storeu_ps(reinterpret_cast<float*>(memory), value.value);
303  }
304 
309  ETL_INLINE_VEC_VOID storeu(std::complex<double>* memory, avx_512_simd_complex_double<std::complex<double>> value) {
310  _mm512_storeu_pd(reinterpret_cast<double*>(memory), value.value);
311  }
312 
317  ETL_INLINE_VEC_VOID storeu(etl::complex<float>* memory, avx_512_simd_complex_float<etl::complex<float>> value) {
318  _mm512_storeu_ps(reinterpret_cast<float*>(memory), value.value);
319  }
320 
325  ETL_INLINE_VEC_VOID storeu(etl::complex<double>* memory, avx_512_simd_complex_double<etl::complex<double>> value) {
326  _mm512_storeu_pd(reinterpret_cast<double*>(memory), value.value);
327  }
328 
333  ETL_STATIC_INLINE(void) store(int8_t* memory, avx_512_simd_byte value) {
334  _mm512_store_si512(reinterpret_cast<__m512i*>(memory), value.value);
335  }
336 
341  ETL_STATIC_INLINE(void) store(int16_t* memory, avx_512_simd_short value) {
342  _mm512_store_si512(reinterpret_cast<__m512i*>(memory), value.value);
343  }
344 
349  ETL_STATIC_INLINE(void) store(int32_t* memory, avx_512_simd_int value) {
350  _mm512_store_si512(reinterpret_cast<__m512i*>(memory), value.value);
351  }
352 
357  ETL_STATIC_INLINE(void) store(int64_t* memory, avx_512_simd_long value) {
358  _mm512_store_si512(reinterpret_cast<__m512i*>(memory), value.value);
359  }
360 
365  ETL_INLINE_VEC_VOID store(float* memory, avx_512_simd_float value) {
366  _mm512_store_ps(memory, value.value);
367  }
368 
373  ETL_INLINE_VEC_VOID store(double* memory, avx_512_simd_double value) {
374  _mm512_store_pd(memory, value.value);
375  }
376 
381  ETL_INLINE_VEC_VOID store(std::complex<float>* memory, avx_512_simd_complex_float<std::complex<float>> value) {
382  _mm512_store_ps(reinterpret_cast<float*>(memory), value.value);
383  }
384 
389  ETL_INLINE_VEC_VOID store(std::complex<double>* memory, avx_512_simd_complex_double<std::complex<double>> value) {
390  _mm512_store_pd(reinterpret_cast<double*>(memory), value.value);
391  }
392 
397  ETL_INLINE_VEC_VOID store(etl::complex<float>* memory, avx_512_simd_complex_float<etl::complex<float>> value) {
398  _mm512_store_ps(reinterpret_cast<float*>(memory), value.value);
399  }
400 
405  ETL_INLINE_VEC_VOID store(etl::complex<double>* memory, avx_512_simd_complex_double<etl::complex<double>> value) {
406  _mm512_store_pd(reinterpret_cast<double*>(memory), value.value);
407  }
408 
413  ETL_STATIC_INLINE(void) stream(int8_t* memory, avx_512_simd_byte value) {
414  _mm512_stream_si512(reinterpret_cast<__m512i*>(memory), value.value);
415  }
416 
421  ETL_STATIC_INLINE(void) stream(int16_t* memory, avx_512_simd_short value) {
422  _mm512_stream_si512(reinterpret_cast<__m512i*>(memory), value.value);
423  }
424 
429  ETL_STATIC_INLINE(void) stream(int32_t* memory, avx_512_simd_int value) {
430  _mm512_stream_si512(reinterpret_cast<__m512i*>(memory), value.value);
431  }
432 
437  ETL_STATIC_INLINE(void) stream(int64_t* memory, avx_512_simd_long value) {
438  _mm512_stream_si512(reinterpret_cast<__m512i*>(memory), value.value);
439  }
440 
445  ETL_INLINE_VEC_VOID stream(float* memory, avx_512_simd_float value) {
446  _mm512_stream_ps(memory, value.value);
447  }
448 
453  ETL_INLINE_VEC_VOID stream(double* memory, avx_512_simd_double value) {
454  _mm512_stream_pd(memory, value.value);
455  }
456 
461  ETL_INLINE_VEC_VOID stream(std::complex<float>* memory, avx_512_simd_complex_float<std::complex<float>> value) {
462  _mm512_stream_ps(reinterpret_cast<float*>(memory), value.value);
463  }
464 
469  ETL_INLINE_VEC_VOID stream(std::complex<double>* memory, avx_512_simd_complex_double<std::complex<double>> value) {
470  _mm512_stream_pd(reinterpret_cast<double*>(memory), value.value);
471  }
472 
477  ETL_INLINE_VEC_VOID stream(etl::complex<float>* memory, avx_512_simd_complex_float<etl::complex<float>> value) {
478  _mm512_stream_ps(reinterpret_cast<float*>(memory), value.value);
479  }
480 
485  ETL_INLINE_VEC_VOID stream(etl::complex<double>* memory, avx_512_simd_complex_double<etl::complex<double>> value) {
486  _mm512_stream_pd(reinterpret_cast<double*>(memory), value.value);
487  }
488 
492  ETL_STATIC_INLINE(avx_512_simd_byte) load(const int8_t* memory) {
493  return _mm512_load_si512(reinterpret_cast<const __m512i*>(memory));
494  }
495 
499  ETL_STATIC_INLINE(avx_512_simd_short) load(const int16_t* memory) {
500  return _mm512_load_si512(reinterpret_cast<const __m512i*>(memory));
501  }
502 
506  ETL_STATIC_INLINE(avx_512_simd_int) load(const int32_t* memory) {
507  return _mm512_load_si512(reinterpret_cast<const __m512i*>(memory));
508  }
509 
513  ETL_STATIC_INLINE(avx_512_simd_long) load(const int64_t* memory) {
514  return _mm512_load_si512(reinterpret_cast<const __m512i*>(memory));
515  }
516 
520  ETL_STATIC_INLINE(avx_512_simd_float) load(const float* memory) {
521  return _mm512_load_ps(memory);
522  }
523 
527  ETL_STATIC_INLINE(avx_512_simd_double) load(const double* memory) {
528  return _mm512_load_pd(memory);
529  }
530 
534  ETL_STATIC_INLINE(avx_512_simd_complex_float<std::complex<float>>) load(const std::complex<float>* memory) {
535  return _mm512_load_ps(reinterpret_cast<const float*>(memory));
536  }
537 
541  ETL_STATIC_INLINE(avx_512_simd_complex_double<std::complex<double>>) load(const std::complex<double>* memory) {
542  return _mm512_load_pd(reinterpret_cast<const double*>(memory));
543  }
544 
548  ETL_STATIC_INLINE(avx_512_simd_complex_float<etl::complex<float>>) load(const etl::complex<float>* memory) {
549  return _mm512_load_ps(reinterpret_cast<const float*>(memory));
550  }
551 
555  ETL_STATIC_INLINE(avx_512_simd_complex_double<etl::complex<double>>) load(const etl::complex<double>* memory) {
556  return _mm512_load_pd(reinterpret_cast<const double*>(memory));
557  }
558 
562  ETL_STATIC_INLINE(avx_512_simd_byte) loadu(const int8_t* memory) {
563  return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(memory));
564  }
565 
569  ETL_STATIC_INLINE(avx_512_simd_short) loadu(const int16_t* memory) {
570  return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(memory));
571  }
572 
576  ETL_STATIC_INLINE(avx_512_simd_int) loadu(const int32_t* memory) {
577  return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(memory));
578  }
579 
583  ETL_STATIC_INLINE(avx_512_simd_long) loadu(const int64_t* memory) {
584  return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(memory));
585  }
586 
590  ETL_STATIC_INLINE(avx_512_simd_float) loadu(const float* memory) {
591  return _mm512_loadu_ps(memory);
592  }
593 
597  ETL_STATIC_INLINE(avx_512_simd_double) loadu(const double* memory) {
598  return _mm512_loadu_pd(memory);
599  }
600 
604  ETL_STATIC_INLINE(avx_512_simd_complex_float<std::complex<float>>) loadu(const std::complex<float>* memory) {
605  return _mm512_loadu_ps(reinterpret_cast<const float*>(memory));
606  }
607 
611  ETL_STATIC_INLINE(avx_512_simd_complex_double<std::complex<double>>) loadu(const std::complex<double>* memory) {
612  return _mm512_loadu_pd(reinterpret_cast<const double*>(memory));
613  }
614 
618  ETL_STATIC_INLINE(avx_512_simd_complex_float<etl::complex<float>>) loadu(const etl::complex<float>* memory) {
619  return _mm512_loadu_ps(reinterpret_cast<const float*>(memory));
620  }
621 
625  ETL_STATIC_INLINE(avx_512_simd_complex_double<etl::complex<double>>) loadu(const etl::complex<double>* memory) {
626  return _mm512_loadu_pd(reinterpret_cast<const double*>(memory));
627  }
628 
632  ETL_STATIC_INLINE(avx_512_simd_byte) set(int8_t value) {
633  return _mm512_set1_epi8(value);
634  }
635 
639  ETL_STATIC_INLINE(avx_512_simd_short) set(int16_t value) {
640  return _mm512_set1_epi16(value);
641  }
642 
646  ETL_STATIC_INLINE(avx_512_simd_int) set(int32_t value) {
647  return _mm512_set1_epi32(value);
648  }
649 
653  ETL_STATIC_INLINE(avx_512_simd_long) set(int64_t value) {
654  return _mm512_set1_epi64(value);
655  }
656 
660  ETL_STATIC_INLINE(avx_512_simd_double) set(double value) {
661  return _mm512_set1_pd(value);
662  }
663 
667  ETL_STATIC_INLINE(avx_512_simd_float) set(float value) {
668  return _mm512_set1_ps(value);
669  }
670 
674  ETL_STATIC_INLINE(avx_512_simd_complex_float<std::complex<float>>) set(std::complex<float> value) {
675  std::complex<float> tmp[]{value, value, value, value, value, value, value, value};
676  return loadu(tmp);
677  }
678 
682  ETL_STATIC_INLINE(avx_512_simd_complex_double<std::complex<double>>) set(std::complex<double> value) {
683  std::complex<double> tmp[]{value, value, value, value};
684  return loadu(tmp);
685  }
686 
690  ETL_STATIC_INLINE(avx_512_simd_complex_float<etl::complex<float>>) set(etl::complex<float> value) {
691  etl::complex<float> tmp[]{value, value, value, value, value, value, value, value};
692  return loadu(tmp);
693  }
694 
698  ETL_STATIC_INLINE(avx_512_simd_complex_double<etl::complex<double>>) set(etl::complex<double> value) {
699  etl::complex<double> tmp[]{value, value, value, value};
700  return loadu(tmp);
701  }
702 
706  template <typename T>
707  ETL_TMP_INLINE(typename avx512_intrinsic_traits<T>::intrinsic_type)
708  zero();
709 
713  ETL_STATIC_INLINE(avx_512_simd_float) round_up(avx_512_simd_float x) {
714  return _mm512_roundscale_round_ps(x.value, _MM_FROUND_TO_POS_INF, _MM_FROUND_NO_EXC);
715  }
716 
720  ETL_STATIC_INLINE(avx_512_simd_double) round_up(avx_512_simd_double x) {
721  return _mm512_roundscale_round_pd(x.value, _MM_FROUND_TO_POS_INF, _MM_FROUND_NO_EXC);
722  }
723 
727  ETL_STATIC_INLINE(avx_512_simd_byte) add(avx_512_simd_byte lhs, avx_512_simd_byte rhs) {
728  return _mm512_add_epi8(lhs.value, rhs.value);
729  }
730 
734  ETL_STATIC_INLINE(avx_512_simd_short) add(avx_512_simd_short lhs, avx_512_simd_short rhs) {
735  return _mm512_add_epi16(lhs.value, rhs.value);
736  }
737 
741  ETL_STATIC_INLINE(avx_512_simd_int) add(avx_512_simd_int lhs, avx_512_simd_int rhs) {
742  return _mm512_add_epi32(lhs.value, rhs.value);
743  }
744 
748  ETL_STATIC_INLINE(avx_512_simd_long) add(avx_512_simd_long lhs, avx_512_simd_long rhs) {
749  return _mm512_add_epi64(lhs.value, rhs.value);
750  }
751 
755  ETL_STATIC_INLINE(avx_512_simd_float) add(avx_512_simd_float lhs, avx_512_simd_float rhs) {
756  return _mm512_add_ps(lhs.value, rhs.value);
757  }
758 
762  ETL_STATIC_INLINE(avx_512_simd_double) add(avx_512_simd_double lhs, avx_512_simd_double rhs) {
763  return _mm512_add_pd(lhs.value, rhs.value);
764  }
765 
769  template <typename T>
770  ETL_STATIC_INLINE(avx_512_simd_complex_float<T>)
771  add(avx_512_simd_complex_float<T> lhs, avx_512_simd_complex_float<T> rhs) {
772  return _mm512_add_ps(lhs.value, rhs.value);
773  }
774 
778  template <typename T>
779  ETL_STATIC_INLINE(avx_512_simd_complex_double<T>)
780  add(avx_512_simd_complex_double<T> lhs, avx_512_simd_complex_double<T> rhs) {
781  return _mm512_add_pd(lhs.value, rhs.value);
782  }
783 
787  ETL_STATIC_INLINE(avx_512_simd_byte) sub(avx_512_simd_byte lhs, avx_512_simd_byte rhs) {
788  return _mm512_sub_epi8(lhs.value, rhs.value);
789  }
790 
794  ETL_STATIC_INLINE(avx_512_simd_short) sub(avx_512_simd_short lhs, avx_512_simd_short rhs) {
795  return _mm512_sub_epi16(lhs.value, rhs.value);
796  }
797 
801  ETL_STATIC_INLINE(avx_512_simd_int) sub(avx_512_simd_int lhs, avx_512_simd_int rhs) {
802  return _mm512_sub_epi32(lhs.value, rhs.value);
803  }
804 
808  ETL_STATIC_INLINE(avx_512_simd_long) sub(avx_512_simd_long lhs, avx_512_simd_long rhs) {
809  return _mm512_sub_epi64(lhs.value, rhs.value);
810  }
811 
815  ETL_STATIC_INLINE(avx_512_simd_float) sub(avx_512_simd_float lhs, avx_512_simd_float rhs) {
816  return _mm512_sub_ps(lhs.value, rhs.value);
817  }
818 
822  ETL_STATIC_INLINE(avx_512_simd_double) sub(avx_512_simd_double lhs, avx_512_simd_double rhs) {
823  return _mm512_sub_pd(lhs.value, rhs.value);
824  }
825 
829  template <typename T>
830  ETL_STATIC_INLINE(avx_512_simd_complex_float<T>)
831  sub(avx_512_simd_complex_float<T> lhs, avx_512_simd_complex_float<T> rhs) {
832  return _mm512_sub_ps(lhs.value, rhs.value);
833  }
834 
838  template <typename T>
839  ETL_STATIC_INLINE(avx_512_simd_complex_double<T>)
840  sub(avx_512_simd_complex_double<T> lhs, avx_512_simd_complex_double<T> rhs) {
841  return _mm512_sub_pd(lhs.value, rhs.value);
842  }
843 
848  ETL_STATIC_INLINE(avx_512_simd_float) sqrt(avx_512_simd_float x) {
849  return _mm512_sqrt_ps(x.value);
850  }
851 
856  ETL_STATIC_INLINE(avx_512_simd_double) sqrt(avx_512_simd_double x) {
857  return _mm512_sqrt_pd(x.value);
858  }
859 
864  ETL_STATIC_INLINE(avx_512_simd_float) minus(avx_512_simd_float x) {
865  return _mm512_xor_ps(x.value, _mm512_set1_ps(-0.f));
866  }
867 
872  ETL_STATIC_INLINE(avx_512_simd_double) minus(avx_512_simd_double x) {
873  return _mm512_xor_pd(x.value, _mm512_set1_pd(-0.));
874  }
875 
879  ETL_STATIC_INLINE(avx_512_simd_byte) mul(avx_512_simd_byte lhs, avx_512_simd_byte rhs) {
880  // Split in multiple vectors (odd and even)
881  __m512i lhs_odd = _mm512_srli_epi16(lhs.value, 8);
882  __m512i rhs_odd = _mm512_srli_epi16(rhs.value, 8);
883  // Do the multiplication on each side
884  __m512i mul_even = _mm512_mullo_epi16(lhs.value, rhs.value);
885  __m512i mul_odd = _mm512_mullo_epi16(lhs_odd, rhs_odd);
886  // Combine again
887  __m512i temp = _mm512_slli_epi16(mul_odd, 8);
888  return _mm512_mask_mov_epi8(mul_even, 0xAAAAAAAAAAAAAAAA, temp);
889  }
890 
894  ETL_STATIC_INLINE(avx_512_simd_short) mul(avx_512_simd_short lhs, avx_512_simd_short rhs) {
895  return _mm512_mullo_epi16(lhs.value, rhs.value);
896  }
897 
901  ETL_STATIC_INLINE(avx_512_simd_int) mul(avx_512_simd_int lhs, avx_512_simd_int rhs) {
902  return _mm512_mullo_epi32(lhs.value, rhs.value);
903  }
904 
908  ETL_STATIC_INLINE(avx_512_simd_long) mul(avx_512_simd_long lhs, avx_512_simd_long rhs) {
909  return _mm512_mullo_epi64(lhs.value, rhs.value);
910  }
911 
915  ETL_STATIC_INLINE(avx_512_simd_float) mul(avx_512_simd_float lhs, avx_512_simd_float rhs) {
916  return _mm512_mul_ps(lhs.value, rhs.value);
917  }
918 
922  ETL_STATIC_INLINE(avx_512_simd_double) mul(avx_512_simd_double lhs, avx_512_simd_double rhs) {
923  return _mm512_mul_pd(lhs.value, rhs.value);
924  }
925 
929  template <typename T>
930  ETL_STATIC_INLINE(avx_512_simd_complex_float<T>)
931  mul(avx_512_simd_complex_float<T> lhs, avx_512_simd_complex_float<T> rhs) {
932  //lhs = [x1.real, x1.img, x2.real, x2.img, ...]
933  //rhs = [y1.real, y1.img, y2.real, y2.img, ...]
934 
935  //ymm1 = [y1.real, y1.real, y2.real, y2.real, ...]
936  __m512 ymm1 = _mm512_moveldup_ps(rhs.value);
937 
938  //ymm2 = [x1.img, x1.real, x2.img, x2.real]
939  __m512 ymm2 = _mm512_permute_ps(lhs.value, 0b10110001);
940 
941  //ymm3 = [y1.imag, y1.imag, y2.imag, y2.imag]
942  __m512 ymm3 = _mm512_movehdup_ps(rhs.value);
943 
944  //ymm4 = ymm2 * ymm3
945  __m512 ymm4 = _mm512_mul_ps(ymm2, ymm3);
946 
947  //result = [(lhs * ymm1) -+ ymm4];
948 
949  return _mm512_fmaddsub_ps(lhs.value, ymm1, ymm4);
950  }
951 
955  template <typename T>
956  ETL_STATIC_INLINE(avx_512_simd_complex_double<T>)
957  mul(avx_512_simd_complex_double<T> lhs, avx_512_simd_complex_double<T> rhs) {
958  __m512d ymm1 = _mm512_shuffle_pd(rhs.value, rhs.value, 0x55);
959  __m512d ymm2 = _mm512_shuffle_pd(lhs.value, lhs.value, 0xFF);
960  __m512d ymm3 = _mm512_shuffle_pd(lhs.value, lhs.value, 0);
961  __m512d ymm4 = _mm512_mul_pd(ymm2, ymm1);
962  return _mm512_fmaddsub_pd(ymm3, rhs.value, ymm4);
963  }
964 
968  ETL_STATIC_INLINE(avx_512_simd_float) fmadd(avx_512_simd_float a, avx_512_simd_float b, avx_512_simd_float c) {
969  return _mm512_fmadd_ps(a.value, b.value, c.value);
970  }
971 
975  ETL_STATIC_INLINE(avx_512_simd_double) fmadd(avx_512_simd_double a, avx_512_simd_double b, avx_512_simd_double c) {
976  return _mm512_fmadd_pd(a.value, b.value, c.value);
977  }
978 
982  template <typename T>
983  ETL_STATIC_INLINE(avx_512_simd_complex_float<T>)
984  fmadd(avx_512_simd_complex_float<T> a, avx_512_simd_complex_float<T> b, avx_512_simd_complex_float<T> c) {
985  return add(mul(a, b), c);
986  }
987 
991  template <typename T>
992  ETL_STATIC_INLINE(avx_512_simd_complex_double<T>)
993  fmadd(avx_512_simd_complex_double<T> a, avx_512_simd_complex_double<T> b, avx_512_simd_complex_double<T> c) {
994  return add(mul(a, b), c);
995  }
996 
1000  ETL_STATIC_INLINE(avx_512_simd_float) div(avx_512_simd_float lhs, avx_512_simd_float rhs) {
1001  return _mm512_div_ps(lhs.value, rhs.value);
1002  }
1003 
1007  ETL_STATIC_INLINE(avx_512_simd_double) div(avx_512_simd_double lhs, avx_512_simd_double rhs) {
1008  return _mm512_div_pd(lhs.value, rhs.value);
1009  }
1010 
1011  //Min
1012 
1016  ETL_STATIC_INLINE(avx_512_simd_double) min(avx_512_simd_double lhs, avx_512_simd_double rhs) {
1017  return _mm512_min_pd(lhs.value, rhs.value);
1018  }
1019 
1023  ETL_STATIC_INLINE(avx_512_simd_float) min(avx_512_simd_float lhs, avx_512_simd_float rhs) {
1024  return _mm512_min_ps(lhs.value, rhs.value);
1025  }
1026 
1027  //Max
1028 
1032  ETL_STATIC_INLINE(avx_512_simd_double) max(avx_512_simd_double lhs, avx_512_simd_double rhs) {
1033  return _mm512_max_pd(lhs.value, rhs.value);
1034  }
1035 
1039  ETL_STATIC_INLINE(avx_512_simd_float) max(avx_512_simd_float lhs, avx_512_simd_float rhs) {
1040  return _mm512_max_ps(lhs.value, rhs.value);
1041  }
1042 
1043  // Horizontal sum reductions
1044  // TODO "Vectorize" these reductions
1045 
1051  ETL_STATIC_INLINE(int8_t) hadd(avx_512_simd_byte in) {
1052  int8_t acc = 0;
1053  for (size_t i = 0; i < 64; ++i) {
1054  acc += in[i];
1055  }
1056  return acc;
1057  }
1058 
1064  ETL_STATIC_INLINE(int16_t) hadd(avx_512_simd_short in) {
1065  int16_t acc = 0;
1066  for (size_t i = 0; i < 32; ++i) {
1067  acc += in[i];
1068  }
1069  return acc;
1070  }
1071 
1077  ETL_STATIC_INLINE(int32_t) hadd(avx_512_simd_int in) {
1078  int32_t acc = 0;
1079  for (size_t i = 0; i < 16; ++i) {
1080  acc += in[i];
1081  }
1082  return acc;
1083  }
1084 
1090  ETL_STATIC_INLINE(int64_t) hadd(avx_512_simd_long in) {
1091  int64_t acc = 0;
1092  for (size_t i = 0; i < 8; ++i) {
1093  acc += in[i];
1094  }
1095  return acc;
1096  }
1097 
1103  ETL_STATIC_INLINE(float) hadd(avx_512_simd_float in) {
1104  return in[0] + in[1] + in[2] + in[3] + in[4] + in[5] + in[6] + in[7] + in[8] + in[9] + in[10] + in[11] + in[12] + in[13] + in[14] + in[15];
1105  }
1106 
1112  ETL_STATIC_INLINE(double) hadd(avx_512_simd_double in) {
1113  return in[0] + in[1] + in[2] + in[3] + in[4] + in[5] + in[6] + in[7];
1114  }
1115 
1121  template <typename T>
1122  ETL_STATIC_INLINE(T)
1123  hadd(avx_512_simd_complex_float<T> in) {
1124  return in[0] + in[1] + in[2] + in[3] + in[4] + in[5] + in[6] + in[7];
1125  }
1126 
1132  template <typename T>
1133  ETL_STATIC_INLINE(T)
1134  hadd(avx_512_simd_complex_double<T> in) {
1135  return in[0] + in[1] + in[2] + in[3];
1136  }
1137 };
1138 
1142 template <>
1143 ETL_OUT_INLINE(avx_512_simd_byte)
1144 avx512_vec::zero<int8_t>() {
1145  return _mm512_setzero_si512();
1146 }
1147 
1151 template <>
1152 ETL_OUT_INLINE(avx_512_simd_short)
1153 avx512_vec::zero<int16_t>() {
1154  return _mm512_setzero_si512();
1155 }
1156 
1160 template <>
1161 ETL_OUT_INLINE(avx_512_simd_int)
1162 avx512_vec::zero<int32_t>() {
1163  return _mm512_setzero_si512();
1164 }
1165 
1169 template <>
1170 ETL_OUT_INLINE(avx_512_simd_long)
1171 avx512_vec::zero<int64_t>() {
1172  return _mm512_setzero_si512();
1173 }
1174 
1178 template <>
1179 ETL_OUT_INLINE(avx_512_simd_float)
1180 avx512_vec::zero<float>() {
1181  return _mm512_setzero_ps();
1182 }
1183 
1187 template <>
1188 ETL_OUT_INLINE(avx_512_simd_double)
1189 avx512_vec::zero<double>() {
1190  return _mm512_setzero_pd();
1191 }
1192 
1196 template <>
1197 ETL_OUT_INLINE(avx_512_simd_complex_float<etl::complex<float>>)
1198 avx512_vec::zero<etl::complex<float>>() {
1199  return _mm512_setzero_ps();
1200 }
1201 
1205 template <>
1206 ETL_OUT_INLINE(avx_512_simd_complex_double<etl::complex<double>>)
1207 avx512_vec::zero<etl::complex<double>>() {
1208  return _mm512_setzero_pd();
1209 }
1210 
1214 template <>
1215 ETL_OUT_INLINE(avx_512_simd_complex_float<std::complex<float>>)
1216 avx512_vec::zero<std::complex<float>>() {
1217  return _mm512_setzero_ps();
1218 }
1219 
1223 template <>
1224 ETL_OUT_INLINE(avx_512_simd_complex_double<std::complex<double>>)
1225 avx512_vec::zero<std::complex<double>>() {
1226  return _mm512_setzero_pd();
1227 }
1228 
1229 } //end of namespace etl
1230 
1231 #endif //__AVX512F__
auto max(L &&lhs, R &&rhs)
Create an expression with the max value of lhs or rhs.
Definition: expression_builder.hpp:65
Complex number implementation.
Definition: complex.hpp:31
auto mul(A &&a, B &&b)
Multiply two matrices together.
Definition: gemm_expr.hpp:442
void minus([[maybe_unused]] size_t n, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] float *B, [[maybe_unused]] size_t ldb)
Wrappers for single-precision egblas minus operation.
Definition: minus.hpp:43
auto sqrt(E &&value) -> detail::unary_helper< E, sqrt_unary_op >
Apply square root on each value of the given expression.
Definition: function_expression_builder.hpp:24
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
auto load(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:143
Root namespace for the ETL library.
Definition: adapter.hpp:15
void store(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:176
void stream(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once, using non-temporal store.
Definition: dyn_matrix_view.hpp:165
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
Inlining macros.