File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
megatron/core/dist_checkpointing/strategies Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -221,6 +221,7 @@ def sharded_tensor_to_torch_sharded_tensor(
221221 ]
222222
223223 # Create a ShardedTensor without invoking communication. Determine global shards
224+ world_size = torch .distributed .get_world_size ()
224225 shard_metadata = []
225226 # NOTE: here we assume a regular grid of shards
226227 for fragment_offsets in itertools .product (* map (range , some_sh_ten .axis_fragmentations )):
@@ -244,13 +245,16 @@ def sharded_tensor_to_torch_sharded_tensor(
244245
245246 else :
246247 # for shards from other ranks we provide simplistic data - this information will be discarded
247- # during TorchShardedTensor._init_from_local_shards_and_global_metadata call
248+ # during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
249+ # Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
250+ # The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
251+ placement = f"rank:{ (rank + 1 ) % world_size } /cuda"
248252 if has_flattened_range and not is_flattened_range_1d :
249253 offset = offset + (0 ,)
250254 size = (1 ,) * len (offsets_shape ) + global_shape [- 1 :]
251255 else :
252256 size = offsets_shape
253- shard_metadata .append (ShardMetadata (offset , size , "cuda" ))
257+ shard_metadata .append (ShardMetadata (offset , size , placement ))
254258
255259 tensor = some_sh_ten .data
256260 sharded_tensor_metadata = ShardedTensorMetadata (
You can’t perform that action at this time.
0 commit comments