2323// Wave Matrix Multiply Accumulate (WMMA) using HIP compiler intrinsic
2424// Does a matrix multiplication of two 16x16, fp16 matrices, and stores them into a 16x16 fp16 result matrix
2525
26- // Use half16 as an alias of the internal clang vector type of 16 fp16 values
27- typedef _Float16 half16 __attribute__( ( ext_vector_type ( 16 ) ) );
26+ // Use frag_type as an alias of the internal clang vector type of 16 fp16 values
27+
28+
29+ #if __gfx1030__ || __gfx1031__ || __gfx1032__ || __gfx1033__ || __gfx1034__ || __gfx1035__ || __gfx1036__
30+ #define __gfx10__
31+ #endif
32+
33+ #if __gfx1100__ || __gfx1101__ || __gfx1102__ || __gfx1103__ || __gfx1150__ || __gfx1151__
34+ #define __gfx11__
35+ #endif
36+
37+ #if __gfx1200__ || __gfx1201__
38+ #define __gfx12__
39+ #endif
40+
41+
42+ #if defined(__gfx12__ )
43+ #define WMMA_DATA_WIDTH 8
44+ typedef _Float16 frag_type __attribute__( ( ext_vector_type ( 8 ) ) );
45+ #else
46+ #define WMMA_DATA_WIDTH 16
47+ typedef _Float16 frag_type __attribute__( ( ext_vector_type ( 16 ) ) );
48+ #endif
2849
2950extern "C" __global__ void wmma_matmul ( __half * a , __half * b , __half * c )
3051{
@@ -34,30 +55,50 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
3455 // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b
3556 // a_frag will store one column of the 16x16 matrix tile
3657 // b_frag will store one row of the 16x16 matrix tile
37- half16 a_frag ;
38- half16 b_frag ;
58+ frag_type a_frag ;
59+ frag_type b_frag ;
3960 // initialize c fragment to 0
40- half16 c_frag = {};
61+ frag_type c_frag = {};
4162
4263 // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
4364 const int lane = lIdx % 16 ;
65+ const int laneGroup = lIdx / 16 ;
66+ #if defined( __gfx12__ )
67+ for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
68+ {
69+ b_frag [ele ] = b [16 * (ele + laneGroup * WMMA_DATA_WIDTH ) + lane ];
70+ }
4471
45- for ( int ele = 0 ; ele < 16 ; ++ ele )
72+ for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
73+ {
74+ a_frag [ele ] = a [16 * lane + ele + laneGroup * WMMA_DATA_WIDTH ];
75+ }
76+ #else
77+ for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
4678 {
4779 b_frag [ele ] = b [16 * ele + lane ];
4880 }
4981
50- for ( int ele = 0 ; ele < 16 ; ++ ele )
82+ for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
5183 {
5284 a_frag [ele ] = a [16 * lane + ele ];
5385 }
54-
86+ #endif
5587 // call the WMMA compiler intrinsic
5688 // more details available in the RDNA3 ISA guide - https://developer.amd.com/wp-content/resources/RDNA3_Shader_ISA_December2022.pdf
5789 // the last parameter is called "OPSEL" which decides which half of the VGPRs of c_frag the results are stored into
5890 // this will only compile on RDNA3
91+ #if defined( __gfx12__ )
92+ c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12 ( a_frag , b_frag , c_frag );
93+ #else
5994 c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32 ( a_frag , b_frag , c_frag , false );
60-
95+ #endif
96+ #if defined( __gfx12__ )
97+ for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
98+ {
99+ c [16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane ] = c_frag [ele ];
100+ }
101+ #else
61102 for ( int ele = 0 ; ele < 8 ; ++ ele )
62103 {
63104 const int r = ele * 2 + ( lIdx / 16 );
@@ -66,4 +107,5 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
66107 // if OPSEL was set to "true", the line above would instead be
67108 // c[16 * r + lane] = c_frag[ele*2 + 1];
68109 }
110+ #endif
69111}
0 commit comments