Skip to content

Qwen3.5 MoE support#120

Open
howard0su wants to merge 5 commits intoLuce-Org:mainfrom
howard0su:moe35
Open

Qwen3.5 MoE support#120
howard0su wants to merge 5 commits intoLuce-Org:mainfrom
howard0su:moe35

Conversation

@howard0su
Copy link
Copy Markdown
Contributor

@howard0su howard0su commented May 7, 2026

Summary

Implements full MoE inference with expert weight swapping for the Qwen3.6-35B-A3B model (256 experts/layer, 8 active) on a single 2080 Ti. The model's expert weights (~18.6 GB) far exceed VRAM, so we use a two-graph-per-layer execution pattern with a dynamic LRU
cache and optional layer pinning.

Key Changes

  1. Expert Cache (expert_cache.h/.cpp) — LRU cache with configurable slots, batch async H2D transfers, per-layer miss tracking, priority-layer eviction protection (first 6 + last 3)
  2. Two-Graph Execution — Graph A (attention + router + top-K) → CPU reads router results, loads misses → Graph B (expert FFN + shared expert + residual)
  3. PinnedExperts — Bulk-load all 256 experts for selected layers into VRAM (zero cache misses). Strategy: always pin first + last layer, fill remaining budget from layer 1
  4. DDTree for MoE — Tree-structured speculative decoding wired through MoE path (enabled but not recommended — wider batches amplify cache misses)

Performance (64 tokens, budget=6, 2080 Ti 22GB)

Configuration tok/s Hit Rate
No pinning, 5000 dynamic slots 12.9 79%
pin=16 (layers 0-14,39), 2000 slots 17.2 86%
pin=20 (layers 0-18,39), 2000 slots 18.2 87%
pin=38, slots=100 28 71%
DDTree budget=22, no pin 3.1

Usage

 cd dflash && cmake --build build --target test_dflash

 # Recommended (20 pinned layers):
 ./build/test_dflash \
   models/Qwen3.6-35B-A3B-UD-Q4_K_M.gguf \
   models/draft/draft-Qwen3.6-35B-A3B.gguf \
   test/prompt_test.bin 64 /dev/null \
   --cache-slots=2000 --budget=6 --pin-layers=20

 # Tune for your VRAM:
 #   --pin-layers=N   (each layer ≈ 470 MB; first+last always pinned)
 #   --cache-slots=M  (each slot ≈ 1.8 MB)

Copy link
Copy Markdown

@cubic-dev-ai cubic-dev-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 issues found across 14 files

Prompt for AI agents (unresolved issues)

Check if these issues are valid — if so, understand the root cause of each and fix them. If appropriate, use sub-agents to investigate and fix each issue separately.


<file name="dflash/test/smoke_moe_forward.cpp">

<violation number="1" location="dflash/test/smoke_moe_forward.cpp:188">
P1: NaN/Inf only breaks the loop; the smoke test still exits 0, hiding forward-pass corruption.</violation>
</file>

<file name="dflash/src/qwen35moe_target_graph.cpp">

<violation number="1" location="dflash/src/qwen35moe_target_graph.cpp:30">
P2: Mutable global debug/timing state is updated from the forward path without synchronization, so concurrent inference can race and trigger UB.</violation>

<violation number="2" location="dflash/src/qwen35moe_target_graph.cpp:309">
P2: Graph A CUDA-debug error path returns without ggml_free(ctx_a), causing a context leak</violation>
</file>

<file name="dflash/src/gguf_target_loader.cpp">

<violation number="1" location="dflash/src/gguf_target_loader.cpp:681">
P2: MoE expert-source setup should validate tensor shapes/strides against the metadata before trusting `nb[2]` and expert counts.</violation>
</file>

<file name="dflash/test/test_dflash.cpp">

<violation number="1" location="dflash/test/test_dflash.cpp:3121">
P1: MoE path ignores the draft backend and computes the draft graph on the target backend, breaking split target/draft GPU setups.</violation>
</file>

Reply with feedback, questions, or to request a fix. Tag @cubic-dev-ai to re-run a review.

Comment thread dflash/test/smoke_moe_forward.cpp Outdated
Comment thread dflash/test/test_dflash.cpp Outdated
Comment thread dflash/src/qwen35moe_target_graph.cpp
Comment thread dflash/src/qwen35moe_target_graph.cpp Outdated
Comment thread dflash/src/gguf_target_loader.cpp
- Add MoE expert weight management (moe_experts.h/cpp): MoeExpertSource
  for mmap layout, PinnedExperts for bulk H2D VRAM pinning
- Add fused mega-graph MoE forward pass (qwen35moe_target_graph.cpp):
  single CUDA graph per (n_tokens, n_pinned) tuple, all 40 layers + lm_head
- Add shared attention block builders (qwen35_blocks.h): full-attention,
  deltanet, and SwiGLU FFN used by both dense and MoE graph paths
- Extend GGUF loader for MoE architecture fields (n_experts, expert_ffn_dim)
- Add MoE generation harness (test_dflash_moe.cpp): budget=1 pure
  autoregressive, budget>1 speculative decode with DDTree support
- Extract shared test helpers (test_helpers.h): DDTree, causal mask,
  top-K extraction reused by both dense and MoE test paths
- No expert swapping — all weights pinned in VRAM (fits 22.5GB card)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@howard0su howard0su changed the title MoE Expert Swapping for Qwen3.6-35B-A3B on 2080 Ti (22GB) Qwen3.5 MoE support May 7, 2026
howard0su and others added 4 commits May 8, 2026 00:59
Replace the per-expert view+add loop (14 nodes/layer) with a single
ggml_repeat_back operation that sums along the expert dimension.

Reduces CUDA graph nodes from 3358 to 2838 (-520 nodes, -15.5%).
Steady-state decode: 12ms → 10.5ms/token (83 → 86+ tok/s).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Remove unnecessary ggml_cont and reorder graph nodes to enable
ggml-cuda's built-in topk_moe fusion pattern:
  softmax → reshape → argsort → view → get_rows → norm chain

Also enables mul_mat_id_glu fusion (gate+up+swiglu → 1 kernel).

Results:
- Fusion confirmed: 90+ tok/s with fusion vs 79 tok/s without
- Steady-state: ~9.8ms/token (was 10.5ms)
- CUDA graph nodes: 2798 (reduced from 2838)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The draft model (z-lab/Qwen3.6-35B-A3B-DFlash) was trained with YaRN
rope scaling (factor=64, beta_fast=32, beta_slow=1) but we were using
vanilla RoPE (ext_factor=0). This caused 81% of frequency dimensions
to have incorrect position encodings, resulting in 0% acceptance rate.

Fix: compute YaRN freq_factors per the HuggingFace config and pass
them to ggml_rope_ext via a new optional rope_freq_factors field in
DraftGraphInputs.

Results:
- Acceptance rate: 1/16 → 12.25/16 (76.6% avg, 100% on most steps)
- Spec-decode throughput: 5 tok/s → 24 tok/s
- Autoregressive baseline unchanged at 91 tok/s

Also removes debug instrumentation from previous investigation.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant