Expression Templates Library (ETL)
sse_exp.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 // Most of the code has been taken from Julien Pommier and adapted
9 // for ETL
10 
11 /* Copyright (C) 2007 Julien Pommier
12 
13  This software is provided 'as-is', without any express or implied
14  warranty. In no event will the authors be held liable for any damages
15  arising from the use of this software.
16 
17  Permission is granted to anyone to use this software for any purpose,
18  including commercial applications, and to alter it and redistribute it
19  freely, subject to the following restrictions:
20 
21  1. The origin of this software must not be misrepresented; you must not
22  claim that you wrote the original software. If you use this software
23  in a product, an acknowledgment in the product documentation would be
24  appreciated but is not required.
25  2. Altered source versions must be plainly marked as such, and must not be
26  misrepresented as being the original software.
27  3. This notice may not be removed or altered from any source distribution.
28 
29  (this is the zlib license)
30 */
31 
37 #pragma once
38 
39 #ifdef __SSE3__
40 
41 #include <immintrin.h>
42 #include <xmmintrin.h>
43 #include <emmintrin.h>
44 
45 #define ETL_INLINE_VEC_128 ETL_STATIC_INLINE(__m128)
46 #define ETL_INLINE_VEC_128D ETL_STATIC_INLINE(__m128d)
47 
48 namespace etl {
49 
50 #define ALIGN16_BEG
51 #define ALIGN16_END __attribute__((aligned(16)))
52 
53 /* declare some SSE constants -- why can't I figure a better way to do that? */
54 #define PS_CONST(Name, Val) static const ALIGN16_BEG float _ps_##Name[4] ALIGN16_END = {Val, Val, Val, Val}
55 #define PI32_CONST(Name, Val) static const ALIGN16_BEG int _pi32_##Name[4] ALIGN16_END = {Val, Val, Val, Val}
56 #define PS_CONST_TYPE(Name, Type, Val) static const ALIGN16_BEG Type _ps_##Name[4] ALIGN16_END = {Val, Val, Val, Val}
57 
58 PS_CONST(1, 1.0f);
59 PS_CONST(0p5, 0.5f);
60 
61 /* the smallest non denormalized float number */
62 PS_CONST_TYPE(min_norm_pos, int, 0x00800000);
63 PS_CONST_TYPE(mant_mask, int, 0x7f800000);
64 PS_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
65 
66 PS_CONST_TYPE(sign_mask, int, (int)0x80000000);
67 PS_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
68 
69 PI32_CONST(1, 1);
70 PI32_CONST(inv1, ~1);
71 PI32_CONST(2, 2);
72 PI32_CONST(4, 4);
73 PI32_CONST(0x7f, 0x7f);
74 
75 PS_CONST(cephes_SQRTHF, 0.707106781186547524);
76 PS_CONST(cephes_log_p0, 7.0376836292E-2);
77 PS_CONST(cephes_log_p1, -1.1514610310E-1);
78 PS_CONST(cephes_log_p2, 1.1676998740E-1);
79 PS_CONST(cephes_log_p3, -1.2420140846E-1);
80 PS_CONST(cephes_log_p4, +1.4249322787E-1);
81 PS_CONST(cephes_log_p5, -1.6668057665E-1);
82 PS_CONST(cephes_log_p6, +2.0000714765E-1);
83 PS_CONST(cephes_log_p7, -2.4999993993E-1);
84 PS_CONST(cephes_log_p8, +3.3333331174E-1);
85 PS_CONST(cephes_log_q1, -2.12194440e-4);
86 PS_CONST(cephes_log_q2, 0.693359375);
87 
93 ETL_INLINE_VEC_128 log_ps(__m128 x) {
94  __m128i emm0;
95  __m128 one = *(__m128*)_ps_1;
96 
97  __m128 invalid_mask = _mm_cmple_ps(x, _mm_setzero_ps());
98 
99  x = _mm_max_ps(x, *(__m128*)_ps_min_norm_pos); /* cut off denormalized stuff */
100 
101  emm0 = _mm_srli_epi32(_mm_castps_si128(x), 23);
102 
103  /* keep only the fractional part */
104  x = _mm_and_ps(x, *(__m128*)_ps_inv_mant_mask);
105  x = _mm_or_ps(x, *(__m128*)_ps_0p5);
106 
107  emm0 = _mm_sub_epi32(emm0, *(__m128i*)_pi32_0x7f);
108  __m128 e = _mm_cvtepi32_ps(emm0);
109 
110  e = _mm_add_ps(e, one);
111 
112  __m128 mask = _mm_cmplt_ps(x, *(__m128*)_ps_cephes_SQRTHF);
113  __m128 tmp = _mm_and_ps(x, mask);
114  x = _mm_sub_ps(x, one);
115  e = _mm_sub_ps(e, _mm_and_ps(one, mask));
116  x = _mm_add_ps(x, tmp);
117 
118  __m128 z = _mm_mul_ps(x, x);
119 
120  __m128 y = *(__m128*)_ps_cephes_log_p0;
121  y = _mm_mul_ps(y, x);
122  y = _mm_add_ps(y, *(__m128*)_ps_cephes_log_p1);
123  y = _mm_mul_ps(y, x);
124  y = _mm_add_ps(y, *(__m128*)_ps_cephes_log_p2);
125  y = _mm_mul_ps(y, x);
126  y = _mm_add_ps(y, *(__m128*)_ps_cephes_log_p3);
127  y = _mm_mul_ps(y, x);
128  y = _mm_add_ps(y, *(__m128*)_ps_cephes_log_p4);
129  y = _mm_mul_ps(y, x);
130  y = _mm_add_ps(y, *(__m128*)_ps_cephes_log_p5);
131  y = _mm_mul_ps(y, x);
132  y = _mm_add_ps(y, *(__m128*)_ps_cephes_log_p6);
133  y = _mm_mul_ps(y, x);
134  y = _mm_add_ps(y, *(__m128*)_ps_cephes_log_p7);
135  y = _mm_mul_ps(y, x);
136  y = _mm_add_ps(y, *(__m128*)_ps_cephes_log_p8);
137  y = _mm_mul_ps(y, x);
138 
139  y = _mm_mul_ps(y, z);
140 
141  tmp = _mm_mul_ps(e, *(__m128*)_ps_cephes_log_q1);
142  y = _mm_add_ps(y, tmp);
143 
144  tmp = _mm_mul_ps(z, *(__m128*)_ps_0p5);
145  y = _mm_sub_ps(y, tmp);
146 
147  tmp = _mm_mul_ps(e, *(__m128*)_ps_cephes_log_q2);
148  x = _mm_add_ps(x, y);
149  x = _mm_add_ps(x, tmp);
150  x = _mm_or_ps(x, invalid_mask); // negative arg will be NAN
151  return x;
152 }
153 
154 PS_CONST(exp_hi, 88.3762626647949f);
155 PS_CONST(exp_lo, -88.3762626647949f);
156 
157 PS_CONST(cephes_LOG2EF, 1.44269504088896341);
158 PS_CONST(cephes_exp_C1, 0.693359375);
159 PS_CONST(cephes_exp_C2, -2.12194440e-4);
160 
161 PS_CONST(cephes_exp_p0, 1.9875691500E-4);
162 PS_CONST(cephes_exp_p1, 1.3981999507E-3);
163 PS_CONST(cephes_exp_p2, 8.3334519073E-3);
164 PS_CONST(cephes_exp_p3, 4.1665795894E-2);
165 PS_CONST(cephes_exp_p4, 1.6666665459E-1);
166 PS_CONST(cephes_exp_p5, 5.0000001201E-1);
167 
173 ETL_INLINE_VEC_128D exp_pd(__m128d x) {
174  const __m128i offset = _mm_setr_epi32(1023, 1023, 0, 0);
175 
176  __m128i k1;
177  __m128d p1;
178  __m128d a1;
179  __m128d x1;
180 
181  auto xmm0 = _mm_set1_pd(7.09782712893383996843e2);
182  auto xmm1 = _mm_set1_pd(-7.08396418532264106224e2);
183 
184  x1 = _mm_min_pd(x, xmm0);
185  x1 = _mm_max_pd(x1, xmm1);
186 
187  /* a = x / log2 */
188  xmm0 = _mm_set1_pd(1.4426950408889634073599);
189  xmm1 = _mm_setzero_pd();
190  a1 = _mm_mul_pd(x1, xmm0);
191 
192  /* k = (int)floor(a) p = (float)k */
193  p1 = _mm_cmplt_pd(a1, xmm1);
194  xmm0 = _mm_set1_pd(1.0);
195  p1 = _mm_and_pd(p1, xmm0);
196  a1 = _mm_sub_pd(a1, p1);
197  k1 = _mm_cvttpd_epi32(a1);
198  p1 = _mm_cvtepi32_pd(k1);
199 
200  /* x -= p * log2 */
201  xmm0 = _mm_set1_pd(6.93145751953125E-1);
202  xmm1 = _mm_set1_pd(1.42860682030941723212E-6);
203 
204 #ifdef __FMA__
205  x1 = _mm_fnmadd_pd(p1, xmm0, x1);
206  x1 = _mm_fnmadd_pd(p1, xmm1, x1);
207 #else
208  a1 = _mm_mul_pd(p1, xmm0);
209  x1 = _mm_sub_pd(x1, a1);
210  a1 = _mm_mul_pd(p1, xmm1);
211  x1 = _mm_sub_pd(x1, a1);
212 #endif
213 
214  /* Compute e^x using a polynomial approximation. */
215  xmm0 = _mm_set1_pd(1.185268231308989403584147407056378360798378534739e-2);
216  xmm1 = _mm_set1_pd(3.87412011356070379615759057344100690905653320886699e-2);
217 
218 #ifdef __FMA__
219  a1 = _mm_fmadd_pd(x1, xmm0, xmm1);
220 #else
221  a1 = _mm_mul_pd(x1, xmm0);
222  a1 = _mm_add_pd(a1, xmm1);
223 #endif
224 
225  xmm0 = _mm_set1_pd(0.16775408658617866431779970932853611481292418818223);
226  xmm1 = _mm_set1_pd(0.49981934577169208735732248650232562589934399402426);
227 
228 #ifdef __FMA__
229  a1 = _mm_fmadd_pd(a1, x1, xmm0);
230  a1 = _mm_fmadd_pd(a1, x1, xmm1);
231 #else
232  a1 = _mm_mul_pd(a1, x1);
233  a1 = _mm_add_pd(a1, xmm0);
234  a1 = _mm_mul_pd(a1, x1);
235  a1 = _mm_add_pd(a1, xmm1);
236 #endif
237 
238  xmm0 = _mm_set1_pd(1.00001092396453942157124178508842412412025643386873);
239  xmm1 = _mm_set1_pd(0.99999989311082729779536722205742989232069120354073);
240 
241 #ifdef __FMA__
242  a1 = _mm_fmadd_pd(a1, x1, xmm0);
243  a1 = _mm_fmadd_pd(a1, x1, xmm1);
244 #else
245  a1 = _mm_mul_pd(a1, x1);
246  a1 = _mm_add_pd(a1, xmm0);
247  a1 = _mm_mul_pd(a1, x1);
248  a1 = _mm_add_pd(a1, xmm1);
249 #endif
250 
251  /* p = 2^k */
252  k1 = _mm_add_epi32(k1, offset);
253  k1 = _mm_slli_epi32(k1, 20);
254  k1 = _mm_shuffle_epi32(k1, _MM_SHUFFLE(1, 3, 0, 2));
255  p1 = _mm_castsi128_pd(k1);
256 
257  /* a *= 2^k */
258  a1 = _mm_mul_pd(a1, p1);
259 
260  return a1;
261 }
262 
268 ETL_INLINE_VEC_128 exp_ps(__m128 x) {
269  __m128 tmp, fx;
270  __m128i emm0;
271  __m128 one = *(__m128*)_ps_1;
272 
273  x = _mm_min_ps(x, *(__m128*)_ps_exp_hi);
274  x = _mm_max_ps(x, *(__m128*)_ps_exp_lo);
275 
276  /* express exp(x) as exp(g + n*log(2)) */
277  fx = _mm_mul_ps(x, *(__m128*)_ps_cephes_LOG2EF);
278  fx = _mm_add_ps(fx, *(__m128*)_ps_0p5);
279 
280  /* how to perform a floorf with SSE: just below */
281  emm0 = _mm_cvttps_epi32(fx);
282  tmp = _mm_cvtepi32_ps(emm0);
283  /* if greater, substract 1 */
284  __m128 mask = _mm_cmpgt_ps(tmp, fx);
285  mask = _mm_and_ps(mask, one);
286  fx = _mm_sub_ps(tmp, mask);
287 
288  tmp = _mm_mul_ps(fx, *(__m128*)_ps_cephes_exp_C1);
289  __m128 z = _mm_mul_ps(fx, *(__m128*)_ps_cephes_exp_C2);
290  x = _mm_sub_ps(x, tmp);
291  x = _mm_sub_ps(x, z);
292 
293  z = _mm_mul_ps(x, x);
294 
295  __m128 y = *(__m128*)_ps_cephes_exp_p0;
296  y = _mm_mul_ps(y, x);
297  y = _mm_add_ps(y, *(__m128*)_ps_cephes_exp_p1);
298  y = _mm_mul_ps(y, x);
299  y = _mm_add_ps(y, *(__m128*)_ps_cephes_exp_p2);
300  y = _mm_mul_ps(y, x);
301  y = _mm_add_ps(y, *(__m128*)_ps_cephes_exp_p3);
302  y = _mm_mul_ps(y, x);
303  y = _mm_add_ps(y, *(__m128*)_ps_cephes_exp_p4);
304  y = _mm_mul_ps(y, x);
305  y = _mm_add_ps(y, *(__m128*)_ps_cephes_exp_p5);
306  y = _mm_mul_ps(y, z);
307  y = _mm_add_ps(y, x);
308  y = _mm_add_ps(y, one);
309 
310  /* build 2^n */
311  emm0 = _mm_cvttps_epi32(fx);
312  emm0 = _mm_add_epi32(emm0, *(__m128i*)_pi32_0x7f);
313  emm0 = _mm_slli_epi32(emm0, 23);
314  __m128 pow2n = _mm_castsi128_ps(emm0);
315  y = _mm_mul_ps(y, pow2n);
316  return y;
317 }
318 
319 PS_CONST(minus_cephes_DP1, -0.78515625);
320 PS_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
321 PS_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
322 PS_CONST(sincof_p0, -1.9515295891E-4);
323 PS_CONST(sincof_p1, 8.3321608736E-3);
324 PS_CONST(sincof_p2, -1.6666654611E-1);
325 PS_CONST(coscof_p0, 2.443315711809948E-005);
326 PS_CONST(coscof_p1, -1.388731625493765E-003);
327 PS_CONST(coscof_p2, 4.166664568298827E-002);
328 PS_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
329 
335 ETL_INLINE_VEC_128 sin_ps(__m128 x) { // any x
336  __m128 xmm1, xmm2, xmm3, sign_bit, y;
337 
338  __m128i emm0, emm2;
339  sign_bit = x;
340  /* take the absolute value */
341  x = _mm_and_ps(x, *(__m128*)_ps_inv_sign_mask);
342  /* extract the sign bit (upper one) */
343  sign_bit = _mm_and_ps(sign_bit, *(__m128*)_ps_sign_mask);
344 
345  /* scale by 4/Pi */
346  y = _mm_mul_ps(x, *(__m128*)_ps_cephes_FOPI);
347 
348  /* store the integer part of y in mm0 */
349  emm2 = _mm_cvttps_epi32(y);
350  emm2 = _mm_add_epi32(emm2, *(__m128i*)_pi32_1);
351  emm2 = _mm_and_si128(emm2, *(__m128i*)_pi32_inv1);
352  y = _mm_cvtepi32_ps(emm2);
353 
354  /* get the swap sign flag */
355  emm0 = _mm_and_si128(emm2, *(__m128i*)_pi32_4);
356  emm0 = _mm_slli_epi32(emm0, 29);
357 
358  emm2 = _mm_and_si128(emm2, *(__m128i*)_pi32_2);
359  emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
360 
361  __m128 swap_sign_bit = _mm_castsi128_ps(emm0);
362  __m128 poly_mask = _mm_castsi128_ps(emm2);
363  sign_bit = _mm_xor_ps(sign_bit, swap_sign_bit);
364 
365  /* The magic pass: "Extended precision modular arithmetic"
366  x = ((x - y * DP1) - y * DP2) - y * DP3; */
367  xmm1 = *(__m128*)_ps_minus_cephes_DP1;
368  xmm2 = *(__m128*)_ps_minus_cephes_DP2;
369  xmm3 = *(__m128*)_ps_minus_cephes_DP3;
370  xmm1 = _mm_mul_ps(y, xmm1);
371  xmm2 = _mm_mul_ps(y, xmm2);
372  xmm3 = _mm_mul_ps(y, xmm3);
373  x = _mm_add_ps(x, xmm1);
374  x = _mm_add_ps(x, xmm2);
375  x = _mm_add_ps(x, xmm3);
376 
377  /* Evaluate the first polynom (0 <= x <= Pi/4) */
378  y = *(__m128*)_ps_coscof_p0;
379  __m128 z = _mm_mul_ps(x, x);
380 
381  y = _mm_mul_ps(y, z);
382  y = _mm_add_ps(y, *(__m128*)_ps_coscof_p1);
383  y = _mm_mul_ps(y, z);
384  y = _mm_add_ps(y, *(__m128*)_ps_coscof_p2);
385  y = _mm_mul_ps(y, z);
386  y = _mm_mul_ps(y, z);
387  __m128 tmp = _mm_mul_ps(z, *(__m128*)_ps_0p5);
388  y = _mm_sub_ps(y, tmp);
389  y = _mm_add_ps(y, *(__m128*)_ps_1);
390 
391  /* Evaluate the second polynom (Pi/4 <= x <= 0) */
392 
393  __m128 y2 = *(__m128*)_ps_sincof_p0;
394  y2 = _mm_mul_ps(y2, z);
395  y2 = _mm_add_ps(y2, *(__m128*)_ps_sincof_p1);
396  y2 = _mm_mul_ps(y2, z);
397  y2 = _mm_add_ps(y2, *(__m128*)_ps_sincof_p2);
398  y2 = _mm_mul_ps(y2, z);
399  y2 = _mm_mul_ps(y2, x);
400  y2 = _mm_add_ps(y2, x);
401 
402  /* select the correct result from the two polynoms */
403  xmm3 = poly_mask;
404  y2 = _mm_and_ps(xmm3, y2);
405  y = _mm_andnot_ps(xmm3, y);
406  y = _mm_add_ps(y, y2);
407  /* update the sign */
408  y = _mm_xor_ps(y, sign_bit);
409  return y;
410 }
411 
417 ETL_INLINE_VEC_128 cos_ps(__m128 x) { // any x
418  __m128 xmm1, xmm2, xmm3, y;
419  __m128i emm0, emm2;
420  /* take the absolute value */
421  x = _mm_and_ps(x, *(__m128*)_ps_inv_sign_mask);
422 
423  /* scale by 4/Pi */
424  y = _mm_mul_ps(x, *(__m128*)_ps_cephes_FOPI);
425 
426  /* store the integer part of y in mm0 */
427  emm2 = _mm_cvttps_epi32(y);
428  emm2 = _mm_add_epi32(emm2, *(__m128i*)_pi32_1);
429  emm2 = _mm_and_si128(emm2, *(__m128i*)_pi32_inv1);
430  y = _mm_cvtepi32_ps(emm2);
431 
432  emm2 = _mm_sub_epi32(emm2, *(__m128i*)_pi32_2);
433 
434  /* get the swap sign flag */
435  emm0 = _mm_andnot_si128(emm2, *(__m128i*)_pi32_4);
436  emm0 = _mm_slli_epi32(emm0, 29);
437  /* get the polynom selection mask */
438  emm2 = _mm_and_si128(emm2, *(__m128i*)_pi32_2);
439  emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
440 
441  __m128 sign_bit = _mm_castsi128_ps(emm0);
442  __m128 poly_mask = _mm_castsi128_ps(emm2);
443 
444  // The magic pass: "Extended precision modular arithmetic"
445  xmm1 = *(__m128*)_ps_minus_cephes_DP1;
446  xmm2 = *(__m128*)_ps_minus_cephes_DP2;
447  xmm3 = *(__m128*)_ps_minus_cephes_DP3;
448  xmm1 = _mm_mul_ps(y, xmm1);
449  xmm2 = _mm_mul_ps(y, xmm2);
450  xmm3 = _mm_mul_ps(y, xmm3);
451  x = _mm_add_ps(x, xmm1);
452  x = _mm_add_ps(x, xmm2);
453  x = _mm_add_ps(x, xmm3);
454 
455  /* Evaluate the first polynom (0 <= x <= Pi/4) */
456  y = *(__m128*)_ps_coscof_p0;
457  __m128 z = _mm_mul_ps(x, x);
458 
459  y = _mm_mul_ps(y, z);
460  y = _mm_add_ps(y, *(__m128*)_ps_coscof_p1);
461  y = _mm_mul_ps(y, z);
462  y = _mm_add_ps(y, *(__m128*)_ps_coscof_p2);
463  y = _mm_mul_ps(y, z);
464  y = _mm_mul_ps(y, z);
465  __m128 tmp = _mm_mul_ps(z, *(__m128*)_ps_0p5);
466  y = _mm_sub_ps(y, tmp);
467  y = _mm_add_ps(y, *(__m128*)_ps_1);
468 
469  /* Evaluate the second polynom (Pi/4 <= x <= 0) */
470 
471  __m128 y2 = *(__m128*)_ps_sincof_p0;
472  y2 = _mm_mul_ps(y2, z);
473  y2 = _mm_add_ps(y2, *(__m128*)_ps_sincof_p1);
474  y2 = _mm_mul_ps(y2, z);
475  y2 = _mm_add_ps(y2, *(__m128*)_ps_sincof_p2);
476  y2 = _mm_mul_ps(y2, z);
477  y2 = _mm_mul_ps(y2, x);
478  y2 = _mm_add_ps(y2, x);
479 
480  /* select the correct result from the two polynoms */
481  xmm3 = poly_mask;
482  y2 = _mm_and_ps(xmm3, y2);
483  y = _mm_andnot_ps(xmm3, y);
484  y = _mm_add_ps(y, y2);
485  /* update the sign */
486  y = _mm_xor_ps(y, sign_bit);
487 
488  return y;
489 }
490 
491 } //end of namespace etl
492 
493 #endif //__SSE3__
Root namespace for the ETL library.
Definition: adapter.hpp:15