Skip to content

fix: support N-D weights with unit leading dims in MatMulNBitsQuantizer#28178

Merged
tianleiwu merged 8 commits into
microsoft:mainfrom
Rishi-Dave:rishidave/feat/matmul-nbits-nd-weight
May 14, 2026
Merged

fix: support N-D weights with unit leading dims in MatMulNBitsQuantizer#28178
tianleiwu merged 8 commits into
microsoft:mainfrom
Rishi-Dave:rishidave/feat/matmul-nbits-nd-weight

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Description

MatMulNBitsQuantizer previously skipped any MatMul whose weight initializer wasn't exactly 2-D, logging:

MatMul weight is not 2D. Skip to quantize

This PR enables quantization of N-D weights that have unit leading batch dims (e.g. [1, K, N], [1, 1, K, N]). For those, the quantizer squeezes the weight to 2-D [K, N], runs the existing Q4 quantization path, and appends a Reshape node after the quantized op to restore the original output shape following ONNX MatMul broadcast rules ([*b_shape[:-2], -1, b_shape[-1]]). Both the HQQ and Default (QOperator + QDQ) paths are updated consistently. Weights with non-unit batch dims (true batched matmul) keep the skip behavior but now emit a more specific message. 1-D weight inputs also remain a safe skip.

Motivation and Context

Fixes #25362.

Users pre-rearrange LLM attention projection weights into 3-D at model-prep time to avoid runtime transposes (the issue shows typical attention code for q_proj/k_proj/v_proj/o_proj). The Q8 quantize_dynamic tool already supports this; Q4 via matmul_nbits_quantizer did not, forcing users back to fp16 for those layers. This PR closes the gap for the common case without changing MatMul semantics.

Testing

Added three unit tests in test_op_matmul_4bits.py covering the Default QOperator, Default QDQ, and HQQ paths with a [1, 52, 288] weight. Each asserts the expected quantized op count and the presence of the inserted Reshape node, plus end-to-end correctness via check_model_correctness. All 10 tests in the file pass (2 skipped — optional neural_compressor-dependent GPTQ/RTN).

MatMulNBitsQuantizer silently skipped quantizing MatMul nodes whose
weight tensor had more than 2 dimensions (e.g. [1, K, N]), logging
"MatMul weight is not 2D. Skip to quantize." This blocked quantization
of LLM attention projection weights pre-reshaped to 3-D.

When all leading dimensions of the weight are 1, squeeze them to get
a 2-D [K, N] view, run the existing quantization logic, then insert a
Reshape node after MatMulNBits (or the final MatMul in QDQ format) to
restore the expected N-D output shape dictated by ONNX MatMul batch
broadcasting semantics ([b1, ..., batch_A, N]).

For weights with genuinely non-unit batch dimensions (true batched
matmuls), the skip is preserved with a clearer log message explaining
why quantization cannot proceed safely.

Fixes: both HQQWeightOnlyQuantizer.quantize() and
DefaultWeightOnlyQuantizer.quantize_matmul() (QOperator and QDQ paths).

Closes microsoft#25362
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables MatMulNBitsQuantizer to quantize MatMul nodes whose weight initializers are N-D tensors with unit leading (“batch”) dimensions by squeezing weights to 2-D for quantization and attempting to restore the original MatMul output shape.

Changes:

  • Update HQQ + Default quantization paths to accept N-D weights with all leading dims == 1 (skip otherwise with a more specific log message).
  • Insert a Reshape after the quantized op (and after QDQ MatMul) to restore an output shape derived from the original weight shape.
  • Add unit tests to validate quantization and presence of the inserted Reshape node for 3D weights in Default (QOperator/QDQ) and HQQ paths.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py Adds N-D unit-leading-dim weight support by squeezing weights and inserting output Reshape nodes in HQQ + Default (QOperator/QDQ) paths.
onnxruntime/test/python/quantization/test_op_matmul_4bits.py Adds tests and test helpers for 3D weight MatMul models; adds extra_quant_nodes plumbing to assert Reshape insertion.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py Outdated
Comment thread onnxruntime/test/python/quantization/test_op_matmul_4bits.py Outdated
Comment thread onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py Outdated
Comment thread onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py Outdated
…ut reshape

The previous post-MatMulNBits Reshape used [*b_original_shape[:-2], -1, N] as the
target. For batched activations (e.g., A=[B,M,K], B=[1,K,N]) this collapses A's
batch dims into the -1 sentinel and produces [1, B*M, N] instead of the ONNX
broadcast result [B, M, N].

Replace the static target shape with a small dynamic graph (Shape, Size, Sub,
Max, ConstantOfShape, Slice, Concat → Reshape) that prepends only the leading
1s required by max(rank(B_orig) - rank(A), 0) and preserves A.shape[:-1] as-is.
This is correct for arbitrary activation rank; in particular the rank(A) == 2
case still produces [1, M, N] and the new rank(A) == 3 case produces [B, S, N].

Wraps the construction in a single helper used at all three call sites
(QOperator MatMulNBits, QDQ MatMulNBits, QDQ DQ→MatMul). Adds a regression test
that uses a 3-D activation [B, S, K] with weight [1, K, N] and asserts both the
op set and full-graph numerical correctness.
Comment thread onnxruntime/test/python/quantization/test_op_matmul_4bits.py Fixed
tianleiwu
tianleiwu previously approved these changes May 3, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The dynamic reshape helper (_build_nbits_output_reshape) correctly implements the ONNX MatMul broadcast rule:

out_shape = [1] * max(rank(B_orig) - rank(A), 0) + A.shape[:-1] + [N]

using Shape/Size/Sub/Max/ConstantOfShape/Slice/Concat — making it work for any activation rank without static knowledge at quantization time. Previous review concerns about the static reshape have been fully addressed.

Positives:

  • Clean separation into a reusable module-level helper shared by both HQQ and Default paths
  • Correct handling of the extra_count = 0 edge case (rank(A) >= rank(B_orig)) where ConstantOfShape produces an empty tensor
  • Good test coverage including the critical 3-D activation case that validates batch dims aren't flattened
  • End-to-end correctness via check_model_correctness ensures numerical accuracy is maintained

Minor suggestions (non-blocking):

  1. The existing lintrunner RUF012 finding on _RESHAPE_HELPER_OP_COUNTS should be addressed (annotate with typing.ClassVar)
  2. When rank(A) >= rank(B_orig), the reshape chain produces a no-op (target shape == output shape). A short-circuit when A's static rank is known could avoid 11 extra nodes in the unoptimized graph, but ONNX optimizers should handle this fine

Comment thread onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py
…A) >= rank(B_orig)

When the activation rank is statically known to be >= the original weight
rank, the post-MatMulNBits Reshape (and its dynamic-shape helper ops) is
a no-op since MatMulNBits already produces the correct output shape.

- Add `_get_static_rank` helper to read a tensor's static rank from the
  graph stack (input/value_info/output).
- Apply the same `needs_reshape` guard at all three quantize_matmul call
  sites (HQQ, Default QOperator, Default QDQ); fall back to emitting
  the reshape when the rank is unknown or smaller than rank_b_orig.
- Annotate `_RESHAPE_HELPER_OP_COUNTS` with `typing.ClassVar` (RUF012).
- Update test_quantize_matmul_int4_3d_weight_3d_activation_preserves_shape
  to assert the new optimized op-count (no reshape helpers); other 3-D
  weight tests still use 2-D activations and continue to exercise the
  reshape path.
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the review. Pushed a follow-up addressing both nits in 3faf2f5:

  • Annotated _RESHAPE_HELPER_OP_COUNTS with typing.ClassVar[dict[str, int]] (RUF012).
  • Added a _get_static_rank helper and a needs_reshape guard at all three quantize_matmul call sites (HQQ, Default QOperator, QDQ): when the activation's static rank is known and >= rank_b_orig, the post-op Reshape is elided since MatMulNBits already produces the correct shape. When the rank is unknown or smaller, the existing dynamic-shape reshape path is preserved.

Updated test_quantize_matmul_int4_3d_weight_3d_activation_preserves_shape (3-D act × 3-D weight) to assert the new optimized op-count — only MatMulNBits:1, no reshape helpers. The other three 3-D weight tests use 2-D activations, so they still exercise the reshape helper path unchanged.

Targeted tests: test_op_matmul_4bits.py 11 passed, 2 skipped (gptq/rtn optional deps), 0 failed. lintrunner -a clean.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py
Comment thread onnxruntime/test/python/quantization/test_op_matmul_4bits.py Outdated
ONNX MatMul promotes a 1-D activation A to rank-2 before computing the
output shape, so deriving the prepended-ones count from raw rank(A)
produced wrong shapes for 1-D activations (e.g. MatMul([K], [1,K,N])
reshaped to [1,1,N] instead of [1,N]).

Compute a_rank_eff = Max(rank(A), 2) and use it when computing the
extra-ones count. Tighten the 3-D-weight/3-D-activation test to assert
all reshape-helper ops are absent in the elision path, and add a new
3-D-weight/1-D-activation regression test that exercises the helper
and verifies output-shape preservation.
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the review. Pushed 896b393 addressing both items:

  1. _build_nbits_output_reshape: ONNX MatMul promotes a 1-D A to rank-2 before broadcasting, so deriving the prepended-ones count from raw rank(A) was wrong for 1-D activations (e.g. MatMul([K], [1,K,N]) reshaped to [1,1,N] instead of [1,N]). Now I compute a_rank_eff = Max(rank(A), 2) via a new scalar-2 initializer and Max node, and feed a_rank_eff into the Sub(rank_b_orig, a_rank_eff) that produces the extras count. The downstream Slice(Shape(A), 0, a_rank-1) still uses raw a_rank, which is correct: for 1-D it slices [0:0] and yields an empty prefix, so Concat([1], [], [N]) = [1,N].

  2. test_quantize_matmul_int4_3d_weight_3d_activation_preserves_shape: now passes extra_quant_nodes asserting count 0 for all seven reshape-helper ops (Shape/Size/Sub/Max/ConstantOfShape/Slice/Concat), so leftover helper nodes can no longer slip through. Bumped _RESHAPE_HELPER_OP_COUNTS["Max"] from 1 to 2 to account for the new node.

Also added test_quantize_matmul_int4_3d_weight_1d_activation_preserves_shape as a regression covering the 1-D activation path end-to-end.

tianleiwu
tianleiwu previously approved these changes May 7, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The dynamic reshape helper (_build_nbits_output_reshape) correctly implements the ONNX MatMul broadcast rule and properly handles 1-D, 2-D, and higher-rank activations. The Max(a_rank, 2) effectively accounts for ONNX MatMul's 1-D promotion semantics, and the static-rank elision avoids inserting dead nodes for the dominant transformer case.

All previously raised concerns (1-D activation bug, batch-dim flattening, no-op reshape elision, test assertion gaps) have been addressed in the latest commits. Test coverage is comprehensive across QOperator, QDQ, and HQQ paths with 2-D, 3-D, and 1-D activation shapes.

Two minor suggestions below — neither is blocking.

Comment thread onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py Outdated
Comment thread onnxruntime/test/python/quantization/test_op_matmul_4bits.py
…ertion

- Fix stale comment in DefaultWeightOnlyQuantizer QDQ path: the no-op
  reshape skip happens on the standard MatMul node (post-DequantizeLinear),
  not on MatMulNBits.
- Add "Reshape": 0 to no_reshape_helpers so the test explicitly asserts
  no Reshape node is inserted when the helper chain should be absent.
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Addressed both nits in d68f159:

  • matmul_nbits_quantizer.py L1163: corrected the inline comment in the QDQ branch from "MatMulNBits" to "MatMul" — that path emits a standard MatMul after DequantizeLinear, so the no-op-reshape rationale is about the MatMul output shape.
  • test_op_matmul_4bits.py: added "Reshape": 0 to no_reshape_helpers so the assertion explicitly fails if a stray Reshape gets inserted into the helper-chain-absent path.

Targeted test_op_matmul_4bits.py run is green (12 passed, 2 pre-existing skips) and lintrunner -a is clean on the two touched files.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

Comment thread onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py
Comment thread onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py Outdated
Address review feedback on MatMulNBits output-reshape helper:

- Replace Size(a_shape) with Shape(a_shape) + Gather(idx=0). The Size op
  on a shape tensor requires ai.onnx opset >= 13, but MatMulNBitsQuantizer
  does not bump the model's opset, so opset 11/12 inputs would produce an
  invalid graph. Shape and Gather are valid from opset 1/11 and yield the
  same scalar int64 rank value.

- Incorporate the unique output tensor name (node.output[0]) into the
  initializer name prefix. The prior prefix was node.name, which is not
  guaranteed unique across a model and would collide when multiple MatMul
  nodes share a name. node.output[0] is unique per the ONNX spec.
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the review. Addressed both items in 8788e90:

  1. Opset compatibility (_build_nbits_output_reshape, line ~914): Replaced Size(a_shape) with Shape(a_shape) + Gather(idx=0). Size requires ai.onnx opset >= 13, and MatMulNBitsQuantizer.process() does not bump the model opset, so opset 11/12 inputs would have produced an invalid graph. Shape and Gather are valid from opset 1/11; with a 0-D scalar index, Gather yields the same scalar int64 rank value that Size produced, so all downstream Max/Sub consumers are unchanged.

  2. Unique initializer names (line ~881): Incorporated final_output (which is always node.output[0], unique per ONNX spec) into the prefix used for the helper's initializers and intermediate tensors. Previously the prefix was node.name, which is not guaranteed unique; two MatMul nodes sharing a name would have collided. Verified all three call sites pass node.output[0] as final_output.

lintrunner -a clean.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up review after commits d68f1594 and 8788e908.

The QDQ-path comment fix and the Reshape: 0 assertion are correct — resolved those threads. However, the opset-11 compatibility change (Size → Shape+Gather) introduced a test expectation mismatch that will cause four tests to fail.

Previously open threads from the prior review round have been re-evaluated and resolved:

  • QDQ-path comment ("MatMulNBits" → "MatMul"): fixed in d68f1594
  • Reshape: 0 in no_reshape_helpers: added in d68f1594
  • Size opset-13 issue: replaced with Shape+Gather in 8788e908
  • Initializer name uniqueness: final_output incorporated into prefix in 8788e908

Comment thread onnxruntime/test/python/quantization/test_op_matmul_4bits.py Outdated
Commit 8788e90 replaced the single Size node in _build_nbits_output_reshape
with Shape + Gather for opset-11 compatibility, but the test expectations
in _RESHAPE_HELPER_OP_COUNTS and the no_reshape_helpers sentinel were not
updated. This caused four 3-D weight tests to fail check_op_type_count.

Update the expected counts to Shape: 2, Gather: 1 (removing Size), and
refresh the two adjacent op-sequence comments to match.
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the catch. Pushed 6668430 which updates the stale op-count expectations in test_op_matmul_4bits.py to match the SizeShape+Gather change from 8788e90:

  • _RESHAPE_HELPER_OP_COUNTS: Shape: 1, Size: 1Shape: 2, Gather: 1 (all other counts unchanged).
  • Comment above the dict updated from Shape/Size/Max/Sub/… to Shape/Gather/Max/Sub/….
  • no_reshape_helpers in the elision test: Size: 0Gather: 0, so the no-op path also asserts no stray Gather is introduced.

Verified locally — all four tests you flagged plus test_quantize_matmul_int4_3d_weight_3d_activation_preserves_shape now pass on the branch tip.

@tianleiwu tianleiwu dismissed their stale review May 13, 2026 20:06

Addressed by commit 6668430 which updates the op-count expectations.

tianleiwu
tianleiwu previously approved these changes May 13, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

All prior concerns from previous review rounds have been addressed. The latest commit (6668430e20) correctly fixes the stale op-count expectations after the Size→Shape+Gather swap. Resolved the remaining open thread.

The dynamic reshape helper (_build_nbits_output_reshape) correctly implements ONNX MatMul broadcast semantics for all activation ranks (1-D, 2-D, and higher). The Max(a_rank, 2) handles 1-D promotion, opset-11 compatibility is maintained via Shape+Gather instead of Size, and initializer names incorporate final_output for uniqueness.

The elision logic (needs_reshape) avoids the 13-node reshape chain when rank(A) >= rank(B_orig) is statically known — the common transformer case.

Previously opened threads — all resolved

Concern Status
Stale op counts after Size→Shape+Gather Fixed in 6668430e20 — resolved
QDQ-path comment ("MatMulNBits" → "MatMul") Fixed in d68f1594 — resolved
Reshape: 0 in no_reshape_helpers Fixed in d68f1594 — resolved
Opset-13 Size dependency Replaced with Shape+Gather in 8788e908 — resolved
Initializer name uniqueness Fixed in 8788e908 — resolved

Remaining nitpick (not in diff)

Stale docstring: DefaultQuantizer.quantize_matmul (line 1044) still says "Currently only support 2D constant matrix and axis 0 blockwise quantization." Consider updating to reflect N-D support with unit leading dims.

Comment thread onnxruntime/test/python/quantization/test_op_matmul_4bits.py
Add test_quantize_matmul_int4_4d_weight_default exercising weight_shape
(1, 1, 52, 288) against a 2-D activation to verify the squeeze/reshape
helper handles rank > 3 (extra_count = 2 leading ones).

Refresh the stale DefaultWeightOnlyQuantizer.quantize_matmul docstring
to reflect N-D constant matrix support with unit leading dims.
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the approval, @tianleiwu. Addressed the two non-blocking nits in 9ed8d55:

  • Added test_quantize_matmul_int4_4d_weight_default with weight_shape = (1, 1, 52, 288) and a 2-D activation. This exercises the rank-4 path through the squeeze/reshape helper (extra_count = max(4 - max(2, 2), 0) = 2 leading ones) and asserts against _RESHAPE_HELPER_OP_COUNTS, matching the structure of the existing 3-D test.
  • Refreshed the stale DefaultWeightOnlyQuantizer.quantize_matmul docstring to reflect that N-D constant matrices with unit leading dims are now supported (squeezed to 2-D before quantization).

Local pytest -k "3d_weight_default or 4d_weight_default" is green and lintrunner -a is clean. Happy to drop the docstring touch from this PR if you'd rather keep the diff strictly scoped.

@tianleiwu tianleiwu merged commit 2979bab into microsoft:main May 14, 2026
88 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] The MatMulNBits matmul_nbits_quantizer does not support 3D weight tensors.

4 participants