2525
2626// Use frag_type as an alias of the internal clang vector type of 16 fp16 values
2727
28- //#define ENABLE_TEST 1
2928
30- #if defined(ENABLE_TEST )
29+ #if __gfx1030__ || __gfx1031__ || __gfx1032__ || __gfx1033__ || __gfx1034__ || __gfx1035__ || __gfx1036__
30+ #define __gfx10__
31+ #endif
32+
33+ #if __gfx1100__ || __gfx1101__ || __gfx1102__ || __gfx1103__ || __gfx1150__ || __gfx1151__
34+ #define __gfx11__
35+ #endif
36+
37+ #if __gfx1200__ || __gfx1201__
38+ #define __gfx12__
39+ #endif
40+
41+
42+ #if defined(__gfx12__ )
3143#define WMMA_DATA_WIDTH 8
3244typedef _Float16 frag_type __attribute__( ( ext_vector_type ( 8 ) ) );
3345#else
@@ -51,7 +63,7 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
5163 // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
5264 const int lane = lIdx % 16 ;
5365 const int laneGroup = lIdx / 16 ;
54- #if defined( ENABLE_TEST )
66+ #if defined( __gfx12__ )
5567 for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
5668 {
5769 b_frag [ele ] = b [16 * (ele + laneGroup * WMMA_DATA_WIDTH ) + lane ];
@@ -76,12 +88,12 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
7688 // more details available in the RDNA3 ISA guide - https://developer.amd.com/wp-content/resources/RDNA3_Shader_ISA_December2022.pdf
7789 // the last parameter is called "OPSEL" which decides which half of the VGPRs of c_frag the results are stored into
7890 // this will only compile on RDNA3
79- #if defined( ENABLE_TEST )
91+ #if defined( __gfx12__ )
8092 c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12 ( a_frag , b_frag , c_frag );
8193#else
8294 c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32 ( a_frag , b_frag , c_frag , false );
8395#endif
84- #if defined( ENABLE_TEST )
96+ #if defined( __gfx12__ )
8597 for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
8698 {
8799 c [16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane ] = c_frag [ele ];
0 commit comments