Skip to content

Commit c223e06

Browse files
fzimmermann89pytorchmergebot
authored andcommitted
Tighten type hints for tensor arithmetic (pytorch#135392)
Fixes pytorch#124015 Pull Request resolved: pytorch#135392 Approved by: https://github.com/ezyang
1 parent a96aadf commit c223e06

File tree

4 files changed

+46
-28
lines changed

4 files changed

+46
-28
lines changed

tools/pyi/gen_pyi.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,18 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
177177
"copy_",
178178
]
179179

180-
binary_ops = (
180+
shift_ops = (
181+
"lshift",
182+
"rshift",
183+
"ilshift",
184+
"irshift", # inplace ops
185+
)
186+
arithmetic_ops = (
181187
"add",
182188
"sub",
183189
"mul",
184190
"div",
185191
"pow",
186-
"lshift",
187-
"rshift",
188192
"mod",
189193
"truediv",
190194
"matmul",
@@ -195,24 +199,26 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
195199
"rtruediv",
196200
"rfloordiv",
197201
"rpow", # reverse arithmetic
202+
"iadd",
203+
"idiv",
204+
"imul",
205+
"isub",
206+
"ifloordiv",
207+
"imod", # inplace ops
208+
)
209+
logic_ops = (
198210
"and",
199211
"or",
200212
"xor",
201213
"rand",
202214
"ror",
203-
"rxor", # logic
204-
"iadd",
215+
"rxor", # reverse logic
205216
"iand",
206-
"idiv",
207-
"ilshift",
208-
"imul",
209217
"ior",
210-
"irshift",
211-
"isub",
212-
"ixor",
213-
"ifloordiv",
214-
"imod", # inplace ops
218+
"ixor", # inplace ops
215219
)
220+
binary_ops = shift_ops + arithmetic_ops + logic_ops
221+
216222
symmetric_comparison_ops = ("eq", "ne")
217223
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
218224
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
@@ -232,14 +238,28 @@ def sig_for_ops(opname: str) -> list[str]:
232238
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
233239

234240
name = opname[2:-2]
235-
if name in binary_ops:
236-
return [f"def {opname}(self, other: Any) -> Tensor: ..."]
237-
elif name in comparison_ops:
238-
sig = f"def {opname}(self, other: Any) -> Tensor: ..."
239-
if name in symmetric_comparison_ops:
241+
if name == "rpow":
242+
return [ # somehow required to make mypy ci happy?
243+
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[has-type]"
244+
]
245+
elif name in arithmetic_ops:
246+
return [
247+
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..."
248+
]
249+
elif name in logic_ops:
250+
return [f"def {opname}(self, other: Union[Tensor, _bool]) -> Tensor: ..."]
251+
elif name in shift_ops:
252+
return [f"def {opname}(self, other: Union[Tensor, _int]) -> Tensor: ..."]
253+
elif name in symmetric_comparison_ops:
254+
return [
240255
# unsafe override https://github.com/python/mypy/issues/5704
241-
sig += " # type: ignore[override]"
242-
return [sig]
256+
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[override]",
257+
f"def {opname}(self, other: Any) -> _bool: ...",
258+
]
259+
elif name in asymmetric_comparison_ops:
260+
return [
261+
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..."
262+
]
243263
elif name in unary_ops:
244264
return [f"def {opname}(self) -> Tensor: ..."]
245265
elif name in to_py_type_ops:

torch/_decomp/decompositions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2291,7 +2291,8 @@ def native_batch_norm_backward(
22912291
mean = save_mean_cast
22922292
invstd = save_invstd_cast
22932293
if train:
2294-
assert save_mean_cast is not None and save_invstd_cast is not None
2294+
assert mean is not None and invstd is not None
2295+
22952296
else:
22962297
assert running_mean_cast is not None and running_var_cast is not None
22972298
mean = running_mean_cast

torch/_inductor/fx_passes/efficient_conv_bn_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def efficient_conv_bn_eval(
3333
"""
3434

3535
assert bn.running_var is not None
36+
assert bn.running_mean is not None
3637

3738
# These lines of code are designed to deal with various cases
3839
# like bn without affine transform, and conv without bias

torch/ao/quantization/_equalize.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
128128
"module type not supported:", type(module1), " ", type(module2)
129129
)
130130

131-
conv1_has_bias = has_bias(module1)
132-
bias = None
131+
bias = get_module_bias(module1) if has_bias(module1) else None
133132

134133
weight1 = get_module_weight(module1)
135134
weight2 = get_module_weight(module2)
@@ -140,9 +139,6 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
140139
number input channels of second arg"
141140
)
142141

143-
if conv1_has_bias:
144-
bias = get_module_bias(module1)
145-
146142
weight1_range = channel_range(weight1, output_axis)
147143
weight2_range = channel_range(weight2, input_axis)
148144

@@ -151,7 +147,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
151147
scaling_factors = torch.sqrt(weight1_range / weight2_range)
152148
inverse_scaling_factors = torch.reciprocal(scaling_factors)
153149

154-
if conv1_has_bias:
150+
if bias is not None:
155151
bias = bias * inverse_scaling_factors
156152

157153
# formatting the scaling (1D) tensors to be applied on the given argument tensors
@@ -168,7 +164,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
168164
weight2 = weight2 * scaling_factors
169165

170166
set_module_weight(module1, weight1)
171-
if conv1_has_bias:
167+
if bias is not None:
172168
set_module_bias(module1, bias)
173169
set_module_weight(module2, weight2)
174170

0 commit comments

Comments
 (0)