Skip to content

Commit b8a59bd

Browse files
authored
Added sgemm sup kernels for zen architecture (#901)
Details: - Added optimized sgemm sup kernels for AMD zen. - Support masked load/store instructions (`vmaskmovps` / `VMASKMOVPS`) for edge and fringe kernels. - Improves performance by reducing branch overhead and enhancing cache behavior.
1 parent eac5ee9 commit b8a59bd

12 files changed

+28515
-90
lines changed

config/zen3/bli_cntx_init_zen3.c

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ void bli_cntx_init_zen3( cntx_t* cntx )
6262
BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16,
6363
BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8,
6464

65-
// gemmsup
66-
#if 0
67-
// AMD: This should be enabled in the PR which has added these kernels
68-
// Update the context with optimized small/unpacked gemm kernels.
6965
BLIS_GEMMSUP_RRR_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m,
7066
BLIS_GEMMSUP_RRC_UKR, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m,
7167
BLIS_GEMMSUP_RCR_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m,
@@ -82,37 +78,6 @@ void bli_cntx_init_zen3( cntx_t* cntx )
8278
BLIS_GEMMSUP_CRC_UKR, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n,
8379
BLIS_GEMMSUP_CCR_UKR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n,
8480
BLIS_GEMMSUP_CCC_UKR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n,
85-
BLIS_GEMMSUP_RRR_UKR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m,
86-
BLIS_GEMMSUP_RCR_UKR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m,
87-
BLIS_GEMMSUP_CRR_UKR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m,
88-
BLIS_GEMMSUP_RCC_UKR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n,
89-
BLIS_GEMMSUP_CCR_UKR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n,
90-
BLIS_GEMMSUP_CCC_UKR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n,
91-
BLIS_GEMMSUP_RRR_UKR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m,
92-
BLIS_GEMMSUP_RCR_UKR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m,
93-
BLIS_GEMMSUP_CRR_UKR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m,
94-
BLIS_GEMMSUP_RCC_UKR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n,
95-
BLIS_GEMMSUP_CCR_UKR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n,
96-
BLIS_GEMMSUP_CCC_UKR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n,
97-
#else
98-
BLIS_GEMMSUP_RRR_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m,
99-
BLIS_GEMMSUP_RRC_UKR, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m,
100-
BLIS_GEMMSUP_RCR_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m,
101-
BLIS_GEMMSUP_RCC_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n,
102-
BLIS_GEMMSUP_CRR_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m,
103-
BLIS_GEMMSUP_CRC_UKR, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n,
104-
BLIS_GEMMSUP_CCR_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n,
105-
BLIS_GEMMSUP_CCC_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n,
106-
107-
BLIS_GEMMSUP_RRR_UKR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m,
108-
BLIS_GEMMSUP_RRC_UKR, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16m,
109-
BLIS_GEMMSUP_RCR_UKR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m,
110-
BLIS_GEMMSUP_RCC_UKR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n,
111-
BLIS_GEMMSUP_CRR_UKR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16m,
112-
BLIS_GEMMSUP_CRC_UKR, BLIS_FLOAT, bli_sgemmsup_rd_haswell_asm_6x16n,
113-
BLIS_GEMMSUP_CCR_UKR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n,
114-
BLIS_GEMMSUP_CCC_UKR, BLIS_FLOAT, bli_sgemmsup_rv_haswell_asm_6x16n,
115-
#endif
11681

11782
// packm
11883
#if 0
@@ -232,28 +197,28 @@ void bli_cntx_init_zen3( cntx_t* cntx )
232197
// s d c z
233198
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 );
234199
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 );
235-
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 );
236-
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 );
237-
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 );
200+
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 );
201+
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 566 );
202+
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 );
238203

239204
bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 );
240205
bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 );
241206

242207
// Initialize sup thresholds with architecture-appropriate values.
243208
// s d c z
244-
bli_blksz_init_easy( &blkszs[ BLIS_MT ], 512, 256, -1, -1 );
245-
bli_blksz_init_easy( &blkszs[ BLIS_NT ], 200, 256, -1, -1 );
246-
bli_blksz_init_easy( &blkszs[ BLIS_KT ], 240, 220, -1, -1 );
209+
bli_blksz_init_easy( &blkszs[ BLIS_MT ], 512, 256, 380, 110 );
210+
bli_blksz_init_easy( &blkszs[ BLIS_NT ], 200, 256, 256, 128 );
211+
bli_blksz_init_easy( &blkszs[ BLIS_KT ], 240, 220, 220, 110 );
247212

248213
// Initialize level-3 sup blocksize objects with architecture-specific
249214
// values.
250215
// s d c z
251216
bli_blksz_init ( &blkszs[ BLIS_MR_SUP ], 6, 6, 3, 3,
252217
9, 9, 3, 3 );
253218
bli_blksz_init_easy( &blkszs[ BLIS_NR_SUP ], 16, 8, 8, 4 );
254-
bli_blksz_init_easy( &blkszs[ BLIS_MC_SUP ], 144, 72, 72, 36 );
255-
bli_blksz_init_easy( &blkszs[ BLIS_KC_SUP ], 512, 256, 128, 64 );
256-
bli_blksz_init_easy( &blkszs[ BLIS_NC_SUP ], 8160, 4080, 2040, 1020 );
219+
bli_blksz_init_easy( &blkszs[ BLIS_MC_SUP ], 144, 72, 144, 24 );
220+
bli_blksz_init_easy( &blkszs[ BLIS_KC_SUP ], 256, 492, 256, 512 );
221+
bli_blksz_init_easy( &blkszs[ BLIS_NC_SUP ], 4080, 1600, 4080, 1536 );
257222

258223
// Update the context with the current architecture's register and cache
259224
// blocksizes (and multiples) for native execution.
@@ -289,17 +254,5 @@ void bli_cntx_init_zen3( cntx_t* cntx )
289254

290255
// -------------------------------------------------------------------------
291256

292-
#if 0
293-
// Initialize the context with the sup handlers.
294-
bli_cntx_set_l3_sup_handlers
295-
(
296-
cntx,
297-
298-
BLIS_GEMM, bli_gemmsup_ref,
299-
//BLIS_GEMMT, bli_gemmtsup_ref,
300-
301-
BLIS_VA_END
302-
);
303-
#endif
304257
}
305258

frame/include/bli_x86_asm_macros.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@
915915
#define VCOMISD(_0, _1) INSTR_(vcomisd, _0, _1)
916916

917917
#define VMASKMOVPD(_0, _1, _2) INSTR_(vmaskmovpd, _0, _1, _2)
918-
918+
#define VMASKMOVPS(_0, _1, _2) INSTR_(vmaskmovps, _0, _1, _2)
919919
#define VFMADD132SS(_0, _1, _2) INSTR_(vfmadd132ss, _0, _1, _2)
920920
#define VFMADD213SS(_0, _1, _2) INSTR_(vfmadd213ss, _0, _1, _2)
921921
#define VFMADD231SS(_0, _1, _2) INSTR_(vfmadd231ss, _0, _1, _2)
@@ -1242,7 +1242,7 @@
12421242
#define vblendpd(_0, _1, _2, _3) VBLENDPD(_0, _1, _2, _3)
12431243
#define vblendmps(_0, _1, _2) VBLENDMSD(_0, _1, _2)
12441244
#define vblendmpd(_0, _1, _2) VBLENDMPD(_0, _1, _2)
1245-
1245+
#define vmaskmovps(_0, _1, _2) VMASKMOVPS(_0, _1, _2)
12461246
#define vmaskmovpd(_0, _1, _2) VMASKMOVPD(_0, _1, _2)
12471247

12481248
// Prefetches

0 commit comments

Comments
 (0)