Skip to content

Commit 0d5db82

Browse files
committed
add UnbindScaledDotProductModel torch2 quantization test
1 parent d034abd commit 0d5db82

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

tests/cross_fw/test_templates/helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,15 @@ def forward(self, query, key, value):
440440
return nn.functional.scaled_dot_product_attention(query, key, value)
441441

442442

443+
class UnbindScaledDotProductAttentionModel(nn.Module):
444+
def __init__(self):
445+
super().__init__()
446+
447+
def forward(self, x):
448+
query, key, value = x.unbind(0)
449+
return nn.functional.scaled_dot_product_attention(query, key, value)
450+
451+
443452
class DepthwiseConvTestModel(nn.Module):
444453
INPUT_SIZE = [1, 2, 4, 4]
445454

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
strict digraph {
2+
x [id=0, type="nncf_model_input", metatype=PTInputNoopMetatype];
3+
"/unbind/0" [id=1, type=unbind, metatype=PTSplitMetatype];
4+
"__nncf_hooks.pre_hooks./scaled_dot_product_attention/0__0.0._scale_param_storage" [id=2, type="nncf_model_const", metatype=PTConstNoopMetatype];
5+
"pre_hook__-scaled_dot_product_attention-0__0[0]/symmetric_quantize/0" [id=3, type="symmetric_quantize", metatype=UnknownMetatype];
6+
"__nncf_hooks.pre_hooks./scaled_dot_product_attention/0__1.0._scale_param_storage" [id=4, type="nncf_model_const", metatype=PTConstNoopMetatype];
7+
"pre_hook__-scaled_dot_product_attention-0__1[0]/symmetric_quantize/0" [id=5, type="symmetric_quantize", metatype=UnknownMetatype];
8+
"/scaled_dot_product_attention/0" [id=6, type="scaled_dot_product_attention", metatype=PTScaledDotProductAttentionMetatype];
9+
output [id=7, type="nncf_model_output", metatype=PTOutputNoopMetatype];
10+
x -> "/unbind/0" [dtype=float, shape="(3, 1, 8, 16)", out_port_id=0, in_port_id=0];
11+
"/unbind/0" -> "pre_hook__-scaled_dot_product_attention-0__0[0]/symmetric_quantize/0" [dtype=float, shape="(1, 8, 16)", out_port_id=0, in_port_id=0];
12+
"/unbind/0" -> "pre_hook__-scaled_dot_product_attention-0__1[0]/symmetric_quantize/0" [dtype=float, shape="(1, 8, 16)", out_port_id=1, in_port_id=0];
13+
"/unbind/0" -> "/scaled_dot_product_attention/0" [dtype=float, shape="(1, 8, 16)", out_port_id=2, in_port_id=2];
14+
"__nncf_hooks.pre_hooks./scaled_dot_product_attention/0__0.0._scale_param_storage" -> "pre_hook__-scaled_dot_product_attention-0__0[0]/symmetric_quantize/0" [dtype=float, shape="(1,)", out_port_id=0, in_port_id=4];
15+
"pre_hook__-scaled_dot_product_attention-0__0[0]/symmetric_quantize/0" -> "/scaled_dot_product_attention/0" [dtype=float, shape="(1, 8, 16)", out_port_id=0, in_port_id=0];
16+
"__nncf_hooks.pre_hooks./scaled_dot_product_attention/0__1.0._scale_param_storage" -> "pre_hook__-scaled_dot_product_attention-0__1[0]/symmetric_quantize/0" [dtype=float, shape="(1,)", out_port_id=0, in_port_id=4];
17+
"pre_hook__-scaled_dot_product_attention-0__1[0]/symmetric_quantize/0" -> "/scaled_dot_product_attention/0" [dtype=float, shape="(1, 8, 16)", out_port_id=0, in_port_id=1];
18+
"/scaled_dot_product_attention/0" -> output [dtype=float, shape="(1, 8, 16)", out_port_id=0, in_port_id=0];
19+
}

tests/torch2/function_hook/quantization/test_quantized_graphs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tests.cross_fw.test_templates.helpers import EmbeddingModel
2626
from tests.cross_fw.test_templates.helpers import RoPEModel
2727
from tests.cross_fw.test_templates.helpers import ScaledDotProductAttentionModel
28+
from tests.cross_fw.test_templates.helpers import UnbindScaledDotProductAttentionModel
2829
from tests.torch import test_models
2930
from tests.torch.quantization.test_algo_quantization import SharedLayersModel
3031
from tests.torch.test_compressed_graph import ModelDesc
@@ -45,6 +46,14 @@
4546
),
4647
{},
4748
),
49+
(
50+
ModelDesc(
51+
"unbind_scaled_dot_product_attention_model",
52+
UnbindScaledDotProductAttentionModel,
53+
{"x": [3, 1, 8, 16]},
54+
),
55+
{},
56+
),
4857
(ModelDesc("shared_model", SharedLayersModel, [1, 1, 5, 6]), {}),
4958
(ModelDesc("alexnet", test_models.AlexNet, [1, 3, 32, 32]), {}),
5059
(ModelDesc("lenet", test_models.LeNet, [1, 3, 32, 32]), {}),

0 commit comments

Comments
 (0)