Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import importlib.util
import inspect
import platform
import sys
from pathlib import Path
from typing import Any, Callable

import pytest

import torch

from _utils_internal import is_npu_available
from packaging import version

from tensordict import (
Expand Down Expand Up @@ -50,7 +53,17 @@

_IS_OSX = platform.system() == "Darwin"

npu_device_count = 0
if torch.cuda.is_available():
cur_device = "cuda"
elif is_npu_available():
cur_device = "npu"
npu_device_count = torch.npu.device_count()


@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
def test_vmap_compile():
# Since we monkey patch vmap we need to make sure compile is happy with it
def func(x, y):
Expand All @@ -67,6 +80,9 @@ def func(x, y):
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestTD:
def test_tensor_output(self, mode):
Expand Down Expand Up @@ -266,7 +282,7 @@ def make_td_with_names(data):
)
@pytest.mark.parametrize("has_device", [True, False])
def test_to(self, has_device, mode):
device = "cuda:0"
device = f"{cur_device}:0"

def test_to_device(td):
return td.to(device)
Expand All @@ -283,6 +299,10 @@ def test_to_device(td):
assert td_device_c.batch_size == td.batch_size
assert td_device_c.device == torch.device(device)

@pytest.mark.skipif(
is_npu_available(),
reason="torch.device in torch.compile is not supported on NPU currently.",
)
def test_lock(self, mode):
def locked_op(td):
# Adding stuff uses cache, check that this doesn't break
Expand Down Expand Up @@ -357,6 +377,9 @@ class MyClass:
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestTC:
def test_tc_tensor_output(self, mode):
Expand Down Expand Up @@ -553,7 +576,7 @@ def clone(td: TensorDict):
)
@pytest.mark.parametrize("has_device", [True, False])
def test_tc_to(self, has_device, mode):
device = "cuda:0"
device = f"{cur_device}:0"

def test_to_device(tc):
return tc.to(device)
Expand All @@ -570,6 +593,10 @@ def test_to_device(tc):
assert tc_device_c.batch_size == data.batch_size
assert tc_device_c.device == torch.device(device)

@pytest.mark.skipif(
is_npu_available(),
reason="torch.device in torch.compile is not supported on NPU currently.",
)
def test_tc_lock(self, mode):
def locked_op(tc):
# Adding stuff uses cache, check that this doesn't break
Expand Down Expand Up @@ -621,6 +648,9 @@ def func_c_mytd():
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestNN:
def test_func(self, mode):
Expand Down Expand Up @@ -725,6 +755,9 @@ def test_prob_module_with_kwargs(self, mode):
@pytest.mark.skipif(
TORCH_VERSION <= version.parse("2.4.0"), reason="requires torch>2.4"
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestFunctional:
def test_functional_error(self, mode):
Expand Down Expand Up @@ -1023,6 +1056,9 @@ def to_numpy(tensor):
(TORCH_VERSION <= version.parse("2.7.0")) and _IS_OSX,
reason="requires torch>=2.7 ons OSX",
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("compiled", [False, True])
class TestCudaGraphs:
@pytest.fixture(scope="class", autouse=True)
Expand Down Expand Up @@ -1247,7 +1283,7 @@ class TestCompileNontensor:
# Same issue with the decorator @tensorclass version
@pytest.fixture(scope="class")
def data(self):
return torch.zeros((4, 3), device="cuda")
return torch.zeros((4, 3), device=cur_device)

class TensorClassWithNonTensorData(TensorClass["nocast"]):
tensor: torch.Tensor
Expand All @@ -1265,13 +1301,13 @@ def fn_no_device(self, data):

def fn_with_device(self, data):
a = self.TensorClassWithNonTensorData(
tensor=data, non_tensor_data=1, batch_size=[4], device="cuda"
tensor=data, non_tensor_data=1, batch_size=[4], device=cur_device
)
return a.tensor

def fn_with_device_without_batch_size(self, data):
a = self.TensorClassWithNonTensorData(
tensor=data, non_tensor_data=1, device="cuda"
tensor=data, non_tensor_data=1, device=cur_device
)
return a.tensor

Expand Down
65 changes: 65 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
import torch
from _pytest.fixtures import fixture
from _utils_internal import is_npu_available
from packaging import version

from packaging.version import parse
Expand Down Expand Up @@ -107,6 +108,70 @@ def test_fsdp_module(self, tmpdir):
assert (TensorDict.load_memmap(tmpdir) == 1).all()


@pytest.mark.skipif(
not is_npu_available() or not torch.npu.device_count() > 2,
reason="not enough npu devices",
)
class TestNPUFSDP:
class MyDModule(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(8, 8, bias=False)
self.fc2 = nn.Linear(8, 8, bias=False)
self.relu = nn.ReLU()
for p in self.parameters():
p.data.fill_(1.0)

def forward(self, input):
return self.relu(self.fc1(input) + self.fc2(input))

@classmethod
def make_module(cls, device=None):
with (
torch.device(f"npu:{device}") if device is not None else torch.device("npu")
):
my_module = cls.MyDModule()
my_sharded_module = FSDP(my_module, device_id=device)
return my_sharded_module

@classmethod
def worker(cls, rank, path):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10017"

torch.distributed.init_process_group(
"hccl",
rank=rank,
world_size=2,
init_method="tcp://localhost:10017",
)
torch.npu.set_device(rank)
module = cls.make_module(rank)
dist.barrier()
# cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
# with FSDP.state_dict_type(module, StateDictType.SHARDED_STATE_DICT): #, cfg):
# tdlogger.info(module.state_dict())

# td = TensorDict(module.state_dict(), []).unflatten_keys(".")
td = TensorDict.from_module(module, use_state_dict=True)
if rank == 0:
td.memmap(path)
dist.destroy_process_group()

def test_fsdp_module(self, tmpdir):
try:
mp.set_start_method("spawn")
except Exception:
tdlogger.info("start method already set to", mp.get_start_method())
proc0 = mp.Process(target=self.worker, args=(0, tmpdir))
proc1 = mp.Process(target=self.worker, args=(1, tmpdir))
proc0.start()
proc1.start()
proc0.join(timeout=TIMEOUT)
proc1.join(timeout=TIMEOUT)
assert (TensorDict.load_memmap(tmpdir) == 1).all()


# not using TorchVersion to make the comparison work with dev
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

Expand Down
36 changes: 17 additions & 19 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest
import torch
from _utils_internal import is_npu_available

from tensordict import (
is_tensor_collection,
Expand Down Expand Up @@ -81,7 +82,6 @@
except ImportError:
from tensordict.utils import Buffer


IS_FB = os.getenv("PYTORCH_TEST_FBCODE")

# Capture all warnings
Expand Down Expand Up @@ -112,6 +112,17 @@
)


def get_device():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
elif is_npu_available():
device = torch.device("npu:0")
elif torch.mps.is_available():
device = torch.device("mps:0")
return device


class TestInteractionType:
def test_base(self):
with set_interaction_type("DETERMINISTIC"):
Expand Down Expand Up @@ -2149,37 +2160,24 @@ def test_module_buffer():
if torch.cuda.device_count():
module.cuda()
assert module.td.device.type == "cuda"
elif is_npu_available():
module = module.to("npu:0")
assert module.td.device.type == "npu"


@pytest.mark.parametrize(
"original_device",
[
None,
torch.device("cpu"),
(
torch.device("cuda:0")
if torch.cuda.is_available()
else (
torch.device("mps:0")
if torch.mps.is_available()
else torch.device("cpu")
)
),
get_device(),
],
)
@pytest.mark.parametrize(
"new_device",
[
torch.device("cpu"),
(
torch.device("cuda:0")
if torch.cuda.is_available()
else (
torch.device("mps:0")
if torch.mps.is_available()
else torch.device("cpu")
)
),
get_device(),
],
)
@pytest.mark.parametrize("tc", [True, False], ids=["tc", "td"])
Expand Down
4 changes: 4 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
import tensordict.utils
import torch
from _utils_internal import is_npu_available

from tensordict import (
assert_allclose_td,
Expand All @@ -45,6 +46,7 @@
from tensordict._td import lazy_stack
from tensordict.base import _GENERIC_NESTED_ERR
from tensordict.tensorclass import from_dataclass

from torch import Tensor

_has_streaming = importlib.util.find_spec("streaming", None) is not None
Expand Down Expand Up @@ -2566,6 +2568,8 @@ def test_to(self):
td = self.get_nested()
if torch.cuda.is_available():
device = torch.device("cuda:0")
elif is_npu_available():
device = torch.device("npu:0")
else:
device = torch.device("cpu:1")
td_device = td.to(device)
Expand Down