Skip to content

[TLE] Optimize sparse MLA forward#534

Open
sunnycase wants to merge 9 commits intotriton_v3.6.xfrom
feature/tle_sparse_mla
Open

[TLE] Optimize sparse MLA forward#534
sunnycase wants to merge 9 commits intotriton_v3.6.xfrom
feature/tle_sparse_mla

Conversation

@sunnycase
Copy link
Copy Markdown
Collaborator

@sunnycase sunnycase commented Apr 20, 2026

Summary

This PR builds on PR #489 and completes the sparse MLA forward optimization work for the triton_v3.6.x branch. It adds the TLE sparse MLA tutorial/benchmark path, TLE-owned shared-memory staging and tile-style pipeline materialization, WGMMA descriptor/fence support, and the native Triton hooks required by those TLE paths under #ifdef __TLE__.

What Changed

1. Sparse MLA forward tutorial and benchmark coverage

  • Extended python/tutorials/tle/deepseek_v32/02-sparse-mla.py with Triton, TLE, TileLang, TileLang-pipelined, TileLang-seesaw, and FlashMLA-compatible sparse MLA forward providers.
  • Added topk_length support so sparse MLA can skip shorter sparse regions instead of always iterating the full static topk.
  • Aligned Triton/TLE launch configuration with the TileLang sparse MLA setup by using fixed num_warps and num_stages instead of autotune for this tutorial path.
  • Kept benchmark reporting focused on prefill rows for this PR summary.

2. WGMMA descriptor view and shared operand fencing

  • Added tle.memdesc_wgmma_view as a descriptor-only view for existing shared-memory tiles consumed by WGMMA.
  • Added tle.wgmma_shared_operand_fence to order generic-proxy shared writes before WGMMA async-proxy reads.
  • Added TLE-to-LLVM lowering for WGMMA descriptor views and shared operand fences.
  • Updated NVIDIA fence insertion to emit dependency-aware TLE operand fences when WGMMA consumes shared descriptors written by async-copy/local staging paths.
  • Updated dot operand optimization so eligible local_load(existing_smem) operands can be reused directly as WGMMA shared operands instead of forcing extra register/materialization paths.

3. Tile-style staging and pipeline materialization

  • Added TLE pipeline scheduling/materialization support for TileLang-style preload/use loops.
  • Generalized materialization to support dynamic trip counts with runtime guards for empty and single-iteration cases.
  • Added and updated TLE MLIR regression coverage for dynamic and short-loop pipeline materialization, schedule preservation, and exclusive-cumsum layout propagation.

4. TLE-guarded native Triton hooks

  • Tightened MMAv3 chained-dot warp assignment under #ifdef __TLE__ so D/C-accumulator-only dot users do not trigger the flash-attention chained-dot heuristic.
  • Added TLE-only short static pipeline-loop handling under #ifdef __TLE__.
  • Added a TLE-only WGMMA lowering fix under #ifdef __TLE__ that materializes the initial C accumulator before wgmma.fence, preventing ptxas from seeing non-WGMMA accumulator definitions inside the WGMMA pipeline stage.
  • Kept the non-TLE/native Triton path on the original upstream behavior.

5. Regression coverage

  • Added TLE MLIR tests under third_party/tle/test/GPU/ for WGMMA descriptor views, WGMMA shared operand fences, tile-style pipeline scheduling/materialization, short/dynamic loop handling, and the WGMMA accumulator-before-fence lowering rule.
  • Kept TLE-only regression checks out of native Triton test trees.

Performance

Environment

  • GPU: NVIDIA H800
  • Device: CUDA_VISIBLE_DEVICES=6
  • Date: 2026-04-21
  • Data type: BF16
  • Timing: 2000 ms warmup, 5000 ms measurement
  • Input mode: FlashMLA-compatible prefill
  • Seed: 1

Sparse MLA Forward Prefill

Configuration for all rows: B=1, S=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, topk_length=2048.

SKV Triton ms TLE ms TLE vs Triton TileLang ms TileLang-Pipelined ms TileLang-Seesaw ms FlashMLA ms
8192 11.832 10.686 1.11x 62.665 9.473 10.425 9.323
32768 17.882 14.964 1.20x 83.530 12.798 13.793 11.999
65536 19.933 17.569 1.14x 92.959 15.907 15.905 14.032
98304 21.028 18.873 1.11x 95.311 16.041 17.431 15.640
131072 21.332 20.703 1.03x 95.751 16.891 17.015 15.638

Interpretation:

  • TLE is faster than Triton on all measured prefill rows, with the margin narrowing at the largest SKV.
  • TileLang-Pipelined and FlashMLA remain faster than the current TLE kernel on these large prefill rows.
  • TileLang baseline is much slower and is included only as a reference provider.

Validation

  • Backend and Python extension rebuilt successfully from this branch.
  • TLE WGMMA accumulator/fence lowering regression passed.
  • Native Hopper conversion regression passed after keeping the TLE-only check in the TLE test tree.
  • TLE sparse MLA ptxas check no longer reports C7515 and reports 0 bytes stack frame, spill stores, and spill loads.
  • Local sparse MLA prefill correctness checks passed against Triton outputs in benchmark runs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants