Skip to content

test(dflash): contract test for draft SWA mask wiring#141

Open
javierpazo wants to merge 2 commits intoLuce-Org:mainfrom
javierpazo:xabicasa/dflash-test-draft-swa-mask-contract
Open

test(dflash): contract test for draft SWA mask wiring#141
javierpazo wants to merge 2 commits intoLuce-Org:mainfrom
javierpazo:xabicasa/dflash-test-draft-swa-mask-contract

Conversation

@javierpazo
Copy link
Copy Markdown
Contributor

test(dflash): contract test for draft SWA mask wiring

Adds a focused regression test that pins the draft-side SWA mask
wiring as a contract, independent of which SWA implementation
eventually lands.

What it tests:

  • When draft_graph_needs_swa_mask(weights, ctx_len) returns
    true, the graph reads attn_mask from DraftGraphInputs and
    propagates it to ggml_flash_attn_ext on SWA layers.
  • When SWA is not active, or total_k <= swa_window, the mask
    is not consumed.
  • build_draft_swa_mask produces the documented shape and
    values (0.0 for visible positions and -inf for masked
    positions, F16 storage).

Why a separate PR:

Depends on the companion feat(dflash): wire caller-provided SWA mask through draft graph (PR to be opened in parallel). The test
needs the helpers exported in that PR's dflash_graph.h.

Build registration:

Adds the test to dflash/CMakeLists.txt next to test_vs_oracle
with the same EXISTS-guard pattern, so a cmake --build ... --target test_draft_swa_mask_contract builds and runs it on a
clean checkout. No effect on other targets.

Validation:

  • cmake --build dflash/build/Release --target test_draft_swa_mask_contract succeeds on RTX 6000 Ada (sm_89),
    Windows MSVC + CUDA 12.x.
  • All assertions green: SWA-active + long-ctx case consumes the
    mask; non-SWA / short-ctx cases do not; mask helper output
    matches the documented shape and values.
  • No regressions on smoke_draft_graph or test_vs_oracle
    (untouched here).

Author: Javier Pazo xabicasa@gmail.com

javierpazo added 2 commits May 9, 2026 12:37
When the Qwen3.6 draft graph is built for a context that exceeds
the SWA window, the caller-provided sliding-window attention mask
must reach `ggml_flash_attn_ext` on the SWA layers.

Previously the mask was constructed and then nullified
post-construction in qwen3_dflash_graph, so SWA layers ran without
the intended visibility constraint at long contexts.

This change makes the wiring explicit and pinable as a contract:

  * `dflash_graph.h` — `DraftGraphInputs` gains an optional
    `attn_mask` field. Documented as caller-owned, type F16, with
    shape `[kv_len, q_len]` (or padded `[kv_pad, q_pad]`), values
    `0` for visible positions and `-inf` for masked positions.
    Two helpers added so callers do not reimplement the same logic:

      bool draft_graph_needs_swa_mask(const DraftWeights & w,
                                      int ctx_len);
      void build_draft_swa_mask(std::vector<uint16_t> & out,
                                int ctx_len, int q_len,
                                int swa_window);

    `lm_head` is normalised to a default-null member at the same
    time (small consistency fix; layout unchanged).

  * `qwen3_dflash_graph.cpp` — when `total_k > swa_window` on a SWA
    layer, the graph wires the caller-provided mask through to
    `ggml_flash_attn_ext` and stops nullifying it. Layers that are
    not SWA still ignore the mask, as before.

  * `smoke_draft_graph.cpp` and `test_vs_oracle.cpp` — small
    alignment so the existing tests can build / fill the draft SWA
    mask when `ctx_len + q_len > swa_window`. No new test scaffolding
    is added in this commit; the focused regression test lives in a
    separate PR (`test(dflash): contract test for draft SWA mask
    wiring`) so each PR keeps one concern.

Validation:

  * Built and ran `smoke_draft_graph` and `test_vs_oracle` on
    RTX 6000 Ada (sm_89), Heretic Q4_K_M target, FP16 safetensors
    drafter, FA_WINDOW=0. Both tests pass before and after; the
    behaviour at ctx_len <= swa_window is unchanged (mask not
    needed and not consumed).
  * At long context the SWA layers now respect the caller mask.

Verification vs existing community PRs:

  COMP-COMPL with PR Luce-Org#94 (open, "support Qwen3.6-27B-DFlash draft
  (SWA layers)", Quitetall) and PR Luce-Org#129 (open Draft, "sliding
  window attention for Qwen3.6 draft model", howard0su).

  * PR Luce-Org#94 wires SWA via masks (same family as this PR).
  * PR Luce-Org#129 wires SWA via per-layer K/V truncation instead.

  The interface added here (caller-mask field + helpers) is small
  enough that it can survive either approach landing first. If
  PR Luce-Org#94 lands first, this commit should rebase cleanly because it
  formalises the mask path Luce-Org#94 already needs internally; if
  PR Luce-Org#129 lands first, the mask path here remains useful for
  callers that prefer mask semantics. Maintainers, happy to
  coordinate ordering.

Author: Javier Pazo <xabicasa@gmail.com>
Adds a focused regression test that pins the draft-side SWA mask
wiring as a contract, independent of which SWA implementation
eventually lands.

What it tests:

  * When `draft_graph_needs_swa_mask(weights, ctx_len)` returns
    true, the graph reads `attn_mask` from `DraftGraphInputs` and
    propagates it to `ggml_flash_attn_ext` on SWA layers.
  * When SWA is not active, or `total_k <= swa_window`, the mask
    is not consumed.
  * `build_draft_swa_mask` produces the documented shape and
    values (`0.0` for visible positions and `-inf` for masked
    positions, F16 storage).

Why a separate PR:

  * Keeps "one concern per PR" per CONTRIBUTING — the wiring is
    feature, this is test.
  * The test stays useful regardless of the approach that lands
    upstream:
      - PR Luce-Org#94 (mask-style) — exercises exactly the wiring this
        test pins.
      - PR Luce-Org#129 (per-layer K/V truncation) — the contract still
        prevents regressions on the mask code path that callers
        rely on.

Depends on the companion `feat(dflash): wire caller-provided SWA
mask through draft graph` (PR to be opened in parallel). The test
needs the helpers exported in that PR's `dflash_graph.h`.

Build registration:

  Adds the test to `dflash/CMakeLists.txt` next to `test_vs_oracle`
  with the same `EXISTS`-guard pattern, so a `cmake --build ...
  --target test_draft_swa_mask_contract` builds and runs it on a
  clean checkout. No effect on other targets.

Validation:

  * `cmake --build dflash/build/Release --target
    test_draft_swa_mask_contract` succeeds on RTX 6000 Ada (sm_89),
    Windows MSVC + CUDA 12.x.
  * All assertions green: SWA-active + long-ctx case consumes the
    mask; non-SWA / short-ctx cases do not; mask helper output
    matches the documented shape and values.
  * No regressions on `smoke_draft_graph` or `test_vs_oracle`
    (untouched here).

Author: Javier Pazo <xabicasa@gmail.com>
Copy link
Copy Markdown

@cubic-dev-ai cubic-dev-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issues found across 6 files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant