Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import importlib.util
import inspect
import platform
import sys
from pathlib import Path
from typing import Any, Callable

Expand Down Expand Up @@ -38,6 +39,8 @@

from tensordict.tensorclass import TensorClass

from _utils_internal import is_npu_available

from torch.utils._pytree import SUPPORTED_NODES, tree_map

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
Expand All @@ -50,7 +53,17 @@

_IS_OSX = platform.system() == "Darwin"

npu_device_count = 0
if torch.cuda.is_available():
cur_device = "cuda"
elif is_npu_available():
cur_device = "npu"
npu_device_count = torch.npu.device_count()


@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
def test_vmap_compile():
# Since we monkey patch vmap we need to make sure compile is happy with it
def func(x, y):
Expand All @@ -67,6 +80,9 @@ def func(x, y):
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestTD:
def test_tensor_output(self, mode):
Expand Down Expand Up @@ -266,7 +282,7 @@ def make_td_with_names(data):
)
@pytest.mark.parametrize("has_device", [True, False])
def test_to(self, has_device, mode):
device = "cuda:0"
device = f"{cur_device}:0"

def test_to_device(td):
return td.to(device)
Expand All @@ -283,6 +299,9 @@ def test_to_device(td):
assert td_device_c.batch_size == td.batch_size
assert td_device_c.device == torch.device(device)

@pytest.mark.skipif(
is_npu_available(), reason="torch.device in torch.compile is not supported on NPU currently."
)
def test_lock(self, mode):
def locked_op(td):
# Adding stuff uses cache, check that this doesn't break
Expand Down Expand Up @@ -357,6 +376,9 @@ class MyClass:
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestTC:
def test_tc_tensor_output(self, mode):
Expand Down Expand Up @@ -553,7 +575,7 @@ def clone(td: TensorDict):
)
@pytest.mark.parametrize("has_device", [True, False])
def test_tc_to(self, has_device, mode):
device = "cuda:0"
device = f"{cur_device}:0"

def test_to_device(tc):
return tc.to(device)
Expand All @@ -570,6 +592,9 @@ def test_to_device(tc):
assert tc_device_c.batch_size == data.batch_size
assert tc_device_c.device == torch.device(device)

@pytest.mark.skipif(
is_npu_available(), reason="torch.device in torch.compile is not supported on NPU currently."
)
def test_tc_lock(self, mode):
def locked_op(tc):
# Adding stuff uses cache, check that this doesn't break
Expand Down Expand Up @@ -621,6 +646,9 @@ def func_c_mytd():
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestNN:
def test_func(self, mode):
Expand Down Expand Up @@ -725,6 +753,9 @@ def test_prob_module_with_kwargs(self, mode):
@pytest.mark.skipif(
TORCH_VERSION <= version.parse("2.4.0"), reason="requires torch>2.4"
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestFunctional:
def test_functional_error(self, mode):
Expand Down Expand Up @@ -1023,6 +1054,9 @@ def to_numpy(tensor):
(TORCH_VERSION <= version.parse("2.7.0")) and _IS_OSX,
reason="requires torch>=2.7 ons OSX",
)
@pytest.mark.skipif(
sys.version_info > (3, 14), reason="torch.compile is not supported on python 3.14+ "
)
@pytest.mark.parametrize("compiled", [False, True])
class TestCudaGraphs:
@pytest.fixture(scope="class", autouse=True)
Expand Down Expand Up @@ -1247,7 +1281,7 @@ class TestCompileNontensor:
# Same issue with the decorator @tensorclass version
@pytest.fixture(scope="class")
def data(self):
return torch.zeros((4, 3), device="cuda")
return torch.zeros((4, 3), device=cur_device)

class TensorClassWithNonTensorData(TensorClass["nocast"]):
tensor: torch.Tensor
Expand All @@ -1265,13 +1299,13 @@ def fn_no_device(self, data):

def fn_with_device(self, data):
a = self.TensorClassWithNonTensorData(
tensor=data, non_tensor_data=1, batch_size=[4], device="cuda"
tensor=data, non_tensor_data=1, batch_size=[4], device=cur_device
)
return a.tensor

def fn_with_device_without_batch_size(self, data):
a = self.TensorClassWithNonTensorData(
tensor=data, non_tensor_data=1, device="cuda"
tensor=data, non_tensor_data=1, device=cur_device
)
return a.tensor

Expand Down
78 changes: 72 additions & 6 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from tensordict import LazyStackedTensorDict, MemoryMappedTensor, TensorDict
from tensordict.utils import logger as tdlogger
from _utils_internal import is_npu_available
from torch import distributed as dist, multiprocessing as mp, nn
from torch.distributed._tensor import (
DeviceMesh,
Expand Down Expand Up @@ -107,6 +108,71 @@ def test_fsdp_module(self, tmpdir):
assert (TensorDict.load_memmap(tmpdir) == 1).all()


@pytest.mark.skipif(
not is_npu_available() or not torch.npu.device_count() > 2, reason="not enough npu devices"
)
class TestNPUFSDP:
class MyDModule(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(8, 8, bias=False)
self.fc2 = nn.Linear(8, 8, bias=False)
self.relu = nn.ReLU()
for p in self.parameters():
p.data.fill_(1.0)

def forward(self, input):
return self.relu(self.fc1(input) + self.fc2(input))

@classmethod
def make_module(cls, device=None):
with (
torch.device(f"npu:{device}")
if device is not None
else torch.device("npu")
):
my_module = cls.MyDModule()
my_sharded_module = FSDP(my_module, device_id=device)
return my_sharded_module

@classmethod
def worker(cls, rank, path):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10017"

torch.distributed.init_process_group(
"hccl",
rank=rank,
world_size=2,
init_method="tcp://localhost:10017",
)
torch.npu.set_device(rank)
module = cls.make_module(rank)
dist.barrier()
# cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
# with FSDP.state_dict_type(module, StateDictType.SHARDED_STATE_DICT): #, cfg):
# tdlogger.info(module.state_dict())

# td = TensorDict(module.state_dict(), []).unflatten_keys(".")
td = TensorDict.from_module(module, use_state_dict=True)
if rank == 0:
td.memmap(path)
dist.destroy_process_group()

def test_fsdp_module(self, tmpdir):
try:
mp.set_start_method("spawn")
except Exception:
tdlogger.info("start method already set to", mp.get_start_method())
proc0 = mp.Process(target=self.worker, args=(0, tmpdir))
proc1 = mp.Process(target=self.worker, args=(1, tmpdir))
proc0.start()
proc1.start()
proc0.join(timeout=TIMEOUT)
proc1.join(timeout=TIMEOUT)
assert (TensorDict.load_memmap(tmpdir) == 1).all()


# not using TorchVersion to make the comparison work with dev
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

Expand Down Expand Up @@ -241,8 +307,8 @@ def server(queue):
},
[2],
)
.expand(1, 2)
.contiguous()
.expand(1, 2)
.contiguous()
)
td.gather_and_stack(0)
assert (td != 0).all()
Expand Down Expand Up @@ -314,8 +380,8 @@ def server(queue, op, async_op, return_premature):
},
[2],
)
.expand(1, 2)
.contiguous()
.expand(1, 2)
.contiguous()
)
out = td.reduce(0, op=op, async_op=async_op, return_premature=return_premature)
if not async_op:
Expand Down Expand Up @@ -798,8 +864,8 @@ def server(cls, queue):
},
[2],
)
.expand(1, 2)
.contiguous()
.expand(1, 2)
.contiguous()
)
td.init_remote(dst=1)

Expand Down
37 changes: 17 additions & 20 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
skip_existing,
)
from tensordict.tensorclass import TensorClass
from _utils_internal import is_npu_available

from torch import distributions, nn
from torch.distributions import Categorical, Normal
Expand All @@ -81,7 +82,6 @@
except ImportError:
from tensordict.utils import Buffer


IS_FB = os.getenv("PYTORCH_TEST_FBCODE")

# Capture all warnings
Expand Down Expand Up @@ -112,6 +112,17 @@
)


def get_device():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
elif is_npu_available():
device = torch.device("npu:0")
elif torch.mps.is_available():
device = torch.device("mps:0")
return device


class TestInteractionType:
def test_base(self):
with set_interaction_type("DETERMINISTIC"):
Expand Down Expand Up @@ -2149,37 +2160,24 @@ def test_module_buffer():
if torch.cuda.device_count():
module.cuda()
assert module.td.device.type == "cuda"
elif is_npu_available():
module = module.to("npu:0")
assert module.td.device.type == "npu"


@pytest.mark.parametrize(
"original_device",
[
None,
torch.device("cpu"),
(
torch.device("cuda:0")
if torch.cuda.is_available()
else (
torch.device("mps:0")
if torch.mps.is_available()
else torch.device("cpu")
)
),
get_device(),
],
)
@pytest.mark.parametrize(
"new_device",
[
torch.device("cpu"),
(
torch.device("cuda:0")
if torch.cuda.is_available()
else (
torch.device("mps:0")
if torch.mps.is_available()
else torch.device("cpu")
)
),
get_device(),
],
)
@pytest.mark.parametrize("tc", [True, False], ids=["tc", "td"])
Expand All @@ -2188,7 +2186,6 @@ def test_module_buffer():
)
def test_to_context(original_device, new_device, tc):
if tc:

class MyTC(TensorClass):
x: torch.Tensor
y: torch.Tensor | None = None
Expand Down
Loading
Loading