Skip to content

Commit f21724a

Browse files
authored
fixed sbrc global load and added more adhoc tests
1 parent ca3ba3b commit f21724a

File tree

2 files changed

+86
-27
lines changed

2 files changed

+86
-27
lines changed

clients/tests/accuracy_test_adhoc.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ std::vector<std::vector<size_t>> adhoc_sizes = {
3232

3333
// L1D_CC subplan of 3D_TRTRTR
3434
{4, 4, 8192},
35+
36+
// SBRC 192 with special param
37+
{192, 192, 192},
38+
{192, 84, 84},
3539
};
3640

3741
const static std::vector<std::vector<size_t>> stride_range = {{1}};

library/src/device/generator/stockham_gen_rc.h

Lines changed: 82 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ struct StockhamKernelRC : public StockhamKernel
9494
stmts += Declaration{
9595
len_along_block,
9696
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z"}, lengths[2], lengths[1]}};
97-
stmts += Declaration{
98-
len_along_plane,
99-
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z"}, lengths[1], lengths[2]}};
10097
stmts += Declaration{
10198
stride_load_in,
10299
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z"}, stride_in[2], stride_in[1]}};
@@ -105,18 +102,10 @@ struct StockhamKernelRC : public StockhamKernel
105102
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z" || sbrc_type == "SBRC_2D"},
106103
stride_out[1],
107104
stride_out[2]}};
108-
stmts += Declaration{
109-
stride_plane_in,
110-
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z"}, stride_in[1], stride_in[2]}};
111-
stmts += Declaration{
112-
stride_plane_out,
113-
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z"}, stride_out[2], stride_out[1]}};
114105

115106
stmts += LineBreak{};
116107
stmts += Declaration{num_of_tiles_in_batch};
117108
stmts += Declaration{tile_index_in_plane};
118-
stmts += Declaration{plane_id};
119-
stmts += Declaration{tile_serial_in_batch};
120109

121110
stmts += LineBreak{};
122111
stmts
@@ -172,6 +161,18 @@ struct StockhamKernelRC : public StockhamKernel
172161
// " tile_index_in_plane means index of the tile in that xy-plane",
173162
// " plane_id means the index of current xy-plane (inwards along Z-axis)"};
174163

164+
offset_3d += Declaration{
165+
len_along_plane,
166+
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z"}, lengths[1], lengths[2]}};
167+
offset_3d += Declaration{
168+
stride_plane_in,
169+
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z"}, stride_in[1], stride_in[2]}};
170+
offset_3d += Declaration{
171+
stride_plane_out,
172+
Ternary{Parens{sbrc_type == "SBRC_3D_FFT_TRANS_XY_Z"}, stride_out[2], stride_out[1]}};
173+
stmts += Declaration{plane_id};
174+
stmts += Declaration{tile_serial_in_batch};
175+
175176
// offset_3d += Assign{num_of_tiles_in_plane, (len_along_block - 1) / transforms_per_block + 1};
176177
offset_3d += Assign{num_of_tiles_in_batch, num_of_tiles_in_plane * len_along_plane};
177178
offset_3d += Assign{tile_serial_in_batch, block_id % num_of_tiles_in_batch};
@@ -257,6 +258,8 @@ struct StockhamKernelRC : public StockhamKernel
257258
auto num_load_blocks = (length * transforms_per_block) / workgroup_size;
258259
// #-row for a load block (global mem) = each thread will across these rows
259260
auto tid0_inc_step = transforms_per_block / num_load_blocks;
261+
// tpb/num_load_blocks, also = wgs/length, it's possible that they aren't divisible.
262+
bool divisible = (transforms_per_block % num_load_blocks) == 0;
260263

261264
stmts += If{transpose_type == "TILE_UNALIGNED",
262265
{Assign{edge,
@@ -268,14 +271,30 @@ struct StockhamKernelRC : public StockhamKernel
268271
// [dim0, dim1] = [tid0, tid1] :
269272
// each thread reads position [tid0, tid1], [tid0+step_h*1, tid1] , [tid0+step_h*2, tid1]...
270273
// tid0 walks the columns; tid1 walks the rows
271-
stmts += Assign{tid0, thread_id / length};
272-
stmts += Assign{tid1, thread_id % length};
274+
if(divisible)
275+
{
276+
stmts += Assign{tid0, thread_id / length};
277+
stmts += Assign{tid1, thread_id % length};
278+
}
273279

280+
// we need to take care about two diff cases for offset in buf and lds
281+
// divisible: each load leads to a perfect block: update offset much simpler
282+
// indivisible: need extra div and mod, otherwise each load will have some elements un-loaded:
274283
auto offset_tile_rbuf = [&](unsigned int i) {
275-
return tid1 * stride0 + (tid0 + i * tid0_inc_step) * stride_load_in;
284+
if(divisible)
285+
return tid1 * stride0 + (tid0 + i * tid0_inc_step) * stride_load_in;
286+
287+
else
288+
return ((thread_id + i * workgroup_size) % length) * stride0
289+
+ ((thread_id + i * workgroup_size) / length) * stride_load_in;
290+
};
291+
auto offset_tile_wlds = [&](unsigned int i) {
292+
if(divisible)
293+
return tid1 * 1 + (tid0 + i * tid0_inc_step) * stride_lds;
294+
else
295+
return ((thread_id + i * workgroup_size) % length) * 1
296+
+ ((thread_id + i * workgroup_size) / length) * stride_lds;
276297
};
277-
auto offset_tile_wlds
278-
= [&](unsigned int i) { return tid1 * 1 + (tid0 + i * tid0_inc_step) * stride_lds; };
279298

280299
StatementList regular_load;
281300
for(unsigned int i = 0; i < num_load_blocks; ++i)
@@ -284,13 +303,30 @@ struct StockhamKernelRC : public StockhamKernel
284303

285304
StatementList edge_load;
286305
Variable t{"t", "unsigned int"};
287-
edge_load += For{
288-
t,
289-
0,
290-
Parens{(tile_index_in_plane * transforms_per_block + tid0 + t) < len_along_block},
291-
tid0_inc_step,
292-
{Assign{lds_complex[tid1 * 1 + (tid0 + t) * stride_lds],
306+
if(divisible)
307+
{
308+
edge_load += For{
309+
t,
310+
0,
311+
Parens{(tile_index_in_plane * transforms_per_block + tid0 + t) < len_along_block},
312+
tid0_inc_step,
313+
{Assign{
314+
lds_complex[tid1 * 1 + (tid0 + t) * stride_lds],
293315
LoadGlobal{buf, offset_in + tid1 * stride0 + (tid0 + t) * stride_load_in}}}};
316+
}
317+
else
318+
{
319+
edge_load
320+
+= For{t,
321+
0,
322+
Parens{(thread_id + t) < (length * transforms_per_block)},
323+
workgroup_size,
324+
{Assign{lds_complex[((thread_id + t) % length) * 1
325+
+ ((thread_id + t) / length) * stride_lds],
326+
LoadGlobal{buf,
327+
offset_in + ((thread_id + t) % length) * stride0
328+
+ ((thread_id + t) / length) * stride_load_in}}}};
329+
}
294330

295331
stmts += If{Or{transpose_type != "TILE_UNALIGNED", Not{edge}}, regular_load};
296332
stmts += Else{edge_load};
@@ -306,6 +342,10 @@ struct StockhamKernelRC : public StockhamKernel
306342
auto num_store_blocks = (length * transforms_per_block) / workgroup_size;
307343
// #-row for a store block (global mem) = each thread will across these rows
308344
auto tid0_inc_step = length / num_store_blocks;
345+
// length / ((length * transforms_per_block) / workgroup_size) = wgs/tpb = blockwidth
346+
// so divisible should always true. But still put the logic here, since it's a
347+
// generator-wise if-else test, not inside the kernel code.
348+
bool divisible = (length % num_store_blocks) == 0;
309349

310350
StatementList stmts;
311351

@@ -335,14 +375,29 @@ struct StockhamKernelRC : public StockhamKernel
335375
// [dim0, dim1] = [tid0, tid1]:
336376
// each thread write GLOBAL_POS [tid0, tid1], [tid0+step_h*1, tid1] , [tid0+step_h*2, tid1].
337377
// NB: This is a transpose from LDS to global, so the pos of lds_read should be remapped
338-
stmts += Assign{tid0, thread_id / store_block_w};
339-
stmts += Assign{tid1, thread_id % store_block_w};
378+
if(divisible)
379+
{
380+
stmts += Assign{tid0, thread_id / store_block_w};
381+
stmts += Assign{tid1, thread_id % store_block_w};
382+
}
340383

384+
// we need to take care about two diff cases for offset in buf and lds
385+
// divisible: each store leads to a perfect block: update offset much simpler
386+
// indivisible: need extra div and mod, otherwise each store will have some elements un-set:
341387
auto offset_tile_wbuf = [&](unsigned int i) {
342-
return tid1 * stride0 + (tid0 + i * tid0_inc_step) * stride_store_out;
388+
if(divisible)
389+
return tid1 * stride0 + (tid0 + i * tid0_inc_step) * stride_store_out;
390+
else
391+
return ((thread_id + i * workgroup_size) % store_block_w) * stride0
392+
+ (tid0 + i * tid0_inc_step) * stride_store_out;
393+
};
394+
auto offset_tile_rlds = [&](unsigned int i) {
395+
if(divisible)
396+
return tid1 * stride_lds + (tid0 + i * tid0_inc_step) * 1;
397+
else
398+
return ((thread_id + i * workgroup_size) % store_block_w) * stride_lds
399+
+ ((thread_id + i * workgroup_size) / store_block_w) * 1;
343400
};
344-
auto offset_tile_rlds
345-
= [&](unsigned int i) { return tid1 * stride_lds + (tid0 + i * tid0_inc_step) * 1; };
346401

347402
StatementList regular_store;
348403
Expression pred{tile_index_in_plane * transforms_per_block + tid1 < len_along_block};

0 commit comments

Comments
 (0)