Skip to content

Commit 63a6bc4

Browse files
author
chenhao388
committed
lint
1 parent 76ef99f commit 63a6bc4

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

test/test_nn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from tensordict.tensorclass import TensorClass
6060
from _utils_internal import is_npu_available
6161

62-
6362
from torch import distributions, nn
6463
from torch.distributions import Categorical, Normal
6564
from torch.utils._pytree import tree_map
@@ -83,7 +82,6 @@
8382
except ImportError:
8483
from tensordict.utils import Buffer
8584

86-
8785
IS_FB = os.getenv("PYTORCH_TEST_FBCODE")
8886

8987
# Capture all warnings
@@ -113,6 +111,7 @@
113111
pytest.mark.filterwarnings("ignore:inplace"),
114112
)
115113

114+
116115
def get_device():
117116
device = torch.device("cpu")
118117
if torch.cuda.is_available():
@@ -123,6 +122,7 @@ def get_device():
123122
device = torch.device("mps:0")
124123
return device
125124

125+
126126
class TestInteractionType:
127127
def test_base(self):
128128
with set_interaction_type("DETERMINISTIC"):
@@ -2161,7 +2161,7 @@ def test_module_buffer():
21612161
module.cuda()
21622162
assert module.td.device.type == "cuda"
21632163
elif is_npu_available():
2164-
module.npu()
2164+
module = module.to("npu:0")
21652165
assert module.td.device.type == "npu"
21662166

21672167

@@ -2186,7 +2186,6 @@ def test_module_buffer():
21862186
)
21872187
def test_to_context(original_device, new_device, tc):
21882188
if tc:
2189-
21902189
class MyTC(TensorClass):
21912190
x: torch.Tensor
21922191
y: torch.Tensor | None = None

test/test_tensorclass.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,6 @@ def test_disallowed_attributes(self):
739739
AttributeError,
740740
match="Attribute name reshape can't be used with @tensorclass",
741741
):
742-
743742
@tensorclass
744743
class MyInvalidClass:
745744
x: torch.Tensor
@@ -1101,7 +1100,6 @@ class MyDataParent:
11011100
@pytest.mark.parametrize("list_to_stack", [True, False])
11021101
def test_indexing(self, list_to_stack):
11031102
with set_list_to_stack(list_to_stack):
1104-
11051103
@tensorclass
11061104
class MyDataNested:
11071105
X: torch.Tensor
@@ -1438,8 +1436,8 @@ class MyDataNested(TensorClass):
14381436
assert (
14391437
repeated.X
14401438
== X.repeat_interleave(
1441-
torch.tensor([2, 3, 4, 5], device=data.device), dim=1
1442-
)
1439+
torch.tensor([2, 3, 4, 5], device=data.device), dim=1
1440+
)
14431441
).all()
14441442

14451443
def test_reshape(self):
@@ -2890,23 +2888,20 @@ class FuncAutoCast:
28902888
class TestShadow:
28912889
def test_no_shadow(self):
28922890
with pytest.raises(AttributeError):
2893-
28942891
@tensorclass
28952892
class MyClass:
28962893
x: str
28972894
y: int
28982895
batch_size: Any
28992896

29002897
with pytest.raises(AttributeError):
2901-
29022898
@tensorclass
29032899
class MyClass: # noqa: F811
29042900
x: str
29052901
y: int
29062902
names: Any
29072903

29082904
with pytest.raises(AttributeError):
2909-
29102905
@tensorclass
29112906
class MyClass: # noqa: F811
29122907
x: str
@@ -3104,7 +3099,7 @@ class MyClass:
31043099
_ = c / 1
31053100
_ = 1 / c
31063101

3107-
_ = c**1
3102+
_ = c ** 1
31083103
# not implemented
31093104
# 1 ** c
31103105

@@ -3304,15 +3299,13 @@ class TensorOnly:
33043299
c: torch.Tensor | None = None
33053300

33063301
with pytest.raises(TypeError, match="tensor_only"):
3307-
33083302
@tensorclass(tensor_only=True, nocast=True)
33093303
class TensorOnlyNocast:
33103304
a: torch.Tensor
33113305
b: torch.Tensor
33123306
c: torch.Tensor | None = None
33133307

33143308
with pytest.raises(TypeError, match="tensor_only"):
3315-
33163309
@tensorclass(tensor_only=True, autocast=True)
33173310
class TensorOnlyAutocast:
33183311
a: torch.Tensor
@@ -3337,7 +3330,6 @@ class TensorOnly(TensorClass["tensor_only"]):
33373330
TypeError,
33383331
match="tensor_only requires types to be Tensor, Tensor-subtrypes or None",
33393332
):
3340-
33413333
class TensorOnlyAny(TensorClass["tensor_only"]):
33423334
a: torch.Tensor
33433335
b: Any
@@ -3347,7 +3339,6 @@ class TensorOnlyAny(TensorClass["tensor_only"]):
33473339
TypeError,
33483340
match="tensor_only requires types to be Tensor, Tensor-subtrypes or None",
33493341
):
3350-
33513342
class TensorOnlyStr(TensorClass["tensor_only"]):
33523343
a: torch.Tensor
33533344
b: torch.Tensor | str
@@ -3357,7 +3348,6 @@ class TensorOnlyStr(TensorClass["tensor_only"]):
33573348
TypeError,
33583349
match="tensor_only requires types to be Tensor, Tensor-subtrypes or None",
33593350
):
3360-
33613351
class TensorOnlyStrUnion(TensorClass["tensor_only"]):
33623352
a: torch.Tensor
33633353
b: torch.Tensor

0 commit comments

Comments
 (0)