Zero-Copy FlashMaskV3 Computation-Communication Overlap#110
Open
Enigmatisms wants to merge 55 commits intoPaddlePaddle:mainfrom
Open
Zero-Copy FlashMaskV3 Computation-Communication Overlap#110Enigmatisms wants to merge 55 commits intoPaddlePaddle:mainfrom
Enigmatisms wants to merge 55 commits intoPaddlePaddle:mainfrom
Conversation
- Add MPI package requirement - Patch nvshmem install dir in CMake - Removed cutlass::FastDivmod dep - Fixed some CMake and PTX bugs (v5) - Resolve symbol not found problem (v2) - Link MPI bootstrap dynlib
- So that we can init nvshmem correctly without hang
- Debug info print
- Debug info print and debug checkpoint - Fixed memcpy bug - Fixed incorrect TMA stride - DeviceSync debug (pre-check + post-check)
- Check comsumer write_ptr wait - Fixed mainloop load
- Add dense copy kernel
- No MPI bootstrap dependency - Add CUDA_ARCH tag in CMake - Add warn for CUDA_ARCH macro - Using self defined nvshmemi API (test int4 & int2) - Scale dense copy bytes by 2 to use all threads
- Fixed some minor bugs and improve docs - Add better synced logging for debugging. - Add wait on timeout debug kernel - Fixed nvshmem_int_wait_on_stream bug - Isolated debug logging macros
- Revert to 2-stream method for correct SR buffer local update - Add cudaStreamSync to replace EventWait (this avoids hang, idk why)
- Add compile def macro for conditional compilation - Manage overlap-comm instance via singleton pattern - Support export nvshmem unique ptr.
- Add compile def macro for conditional compilation - Manage overlap-comm instance via singleton pattern - Support export nvshmem unique ptr.
- conflict resolving patches - bwd bugs for non-overlapping mode
This reverts commit f3154a2.
- Since multi-node does not support bitwise AMO
- IBRC does not allow int AMO - Fixed standard AMO op stupid bug
- Add sparse large KV chunk kernel and dynamic scheduling - Fixed semaphore CUDA 300 problem - Benched multi-node and make dense copy size bigger - Deprecation warning: delete sparse small chunk kernels and dense copy flag - Add variable RDMA row per warp settings
- Consider KV head > 1 - Consider mask head > 1
- Memcpy multiple-times with correct offset - Correct coordinate
- 16 reduces the comm by half for KV head > 1 Potentially beneficial since in KV head 1, 32 might have already saturated bandwidth well enough.
(cherry picked from commit 0c424ba)
- Better genealization
- Add new AG remote_get kernel for splited calls - Refactored bf16 chunk dK dV reduce - Correct (I hope) mainloop and Paddle adaptation - Arange all the kernels into run_rs_overlap_kernel - Splitting into multiple loops (bwd) - New comm buffer (SepSRBuffer), new semaphore ops and comm kernels for RS - Patched multiple bugs
- Add more debug logging. - Fix misaligned address for remote_put - Fix hang problem. Accuracy needs fixing. - Found zeroing-problem
- Remote get local chunk skipping - Remote get local chunk correct offsetting - comm_stream set the highest priority - resolve Hyper-Q deadlock - Batch size > 1 to be supported
- Double buffering for SepSRBuffer
- Reduce dKV does not need send buffer - Adjust sync and visibility slightly - Fix local rank wait hang by correctly skipping local rank and reorganize recv buffer reset - Fix accuracy problem of not correctly notify local rank
- SMEM support for remote-put/get kernel copy chunk mask (correct ver)
- Decouple send buffer with consumer stream - Fix hang problem - Discard cuStream ops for set and wait values, could be catastrophical
- Tweak fine-grained semaphore sync (WIP) - SMEM for copy_chunk_mask in split remote get kernels - Fix dQ semaphore reset ops
- Make stream_coordinator universal and correct - Try fixing semaphore sync hang and acc mismatch problem (better sync) - fwd execution fence - Deprecate and remove sparse chunk get specialized kernel - Disable split AG remote_get kernel semaphore sync for now
- Fixed hang bug due to masking skip
- Fix one instability point for RS-overlap
- Mask skip before sync is useless
- This breaks the zero-copy trait for generality - [WIP] Acc mismatch resolving
- Paddle end failure, can be fixed in FM side - CP segment heuristic (and paddle porting) - Better debug logging - [Trial] Fused local & remote consumer notify
- Add NVSHMEM include dir for flashmask v2 - Ready for PR
This reverts commit de28678.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
FlashMaskV3 计算通信 Overlap 代码 beta 版本,目前在对上库代码进行进一步测试。
本 PR 实现的计算-通信 Overlap 几乎是 Zero-Copy 的(除了必要的本地tensor 到对称 buffer 迁移外),其他操作几乎都是 hack 了 attention 模块的输入输出 buffer 指针,使得计算、收发能在原地完成。所以本部分向 FA4 的迁移可能存在一定的难度,需要确定 FA4 的实现才能清楚工作量。
本部分包括了如下几部分改动(
flashmaskv2/distributed):对于 flashmaskv2 文件夹下各文件的修改主要为:
CP overlap flashmask 硬需求三个新的 PHI 参数:
代码目前正在进行如下测试: