Skip to content
Merged
258 changes: 247 additions & 11 deletions onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,23 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
return [node] # only care about constant weight

b_array = onnx.numpy_helper.to_array(b_pb)
if len(b_array.shape) != 2:
logger.info("MatMul weight is not 2D. Skip to quantize")
return [node] # can only process 2-D matrix
b_original_shape = b_array.shape
if len(b_original_shape) != 2:
if len(b_original_shape) < 2:
logger.info("MatMul weight has fewer than 2 dimensions. Skip to quantize.")
return [node]
leading = b_original_shape[:-2]
if any(d != 1 for d in leading):
logger.info(
"MatMul weight has non-unit batch dims %s; N-D batched quantization not supported. "
"Skip to quantize.",
list(leading),
)
return [node]
# Squeeze all unit leading dims to obtain a 2-D weight [K, N]
b_array = b_array.reshape(b_original_shape[-2], b_original_shape[-1])
else:
b_original_shape = None # already 2-D, no reshape needed
b_array_torch = torch.from_numpy(b_array)
if torch.cuda.is_available():
b_array_torch = b_array_torch.cuda()
Expand Down Expand Up @@ -755,18 +769,42 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
kwargs["bits"] = bits
kwargs["block_size"] = self.config.block_size

# When rank(A) >= rank_b_orig, the reshape would be a no-op since MatMulNBits
# already produces the correct output shape.
a_static_rank = _get_static_rank(node.input[0], graph_stack)
rank_b_orig = len(b_original_shape) if b_original_shape is not None else 0
needs_reshape = b_original_shape is not None and (a_static_rank is None or a_static_rank < rank_b_orig)
matmul_q_output = node.output[0] if not needs_reshape else node.output[0] + "_pre_reshape"
matmul_q_node = onnx.helper.make_node(
"MatMulNBits",
inputs=input_names,
outputs=[node.output[0]],
outputs=[matmul_q_output],
name=node.name + "_Q" + str(bits) if node.name else "",
domain="com.microsoft",
**kwargs,
)

output_nodes = [matmul_q_node]
if needs_reshape:
# Restore ONNX MatMul broadcast shape on the output.
# MatMul(A, B_orig) output shape (with B_orig leading dims all 1) is:
# [1] * max(rank(B_orig) - rank(A), 0) + A.shape[:-1] + [N]
# MatMulNBits with squeezed B produces [...A.shape[:-1], N] (rank=rank(A)),
# so we only need to prepend leading 1s when rank(B_orig) > rank(A).
output_nodes.extend(
_build_nbits_output_reshape(
a_input_name=node.input[0],
b_original_shape=b_original_shape,
target_graph=bs_graph,
name_prefix=(node.name + "_Q" + str(bits)) if node.name else (b_pb.name + "_Q" + str(bits)),
pre_reshape_output=matmul_q_output,
final_output=node.output[0],
)
)

logger.info(f"complete quantization of {node.name} ...")

return [matmul_q_node]
return output_nodes


def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
Expand All @@ -778,6 +816,155 @@ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, Gr
return None, None


def _get_static_rank(tensor_name: str, graph_path: list[GraphProto]) -> int | None:
"""Return the static rank of a tensor if its shape is known, else None.

Searches graph inputs, value_info, and outputs in the graph stack (inner-most
graph first). A known shape requires ``HasField('shape')`` to be true on the
tensor_type; the rank is then ``len(shape.dim)``. Individual dim sizes may
still be symbolic — only the rank (dimension count) matters here.
"""
for gid in range(len(graph_path) - 1, -1, -1):
graph = graph_path[gid]
for vi in list(graph.input) + list(graph.value_info) + list(graph.output):
if vi.name == tensor_name:
tt = vi.type.tensor_type
if tt.HasField("shape"):
return len(tt.shape.dim)
return None
return None


def _build_nbits_output_reshape(
a_input_name: str,
b_original_shape: tuple,
target_graph: GraphProto,
name_prefix: str,
pre_reshape_output: str,
final_output: str,
) -> list[NodeProto]:
"""Build the reshape chain that restores the ONNX MatMul-broadcast output shape.

MatMulNBits produces shape ``[...A_batch_dims, M, N]`` (rank = rank(A)). To match
the original ``MatMul(A, B_orig)`` output, where B_orig has all-unit leading
dims, we need:

a_rank_eff = max(rank(A), 2) # ONNX promotes 1-D A to rank-2
out_shape = [1] * max(rank(B_orig) - a_rank_eff, 0) + A.shape[:-1] + [N]

This is built dynamically via Shape/Gather/Max/Sub/Max/ConstantOfShape/Slice/Concat
so it works regardless of A's static rank (handles rank(A) == 1, rank(A) == 2
— the common transformer case — as well as rank(A) >= rank(B_orig) where no
leading-1 prepending is needed). All ops used are valid from opset 11 onward.

Args:
a_input_name: name of the activation input edge (A) feeding MatMulNBits.
b_original_shape: the original (pre-squeeze) shape of B, e.g. ``(1, K, N)``.
target_graph: graph proto to append helper initializers into.
name_prefix: unique prefix for generated node/initializer names.
pre_reshape_output: name of the MatMulNBits output edge (the input of the
generated Reshape).
final_output: name of the final edge produced by the generated Reshape
(must match the original MatMul output edge).

Returns:
List of nodes to append to the consumer's ``output_nodes`` after the
MatMulNBits node. Initializers are appended to ``target_graph`` in place.
"""
rank_b_orig = len(b_original_shape)
n_dim = b_original_shape[-1]

# Incorporate the unique output tensor name so initializer names are unique
# even when multiple MatMul nodes share the same node.name (Fix 2).
p = name_prefix + "_" + final_output
init_zero = p + "_zero"
init_one = p + "_one"
init_two = p + "_two"
init_one_vec = p + "_one_vec"
init_rank_b = p + "_rank_b"
init_n_vec = p + "_n_vec"
init_zero_vec = p + "_zero_vec"
init_one_value_template = p + "_one_value"

target_graph.initializer.extend(
[
onnx.numpy_helper.from_array(np.array(0, dtype=np.int64), name=init_zero),
onnx.numpy_helper.from_array(np.array(1, dtype=np.int64), name=init_one),
onnx.numpy_helper.from_array(np.array(2, dtype=np.int64), name=init_two),
onnx.numpy_helper.from_array(np.array([1], dtype=np.int64), name=init_one_vec),
onnx.numpy_helper.from_array(np.array(rank_b_orig, dtype=np.int64), name=init_rank_b),
onnx.numpy_helper.from_array(np.array([n_dim], dtype=np.int64), name=init_n_vec),
onnx.numpy_helper.from_array(np.array([0], dtype=np.int64), name=init_zero_vec),
]
)

a_shape = p + "_a_shape"
a_shape_of_shape = p + "_a_shape_of_shape"
a_rank = p + "_a_rank"
a_rank_eff = p + "_a_rank_eff"
extra_raw = p + "_extra_raw"
extra_count = p + "_extra_count"
extra_count_vec = p + "_extra_count_vec"
extra_ones = p + "_extra_ones"
a_rank_minus_one = p + "_a_rank_m1"
a_rank_minus_one_vec = p + "_a_rank_m1_vec"
a_prefix_shape = p + "_a_prefix_shape"
target_shape = p + "_target_shape"

nodes = [
onnx.helper.make_node("Shape", [a_input_name], [a_shape], name=p + "_shape_a"),
# Use Shape(shape) + Gather instead of Size to stay within opset 11
# (Size requires opset >= 13 when applied to a shape tensor).
# Shape applied to the 1-D shape vector yields [rank_a] as a 1-element
# tensor; Gather with scalar index 0 extracts it as a scalar int64.
onnx.helper.make_node("Shape", [a_shape], [a_shape_of_shape], name=p + "_shape_of_a_shape"),
onnx.helper.make_node("Gather", [a_shape_of_shape, init_zero], [a_rank], name=p + "_gather_rank"),
# ONNX MatMul promotes a 1-D activation to rank-2 before computing the
Comment thread
tianleiwu marked this conversation as resolved.
# output shape, so use Max(a_rank, 2) as the effective rank when
# computing how many leading 1s to prepend.
onnx.helper.make_node("Max", [a_rank, init_two], [a_rank_eff], name=p + "_max_rank_eff"),
onnx.helper.make_node("Sub", [init_rank_b, a_rank_eff], [extra_raw], name=p + "_sub"),
onnx.helper.make_node("Max", [extra_raw, init_zero], [extra_count], name=p + "_max"),
onnx.helper.make_node("Reshape", [extra_count, init_one_vec], [extra_count_vec], name=p + "_reshape_extra"),
Comment thread
tianleiwu marked this conversation as resolved.
onnx.helper.make_node(
"ConstantOfShape",
[extra_count_vec],
[extra_ones],
name=p + "_const_ones",
value=onnx.helper.make_tensor(
name=init_one_value_template,
data_type=TensorProto.INT64,
dims=[1],
vals=[1],
),
),
onnx.helper.make_node("Sub", [a_rank, init_one], [a_rank_minus_one], name=p + "_sub_one"),
onnx.helper.make_node(
"Reshape", [a_rank_minus_one, init_one_vec], [a_rank_minus_one_vec], name=p + "_reshape_rank_m1"
),
onnx.helper.make_node(
"Slice",
[a_shape, init_zero_vec, a_rank_minus_one_vec, init_zero_vec],
[a_prefix_shape],
name=p + "_slice_a_prefix",
),
onnx.helper.make_node(
"Concat",
[extra_ones, a_prefix_shape, init_n_vec],
[target_shape],
name=p + "_concat_target",
axis=0,
),
onnx.helper.make_node(
"Reshape",
[pre_reshape_output, target_shape],
[final_output],
name=p + "_reshape_out",
),
]
return nodes


# transpose int4 matrix (packed as uint8)
def transpose_packed_int4_matrix(packed, rows, cols):
# unpack to int4 matrix
Comment thread
tianleiwu marked this conversation as resolved.
Expand Down Expand Up @@ -855,7 +1042,8 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
"""
Quantize weight B of MatMul node to int4 or int8.
Currently only support 2D constant matrix and axis 0 blockwise quantization.
Supports 2D constant matrix, and N-D constant matrices whose leading dimensions are all 1
(which are squeezed to 2D before quantization). Axis 0 blockwise quantization only.
"""
bits = self.config.bits
if bits == 8:
Expand All @@ -869,9 +1057,23 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
return [node] # only care about constant weight

b_ndarray = ir.from_proto(b_tensor).numpy()
if len(b_ndarray.shape) != 2:
logger.info("MatMul weight is not 2D. Skip to quantize")
return [node] # can only process 2-D matrix
b_original_shape = b_ndarray.shape
if len(b_original_shape) != 2:
if len(b_original_shape) < 2:
logger.info("MatMul weight has fewer than 2 dimensions. Skip to quantize.")
return [node]
leading = b_original_shape[:-2]
if any(d != 1 for d in leading):
logger.info(
"MatMul weight has non-unit batch dims %s; N-D batched quantization not supported. "
"Skip to quantize.",
list(leading),
)
return [node]
# Squeeze all unit leading dims to obtain a 2-D weight [K, N]
b_ndarray = b_ndarray.reshape(b_original_shape[-2], b_original_shape[-1])
else:
b_original_shape = None # already 2-D, no reshape needed

bfloat16 = b_ndarray.dtype == "bfloat16"
if bfloat16:
Expand Down Expand Up @@ -931,16 +1133,33 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
if self.config.accuracy_level:
kwargs["accuracy_level"] = self.config.accuracy_level

# When rank(A) >= rank_b_orig, the reshape would be a no-op since MatMulNBits
# already produces the correct output shape.
a_static_rank = _get_static_rank(node.input[0], graph_stack)
rank_b_orig = len(b_original_shape) if b_original_shape is not None else 0
needs_reshape = b_original_shape is not None and (a_static_rank is None or a_static_rank < rank_b_orig)
qop_output = node.output[0] if not needs_reshape else node.output[0] + "_pre_reshape"
matmul_qbit_node = onnx.helper.make_node(
"MatMulNBits",
inputs=input_names,
outputs=[node.output[0]],
outputs=[qop_output],
name=node.name + f"_Q{bits}" if node.name else "",
domain="com.microsoft",
**kwargs,
)

output_nodes.append(matmul_qbit_node)
if needs_reshape:
output_nodes.extend(
_build_nbits_output_reshape(
a_input_name=node.input[0],
b_original_shape=b_original_shape,
target_graph=b_graph,
name_prefix=(node.name + f"_Q{bits}") if node.name else (b_tensor.name + f"_Q{bits}"),
pre_reshape_output=qop_output,
final_output=node.output[0],
)
)
else:
dq_input_names = [b_quant.name, scales_tensor.name]
dq_output_names = [b_quant.name + "_output"]
Expand All @@ -950,7 +1169,13 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
node.input[0],
tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0],
]
matmul_output_names = [node.output[0]]
# When rank(A) >= rank_b_orig, the reshape would be a no-op since MatMul
# already produces the correct output shape.
a_static_rank = _get_static_rank(node.input[0], graph_stack)
rank_b_orig = len(b_original_shape) if b_original_shape is not None else 0
needs_reshape = b_original_shape is not None and (a_static_rank is None or a_static_rank < rank_b_orig)
qdq_matmul_out = node.output[0] if not needs_reshape else node.output[0] + "_pre_reshape"
matmul_output_names = [qdq_matmul_out]
if not self.config.is_symmetric:
zp_tensor = onnx.helper.make_tensor(
b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True
Expand Down Expand Up @@ -985,6 +1210,17 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
output_nodes.extend([dq_node, tp_node, matmul_node])
else:
output_nodes.extend([dq_node, matmul_node])
if needs_reshape:
output_nodes.extend(
_build_nbits_output_reshape(
a_input_name=node.input[0],
b_original_shape=b_original_shape,
target_graph=b_graph,
name_prefix=(node.name + f"_DQ_Q{bits}") if node.name else (b_tensor.name + f"_DQ_Q{bits}"),
pre_reshape_output=qdq_matmul_out,
final_output=node.output[0],
)
)

return output_nodes

Expand Down
Loading
Loading