@@ -94,9 +94,6 @@ struct StockhamKernelRC : public StockhamKernel
9494 stmts += Declaration{
9595 len_along_block,
9696 Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" }, lengths[2 ], lengths[1 ]}};
97- stmts += Declaration{
98- len_along_plane,
99- Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" }, lengths[1 ], lengths[2 ]}};
10097 stmts += Declaration{
10198 stride_load_in,
10299 Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" }, stride_in[2 ], stride_in[1 ]}};
@@ -105,18 +102,10 @@ struct StockhamKernelRC : public StockhamKernel
105102 Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" || sbrc_type == " SBRC_2D" },
106103 stride_out[1 ],
107104 stride_out[2 ]}};
108- stmts += Declaration{
109- stride_plane_in,
110- Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" }, stride_in[1 ], stride_in[2 ]}};
111- stmts += Declaration{
112- stride_plane_out,
113- Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" }, stride_out[2 ], stride_out[1 ]}};
114105
115106 stmts += LineBreak{};
116107 stmts += Declaration{num_of_tiles_in_batch};
117108 stmts += Declaration{tile_index_in_plane};
118- stmts += Declaration{plane_id};
119- stmts += Declaration{tile_serial_in_batch};
120109
121110 stmts += LineBreak{};
122111 stmts
@@ -172,6 +161,18 @@ struct StockhamKernelRC : public StockhamKernel
172161 // " tile_index_in_plane means index of the tile in that xy-plane",
173162 // " plane_id means the index of current xy-plane (inwards along Z-axis)"};
174163
164+ offset_3d += Declaration{
165+ len_along_plane,
166+ Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" }, lengths[1 ], lengths[2 ]}};
167+ offset_3d += Declaration{
168+ stride_plane_in,
169+ Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" }, stride_in[1 ], stride_in[2 ]}};
170+ offset_3d += Declaration{
171+ stride_plane_out,
172+ Ternary{Parens{sbrc_type == " SBRC_3D_FFT_TRANS_XY_Z" }, stride_out[2 ], stride_out[1 ]}};
173+ stmts += Declaration{plane_id};
174+ stmts += Declaration{tile_serial_in_batch};
175+
175176 // offset_3d += Assign{num_of_tiles_in_plane, (len_along_block - 1) / transforms_per_block + 1};
176177 offset_3d += Assign{num_of_tiles_in_batch, num_of_tiles_in_plane * len_along_plane};
177178 offset_3d += Assign{tile_serial_in_batch, block_id % num_of_tiles_in_batch};
@@ -257,6 +258,8 @@ struct StockhamKernelRC : public StockhamKernel
257258 auto num_load_blocks = (length * transforms_per_block) / workgroup_size;
258259 // #-row for a load block (global mem) = each thread will across these rows
259260 auto tid0_inc_step = transforms_per_block / num_load_blocks;
261+ // tpb/num_load_blocks, also = wgs/length, it's possible that they aren't divisible.
262+ bool divisible = (transforms_per_block % num_load_blocks) == 0 ;
260263
261264 stmts += If{transpose_type == " TILE_UNALIGNED" ,
262265 {Assign{edge,
@@ -268,14 +271,30 @@ struct StockhamKernelRC : public StockhamKernel
268271 // [dim0, dim1] = [tid0, tid1] :
269272 // each thread reads position [tid0, tid1], [tid0+step_h*1, tid1] , [tid0+step_h*2, tid1]...
270273 // tid0 walks the columns; tid1 walks the rows
271- stmts += Assign{tid0, thread_id / length};
272- stmts += Assign{tid1, thread_id % length};
274+ if (divisible)
275+ {
276+ stmts += Assign{tid0, thread_id / length};
277+ stmts += Assign{tid1, thread_id % length};
278+ }
273279
280+ // we need to take care about two diff cases for offset in buf and lds
281+ // divisible: each load leads to a perfect block: update offset much simpler
282+ // indivisible: need extra div and mod, otherwise each load will have some elements un-loaded:
274283 auto offset_tile_rbuf = [&](unsigned int i) {
275- return tid1 * stride0 + (tid0 + i * tid0_inc_step) * stride_load_in;
284+ if (divisible)
285+ return tid1 * stride0 + (tid0 + i * tid0_inc_step) * stride_load_in;
286+
287+ else
288+ return ((thread_id + i * workgroup_size) % length) * stride0
289+ + ((thread_id + i * workgroup_size) / length) * stride_load_in;
290+ };
291+ auto offset_tile_wlds = [&](unsigned int i) {
292+ if (divisible)
293+ return tid1 * 1 + (tid0 + i * tid0_inc_step) * stride_lds;
294+ else
295+ return ((thread_id + i * workgroup_size) % length) * 1
296+ + ((thread_id + i * workgroup_size) / length) * stride_lds;
276297 };
277- auto offset_tile_wlds
278- = [&](unsigned int i) { return tid1 * 1 + (tid0 + i * tid0_inc_step) * stride_lds; };
279298
280299 StatementList regular_load;
281300 for (unsigned int i = 0 ; i < num_load_blocks; ++i)
@@ -284,13 +303,30 @@ struct StockhamKernelRC : public StockhamKernel
284303
285304 StatementList edge_load;
286305 Variable t{" t" , " unsigned int" };
287- edge_load += For{
288- t,
289- 0 ,
290- Parens{(tile_index_in_plane * transforms_per_block + tid0 + t) < len_along_block},
291- tid0_inc_step,
292- {Assign{lds_complex[tid1 * 1 + (tid0 + t) * stride_lds],
306+ if (divisible)
307+ {
308+ edge_load += For{
309+ t,
310+ 0 ,
311+ Parens{(tile_index_in_plane * transforms_per_block + tid0 + t) < len_along_block},
312+ tid0_inc_step,
313+ {Assign{
314+ lds_complex[tid1 * 1 + (tid0 + t) * stride_lds],
293315 LoadGlobal{buf, offset_in + tid1 * stride0 + (tid0 + t) * stride_load_in}}}};
316+ }
317+ else
318+ {
319+ edge_load
320+ += For{t,
321+ 0 ,
322+ Parens{(thread_id + t) < (length * transforms_per_block)},
323+ workgroup_size,
324+ {Assign{lds_complex[((thread_id + t) % length) * 1
325+ + ((thread_id + t) / length) * stride_lds],
326+ LoadGlobal{buf,
327+ offset_in + ((thread_id + t) % length) * stride0
328+ + ((thread_id + t) / length) * stride_load_in}}}};
329+ }
294330
295331 stmts += If{Or{transpose_type != " TILE_UNALIGNED" , Not{edge}}, regular_load};
296332 stmts += Else{edge_load};
@@ -306,6 +342,10 @@ struct StockhamKernelRC : public StockhamKernel
306342 auto num_store_blocks = (length * transforms_per_block) / workgroup_size;
307343 // #-row for a store block (global mem) = each thread will across these rows
308344 auto tid0_inc_step = length / num_store_blocks;
345+ // length / ((length * transforms_per_block) / workgroup_size) = wgs/tpb = blockwidth
346+ // so divisible should always true. But still put the logic here, since it's a
347+ // generator-wise if-else test, not inside the kernel code.
348+ bool divisible = (length % num_store_blocks) == 0 ;
309349
310350 StatementList stmts;
311351
@@ -335,14 +375,29 @@ struct StockhamKernelRC : public StockhamKernel
335375 // [dim0, dim1] = [tid0, tid1]:
336376 // each thread write GLOBAL_POS [tid0, tid1], [tid0+step_h*1, tid1] , [tid0+step_h*2, tid1].
337377 // NB: This is a transpose from LDS to global, so the pos of lds_read should be remapped
338- stmts += Assign{tid0, thread_id / store_block_w};
339- stmts += Assign{tid1, thread_id % store_block_w};
378+ if (divisible)
379+ {
380+ stmts += Assign{tid0, thread_id / store_block_w};
381+ stmts += Assign{tid1, thread_id % store_block_w};
382+ }
340383
384+ // we need to take care about two diff cases for offset in buf and lds
385+ // divisible: each store leads to a perfect block: update offset much simpler
386+ // indivisible: need extra div and mod, otherwise each store will have some elements un-set:
341387 auto offset_tile_wbuf = [&](unsigned int i) {
342- return tid1 * stride0 + (tid0 + i * tid0_inc_step) * stride_store_out;
388+ if (divisible)
389+ return tid1 * stride0 + (tid0 + i * tid0_inc_step) * stride_store_out;
390+ else
391+ return ((thread_id + i * workgroup_size) % store_block_w) * stride0
392+ + (tid0 + i * tid0_inc_step) * stride_store_out;
393+ };
394+ auto offset_tile_rlds = [&](unsigned int i) {
395+ if (divisible)
396+ return tid1 * stride_lds + (tid0 + i * tid0_inc_step) * 1 ;
397+ else
398+ return ((thread_id + i * workgroup_size) % store_block_w) * stride_lds
399+ + ((thread_id + i * workgroup_size) / store_block_w) * 1 ;
343400 };
344- auto offset_tile_rlds
345- = [&](unsigned int i) { return tid1 * stride_lds + (tid0 + i * tid0_inc_step) * 1 ; };
346401
347402 StatementList regular_store;
348403 Expression pred{tile_index_in_plane * transforms_per_block + tid1 < len_along_block};
0 commit comments