Skip to content

Commit 5926022

Browse files
32bit accumulation.
1 parent 0bd2133 commit 5926022

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

Test/WMMA/wmma_test_kernel.h

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@
4242
#if defined(__gfx12__)
4343
#define WMMA_DATA_WIDTH 8
4444
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 8 ) ) );
45+
typedef float frag_type_c __attribute__( ( ext_vector_type( 8 ) ) );
4546
#else
4647
#define WMMA_DATA_WIDTH 16
4748
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 16 ) ) );
49+
typedef _Float16 frag_type_c __attribute__( ( ext_vector_type( 16 ) ) );
4850
#endif
4951

5052
extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
@@ -58,38 +60,30 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
5860
frag_type a_frag;
5961
frag_type b_frag;
6062
// initialize c fragment to 0
61-
frag_type c_frag = {};
63+
frag_type_c c_frag = {};
6264

63-
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
6465
const int lane = lIdx % 16;
6566
const int laneGroup = lIdx / 16;
6667
#if defined( __gfx12__ )
6768
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
6869
{
6970
b_frag[ele] = b[16 * (ele+laneGroup * WMMA_DATA_WIDTH) + lane];
70-
}
71-
72-
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
73-
{
74-
a_frag[ele] = a[16 * lane + ele+laneGroup * WMMA_DATA_WIDTH];
71+
a_frag[ele] = a[16 * lane + (ele+laneGroup * WMMA_DATA_WIDTH)];
7572
}
7673
#else
74+
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
7775
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
7876
{
7977
b_frag[ele] = b[16 * ele + lane];
80-
}
81-
82-
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
83-
{
8478
a_frag[ele] = a[16 * lane + ele];
8579
}
8680
#endif
8781
// call the WMMA compiler intrinsic
8882
// more details available in the RDNA3 ISA guide - https://developer.amd.com/wp-content/resources/RDNA3_Shader_ISA_December2022.pdf
83+
// more details available in the RDNA4 ISA guide - https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna4-instruction-set-architecture.pdf
8984
// the last parameter is called "OPSEL" which decides which half of the VGPRs of c_frag the results are stored into
90-
// this will only compile on RDNA3
9185
#if defined( __gfx12__ )
92-
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( a_frag, b_frag, c_frag );
86+
c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( a_frag, b_frag, c_frag );
9387
#else
9488
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( a_frag, b_frag, c_frag, false );
9589
#endif

0 commit comments

Comments
 (0)