diff --git a/examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py b/examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py index e794ab8e..dfbba3d5 100644 --- a/examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py +++ b/examples/17_gemm_one_shot_all_reduce_pc/gemm_one_shot_all_reduce_pc.py @@ -127,14 +127,16 @@ def persistent_gemm( # Write to local buffer tl.store(local_C + local_offset, acc, mask=mask, cache_modifier=".wt") + # Ensure local_C write is visible before signaling + tl.atomic_fence(sem="release", scope="sys") + # Signal that this tile is ready - tl.debug_barrier() tl.store(locks + tile_id, 1, cache_modifier=".wt") - # Signal to all remote ranks that this tile is ready + # Signal to all remote ranks that this tile is ready by incrementing their counter for remote_rank in range(world_size): if remote_rank != cur_rank: - iris.atomic_xchg(tile_ready + tile_id, 1, cur_rank, remote_rank, heap_bases, sem="release", scope="sys") + iris.atomic_add(tile_ready + tile_id, 1, cur_rank, remote_rank, heap_bases, sem="release", scope="sys") if COLLECT_TIMESTAMPS: timestamp = read_realtime() @@ -213,16 +215,22 @@ def persistent_all_reduce( while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: pass - # Wait for remote ranks - for remote_rank in range(world_size): - if remote_rank != cur_rank: - while ( - iris.atomic_cas( - tile_ready + tile_id, 0, 0, cur_rank, remote_rank, heap_bases, sem="acquire", scope="sys" - ) - != 1 - ): - pass + # Ensure local producer's writes are visible + tl.atomic_fence(sem="acquire", scope="gpu") + + # Wait for remote ranks - each remote rank increments tile_ready when done + # We expect (world_size - 1) increments from all other ranks + while iris.atomic_cas( + tile_ready + tile_id, + 0, # Never matches when ready, so acts as atomic read + 0, + cur_rank, + cur_rank, + heap_bases, + sem="acquire", + scope="sys", + ) < (world_size - 1): + pass # Map tile_id to (pid_m, pid_n) num_pid_in_group = GROUP_SIZE_M * num_pid_n