Skip to content

Commit 5857369

Browse files
Clean up.
1 parent e9f427d commit 5857369

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

Test/WMMA/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ int main( int argc, char** argv )
131131
{
132132
for (int j = 0; j < 16; ++j)
133133
{
134-
printf("%3.1f ", (float)c[i * 16 + j]);
134+
printf("%3.0f ", (float)c[i * 16 + j]);
135135
if( c[i * 16 + j] != d[i * 16 + j] )
136136
{
137137
pass = false;

Test/WMMA/wmma_test_kernel.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,21 @@
2525

2626
// Use frag_type as an alias of the internal clang vector type of 16 fp16 values
2727

28-
//#define ENABLE_TEST 1
2928

30-
#if defined(ENABLE_TEST)
29+
#if __gfx1030__ || __gfx1031__ || __gfx1032__ || __gfx1033__ || __gfx1034__ || __gfx1035__ || __gfx1036__
30+
#define __gfx10__
31+
#endif
32+
33+
#if __gfx1100__ || __gfx1101__ || __gfx1102__ || __gfx1103__ || __gfx1150__ || __gfx1151__
34+
#define __gfx11__
35+
#endif
36+
37+
#if __gfx1200__ || __gfx1201__
38+
#define __gfx12__
39+
#endif
40+
41+
42+
#if defined(__gfx12__)
3143
#define WMMA_DATA_WIDTH 8
3244
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 8 ) ) );
3345
#else
@@ -51,7 +63,7 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
5163
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
5264
const int lane = lIdx % 16;
5365
const int laneGroup = lIdx / 16;
54-
#if defined( ENABLE_TEST )
66+
#if defined( __gfx12__ )
5567
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
5668
{
5769
b_frag[ele] = b[16 * (ele+laneGroup * WMMA_DATA_WIDTH) + lane];
@@ -76,12 +88,12 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
7688
// more details available in the RDNA3 ISA guide - https://developer.amd.com/wp-content/resources/RDNA3_Shader_ISA_December2022.pdf
7789
// the last parameter is called "OPSEL" which decides which half of the VGPRs of c_frag the results are stored into
7890
// this will only compile on RDNA3
79-
#if defined( ENABLE_TEST )
91+
#if defined( __gfx12__ )
8092
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( a_frag, b_frag, c_frag );
8193
#else
8294
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( a_frag, b_frag, c_frag, false );
8395
#endif
84-
#if defined( ENABLE_TEST )
96+
#if defined( __gfx12__ )
8597
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
8698
{
8799
c[16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane] = c_frag[ele];

0 commit comments

Comments
 (0)