11#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
14#pragma GCC target("cpu=power10,htm")
17#if !__has_builtin(__builtin_vsx_assemble_pair)
18#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
26template<
typename Scalar,
typename Packet>
27EIGEN_ALWAYS_INLINE
void bsetzeroMMA(__vector_quad* acc)
29 __builtin_mma_xxsetaccz(acc);
32template<
typename DataMapper,
typename Index,
typename Packet, const Index accCols>
33EIGEN_ALWAYS_INLINE
void storeAccumulator(
Index i,
const DataMapper& data,
const Packet& alpha, __vector_quad* acc)
35 PacketBlock<Packet, 4> result;
36 __builtin_mma_disassemble_acc(&result.packet, acc);
38 PacketBlock<Packet, 4> tRes;
39 bload<DataMapper, Packet, Index, accCols, ColMajor, false, 4>(tRes, data, i, 0);
41 bscale<Packet, 4>(tRes, result, alpha);
43 data.template storePacketBlock<Packet, 4>(i, 0, tRes);
46template<
typename DataMapper,
typename Index,
typename Packet,
typename Packetc, const Index accColsC>
47EIGEN_ALWAYS_INLINE
void storeComplexAccumulator(
Index i,
const DataMapper& data,
const Packet& alphaReal,
const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
49 PacketBlock<Packet, 4> resultReal, resultImag;
50 __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
51 __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
53 PacketBlock<Packetc, 8> tRes;
54 bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4>(tRes, data, i, 0);
56 PacketBlock<Packet,4> taccReal, taccImag;
57 bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
59 PacketBlock<Packetc, 4> acc1, acc2;
60 bcouple<Packet, Packetc, 4>(taccReal, taccImag, tRes, acc1, acc2);
62 data.template storePacketBlock<Packetc, 4>(i, 0, acc1);
63 data.template storePacketBlock<Packetc, 4>(i + accColsC, 0, acc2);
67template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
68EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad* acc,
const RhsPacket& a,
const LhsPacket& b)
70 if(NegativeAccumulate)
72 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
74 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)a, (__vector
unsigned char)b);
78template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
79EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad* acc,
const PacketBlock<Packet2d,2>& a,
const Packet2d& b)
81 __vector_pair* a0 = (__vector_pair *)(&a.packet[0]);
82 if(NegativeAccumulate)
84 __builtin_mma_xvf64gernp(acc, *a0, (__vector
unsigned char)b);
86 __builtin_mma_xvf64gerpp(acc, *a0, (__vector
unsigned char)b);
90template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
91EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad* acc,
const __vector_pair& a,
const Packet2d& b)
93 if(NegativeAccumulate)
95 __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector
unsigned char)b);
97 __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector
unsigned char)b);
101template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
102EIGEN_ALWAYS_INLINE
void pgerMMA(__vector_quad*,
const __vector_pair&,
const Packet4f&)
107template<
typename Scalar,
typename Packet,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
108EIGEN_ALWAYS_INLINE
void pgercMMA(__vector_quad* accReal, __vector_quad* accImag,
const Packet& lhsV,
const Packet& lhsVi,
const RhsPacket& rhsV,
const RhsPacket& rhsVi)
110 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
112 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
115 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
116 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
118 EIGEN_UNUSED_VARIABLE(rhsVi);
120 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
125template<
typename Scalar,
typename Packet>
126EIGEN_ALWAYS_INLINE
void ploadRhsMMA(
const Scalar* rhs, Packet& rhsV)
128 rhsV = ploadRhs<Scalar, Packet>(rhs);
132EIGEN_ALWAYS_INLINE
void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(
const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
134 rhsV.packet[0] = ploadRhs<double, Packet2d>((
const double *)((Packet2d *)rhs ));
135 rhsV.packet[1] = ploadRhs<double, Packet2d>((
const double *)(((Packet2d *)rhs) + 1));
139EIGEN_ALWAYS_INLINE
void ploadRhsMMA<double, __vector_pair>(
const double* rhs, __vector_pair& rhsV)
142 __builtin_vsx_assemble_pair(&rhsV,
143 (__vector
unsigned char)(ploadRhs<double, Packet2d>((
const double *)(((Packet2d *)rhs) + 1))),
144 (__vector
unsigned char)(ploadRhs<double, Packet2d>((
const double *)((Packet2d *)rhs ))));
146 __asm__ (
"lxvp %x0,%1" :
"=wa" (rhsV) :
"Y" (*rhs));
151EIGEN_ALWAYS_INLINE
void ploadRhsMMA(
const float*, __vector_pair&)
159#define MICRO_MMA_UNROLL(func) \
160 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
162#define MICRO_MMA_LOAD_ONE(iter) \
163 if (unroll_factor > iter) { \
164 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
165 lhs_ptr##iter += accCols; \
167 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
170#define MICRO_MMA_WORK_ONE(iter, type, peel) \
171 if (unroll_factor > iter) { \
172 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
175#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
176 if (PEEL_MMA > peel) { \
177 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
178 ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
179 MICRO_MMA_UNROLL(func2); \
180 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
181 func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
183 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
186#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
187 type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
188 MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
189 MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
190 MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
191 MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7);
193#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
195 MICRO_MMA_TYPE_PEEL(func,func2,type,0);
197#define MICRO_MMA_ONE_PEEL \
198 if (sizeof(Scalar) == sizeof(float)) { \
199 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
201 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
203 rhs_ptr += (accRows * PEEL_MMA);
205#define MICRO_MMA_ONE \
206 if (sizeof(Scalar) == sizeof(float)) { \
207 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
209 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
213#define MICRO_MMA_DST_PTR_ONE(iter) \
214 if (unroll_factor > iter) { \
215 bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
217 EIGEN_UNUSED_VARIABLE(accZero##iter); \
220#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
222#define MICRO_MMA_SRC_PTR_ONE(iter) \
223 if (unroll_factor > iter) { \
224 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
226 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
229#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
231#define MICRO_MMA_PREFETCH_ONE(iter) \
232 if (unroll_factor > iter) { \
233 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
236#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
238#define MICRO_MMA_STORE_ONE(iter) \
239 if (unroll_factor > iter) { \
240 storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, res, pAlpha, &accZero##iter); \
243#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
245template<
int unroll_factor,
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols>
246EIGEN_ALWAYS_INLINE
void gemm_unrolled_MMA_iteration(
247 const DataMapper& res,
248 const Scalar* lhs_base,
249 const Scalar* rhs_base,
253 const Packet& pAlpha)
255 const Scalar* rhs_ptr = rhs_base;
256 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
257 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
263 for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
265 EIGEN_POWER_PREFETCH(rhs_ptr);
269 for(; k < depth; k++)
275 row += unroll_factor*accCols;
278template<
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols>
279EIGEN_ALWAYS_INLINE
void gemmMMA_cols(
280 const DataMapper& res,
281 const Scalar* blockA,
282 const Scalar* blockB,
291 Index remaining_rows,
292 const Packet& pAlpha,
295 const DataMapper res3 = res.getSubMapper(0, col);
297 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
298 const Scalar* lhs_base = blockA + accCols*offsetA;
301#define MAX_MMA_UNROLL 7
302 while(row + MAX_MMA_UNROLL*accCols <= rows) {
303 gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
305 switch( (rows-row)/accCols ) {
306#if MAX_MMA_UNROLL > 7
308 gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
311#if MAX_MMA_UNROLL > 6
313 gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
316#if MAX_MMA_UNROLL > 5
318 gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
321#if MAX_MMA_UNROLL > 4
323 gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
326#if MAX_MMA_UNROLL > 3
328 gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
331#if MAX_MMA_UNROLL > 2
333 gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
336#if MAX_MMA_UNROLL > 1
338 gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
346 if(remaining_rows > 0)
348 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
352template<
typename Scalar,
typename Index,
typename Packet,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols>
353void gemmMMA(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB,
Index rows,
Index depth,
Index cols, Scalar alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
355 const Index remaining_rows = rows % accCols;
357 if( strideA == -1 ) strideA = depth;
358 if( strideB == -1 ) strideB = depth;
360 const Packet pAlpha = pset1<Packet>(alpha);
361 const Packet pMask = bmask<Packet>((
const int)(remaining_rows));
364 for(; col + accRows <= cols; col += accRows)
366 gemmMMA_cols<Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
369 gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
372#define accColsC (accCols / 2)
373#define advanceRows ((LhsIsReal) ? 1 : 2)
374#define advanceCols ((RhsIsReal) ? 1 : 2)
377#define PEEL_COMPLEX_MMA 3
379#define MICRO_COMPLEX_MMA_UNROLL(func) \
380 func(0) func(1) func(2) func(3)
382#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
383 if (unroll_factor > iter) { \
384 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
386 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
388 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
390 lhs_ptr_real##iter += accCols; \
392 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
393 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
396#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
397 if (unroll_factor > iter) { \
398 pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
401#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
402 if (PEEL_COMPLEX_MMA > peel) { \
403 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
404 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
405 ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
407 ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
409 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
411 MICRO_COMPLEX_MMA_UNROLL(func2); \
412 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
414 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
415 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
418#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
419 type rhsV0, rhsV1, rhsV2, rhsV3; \
420 type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
421 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
422 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3);
424#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
425 type rhsV0, rhsVi0; \
426 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
428#define MICRO_COMPLEX_MMA_ONE_PEEL \
429 if (sizeof(Scalar) == sizeof(float)) { \
430 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
432 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
434 rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
435 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
437#define MICRO_COMPLEX_MMA_ONE \
438 if (sizeof(Scalar) == sizeof(float)) { \
439 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
441 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
443 rhs_ptr_real += accRows; \
444 if(!RhsIsReal) rhs_ptr_imag += accRows;
446#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
447 if (unroll_factor > iter) { \
448 bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
449 bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
451 EIGEN_UNUSED_VARIABLE(accReal##iter); \
452 EIGEN_UNUSED_VARIABLE(accImag##iter); \
455#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
457#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
458 if (unroll_factor > iter) { \
459 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
461 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
464#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
466#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
467 if (unroll_factor > iter) { \
468 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
471#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
473#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
474 if (unroll_factor > iter) { \
475 storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC>(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
478#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
480template<
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
481EIGEN_ALWAYS_INLINE
void gemm_complex_unrolled_MMA_iteration(
482 const DataMapper& res,
483 const Scalar* lhs_base,
484 const Scalar* rhs_base,
489 const Packet& pAlphaReal,
490 const Packet& pAlphaImag)
492 const Scalar* rhs_ptr_real = rhs_base;
493 const Scalar* rhs_ptr_imag = NULL;
494 const Index imag_delta = accCols*strideA;
496 rhs_ptr_imag = rhs_base + accRows*strideB;
498 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
500 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
501 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
502 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
504 MICRO_COMPLEX_MMA_SRC_PTR
505 MICRO_COMPLEX_MMA_DST_PTR
508 for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
510 EIGEN_POWER_PREFETCH(rhs_ptr_real);
512 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
514 MICRO_COMPLEX_MMA_PREFETCH
515 MICRO_COMPLEX_MMA_ONE_PEEL
517 for(; k < depth; k++)
519 MICRO_COMPLEX_MMA_ONE
521 MICRO_COMPLEX_MMA_STORE
523 row += unroll_factor*accCols;
526template<
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
527EIGEN_ALWAYS_INLINE
void gemmMMA_complex_cols(
528 const DataMapper& res,
529 const Scalar* blockA,
530 const Scalar* blockB,
539 Index remaining_rows,
540 const Packet& pAlphaReal,
541 const Packet& pAlphaImag,
544 const DataMapper res3 = res.getSubMapper(0, col);
546 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
547 const Scalar* lhs_base = blockA + accCols*offsetA;
550#define MAX_COMPLEX_MMA_UNROLL 4
551 while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
552 gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
554 switch( (rows-row)/accCols ) {
555#if MAX_COMPLEX_MMA_UNROLL > 4
557 gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
560#if MAX_COMPLEX_MMA_UNROLL > 3
562 gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
565#if MAX_COMPLEX_MMA_UNROLL > 2
567 gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
570#if MAX_COMPLEX_MMA_UNROLL > 1
572 gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
578#undef MAX_COMPLEX_MMA_UNROLL
580 if(remaining_rows > 0)
582 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
586template<
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Index,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
587void gemm_complexMMA(
const DataMapper& res,
const LhsScalar* blockAc,
const RhsScalar* blockBc,
Index rows,
Index depth,
Index cols, Scalarc alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
589 const Index remaining_rows = rows % accCols;
591 if( strideA == -1 ) strideA = depth;
592 if( strideB == -1 ) strideB = depth;
594 const Packet pAlphaReal = pset1<Packet>(alpha.real());
595 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
596 const Packet pMask = bmask<Packet>((
const int)(remaining_rows));
598 const Scalar* blockA = (Scalar *) blockAc;
599 const Scalar* blockB = (Scalar *) blockBc;
602 for(; col + accRows <= cols; col += accRows)
604 gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
607 gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
614#pragma GCC reset_options
Namespace containing all symbols from the Eigen library.
Definition: Core:141
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74