Skip to content

Commit 48a6dac

Browse files
frozenleavesfrozenleaves
authored andcommitted
fix bug: fsdp2 cannnot run with npu, because the hardcode with cuda
1 parent 9cb1a6b commit 48a6dac

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/accelerate/utils/fsdp_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,9 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
465465
"""
466466
import torch.distributed as dist
467467
from torch.distributed.tensor import distribute_tensor
468-
469-
# Model was previously copied to meta device
468+
from accelerate.state import PartialState
469+
470+
# Model was previously copied to meta device
470471
meta_sharded_sd = model.state_dict()
471472
sharded_sd = {}
472473

@@ -498,8 +499,8 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
498499

499500
if accelerator.is_main_process:
500501
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
501-
full_param = full_param.detach().cuda()
502502
mesh = sharded_param.device_mesh
503+
full_param = full_param.detach().to(mesh.device_type)
503504
dist.broadcast(full_param, src=0, group=mesh.get_group())
504505
sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements)
505506
to_contiguous, casting_dtype = _infer_parameter_dtype(
@@ -512,8 +513,8 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
512513
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
513514
else:
514515
for param_name, sharded_param in meta_sharded_sd.items():
515-
full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype)
516516
mesh = sharded_param.device_mesh
517+
full_tensor = torch.empty(sharded_param.size(), device=mesh.device_type, dtype=sharded_param.dtype)
517518
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
518519
sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements)
519520
to_contiguous, casting_dtype = _infer_parameter_dtype(

0 commit comments

Comments
 (0)