@@ -465,8 +465,9 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
465
465
"""
466
466
import torch .distributed as dist
467
467
from torch .distributed .tensor import distribute_tensor
468
-
469
- # Model was previously copied to meta device
468
+ from accelerate .state import PartialState
469
+
470
+ # Model was previously copied to meta device
470
471
meta_sharded_sd = model .state_dict ()
471
472
sharded_sd = {}
472
473
@@ -498,8 +499,8 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
498
499
499
500
if accelerator .is_main_process :
500
501
for (param_name , full_param ), sharded_param in zip (full_sd .items (), meta_sharded_sd .values ()):
501
- full_param = full_param .detach ().cuda ()
502
502
mesh = sharded_param .device_mesh
503
+ full_param = full_param .detach ().to (mesh .device_type )
503
504
dist .broadcast (full_param , src = 0 , group = mesh .get_group ())
504
505
sharded_tensor = distribute_tensor (full_param , mesh , sharded_param .placements )
505
506
to_contiguous , casting_dtype = _infer_parameter_dtype (
@@ -512,8 +513,8 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
512
513
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
513
514
else :
514
515
for param_name , sharded_param in meta_sharded_sd .items ():
515
- full_tensor = torch .empty (sharded_param .size (), device = "cuda" , dtype = sharded_param .dtype )
516
516
mesh = sharded_param .device_mesh
517
+ full_tensor = torch .empty (sharded_param .size (), device = mesh .device_type , dtype = sharded_param .dtype )
517
518
dist .broadcast (full_tensor , src = 0 , group = mesh .get_group ())
518
519
sharded_tensor = distribute_tensor (full_tensor , mesh , sharded_param .placements )
519
520
to_contiguous , casting_dtype = _infer_parameter_dtype (
0 commit comments