elpa2_kernels_real_avx512_2hv_single_precision.c 23.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
//    This file is part of ELPA.
//
//    The ELPA library was originally created by the ELPA consortium,
//    consisting of the following organizations:
//
//    - Max Planck Computing and Data Facility (MPCDF), formerly known as
//      Rechenzentrum Garching der Max-Planck-Gesellschaft (RZG),
//    - Bergische Universität Wuppertal, Lehrstuhl für angewandte
//      Informatik,
//    - Technische Universität München, Lehrstuhl für Informatik mit
//      Schwerpunkt Wissenschaftliches Rechnen ,
//    - Fritz-Haber-Institut, Berlin, Abt. Theorie,
//    - Max-Plack-Institut für Mathematik in den Naturwissenschaftrn,
//      Leipzig, Abt. Komplexe Strukutren in Biologie und Kognition,
//      and
//    - IBM Deutschland GmbH
//
//    This particular source code file contains additions, changes and
//    enhancements authored by Intel Corporation which is not part of
//    the ELPA consortium.
//
//    More information can be found here:
//    http://elpa.mpcdf.mpg.de/
//
//    ELPA is free software: you can redistribute it and/or modify
//    it under the terms of the version 3 of the license of the
//    GNU Lesser General Public License as published by the Free
//    Software Foundation.
//
//    ELPA is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public License
//    along with ELPA.  If not, see <http://www.gnu.org/licenses/>
//
//    ELPA reflects a substantial effort on the part of the original
//    ELPA consortium, and we ask you to respect the spirit of the
//    license that we chose: i.e., please contribute any changes you
//    may have back to the original ELPA library distribution, and keep
//    any derivatives of ELPA under the same license that we chose for
//    the original distribution, the GNU Lesser General Public License.
//
45
46
47
// Author: Andreas Marek (andreas.marek@mpcdf.mpg.de)
// --------------------------------------------------------------------------------------------------

48
#include "config-f90.h"
49

50
51
52
53
#include <x86intrin.h>

#define __forceinline __attribute__((always_inline)) static

54
#ifdef HAVE_AVX512
55
#define __ELPA_USE_FMA__
56
#define _mm512_FMA_ps(a,b,c) _mm512_fmadd_ps(a,b,c)
57
58
59
60
#endif


//Forward declaration
61
//__forceinline void hh_trafo_kernel_8_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s);
62
__forceinline void hh_trafo_kernel_16_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s);
63
64
65
66
//__forceinline void hh_trafo_kernel_24_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s);
__forceinline void hh_trafo_kernel_32_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s);
__forceinline void hh_trafo_kernel_48_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s);
__forceinline void hh_trafo_kernel_64_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s);
67
68
69

void double_hh_trafo_real_avx512_2hv_single(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh);
/*
70
!f>#if defined(HAVE_AVX512)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
!f> interface
!f>   subroutine double_hh_trafo_real_avx512_2hv_single(q, hh, pnb, pnq, pldq, pldh) &
!f>                             bind(C, name="double_hh_trafo_real_avx512_2hv_single")
!f>     use, intrinsic :: iso_c_binding
!f>     integer(kind=c_int)     :: pnb, pnq, pldq, pldh
!f>     type(c_ptr), value      :: q
!f>     real(kind=c_float)      :: hh(pnb,6)
!f>   end subroutine
!f> end interface
!f>#endif
*/

void double_hh_trafo_real_avx512_2hv_single(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh)
{
	int i;
	int nb = *pnb;
	int nq = *pldq;
	int ldq = *pldq;
	int ldh = *pldh;

	// calculating scalar product to compute
	// 2 householder vectors simultaneously
	float s = hh[(ldh)+1]*1.0;

	#pragma ivdep
	for (i = 2; i < nb; i++)
	{
		s += hh[i-1] * hh[(i+ldh)];
	}

	// Production level kernel calls with padding
102
	for (i = 0; i < nq-48; i+=64)
103
	{
104
		hh_trafo_kernel_64_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s);
105
	}
106

107
108
109
110
111
	if (nq == i)
	{
		return;
	}

112
	if (nq-i == 48)
113
	{
114
		hh_trafo_kernel_48_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s);
115
	}
116
	else if (nq-i == 32)
117
	{
118
		hh_trafo_kernel_32_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s);
119
	}
120

121
122
	else
	{
123
		hh_trafo_kernel_16_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s);
124
125
126
127
	}
}
/**
 * Unrolled kernel that computes
128
 * 64 rows of Q simultaneously, a
129
130
131
 * matrix vector product with two householder
 * vectors + a rank 2 update is performed
 */
132
 __forceinline void hh_trafo_kernel_64_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s)
133
134
135
136
137
138
139
{
	/////////////////////////////////////////////////////
	// Matrix Vector Multiplication, Q [24 x nb+1] * hh
	// hh contains two householder vectors, with offset 1
	/////////////////////////////////////////////////////
	int i;
	// Needed bit mask for floating point sign flip
140
141
	// carefull here
        __m512 sign = (__m512d)_mm512_set1_epi32(0x80000000);
142

143
144
145
146
	__m512 x1 = _mm512_load_ps(&q[ldq]);
	__m512 x2 = _mm512_load_ps(&q[ldq+32]);
	__m512 x3 = _mm512_load_ps(&q[ldq+48]);
	__m512 x4 = _mm512_load_ps(&q[ldq+64]);
147
148


149
150
	__m512 h1 = _mm512_set1_ps(hh[ldh+1]);
	__m512 h2;
151

152
153
154
155
156
157
158
159
	__m512 q1 = _mm512_load_ps(q);
	__m512 y1 = _mm512_FMA_ps(x1, h1, q1);
	__m512 q2 = _mm512_load_ps(&q[16]);
	__m512 y2 = _mm512_FMA_ps(x2, h1, q2);
	__m512 q3 = _mm512_load_ps(&q[32]);
	__m512 y3 = _mm512_FMA_ps(x3, h1, q3);
	__m512 q4 = _mm512_load_ps(&q[48]);
	__m512 y4 = _mm512_FMA_ps(x4, h1, q4);
160
161
162

	for(i = 2; i < nb; i++)
	{
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
		h1 = _mm512_set1_ps(hh[i-1]);
		h2 = _mm512_set1_ps(hh[ldh+i]);

		q1 = _mm512_load_ps(&q[i*ldq]);
		x1 = _mm512_FMA_ps(q1, h1, x1);
		y1 = _mm512_FMA_ps(q1, h2, y1);
		q2 = _mm512_load_ps(&q[(i*ldq)+16]);
		x2 = _mm512_FMA_ps(q2, h1, x2);
		y2 = _mm512_FMA_ps(q2, h2, y2);
		q3 = _mm512_load_ps(&q[(i*ldq)+32]);
		x3 = _mm512_FMA_ps(q3, h1, x3);
		y3 = _mm512_FMA_ps(q3, h2, y3);
		q4 = _mm512_load_ps(&q[(i*ldq)+48]);
		x4 = _mm512_FMA_ps(q4, h1, x4);
		y4 = _mm512_FMA_ps(q4, h2, y4);

179
180
	}

181
182
183
184
185
186
187
188
189
190
191
	h1 = _mm512_set1_ps(hh[nb-1]);

	q1 = _mm512_load_ps(&q[nb*ldq]);
	x1 = _mm512_FMA_ps(q1, h1, x1);
	q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
	x2 = _mm512_FMA_ps(q2, h1, x2);
	q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
	x3 = _mm512_FMA_ps(q3, h1, x3);
	q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
	x4 = _mm512_FMA_ps(q4, h1, x4);

192
193

	/////////////////////////////////////////////////////
194
	// Rank-2 update of Q [24 x nb+1]
195
196
	/////////////////////////////////////////////////////

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
	__m512 tau1 = _mm512_set1_ps(hh[0]);
	__m512 tau2 = _mm512_set1_ps(hh[ldh]);
	__m512 vs = _mm512_set1_ps(s);

	h1 = (__m512) _mm512_xor_epi32((__m512i) tau1, (__m512i) sign);
	x1 = _mm512_mul_ps(x1, h1);
	x2 = _mm512_mul_ps(x2, h1);
	x3 = _mm512_mul_ps(x3, h1);
	x4 = _mm512_mul_ps(x4, h1);

	h1 = (__m512) _mm512_xor_epi32((__m512i) tau2, (__m512i) sign);
	h2 = _mm512_mul_ps(h1, vs);
	y1 = _mm512_FMA_ps(y1, h1, _mm512_mul_ps(x1,h2));
	y2 = _mm512_FMA_ps(y2, h1, _mm512_mul_ps(x2,h2));
	y3 = _mm512_FMA_ps(y3, h1, _mm512_mul_ps(x3,h2));
	y4 = _mm512_FMA_ps(y4, h1, _mm512_mul_ps(x4,h2));

	q1 = _mm512_load_ps(q);
	q1 = _mm512_add_ps(q1, y1);
	_mm512_store_ps(q,q1);
	q2 = _mm512_load_ps(&q[16]);
	q2 = _mm512_add_ps(q2, y2);
	_mm512_store_ps(&q[16],q2);
	q3 = _mm512_load_ps(&q[32]);
	q3 = _mm512_add_ps(q3, y3);
	_mm512_store_ps(&q[32],q3);
	q4 = _mm512_load_ps(&q[48]);
	q4 = _mm512_add_ps(q4, y4);
	_mm512_store_ps(&q[48],q4);

	h2 = _mm512_set1_ps(hh[ldh+1]);

	q1 = _mm512_load_ps(&q[ldq]);
	q1 = _mm512_add_ps(q1, _mm512_FMA_ps(y1, h2, x1));
	_mm512_store_ps(&q[ldq],q1);
	q2 = _mm512_load_ps(&q[ldq+16]);
	q2 = _mm512_add_ps(q2, _mm512_FMA_ps(y2, h2, x2));
	_mm512_store_ps(&q[ldq+16],q2);
	q3 = _mm512_load_ps(&q[ldq+32]);
	q3 = _mm512_add_ps(q3, _mm512_FMA_ps(y3, h2, x3));
	_mm512_store_ps(&q[ldq+32],q3);
	q4 = _mm512_load_ps(&q[ldq+48]);
	q4 = _mm512_add_ps(q4, _mm512_FMA_ps(y4, h2, x4));
	_mm512_store_ps(&q[ldq+48],q4);
241
242
243

	for (i = 2; i < nb; i++)
	{
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
		h1 = _mm512_set1_ps(hh[i-1]);
		h2 = _mm512_set1_ps(hh[ldh+i]);

		q1 = _mm512_load_ps(&q[i*ldq]);
		q1 = _mm512_FMA_ps(x1, h1, q1);
		q1 = _mm512_FMA_ps(y1, h2, q1);
		_mm512_store_ps(&q[i*ldq],q1);
		q2 = _mm512_load_ps(&q[(i*ldq)+16]);
		q2 = _mm512_FMA_ps(x2, h1, q2);
		q2 = _mm512_FMA_ps(y2, h2, q2);
		_mm512_store_ps(&q[(i*ldq)+16],q2);
		q3 = _mm512_load_ps(&q[(i*ldq)+32]);
		q3 = _mm512_FMA_ps(x3, h1, q3);
		q3 = _mm512_FMA_ps(y3, h2, q3);
		_mm512_store_ps(&q[(i*ldq)+32],q3);
		q4 = _mm512_load_ps(&q[(i*ldq)+48]);
		q4 = _mm512_FMA_ps(x4, h1, q4);
		q4 = _mm512_FMA_ps(y4, h2, q4);
		_mm512_store_ps(&q[(i*ldq)+48],q4);

264
265
	}

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
	h1 = _mm512_set1_ps(hh[nb-1]);

	q1 = _mm512_load_ps(&q[nb*ldq]);
	q1 = _mm512_FMA_ps(x1, h1, q1);
	_mm512_store_ps(&q[nb*ldq],q1);
	q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
	q2 = _mm512_FMA_ps(x2, h1, q2);
	_mm512_store_ps(&q[(nb*ldq)+16],q2);
	q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
	q3 = _mm512_FMA_ps(x3, h1, q3);
	_mm512_store_ps(&q[(nb*ldq)+32],q3);
	q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
	q4 = _mm512_FMA_ps(x4, h1, q4);
	_mm512_store_ps(&q[(nb*ldq)+48],q4);

281
282
283
284
}

/**
 * Unrolled kernel that computes
285
 * 48 rows of Q simultaneously, a
286
287
288
 * matrix vector product with two householder
 * vectors + a rank 2 update is performed
 */
289
 __forceinline void hh_trafo_kernel_48_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s)
290
291
{
	/////////////////////////////////////////////////////
292
	// Matrix Vector Multiplication, Q [24 x nb+1] * hh
293
294
295
296
	// hh contains two householder vectors, with offset 1
	/////////////////////////////////////////////////////
	int i;
	// Needed bit mask for floating point sign flip
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
	// carefull here
        __m512 sign = (__m512)_mm512_set1_epi32(0x80000000);

	__m512 x1 = _mm512_load_ps(&q[ldq]);
	__m512 x2 = _mm512_load_ps(&q[ldq+32]);
	__m512 x3 = _mm512_load_ps(&q[ldq+48]);
//	__m512 x4 = _mm512_load_ps(&q[ldq+64]);


	__m512 h1 = _mm512_set1_ps(hh[ldh+1]);
	__m512 h2;

	__m512 q1 = _mm512_load_ps(q);
	__m512 y1 = _mm512_FMA_ps(x1, h1, q1);
	__m512 q2 = _mm512_load_ps(&q[16]);
	__m512 y2 = _mm512_FMA_ps(x2, h1, q2);
	__m512 q3 = _mm512_load_ps(&q[32]);
	__m512 y3 = _mm512_FMA_ps(x3, h1, q3);
//	__m512 q4 = _mm512_load_ps(&q[48]);
//	__m512 y4 = _mm512_FMA_ps(x4, h1, q4);
317
318
319

	for(i = 2; i < nb; i++)
	{
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
		h1 = _mm512_set1_ps(hh[i-1]);
		h2 = _mm512_set1_ps(hh[ldh+i]);

		q1 = _mm512_load_ps(&q[i*ldq]);
		x1 = _mm512_FMA_ps(q1, h1, x1);
		y1 = _mm512_FMA_ps(q1, h2, y1);
		q2 = _mm512_load_ps(&q[(i*ldq)+16]);
		x2 = _mm512_FMA_ps(q2, h1, x2);
		y2 = _mm512_FMA_ps(q2, h2, y2);
		q3 = _mm512_load_ps(&q[(i*ldq)+32]);
		x3 = _mm512_FMA_ps(q3, h1, x3);
		y3 = _mm512_FMA_ps(q3, h2, y3);
//		q4 = _mm512_load_ps(&q[(i*ldq)+48]);
//		x4 = _mm512_FMA_ps(q4, h1, x4);
//		y4 = _mm512_FMA_ps(q4, h2, y4);

336
337
	}

338
339
340
341
342
343
344
345
346
347
348
	h1 = _mm512_set1_ps(hh[nb-1]);

	q1 = _mm512_load_ps(&q[nb*ldq]);
	x1 = _mm512_FMA_ps(q1, h1, x1);
	q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
	x2 = _mm512_FMA_ps(q2, h1, x2);
	q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
	x3 = _mm512_FMA_ps(q3, h1, x3);
//	q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
//	x4 = _mm512_FMA_ps(q4, h1, x4);

349
350

	/////////////////////////////////////////////////////
351
	// Rank-2 update of Q [24 x nb+1]
352
353
	/////////////////////////////////////////////////////

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
	__m512 tau1 = _mm512_set1_ps(hh[0]);
	__m512 tau2 = _mm512_set1_ps(hh[ldh]);
	__m512 vs = _mm512_set1_ps(s);

	h1 = (__m512) _mm512_xor_epi32((__m512i) tau1, (__m512i) sign);
	x1 = _mm512_mul_ps(x1, h1);
	x2 = _mm512_mul_ps(x2, h1);
	x3 = _mm512_mul_ps(x3, h1);
//	x4 = _mm512_mul_ps(x4, h1);

	h1 = (__m512) _mm512_xor_epi32((__m512i) tau2, (__m512i) sign);
	h2 = _mm512_mul_ps(h1, vs);
	y1 = _mm512_FMA_ps(y1, h1, _mm512_mul_ps(x1,h2));
	y2 = _mm512_FMA_ps(y2, h1, _mm512_mul_ps(x2,h2));
	y3 = _mm512_FMA_ps(y3, h1, _mm512_mul_ps(x3,h2));
//	y4 = _mm512_FMA_ps(y4, h1, _mm512_mul_ps(x4,h2));

	q1 = _mm512_load_ps(q);
	q1 = _mm512_add_ps(q1, y1);
	_mm512_store_ps(q,q1);
	q2 = _mm512_load_ps(&q[16]);
	q2 = _mm512_add_ps(q2, y2);
	_mm512_store_ps(&q[16],q2);
	q3 = _mm512_load_ps(&q[32]);
	q3 = _mm512_add_ps(q3, y3);
	_mm512_store_ps(&q[32],q3);
//	q4 = _mm512_load_ps(&q[48]);
//	q4 = _mm512_add_ps(q4, y4);
//	_mm512_store_ps(&q[48],q4);

	h2 = _mm512_set1_ps(hh[ldh+1]);

	q1 = _mm512_load_ps(&q[ldq]);
	q1 = _mm512_add_ps(q1, _mm512_FMA_ps(y1, h2, x1));
	_mm512_store_ps(&q[ldq],q1);
	q2 = _mm512_load_ps(&q[ldq+16]);
	q2 = _mm512_add_ps(q2, _mm512_FMA_ps(y2, h2, x2));
	_mm512_store_ps(&q[ldq+16],q2);
	q3 = _mm512_load_ps(&q[ldq+32]);
	q3 = _mm512_add_ps(q3, _mm512_FMA_ps(y3, h2, x3));
	_mm512_store_ps(&q[ldq+32],q3);
//	q4 = _mm512_load_ps(&q[ldq+48]);
//	q4 = _mm512_add_ps(q4, _mm512_FMA_ps(y4, h2, x4));
//	_mm512_store_ps(&q[ldq+48],q4);
398
399
400

	for (i = 2; i < nb; i++)
	{
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
		h1 = _mm512_set1_ps(hh[i-1]);
		h2 = _mm512_set1_ps(hh[ldh+i]);

		q1 = _mm512_load_ps(&q[i*ldq]);
		q1 = _mm512_FMA_ps(x1, h1, q1);
		q1 = _mm512_FMA_ps(y1, h2, q1);
		_mm512_store_ps(&q[i*ldq],q1);
		q2 = _mm512_load_ps(&q[(i*ldq)+16]);
		q2 = _mm512_FMA_ps(x2, h1, q2);
		q2 = _mm512_FMA_ps(y2, h2, q2);
		_mm512_store_ps(&q[(i*ldq)+16],q2);
		q3 = _mm512_load_ps(&q[(i*ldq)+32]);
		q3 = _mm512_FMA_ps(x3, h1, q3);
		q3 = _mm512_FMA_ps(y3, h2, q3);
		_mm512_store_ps(&q[(i*ldq)+32],q3);
//		q4 = _mm512_load_ps(&q[(i*ldq)+48]);
//		q4 = _mm512_FMA_ps(x4, h1, q4);
//		q4 = _mm512_FMA_ps(y4, h2, q4);
//		_mm512_store_ps(&q[(i*ldq)+48],q4);

421
422
	}

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
	h1 = _mm512_set1_ps(hh[nb-1]);

	q1 = _mm512_load_ps(&q[nb*ldq]);
	q1 = _mm512_FMA_ps(x1, h1, q1);
	_mm512_store_ps(&q[nb*ldq],q1);
	q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
	q2 = _mm512_FMA_ps(x2, h1, q2);
	_mm512_store_ps(&q[(nb*ldq)+16],q2);
	q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
	q3 = _mm512_FMA_ps(x3, h1, q3);
	_mm512_store_ps(&q[(nb*ldq)+32],q3);
//	q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
//	q4 = _mm512_FMA_ps(x4, h1, q4);
//	_mm512_store_ps(&q[(nb*ldq)+48],q4);

438
439
}

440

441
442
/**
 * Unrolled kernel that computes
443
 * 32 rows of Q simultaneously, a
444
445
446
 * matrix vector product with two householder
 * vectors + a rank 2 update is performed
 */
447
 __forceinline void hh_trafo_kernel_32_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s)
448
449
{
	/////////////////////////////////////////////////////
450
	// Matrix Vector Multiplication, Q [24 x nb+1] * hh
451
452
453
454
455
	// hh contains two householder vectors, with offset 1
	/////////////////////////////////////////////////////
	int i;
	// Needed bit mask for floating point sign flip
	// carefull here
456
        __m512 sign = (__m512)_mm512_set1_epi32(0x80000000);
457

458
459
460
461
	__m512 x1 = _mm512_load_ps(&q[ldq]);
	__m512 x2 = _mm512_load_ps(&q[ldq+32]);
//	__m512 x3 = _mm512_load_ps(&q[ldq+48]);
//	__m512 x4 = _mm512_load_ps(&q[ldq+64]);
462
463


464
465
466
467
468
469
470
471
472
473
474
	__m512 h1 = _mm512_set1_ps(hh[ldh+1]);
	__m512 h2;

	__m512 q1 = _mm512_load_ps(q);
	__m512 y1 = _mm512_FMA_ps(x1, h1, q1);
	__m512 q2 = _mm512_load_ps(&q[16]);
	__m512 y2 = _mm512_FMA_ps(x2, h1, q2);
//	__m512 q3 = _mm512_load_ps(&q[32]);
//	__m512 y3 = _mm512_FMA_ps(x3, h1, q3);
//	__m512 q4 = _mm512_load_ps(&q[48]);
//	__m512 y4 = _mm512_FMA_ps(x4, h1, q4);
475
476
477

	for(i = 2; i < nb; i++)
	{
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
		h1 = _mm512_set1_ps(hh[i-1]);
		h2 = _mm512_set1_ps(hh[ldh+i]);

		q1 = _mm512_load_ps(&q[i*ldq]);
		x1 = _mm512_FMA_ps(q1, h1, x1);
		y1 = _mm512_FMA_ps(q1, h2, y1);
		q2 = _mm512_load_ps(&q[(i*ldq)+16]);
		x2 = _mm512_FMA_ps(q2, h1, x2);
		y2 = _mm512_FMA_ps(q2, h2, y2);
//		q3 = _mm512_load_ps(&q[(i*ldq)+32]);
//		x3 = _mm512_FMA_ps(q3, h1, x3);
//		y3 = _mm512_FMA_ps(q3, h2, y3);
//		q4 = _mm512_load_ps(&q[(i*ldq)+48]);
//		x4 = _mm512_FMA_ps(q4, h1, x4);
//		y4 = _mm512_FMA_ps(q4, h2, y4);
493
494
495

	}

496
497
498
499
500
501
502
503
504
505
	h1 = _mm512_set1_ps(hh[nb-1]);

	q1 = _mm512_load_ps(&q[nb*ldq]);
	x1 = _mm512_FMA_ps(q1, h1, x1);
	q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
	x2 = _mm512_FMA_ps(q2, h1, x2);
//	q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
//	x3 = _mm512_FMA_ps(q3, h1, x3);
//	q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
//	x4 = _mm512_FMA_ps(q4, h1, x4);
506
507
508


	/////////////////////////////////////////////////////
509
	// Rank-2 update of Q [24 x nb+1]
510
511
	/////////////////////////////////////////////////////

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
	__m512 tau1 = _mm512_set1_ps(hh[0]);
	__m512 tau2 = _mm512_set1_ps(hh[ldh]);
	__m512 vs = _mm512_set1_ps(s);

	h1 = (__m512) _mm512_xor_epi32((__m512i) tau1, (__m512i) sign);
	x1 = _mm512_mul_ps(x1, h1);
	x2 = _mm512_mul_ps(x2, h1);
//	x3 = _mm512_mul_ps(x3, h1);
//	x4 = _mm512_mul_ps(x4, h1);

	h1 = (__m512) _mm512_xor_epi32((__m512i) tau2, (__m512i) sign);
	h2 = _mm512_mul_ps(h1, vs);
	y1 = _mm512_FMA_ps(y1, h1, _mm512_mul_ps(x1,h2));
	y2 = _mm512_FMA_ps(y2, h1, _mm512_mul_ps(x2,h2));
//	y3 = _mm512_FMA_ps(y3, h1, _mm512_mul_ps(x3,h2));
//	y4 = _mm512_FMA_ps(y4, h1, _mm512_mul_ps(x4,h2));

	q1 = _mm512_load_ps(q);
	q1 = _mm512_add_ps(q1, y1);
	_mm512_store_ps(q,q1);
	q2 = _mm512_load_ps(&q[16]);
	q2 = _mm512_add_ps(q2, y2);
	_mm512_store_ps(&q[16],q2);
//	q3 = _mm512_load_ps(&q[32]);
//	q3 = _mm512_add_ps(q3, y3);
//	_mm512_store_ps(&q[32],q3);
//	q4 = _mm512_load_ps(&q[48]);
//	q4 = _mm512_add_ps(q4, y4);
//	_mm512_store_ps(&q[48],q4);

	h2 = _mm512_set1_ps(hh[ldh+1]);

	q1 = _mm512_load_ps(&q[ldq]);
	q1 = _mm512_add_ps(q1, _mm512_FMA_ps(y1, h2, x1));
	_mm512_store_ps(&q[ldq],q1);
	q2 = _mm512_load_ps(&q[ldq+16]);
	q2 = _mm512_add_ps(q2, _mm512_FMA_ps(y2, h2, x2));
	_mm512_store_ps(&q[ldq+16],q2);
//	q3 = _mm512_load_ps(&q[ldq+32]);
//	q3 = _mm512_add_ps(q3, _mm512_FMA_ps(y3, h2, x3));
//	_mm512_store_ps(&q[ldq+32],q3);
//	q4 = _mm512_load_ps(&q[ldq+48]);
//	q4 = _mm512_add_ps(q4, _mm512_FMA_ps(y4, h2, x4));
//	_mm512_store_ps(&q[ldq+48],q4);
556
557
558

	for (i = 2; i < nb; i++)
	{
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
		h1 = _mm512_set1_ps(hh[i-1]);
		h2 = _mm512_set1_ps(hh[ldh+i]);

		q1 = _mm512_load_ps(&q[i*ldq]);
		q1 = _mm512_FMA_ps(x1, h1, q1);
		q1 = _mm512_FMA_ps(y1, h2, q1);
		_mm512_store_ps(&q[i*ldq],q1);
		q2 = _mm512_load_ps(&q[(i*ldq)+16]);
		q2 = _mm512_FMA_ps(x2, h1, q2);
		q2 = _mm512_FMA_ps(y2, h2, q2);
		_mm512_store_ps(&q[(i*ldq)+16],q2);
//		q3 = _mm512_load_ps(&q[(i*ldq)+32]);
//		q3 = _mm512_FMA_ps(x3, h1, q3);
//		q3 = _mm512_FMA_ps(y3, h2, q3);
//		_mm512_store_ps(&q[(i*ldq)+32],q3);
//		q4 = _mm512_load_ps(&q[(i*ldq)+48]);
//		q4 = _mm512_FMA_ps(x4, h1, q4);
//		q4 = _mm512_FMA_ps(y4, h2, q4);
//		_mm512_store_ps(&q[(i*ldq)+48],q4);
578
579
580

	}

581
582
583
584
585
586
587
588
589
590
591
592
593
594
	h1 = _mm512_set1_ps(hh[nb-1]);

	q1 = _mm512_load_ps(&q[nb*ldq]);
	q1 = _mm512_FMA_ps(x1, h1, q1);
	_mm512_store_ps(&q[nb*ldq],q1);
	q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
	q2 = _mm512_FMA_ps(x2, h1, q2);
	_mm512_store_ps(&q[(nb*ldq)+16],q2);
//	q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
//	q3 = _mm512_FMA_ps(x3, h1, q3);
//	_mm512_store_ps(&q[(nb*ldq)+32],q3);
//	q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
//	q4 = _mm512_FMA_ps(x4, h1, q4);
//	_mm512_store_ps(&q[(nb*ldq)+48],q4);
595
596
597

}

598

599
600
/**
 * Unrolled kernel that computes
601
 * 16 rows of Q simultaneously, a
602
603
604
 * matrix vector product with two householder
 * vectors + a rank 2 update is performed
 */
605
 __forceinline void hh_trafo_kernel_16_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s)
606
607
{
	/////////////////////////////////////////////////////
608
	// Matrix Vector Multiplication, Q [24 x nb+1] * hh
609
610
611
612
	// hh contains two householder vectors, with offset 1
	/////////////////////////////////////////////////////
	int i;
	// Needed bit mask for floating point sign flip
613
614
	// carefull here
        __m512 sign = (__m512)_mm512_set1_epi32(0x80000000);
615

616
617
618
619
	__m512 x1 = _mm512_load_ps(&q[ldq]);
//	__m512 x2 = _mm512_load_ps(&q[ldq+32]);
//	__m512 x3 = _mm512_load_ps(&q[ldq+48]);
//	__m512 x4 = _mm512_load_ps(&q[ldq+64]);
620
621


622
623
624
625
626
627
628
629
630
631
632
	__m512 h1 = _mm512_set1_ps(hh[ldh+1]);
	__m512 h2;

	__m512 q1 = _mm512_load_ps(q);
	__m512 y1 = _mm512_FMA_ps(x1, h1, q1);
//	__m512 q2 = _mm512_load_ps(&q[16]);
//	__m512 y2 = _mm512_FMA_ps(x2, h1, q2);
//	__m512 q3 = _mm512_load_ps(&q[32]);
//	__m512 y3 = _mm512_FMA_ps(x3, h1, q3);
//	__m512 q4 = _mm512_load_ps(&q[48]);
//	__m512 y4 = _mm512_FMA_ps(x4, h1, q4);
633
634
635

	for(i = 2; i < nb; i++)
	{
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
		h1 = _mm512_set1_ps(hh[i-1]);
		h2 = _mm512_set1_ps(hh[ldh+i]);

		q1 = _mm512_load_ps(&q[i*ldq]);
		x1 = _mm512_FMA_ps(q1, h1, x1);
		y1 = _mm512_FMA_ps(q1, h2, y1);
//		q2 = _mm512_load_ps(&q[(i*ldq)+16]);
//		x2 = _mm512_FMA_ps(q2, h1, x2);
//		y2 = _mm512_FMA_ps(q2, h2, y2);
//		q3 = _mm512_load_ps(&q[(i*ldq)+32]);
//		x3 = _mm512_FMA_ps(q3, h1, x3);
//		y3 = _mm512_FMA_ps(q3, h2, y3);
//		q4 = _mm512_load_ps(&q[(i*ldq)+48]);
//		x4 = _mm512_FMA_ps(q4, h1, x4);
//		y4 = _mm512_FMA_ps(q4, h2, y4);

652
653
	}

654
655
656
657
658
659
660
661
662
663
664
	h1 = _mm512_set1_ps(hh[nb-1]);

	q1 = _mm512_load_ps(&q[nb*ldq]);
	x1 = _mm512_FMA_ps(q1, h1, x1);
//	q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
//	x2 = _mm512_FMA_ps(q2, h1, x2);
//	q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
//	x3 = _mm512_FMA_ps(q3, h1, x3);
//	q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
//	x4 = _mm512_FMA_ps(q4, h1, x4);

665
666

	/////////////////////////////////////////////////////
667
	// Rank-2 update of Q [24 x nb+1]
668
669
	/////////////////////////////////////////////////////

670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
	__m512 tau1 = _mm512_set1_ps(hh[0]);
	__m512 tau2 = _mm512_set1_ps(hh[ldh]);
	__m512 vs = _mm512_set1_ps(s);

	h1 = (__m512) _mm512_xor_epi32((__m512i) tau1, (__m512i) sign);
	x1 = _mm512_mul_ps(x1, h1);
//	x2 = _mm512_mul_ps(x2, h1);
//	x3 = _mm512_mul_ps(x3, h1);
//	x4 = _mm512_mul_ps(x4, h1);

	h1 = (__m512) _mm512_xor_epi32((__m512i) tau2, (__m512i) sign);
	h2 = _mm512_mul_ps(h1, vs);
	y1 = _mm512_FMA_ps(y1, h1, _mm512_mul_ps(x1,h2));
//	y2 = _mm512_FMA_ps(y2, h1, _mm512_mul_ps(x2,h2));
//	y3 = _mm512_FMA_ps(y3, h1, _mm512_mul_ps(x3,h2));
//	y4 = _mm512_FMA_ps(y4, h1, _mm512_mul_ps(x4,h2));

	q1 = _mm512_load_ps(q);
	q1 = _mm512_add_ps(q1, y1);
	_mm512_store_ps(q,q1);
//	q2 = _mm512_load_ps(&q[16]);
//	q2 = _mm512_add_ps(q2, y2);
//	_mm512_store_ps(&q[16],q2);
//	q3 = _mm512_load_ps(&q[32]);
//	q3 = _mm512_add_ps(q3, y3);
//	_mm512_store_ps(&q[32],q3);
//	q4 = _mm512_load_ps(&q[48]);
//	q4 = _mm512_add_ps(q4, y4);
//	_mm512_store_ps(&q[48],q4);

	h2 = _mm512_set1_ps(hh[ldh+1]);

	q1 = _mm512_load_ps(&q[ldq]);
	q1 = _mm512_add_ps(q1, _mm512_FMA_ps(y1, h2, x1));
	_mm512_store_ps(&q[ldq],q1);
//	q2 = _mm512_load_ps(&q[ldq+16]);
//	q2 = _mm512_add_ps(q2, _mm512_FMA_ps(y2, h2, x2));
//	_mm512_store_ps(&q[ldq+16],q2);
//	q3 = _mm512_load_ps(&q[ldq+32]);
//	q3 = _mm512_add_ps(q3, _mm512_FMA_ps(y3, h2, x3));
//	_mm512_store_ps(&q[ldq+32],q3);
//	q4 = _mm512_load_ps(&q[ldq+48]);
//	q4 = _mm512_add_ps(q4, _mm512_FMA_ps(y4, h2, x4));
//	_mm512_store_ps(&q[ldq+48],q4);
714
715
716

	for (i = 2; i < nb; i++)
	{
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
		h1 = _mm512_set1_ps(hh[i-1]);
		h2 = _mm512_set1_ps(hh[ldh+i]);

		q1 = _mm512_load_ps(&q[i*ldq]);
		q1 = _mm512_FMA_ps(x1, h1, q1);
		q1 = _mm512_FMA_ps(y1, h2, q1);
		_mm512_store_ps(&q[i*ldq],q1);
//		q2 = _mm512_load_ps(&q[(i*ldq)+16]);
//		q2 = _mm512_FMA_ps(x2, h1, q2);
//		q2 = _mm512_FMA_ps(y2, h2, q2);
//		_mm512_store_ps(&q[(i*ldq)+16],q2);
//		q3 = _mm512_load_ps(&q[(i*ldq)+32]);
//		q3 = _mm512_FMA_ps(x3, h1, q3);
//		q3 = _mm512_FMA_ps(y3, h2, q3);
//		_mm512_store_ps(&q[(i*ldq)+32],q3);
//		q4 = _mm512_load_ps(&q[(i*ldq)+48]);
//		q4 = _mm512_FMA_ps(x4, h1, q4);
//		q4 = _mm512_FMA_ps(y4, h2, q4);
//		_mm512_store_ps(&q[(i*ldq)+48],q4);

737
738
	}

739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
	h1 = _mm512_set1_ps(hh[nb-1]);

	q1 = _mm512_load_ps(&q[nb*ldq]);
	q1 = _mm512_FMA_ps(x1, h1, q1);
	_mm512_store_ps(&q[nb*ldq],q1);
//	q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
//	q2 = _mm512_FMA_ps(x2, h1, q2);
//	_mm512_store_ps(&q[(nb*ldq)+16],q2);
//	q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
//	q3 = _mm512_FMA_ps(x3, h1, q3);
//	_mm512_store_ps(&q[(nb*ldq)+32],q3);
//	q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
//	q4 = _mm512_FMA_ps(x4, h1, q4);
//	_mm512_store_ps(&q[(nb*ldq)+48],q4);

754
755
}