25 #include <immintrin.h> 46 __m128d* u3_1qubit_tmp = (__m128d*) & u3_1qbit[0];
47 __m256d u3_1qbit_00_vec = _mm256_broadcast_pd(u3_1qubit_tmp);
49 u3_1qubit_tmp = (__m128d*) & u3_1qbit[1];
50 __m256d u3_1qbit_01_vec = _mm256_broadcast_pd(u3_1qubit_tmp);
52 u3_1qubit_tmp = (__m128d*) & u3_1qbit[2];
53 __m256d u3_1qbit_10_vec = _mm256_broadcast_pd(u3_1qubit_tmp);
55 u3_1qubit_tmp = (__m128d*) & u3_1qbit[3];
56 __m256d u3_1qbit_11_vec = _mm256_broadcast_pd(u3_1qubit_tmp);
59 for (
int current_idx_pair=current_idx + index_step_target; current_idx_pair<
matrix_size; current_idx_pair=current_idx_pair+(index_step_target << 1) ) {
61 for (
int idx = 0; idx < index_step_target; idx++) {
64 int current_idx_loc = current_idx + idx;
65 int current_idx_pair_loc = current_idx_pair + idx;
67 int row_offset = current_idx_loc * input.
stride;
68 int row_offset_pair = current_idx_pair_loc * input.
stride;
70 if (control_qbit < 0 || ((current_idx_loc >> control_qbit) & 1)) {
73 double* element = (
double*)input.
get_data() + 2 * row_offset;
74 double* element_pair = (
double*)input.
get_data() + 2 * row_offset_pair;
77 __m256d neg = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);
80 for (
int col_idx = 0; col_idx < 2 * (input.
cols - 1); col_idx = col_idx + 4) {
83 __m256d element_vec = _mm256_load_pd(element + col_idx);
84 __m256d element_pair_vec = _mm256_load_pd(element_pair + col_idx);
89 __m256d vec3 = _mm256_mul_pd(u3_1qbit_00_vec, element_vec);
92 __m256d element_vec_permuted = _mm256_permute_pd(element_vec, 0x5);
95 element_vec_permuted = _mm256_mul_pd(element_vec_permuted, neg);
98 __m256d vec4 = _mm256_mul_pd(u3_1qbit_00_vec, element_vec_permuted);
101 vec3 = _mm256_hsub_pd(vec3, vec4);
107 __m256d vec5 = _mm256_mul_pd(u3_1qbit_01_vec, element_pair_vec);
110 __m256d element_pair_vec_permuted = _mm256_permute_pd(element_pair_vec, 0x5);
113 element_pair_vec_permuted = _mm256_mul_pd(element_pair_vec_permuted, neg);
116 vec4 = _mm256_mul_pd(u3_1qbit_01_vec, element_pair_vec_permuted);
119 vec5 = _mm256_hsub_pd(vec5, vec4);
122 vec3 = _mm256_add_pd(vec3, vec5);
126 _mm256_store_pd(element + col_idx, vec3);
132 vec3 = _mm256_mul_pd(u3_1qbit_10_vec, element_vec);
135 vec4 = _mm256_mul_pd(u3_1qbit_10_vec, element_vec_permuted);
138 vec3 = _mm256_hsub_pd(vec3, vec4);
144 vec5 = _mm256_mul_pd(u3_1qbit_11_vec, element_pair_vec);
147 vec4 = _mm256_mul_pd(u3_1qbit_11_vec, element_pair_vec_permuted);
150 vec5 = _mm256_hsub_pd(vec5, vec4);
153 vec3 = _mm256_add_pd(vec3, vec5);
156 _mm256_store_pd(element_pair + col_idx, vec3);
160 if (input.
cols % 2 == 1) {
162 int col_idx = input.
cols - 1;
164 int index = row_offset + col_idx;
165 int index_pair = row_offset_pair + col_idx;
173 input[index].real = tmp1.
real + tmp2.
real;
174 input[index].imag = tmp1.
imag + tmp2.
imag;
176 tmp1 =
mult(u3_1qbit[2], element);
177 tmp2 =
mult(u3_1qbit[3], element_pair);
179 input[index_pair].real = tmp1.
real + tmp2.
real;
180 input[index_pair].imag = tmp1.
imag + tmp2.
imag;
204 current_idx = current_idx + (index_step_target << 1);
234 __m256d u3_1bit_00r_vec = _mm256_broadcast_sd(&u3_1qbit[0].
real);
235 __m256d u3_1bit_00i_vec = _mm256_broadcast_sd(&u3_1qbit[0].imag);
236 __m256d u3_1bit_01r_vec = _mm256_broadcast_sd(&u3_1qbit[1].real);
237 __m256d u3_1bit_01i_vec = _mm256_broadcast_sd(&u3_1qbit[1].imag);
238 __m256d u3_1bit_10r_vec = _mm256_broadcast_sd(&u3_1qbit[2].real);
239 __m256d u3_1bit_10i_vec = _mm256_broadcast_sd(&u3_1qbit[2].imag);
240 __m256d u3_1bit_11r_vec = _mm256_broadcast_sd(&u3_1qbit[3].real);
241 __m256d u3_1bit_11i_vec = _mm256_broadcast_sd(&u3_1qbit[3].imag);
244 for (
int current_idx_pair=current_idx + index_step_target; current_idx_pair<
matrix_size; current_idx_pair=current_idx_pair+(index_step_target << 1) ) {
247 for (
int idx = 0; idx < index_step_target; idx++) {
250 int current_idx_loc = current_idx + idx;
251 int current_idx_pair_loc = current_idx_pair + idx;
253 int row_offset = current_idx_loc * input.
stride;
254 int row_offset_pair = current_idx_pair_loc * input.
stride;
256 if (control_qbit < 0 || ((current_idx_loc >> control_qbit) & 1)) {
259 double* element = (
double*)input.
get_data() + 2 * row_offset;
260 double* element_pair = (
double*)input.
get_data() + 2 * row_offset_pair;
263 for (
int col_idx = 0; col_idx < 2 * (input.
cols - 3); col_idx = col_idx + 8) {
266 __m256d element_vec = _mm256_load_pd(element + col_idx);
267 __m256d element_vec2 = _mm256_load_pd(element + col_idx + 4);
268 __m256d tmp = _mm256_shuffle_pd(element_vec, element_vec2, 0);
269 element_vec2 = _mm256_shuffle_pd(element_vec, element_vec2, 0xf);
272 __m256d element_pair_vec = _mm256_load_pd(element_pair + col_idx);
273 __m256d element_pair_vec2 = _mm256_load_pd(element_pair + col_idx + 4);
274 tmp = _mm256_shuffle_pd(element_pair_vec, element_pair_vec2, 0);
275 element_pair_vec2 = _mm256_shuffle_pd(element_pair_vec, element_pair_vec2, 0xf);
276 element_pair_vec = tmp;
278 __m256d vec3 = _mm256_mul_pd(u3_1bit_00r_vec, element_vec);
279 vec3 = _mm256_fnmadd_pd(u3_1bit_00i_vec, element_vec2, vec3);
280 __m256d vec4 = _mm256_mul_pd(u3_1bit_01r_vec, element_pair_vec);
281 vec4 = _mm256_fnmadd_pd(u3_1bit_01i_vec, element_pair_vec2, vec4);
282 vec3 = _mm256_add_pd(vec3, vec4);
283 __m256d vec5 = _mm256_mul_pd(u3_1bit_00r_vec, element_vec2);
284 vec5 = _mm256_fmadd_pd(u3_1bit_00i_vec, element_vec, vec5);
285 __m256d vec6 = _mm256_mul_pd(u3_1bit_01r_vec, element_pair_vec2);
286 vec6 = _mm256_fmadd_pd(u3_1bit_01i_vec, element_pair_vec, vec6);
287 vec5 = _mm256_add_pd(vec5, vec6);
290 tmp = _mm256_shuffle_pd(vec3, vec5, 0);
291 vec5 = _mm256_shuffle_pd(vec3, vec5, 0xf);
293 _mm256_store_pd(element + col_idx, vec3);
294 _mm256_store_pd(element + col_idx + 4, vec5);
296 __m256d vec7 = _mm256_mul_pd(u3_1bit_10r_vec, element_vec);
297 vec7 = _mm256_fnmadd_pd(u3_1bit_10i_vec, element_vec2, vec7);
298 __m256d vec8 = _mm256_mul_pd(u3_1bit_11r_vec, element_pair_vec);
299 vec8 = _mm256_fnmadd_pd(u3_1bit_11i_vec, element_pair_vec2, vec8);
300 vec7 = _mm256_add_pd(vec7, vec8);
301 __m256d vec9 = _mm256_mul_pd(u3_1bit_10r_vec, element_vec2);
302 vec9 = _mm256_fmadd_pd(u3_1bit_10i_vec, element_vec, vec9);
303 __m256d vec10 = _mm256_mul_pd(u3_1bit_11r_vec, element_pair_vec2);
304 vec10 = _mm256_fmadd_pd(u3_1bit_11i_vec, element_pair_vec, vec10);
305 vec9 = _mm256_add_pd(vec9, vec10);
308 tmp = _mm256_shuffle_pd(vec7, vec9, 0);
309 vec9 = _mm256_shuffle_pd(vec7, vec9, 0xf);
311 _mm256_store_pd(element_pair + col_idx, vec7);
312 _mm256_store_pd(element_pair + col_idx + 4, vec9);
315 int remainder = input.
cols % 4;
316 if (remainder != 0) {
318 for (
int col_idx = input.
cols-remainder; col_idx < input.
cols; col_idx++) {
319 int index = row_offset + col_idx;
320 int index_pair = row_offset_pair + col_idx;
328 input[index].real = tmp1.
real + tmp2.
real;
329 input[index].imag = tmp1.
imag + tmp2.
imag;
331 tmp1 =
mult(u3_1qbit[2], element);
332 tmp2 =
mult(u3_1qbit[3], element_pair);
334 input[index_pair].real = tmp1.
real + tmp2.
real;
335 input[index_pair].imag = tmp1.
imag + tmp2.
imag;
359 current_idx = current_idx + (index_step_target << 1);
385 __m256d u3_1bit_00r_vec = _mm256_broadcast_sd(&u3_1qbit[0].
real);
386 __m256d u3_1bit_00i_vec = _mm256_broadcast_sd(&u3_1qbit[0].imag);
387 __m256d u3_1bit_01r_vec = _mm256_broadcast_sd(&u3_1qbit[1].real);
388 __m256d u3_1bit_01i_vec = _mm256_broadcast_sd(&u3_1qbit[1].imag);
389 __m256d u3_1bit_10r_vec = _mm256_broadcast_sd(&u3_1qbit[2].real);
390 __m256d u3_1bit_10i_vec = _mm256_broadcast_sd(&u3_1qbit[2].imag);
391 __m256d u3_1bit_11r_vec = _mm256_broadcast_sd(&u3_1qbit[3].real);
392 __m256d u3_1bit_11i_vec = _mm256_broadcast_sd(&u3_1qbit[3].imag);
396 int parallel_outer_cycles = matrix_size/(index_step_target << 1);
397 int outer_grain_size;
398 if ( index_step_target <= 2 ) {
399 outer_grain_size = 32;
401 else if ( index_step_target <= 4 ) {
402 outer_grain_size = 16;
404 else if ( index_step_target <= 8 ) {
405 outer_grain_size = 8;
407 else if ( index_step_target <= 16 ) {
408 outer_grain_size = 4;
411 outer_grain_size = 2;
415 tbb::parallel_for( tbb::blocked_range<int>(0,parallel_outer_cycles,outer_grain_size), [&](tbb::blocked_range<int> r) {
417 int current_idx = r.begin()*(index_step_target << 1);
418 int current_idx_pair = index_step_target + r.begin()*(index_step_target << 1);
420 for (
int rdx=r.begin(); rdx<r.end(); rdx++) {
423 tbb::parallel_for( tbb::blocked_range<int>(0,index_step_target,32), [&](tbb::blocked_range<int> r) {
424 for (
int idx=r.begin(); idx<r.end(); ++idx) {
427 int current_idx_loc = current_idx + idx;
428 int current_idx_pair_loc = current_idx_pair + idx;
430 int row_offset = current_idx_loc * input.
stride;
431 int row_offset_pair = current_idx_pair_loc * input.
stride;
433 if (control_qbit < 0 || ((current_idx_loc >> control_qbit) & 1)) {
436 double* element = (
double*)input.
get_data() + 2 * row_offset;
437 double* element_pair = (
double*)input.
get_data() + 2 * row_offset_pair;
440 for (
int col_idx = 0; col_idx < 2 * (input.
cols - 3); col_idx = col_idx + 8) {
443 __m256d element_vec = _mm256_load_pd(element + col_idx);
444 __m256d element_vec2 = _mm256_load_pd(element + col_idx + 4);
445 __m256d tmp = _mm256_shuffle_pd(element_vec, element_vec2, 0);
446 element_vec2 = _mm256_shuffle_pd(element_vec, element_vec2, 0xf);
449 __m256d element_pair_vec = _mm256_load_pd(element_pair + col_idx);
450 __m256d element_pair_vec2 = _mm256_load_pd(element_pair + col_idx + 4);
451 tmp = _mm256_shuffle_pd(element_pair_vec, element_pair_vec2, 0);
452 element_pair_vec2 = _mm256_shuffle_pd(element_pair_vec, element_pair_vec2, 0xf);
453 element_pair_vec = tmp;
455 __m256d vec3 = _mm256_mul_pd(u3_1bit_00r_vec, element_vec);
456 vec3 = _mm256_fnmadd_pd(u3_1bit_00i_vec, element_vec2, vec3);
457 __m256d vec4 = _mm256_mul_pd(u3_1bit_01r_vec, element_pair_vec);
458 vec4 = _mm256_fnmadd_pd(u3_1bit_01i_vec, element_pair_vec2, vec4);
459 vec3 = _mm256_add_pd(vec3, vec4);
460 __m256d vec5 = _mm256_mul_pd(u3_1bit_00r_vec, element_vec2);
461 vec5 = _mm256_fmadd_pd(u3_1bit_00i_vec, element_vec, vec5);
462 __m256d vec6 = _mm256_mul_pd(u3_1bit_01r_vec, element_pair_vec2);
463 vec6 = _mm256_fmadd_pd(u3_1bit_01i_vec, element_pair_vec, vec6);
464 vec5 = _mm256_add_pd(vec5, vec6);
467 tmp = _mm256_shuffle_pd(vec3, vec5, 0);
468 vec5 = _mm256_shuffle_pd(vec3, vec5, 0xf);
470 _mm256_store_pd(element + col_idx, vec3);
471 _mm256_store_pd(element + col_idx + 4, vec5);
473 __m256d vec7 = _mm256_mul_pd(u3_1bit_10r_vec, element_vec);
474 vec7 = _mm256_fnmadd_pd(u3_1bit_10i_vec, element_vec2, vec7);
475 __m256d vec8 = _mm256_mul_pd(u3_1bit_11r_vec, element_pair_vec);
476 vec8 = _mm256_fnmadd_pd(u3_1bit_11i_vec, element_pair_vec2, vec8);
477 vec7 = _mm256_add_pd(vec7, vec8);
478 __m256d vec9 = _mm256_mul_pd(u3_1bit_10r_vec, element_vec2);
479 vec9 = _mm256_fmadd_pd(u3_1bit_10i_vec, element_vec, vec9);
480 __m256d vec10 = _mm256_mul_pd(u3_1bit_11r_vec, element_pair_vec2);
481 vec10 = _mm256_fmadd_pd(u3_1bit_11i_vec, element_pair_vec, vec10);
482 vec9 = _mm256_add_pd(vec9, vec10);
485 tmp = _mm256_shuffle_pd(vec7, vec9, 0);
486 vec9 = _mm256_shuffle_pd(vec7, vec9, 0xf);
488 _mm256_store_pd(element_pair + col_idx, vec7);
489 _mm256_store_pd(element_pair + col_idx + 4, vec9);
492 int remainder = input.
cols % 4;
493 if (remainder != 0) {
495 for (
int col_idx = input.
cols-remainder; col_idx < input.
cols; col_idx++) {
496 int index = row_offset + col_idx;
497 int index_pair = row_offset_pair + col_idx;
505 input[index].real = tmp1.
real + tmp2.
real;
506 input[index].imag = tmp1.
imag + tmp2.
imag;
508 tmp1 =
mult(u3_1qbit[2], element);
509 tmp2 =
mult(u3_1qbit[3], element_pair);
511 input[index_pair].real = tmp1.
real + tmp2.
real;
512 input[index_pair].imag = tmp1.
imag + tmp2.
imag;
537 current_idx = current_idx + (index_step_target << 1);
538 current_idx_pair = current_idx_pair + (index_step_target << 1);
int stride
The column stride of the array. (The array elements in one row are a_0, a_1, ... a_{cols-1}, 0, 0, 0, 0. The number of zeros is stride-cols)
QGD_Complex16 mult(QGD_Complex16 &a, QGD_Complex16 &b)
Call to calculate the product of two complex scalars.
scalar * get_data() const
Call to get the pointer to the stored data.
int cols
The number of columns.
Structure type representing complex numbers in the SQUANDER package.
Class to store data of complex arrays and its properties.
double real
the real part of a complex number
double imag
the imaginary part of a complex number