Skip to content

Commit b3bc797

Browse files
jiayisunxpytorchmergebot
authored andcommitted
[Inductor][Quant]Support qconv_pointwise.tensor and qconv2d_pointwise.binary_tensor (pytorch#166608)
Pull Request resolved: pytorch#166608 Approved by: https://github.com/Xia-Weiwen, https://github.com/mingfeima, https://github.com/jansel
1 parent a928c9d commit b3bc797

File tree

5 files changed

+181
-47
lines changed

5 files changed

+181
-47
lines changed

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,25 @@ def matcher_check_fn():
11641164
quantization_with_autocast=quantization_with_autocast,
11651165
)
11661166

1167+
if torch._inductor.config.cpp_wrapper:
1168+
self._test_code_common(
1169+
mod,
1170+
(v,),
1171+
[f"aoti_torch_{device}__qconv_pointwise_tensor"],
1172+
[],
1173+
check_quantization=True,
1174+
num_include_ops=[3],
1175+
)
1176+
else:
1177+
self._test_code_common(
1178+
mod,
1179+
(v,),
1180+
["torch.ops.onednn.qconv_pointwise.tensor"],
1181+
[],
1182+
check_quantization=True,
1183+
num_include_ops=[3],
1184+
)
1185+
11671186
@skipIfNoDynamoSupport
11681187
@skipIfNoONEDNN
11691188
@skipIfRocm
@@ -1270,6 +1289,25 @@ def matcher_check_fn():
12701289
matcher_check_fn=matcher_check_fn,
12711290
)
12721291

1292+
if torch._inductor.config.cpp_wrapper:
1293+
self._test_code_common(
1294+
mod,
1295+
(v,),
1296+
[f"aoti_torch_{device}__qconv_pointwise_tensor"],
1297+
[],
1298+
check_quantization=True,
1299+
num_include_ops=[2],
1300+
)
1301+
else:
1302+
self._test_code_common(
1303+
mod,
1304+
(v,),
1305+
["torch.ops.onednn.qconv_pointwise.tensor"],
1306+
[],
1307+
check_quantization=True,
1308+
num_include_ops=[2],
1309+
)
1310+
12731311
@skipIfNoDynamoSupport
12741312
@skipIfNoONEDNN
12751313
def test_qconv2d_relu_cpu(self):
@@ -1548,6 +1586,32 @@ def matcher_check_fn():
15481586
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
15491587
)
15501588

1589+
if not TEST_ACL:
1590+
if torch._inductor.config.cpp_wrapper:
1591+
self._test_code_common(
1592+
mod,
1593+
(v,),
1594+
[
1595+
f"aoti_torch_{device}__qconv_pointwise_tensor",
1596+
f"aoti_torch_{device}__qconv2d_pointwise_binary_tensor",
1597+
],
1598+
[],
1599+
check_quantization=True,
1600+
num_include_ops=[2, 2],
1601+
)
1602+
else:
1603+
self._test_code_common(
1604+
mod,
1605+
(v,),
1606+
[
1607+
"torch.ops.onednn.qconv_pointwise.tensor",
1608+
"torch.ops.onednn.qconv2d_pointwise.binary_tensor",
1609+
],
1610+
[],
1611+
check_quantization=True,
1612+
num_include_ops=[2, 2],
1613+
)
1614+
15511615
def _qconv2d_add_test_helper2(
15521616
self, device="cpu", use_relu=False, int8_mixed_bf16=False
15531617
):
@@ -1645,6 +1709,26 @@ def matcher_check_fn():
16451709
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
16461710
)
16471711

1712+
if not TEST_ACL:
1713+
if torch._inductor.config.cpp_wrapper:
1714+
self._test_code_common(
1715+
mod,
1716+
(x, x2, x3),
1717+
[f"aoti_torch_{device}__qconv2d_pointwise_binary_tensor"],
1718+
[],
1719+
check_quantization=True,
1720+
num_include_ops=[2],
1721+
)
1722+
else:
1723+
self._test_code_common(
1724+
mod,
1725+
(x, x2, x3),
1726+
["torch.ops.onednn.qconv2d_pointwise.binary_tensor"],
1727+
[],
1728+
check_quantization=True,
1729+
num_include_ops=[2],
1730+
)
1731+
16481732
@skipIfNoDynamoSupport
16491733
@skipIfNoONEDNN
16501734
def test_qconv2d_add_cpu(self):

torch/_inductor/fx_passes/quantization.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,14 @@ def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
179179
)
180180

181181

182-
def get_qconv_pt2e_pattern(users=1):
182+
def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1):
183+
qconv_op = (
184+
torch.ops.onednn.qconv_pointwise.tensor
185+
if x_scale_zp_are_tensors
186+
else torch.ops.onednn.qconv_pointwise.default
187+
)
183188
return CallFunction(
184-
torch.ops.onednn.qconv_pointwise.default,
189+
qconv_op,
185190
KeywordArg("x"),
186191
KeywordArg("x_scale"),
187192
KeywordArg("x_zp"),
@@ -203,9 +208,14 @@ def get_qconv_pt2e_pattern(users=1):
203208
)
204209

205210

206-
def get_qconv2d_binary_pt2e_pattern(users=1):
211+
def get_qconv2d_binary_pt2e_pattern(x_scale_zp_are_tensors=False, users=1):
212+
qconv_op = (
213+
torch.ops.onednn.qconv2d_pointwise.binary_tensor
214+
if x_scale_zp_are_tensors
215+
else torch.ops.onednn.qconv2d_pointwise.binary
216+
)
207217
return CallFunction(
208-
torch.ops.onednn.qconv2d_pointwise.binary,
218+
qconv_op,
209219
KeywordArg("x"),
210220
KeywordArg("x_scale"),
211221
KeywordArg("x_zp"),
@@ -431,7 +441,13 @@ def qconv(match: Match, *args, **kwargs):
431441
kwargs["groups"],
432442
)
433443
output_dtype = _get_pattern_output_dtype(match)
434-
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
444+
assert output_dtype in [
445+
torch.int8,
446+
torch.uint8,
447+
torch.float8_e4m3fn,
448+
torch.float32,
449+
torch.bfloat16,
450+
]
435451
# Output QParams
436452
o_inv_scale = kwargs["output_scale"]
437453
o_zero_point = kwargs["output_zero_point"]
@@ -816,12 +832,17 @@ def qconv_binary(match: Match, *args, **kwargs):
816832

817833
def _register_quantization_unary_lowering():
818834
# QConv2d
819-
for users in [1, 2]:
820-
qconv_pattern = get_qconv_pt2e_pattern(users)
835+
for x_scale_zp_are_tensors, users in itertools.product([False, True], [1, 2]):
836+
qconv_pattern = get_qconv_pt2e_pattern(x_scale_zp_are_tensors, users)
837+
computation_op = (
838+
torch.ops.onednn.qconv_pointwise.tensor
839+
if x_scale_zp_are_tensors
840+
else torch.ops.onednn.qconv_pointwise.default
841+
)
821842
_register_quantized_conv_lowering(
822843
qconv_pattern,
823844
2, # pass_number
824-
torch.ops.onednn.qconv_pointwise.default, # computation_op
845+
computation_op,
825846
)
826847

827848
# QLinear
@@ -841,12 +862,17 @@ def _register_quantization_unary_lowering():
841862

842863
def _register_quantization_binary_lowering():
843864
# QConv2d
844-
for users in (1, 2):
845-
qconv_pattern = get_qconv2d_binary_pt2e_pattern(users)
865+
for x_scale_zp_are_tensors, users in itertools.product([False, True], [1, 2]):
866+
qconv_pattern = get_qconv2d_binary_pt2e_pattern(x_scale_zp_are_tensors, users)
867+
computation_op = (
868+
torch.ops.onednn.qconv2d_pointwise.binary_tensor
869+
if x_scale_zp_are_tensors
870+
else torch.ops.onednn.qconv2d_pointwise.binary
871+
)
846872
_register_quantized_conv_binary_lowering(
847873
qconv_pattern,
848874
2, # pass_number
849-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
875+
computation_op,
850876
)
851877

852878
# QLinear
@@ -3027,21 +3053,21 @@ def _register_qconv_unary_fusion():
30273053
PostOpAttr(
30283054
"none", None, "none", [], ""
30293055
): generate_pattern_with_output_quant(
3030-
get_qconv_pt2e_pattern(1),
3056+
get_qconv_pt2e_pattern(users=1),
30313057
),
30323058
PostOpAttr(
30333059
"none", None, "relu", [], ""
30343060
): generate_pattern_with_output_quant(
30353061
generate_pattern_with_unary(
3036-
get_qconv_pt2e_pattern(1), aten.relu.default
3062+
get_qconv_pt2e_pattern(users=1), aten.relu.default
30373063
),
30383064
),
30393065
PostOpAttr(
30403066
"none", None, "hardtanh", [], ""
30413067
): generate_pattern_with_output_quant(
30423068
_unary_fusion_pattern(
30433069
_hardtanh_fusion,
3044-
get_qconv_pt2e_pattern(1),
3070+
get_qconv_pt2e_pattern(users=1),
30453071
1,
30463072
is_bf16,
30473073
),
@@ -3052,7 +3078,7 @@ def _register_qconv_unary_fusion():
30523078
): generate_pattern_with_output_quant(
30533079
_unary_fusion_pattern(
30543080
_hardswish_fusion,
3055-
get_qconv_pt2e_pattern(1 if is_bf16 else 2),
3081+
get_qconv_pt2e_pattern(users=1 if is_bf16 else 2),
30563082
2,
30573083
is_bf16,
30583084
),
@@ -3063,7 +3089,7 @@ def _register_qconv_unary_fusion():
30633089
): generate_pattern_with_output_quant(
30643090
_unary_fusion_pattern(
30653091
_silu_fusion,
3066-
get_qconv_pt2e_pattern(1 if is_bf16 else 2),
3092+
get_qconv_pt2e_pattern(users=1 if is_bf16 else 2),
30673093
2,
30683094
is_bf16,
30693095
),
@@ -3083,14 +3109,14 @@ def _register_qconv_unary_fusion():
30833109
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
30843110
conv_unary_replace_float_out_patterns = {
30853111
PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary(
3086-
get_qconv_pt2e_pattern(1), aten.relu.default
3112+
get_qconv_pt2e_pattern(users=1), aten.relu.default
30873113
),
30883114
PostOpAttr(
30893115
"none", None, "hardtanh", [], ""
30903116
): _may_generate_pattern_with_dtype_convert(
30913117
_unary_fusion_pattern(
30923118
_hardtanh_fusion,
3093-
get_qconv_pt2e_pattern(1),
3119+
get_qconv_pt2e_pattern(users=1),
30943120
1,
30953121
is_bf16,
30963122
),
@@ -3102,7 +3128,7 @@ def _register_qconv_unary_fusion():
31023128
): _may_generate_pattern_with_dtype_convert(
31033129
_unary_fusion_pattern(
31043130
_hardswish_fusion,
3105-
get_qconv_pt2e_pattern(1 if is_bf16 else 2),
3131+
get_qconv_pt2e_pattern(users=1 if is_bf16 else 2),
31063132
2,
31073133
is_bf16,
31083134
),
@@ -3114,7 +3140,7 @@ def _register_qconv_unary_fusion():
31143140
): _may_generate_pattern_with_dtype_convert(
31153141
_unary_fusion_pattern(
31163142
_silu_fusion,
3117-
get_qconv_pt2e_pattern(1 if is_bf16 else 2),
3143+
get_qconv_pt2e_pattern(users=1 if is_bf16 else 2),
31183144
2,
31193145
is_bf16,
31203146
),
@@ -3146,7 +3172,7 @@ def _register_qconv_binary_fusion():
31463172
): generate_pattern_with_output_quant(
31473173
generate_pattern_with_binary(
31483174
aten.add.Tensor,
3149-
get_qconv_pt2e_pattern(1),
3175+
get_qconv_pt2e_pattern(users=1),
31503176
dequantize_accum_pattern,
31513177
int8_mixed_bf16_with_inplace_add,
31523178
swap_inputs=swap_inputs,
@@ -3158,7 +3184,7 @@ def _register_qconv_binary_fusion():
31583184
generate_pattern_with_unary(
31593185
generate_pattern_with_binary(
31603186
aten.add.Tensor,
3161-
get_qconv_pt2e_pattern(1),
3187+
get_qconv_pt2e_pattern(users=1),
31623188
dequantize_accum_pattern,
31633189
int8_mixed_bf16_with_inplace_add,
31643190
swap_inputs=swap_inputs,
@@ -3185,7 +3211,7 @@ def _register_qconv_binary_fusion():
31853211
PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
31863212
generate_pattern_with_binary(
31873213
aten.add.Tensor,
3188-
get_qconv_pt2e_pattern(1),
3214+
get_qconv_pt2e_pattern(users=1),
31893215
KeywordArg("accum_after_dequant"),
31903216
int8_mixed_bf16_with_inplace_add,
31913217
swap_inputs=swap_inputs,
@@ -3223,7 +3249,7 @@ def _register_qconv_binary_fusion():
32233249
"sum", 1.0, "none", [], ""
32243250
): generate_pattern_with_binary(
32253251
aten.add.Tensor,
3226-
get_qconv_pt2e_pattern(1),
3252+
get_qconv_pt2e_pattern(users=1),
32273253
KeywordArg("accum_after_dequant"),
32283254
int8_mixed_bf16_with_inplace_add,
32293255
swap_inputs=swap_inputs,

torch/_inductor/mkldnn_ir.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def __init__(
603603
inputs,
604604
constant_args,
605605
None,
606-
op_overload=torch.ops.onednn.qconv_pointwise.default,
606+
op_overload=torch.ops.onednn.qconv_pointwise.tensor,
607607
cpp_kernel_name=f"aoti_torch_{self.device_type}__qconv_pointwise_tensor",
608608
)
609609

@@ -623,7 +623,7 @@ def create(
623623
x_zero_point: Union["ShapeAsConstantBuffer", "TensorBox"],
624624
qw: "TensorBox", # qw
625625
w_scale: "TensorBox",
626-
w_zero_point: "TensorBox",
626+
w_zero_point,
627627
bias: "TensorBox",
628628
stride: list[int],
629629
padding: list[int],
@@ -711,7 +711,7 @@ def __init__(
711711
inputs,
712712
constant_args,
713713
None,
714-
op_overload=torch.ops.onednn.qconv2d_pointwise.binary,
714+
op_overload=torch.ops.onednn.qconv2d_pointwise.binary_tensor,
715715
cpp_kernel_name=(
716716
f"aoti_torch_{self.device_type}__qconv2d_pointwise_binary_tensor"
717717
),

0 commit comments

Comments
 (0)