Skip to content

Commit 7630509

Browse files
committed
Apply style fixes
1 parent dd6a236 commit 7630509

File tree

2 files changed

+61
-198
lines changed

2 files changed

+61
-198
lines changed

src/transformers/integrations/peft.py

Lines changed: 38 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -205,48 +205,28 @@ def load_adapter(
205205
if hotswap == "auto":
206206
# if user called model.enable_peft_hotswap and this is not the first adapter, enable hotswap
207207
hotswap_enabled = getattr(self, "_hotswap_enabled", False)
208-
not_first_adapter = bool(
209-
self._hf_peft_config_loaded and (adapter_name in self.peft_config)
210-
)
208+
not_first_adapter = bool(self._hf_peft_config_loaded and (adapter_name in self.peft_config))
211209
hotswap = hotswap_enabled and not_first_adapter
212210

213211
if hotswap:
214212
min_version_hotswap = "0.15.0"
215-
if version.parse(importlib.metadata.version("peft")) < version.parse(
216-
min_version_hotswap
217-
):
218-
raise ValueError(
219-
f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}."
220-
)
221-
if (not self._hf_peft_config_loaded) or (
222-
adapter_name not in self.peft_config
223-
):
213+
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap):
214+
raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.")
215+
if (not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config):
224216
raise ValueError(
225217
"To hotswap an adapter, there must already be an existing adapter with the same adapter name."
226218
)
227-
if any(
228-
conf.peft_type != PeftType.LORA for conf in self.peft_config.values()
229-
):
230-
raise ValueError(
231-
"Hotswapping is currently only supported for LoRA, please set `hotswap=False`."
232-
)
219+
if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()):
220+
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
233221

234222
# peft only supports low_cpu_mem_usage starting from v0.13.0
235223
peft_load_kwargs = {}
236-
key_mapping = (
237-
adapter_kwargs.pop("key_mapping", None)
238-
if adapter_kwargs is not None
239-
else None
240-
)
241-
if key_mapping is None and any(
242-
allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS
243-
):
224+
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
225+
if key_mapping is None and any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
244226
key_mapping = self._checkpoint_conversion_mapping
245227
if low_cpu_mem_usage:
246228
min_version_lcmu = "0.13.0"
247-
if version.parse(importlib.metadata.version("peft")) >= version.parse(
248-
min_version_lcmu
249-
):
229+
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
250230
peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
251231
else:
252232
raise ValueError(
@@ -264,34 +244,20 @@ def load_adapter(
264244
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
265245
from peft.utils import set_peft_model_state_dict
266246

267-
if (
268-
self._hf_peft_config_loaded
269-
and (not hotswap)
270-
and (adapter_name in self.peft_config)
271-
):
272-
raise ValueError(
273-
f"Adapter with name {adapter_name} already exists. Please use a different name."
274-
)
275-
elif hotswap and (
276-
(not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config)
277-
):
247+
if self._hf_peft_config_loaded and (not hotswap) and (adapter_name in self.peft_config):
248+
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
249+
elif hotswap and ((not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config)):
278250
raise ValueError(
279251
"To hotswap an adapter, there must already be an existing adapter with the same adapter name."
280252
)
281253

282-
if peft_model_id is None and (
283-
adapter_state_dict is None and peft_config is None
284-
):
254+
if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
285255
raise ValueError(
286256
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
287257
)
288258

289259
if "device" not in adapter_kwargs:
290-
device = (
291-
self.device
292-
if not hasattr(self, "hf_device_map")
293-
else list(self.hf_device_map.values())[0]
294-
)
260+
device = self.device if not hasattr(self, "hf_device_map") else list(self.hf_device_map.values())[0]
295261
else:
296262
device = adapter_kwargs.pop("device")
297263

@@ -302,11 +268,7 @@ def load_adapter(
302268
# We keep `revision` in the signature for backward compatibility
303269
if revision is not None and "revision" not in adapter_kwargs:
304270
adapter_kwargs["revision"] = revision
305-
elif (
306-
revision is not None
307-
and "revision" in adapter_kwargs
308-
and revision != adapter_kwargs["revision"]
309-
):
271+
elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]:
310272
logger.error(
311273
"You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. "
312274
"The one in `adapter_kwargs` will be used."
@@ -337,9 +299,7 @@ def load_adapter(
337299
peft_config.inference_mode = not is_trainable
338300

339301
if peft_config.peft_type != PeftType.LORA:
340-
raise ValueError(
341-
"Hotswapping is currently only supported for LoRA, please set `hotswap=False`."
342-
)
302+
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
343303

344304
if not hotswap:
345305
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
@@ -350,9 +310,7 @@ def load_adapter(
350310
self._hf_peft_config_loaded = True
351311

352312
if peft_model_id is not None:
353-
adapter_state_dict = load_peft_weights(
354-
peft_model_id, token=token, device=device, **adapter_kwargs
355-
)
313+
adapter_state_dict = load_peft_weights(peft_model_id, token=token, device=device, **adapter_kwargs)
356314

357315
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
358316
processed_adapter_state_dict = {}
@@ -392,9 +350,7 @@ def load_adapter(
392350
# Therefore, it needs to be called here
393351
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
394352

395-
prepare_model_for_compiled_hotswap(
396-
self, config=peft_config, **self._prepare_peft_hotswap_kwargs
397-
)
353+
prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs)
398354
# We only want to call prepare_model_for_compiled_hotswap once
399355
self._prepare_peft_hotswap_kwargs = None
400356
else:
@@ -403,9 +359,7 @@ def load_adapter(
403359
hotswap_adapter_from_state_dict,
404360
)
405361

406-
check_hotswap_configs_compatible(
407-
self.peft_config[adapter_name], peft_config
408-
)
362+
check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config)
409363
try:
410364
hotswap_adapter_from_state_dict(
411365
model=self,
@@ -414,20 +368,15 @@ def load_adapter(
414368
config=peft_config,
415369
)
416370
except Exception as e:
417-
logger.error(
418-
f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}"
419-
)
371+
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
420372
raise
421373
incompatible_keys = None
422374

423375
if incompatible_keys is not None:
424376
err_msg = ""
425377
origin_name = peft_model_id if peft_model_id is not None else "state_dict"
426378
# Check for unexpected keys.
427-
if (
428-
hasattr(incompatible_keys, "unexpected_keys")
429-
and len(incompatible_keys.unexpected_keys) > 0
430-
):
379+
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
431380
err_msg = (
432381
f"Loading adapter weights from {origin_name} led to unexpected keys not found in the model: "
433382
f"{', '.join(incompatible_keys.unexpected_keys)}. "
@@ -437,9 +386,7 @@ def load_adapter(
437386
missing_keys = getattr(incompatible_keys, "missing_keys", None)
438387
if missing_keys:
439388
# Filter missing keys specific to the current adapter, as missing base model keys are expected.
440-
lora_missing_keys = [
441-
k for k in missing_keys if "lora_" in k and adapter_name in k
442-
]
389+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
443390
if lora_missing_keys:
444391
err_msg += (
445392
f"Loading adapter weights from {origin_name} led to missing keys in the model: "
@@ -455,9 +402,7 @@ def load_adapter(
455402
# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
456403
if (
457404
(getattr(self, "hf_device_map", None) is not None)
458-
and (
459-
len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0
460-
)
405+
and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
461406
and len(self.peft_config) == 1
462407
):
463408
self._dispatch_accelerate_model(
@@ -492,18 +437,12 @@ def enable_peft_hotswap(
492437
- "ignore": do nothing
493438
"""
494439
min_version_hotswap = "0.15.0"
495-
if version.parse(importlib.metadata.version("peft")) < version.parse(
496-
min_version_hotswap
497-
):
498-
raise ValueError(
499-
f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}."
500-
)
440+
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap):
441+
raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.")
501442

502443
if getattr(self, "peft_config", {}):
503444
if check_compiled == "error":
504-
raise RuntimeError(
505-
"Call `enable_peft_hotswap` before loading the first adapter."
506-
)
445+
raise RuntimeError("Call `enable_peft_hotswap` before loading the first adapter.")
507446
elif check_compiled == "warn":
508447
logger.warning(
509448
"It is recommended to call `enable_peft_hotswap` before loading the first adapter to avoid recompilation."
@@ -546,14 +485,10 @@ def add_adapter(self, adapter_config, adapter_name: str | None = None) -> None:
546485
if not self._hf_peft_config_loaded:
547486
self._hf_peft_config_loaded = True
548487
elif adapter_name in self.peft_config:
549-
raise ValueError(
550-
f"Adapter with name {adapter_name} already exists. Please use a different name."
551-
)
488+
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
552489

553490
if not isinstance(adapter_config, PeftConfig):
554-
raise TypeError(
555-
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
556-
)
491+
raise TypeError(f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead.")
557492

558493
# Retrieve the name or path of the model, one could also use self.config._name_or_path
559494
# but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100
@@ -667,9 +602,7 @@ def active_adapters(self) -> list[str]:
667602
check_peft_version(min_version=MIN_PEFT_VERSION)
668603

669604
if not is_peft_available():
670-
raise ImportError(
671-
"PEFT is not available. Please install PEFT to use this function: `pip install peft`."
672-
)
605+
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
673606

674607
if not self._hf_peft_config_loaded:
675608
raise ValueError("No adapter loaded. Please load an adapter first.")
@@ -687,9 +620,7 @@ def active_adapters(self) -> list[str]:
687620

688621
return active_adapters
689622

690-
def get_adapter_state_dict(
691-
self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None
692-
) -> dict:
623+
def get_adapter_state_dict(self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None) -> dict:
693624
"""
694625
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
695626
official documentation: https://huggingface.co/docs/peft
@@ -715,9 +646,7 @@ def get_adapter_state_dict(
715646
if adapter_name is None:
716647
adapter_name = self.active_adapters()[0]
717648

718-
adapter_state_dict = get_peft_model_state_dict(
719-
self, state_dict=state_dict, adapter_name=adapter_name
720-
)
649+
adapter_state_dict = get_peft_model_state_dict(self, state_dict=state_dict, adapter_name=adapter_name)
721650
return adapter_state_dict
722651

723652
def _dispatch_accelerate_model(
@@ -817,9 +746,7 @@ def old_delete_adapter(model, adapter_name, prefix=None):
817746
f">= {min_version_delete_adapter} is required."
818747
)
819748

820-
if version.parse(importlib.metadata.version("peft")) >= version.parse(
821-
min_version_delete_adapter
822-
):
749+
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_delete_adapter):
823750
from peft.functional import delete_adapter
824751
else:
825752
delete_adapter = old_delete_adapter
@@ -828,24 +755,17 @@ def old_delete_adapter(model, adapter_name, prefix=None):
828755
adapter_names = [adapter_names]
829756

830757
# Check that all adapter names are present in the config
831-
missing_adapters = [
832-
name for name in adapter_names if name not in self.peft_config
833-
]
758+
missing_adapters = [name for name in adapter_names if name not in self.peft_config]
834759
if missing_adapters:
835760
raise ValueError(
836761
f"The following adapter(s) are not present and cannot be deleted: {', '.join(missing_adapters)}"
837762
)
838763

839-
prefixes = [
840-
f"{self.peft_config[adapter_name].peft_type.value.lower()}_"
841-
for adapter_name in adapter_names
842-
]
764+
prefixes = [f"{self.peft_config[adapter_name].peft_type.value.lower()}_" for adapter_name in adapter_names]
843765
for adapter_name, prefix in zip(adapter_names, prefixes):
844766
delete_adapter(self, adapter_name=adapter_name, prefix=prefix)
845767
# For transformers integration - we need to pop the adapter from the config
846-
if getattr(self, "_hf_peft_config_loaded", False) and hasattr(
847-
self, "peft_config"
848-
):
768+
if getattr(self, "_hf_peft_config_loaded", False) and hasattr(self, "peft_config"):
849769
self.peft_config.pop(adapter_name, None)
850770

851771
# In case all adapters are deleted, we need to delete the config
@@ -864,9 +784,7 @@ def maybe_load_adapters(
864784
return None, pretrained_model_name_or_path, adapter_kwargs
865785

866786
if "local_files_only" not in adapter_kwargs:
867-
adapter_kwargs["local_files_only"] = download_kwargs.get(
868-
"local_files_only", False
869-
)
787+
adapter_kwargs["local_files_only"] = download_kwargs.get("local_files_only", False)
870788

871789
token = download_kwargs.get("token")
872790

@@ -902,12 +820,8 @@ def maybe_load_adapters(
902820
peft_kwargs[arg_name] = download_kwargs[arg_name]
903821
if "commit_hash" in download_kwargs:
904822
peft_kwargs["_commit_hash"] = download_kwargs["commit_hash"]
905-
peft_kwargs["force_download"] = bool(
906-
download_kwargs.get("force_download", False)
907-
)
908-
peft_kwargs["local_files_only"] = bool(
909-
download_kwargs.get("local_files_only", False)
910-
)
823+
peft_kwargs["force_download"] = bool(download_kwargs.get("force_download", False))
824+
peft_kwargs["local_files_only"] = bool(download_kwargs.get("local_files_only", False))
911825
peft_kwargs["token"] = token or token_from_adapter_kwargs
912826
_adapter_model_path = find_adapter_config_file(
913827
pretrained_model_name_or_path,

0 commit comments

Comments
 (0)