Skip to content

Commit 2297178

Browse files
Merge branch 'dnarayanan/remove_buckets_from_checkpoints' into 'main'
Have checkpoints be agnostic to the way parameters are mapped to buckets See merge request ADLR/megatron-lm!1309
2 parents ec16db0 + 369e698 commit 2297178

File tree

2 files changed

+90
-94
lines changed

2 files changed

+90
-94
lines changed

megatron/core/distributed/param_and_grad_buffer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,13 @@ def _does_param_require_new_bucket(param):
307307
# Next, create underlying storage for buffer (with numel elements that includes
308308
# padding as necessary).
309309
self.numel = data_end_index
310+
self.numel_unpadded = sum(per_bucket_numel_unpadded)
311+
assert self.numel_unpadded <= self.numel
310312
if self.ddp_config.use_distributed_optimizer:
311313
assert self.numel % self.data_parallel_world_size == 0
314+
else:
315+
assert self.numel == self.numel_unpadded
316+
312317
self.param_data = None
313318
# Only re-map param tensors if using distributed optimizer.
314319
if self.ddp_config.use_distributed_optimizer:

megatron/core/optimizer/distrib_optimizer.py

Lines changed: 85 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -715,22 +715,38 @@ def get_parameter_state_dp_zero(self):
715715

716716
# Collect param states.
717717
state = {
718-
"per_bucket_numel": self.per_bucket_numel,
719-
"per_bucket_numel_unpadded": self.per_bucket_numel_unpadded,
718+
"buckets_coalesced": True,
720719
}
721720
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
722721

723722
# Iterate grad buffers (by data type).
724723
dtype_state = {}
725724
assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
726725
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
726+
buffer_numel_unpadded = self.buffers[gbuf_idx].numel_unpadded
727+
# Create coalesced tensors for all state related to parameters in this buffer.
727728
world_tensors = {}
729+
if data_parallel_rank == 0:
730+
world_tensors = {
731+
key: torch.empty(
732+
(buffer_numel_unpadded,), dtype=torch.float32, device="cpu"
733+
)
734+
for key in ("param", "exp_avg", "exp_avg_sq")
735+
}
736+
world_tensors["numel_unpadded"] = buffer_numel_unpadded
737+
offset_in_world_tensors = 0
728738
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
729739

730740
# Compute local DP contiguous shard's size.
731741
gbuf_world_numel = self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel()
732742
assert gbuf_world_numel % data_parallel_world_size == 0
733743
gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
744+
745+
gbuf_world_numel_unpadded = (
746+
self.buffers[gbuf_idx].buckets[bucket_idx].numel_unpadded
747+
)
748+
assert gbuf_world_numel_unpadded <= gbuf_world_numel
749+
734750
local_shards = {
735751
key: torch.empty((gbuf_local_numel,), dtype=torch.float32, device="cpu")
736752
for key in ("param", "exp_avg", "exp_avg_sq")
@@ -779,9 +795,17 @@ def get_parameter_state_dp_zero(self):
779795

780796
# Concatenate.
781797
if data_parallel_rank == 0:
782-
if key not in world_tensors:
783-
world_tensors[key] = []
784-
world_tensors[key].append(torch.cat(recv_tensors))
798+
recv_tensors_concatenated = torch.cat(recv_tensors)
799+
# Copy this bucket's collected all-gather tensors into the right place in the
800+
# tensor for the buffer. The tensor for the buffer gets rid of the padding
801+
# between buckets.
802+
start = offset_in_world_tensors
803+
end = offset_in_world_tensors + gbuf_world_numel_unpadded
804+
world_tensors[key][start:end].copy_(
805+
recv_tensors_concatenated[:gbuf_world_numel_unpadded]
806+
)
807+
808+
offset_in_world_tensors += gbuf_world_numel_unpadded
785809

786810
# Collect world state.
787811
dtype_state[dtype] = world_tensors
@@ -1001,7 +1025,8 @@ def load_parameter_state_from_fs_bucket_space(self, state_dict):
10011025
dst_tensors[key].copy_(src_tensors[key])
10021026

10031027
def load_parameter_state_from_dp_zero(self, state_dict):
1004-
"""Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank.
1028+
"""Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank,
1029+
using the new checkpoint format with coalesced state across buckets.
10051030
10061031
This method performs the reverse of get_parameter_state_dp_zero():
10071032
- Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
@@ -1010,13 +1035,6 @@ def load_parameter_state_from_dp_zero(self, state_dict):
10101035
buffers. (e.g., one buffer each for main_param, exp_avg, and
10111036
exp_avg_sq).
10121037
"""
1013-
if state_dict is not None and "per_bucket_numel_unpadded" in state_dict:
1014-
per_bucket_numel_unpadded_in_checkpoint = state_dict["per_bucket_numel_unpadded"]
1015-
assert self.per_bucket_numel_unpadded == per_bucket_numel_unpadded_in_checkpoint, (
1016-
f"Number of unpadded elements in each bucket need to be the same in current run "
1017-
f"({self.per_bucket_numel_unpadded}) and checkpoint "
1018-
f"({per_bucket_numel_unpadded_in_checkpoint})"
1019-
)
10201038

10211039
# Data parallelism variables.
10221040
data_parallel_world_size = self.data_parallel_group_gloo.size()
@@ -1029,74 +1047,47 @@ def load_parameter_state_from_dp_zero(self, state_dict):
10291047
# Scatter tensors to all DP ranks.
10301048
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
10311049
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
1032-
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
1033-
1034-
# Compute local DP contiguous shard's size.
1035-
gbuf_world_numel = self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel()
1036-
assert gbuf_world_numel == self.per_bucket_numel[gbuf_idx][dtype][bucket_idx]
1037-
assert gbuf_world_numel % data_parallel_world_size == 0
1038-
gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
1039-
1040-
# Contiguous local shards (received from DP rank 0).
1041-
local_shards = {
1042-
key: torch.empty((gbuf_local_numel,), dtype=torch.float32, device="cpu")
1043-
for key in ("param", "exp_avg", "exp_avg_sq")
1044-
}
1050+
if data_parallel_rank == 0:
1051+
buffer_numel_unpadded = self.buffers[gbuf_idx].numel_unpadded
1052+
checkpoint_numel_unpadded = state_dict[gbuf_idx][dtype]["numel_unpadded"]
1053+
assert buffer_numel_unpadded == checkpoint_numel_unpadded, (
1054+
f"Number of unpadded elements must be same in current run "
1055+
f"({buffer_numel_unpadded}) and checkpoint ({checkpoint_numel_unpadded})"
1056+
)
1057+
for key in ("param", "exp_avg", "exp_avg_sq"):
1058+
offset_in_world_tensors = 0
1059+
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
1060+
# Compute local DP contiguous shard's size.
1061+
gbuf_world_numel = (
1062+
self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel()
1063+
)
1064+
assert gbuf_world_numel % data_parallel_world_size == 0
1065+
gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
1066+
gbuf_world_numel_unpadded = (
1067+
self.buffers[gbuf_idx].buckets[bucket_idx].numel_unpadded
1068+
)
1069+
assert gbuf_world_numel_unpadded <= gbuf_world_numel
10451070

1046-
# Scatter local shards from DP rank 0.
1047-
for key, recv_tensor in local_shards.items():
1071+
# Contiguous local shards (received from DP rank 0).
1072+
recv_tensor = torch.empty(
1073+
(gbuf_local_numel,), dtype=torch.float32, device="cpu"
1074+
)
10481075

10491076
# Scatter tensor list.
10501077
if data_parallel_rank == 0:
1051-
world_tensor_for_all_buckets = state_dict[gbuf_idx][dtype][key]
1052-
if not isinstance(world_tensor_for_all_buckets, list):
1053-
world_tensor_for_all_buckets = [world_tensor_for_all_buckets]
1054-
assert bucket_idx < len(world_tensor_for_all_buckets), (
1055-
f"Trying to load state for bucket_id {bucket_idx} (out of "
1056-
f"{len(gbuf_range_map_for_all_buckets)} buckets) from checkpoint; "
1057-
f"checkpoint only has {len(world_tensor_for_all_buckets)} bucket(s)"
1078+
world_tensors = state_dict[gbuf_idx][dtype][key]
1079+
1080+
start = offset_in_world_tensors
1081+
end = offset_in_world_tensors + gbuf_world_numel_unpadded
1082+
assert 0 <= start < end <= world_tensors.numel()
1083+
world_tensor = world_tensors[start:end]
1084+
offset_in_world_tensors += gbuf_world_numel_unpadded
1085+
1086+
# Pad world_tensor to gbuf_world_numel. Don't pad at the front, pad at the back.
1087+
world_tensor = torch.nn.functional.pad(
1088+
world_tensor, (0, gbuf_world_numel - gbuf_world_numel_unpadded)
10581089
)
1059-
# This tensor might be bigger or smaller than expected (depending on
1060-
# relative sizes of per_bucket_numel_in_checkpoint and self.per_bucket_numel).
1061-
world_tensor = world_tensor_for_all_buckets[bucket_idx]
1062-
if "per_bucket_numel" in state_dict:
1063-
numel_in_checkpoint = state_dict["per_bucket_numel"][gbuf_idx][
1064-
dtype
1065-
][bucket_idx]
1066-
numel = self.per_bucket_numel[gbuf_idx][dtype][bucket_idx]
1067-
numel_unpadded = self.per_bucket_numel_unpadded[gbuf_idx][dtype][
1068-
bucket_idx
1069-
]
1070-
assert world_tensor.numel() == numel_in_checkpoint
1071-
assert numel_unpadded <= world_tensor.numel(), (
1072-
"True number of elements should be fewer than number of elements in "
1073-
"checkpoint tensor"
1074-
)
1075-
if world_tensor.numel() > numel:
1076-
# Truncate extra values, which are padding anyway.
1077-
logger.info(
1078-
f"Truncating extra values from checkpoint (numel_in_checkpoint={numel_in_checkpoint}, "
1079-
f"numel={numel}, numel_unpadded={numel_unpadded})"
1080-
)
1081-
world_tensor = world_tensor[:numel]
1082-
elif world_tensor.numel() < numel:
1083-
# In this case, numel > world_tensor.numel() (which is numel_in_checkpoint).
1084-
# Create new tensor with right number of values, then copy and use new tensor.
1085-
logger.info(
1086-
f"Expanding tensor from checkpoint (numel_in_checkpoint={numel_in_checkpoint}, "
1087-
f"numel={numel}, numel_unpadded={numel_unpadded})"
1088-
)
1089-
world_tensor_reshaped = torch.empty(
1090-
(numel,),
1091-
dtype=world_tensor.dtype,
1092-
device=world_tensor.device,
1093-
)
1094-
world_tensor_reshaped[:numel_in_checkpoint].copy_(world_tensor)
1095-
world_tensor = world_tensor_reshaped
1096-
else:
1097-
logger.info(
1098-
"***WARNING*** Using older checkpoint so skipping padding checks"
1099-
)
1090+
assert world_tensor.numel() == gbuf_world_numel
11001091
gbuf_start_idxs = list(range(0, gbuf_world_numel, gbuf_local_numel))
11011092
send_tensors = [
11021093
world_tensor[i : (i + gbuf_local_numel)] for i in gbuf_start_idxs
@@ -1112,25 +1103,25 @@ def load_parameter_state_from_dp_zero(self, state_dict):
11121103
data_parallel_group_gloo,
11131104
)
11141105

1115-
# Copy local contiguous shards to param/optim shards.
1116-
for model_param, param_range_map in gbuf_range_map["param_map"].items():
1117-
1118-
# Main param & optimizer states.
1119-
group_index, group_order = self.model_param_group_index_map[model_param]
1120-
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
1121-
optim_state = self.optimizer.state[main_param]
1122-
1123-
tensors = {
1124-
"param": main_param,
1125-
**optim_state,
1126-
}
1106+
# Copy local contiguous shards to param/optim shards.
1107+
for model_param, param_range_map in gbuf_range_map["param_map"].items():
11271108

1128-
# Copy states into contiguous shard.
1129-
gbuf_local_start = param_range_map["gbuf_local"].start
1130-
gbuf_local_end = param_range_map["gbuf_local"].end
1131-
for key in local_shards:
1132-
tensors[key].data.copy_(
1133-
local_shards[key][gbuf_local_start:gbuf_local_end]
1109+
# Main param & optimizer states.
1110+
group_index, group_order = self.model_param_group_index_map[model_param]
1111+
main_param = self.optimizer.param_groups[group_index]["params"][
1112+
group_order
1113+
]
1114+
if key == "param":
1115+
tensor_to_copy_into = main_param
1116+
else:
1117+
optim_state = self.optimizer.state[main_param]
1118+
tensor_to_copy_into = optim_state[key]
1119+
1120+
# Copy states into contiguous shard.
1121+
gbuf_local_start = param_range_map["gbuf_local"].start
1122+
gbuf_local_end = param_range_map["gbuf_local"].end
1123+
tensor_to_copy_into.data.copy_(
1124+
recv_tensor[gbuf_local_start:gbuf_local_end]
11341125
)
11351126

11361127
def load_parameter_state(self, filename: str):

0 commit comments

Comments
 (0)