Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nncf/openvino/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None:
in_node_id = graph.get_node_by_name(op.get_friendly_name()).node_id
for output_port_id, out in enumerate(op.outputs()):
node_vs_target_inputs = defaultdict(list)
for inp in out.get_target_inputs():
for inp in sorted(out.get_target_inputs(), key=lambda inp: inp.get_node().get_friendly_name()):
Copy link
Contributor

@ljaljushkin ljaljushkin Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it needed?
do we rely on order of friendly names somewhere?

node_vs_target_inputs[inp.get_node()].append(inp)

for out_node, inputs in node_vs_target_inputs.items():
Expand Down
30 changes: 19 additions & 11 deletions src/nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,19 +1028,22 @@ def apply(
)
return transformed_model

def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]:
def _get_activation_node_port_and_channel(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]:
"""
This method returns the activation layer and corresponding port id for the node.
This method returns the activation layer, corresponding port id and channel axis for the given node.

:param node: NNCFGraph node for which the activation is sought.
:param nncf_graph: NNCFGraph instance with the node.
:return: Tuple with the activation node and port id.
:return: Tuple with the activation node, port id and channel axis.
"""
activation_port = self._backend_entity.get_activation_port_id(node, nncf_graph)
activation_edge = nncf_graph.get_input_edge_by_port_id(node, activation_port)
activation_node = activation_edge.from_node
port_id = activation_edge.output_port_id
return activation_node, port_id
activation_channel_axis = self._backend_entity.get_activation_channel_axis(
node, port_id, activation_edge.tensor_shape
)
return activation_node, port_id, activation_channel_axis

def get_matmul_input_to_output_nodes_map(
self, matmul_nodes: list[NNCFNode], graph: NNCFGraph
Expand All @@ -1061,8 +1064,8 @@ def get_matmul_input_to_output_nodes_map(
"""
matmul_input_to_output_nodes_map = defaultdict(list)
for node in matmul_nodes:
act_node, output_port_id = self._get_activation_node_and_port(node, graph)
matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node)
act_node, output_port_id, act_channel_axis = self._get_activation_node_port_and_channel(node, graph)
matmul_input_to_output_nodes_map[(act_node, output_port_id, act_channel_axis)].append(node)
return matmul_input_to_output_nodes_map

def get_compression_nodes_info(
Expand Down Expand Up @@ -1130,7 +1133,11 @@ def get_statistic_points(

# Statistics for data aware algorithms
if self._data_aware_compression:
for (node, output_port_id), node_with_weights in matmul_input_to_output_nodes_map.items():
for (
node,
output_port_id,
input_channel_axis,
), node_with_weights in matmul_input_to_output_nodes_map.items():
statistic_point = self._backend_entity.target_point(
TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id
)
Expand All @@ -1145,10 +1152,11 @@ def get_statistic_points(
]
all_weight_dims.extend(weight_dims)

# by default, reduce activations across all but the last dimension. The last dimension is
# assumed to be the hidden size dimension.
# Reduce activations across all but the hidden dimension.
n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape)
reduction_axes = tuple(range(n_dims - 1))
# negative axis (e.g. -1 for the last axis) is converted into corresponding positive value
input_channel_axis = input_channel_axis % n_dims
reduction_axes = tuple(i for i in range(n_dims) if i != input_channel_axis)

# For 3D weights, hidden dimension is the second dimension. Reduce by all other dimensions
reduction_axes = (1,) if any(weight_dim == 3 for weight_dim in all_weight_dims) else reduction_axes
Expand Down Expand Up @@ -1191,7 +1199,7 @@ def _get_statistics_for_weights_compression(
# Where mean_value is a 1D tensor representing an activation reduced over batch and sequence length dimensions,
# shape is an original shape of an activation before reduction, n is the size of the dataset (or subset_size).
statistics = {}
for (act_node, output_port_id), matmul_nodes in matmul_input_to_output_nodes_map.items():
for (act_node, output_port_id, _), matmul_nodes in matmul_input_to_output_nodes_map.items():
tensor_collectors = list(
statistic_points.get_algo_statistics_for_node(
act_node.node_name,
Expand Down
11 changes: 11 additions & 0 deletions src/nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ def get_ignored_patterns() -> GraphPattern:
:return: backend-specific ignored patterns.
"""

@staticmethod
@abstractmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int:
"""
Returns axis number of the activation tensor which correspond to it channel.
:param node: NNCFNode instance.
:param port_id: Port ID for input.
:param input_shape: Shape of the input.
:return: Channel axis number.
"""


class AWQAlgoBackend(WeightCompressionAlgoBackend):
@staticmethod
Expand Down
35 changes: 20 additions & 15 deletions src/nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,14 @@ def apply(
]:
continue
_, input_tensors = next(iter(inputs.items()))
hessian = self._calculate_hessian(node, input_tensors)
scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors)
input_channel_axis = self._backend_entity.get_activation_channel_axis(
node, self._backend_entity.get_activation_port_id(node, graph), input_tensors[0].shape
)
hessian = self._calculate_hessian(node, input_tensors, input_channel_axis)
scale, zero_point = self._quantize_weights(
model, graph, wc_params, hessian, input_tensors, input_channel_axis
)
res[wc_params.weight_name] = CompressedWeight(None, scale, zero_point, None)

return model, res

def get_statistic_points(
Expand Down Expand Up @@ -158,7 +162,7 @@ def get_statistic_points(

return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes)

def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor:
def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], input_channel_axis: int) -> Tensor:
"""
Calculates the Hessian matrix for the given node and inputs.

Expand All @@ -171,19 +175,18 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor:
if node.metatype in self._backend_entity.convolution_metatypes:
msg = "Convolution metatypes are not supported"
raise nncf.UnsupportedModelError(msg)
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)

hessian = fns.zeros(
(inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32
(inputs[0].shape[input_channel_axis], inputs[0].shape[input_channel_axis]),
backend=inputs[0].backend,
dtype=TensorDataType.float32,
)

for inp in inputs:
batch_size = 1 if len(inp.shape) == 2 else inp.shape[0]
if node.metatype in self._backend_entity.matmul_metatypes:
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.reshape((-1, inp.shape[input_channel_axis]))
inp = fns.transpose(inp)
hessian *= nsamples / (nsamples + batch_size)
nsamples += batch_size
Expand All @@ -199,6 +202,7 @@ def _quantize_weights(
wc_params: WeightCompressionParameters,
hessian: Tensor,
inputs: list[Tensor],
input_channel_axis: int,
):
"""
Quantizes the weights of the model based on the calculated Hessian matrix.
Expand All @@ -211,10 +215,7 @@ def _quantize_weights(
"""
if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes:
msg = "Convolution metatypes are not supported"
raise RuntimeError(msg)
if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]:
msg = "Transpose is not supported"
raise RuntimeError(msg)
raise nncf.UnsupportedModelError(msg)

weight_tensor = self._backend_entity.get_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph
Expand Down Expand Up @@ -272,8 +273,12 @@ def _quantize_weights(
scales.append(scale)
else:
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
slicing_along_axis = [slice(None)] * len(inputs[0].shape)
slicing_along_axis[input_channel_axis] = slice(i1 + i, i1 + i + group_size)
activations = [inp[tuple(slicing_along_axis)] for inp in inputs]
wc_statistics = ScaleEstimation.activations_to_wc_statistics(
activations, input_channel_axis
)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
wc_statistics,
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_statistic_points(
self._set_backend_entity(model)

statistic_container = StatisticPointsContainer()
for act_node, output_port_id in nodes_and_port_ids:
for act_node, output_port_id, _ in nodes_and_port_ids:
n_dims = len(graph.get_output_edges_by_port_id(act_node, output_port_id)[0].tensor_shape)
if n_dims < 2:
msg = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from nncf.onnx.graph.model_transformer import remove_initializer
from nncf.onnx.graph.model_transformer import remove_node
from nncf.onnx.graph.model_transformer import set_initializer
from nncf.onnx.graph.node_utils import get_act_quantization_axis
from nncf.onnx.graph.node_utils import get_weight_quantization_axis
from nncf.onnx.graph.onnx_helper import ONNX_DTYPE_TO_NNCF_DTYPE
from nncf.onnx.graph.onnx_helper import get_name_to_node_map
Expand Down Expand Up @@ -301,6 +302,10 @@ def filter_func(point: StatisticPoint) -> bool:

return filter_func

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int:
return get_act_quantization_axis(node, port_id)

def insert_adapters(
self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nncf.openvino.graph.node_utils import convert_op
from nncf.openvino.graph.node_utils import create_ov_codebook_subgraph
from nncf.openvino.graph.node_utils import create_ov_const_from_tensor
from nncf.openvino.graph.node_utils import get_activation_channel_axis
from nncf.openvino.graph.node_utils import get_const_value_as_numpy_tensor
from nncf.openvino.graph.node_utils import get_const_value_as_ov_tensor
from nncf.openvino.graph.node_utils import get_weight_channel_axes
Expand Down Expand Up @@ -118,9 +119,6 @@ def mean_statistic_collector(

@staticmethod
def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)
constant_ports = node.layer_attributes.get_const_port_ids()
activation_ports = [
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports
Expand All @@ -137,6 +135,9 @@ def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> list[tupl
return result

def get_weight(self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.Model, graph: NNCFGraph) -> Tensor:
if not node_with_weight.layer_attributes.constant_attributes[weight_port_id]["transpose"]:
msg = "Only transposed weights are supported"
raise nncf.UnsupportedModelError(msg)
weight_name = node_with_weight.layer_attributes.constant_attributes[weight_port_id]["name"]
weight_node = self.name_to_node_mapping[weight_name]
weight_tensor = get_const_value_as_numpy_tensor(weight_node)
Expand Down Expand Up @@ -203,7 +204,12 @@ def insert_adapters(
A_W = opset.constant(lora_A.data)
B_W = opset.constant(lora_B.data)

A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True)
A_MM = opset.matmul(
input_node,
A_W,
transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes["transpose"],
transpose_b=True,
)
B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True)

node_output_port = mm_node.output(0)
Expand Down Expand Up @@ -399,6 +405,10 @@ def get_ignored_patterns() -> GraphPattern:
pattern.add_pattern_alternative(create_sam_pe())
return pattern

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int:
return get_activation_channel_axis(node, port_id, input_shape)


class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def calculate_quantization_params(
return result_scale, zp

@staticmethod
def activations_to_wc_statistics(activations: list[Tensor]) -> WCTensorStatistic:
def activations_to_wc_statistics(activations: list[Tensor], input_channel_axis: int) -> WCTensorStatistic:
"""
Mimic the activation reducing logic from WeightCompression.get_statistic_points.

Expand All @@ -393,7 +393,9 @@ def activations_to_wc_statistics(activations: list[Tensor]) -> WCTensorStatistic
shapes = []
for act in activations:
shapes.append(act.shape)
reduction_shape = tuple(range(act.ndim - 1))
# negative axis (e.g. -1 for the last axis) is converted into corresponding positive value
input_channel_axis = input_channel_axis % len(act.shape)
reduction_shape = tuple(i for i in range(len(act.shape)) if i != input_channel_axis)
mean_values.append(fns.mean(act, axis=reduction_shape))
wc_statistics = WCTensorStatistic(mean_values, shapes)
return wc_statistics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from nncf.torch.model_graph_manager import split_const_name
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.node_utils import get_activation_channel_axis as get_activation_channel_axis_util
from nncf.torch.quantization.ignored_patterns import create_rope
from nncf.torch.quantization.ignored_patterns import create_sam_pe
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
Expand Down Expand Up @@ -486,6 +487,10 @@ def get_ignored_patterns() -> GraphPattern:
pattern.add_pattern_alternative(create_sam_pe())
return pattern

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int:
return get_activation_channel_axis_util(node, port_id)


class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_weight_compression_reduction_axes
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.node_utils import get_activation_channel_axis as get_activation_channel_axis_util
from nncf.torch.quantization.ignored_patterns import create_rope
from nncf.torch.quantization.ignored_patterns import create_sam_pe
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
Expand Down Expand Up @@ -262,6 +263,10 @@ def get_ignored_patterns() -> GraphPattern:
pattern.add_pattern_alternative(create_sam_pe())
return pattern

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int:
return get_activation_channel_axis_util(node, port_id)


class FXMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, FXWeightCompressionAlgoBackend):
pass
Expand Down
42 changes: 42 additions & 0 deletions src/nncf/torch/node_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nncf
import nncf.torch.graph.operator_metatypes as op
from nncf.common.graph import NNCFNode
from nncf.torch.graph.operator_metatypes import PTAddmmMetatype
from nncf.torch.graph.operator_metatypes import PTMatMulMetatype


def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
"""
Returns axis number of the activation tensor which correspond to it channel.

:param node: NNCFNode instance.
:param port_id: Port ID for input.
:return: Channel axis number.
"""
if node.metatype not in op.CONVOLUTION_METATYPES + op.MATMUL_METATYPES + op.UNIFICATION_PRODUCING_METATYPES:
msg = f"Activation channel axis retrieval from node with metatype {node.metatype} is not supported"
raise nncf.InternalError(msg)

if node.metatype not in [PTMatMulMetatype, PTAddmmMetatype]:
return node.metatype.output_channel_axis

if port_id == 0:
# X(port:0) * W(port:1): [..., C_IN] * [... , C_IN, C_OUT]
return -1
if port_id == 1:
# W(port:0) * X(port:1): [... , C_OUT, C_IN] * [... , C_IN, ...]
return -2

msg = f"Port id for a {node.metatype} operation is expected to be in [0, 1], {port_id} recieved"
raise nncf.InternalError(msg)
Loading