perf(npu): route dispatched scalar div/sub/pow through cached-scalar kernel#566
Merged
Merged
Conversation
…kernel Dispatched NPU div(tensor, scalar), sub(tensor, scalar), and pow(tensor, scalar) materialized a fresh full-shape scalar tensor on every call via _scalar_to_npu_tensor (host malloc + byte-pack + H2D memcpy). add/mul already avoid this by reusing the persistent cached 0-d scalar tensor and the tensor-tensor exact kernels; div/sub/pow now do the same. Profiling the tiny-Qwen2 train step (910B3, fp16) showed _scalar_to_npu_tensor as the #2 backward tottime row (~0.013s, 50 calls/iter: 25 from MeanBackward0's div-by-numel, 25 from PowTensorScalarBackward0's pow(x, 1.0)). Routing these through the existing fast_div_scalar_exact / fast_sub_scalar_exact kernels and the cached 0-d scalar in fast_pow_tensor_scalar removes that row entirely: backward function calls drop 63692 -> 58858 (-4834/iter) and backward host time drops ~9%. The slow _scalar_to_npu_tensor path is preserved as a guarded fallback, and the graph-capture path for pow is unchanged. Behavior-preserving: the cached scalar carries the same dtype as `a` (matching the previous full-shape materialization) and broadcasts via the same tensor-tensor kernels. Qwen2 full-model parity unchanged (logits max_abs_diff 0.0). Validated: 5 new RED->green scalar fast-path tests, full NPU suite (608 passed), CPU+contract (3577 passed), pylint 10.00/10. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Candle AI ReviewScope
Findings
No CPU fallback, no PyTorch imports in source, schema dispatch unchanged, API surface unchanged, and the changes are generic (apply to any model using scalar div/sub/pow, not Qwen2-specific). PR Completeness
Risk: LowBehavior-preserving routed through existing cached-scalar + tensor-tensor kernels with the slow path preserved as fallback and parity verified on Qwen2 logits; only minor asymmetry in the pow fast-path type guarding. Full review at commit e0f68e9 (2026-06-14 08:37 UTC) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Dispatched NPU
div(tensor, scalar),sub(tensor, scalar), andpow(tensor, scalar)materialized a fresh full-shape scalar tensor on every call via_scalar_to_npu_tensor(host malloc + byte-pack + H2D memcpy).add/mulalready avoid this by reusing the persistent cached 0-d scalar tensor and the tensor-tensor exact kernels. This PR makesdiv/sub/powdo the same.Why
Profiling the tiny-Qwen2 train step (910B3, fp16, backward-only cProfile) showed
_scalar_to_npu_tensoras the #2 backwardtottimerow (~0.013s, 50 calls/iter):MeanBackward0(RMSNormmean→ div bynumel)PowTensorScalarBackward0(RMSNormx²backward →pow(x, 1.0))Each call did a per-call host malloc + dtype byte-pack + H2D memcpy to fill a full-shape scalar tensor.
Change
math.py:div/subscalar operands now route through the existingfast_div_scalar_exact/fast_sub_scalar_exactCython kernels (cached 0-d scalar + tensor-tensor exact kernel) before falling back to_scalar_to_npu_tensor._npu_ops.pyx:fast_pow_tensor_scalarreuses the cached 0-d scalar tensor outside graph capture; the per-call materialization is preserved during capture so no new cached-fill H2D traffic is recorded into the graph._scalar_to_npu_tensorpath is preserved as a guarded fallback (kernel entry point not deleted).Impact (tiny-Qwen2, 910B3, fp16, backward-only profile)
_scalar_to_npu_tensorremoved entirely from backward hot rows.Correctness
Behavior-preserving: the cached scalar carries the same dtype as
a(matching the previous full-shape materialization) and broadcasts via the same tensor-tensor kernels.max_abs_diff 0.0,max_rel_diff 0.0,failures: [].Tests
tests/npu/cython/test_dispatched_scalar_fast_path.py(RED→green): assert dispatcheddiv/sub/powscalar ops do not hit_scalar_to_npu_tensor, plus fp16/fp32 correctness and an end-to-end RMSNorm-shaped backward.Validation
🤖 Generated with Claude Code