fix: support N-D weights with unit leading dims in MatMulNBitsQuantizer#28178
Conversation
MatMulNBitsQuantizer silently skipped quantizing MatMul nodes whose weight tensor had more than 2 dimensions (e.g. [1, K, N]), logging "MatMul weight is not 2D. Skip to quantize." This blocked quantization of LLM attention projection weights pre-reshaped to 3-D. When all leading dimensions of the weight are 1, squeeze them to get a 2-D [K, N] view, run the existing quantization logic, then insert a Reshape node after MatMulNBits (or the final MatMul in QDQ format) to restore the expected N-D output shape dictated by ONNX MatMul batch broadcasting semantics ([b1, ..., batch_A, N]). For weights with genuinely non-unit batch dimensions (true batched matmuls), the skip is preserved with a clearer log message explaining why quantization cannot proceed safely. Fixes: both HQQWeightOnlyQuantizer.quantize() and DefaultWeightOnlyQuantizer.quantize_matmul() (QOperator and QDQ paths). Closes microsoft#25362
There was a problem hiding this comment.
Pull request overview
Enables MatMulNBitsQuantizer to quantize MatMul nodes whose weight initializers are N-D tensors with unit leading (“batch”) dimensions by squeezing weights to 2-D for quantization and attempting to restore the original MatMul output shape.
Changes:
- Update HQQ + Default quantization paths to accept N-D weights with all leading dims == 1 (skip otherwise with a more specific log message).
- Insert a
Reshapeafter the quantized op (and after QDQ MatMul) to restore an output shape derived from the original weight shape. - Add unit tests to validate quantization and presence of the inserted
Reshapenode for 3D weights in Default (QOperator/QDQ) and HQQ paths.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py | Adds N-D unit-leading-dim weight support by squeezing weights and inserting output Reshape nodes in HQQ + Default (QOperator/QDQ) paths. |
| onnxruntime/test/python/quantization/test_op_matmul_4bits.py | Adds tests and test helpers for 3D weight MatMul models; adds extra_quant_nodes plumbing to assert Reshape insertion. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ut reshape The previous post-MatMulNBits Reshape used [*b_original_shape[:-2], -1, N] as the target. For batched activations (e.g., A=[B,M,K], B=[1,K,N]) this collapses A's batch dims into the -1 sentinel and produces [1, B*M, N] instead of the ONNX broadcast result [B, M, N]. Replace the static target shape with a small dynamic graph (Shape, Size, Sub, Max, ConstantOfShape, Slice, Concat → Reshape) that prepends only the leading 1s required by max(rank(B_orig) - rank(A), 0) and preserves A.shape[:-1] as-is. This is correct for arbitrary activation rank; in particular the rank(A) == 2 case still produces [1, M, N] and the new rank(A) == 3 case produces [B, S, N]. Wraps the construction in a single helper used at all three call sites (QOperator MatMulNBits, QDQ MatMulNBits, QDQ DQ→MatMul). Adds a regression test that uses a 3-D activation [B, S, K] with weight [1, K, N] and asserts both the op set and full-graph numerical correctness.
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
The dynamic reshape helper (_build_nbits_output_reshape) correctly implements the ONNX MatMul broadcast rule:
out_shape = [1] * max(rank(B_orig) - rank(A), 0) + A.shape[:-1] + [N]
using Shape/Size/Sub/Max/ConstantOfShape/Slice/Concat — making it work for any activation rank without static knowledge at quantization time. Previous review concerns about the static reshape have been fully addressed.
Positives:
- Clean separation into a reusable module-level helper shared by both HQQ and Default paths
- Correct handling of the
extra_count = 0edge case (rank(A) >= rank(B_orig)) where ConstantOfShape produces an empty tensor - Good test coverage including the critical 3-D activation case that validates batch dims aren't flattened
- End-to-end correctness via
check_model_correctnessensures numerical accuracy is maintained
Minor suggestions (non-blocking):
- The existing lintrunner RUF012 finding on
_RESHAPE_HELPER_OP_COUNTSshould be addressed (annotate withtyping.ClassVar) - When
rank(A) >= rank(B_orig), the reshape chain produces a no-op (target shape == output shape). A short-circuit when A's static rank is known could avoid 11 extra nodes in the unoptimized graph, but ONNX optimizers should handle this fine
…A) >= rank(B_orig) When the activation rank is statically known to be >= the original weight rank, the post-MatMulNBits Reshape (and its dynamic-shape helper ops) is a no-op since MatMulNBits already produces the correct output shape. - Add `_get_static_rank` helper to read a tensor's static rank from the graph stack (input/value_info/output). - Apply the same `needs_reshape` guard at all three quantize_matmul call sites (HQQ, Default QOperator, Default QDQ); fall back to emitting the reshape when the rank is unknown or smaller than rank_b_orig. - Annotate `_RESHAPE_HELPER_OP_COUNTS` with `typing.ClassVar` (RUF012). - Update test_quantize_matmul_int4_3d_weight_3d_activation_preserves_shape to assert the new optimized op-count (no reshape helpers); other 3-D weight tests still use 2-D activations and continue to exercise the reshape path.
|
Thanks for the review. Pushed a follow-up addressing both nits in 3faf2f5:
Updated Targeted tests: |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
ONNX MatMul promotes a 1-D activation A to rank-2 before computing the output shape, so deriving the prepended-ones count from raw rank(A) produced wrong shapes for 1-D activations (e.g. MatMul([K], [1,K,N]) reshaped to [1,1,N] instead of [1,N]). Compute a_rank_eff = Max(rank(A), 2) and use it when computing the extra-ones count. Tighten the 3-D-weight/3-D-activation test to assert all reshape-helper ops are absent in the elision path, and add a new 3-D-weight/1-D-activation regression test that exercises the helper and verifies output-shape preservation.
|
Thanks for the review. Pushed 896b393 addressing both items:
Also added |
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
The dynamic reshape helper (_build_nbits_output_reshape) correctly implements the ONNX MatMul broadcast rule and properly handles 1-D, 2-D, and higher-rank activations. The Max(a_rank, 2) effectively accounts for ONNX MatMul's 1-D promotion semantics, and the static-rank elision avoids inserting dead nodes for the dominant transformer case.
All previously raised concerns (1-D activation bug, batch-dim flattening, no-op reshape elision, test assertion gaps) have been addressed in the latest commits. Test coverage is comprehensive across QOperator, QDQ, and HQQ paths with 2-D, 3-D, and 1-D activation shapes.
Two minor suggestions below — neither is blocking.
…ertion - Fix stale comment in DefaultWeightOnlyQuantizer QDQ path: the no-op reshape skip happens on the standard MatMul node (post-DequantizeLinear), not on MatMulNBits. - Add "Reshape": 0 to no_reshape_helpers so the test explicitly asserts no Reshape node is inserted when the helper chain should be absent.
|
Addressed both nits in d68f159:
Targeted |
Address review feedback on MatMulNBits output-reshape helper: - Replace Size(a_shape) with Shape(a_shape) + Gather(idx=0). The Size op on a shape tensor requires ai.onnx opset >= 13, but MatMulNBitsQuantizer does not bump the model's opset, so opset 11/12 inputs would produce an invalid graph. Shape and Gather are valid from opset 1/11 and yield the same scalar int64 rank value. - Incorporate the unique output tensor name (node.output[0]) into the initializer name prefix. The prior prefix was node.name, which is not guaranteed unique across a model and would collide when multiple MatMul nodes share a name. node.output[0] is unique per the ONNX spec.
|
Thanks for the review. Addressed both items in 8788e90:
|
tianleiwu
left a comment
There was a problem hiding this comment.
Follow-up review after commits d68f1594 and 8788e908.
The QDQ-path comment fix and the Reshape: 0 assertion are correct — resolved those threads. However, the opset-11 compatibility change (Size → Shape+Gather) introduced a test expectation mismatch that will cause four tests to fail.
Previously open threads from the prior review round have been re-evaluated and resolved:
- QDQ-path comment ("MatMulNBits" → "MatMul"): fixed in
d68f1594 Reshape: 0inno_reshape_helpers: added ind68f1594Sizeopset-13 issue: replaced with Shape+Gather in8788e908- Initializer name uniqueness:
final_outputincorporated into prefix in8788e908
Commit 8788e90 replaced the single Size node in _build_nbits_output_reshape with Shape + Gather for opset-11 compatibility, but the test expectations in _RESHAPE_HELPER_OP_COUNTS and the no_reshape_helpers sentinel were not updated. This caused four 3-D weight tests to fail check_op_type_count. Update the expected counts to Shape: 2, Gather: 1 (removing Size), and refresh the two adjacent op-sequence comments to match.
|
Thanks for the catch. Pushed 6668430 which updates the stale op-count expectations in
Verified locally — all four tests you flagged plus |
Addressed by commit 6668430 which updates the op-count expectations.
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
All prior concerns from previous review rounds have been addressed. The latest commit (6668430e20) correctly fixes the stale op-count expectations after the Size→Shape+Gather swap. Resolved the remaining open thread.
The dynamic reshape helper (_build_nbits_output_reshape) correctly implements ONNX MatMul broadcast semantics for all activation ranks (1-D, 2-D, and higher). The Max(a_rank, 2) handles 1-D promotion, opset-11 compatibility is maintained via Shape+Gather instead of Size, and initializer names incorporate final_output for uniqueness.
The elision logic (needs_reshape) avoids the 13-node reshape chain when rank(A) >= rank(B_orig) is statically known — the common transformer case.
Previously opened threads — all resolved
| Concern | Status |
|---|---|
| Stale op counts after Size→Shape+Gather | Fixed in 6668430e20 — resolved |
| QDQ-path comment ("MatMulNBits" → "MatMul") | Fixed in d68f1594 — resolved |
Reshape: 0 in no_reshape_helpers |
Fixed in d68f1594 — resolved |
Opset-13 Size dependency |
Replaced with Shape+Gather in 8788e908 — resolved |
| Initializer name uniqueness | Fixed in 8788e908 — resolved |
Remaining nitpick (not in diff)
Stale docstring: DefaultQuantizer.quantize_matmul (line 1044) still says "Currently only support 2D constant matrix and axis 0 blockwise quantization." Consider updating to reflect N-D support with unit leading dims.
Add test_quantize_matmul_int4_4d_weight_default exercising weight_shape (1, 1, 52, 288) against a 2-D activation to verify the squeeze/reshape helper handles rank > 3 (extra_count = 2 leading ones). Refresh the stale DefaultWeightOnlyQuantizer.quantize_matmul docstring to reflect N-D constant matrix support with unit leading dims.
|
Thanks for the approval, @tianleiwu. Addressed the two non-blocking nits in 9ed8d55:
Local |
Description
MatMulNBitsQuantizerpreviously skipped any MatMul whose weight initializer wasn't exactly 2-D, logging:This PR enables quantization of N-D weights that have unit leading batch dims (e.g.
[1, K, N],[1, 1, K, N]). For those, the quantizer squeezes the weight to 2-D[K, N], runs the existing Q4 quantization path, and appends aReshapenode after the quantized op to restore the original output shape following ONNX MatMul broadcast rules ([*b_shape[:-2], -1, b_shape[-1]]). Both the HQQ and Default (QOperator + QDQ) paths are updated consistently. Weights with non-unit batch dims (true batched matmul) keep the skip behavior but now emit a more specific message. 1-D weight inputs also remain a safe skip.Motivation and Context
Fixes #25362.
Users pre-rearrange LLM attention projection weights into 3-D at model-prep time to avoid runtime transposes (the issue shows typical attention code for
q_proj/k_proj/v_proj/o_proj). The Q8quantize_dynamictool already supports this; Q4 viamatmul_nbits_quantizerdid not, forcing users back to fp16 for those layers. This PR closes the gap for the common case without changing MatMul semantics.Testing
Added three unit tests in
test_op_matmul_4bits.pycovering the Default QOperator, Default QDQ, and HQQ paths with a[1, 52, 288]weight. Each asserts the expected quantized op count and the presence of the insertedReshapenode, plus end-to-end correctness viacheck_model_correctness. All 10 tests in the file pass (2 skipped — optionalneural_compressor-dependent GPTQ/RTN).