diff --git a/test/test_compile.py b/test/test_compile.py index 27cbcacbc..c0b7197d1 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -7,12 +7,15 @@ import importlib.util import inspect import platform +import sys from pathlib import Path from typing import Any, Callable import pytest import torch + +from _utils_internal import is_npu_available from packaging import version from tensordict import ( @@ -50,7 +53,17 @@ _IS_OSX = platform.system() == "Darwin" +npu_device_count = 0 +if torch.cuda.is_available(): + cur_device = "cuda" +elif is_npu_available(): + cur_device = "npu" + npu_device_count = torch.npu.device_count() + +@pytest.mark.skipif( + sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ " +) def test_vmap_compile(): # Since we monkey patch vmap we need to make sure compile is happy with it def func(x, y): @@ -67,6 +80,9 @@ def func(x, y): @pytest.mark.skipif( TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" ) +@pytest.mark.skipif( + sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ " +) @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestTD: def test_tensor_output(self, mode): @@ -266,7 +282,7 @@ def make_td_with_names(data): ) @pytest.mark.parametrize("has_device", [True, False]) def test_to(self, has_device, mode): - device = "cuda:0" + device = f"{cur_device}:0" def test_to_device(td): return td.to(device) @@ -283,6 +299,10 @@ def test_to_device(td): assert td_device_c.batch_size == td.batch_size assert td_device_c.device == torch.device(device) + @pytest.mark.skipif( + is_npu_available(), + reason="torch.device in torch.compile is not supported on NPU currently.", + ) def test_lock(self, mode): def locked_op(td): # Adding stuff uses cache, check that this doesn't break @@ -357,6 +377,9 @@ class MyClass: @pytest.mark.skipif( TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" ) +@pytest.mark.skipif( + sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ " +) @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestTC: def test_tc_tensor_output(self, mode): @@ -553,7 +576,7 @@ def clone(td: TensorDict): ) @pytest.mark.parametrize("has_device", [True, False]) def test_tc_to(self, has_device, mode): - device = "cuda:0" + device = f"{cur_device}:0" def test_to_device(tc): return tc.to(device) @@ -570,6 +593,10 @@ def test_to_device(tc): assert tc_device_c.batch_size == data.batch_size assert tc_device_c.device == torch.device(device) + @pytest.mark.skipif( + is_npu_available(), + reason="torch.device in torch.compile is not supported on NPU currently.", + ) def test_tc_lock(self, mode): def locked_op(tc): # Adding stuff uses cache, check that this doesn't break @@ -621,6 +648,9 @@ def func_c_mytd(): @pytest.mark.skipif( TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" ) +@pytest.mark.skipif( + sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ " +) @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestNN: def test_func(self, mode): @@ -725,6 +755,9 @@ def test_prob_module_with_kwargs(self, mode): @pytest.mark.skipif( TORCH_VERSION <= version.parse("2.4.0"), reason="requires torch>2.4" ) +@pytest.mark.skipif( + sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ " +) @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestFunctional: def test_functional_error(self, mode): @@ -1023,6 +1056,9 @@ def to_numpy(tensor): (TORCH_VERSION <= version.parse("2.7.0")) and _IS_OSX, reason="requires torch>=2.7 ons OSX", ) +@pytest.mark.skipif( + sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ " +) @pytest.mark.parametrize("compiled", [False, True]) class TestCudaGraphs: @pytest.fixture(scope="class", autouse=True) @@ -1247,7 +1283,7 @@ class TestCompileNontensor: # Same issue with the decorator @tensorclass version @pytest.fixture(scope="class") def data(self): - return torch.zeros((4, 3), device="cuda") + return torch.zeros((4, 3), device=cur_device) class TensorClassWithNonTensorData(TensorClass["nocast"]): tensor: torch.Tensor @@ -1265,13 +1301,13 @@ def fn_no_device(self, data): def fn_with_device(self, data): a = self.TensorClassWithNonTensorData( - tensor=data, non_tensor_data=1, batch_size=[4], device="cuda" + tensor=data, non_tensor_data=1, batch_size=[4], device=cur_device ) return a.tensor def fn_with_device_without_batch_size(self, data): a = self.TensorClassWithNonTensorData( - tensor=data, non_tensor_data=1, device="cuda" + tensor=data, non_tensor_data=1, device=cur_device ) return a.tensor diff --git a/test/test_distributed.py b/test/test_distributed.py index 7910a9900..b0e30d5cf 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -11,6 +11,7 @@ import pytest import torch from _pytest.fixtures import fixture +from _utils_internal import is_npu_available from packaging import version from packaging.version import parse @@ -107,6 +108,70 @@ def test_fsdp_module(self, tmpdir): assert (TensorDict.load_memmap(tmpdir) == 1).all() +@pytest.mark.skipif( + not is_npu_available() or not torch.npu.device_count() > 2, + reason="not enough npu devices", +) +class TestNPUFSDP: + class MyDModule(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 8, bias=False) + self.fc2 = nn.Linear(8, 8, bias=False) + self.relu = nn.ReLU() + for p in self.parameters(): + p.data.fill_(1.0) + + def forward(self, input): + return self.relu(self.fc1(input) + self.fc2(input)) + + @classmethod + def make_module(cls, device=None): + with ( + torch.device(f"npu:{device}") if device is not None else torch.device("npu") + ): + my_module = cls.MyDModule() + my_sharded_module = FSDP(my_module, device_id=device) + return my_sharded_module + + @classmethod + def worker(cls, rank, path): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "10017" + + torch.distributed.init_process_group( + "hccl", + rank=rank, + world_size=2, + init_method="tcp://localhost:10017", + ) + torch.npu.set_device(rank) + module = cls.make_module(rank) + dist.barrier() + # cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + # with FSDP.state_dict_type(module, StateDictType.SHARDED_STATE_DICT): #, cfg): + # tdlogger.info(module.state_dict()) + + # td = TensorDict(module.state_dict(), []).unflatten_keys(".") + td = TensorDict.from_module(module, use_state_dict=True) + if rank == 0: + td.memmap(path) + dist.destroy_process_group() + + def test_fsdp_module(self, tmpdir): + try: + mp.set_start_method("spawn") + except Exception: + tdlogger.info("start method already set to", mp.get_start_method()) + proc0 = mp.Process(target=self.worker, args=(0, tmpdir)) + proc1 = mp.Process(target=self.worker, args=(1, tmpdir)) + proc0.start() + proc1.start() + proc0.join(timeout=TIMEOUT) + proc1.join(timeout=TIMEOUT) + assert (TensorDict.load_memmap(tmpdir) == 1).all() + + # not using TorchVersion to make the comparison work with dev TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) diff --git a/test/test_nn.py b/test/test_nn.py index 253514324..68276388b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -17,6 +17,7 @@ import pytest import torch +from _utils_internal import is_npu_available from tensordict import ( is_tensor_collection, @@ -81,7 +82,6 @@ except ImportError: from tensordict.utils import Buffer - IS_FB = os.getenv("PYTORCH_TEST_FBCODE") # Capture all warnings @@ -112,6 +112,17 @@ ) +def get_device(): + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + elif is_npu_available(): + device = torch.device("npu:0") + elif torch.mps.is_available(): + device = torch.device("mps:0") + return device + + class TestInteractionType: def test_base(self): with set_interaction_type("DETERMINISTIC"): @@ -2149,6 +2160,9 @@ def test_module_buffer(): if torch.cuda.device_count(): module.cuda() assert module.td.device.type == "cuda" + elif is_npu_available(): + module = module.to("npu:0") + assert module.td.device.type == "npu" @pytest.mark.parametrize( @@ -2156,30 +2170,14 @@ def test_module_buffer(): [ None, torch.device("cpu"), - ( - torch.device("cuda:0") - if torch.cuda.is_available() - else ( - torch.device("mps:0") - if torch.mps.is_available() - else torch.device("cpu") - ) - ), + get_device(), ], ) @pytest.mark.parametrize( "new_device", [ torch.device("cpu"), - ( - torch.device("cuda:0") - if torch.cuda.is_available() - else ( - torch.device("mps:0") - if torch.mps.is_available() - else torch.device("cpu") - ) - ), + get_device(), ], ) @pytest.mark.parametrize("tc", [True, False], ids=["tc", "td"]) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 4f0a9266a..e89a1b3ac 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -26,6 +26,7 @@ import pytest import tensordict.utils import torch +from _utils_internal import is_npu_available from tensordict import ( assert_allclose_td, @@ -45,6 +46,7 @@ from tensordict._td import lazy_stack from tensordict.base import _GENERIC_NESTED_ERR from tensordict.tensorclass import from_dataclass + from torch import Tensor _has_streaming = importlib.util.find_spec("streaming", None) is not None @@ -2566,6 +2568,8 @@ def test_to(self): td = self.get_nested() if torch.cuda.is_available(): device = torch.device("cuda:0") + elif is_npu_available(): + device = torch.device("npu:0") else: device = torch.device("cpu:1") td_device = td.to(device)