Skip to content

Commit 76ef99f

Browse files
author
chenhao388
committed
Add more supports for NPU in addition to CUDA in previously supported use cases.
1 parent b59ca7c commit 76ef99f

File tree

4 files changed

+109
-32
lines changed

4 files changed

+109
-32
lines changed

test/test_compile.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
from tensordict.tensorclass import TensorClass
4040

41+
from _utils_internal import is_npu_available
42+
4143
from torch.utils._pytree import SUPPORTED_NODES, tree_map
4244

4345
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
@@ -50,6 +52,13 @@
5052

5153
_IS_OSX = platform.system() == "Darwin"
5254

55+
npu_device_count = 0
56+
if torch.cuda.is_available():
57+
cur_device = "cuda"
58+
elif is_npu_available():
59+
cur_device = "npu"
60+
npu_device_count = torch.npu.device_count()
61+
5362

5463
def test_vmap_compile():
5564
# Since we monkey patch vmap we need to make sure compile is happy with it
@@ -262,11 +271,11 @@ def make_td_with_names(data):
262271
make_td_with_names_c(data_dict)
263272

264273
@pytest.mark.skipif(
265-
not torch.cuda.is_available(), reason="cuda required to test device casting"
274+
not torch.cuda.is_available() and not is_npu_available(), reason="cuda or npu required to test device casting"
266275
)
267276
@pytest.mark.parametrize("has_device", [True, False])
268277
def test_to(self, has_device, mode):
269-
device = "cuda:0"
278+
device = f"{cur_device}:0"
270279

271280
def test_to_device(td):
272281
return td.to(device)
@@ -549,11 +558,11 @@ def clone(td: TensorDict):
549558
assert clone_c(data).a.b is data.a.b
550559

551560
@pytest.mark.skipif(
552-
not torch.cuda.is_available(), reason="cuda required to test device casting"
561+
not torch.cuda.is_available() and not is_npu_available(), reason="cuda or npu required to test device casting"
553562
)
554563
@pytest.mark.parametrize("has_device", [True, False])
555564
def test_tc_to(self, has_device, mode):
556-
device = "cuda:0"
565+
device = f"{cur_device}:0"
557566

558567
def test_to_device(tc):
559568
return tc.to(device)
@@ -1242,12 +1251,12 @@ def test_state_dict(self, compiled):
12421251
torch.testing.assert_close(y1, y2)
12431252

12441253

1245-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
1254+
@pytest.mark.skipif(not torch.cuda.is_available() and not is_npu_available(), reason="cuda or npu is not available")
12461255
class TestCompileNontensor:
12471256
# Same issue with the decorator @tensorclass version
12481257
@pytest.fixture(scope="class")
12491258
def data(self):
1250-
return torch.zeros((4, 3), device="cuda")
1259+
return torch.zeros((4, 3), device=cur_device)
12511260

12521261
class TensorClassWithNonTensorData(TensorClass["nocast"]):
12531262
tensor: torch.Tensor
@@ -1265,13 +1274,13 @@ def fn_no_device(self, data):
12651274

12661275
def fn_with_device(self, data):
12671276
a = self.TensorClassWithNonTensorData(
1268-
tensor=data, non_tensor_data=1, batch_size=[4], device="cuda"
1277+
tensor=data, non_tensor_data=1, batch_size=[4], device=cur_device
12691278
)
12701279
return a.tensor
12711280

12721281
def fn_with_device_without_batch_size(self, data):
12731282
a = self.TensorClassWithNonTensorData(
1274-
tensor=data, non_tensor_data=1, device="cuda"
1283+
tensor=data, non_tensor_data=1, device=cur_device
12751284
)
12761285
return a.tensor
12771286

test/test_distributed.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from tensordict import LazyStackedTensorDict, MemoryMappedTensor, TensorDict
1919
from tensordict.utils import logger as tdlogger
20+
from _utils_internal import is_npu_available
2021
from torch import distributed as dist, multiprocessing as mp, nn
2122
from torch.distributed._tensor import (
2223
DeviceMesh,
@@ -107,6 +108,71 @@ def test_fsdp_module(self, tmpdir):
107108
assert (TensorDict.load_memmap(tmpdir) == 1).all()
108109

109110

111+
@pytest.mark.skipif(
112+
not is_npu_available() or not torch.npu.device_count() > 2, reason="not enough npu devices"
113+
)
114+
class TestNPUFSDP:
115+
class MyDModule(nn.Module):
116+
def __init__(self):
117+
super().__init__()
118+
self.fc1 = nn.Linear(8, 8, bias=False)
119+
self.fc2 = nn.Linear(8, 8, bias=False)
120+
self.relu = nn.ReLU()
121+
for p in self.parameters():
122+
p.data.fill_(1.0)
123+
124+
def forward(self, input):
125+
return self.relu(self.fc1(input) + self.fc2(input))
126+
127+
@classmethod
128+
def make_module(cls, device=None):
129+
with (
130+
torch.device(f"npu:{device}")
131+
if device is not None
132+
else torch.device("npu")
133+
):
134+
my_module = cls.MyDModule()
135+
my_sharded_module = FSDP(my_module, device_id=device)
136+
return my_sharded_module
137+
138+
@classmethod
139+
def worker(cls, rank, path):
140+
os.environ["MASTER_ADDR"] = "localhost"
141+
os.environ["MASTER_PORT"] = "10017"
142+
143+
torch.distributed.init_process_group(
144+
"hccl",
145+
rank=rank,
146+
world_size=2,
147+
init_method="tcp://localhost:10017",
148+
)
149+
torch.npu.set_device(rank)
150+
module = cls.make_module(rank)
151+
dist.barrier()
152+
# cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
153+
# with FSDP.state_dict_type(module, StateDictType.SHARDED_STATE_DICT): #, cfg):
154+
# tdlogger.info(module.state_dict())
155+
156+
# td = TensorDict(module.state_dict(), []).unflatten_keys(".")
157+
td = TensorDict.from_module(module, use_state_dict=True)
158+
if rank == 0:
159+
td.memmap(path)
160+
dist.destroy_process_group()
161+
162+
def test_fsdp_module(self, tmpdir):
163+
try:
164+
mp.set_start_method("spawn")
165+
except Exception:
166+
tdlogger.info("start method already set to", mp.get_start_method())
167+
proc0 = mp.Process(target=self.worker, args=(0, tmpdir))
168+
proc1 = mp.Process(target=self.worker, args=(1, tmpdir))
169+
proc0.start()
170+
proc1.start()
171+
proc0.join(timeout=TIMEOUT)
172+
proc1.join(timeout=TIMEOUT)
173+
assert (TensorDict.load_memmap(tmpdir) == 1).all()
174+
175+
110176
# not using TorchVersion to make the comparison work with dev
111177
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
112178

@@ -241,8 +307,8 @@ def server(queue):
241307
},
242308
[2],
243309
)
244-
.expand(1, 2)
245-
.contiguous()
310+
.expand(1, 2)
311+
.contiguous()
246312
)
247313
td.gather_and_stack(0)
248314
assert (td != 0).all()
@@ -314,8 +380,8 @@ def server(queue, op, async_op, return_premature):
314380
},
315381
[2],
316382
)
317-
.expand(1, 2)
318-
.contiguous()
383+
.expand(1, 2)
384+
.contiguous()
319385
)
320386
out = td.reduce(0, op=op, async_op=async_op, return_premature=return_premature)
321387
if not async_op:
@@ -798,8 +864,8 @@ def server(cls, queue):
798864
},
799865
[2],
800866
)
801-
.expand(1, 2)
802-
.contiguous()
867+
.expand(1, 2)
868+
.contiguous()
803869
)
804870
td.init_remote(dst=1)
805871

test/test_nn.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
skip_existing,
5858
)
5959
from tensordict.tensorclass import TensorClass
60+
from _utils_internal import is_npu_available
61+
6062

6163
from torch import distributions, nn
6264
from torch.distributions import Categorical, Normal
@@ -111,6 +113,15 @@
111113
pytest.mark.filterwarnings("ignore:inplace"),
112114
)
113115

116+
def get_device():
117+
device = torch.device("cpu")
118+
if torch.cuda.is_available():
119+
device = torch.device("cuda:0")
120+
elif is_npu_available():
121+
device = torch.device("npu:0")
122+
elif torch.mps.is_available():
123+
device = torch.device("mps:0")
124+
return device
114125

115126
class TestInteractionType:
116127
def test_base(self):
@@ -2149,37 +2160,24 @@ def test_module_buffer():
21492160
if torch.cuda.device_count():
21502161
module.cuda()
21512162
assert module.td.device.type == "cuda"
2163+
elif is_npu_available():
2164+
module.npu()
2165+
assert module.td.device.type == "npu"
21522166

21532167

21542168
@pytest.mark.parametrize(
21552169
"original_device",
21562170
[
21572171
None,
21582172
torch.device("cpu"),
2159-
(
2160-
torch.device("cuda:0")
2161-
if torch.cuda.is_available()
2162-
else (
2163-
torch.device("mps:0")
2164-
if torch.mps.is_available()
2165-
else torch.device("cpu")
2166-
)
2167-
),
2173+
get_device(),
21682174
],
21692175
)
21702176
@pytest.mark.parametrize(
21712177
"new_device",
21722178
[
21732179
torch.device("cpu"),
2174-
(
2175-
torch.device("cuda:0")
2176-
if torch.cuda.is_available()
2177-
else (
2178-
torch.device("mps:0")
2179-
if torch.mps.is_available()
2180-
else torch.device("cpu")
2181-
)
2182-
),
2180+
get_device(),
21832181
],
21842182
)
21852183
@pytest.mark.parametrize("tc", [True, False], ids=["tc", "td"])

test/test_tensorclass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from tensordict._td import lazy_stack
4646
from tensordict.base import _GENERIC_NESTED_ERR
4747
from tensordict.tensorclass import from_dataclass
48+
from _utils_internal import is_npu_available
49+
4850
from torch import Tensor
4951

5052
_has_streaming = importlib.util.find_spec("streaming", None) is not None
@@ -2566,6 +2568,8 @@ def test_to(self):
25662568
td = self.get_nested()
25672569
if torch.cuda.is_available():
25682570
device = torch.device("cuda:0")
2571+
elif is_npu_available():
2572+
device = torch.device("npu:0")
25692573
else:
25702574
device = torch.device("cpu:1")
25712575
td_device = td.to(device)

0 commit comments

Comments
 (0)