@@ -739,7 +739,6 @@ def test_disallowed_attributes(self):
739739 AttributeError ,
740740 match = "Attribute name reshape can't be used with @tensorclass" ,
741741 ):
742-
743742 @tensorclass
744743 class MyInvalidClass :
745744 x : torch .Tensor
@@ -1101,7 +1100,6 @@ class MyDataParent:
11011100 @pytest .mark .parametrize ("list_to_stack" , [True , False ])
11021101 def test_indexing (self , list_to_stack ):
11031102 with set_list_to_stack (list_to_stack ):
1104-
11051103 @tensorclass
11061104 class MyDataNested :
11071105 X : torch .Tensor
@@ -1438,8 +1436,8 @@ class MyDataNested(TensorClass):
14381436 assert (
14391437 repeated .X
14401438 == X .repeat_interleave (
1441- torch .tensor ([2 , 3 , 4 , 5 ], device = data .device ), dim = 1
1442- )
1439+ torch .tensor ([2 , 3 , 4 , 5 ], device = data .device ), dim = 1
1440+ )
14431441 ).all ()
14441442
14451443 def test_reshape (self ):
@@ -2890,23 +2888,20 @@ class FuncAutoCast:
28902888class TestShadow :
28912889 def test_no_shadow (self ):
28922890 with pytest .raises (AttributeError ):
2893-
28942891 @tensorclass
28952892 class MyClass :
28962893 x : str
28972894 y : int
28982895 batch_size : Any
28992896
29002897 with pytest .raises (AttributeError ):
2901-
29022898 @tensorclass
29032899 class MyClass : # noqa: F811
29042900 x : str
29052901 y : int
29062902 names : Any
29072903
29082904 with pytest .raises (AttributeError ):
2909-
29102905 @tensorclass
29112906 class MyClass : # noqa: F811
29122907 x : str
@@ -3104,7 +3099,7 @@ class MyClass:
31043099 _ = c / 1
31053100 _ = 1 / c
31063101
3107- _ = c ** 1
3102+ _ = c ** 1
31083103 # not implemented
31093104 # 1 ** c
31103105
@@ -3304,15 +3299,13 @@ class TensorOnly:
33043299 c : torch .Tensor | None = None
33053300
33063301 with pytest .raises (TypeError , match = "tensor_only" ):
3307-
33083302 @tensorclass (tensor_only = True , nocast = True )
33093303 class TensorOnlyNocast :
33103304 a : torch .Tensor
33113305 b : torch .Tensor
33123306 c : torch .Tensor | None = None
33133307
33143308 with pytest .raises (TypeError , match = "tensor_only" ):
3315-
33163309 @tensorclass (tensor_only = True , autocast = True )
33173310 class TensorOnlyAutocast :
33183311 a : torch .Tensor
@@ -3337,7 +3330,6 @@ class TensorOnly(TensorClass["tensor_only"]):
33373330 TypeError ,
33383331 match = "tensor_only requires types to be Tensor, Tensor-subtrypes or None" ,
33393332 ):
3340-
33413333 class TensorOnlyAny (TensorClass ["tensor_only" ]):
33423334 a : torch .Tensor
33433335 b : Any
@@ -3347,7 +3339,6 @@ class TensorOnlyAny(TensorClass["tensor_only"]):
33473339 TypeError ,
33483340 match = "tensor_only requires types to be Tensor, Tensor-subtrypes or None" ,
33493341 ):
3350-
33513342 class TensorOnlyStr (TensorClass ["tensor_only" ]):
33523343 a : torch .Tensor
33533344 b : torch .Tensor | str
@@ -3357,7 +3348,6 @@ class TensorOnlyStr(TensorClass["tensor_only"]):
33573348 TypeError ,
33583349 match = "tensor_only requires types to be Tensor, Tensor-subtrypes or None" ,
33593350 ):
3360-
33613351 class TensorOnlyStrUnion (TensorClass ["tensor_only" ]):
33623352 a : torch .Tensor
33633353 b : torch .Tensor
0 commit comments