Skip to content

perf(npu): route dispatched scalar div/sub/pow through cached-scalar kernel#566

Merged
lvyufeng merged 1 commit into
candle-org:mainfrom
lvyufeng:npu-qwen2-backward-fast
Jun 14, 2026
Merged

perf(npu): route dispatched scalar div/sub/pow through cached-scalar kernel#566
lvyufeng merged 1 commit into
candle-org:mainfrom
lvyufeng:npu-qwen2-backward-fast

Conversation

@lvyufeng

Copy link
Copy Markdown
Contributor

Summary

Dispatched NPU div(tensor, scalar), sub(tensor, scalar), and pow(tensor, scalar) materialized a fresh full-shape scalar tensor on every call via _scalar_to_npu_tensor (host malloc + byte-pack + H2D memcpy). add/mul already avoid this by reusing the persistent cached 0-d scalar tensor and the tensor-tensor exact kernels. This PR makes div/sub/pow do the same.

Why

Profiling the tiny-Qwen2 train step (910B3, fp16, backward-only cProfile) showed _scalar_to_npu_tensor as the #2 backward tottime row (~0.013s, 50 calls/iter):

  • 25 from MeanBackward0 (RMSNorm mean → div by numel)
  • 25 from PowTensorScalarBackward0 (RMSNorm backward → pow(x, 1.0))

Each call did a per-call host malloc + dtype byte-pack + H2D memcpy to fill a full-shape scalar tensor.

Change

  • math.py: div/sub scalar operands now route through the existing fast_div_scalar_exact / fast_sub_scalar_exact Cython kernels (cached 0-d scalar + tensor-tensor exact kernel) before falling back to _scalar_to_npu_tensor.
  • _npu_ops.pyx: fast_pow_tensor_scalar reuses the cached 0-d scalar tensor outside graph capture; the per-call materialization is preserved during capture so no new cached-fill H2D traffic is recorded into the graph.
  • The slow _scalar_to_npu_tensor path is preserved as a guarded fallback (kernel entry point not deleted).

Impact (tiny-Qwen2, 910B3, fp16, backward-only profile)

  • _scalar_to_npu_tensor removed entirely from backward hot rows.
  • Backward function calls: 63692 → 58858 (−4834/iter).
  • Backward host profile time: −9% (0.173s → 0.157s over 5 iters).
  • Wall-clock backward median: ~38.6ms → ~37.0ms.

Correctness

Behavior-preserving: the cached scalar carries the same dtype as a (matching the previous full-shape materialization) and broadcasts via the same tensor-tensor kernels.

  • Qwen2 full-model parity unchanged: logits max_abs_diff 0.0, max_rel_diff 0.0, failures: [].

Tests

  • 5 new tests in tests/npu/cython/test_dispatched_scalar_fast_path.py (RED→green): assert dispatched div/sub/pow scalar ops do not hit _scalar_to_npu_tensor, plus fp16/fp32 correctness and an end-to-end RMSNorm-shaped backward.

Validation

  • Full NPU suite: 608 passed, 25 skipped
  • CPU + contract: 3577 passed, 30 skipped
  • pylint: 10.00/10

🤖 Generated with Claude Code

…kernel

Dispatched NPU div(tensor, scalar), sub(tensor, scalar), and pow(tensor,
scalar) materialized a fresh full-shape scalar tensor on every call via
_scalar_to_npu_tensor (host malloc + byte-pack + H2D memcpy). add/mul already
avoid this by reusing the persistent cached 0-d scalar tensor and the
tensor-tensor exact kernels; div/sub/pow now do the same.

Profiling the tiny-Qwen2 train step (910B3, fp16) showed _scalar_to_npu_tensor
as the #2 backward tottime row (~0.013s, 50 calls/iter: 25 from MeanBackward0's
div-by-numel, 25 from PowTensorScalarBackward0's pow(x, 1.0)). Routing these
through the existing fast_div_scalar_exact / fast_sub_scalar_exact kernels and
the cached 0-d scalar in fast_pow_tensor_scalar removes that row entirely:
backward function calls drop 63692 -> 58858 (-4834/iter) and backward host time
drops ~9%. The slow _scalar_to_npu_tensor path is preserved as a guarded
fallback, and the graph-capture path for pow is unchanged.

Behavior-preserving: the cached scalar carries the same dtype as `a` (matching
the previous full-shape materialization) and broadcasts via the same
tensor-tensor kernels. Qwen2 full-model parity unchanged (logits max_abs_diff
0.0). Validated: 5 new RED->green scalar fast-path tests, full NPU suite (608
passed), CPU+contract (3577 passed), pylint 10.00/10.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@github-actions

Copy link
Copy Markdown

Candle AI Review

Scope

  • Files changed: 3 | Lines: +137 / −0
  • Areas: backend/npu, dispatch, tests

Findings

  • [WARNING] src/candle/_C/_npu_ops.pyx:5730-5733 — The pow fast path requires exponent to be a Python scalar for _get_cached_scalar_tensor, but no isinstance check is performed on exponent unlike math.py's div/sub path. If a Tensor scalar leaks here, behavior diverges from the fallback. Minor, but asymmetric vs. the math.py guarding.
  • [INFO] src/candle/_C/_npu_ops.pyx:5732 — Branch checks (<TensorImpl>a)._device_type == 1 (NPU) but _npu_scalar_like already runs an ACLNN kernel on NPU; this is an optimization, not a fallback correctness issue. Acceptable, but _npu_in_graph_capture guard means behavior differs in capture vs. eager — worth confirming fast_pow tensor-tensor kernel is capture-safe (no host syncs/allocations).
  • [INFO] src/candle/_backends/npu/ops/math.py:325-340div/sub fast-path scalars skip the _HAS_FAST_DIV/_HAS_FAST_SUB gate entirely. The new _fast_div_scalar_impl/_fast_sub_scalar_impl are independent of those flags, which is fine since they have their own availability check (is not None). No invariant violation.

No CPU fallback, no PyTorch imports in source, schema dispatch unchanged, API surface unchanged, and the changes are generic (apply to any model using scalar div/sub/pow, not Qwen2-specific).

PR Completeness

  • Linked issue: no
  • Test plan: filled (5 tests covering correctness + slow-path avoidance across dtypes + RMSNorm-shaped backward end-to-end)
  • Template: missing (no checklist section, no risk assessment in body)

Risk: Low

Behavior-preserving routed through existing cached-scalar + tensor-tensor kernels with the slow path preserved as fallback and parity verified on Qwen2 logits; only minor asymmetry in the pow fast-path type guarding.


Full review at commit e0f68e9 (2026-06-14 08:37 UTC)

@lvyufeng lvyufeng merged commit ef1adf1 into candle-org:main Jun 14, 2026
17 of 28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant