Skip to content

Commit ad8d6df

Browse files
Pack version.
1 parent 5926022 commit ad8d6df

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

Test/WMMA/wmma_test_kernel.h

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,18 @@
4141

4242
#if defined(__gfx12__)
4343
#define WMMA_DATA_WIDTH 8
44-
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 8 ) ) );
44+
typedef __fp16 frag_type __attribute__( ( ext_vector_type( 8 ) ) );
4545
typedef float frag_type_c __attribute__( ( ext_vector_type( 8 ) ) );
46+
typedef __fp16 half_2 __attribute__( ( ext_vector_type( 2 ) ) );
4647
#else
4748
#define WMMA_DATA_WIDTH 16
48-
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 16 ) ) );
49-
typedef _Float16 frag_type_c __attribute__( ( ext_vector_type( 16 ) ) );
49+
typedef __fp16 frag_type __attribute__( ( ext_vector_type( 16 ) ) );
50+
typedef __fp16 frag_type_c __attribute__( ( ext_vector_type( 16 ) ) );
5051
#endif
5152

52-
extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
53+
__device__ half_2 packFp32s( float a, float b ) { return __builtin_amdgcn_cvt_pkrtz( a, b ); }
54+
55+
extern "C" __global__ void wmma_matmul( __fp16* a, __fp16* b, __fp16* c )
5356
{
5457
const int gIdx = blockIdx.x * blockDim.x + threadIdx.x;
5558
const int lIdx = threadIdx.x;
@@ -65,11 +68,25 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
6568
const int lane = lIdx % 16;
6669
const int laneGroup = lIdx / 16;
6770
#if defined( __gfx12__ )
71+
#if 1
6872
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
6973
{
70-
b_frag[ele] = b[16 * (ele+laneGroup * WMMA_DATA_WIDTH) + lane];
71-
a_frag[ele] = a[16 * lane + (ele+laneGroup * WMMA_DATA_WIDTH)];
74+
b_frag[ele] = b[16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane];
75+
a_frag[ele] = a[16 * lane + ( ele + laneGroup * WMMA_DATA_WIDTH )];
76+
}
77+
#else
78+
{//with __builtin_amdgcn_cvt_pkrtz
79+
half_2* a_ptr = reinterpret_cast<half_2*>( &a_frag );
80+
half_2* b_ptr = reinterpret_cast<half_2*>( &b_frag );
81+
for( int ele = 0; ele < WMMA_DATA_WIDTH / 2; ++ele )
82+
{
83+
const int e0 = ele * 2 + 0;
84+
const int e1 = ele * 2 + 1;
85+
b_ptr[ele] = packFp32s( b[16 * ( e0 + laneGroup * WMMA_DATA_WIDTH ) + lane], b[16 * ( e1 + laneGroup * WMMA_DATA_WIDTH ) + lane] );
86+
a_ptr[ele] = packFp32s( a[16 * lane + ( e0 + laneGroup * WMMA_DATA_WIDTH )], a[16 * lane + ( e1 + laneGroup * WMMA_DATA_WIDTH )] );
87+
}
7288
}
89+
#endif
7390
#else
7491
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
7592
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )

0 commit comments

Comments
 (0)