Skip to content

Commit 64f2662

Browse files
authored
Merge pull request #105 from takahiroharada/feature/ORO-0-wmma-update
Wmma update
2 parents ef26098 + 5857369 commit 64f2662

File tree

2 files changed

+72
-15
lines changed

2 files changed

+72
-15
lines changed

Test/WMMA/main.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,20 @@ int main( int argc, char** argv )
9494
{
9595
for (int j = 0; j < 16; ++j)
9696
{
97-
a[i * 16 + j] = (__half)1.f;
98-
b[i * 16 + j] = (__half)1.f;
97+
a[i * 16 + j] = ( i < 8 && j < 8 ) ? (__half)0.f : (__half)1.f;
98+
b[i * 16 + j] = ( i < 8 ) ? (__half)1.f : (__half)2.f;
99+
if( j > 8 ) b[i * 16 + j] *= 2.f;
100+
}
101+
}
102+
103+
__half d[16 * 16] = {};
104+
for( int i = 0; i < 16; ++i )
105+
{
106+
for( int j = 0; j < 16; ++j )
107+
{
108+
__half& dst = d[i * 16 + j];
109+
dst = 0.f;
110+
for( int k = 0; k < 16; k++ ) dst += a[i * 16 + k] * b[k * 16 + j];
99111
}
100112
}
101113

@@ -114,19 +126,22 @@ int main( int argc, char** argv )
114126
oroFree((oroDeviceptr)c_gpu);
115127

116128
printf( "Output matrix:\n" );
129+
bool pass = true;
117130
for (int i = 0; i < 16; ++i)
118131
{
119132
for (int j = 0; j < 16; ++j)
120133
{
121-
printf("%f ", (float)c[i * 16 + j]);
134+
printf("%3.0f ", (float)c[i * 16 + j]);
135+
if( c[i * 16 + j] != d[i * 16 + j] )
136+
{
137+
pass = false;
138+
}
122139
}
123140
printf("\n");
124141
}
125-
printf( "Done!\n" );
142+
printf( pass ? "Pass!\n" : "Failed!\n" );
126143
e = oroCtxDestroy( ctx );
127144

128-
129-
130145
if ( testErrorFlag )
131146
return OROCHI_TEST_RETCODE__ERROR;
132147
return OROCHI_TEST_RETCODE__SUCCESS;

Test/WMMA/wmma_test_kernel.h

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,29 @@
2323
// Wave Matrix Multiply Accumulate (WMMA) using HIP compiler intrinsic
2424
// Does a matrix multiplication of two 16x16, fp16 matrices, and stores them into a 16x16 fp16 result matrix
2525

26-
// Use half16 as an alias of the internal clang vector type of 16 fp16 values
27-
typedef _Float16 half16 __attribute__( ( ext_vector_type( 16 ) ) );
26+
// Use frag_type as an alias of the internal clang vector type of 16 fp16 values
27+
28+
29+
#if __gfx1030__ || __gfx1031__ || __gfx1032__ || __gfx1033__ || __gfx1034__ || __gfx1035__ || __gfx1036__
30+
#define __gfx10__
31+
#endif
32+
33+
#if __gfx1100__ || __gfx1101__ || __gfx1102__ || __gfx1103__ || __gfx1150__ || __gfx1151__
34+
#define __gfx11__
35+
#endif
36+
37+
#if __gfx1200__ || __gfx1201__
38+
#define __gfx12__
39+
#endif
40+
41+
42+
#if defined(__gfx12__)
43+
#define WMMA_DATA_WIDTH 8
44+
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 8 ) ) );
45+
#else
46+
#define WMMA_DATA_WIDTH 16
47+
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 16 ) ) );
48+
#endif
2849

2950
extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
3051
{
@@ -34,30 +55,50 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
3455
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b
3556
// a_frag will store one column of the 16x16 matrix tile
3657
// b_frag will store one row of the 16x16 matrix tile
37-
half16 a_frag;
38-
half16 b_frag;
58+
frag_type a_frag;
59+
frag_type b_frag;
3960
// initialize c fragment to 0
40-
half16 c_frag = {};
61+
frag_type c_frag = {};
4162

4263
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
4364
const int lane = lIdx % 16;
65+
const int laneGroup = lIdx / 16;
66+
#if defined( __gfx12__ )
67+
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
68+
{
69+
b_frag[ele] = b[16 * (ele+laneGroup * WMMA_DATA_WIDTH) + lane];
70+
}
4471

45-
for( int ele = 0; ele < 16; ++ele )
72+
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
73+
{
74+
a_frag[ele] = a[16 * lane + ele+laneGroup * WMMA_DATA_WIDTH];
75+
}
76+
#else
77+
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
4678
{
4779
b_frag[ele] = b[16 * ele + lane];
4880
}
4981

50-
for( int ele = 0; ele < 16; ++ele )
82+
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
5183
{
5284
a_frag[ele] = a[16 * lane + ele];
5385
}
54-
86+
#endif
5587
// call the WMMA compiler intrinsic
5688
// more details available in the RDNA3 ISA guide - https://developer.amd.com/wp-content/resources/RDNA3_Shader_ISA_December2022.pdf
5789
// the last parameter is called "OPSEL" which decides which half of the VGPRs of c_frag the results are stored into
5890
// this will only compile on RDNA3
91+
#if defined( __gfx12__ )
92+
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( a_frag, b_frag, c_frag );
93+
#else
5994
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( a_frag, b_frag, c_frag, false );
60-
95+
#endif
96+
#if defined( __gfx12__ )
97+
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
98+
{
99+
c[16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane] = c_frag[ele];
100+
}
101+
#else
61102
for( int ele = 0; ele < 8; ++ele )
62103
{
63104
const int r = ele * 2 + ( lIdx / 16 );
@@ -66,4 +107,5 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
66107
// if OPSEL was set to "true", the line above would instead be
67108
// c[16 * r + lane] = c_frag[ele*2 + 1];
68109
}
110+
#endif
69111
}

0 commit comments

Comments
 (0)