4141
4242#if defined(__gfx12__ )
4343#define WMMA_DATA_WIDTH 8
44- typedef _Float16 frag_type __attribute__( ( ext_vector_type ( 8 ) ) );
44+ typedef __fp16 frag_type __attribute__( ( ext_vector_type ( 8 ) ) );
4545typedef float frag_type_c __attribute__( ( ext_vector_type ( 8 ) ) );
46+ typedef __fp16 half_2 __attribute__( ( ext_vector_type ( 2 ) ) );
4647#else
4748#define WMMA_DATA_WIDTH 16
48- typedef _Float16 frag_type __attribute__( ( ext_vector_type ( 16 ) ) );
49- typedef _Float16 frag_type_c __attribute__( ( ext_vector_type ( 16 ) ) );
49+ typedef __fp16 frag_type __attribute__( ( ext_vector_type ( 16 ) ) );
50+ typedef __fp16 frag_type_c __attribute__( ( ext_vector_type ( 16 ) ) );
5051#endif
5152
52- extern "C" __global__ void wmma_matmul ( __half * a , __half * b , __half * c )
53+ __device__ half_2 packFp32s ( float a , float b ) { return __builtin_amdgcn_cvt_pkrtz ( a , b ); }
54+
55+ extern "C" __global__ void wmma_matmul ( __fp16 * a , __fp16 * b , __fp16 * c )
5356{
5457 const int gIdx = blockIdx .x * blockDim .x + threadIdx .x ;
5558 const int lIdx = threadIdx .x ;
@@ -65,11 +68,25 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
6568 const int lane = lIdx % 16 ;
6669 const int laneGroup = lIdx / 16 ;
6770#if defined( __gfx12__ )
71+ #if 1
6872 for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
6973 {
70- b_frag [ele ] = b [16 * (ele + laneGroup * WMMA_DATA_WIDTH ) + lane ];
71- a_frag [ele ] = a [16 * lane + (ele + laneGroup * WMMA_DATA_WIDTH )];
74+ b_frag [ele ] = b [16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane ];
75+ a_frag [ele ] = a [16 * lane + ( ele + laneGroup * WMMA_DATA_WIDTH )];
76+ }
77+ #else
78+ {//with __builtin_amdgcn_cvt_pkrtz
79+ half_2 * a_ptr = reinterpret_cast < half_2 * > ( & a_frag );
80+ half_2 * b_ptr = reinterpret_cast < half_2 * > ( & b_frag );
81+ for ( int ele = 0 ; ele < WMMA_DATA_WIDTH / 2 ; ++ ele )
82+ {
83+ const int e0 = ele * 2 + 0 ;
84+ const int e1 = ele * 2 + 1 ;
85+ b_ptr [ele ] = packFp32s ( b [16 * ( e0 + laneGroup * WMMA_DATA_WIDTH ) + lane ], b [16 * ( e1 + laneGroup * WMMA_DATA_WIDTH ) + lane ] );
86+ a_ptr [ele ] = packFp32s ( a [16 * lane + ( e0 + laneGroup * WMMA_DATA_WIDTH )], a [16 * lane + ( e1 + laneGroup * WMMA_DATA_WIDTH )] );
87+ }
7288 }
89+ #endif
7390#else
7491 // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
7592 for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
0 commit comments