@@ -205,48 +205,28 @@ def load_adapter(
205205 if hotswap == "auto" :
206206 # if user called model.enable_peft_hotswap and this is not the first adapter, enable hotswap
207207 hotswap_enabled = getattr (self , "_hotswap_enabled" , False )
208- not_first_adapter = bool (
209- self ._hf_peft_config_loaded and (adapter_name in self .peft_config )
210- )
208+ not_first_adapter = bool (self ._hf_peft_config_loaded and (adapter_name in self .peft_config ))
211209 hotswap = hotswap_enabled and not_first_adapter
212210
213211 if hotswap :
214212 min_version_hotswap = "0.15.0"
215- if version .parse (importlib .metadata .version ("peft" )) < version .parse (
216- min_version_hotswap
217- ):
218- raise ValueError (
219- f"To hotswap the adapter, you need PEFT >= v{ min_version_hotswap } ."
220- )
221- if (not self ._hf_peft_config_loaded ) or (
222- adapter_name not in self .peft_config
223- ):
213+ if version .parse (importlib .metadata .version ("peft" )) < version .parse (min_version_hotswap ):
214+ raise ValueError (f"To hotswap the adapter, you need PEFT >= v{ min_version_hotswap } ." )
215+ if (not self ._hf_peft_config_loaded ) or (adapter_name not in self .peft_config ):
224216 raise ValueError (
225217 "To hotswap an adapter, there must already be an existing adapter with the same adapter name."
226218 )
227- if any (
228- conf .peft_type != PeftType .LORA for conf in self .peft_config .values ()
229- ):
230- raise ValueError (
231- "Hotswapping is currently only supported for LoRA, please set `hotswap=False`."
232- )
219+ if any (conf .peft_type != PeftType .LORA for conf in self .peft_config .values ()):
220+ raise ValueError ("Hotswapping is currently only supported for LoRA, please set `hotswap=False`." )
233221
234222 # peft only supports low_cpu_mem_usage starting from v0.13.0
235223 peft_load_kwargs = {}
236- key_mapping = (
237- adapter_kwargs .pop ("key_mapping" , None )
238- if adapter_kwargs is not None
239- else None
240- )
241- if key_mapping is None and any (
242- allowed_name in self .__class__ .__name__ .lower () for allowed_name in VLMS
243- ):
224+ key_mapping = adapter_kwargs .pop ("key_mapping" , None ) if adapter_kwargs is not None else None
225+ if key_mapping is None and any (allowed_name in self .__class__ .__name__ .lower () for allowed_name in VLMS ):
244226 key_mapping = self ._checkpoint_conversion_mapping
245227 if low_cpu_mem_usage :
246228 min_version_lcmu = "0.13.0"
247- if version .parse (importlib .metadata .version ("peft" )) >= version .parse (
248- min_version_lcmu
249- ):
229+ if version .parse (importlib .metadata .version ("peft" )) >= version .parse (min_version_lcmu ):
250230 peft_load_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
251231 else :
252232 raise ValueError (
@@ -264,34 +244,20 @@ def load_adapter(
264244 from peft import PeftConfig , inject_adapter_in_model , load_peft_weights
265245 from peft .utils import set_peft_model_state_dict
266246
267- if (
268- self ._hf_peft_config_loaded
269- and (not hotswap )
270- and (adapter_name in self .peft_config )
271- ):
272- raise ValueError (
273- f"Adapter with name { adapter_name } already exists. Please use a different name."
274- )
275- elif hotswap and (
276- (not self ._hf_peft_config_loaded ) or (adapter_name not in self .peft_config )
277- ):
247+ if self ._hf_peft_config_loaded and (not hotswap ) and (adapter_name in self .peft_config ):
248+ raise ValueError (f"Adapter with name { adapter_name } already exists. Please use a different name." )
249+ elif hotswap and ((not self ._hf_peft_config_loaded ) or (adapter_name not in self .peft_config )):
278250 raise ValueError (
279251 "To hotswap an adapter, there must already be an existing adapter with the same adapter name."
280252 )
281253
282- if peft_model_id is None and (
283- adapter_state_dict is None and peft_config is None
284- ):
254+ if peft_model_id is None and (adapter_state_dict is None and peft_config is None ):
285255 raise ValueError (
286256 "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
287257 )
288258
289259 if "device" not in adapter_kwargs :
290- device = (
291- self .device
292- if not hasattr (self , "hf_device_map" )
293- else list (self .hf_device_map .values ())[0 ]
294- )
260+ device = self .device if not hasattr (self , "hf_device_map" ) else list (self .hf_device_map .values ())[0 ]
295261 else :
296262 device = adapter_kwargs .pop ("device" )
297263
@@ -302,11 +268,7 @@ def load_adapter(
302268 # We keep `revision` in the signature for backward compatibility
303269 if revision is not None and "revision" not in adapter_kwargs :
304270 adapter_kwargs ["revision" ] = revision
305- elif (
306- revision is not None
307- and "revision" in adapter_kwargs
308- and revision != adapter_kwargs ["revision" ]
309- ):
271+ elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs ["revision" ]:
310272 logger .error (
311273 "You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. "
312274 "The one in `adapter_kwargs` will be used."
@@ -337,9 +299,7 @@ def load_adapter(
337299 peft_config .inference_mode = not is_trainable
338300
339301 if peft_config .peft_type != PeftType .LORA :
340- raise ValueError (
341- "Hotswapping is currently only supported for LoRA, please set `hotswap=False`."
342- )
302+ raise ValueError ("Hotswapping is currently only supported for LoRA, please set `hotswap=False`." )
343303
344304 if not hotswap :
345305 # TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
@@ -350,9 +310,7 @@ def load_adapter(
350310 self ._hf_peft_config_loaded = True
351311
352312 if peft_model_id is not None :
353- adapter_state_dict = load_peft_weights (
354- peft_model_id , token = token , device = device , ** adapter_kwargs
355- )
313+ adapter_state_dict = load_peft_weights (peft_model_id , token = token , device = device , ** adapter_kwargs )
356314
357315 # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
358316 processed_adapter_state_dict = {}
@@ -392,9 +350,7 @@ def load_adapter(
392350 # Therefore, it needs to be called here
393351 from peft .utils .hotswap import prepare_model_for_compiled_hotswap
394352
395- prepare_model_for_compiled_hotswap (
396- self , config = peft_config , ** self ._prepare_peft_hotswap_kwargs
397- )
353+ prepare_model_for_compiled_hotswap (self , config = peft_config , ** self ._prepare_peft_hotswap_kwargs )
398354 # We only want to call prepare_model_for_compiled_hotswap once
399355 self ._prepare_peft_hotswap_kwargs = None
400356 else :
@@ -403,9 +359,7 @@ def load_adapter(
403359 hotswap_adapter_from_state_dict ,
404360 )
405361
406- check_hotswap_configs_compatible (
407- self .peft_config [adapter_name ], peft_config
408- )
362+ check_hotswap_configs_compatible (self .peft_config [adapter_name ], peft_config )
409363 try :
410364 hotswap_adapter_from_state_dict (
411365 model = self ,
@@ -414,20 +368,15 @@ def load_adapter(
414368 config = peft_config ,
415369 )
416370 except Exception as e :
417- logger .error (
418- f"Hotswapping { adapter_name } was unsucessful with the following error: \n { e } "
419- )
371+ logger .error (f"Hotswapping { adapter_name } was unsucessful with the following error: \n { e } " )
420372 raise
421373 incompatible_keys = None
422374
423375 if incompatible_keys is not None :
424376 err_msg = ""
425377 origin_name = peft_model_id if peft_model_id is not None else "state_dict"
426378 # Check for unexpected keys.
427- if (
428- hasattr (incompatible_keys , "unexpected_keys" )
429- and len (incompatible_keys .unexpected_keys ) > 0
430- ):
379+ if hasattr (incompatible_keys , "unexpected_keys" ) and len (incompatible_keys .unexpected_keys ) > 0 :
431380 err_msg = (
432381 f"Loading adapter weights from { origin_name } led to unexpected keys not found in the model: "
433382 f"{ ', ' .join (incompatible_keys .unexpected_keys )} . "
@@ -437,9 +386,7 @@ def load_adapter(
437386 missing_keys = getattr (incompatible_keys , "missing_keys" , None )
438387 if missing_keys :
439388 # Filter missing keys specific to the current adapter, as missing base model keys are expected.
440- lora_missing_keys = [
441- k for k in missing_keys if "lora_" in k and adapter_name in k
442- ]
389+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
443390 if lora_missing_keys :
444391 err_msg += (
445392 f"Loading adapter weights from { origin_name } led to missing keys in the model: "
@@ -455,9 +402,7 @@ def load_adapter(
455402 # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
456403 if (
457404 (getattr (self , "hf_device_map" , None ) is not None )
458- and (
459- len (set (self .hf_device_map .values ()).intersection ({"cpu" , "disk" })) > 0
460- )
405+ and (len (set (self .hf_device_map .values ()).intersection ({"cpu" , "disk" })) > 0 )
461406 and len (self .peft_config ) == 1
462407 ):
463408 self ._dispatch_accelerate_model (
@@ -492,18 +437,12 @@ def enable_peft_hotswap(
492437 - "ignore": do nothing
493438 """
494439 min_version_hotswap = "0.15.0"
495- if version .parse (importlib .metadata .version ("peft" )) < version .parse (
496- min_version_hotswap
497- ):
498- raise ValueError (
499- f"To hotswap the adapter, you need PEFT >= v{ min_version_hotswap } ."
500- )
440+ if version .parse (importlib .metadata .version ("peft" )) < version .parse (min_version_hotswap ):
441+ raise ValueError (f"To hotswap the adapter, you need PEFT >= v{ min_version_hotswap } ." )
501442
502443 if getattr (self , "peft_config" , {}):
503444 if check_compiled == "error" :
504- raise RuntimeError (
505- "Call `enable_peft_hotswap` before loading the first adapter."
506- )
445+ raise RuntimeError ("Call `enable_peft_hotswap` before loading the first adapter." )
507446 elif check_compiled == "warn" :
508447 logger .warning (
509448 "It is recommended to call `enable_peft_hotswap` before loading the first adapter to avoid recompilation."
@@ -546,14 +485,10 @@ def add_adapter(self, adapter_config, adapter_name: str | None = None) -> None:
546485 if not self ._hf_peft_config_loaded :
547486 self ._hf_peft_config_loaded = True
548487 elif adapter_name in self .peft_config :
549- raise ValueError (
550- f"Adapter with name { adapter_name } already exists. Please use a different name."
551- )
488+ raise ValueError (f"Adapter with name { adapter_name } already exists. Please use a different name." )
552489
553490 if not isinstance (adapter_config , PeftConfig ):
554- raise TypeError (
555- f"adapter_config should be an instance of PeftConfig. Got { type (adapter_config )} instead."
556- )
491+ raise TypeError (f"adapter_config should be an instance of PeftConfig. Got { type (adapter_config )} instead." )
557492
558493 # Retrieve the name or path of the model, one could also use self.config._name_or_path
559494 # but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100
@@ -667,9 +602,7 @@ def active_adapters(self) -> list[str]:
667602 check_peft_version (min_version = MIN_PEFT_VERSION )
668603
669604 if not is_peft_available ():
670- raise ImportError (
671- "PEFT is not available. Please install PEFT to use this function: `pip install peft`."
672- )
605+ raise ImportError ("PEFT is not available. Please install PEFT to use this function: `pip install peft`." )
673606
674607 if not self ._hf_peft_config_loaded :
675608 raise ValueError ("No adapter loaded. Please load an adapter first." )
@@ -687,9 +620,7 @@ def active_adapters(self) -> list[str]:
687620
688621 return active_adapters
689622
690- def get_adapter_state_dict (
691- self , adapter_name : Optional [str ] = None , state_dict : Optional [dict ] = None
692- ) -> dict :
623+ def get_adapter_state_dict (self , adapter_name : Optional [str ] = None , state_dict : Optional [dict ] = None ) -> dict :
693624 """
694625 If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
695626 official documentation: https://huggingface.co/docs/peft
@@ -715,9 +646,7 @@ def get_adapter_state_dict(
715646 if adapter_name is None :
716647 adapter_name = self .active_adapters ()[0 ]
717648
718- adapter_state_dict = get_peft_model_state_dict (
719- self , state_dict = state_dict , adapter_name = adapter_name
720- )
649+ adapter_state_dict = get_peft_model_state_dict (self , state_dict = state_dict , adapter_name = adapter_name )
721650 return adapter_state_dict
722651
723652 def _dispatch_accelerate_model (
@@ -817,9 +746,7 @@ def old_delete_adapter(model, adapter_name, prefix=None):
817746 f">= { min_version_delete_adapter } is required."
818747 )
819748
820- if version .parse (importlib .metadata .version ("peft" )) >= version .parse (
821- min_version_delete_adapter
822- ):
749+ if version .parse (importlib .metadata .version ("peft" )) >= version .parse (min_version_delete_adapter ):
823750 from peft .functional import delete_adapter
824751 else :
825752 delete_adapter = old_delete_adapter
@@ -828,24 +755,17 @@ def old_delete_adapter(model, adapter_name, prefix=None):
828755 adapter_names = [adapter_names ]
829756
830757 # Check that all adapter names are present in the config
831- missing_adapters = [
832- name for name in adapter_names if name not in self .peft_config
833- ]
758+ missing_adapters = [name for name in adapter_names if name not in self .peft_config ]
834759 if missing_adapters :
835760 raise ValueError (
836761 f"The following adapter(s) are not present and cannot be deleted: { ', ' .join (missing_adapters )} "
837762 )
838763
839- prefixes = [
840- f"{ self .peft_config [adapter_name ].peft_type .value .lower ()} _"
841- for adapter_name in adapter_names
842- ]
764+ prefixes = [f"{ self .peft_config [adapter_name ].peft_type .value .lower ()} _" for adapter_name in adapter_names ]
843765 for adapter_name , prefix in zip (adapter_names , prefixes ):
844766 delete_adapter (self , adapter_name = adapter_name , prefix = prefix )
845767 # For transformers integration - we need to pop the adapter from the config
846- if getattr (self , "_hf_peft_config_loaded" , False ) and hasattr (
847- self , "peft_config"
848- ):
768+ if getattr (self , "_hf_peft_config_loaded" , False ) and hasattr (self , "peft_config" ):
849769 self .peft_config .pop (adapter_name , None )
850770
851771 # In case all adapters are deleted, we need to delete the config
@@ -864,9 +784,7 @@ def maybe_load_adapters(
864784 return None , pretrained_model_name_or_path , adapter_kwargs
865785
866786 if "local_files_only" not in adapter_kwargs :
867- adapter_kwargs ["local_files_only" ] = download_kwargs .get (
868- "local_files_only" , False
869- )
787+ adapter_kwargs ["local_files_only" ] = download_kwargs .get ("local_files_only" , False )
870788
871789 token = download_kwargs .get ("token" )
872790
@@ -902,12 +820,8 @@ def maybe_load_adapters(
902820 peft_kwargs [arg_name ] = download_kwargs [arg_name ]
903821 if "commit_hash" in download_kwargs :
904822 peft_kwargs ["_commit_hash" ] = download_kwargs ["commit_hash" ]
905- peft_kwargs ["force_download" ] = bool (
906- download_kwargs .get ("force_download" , False )
907- )
908- peft_kwargs ["local_files_only" ] = bool (
909- download_kwargs .get ("local_files_only" , False )
910- )
823+ peft_kwargs ["force_download" ] = bool (download_kwargs .get ("force_download" , False ))
824+ peft_kwargs ["local_files_only" ] = bool (download_kwargs .get ("local_files_only" , False ))
911825 peft_kwargs ["token" ] = token or token_from_adapter_kwargs
912826 _adapter_model_path = find_adapter_config_file (
913827 pretrained_model_name_or_path ,
0 commit comments