@@ -715,22 +715,38 @@ def get_parameter_state_dp_zero(self):
715715
716716 # Collect param states.
717717 state = {
718- "per_bucket_numel" : self .per_bucket_numel ,
719- "per_bucket_numel_unpadded" : self .per_bucket_numel_unpadded ,
718+ "buckets_coalesced" : True ,
720719 }
721720 for gbuf_idx , gbuf_range_maps in enumerate (self .gbuf_ranges ):
722721
723722 # Iterate grad buffers (by data type).
724723 dtype_state = {}
725724 assert len (gbuf_range_maps ) == 1 , "single dtype supported, for now."
726725 for dtype , gbuf_range_map_for_all_buckets in gbuf_range_maps .items ():
726+ buffer_numel_unpadded = self .buffers [gbuf_idx ].numel_unpadded
727+ # Create coalesced tensors for all state related to parameters in this buffer.
727728 world_tensors = {}
729+ if data_parallel_rank == 0 :
730+ world_tensors = {
731+ key : torch .empty (
732+ (buffer_numel_unpadded ,), dtype = torch .float32 , device = "cpu"
733+ )
734+ for key in ("param" , "exp_avg" , "exp_avg_sq" )
735+ }
736+ world_tensors ["numel_unpadded" ] = buffer_numel_unpadded
737+ offset_in_world_tensors = 0
728738 for bucket_idx , gbuf_range_map in enumerate (gbuf_range_map_for_all_buckets ):
729739
730740 # Compute local DP contiguous shard's size.
731741 gbuf_world_numel = self .buffers [gbuf_idx ].buckets [bucket_idx ].grad_data .numel ()
732742 assert gbuf_world_numel % data_parallel_world_size == 0
733743 gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
744+
745+ gbuf_world_numel_unpadded = (
746+ self .buffers [gbuf_idx ].buckets [bucket_idx ].numel_unpadded
747+ )
748+ assert gbuf_world_numel_unpadded <= gbuf_world_numel
749+
734750 local_shards = {
735751 key : torch .empty ((gbuf_local_numel ,), dtype = torch .float32 , device = "cpu" )
736752 for key in ("param" , "exp_avg" , "exp_avg_sq" )
@@ -779,9 +795,17 @@ def get_parameter_state_dp_zero(self):
779795
780796 # Concatenate.
781797 if data_parallel_rank == 0 :
782- if key not in world_tensors :
783- world_tensors [key ] = []
784- world_tensors [key ].append (torch .cat (recv_tensors ))
798+ recv_tensors_concatenated = torch .cat (recv_tensors )
799+ # Copy this bucket's collected all-gather tensors into the right place in the
800+ # tensor for the buffer. The tensor for the buffer gets rid of the padding
801+ # between buckets.
802+ start = offset_in_world_tensors
803+ end = offset_in_world_tensors + gbuf_world_numel_unpadded
804+ world_tensors [key ][start :end ].copy_ (
805+ recv_tensors_concatenated [:gbuf_world_numel_unpadded ]
806+ )
807+
808+ offset_in_world_tensors += gbuf_world_numel_unpadded
785809
786810 # Collect world state.
787811 dtype_state [dtype ] = world_tensors
@@ -1001,7 +1025,8 @@ def load_parameter_state_from_fs_bucket_space(self, state_dict):
10011025 dst_tensors [key ].copy_ (src_tensors [key ])
10021026
10031027 def load_parameter_state_from_dp_zero (self , state_dict ):
1004- """Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank.
1028+ """Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank,
1029+ using the new checkpoint format with coalesced state across buckets.
10051030
10061031 This method performs the reverse of get_parameter_state_dp_zero():
10071032 - Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
@@ -1010,13 +1035,6 @@ def load_parameter_state_from_dp_zero(self, state_dict):
10101035 buffers. (e.g., one buffer each for main_param, exp_avg, and
10111036 exp_avg_sq).
10121037 """
1013- if state_dict is not None and "per_bucket_numel_unpadded" in state_dict :
1014- per_bucket_numel_unpadded_in_checkpoint = state_dict ["per_bucket_numel_unpadded" ]
1015- assert self .per_bucket_numel_unpadded == per_bucket_numel_unpadded_in_checkpoint , (
1016- f"Number of unpadded elements in each bucket need to be the same in current run "
1017- f"({ self .per_bucket_numel_unpadded } ) and checkpoint "
1018- f"({ per_bucket_numel_unpadded_in_checkpoint } )"
1019- )
10201038
10211039 # Data parallelism variables.
10221040 data_parallel_world_size = self .data_parallel_group_gloo .size ()
@@ -1029,74 +1047,47 @@ def load_parameter_state_from_dp_zero(self, state_dict):
10291047 # Scatter tensors to all DP ranks.
10301048 for gbuf_idx , gbuf_range_maps in enumerate (self .gbuf_ranges ):
10311049 for dtype , gbuf_range_map_for_all_buckets in gbuf_range_maps .items ():
1032- for bucket_idx , gbuf_range_map in enumerate (gbuf_range_map_for_all_buckets ):
1033-
1034- # Compute local DP contiguous shard's size.
1035- gbuf_world_numel = self .buffers [gbuf_idx ].buckets [bucket_idx ].grad_data .numel ()
1036- assert gbuf_world_numel == self .per_bucket_numel [gbuf_idx ][dtype ][bucket_idx ]
1037- assert gbuf_world_numel % data_parallel_world_size == 0
1038- gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
1039-
1040- # Contiguous local shards (received from DP rank 0).
1041- local_shards = {
1042- key : torch .empty ((gbuf_local_numel ,), dtype = torch .float32 , device = "cpu" )
1043- for key in ("param" , "exp_avg" , "exp_avg_sq" )
1044- }
1050+ if data_parallel_rank == 0 :
1051+ buffer_numel_unpadded = self .buffers [gbuf_idx ].numel_unpadded
1052+ checkpoint_numel_unpadded = state_dict [gbuf_idx ][dtype ]["numel_unpadded" ]
1053+ assert buffer_numel_unpadded == checkpoint_numel_unpadded , (
1054+ f"Number of unpadded elements must be same in current run "
1055+ f"({ buffer_numel_unpadded } ) and checkpoint ({ checkpoint_numel_unpadded } )"
1056+ )
1057+ for key in ("param" , "exp_avg" , "exp_avg_sq" ):
1058+ offset_in_world_tensors = 0
1059+ for bucket_idx , gbuf_range_map in enumerate (gbuf_range_map_for_all_buckets ):
1060+ # Compute local DP contiguous shard's size.
1061+ gbuf_world_numel = (
1062+ self .buffers [gbuf_idx ].buckets [bucket_idx ].grad_data .numel ()
1063+ )
1064+ assert gbuf_world_numel % data_parallel_world_size == 0
1065+ gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
1066+ gbuf_world_numel_unpadded = (
1067+ self .buffers [gbuf_idx ].buckets [bucket_idx ].numel_unpadded
1068+ )
1069+ assert gbuf_world_numel_unpadded <= gbuf_world_numel
10451070
1046- # Scatter local shards from DP rank 0.
1047- for key , recv_tensor in local_shards .items ():
1071+ # Contiguous local shards (received from DP rank 0).
1072+ recv_tensor = torch .empty (
1073+ (gbuf_local_numel ,), dtype = torch .float32 , device = "cpu"
1074+ )
10481075
10491076 # Scatter tensor list.
10501077 if data_parallel_rank == 0 :
1051- world_tensor_for_all_buckets = state_dict [gbuf_idx ][dtype ][key ]
1052- if not isinstance (world_tensor_for_all_buckets , list ):
1053- world_tensor_for_all_buckets = [world_tensor_for_all_buckets ]
1054- assert bucket_idx < len (world_tensor_for_all_buckets ), (
1055- f"Trying to load state for bucket_id { bucket_idx } (out of "
1056- f"{ len (gbuf_range_map_for_all_buckets )} buckets) from checkpoint; "
1057- f"checkpoint only has { len (world_tensor_for_all_buckets )} bucket(s)"
1078+ world_tensors = state_dict [gbuf_idx ][dtype ][key ]
1079+
1080+ start = offset_in_world_tensors
1081+ end = offset_in_world_tensors + gbuf_world_numel_unpadded
1082+ assert 0 <= start < end <= world_tensors .numel ()
1083+ world_tensor = world_tensors [start :end ]
1084+ offset_in_world_tensors += gbuf_world_numel_unpadded
1085+
1086+ # Pad world_tensor to gbuf_world_numel. Don't pad at the front, pad at the back.
1087+ world_tensor = torch .nn .functional .pad (
1088+ world_tensor , (0 , gbuf_world_numel - gbuf_world_numel_unpadded )
10581089 )
1059- # This tensor might be bigger or smaller than expected (depending on
1060- # relative sizes of per_bucket_numel_in_checkpoint and self.per_bucket_numel).
1061- world_tensor = world_tensor_for_all_buckets [bucket_idx ]
1062- if "per_bucket_numel" in state_dict :
1063- numel_in_checkpoint = state_dict ["per_bucket_numel" ][gbuf_idx ][
1064- dtype
1065- ][bucket_idx ]
1066- numel = self .per_bucket_numel [gbuf_idx ][dtype ][bucket_idx ]
1067- numel_unpadded = self .per_bucket_numel_unpadded [gbuf_idx ][dtype ][
1068- bucket_idx
1069- ]
1070- assert world_tensor .numel () == numel_in_checkpoint
1071- assert numel_unpadded <= world_tensor .numel (), (
1072- "True number of elements should be fewer than number of elements in "
1073- "checkpoint tensor"
1074- )
1075- if world_tensor .numel () > numel :
1076- # Truncate extra values, which are padding anyway.
1077- logger .info (
1078- f"Truncating extra values from checkpoint (numel_in_checkpoint={ numel_in_checkpoint } , "
1079- f"numel={ numel } , numel_unpadded={ numel_unpadded } )"
1080- )
1081- world_tensor = world_tensor [:numel ]
1082- elif world_tensor .numel () < numel :
1083- # In this case, numel > world_tensor.numel() (which is numel_in_checkpoint).
1084- # Create new tensor with right number of values, then copy and use new tensor.
1085- logger .info (
1086- f"Expanding tensor from checkpoint (numel_in_checkpoint={ numel_in_checkpoint } , "
1087- f"numel={ numel } , numel_unpadded={ numel_unpadded } )"
1088- )
1089- world_tensor_reshaped = torch .empty (
1090- (numel ,),
1091- dtype = world_tensor .dtype ,
1092- device = world_tensor .device ,
1093- )
1094- world_tensor_reshaped [:numel_in_checkpoint ].copy_ (world_tensor )
1095- world_tensor = world_tensor_reshaped
1096- else :
1097- logger .info (
1098- "***WARNING*** Using older checkpoint so skipping padding checks"
1099- )
1090+ assert world_tensor .numel () == gbuf_world_numel
11001091 gbuf_start_idxs = list (range (0 , gbuf_world_numel , gbuf_local_numel ))
11011092 send_tensors = [
11021093 world_tensor [i : (i + gbuf_local_numel )] for i in gbuf_start_idxs
@@ -1112,25 +1103,25 @@ def load_parameter_state_from_dp_zero(self, state_dict):
11121103 data_parallel_group_gloo ,
11131104 )
11141105
1115- # Copy local contiguous shards to param/optim shards.
1116- for model_param , param_range_map in gbuf_range_map ["param_map" ].items ():
1117-
1118- # Main param & optimizer states.
1119- group_index , group_order = self .model_param_group_index_map [model_param ]
1120- main_param = self .optimizer .param_groups [group_index ]["params" ][group_order ]
1121- optim_state = self .optimizer .state [main_param ]
1122-
1123- tensors = {
1124- "param" : main_param ,
1125- ** optim_state ,
1126- }
1106+ # Copy local contiguous shards to param/optim shards.
1107+ for model_param , param_range_map in gbuf_range_map ["param_map" ].items ():
11271108
1128- # Copy states into contiguous shard.
1129- gbuf_local_start = param_range_map ["gbuf_local" ].start
1130- gbuf_local_end = param_range_map ["gbuf_local" ].end
1131- for key in local_shards :
1132- tensors [key ].data .copy_ (
1133- local_shards [key ][gbuf_local_start :gbuf_local_end ]
1109+ # Main param & optimizer states.
1110+ group_index , group_order = self .model_param_group_index_map [model_param ]
1111+ main_param = self .optimizer .param_groups [group_index ]["params" ][
1112+ group_order
1113+ ]
1114+ if key == "param" :
1115+ tensor_to_copy_into = main_param
1116+ else :
1117+ optim_state = self .optimizer .state [main_param ]
1118+ tensor_to_copy_into = optim_state [key ]
1119+
1120+ # Copy states into contiguous shard.
1121+ gbuf_local_start = param_range_map ["gbuf_local" ].start
1122+ gbuf_local_end = param_range_map ["gbuf_local" ].end
1123+ tensor_to_copy_into .data .copy_ (
1124+ recv_tensor [gbuf_local_start :gbuf_local_end ]
11341125 )
11351126
11361127 def load_parameter_state (self , filename : str ):
0 commit comments