diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index d1067c15caeb2..296718bfbff39 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -697,9 +697,23 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP return [node] # only care about constant weight b_array = onnx.numpy_helper.to_array(b_pb) - if len(b_array.shape) != 2: - logger.info("MatMul weight is not 2D. Skip to quantize") - return [node] # can only process 2-D matrix + b_original_shape = b_array.shape + if len(b_original_shape) != 2: + if len(b_original_shape) < 2: + logger.info("MatMul weight has fewer than 2 dimensions. Skip to quantize.") + return [node] + leading = b_original_shape[:-2] + if any(d != 1 for d in leading): + logger.info( + "MatMul weight has non-unit batch dims %s; N-D batched quantization not supported. " + "Skip to quantize.", + list(leading), + ) + return [node] + # Squeeze all unit leading dims to obtain a 2-D weight [K, N] + b_array = b_array.reshape(b_original_shape[-2], b_original_shape[-1]) + else: + b_original_shape = None # already 2-D, no reshape needed b_array_torch = torch.from_numpy(b_array) if torch.cuda.is_available(): b_array_torch = b_array_torch.cuda() @@ -755,18 +769,42 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP kwargs["bits"] = bits kwargs["block_size"] = self.config.block_size + # When rank(A) >= rank_b_orig, the reshape would be a no-op since MatMulNBits + # already produces the correct output shape. + a_static_rank = _get_static_rank(node.input[0], graph_stack) + rank_b_orig = len(b_original_shape) if b_original_shape is not None else 0 + needs_reshape = b_original_shape is not None and (a_static_rank is None or a_static_rank < rank_b_orig) + matmul_q_output = node.output[0] if not needs_reshape else node.output[0] + "_pre_reshape" matmul_q_node = onnx.helper.make_node( "MatMulNBits", inputs=input_names, - outputs=[node.output[0]], + outputs=[matmul_q_output], name=node.name + "_Q" + str(bits) if node.name else "", domain="com.microsoft", **kwargs, ) + output_nodes = [matmul_q_node] + if needs_reshape: + # Restore ONNX MatMul broadcast shape on the output. + # MatMul(A, B_orig) output shape (with B_orig leading dims all 1) is: + # [1] * max(rank(B_orig) - rank(A), 0) + A.shape[:-1] + [N] + # MatMulNBits with squeezed B produces [...A.shape[:-1], N] (rank=rank(A)), + # so we only need to prepend leading 1s when rank(B_orig) > rank(A). + output_nodes.extend( + _build_nbits_output_reshape( + a_input_name=node.input[0], + b_original_shape=b_original_shape, + target_graph=bs_graph, + name_prefix=(node.name + "_Q" + str(bits)) if node.name else (b_pb.name + "_Q" + str(bits)), + pre_reshape_output=matmul_q_output, + final_output=node.output[0], + ) + ) + logger.info(f"complete quantization of {node.name} ...") - return [matmul_q_node] + return output_nodes def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: @@ -778,6 +816,155 @@ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, Gr return None, None +def _get_static_rank(tensor_name: str, graph_path: list[GraphProto]) -> int | None: + """Return the static rank of a tensor if its shape is known, else None. + + Searches graph inputs, value_info, and outputs in the graph stack (inner-most + graph first). A known shape requires ``HasField('shape')`` to be true on the + tensor_type; the rank is then ``len(shape.dim)``. Individual dim sizes may + still be symbolic — only the rank (dimension count) matters here. + """ + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for vi in list(graph.input) + list(graph.value_info) + list(graph.output): + if vi.name == tensor_name: + tt = vi.type.tensor_type + if tt.HasField("shape"): + return len(tt.shape.dim) + return None + return None + + +def _build_nbits_output_reshape( + a_input_name: str, + b_original_shape: tuple, + target_graph: GraphProto, + name_prefix: str, + pre_reshape_output: str, + final_output: str, +) -> list[NodeProto]: + """Build the reshape chain that restores the ONNX MatMul-broadcast output shape. + + MatMulNBits produces shape ``[...A_batch_dims, M, N]`` (rank = rank(A)). To match + the original ``MatMul(A, B_orig)`` output, where B_orig has all-unit leading + dims, we need: + + a_rank_eff = max(rank(A), 2) # ONNX promotes 1-D A to rank-2 + out_shape = [1] * max(rank(B_orig) - a_rank_eff, 0) + A.shape[:-1] + [N] + + This is built dynamically via Shape/Gather/Max/Sub/Max/ConstantOfShape/Slice/Concat + so it works regardless of A's static rank (handles rank(A) == 1, rank(A) == 2 + — the common transformer case — as well as rank(A) >= rank(B_orig) where no + leading-1 prepending is needed). All ops used are valid from opset 11 onward. + + Args: + a_input_name: name of the activation input edge (A) feeding MatMulNBits. + b_original_shape: the original (pre-squeeze) shape of B, e.g. ``(1, K, N)``. + target_graph: graph proto to append helper initializers into. + name_prefix: unique prefix for generated node/initializer names. + pre_reshape_output: name of the MatMulNBits output edge (the input of the + generated Reshape). + final_output: name of the final edge produced by the generated Reshape + (must match the original MatMul output edge). + + Returns: + List of nodes to append to the consumer's ``output_nodes`` after the + MatMulNBits node. Initializers are appended to ``target_graph`` in place. + """ + rank_b_orig = len(b_original_shape) + n_dim = b_original_shape[-1] + + # Incorporate the unique output tensor name so initializer names are unique + # even when multiple MatMul nodes share the same node.name (Fix 2). + p = name_prefix + "_" + final_output + init_zero = p + "_zero" + init_one = p + "_one" + init_two = p + "_two" + init_one_vec = p + "_one_vec" + init_rank_b = p + "_rank_b" + init_n_vec = p + "_n_vec" + init_zero_vec = p + "_zero_vec" + init_one_value_template = p + "_one_value" + + target_graph.initializer.extend( + [ + onnx.numpy_helper.from_array(np.array(0, dtype=np.int64), name=init_zero), + onnx.numpy_helper.from_array(np.array(1, dtype=np.int64), name=init_one), + onnx.numpy_helper.from_array(np.array(2, dtype=np.int64), name=init_two), + onnx.numpy_helper.from_array(np.array([1], dtype=np.int64), name=init_one_vec), + onnx.numpy_helper.from_array(np.array(rank_b_orig, dtype=np.int64), name=init_rank_b), + onnx.numpy_helper.from_array(np.array([n_dim], dtype=np.int64), name=init_n_vec), + onnx.numpy_helper.from_array(np.array([0], dtype=np.int64), name=init_zero_vec), + ] + ) + + a_shape = p + "_a_shape" + a_shape_of_shape = p + "_a_shape_of_shape" + a_rank = p + "_a_rank" + a_rank_eff = p + "_a_rank_eff" + extra_raw = p + "_extra_raw" + extra_count = p + "_extra_count" + extra_count_vec = p + "_extra_count_vec" + extra_ones = p + "_extra_ones" + a_rank_minus_one = p + "_a_rank_m1" + a_rank_minus_one_vec = p + "_a_rank_m1_vec" + a_prefix_shape = p + "_a_prefix_shape" + target_shape = p + "_target_shape" + + nodes = [ + onnx.helper.make_node("Shape", [a_input_name], [a_shape], name=p + "_shape_a"), + # Use Shape(shape) + Gather instead of Size to stay within opset 11 + # (Size requires opset >= 13 when applied to a shape tensor). + # Shape applied to the 1-D shape vector yields [rank_a] as a 1-element + # tensor; Gather with scalar index 0 extracts it as a scalar int64. + onnx.helper.make_node("Shape", [a_shape], [a_shape_of_shape], name=p + "_shape_of_a_shape"), + onnx.helper.make_node("Gather", [a_shape_of_shape, init_zero], [a_rank], name=p + "_gather_rank"), + # ONNX MatMul promotes a 1-D activation to rank-2 before computing the + # output shape, so use Max(a_rank, 2) as the effective rank when + # computing how many leading 1s to prepend. + onnx.helper.make_node("Max", [a_rank, init_two], [a_rank_eff], name=p + "_max_rank_eff"), + onnx.helper.make_node("Sub", [init_rank_b, a_rank_eff], [extra_raw], name=p + "_sub"), + onnx.helper.make_node("Max", [extra_raw, init_zero], [extra_count], name=p + "_max"), + onnx.helper.make_node("Reshape", [extra_count, init_one_vec], [extra_count_vec], name=p + "_reshape_extra"), + onnx.helper.make_node( + "ConstantOfShape", + [extra_count_vec], + [extra_ones], + name=p + "_const_ones", + value=onnx.helper.make_tensor( + name=init_one_value_template, + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + ), + ), + onnx.helper.make_node("Sub", [a_rank, init_one], [a_rank_minus_one], name=p + "_sub_one"), + onnx.helper.make_node( + "Reshape", [a_rank_minus_one, init_one_vec], [a_rank_minus_one_vec], name=p + "_reshape_rank_m1" + ), + onnx.helper.make_node( + "Slice", + [a_shape, init_zero_vec, a_rank_minus_one_vec, init_zero_vec], + [a_prefix_shape], + name=p + "_slice_a_prefix", + ), + onnx.helper.make_node( + "Concat", + [extra_ones, a_prefix_shape, init_n_vec], + [target_shape], + name=p + "_concat_target", + axis=0, + ), + onnx.helper.make_node( + "Reshape", + [pre_reshape_output, target_shape], + [final_output], + name=p + "_reshape_out", + ), + ] + return nodes + + # transpose int4 matrix (packed as uint8) def transpose_packed_int4_matrix(packed, rows, cols): # unpack to int4 matrix @@ -855,7 +1042,8 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: """ Quantize weight B of MatMul node to int4 or int8. - Currently only support 2D constant matrix and axis 0 blockwise quantization. + Supports 2D constant matrix, and N-D constant matrices whose leading dimensions are all 1 + (which are squeezed to 2D before quantization). Axis 0 blockwise quantization only. """ bits = self.config.bits if bits == 8: @@ -869,9 +1057,23 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis return [node] # only care about constant weight b_ndarray = ir.from_proto(b_tensor).numpy() - if len(b_ndarray.shape) != 2: - logger.info("MatMul weight is not 2D. Skip to quantize") - return [node] # can only process 2-D matrix + b_original_shape = b_ndarray.shape + if len(b_original_shape) != 2: + if len(b_original_shape) < 2: + logger.info("MatMul weight has fewer than 2 dimensions. Skip to quantize.") + return [node] + leading = b_original_shape[:-2] + if any(d != 1 for d in leading): + logger.info( + "MatMul weight has non-unit batch dims %s; N-D batched quantization not supported. " + "Skip to quantize.", + list(leading), + ) + return [node] + # Squeeze all unit leading dims to obtain a 2-D weight [K, N] + b_ndarray = b_ndarray.reshape(b_original_shape[-2], b_original_shape[-1]) + else: + b_original_shape = None # already 2-D, no reshape needed bfloat16 = b_ndarray.dtype == "bfloat16" if bfloat16: @@ -931,16 +1133,33 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis if self.config.accuracy_level: kwargs["accuracy_level"] = self.config.accuracy_level + # When rank(A) >= rank_b_orig, the reshape would be a no-op since MatMulNBits + # already produces the correct output shape. + a_static_rank = _get_static_rank(node.input[0], graph_stack) + rank_b_orig = len(b_original_shape) if b_original_shape is not None else 0 + needs_reshape = b_original_shape is not None and (a_static_rank is None or a_static_rank < rank_b_orig) + qop_output = node.output[0] if not needs_reshape else node.output[0] + "_pre_reshape" matmul_qbit_node = onnx.helper.make_node( "MatMulNBits", inputs=input_names, - outputs=[node.output[0]], + outputs=[qop_output], name=node.name + f"_Q{bits}" if node.name else "", domain="com.microsoft", **kwargs, ) output_nodes.append(matmul_qbit_node) + if needs_reshape: + output_nodes.extend( + _build_nbits_output_reshape( + a_input_name=node.input[0], + b_original_shape=b_original_shape, + target_graph=b_graph, + name_prefix=(node.name + f"_Q{bits}") if node.name else (b_tensor.name + f"_Q{bits}"), + pre_reshape_output=qop_output, + final_output=node.output[0], + ) + ) else: dq_input_names = [b_quant.name, scales_tensor.name] dq_output_names = [b_quant.name + "_output"] @@ -950,7 +1169,13 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis node.input[0], tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0], ] - matmul_output_names = [node.output[0]] + # When rank(A) >= rank_b_orig, the reshape would be a no-op since MatMul + # already produces the correct output shape. + a_static_rank = _get_static_rank(node.input[0], graph_stack) + rank_b_orig = len(b_original_shape) if b_original_shape is not None else 0 + needs_reshape = b_original_shape is not None and (a_static_rank is None or a_static_rank < rank_b_orig) + qdq_matmul_out = node.output[0] if not needs_reshape else node.output[0] + "_pre_reshape" + matmul_output_names = [qdq_matmul_out] if not self.config.is_symmetric: zp_tensor = onnx.helper.make_tensor( b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True @@ -985,6 +1210,17 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis output_nodes.extend([dq_node, tp_node, matmul_node]) else: output_nodes.extend([dq_node, matmul_node]) + if needs_reshape: + output_nodes.extend( + _build_nbits_output_reshape( + a_input_name=node.input[0], + b_original_shape=b_original_shape, + target_graph=b_graph, + name_prefix=(node.name + f"_DQ_Q{bits}") if node.name else (b_tensor.name + f"_DQ_Q{bits}"), + pre_reshape_output=qdq_matmul_out, + final_output=node.output[0], + ) + ) return output_nodes diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index b049a041cf0a7..2a07b34d37dc6 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -9,6 +9,7 @@ import unittest from importlib.util import find_spec from pathlib import Path +from typing import ClassVar import numpy as np import onnx @@ -187,6 +188,7 @@ def quant_test( quant_axes: tuple[tuple[str, int], ...] = (("MatMul", 0), ("Gather", 1)), rtol: float = 0.01, atol: float = 0.05, + extra_quant_nodes: dict | None = None, ): use_qdq = quant_format == quant_utils.QuantFormat.QDQ name_prefix = "QDQ" if use_qdq else "QOperator" @@ -213,6 +215,8 @@ def quant_test( quant_nodes = {"GatherBlockQuantized": 1} else: quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1} + if extra_quant_nodes: + quant_nodes.update(extra_quant_nodes) check_op_type_count(self, model_int4_path, **quant_nodes) if use_qdq: @@ -254,6 +258,7 @@ def quant_test_with_algo( data_reader: TestDataFeeds, block_size: int, is_symmetric: bool, + extra_quant_nodes: dict | None = None, ): model_int4_path = str( Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() @@ -282,6 +287,8 @@ def quant_test_with_algo( quant.model.save_model_to_file(model_int4_path, False) quant_nodes = {"MatMulNBits": 1} + if extra_quant_nodes: + quant_nodes.update(extra_quant_nodes) check_op_type_count(self, model_int4_path, **quant_nodes) data_reader.rewind() @@ -367,6 +374,228 @@ def test_quantize_matmul_int4_using_hqq_algo(self): data_reader = self.input_feeds(1, {"input": (100, 52)}) self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False) + def construct_model_matmul_3d(self, output_model_path: str, symmetric: bool, weight_shape: tuple) -> None: + # (input) + # | + # MatMul (weight has shape [1, K, N] -- unit leading dim) + # | + # (output) + input_name = "input" + output_name = "output" + initializers = [] + + weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name="linear1.weight")) + matmul_node = onnx.helper.make_node( + "MatMul", + [input_name, "linear1.weight"], + [output_name], + "MatMul_0", + ) + + # weight_shape = (b1, ..., K, N) -> input shape (-1, K), + # ONNX MatMul output shape = [b1, ..., -1, N] (leading batch dims of B come first) + k = weight_shape[-2] + n = weight_shape[-1] + leading = list(weight_shape[:-2]) + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, k]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [*leading, -1, n]) + graph = helper.make_graph( + [matmul_node], + "matmul_4bits_3d_test", + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 + onnx.save(model, output_model_path) + + # The output-restoration helper builds an ONNX-broadcast-correct target shape + # dynamically (Shape/Gather/Max/Sub/Max/ConstantOfShape/Slice/Concat → Reshape) so it + # works for arbitrary activation rank, including 1-D activations. The op-count + # expectations below reflect those helper ops. + _RESHAPE_HELPER_OP_COUNTS: ClassVar[dict[str, int]] = { + "Reshape": 3, + "Shape": 2, + "Gather": 1, + "Sub": 2, + "Max": 2, + "ConstantOfShape": 1, + "Slice": 1, + "Concat": 1, + } + + def test_quantize_matmul_int4_3d_weight_default(self): + """Test that Default quantizer handles weight with unit leading dim, e.g. shape [1, K, N].""" + np.random.seed(42) + weight_shape = (1, 52, 288) # unit leading dim + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_3d_default.onnx").absolute()) + self.construct_model_matmul_3d(model_fp32_path, symmetric=True, weight_shape=weight_shape) + data_reader = self.input_feeds(1, {"input": (100, 52)}) + self.quant_test(model_fp32_path, data_reader, 32, True, extra_quant_nodes=self._RESHAPE_HELPER_OP_COUNTS) + + def test_quantize_matmul_int4_3d_weight_default_qdq(self): + """Test QDQ format with unit leading dim 3D weight.""" + np.random.seed(42) + weight_shape = (1, 52, 288) + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_3d_default_qdq.onnx").absolute()) + self.construct_model_matmul_3d(model_fp32_path, symmetric=True, weight_shape=weight_shape) + data_reader = self.input_feeds(1, {"input": (100, 52)}) + self.quant_test( + model_fp32_path, + data_reader, + 32, + True, + quant_utils.QuantFormat.QDQ, + extra_quant_nodes=self._RESHAPE_HELPER_OP_COUNTS, + ) + + def test_quantize_matmul_int4_3d_weight_hqq(self): + """Test that HQQ quantizer handles weight with unit leading dim, e.g. shape [1, K, N].""" + if not find_spec("torch"): + self.skipTest("skip test since torch is not installed") + np.random.seed(42) + weight_shape = (1, 52, 288) + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_3d_hqq.onnx").absolute()) + self.construct_model_matmul_3d(model_fp32_path, symmetric=False, weight_shape=weight_shape) + data_reader = self.input_feeds(1, {"input": (100, 52)}) + self.quant_test_with_algo( + "HQQ", model_fp32_path, data_reader, 32, False, extra_quant_nodes=self._RESHAPE_HELPER_OP_COUNTS + ) + + def construct_model_matmul_3d_activation( + self, output_model_path: str, symmetric: bool, weight_shape: tuple, batch: int = 2, seq: int = 100 + ) -> None: + """Build a model with batched 3-D activation [batch, seq, K] and 3-D weight [1, K, N]. + + ONNX MatMul([batch, seq, K], [1, K, N]) -> [batch, seq, N]. The quantized graph + must preserve this shape; a naive reshape that flattens batch dims would + produce [1, batch*seq, N] and is observably wrong. + """ + input_name = "input" + output_name = "output" + initializers = [] + weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name="linear1.weight")) + matmul_node = onnx.helper.make_node( + "MatMul", + [input_name, "linear1.weight"], + [output_name], + "MatMul_0", + ) + k = weight_shape[-2] + n = weight_shape[-1] + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [batch, seq, k]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [batch, seq, n]) + graph = helper.make_graph( + [matmul_node], "matmul_4bits_3d_activation_test", [input_tensor], [output_tensor], initializer=initializers + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 + onnx.save(model, output_model_path) + + def test_quantize_matmul_int4_3d_weight_3d_activation_preserves_shape(self): + """ONNX MatMul([B,S,K],[1,K,N]) must yield [B,S,N] after quantization. + + When activation rank >= weight rank, MatMulNBits already produces the + correct output shape, so the post-op reshape is elided as a no-op. + The quantized graph therefore contains only MatMulNBits:1 and no + reshape helper ops (Reshape/Shape/Gather/Sub/Max/ConstantOfShape/Slice/Concat). + """ + np.random.seed(42) + weight_shape = (1, 52, 288) + batch, seq = 2, 100 + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_3d_act.onnx").absolute()) + self.construct_model_matmul_3d_activation( + model_fp32_path, symmetric=True, weight_shape=weight_shape, batch=batch, seq=seq + ) + data_reader = self.input_feeds(1, {"input": (batch, seq, weight_shape[-2])}) + # When activation rank (3) >= weight rank (3), no reshape helpers are needed. + # Assert explicitly that none of the helper ops appear in the quantized graph. + no_reshape_helpers = { + "Shape": 0, + "Gather": 0, + "Sub": 0, + "Max": 0, + "ConstantOfShape": 0, + "Slice": 0, + "Concat": 0, + "Reshape": 0, + } + self.quant_test(model_fp32_path, data_reader, 32, True, extra_quant_nodes=no_reshape_helpers) + + # The check_model_correctness inside quant_test runs both the FP32 and + # quantized models on the same input and compares outputs. If the + # quantized graph collapsed batch dims (output [1, B*S, N] instead of + # [B, S, N]), that comparison would fail before we even reach here. + + def construct_model_matmul_3d_weight_1d_activation( + self, output_model_path: str, symmetric: bool, weight_shape: tuple + ) -> None: + """Build a model with 1-D activation [K] and 3-D weight [1, K, N]. + + ONNX MatMul([K], [1, K, N]) -> [1, N]. The quantized graph must preserve + this output shape; the reshape helper computes extra_count=1 (since + a_rank_eff=max(1,2)=2 and rank_b_orig=3) so a single leading 1 is prepended. + """ + input_name = "input" + output_name = "output" + initializers = [] + weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name="linear1.weight")) + matmul_node = onnx.helper.make_node( + "MatMul", + [input_name, "linear1.weight"], + [output_name], + "MatMul_0", + ) + k = weight_shape[-2] + n = weight_shape[-1] + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [k]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [1, n]) + graph = helper.make_graph( + [matmul_node], + "matmul_4bits_1d_activation_test", + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 + onnx.save(model, output_model_path) + + def test_quantize_matmul_int4_3d_weight_1d_activation_preserves_shape(self): + """ONNX MatMul([K],[1,K,N]) must yield [1,N] after quantization. + + When the activation is 1-D, ONNX MatMul promotes it to rank-2 internally + before broadcasting against the weight. The reshape helper must use + max(rank(A), 2) as the effective activation rank so that exactly one + leading 1 is prepended, giving output shape [1, N]. + """ + np.random.seed(42) + weight_shape = (1, 52, 288) + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_1d_act.onnx").absolute()) + self.construct_model_matmul_3d_weight_1d_activation(model_fp32_path, symmetric=True, weight_shape=weight_shape) + data_reader = self.input_feeds(1, {"input": (weight_shape[-2],)}) + # rank(A)=1, rank(B_orig)=3: a_rank_eff=max(1,2)=2, extra_count=1, so the + # reshape helper IS emitted. Verify the full helper op-count and that + # check_model_correctness confirms the output shape is [1, N]. + self.quant_test(model_fp32_path, data_reader, 32, True, extra_quant_nodes=self._RESHAPE_HELPER_OP_COUNTS) + + def test_quantize_matmul_int4_4d_weight_default(self): + """Test that Default quantizer handles weight with two unit leading dims, e.g. shape [1, 1, K, N]. + + For activation rank 2 and weight rank 4, extra_count = max(4 - max(2, 2), 0) = 2, + so the reshape helper prepends two leading 1s to the MatMulNBits output. + """ + np.random.seed(42) + weight_shape = (1, 1, 52, 288) # two unit leading dims + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_4d_default.onnx").absolute()) + self.construct_model_matmul_3d(model_fp32_path, symmetric=True, weight_shape=weight_shape) + data_reader = self.input_feeds(1, {"input": (100, 52)}) + self.quant_test(model_fp32_path, data_reader, 32, True, extra_quant_nodes=self._RESHAPE_HELPER_OP_COUNTS) + if __name__ == "__main__": unittest.main()