Skip to content

Zero-Copy FlashMaskV3 Computation-Communication Overlap#110

Open
Enigmatisms wants to merge 55 commits intoPaddlePaddle:mainfrom
Enigmatisms:overlap_debug
Open

Zero-Copy FlashMaskV3 Computation-Communication Overlap#110
Enigmatisms wants to merge 55 commits intoPaddlePaddle:mainfrom
Enigmatisms:overlap_debug

Conversation

@Enigmatisms
Copy link

@Enigmatisms Enigmatisms commented Mar 2, 2026

FlashMaskV3 计算通信 Overlap 代码 beta 版本,目前在对上库代码进行进一步测试。

本 PR 实现的计算-通信 Overlap 几乎是 Zero-Copy 的(除了必要的本地tensor 到对称 buffer 迁移外),其他操作几乎都是 hack 了 attention 模块的输入输出 buffer 指针,使得计算、收发能在原地完成。所以本部分向 FA4 的迁移可能存在一定的难度,需要确定 FA4 的实现才能清楚工作量。

本部分包括了如下几部分改动(flashmaskv2/distributed):

  • CP all-gather overlap 主逻辑
  • CP reduce-scatter overlap 主逻辑
  • 收发 buffer 定义以及 overlap communicator static 单例管理逻辑

对于 flashmaskv2 文件夹下各文件的修改主要为:

  • fwd/bwd launch template 适配 overlap 逻辑:如重定向 buffer 以达到 zero-copy,管理 overlap communicator 创建以及行为
  • mainloop 适配 overlap 逻辑:写指针操作
  • api 相关:paddle 适配等,导出 unique_id / RS number of chunk per segment (stage),允许输入 rank / nranks 等

CP overlap flashmask 硬需求三个新的 PHI 参数:

  • rank: 当前机为 CP 组的哪个 rank?
  • nranks: CP 组大小
  • unique_id: NVSHMEM 通信 group 初始化需要 CP-group 唯一的标识符(128 Byte unique_id,由NVSHMEM产生)。由于 Paddle 本身没有 PHI 算子用于生成 unique_id,这里导出了一个 unique_id 生成接口,共 Paddle 端对应实现的 unique_id 生成 PHI API 调用。(Paddle 的 DeepEP 有此 API,但会导致 Attention 依赖 DeepEP)。

代码目前正在进行如下测试:

  • (P0)FA包独立编译、Paddle 包统一编译(完全从头)正确性测试
  • (P0)重测单机 benchmark
  • (P0)重测单机精度测试
  • (P1)最新版 EB5 调研版收敛

Enigmatisms and others added 30 commits December 30, 2025 04:41
- Add MPI package requirement

- Patch nvshmem install dir in CMake

- Removed cutlass::FastDivmod dep

- Fixed some CMake and PTX bugs (v5)

- Resolve symbol not found problem (v2)

- Link MPI bootstrap dynlib
- So that we can init nvshmem correctly without hang
- Debug info print and debug checkpoint

- Fixed memcpy bug

- Fixed incorrect TMA stride

- DeviceSync debug (pre-check + post-check)
- Check comsumer write_ptr wait

- Fixed mainloop load
- No MPI bootstrap dependency
- Add CUDA_ARCH tag in CMake
- Add warn for CUDA_ARCH macro
- Using self defined nvshmemi API (test int4 & int2)
- Scale dense copy bytes by 2 to use all threads
- Fixed some minor bugs and improve docs
- Add better synced logging for debugging.
- Add wait on timeout debug kernel
- Fixed nvshmem_int_wait_on_stream bug
- Isolated debug logging macros
- Revert to 2-stream method for correct SR buffer local update
- Add cudaStreamSync to replace EventWait (this avoids hang, idk why)
- Add compile def macro for conditional compilation
- Manage overlap-comm instance via singleton pattern
- Support export nvshmem unique ptr.
- Add compile def macro for conditional compilation
- Manage overlap-comm instance via singleton pattern
- Support export nvshmem unique ptr.
- conflict resolving patches
- bwd bugs for non-overlapping mode
- Since multi-node does not support bitwise AMO
- IBRC does not allow int AMO
- Fixed standard AMO op stupid bug
- Add sparse large KV chunk kernel and dynamic scheduling
- Fixed semaphore CUDA 300 problem
- Benched multi-node and make dense copy size bigger
- Deprecation warning: delete sparse small chunk kernels and dense copy flag
- Add variable RDMA row per warp settings
- Consider KV head > 1
- Consider mask head > 1
- Memcpy multiple-times with correct offset
- Correct coordinate
- 16 reduces the comm by half for KV head > 1
Potentially beneficial since in KV head 1, 32 might have already saturated bandwidth well enough.
(cherry picked from commit 0c424ba)
- Add new AG remote_get kernel for splited calls
- Refactored bf16 chunk dK dV reduce
- Correct (I hope) mainloop and Paddle adaptation
- Arange all the kernels into run_rs_overlap_kernel
- Splitting into multiple loops (bwd)
- New comm buffer (SepSRBuffer), new semaphore ops and comm kernels for RS
- Patched multiple bugs
- Add more debug logging.
- Fix misaligned address for remote_put
- Fix hang problem. Accuracy needs fixing.
- Found zeroing-problem
- Remote get local chunk skipping
- Remote get local chunk correct offsetting
- comm_stream set the highest priority
- resolve Hyper-Q deadlock
- Batch size > 1 to be supported
- Double buffering for SepSRBuffer
- Reduce dKV does not need send buffer
- Adjust sync and visibility slightly
- Fix local rank wait hang by correctly skipping local rank and reorganize recv buffer reset
- Fix accuracy problem of not correctly notify local rank
- SMEM support for remote-put/get kernel copy chunk mask (correct ver)
- Decouple send buffer with consumer stream
- Fix hang problem
- Discard cuStream ops for set and wait values, could be catastrophical
- Tweak fine-grained semaphore sync (WIP)
- SMEM for copy_chunk_mask in split remote get kernels
- Fix dQ semaphore reset ops
- Make stream_coordinator universal and correct
- Try fixing semaphore sync hang and acc mismatch problem (better sync)
- fwd execution fence
- Deprecate and remove sparse chunk get specialized kernel
- Disable split AG remote_get kernel semaphore sync for now
- Fix one instability point for RS-overlap
- Mask skip before sync is useless
- This breaks the zero-copy trait for generality
- [WIP] Acc mismatch resolving
- Paddle end failure, can be fixed in FM side
- CP segment heuristic (and paddle porting)
- Better debug logging
- [Trial] Fused local & remote consumer notify
- Add NVSHMEM include dir for flashmask v2
- Ready for PR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants