@@ -160,10 +160,14 @@ def save_checkpoint(self, update, last=False):
160160 return
161161 self .accelerator .save (checkpoint , f"{ self .checkpoint_path } /model_{ update } .pt" )
162162 if self .keep_last_n_checkpoints > 0 :
163+ # Updated logic to exclude pretrained model from rotation
163164 checkpoints = [
164165 f
165166 for f in os .listdir (self .checkpoint_path )
166- if f .startswith ("model_" ) and f .endswith (".pt" ) and f != "model_last.pt"
167+ if f .startswith ("model_" )
168+ and not f .startswith ("pretrained_" ) # Exclude pretrained models
169+ and f .endswith (".pt" )
170+ and f != "model_last.pt"
167171 ]
168172 checkpoints .sort (key = lambda x : int (x .split ("_" )[1 ].split ("." )[0 ]))
169173 while len (checkpoints ) > self .keep_last_n_checkpoints :
@@ -183,10 +187,24 @@ def load_checkpoint(self):
183187 if "model_last.pt" in os .listdir (self .checkpoint_path ):
184188 latest_checkpoint = "model_last.pt"
185189 else :
186- latest_checkpoint = sorted (
187- [f for f in os .listdir (self .checkpoint_path ) if f .endswith (".pt" )],
188- key = lambda x : int ("" .join (filter (str .isdigit , x ))),
189- )[- 1 ]
190+ # Updated to consider pretrained models for loading but prioritize training checkpoints
191+ all_checkpoints = [
192+ f
193+ for f in os .listdir (self .checkpoint_path )
194+ if (f .startswith ("model_" ) or f .startswith ("pretrained_" )) and f .endswith (".pt" )
195+ ]
196+
197+ # First try to find regular training checkpoints
198+ training_checkpoints = [f for f in all_checkpoints if f .startswith ("model_" ) and f != "model_last.pt" ]
199+ if training_checkpoints :
200+ latest_checkpoint = sorted (
201+ training_checkpoints ,
202+ key = lambda x : int ("" .join (filter (str .isdigit , x ))),
203+ )[- 1 ]
204+ else :
205+ # If no training checkpoints, use pretrained model
206+ latest_checkpoint = next (f for f in all_checkpoints if f .startswith ("pretrained_" ))
207+
190208 # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
191209 checkpoint = torch .load (f"{ self .checkpoint_path } /{ latest_checkpoint } " , weights_only = True , map_location = "cpu" )
192210
0 commit comments