@@ -179,9 +179,14 @@ def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
179179)
180180
181181
182- def get_qconv_pt2e_pattern (users = 1 ):
182+ def get_qconv_pt2e_pattern (x_scale_zp_are_tensors = False , users = 1 ):
183+ qconv_op = (
184+ torch .ops .onednn .qconv_pointwise .tensor
185+ if x_scale_zp_are_tensors
186+ else torch .ops .onednn .qconv_pointwise .default
187+ )
183188 return CallFunction (
184- torch . ops . onednn . qconv_pointwise . default ,
189+ qconv_op ,
185190 KeywordArg ("x" ),
186191 KeywordArg ("x_scale" ),
187192 KeywordArg ("x_zp" ),
@@ -203,9 +208,14 @@ def get_qconv_pt2e_pattern(users=1):
203208 )
204209
205210
206- def get_qconv2d_binary_pt2e_pattern (users = 1 ):
211+ def get_qconv2d_binary_pt2e_pattern (x_scale_zp_are_tensors = False , users = 1 ):
212+ qconv_op = (
213+ torch .ops .onednn .qconv2d_pointwise .binary_tensor
214+ if x_scale_zp_are_tensors
215+ else torch .ops .onednn .qconv2d_pointwise .binary
216+ )
207217 return CallFunction (
208- torch . ops . onednn . qconv2d_pointwise . binary ,
218+ qconv_op ,
209219 KeywordArg ("x" ),
210220 KeywordArg ("x_scale" ),
211221 KeywordArg ("x_zp" ),
@@ -431,7 +441,13 @@ def qconv(match: Match, *args, **kwargs):
431441 kwargs ["groups" ],
432442 )
433443 output_dtype = _get_pattern_output_dtype (match )
434- assert output_dtype in [torch .int8 , torch .uint8 , torch .float32 , torch .bfloat16 ]
444+ assert output_dtype in [
445+ torch .int8 ,
446+ torch .uint8 ,
447+ torch .float8_e4m3fn ,
448+ torch .float32 ,
449+ torch .bfloat16 ,
450+ ]
435451 # Output QParams
436452 o_inv_scale = kwargs ["output_scale" ]
437453 o_zero_point = kwargs ["output_zero_point" ]
@@ -816,12 +832,17 @@ def qconv_binary(match: Match, *args, **kwargs):
816832
817833def _register_quantization_unary_lowering ():
818834 # QConv2d
819- for users in [1 , 2 ]:
820- qconv_pattern = get_qconv_pt2e_pattern (users )
835+ for x_scale_zp_are_tensors , users in itertools .product ([False , True ], [1 , 2 ]):
836+ qconv_pattern = get_qconv_pt2e_pattern (x_scale_zp_are_tensors , users )
837+ computation_op = (
838+ torch .ops .onednn .qconv_pointwise .tensor
839+ if x_scale_zp_are_tensors
840+ else torch .ops .onednn .qconv_pointwise .default
841+ )
821842 _register_quantized_conv_lowering (
822843 qconv_pattern ,
823844 2 , # pass_number
824- torch . ops . onednn . qconv_pointwise . default , # computation_op
845+ computation_op ,
825846 )
826847
827848 # QLinear
@@ -841,12 +862,17 @@ def _register_quantization_unary_lowering():
841862
842863def _register_quantization_binary_lowering ():
843864 # QConv2d
844- for users in (1 , 2 ):
845- qconv_pattern = get_qconv2d_binary_pt2e_pattern (users )
865+ for x_scale_zp_are_tensors , users in itertools .product ([False , True ], [1 , 2 ]):
866+ qconv_pattern = get_qconv2d_binary_pt2e_pattern (x_scale_zp_are_tensors , users )
867+ computation_op = (
868+ torch .ops .onednn .qconv2d_pointwise .binary_tensor
869+ if x_scale_zp_are_tensors
870+ else torch .ops .onednn .qconv2d_pointwise .binary
871+ )
846872 _register_quantized_conv_binary_lowering (
847873 qconv_pattern ,
848874 2 , # pass_number
849- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
875+ computation_op ,
850876 )
851877
852878 # QLinear
@@ -3027,21 +3053,21 @@ def _register_qconv_unary_fusion():
30273053 PostOpAttr (
30283054 "none" , None , "none" , [], ""
30293055 ): generate_pattern_with_output_quant (
3030- get_qconv_pt2e_pattern (1 ),
3056+ get_qconv_pt2e_pattern (users = 1 ),
30313057 ),
30323058 PostOpAttr (
30333059 "none" , None , "relu" , [], ""
30343060 ): generate_pattern_with_output_quant (
30353061 generate_pattern_with_unary (
3036- get_qconv_pt2e_pattern (1 ), aten .relu .default
3062+ get_qconv_pt2e_pattern (users = 1 ), aten .relu .default
30373063 ),
30383064 ),
30393065 PostOpAttr (
30403066 "none" , None , "hardtanh" , [], ""
30413067 ): generate_pattern_with_output_quant (
30423068 _unary_fusion_pattern (
30433069 _hardtanh_fusion ,
3044- get_qconv_pt2e_pattern (1 ),
3070+ get_qconv_pt2e_pattern (users = 1 ),
30453071 1 ,
30463072 is_bf16 ,
30473073 ),
@@ -3052,7 +3078,7 @@ def _register_qconv_unary_fusion():
30523078 ): generate_pattern_with_output_quant (
30533079 _unary_fusion_pattern (
30543080 _hardswish_fusion ,
3055- get_qconv_pt2e_pattern (1 if is_bf16 else 2 ),
3081+ get_qconv_pt2e_pattern (users = 1 if is_bf16 else 2 ),
30563082 2 ,
30573083 is_bf16 ,
30583084 ),
@@ -3063,7 +3089,7 @@ def _register_qconv_unary_fusion():
30633089 ): generate_pattern_with_output_quant (
30643090 _unary_fusion_pattern (
30653091 _silu_fusion ,
3066- get_qconv_pt2e_pattern (1 if is_bf16 else 2 ),
3092+ get_qconv_pt2e_pattern (users = 1 if is_bf16 else 2 ),
30673093 2 ,
30683094 is_bf16 ,
30693095 ),
@@ -3083,14 +3109,14 @@ def _register_qconv_unary_fusion():
30833109 # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
30843110 conv_unary_replace_float_out_patterns = {
30853111 PostOpAttr ("none" , None , "relu" , [], "" ): generate_pattern_with_unary (
3086- get_qconv_pt2e_pattern (1 ), aten .relu .default
3112+ get_qconv_pt2e_pattern (users = 1 ), aten .relu .default
30873113 ),
30883114 PostOpAttr (
30893115 "none" , None , "hardtanh" , [], ""
30903116 ): _may_generate_pattern_with_dtype_convert (
30913117 _unary_fusion_pattern (
30923118 _hardtanh_fusion ,
3093- get_qconv_pt2e_pattern (1 ),
3119+ get_qconv_pt2e_pattern (users = 1 ),
30943120 1 ,
30953121 is_bf16 ,
30963122 ),
@@ -3102,7 +3128,7 @@ def _register_qconv_unary_fusion():
31023128 ): _may_generate_pattern_with_dtype_convert (
31033129 _unary_fusion_pattern (
31043130 _hardswish_fusion ,
3105- get_qconv_pt2e_pattern (1 if is_bf16 else 2 ),
3131+ get_qconv_pt2e_pattern (users = 1 if is_bf16 else 2 ),
31063132 2 ,
31073133 is_bf16 ,
31083134 ),
@@ -3114,7 +3140,7 @@ def _register_qconv_unary_fusion():
31143140 ): _may_generate_pattern_with_dtype_convert (
31153141 _unary_fusion_pattern (
31163142 _silu_fusion ,
3117- get_qconv_pt2e_pattern (1 if is_bf16 else 2 ),
3143+ get_qconv_pt2e_pattern (users = 1 if is_bf16 else 2 ),
31183144 2 ,
31193145 is_bf16 ,
31203146 ),
@@ -3146,7 +3172,7 @@ def _register_qconv_binary_fusion():
31463172 ): generate_pattern_with_output_quant (
31473173 generate_pattern_with_binary (
31483174 aten .add .Tensor ,
3149- get_qconv_pt2e_pattern (1 ),
3175+ get_qconv_pt2e_pattern (users = 1 ),
31503176 dequantize_accum_pattern ,
31513177 int8_mixed_bf16_with_inplace_add ,
31523178 swap_inputs = swap_inputs ,
@@ -3158,7 +3184,7 @@ def _register_qconv_binary_fusion():
31583184 generate_pattern_with_unary (
31593185 generate_pattern_with_binary (
31603186 aten .add .Tensor ,
3161- get_qconv_pt2e_pattern (1 ),
3187+ get_qconv_pt2e_pattern (users = 1 ),
31623188 dequantize_accum_pattern ,
31633189 int8_mixed_bf16_with_inplace_add ,
31643190 swap_inputs = swap_inputs ,
@@ -3185,7 +3211,7 @@ def _register_qconv_binary_fusion():
31853211 PostOpAttr ("sum" , 1.0 , "relu" , [], "" ): generate_pattern_with_unary (
31863212 generate_pattern_with_binary (
31873213 aten .add .Tensor ,
3188- get_qconv_pt2e_pattern (1 ),
3214+ get_qconv_pt2e_pattern (users = 1 ),
31893215 KeywordArg ("accum_after_dequant" ),
31903216 int8_mixed_bf16_with_inplace_add ,
31913217 swap_inputs = swap_inputs ,
@@ -3223,7 +3249,7 @@ def _register_qconv_binary_fusion():
32233249 "sum" , 1.0 , "none" , [], ""
32243250 ): generate_pattern_with_binary (
32253251 aten .add .Tensor ,
3226- get_qconv_pt2e_pattern (1 ),
3252+ get_qconv_pt2e_pattern (users = 1 ),
32273253 KeywordArg ("accum_after_dequant" ),
32283254 int8_mixed_bf16_with_inplace_add ,
32293255 swap_inputs = swap_inputs ,
0 commit comments