Skip to content

Commit 61cafd9

Browse files
ENH: Add support for LoRA hotswapping (#41297)
LoRA hotswapping has been available in PEFT since 0.15.0. There is already a diffusers integration (huggingface/diffusers#9453), but the transformers integration was still missing this feature. This PR remedies this. Hotswapping allows to swap different LoRA adapters in-place instead of loading multiple adapters and switchint between them. Not only can this be advantageous to safe memory and potentially for quicker loading, the biggest advantage is that if the model is compiled, we can hotswap without triggering recompilation (loading a separate adapter would require recompilation). There are some caveats to using this feature, most notably that only LoRA is supported. This was fine for diffusers, as it only works with LoRA, but the transformers integration works with other PEFT methods too. However, LoRA should be by far the most common method, so this should be fine for now. This and other caveats have been documented. To make the usage more intuitive, hotswap is now auto-enabled after calling model.enable_peft_hotswap(). For this, we detect if enable_peft_hotswap() was called *and* if the adapter being loaded is *not* the first adapter (because the first adapter cannot be hotswapped, it needs to be loaded normally).
1 parent 453a246 commit 61cafd9

File tree

3 files changed

+382
-8
lines changed

3 files changed

+382
-8
lines changed

docs/source/en/peft.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,49 @@ model.enable_adapters()
151151
# disable all adapters
152152
model.disable_adapters()
153153
```
154+
155+
## Hotswapping adapters
156+
157+
A common use case when serving multiple adapters is to load one adapter first, generate output, load another adapter, generate more outputs, load another adapter, etc. This can be inefficient, since each time a new adapter is loaded, new memory is reserved; moreover, if the model is compiled with `torch.compile`, it needs to be re-compiled each time a new adapter is used. When switching frequently, the compilation time may never be amortized.
158+
159+
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and, in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter. Note that other PEFT methods are not supported yet, only LoRA.
160+
161+
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter (`"default"` is the default adapter name) to be swapped.
162+
163+
```python
164+
model = AutoModel.from_pretrained(...)
165+
# load adapter 1 as normal
166+
model.load_adapter(file_name_adapter_1)
167+
# generate outputs with adapter 1
168+
...
169+
# now hotswap the 2nd adapter
170+
model.load_adapter(file_name_adapter_2, hotswap=True, adapter_name="default")
171+
# generate outputs with adapter 2
172+
```
173+
174+
For compiled models, it is often necessary to call [`~integrations.peft.PeftAdapterMixin.enable_peft_hotswap`] to avoid recompilation. Call this method _before_ loading the first adapter, while `torch.compile` should be called _after_ loading the first adapter.
175+
176+
```python
177+
model = AutoModel.from_pretrained(...)
178+
max_rank = ... # the highest rank among all LoRAs that you want to load
179+
# call *before* compiling and loading the LoRA adapter
180+
model.enable_peft_hotswap(target_rank=max_rank)
181+
model.load_adapter(file_name_1, adapter_name="default")
182+
# optionally compile the model now
183+
model = torch.compile(model, ...)
184+
output_1 = model(...)
185+
# now you can hotswap the 2nd adapter, use the same name as for the 1st
186+
model.load_adapter(file_name_2, adapter_name="default")
187+
output_2 = model(...)
188+
```
189+
190+
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
191+
192+
By default, hotswapping is disabled and requires you to pass `hotswap=True` to `load_adapter`. However, if you called `enable_peft_hotswap` first, hotswapping will be enabled by default. If you want to avoid using it, you need to pass `hotswap=False`.
193+
194+
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
195+
196+
> [!Tip]
197+
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [PEFT](https://github.com/huggingface/peft/issues) with a reproducible example.
198+
199+
For an example of how the use of `torch.compile` in combination with hotswapping can improve runtime, check out [this blogpost](https://huggingface.co/blog/lora-fast). Although that example uses Diffusers, similar improvements can be expected here.

src/transformers/integrations/peft.py

Lines changed: 147 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import json
1818
import os
1919
import re
20-
from typing import Any, Optional, Union
20+
from typing import Any, Literal, Optional, Union
2121

2222
from packaging import version
2323

@@ -89,6 +89,7 @@ class PeftAdapterMixin:
8989
"""
9090

9191
_hf_peft_config_loaded = False
92+
_prepare_peft_hotswap_kwargs: Optional[dict] = None
9293

9394
def load_adapter(
9495
self,
@@ -104,6 +105,7 @@ def load_adapter(
104105
adapter_state_dict: Optional[dict[str, "torch.Tensor"]] = None,
105106
low_cpu_mem_usage: bool = False,
106107
is_trainable: bool = False,
108+
hotswap: bool | Literal["auto"] = "auto",
107109
adapter_kwargs: Optional[dict[str, Any]] = None,
108110
) -> None:
109111
"""
@@ -159,12 +161,63 @@ def load_adapter(
159161
is_trainable (`bool`, *optional*, defaults to `False`):
160162
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
161163
used for inference.
164+
hotswap : (`"auto"` or `bool`, *optional*, defaults to `"auto"`)
165+
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means
166+
that, instead of loading an additional adapter, this will take the existing adapter weights and replace
167+
them with the weights of the new adapter. This can be faster and more memory efficient. However, the
168+
main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new
169+
adapter does not require recompilation of the model. When using hotswapping, the passed `adapter_name`
170+
should be the name of an already loaded adapter.
171+
172+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
173+
to call an additional method before loading the adapter:
174+
175+
```py
176+
model = AutoModel.from_pretrained(...)
177+
max_rank = ... # the highest rank among all LoRAs that you want to load
178+
# call *before* compiling and loading the LoRA adapter
179+
model.enable_peft_hotswap(target_rank=max_rank)
180+
model.load_adapter(file_name_1, adapter_name="default")
181+
# optionally compile the model now
182+
model = torch.compile(model, ...)
183+
output_1 = model(...)
184+
# now you can hotswap the 2nd adapter, use the same name as for the 1st
185+
# hotswap is activated by default since enable_peft_hotswap was called
186+
model.load_adapter(file_name_2, adapter_name="default")
187+
output_2 = model(...)
188+
```
189+
190+
By default, hotswap is disabled and requires passing `hotswap=True`. If you called
191+
`enable_peft_hotswap` first, it is enabled. You can still manually disable it in that case by passing
192+
`hotswap=False`.
193+
194+
Note that hotswapping comes with a couple of limitations documented here:
195+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
162196
adapter_kwargs (`dict[str, Any]`, *optional*):
163197
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
164198
`find_adapter_config_file` method.
165199
"""
166200
check_peft_version(min_version=MIN_PEFT_VERSION)
167201

202+
from peft import PeftType
203+
204+
if hotswap == "auto":
205+
# if user called model.enable_peft_hotswap and this is not the first adapter, enable hotswap
206+
hotswap_enabled = getattr(self, "_hotswap_enabled", False)
207+
not_first_adapter = bool(self._hf_peft_config_loaded and (adapter_name in self.peft_config))
208+
hotswap = hotswap_enabled and not_first_adapter
209+
210+
if hotswap:
211+
min_version_hotswap = "0.15.0"
212+
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap):
213+
raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.")
214+
if (not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config):
215+
raise ValueError(
216+
"To hotswap an adapter, there must already be an existing adapter with the same adapter name."
217+
)
218+
if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()):
219+
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
220+
168221
# peft only supports low_cpu_mem_usage starting from v0.13.0
169222
peft_load_kwargs = {}
170223
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
@@ -187,8 +240,12 @@ def load_adapter(
187240
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
188241
from peft.utils import set_peft_model_state_dict
189242

190-
if self._hf_peft_config_loaded and adapter_name in self.peft_config:
243+
if self._hf_peft_config_loaded and (not hotswap) and (adapter_name in self.peft_config):
191244
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
245+
elif hotswap and ((not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config)):
246+
raise ValueError(
247+
"To hotswap an adapter, there must already be an existing adapter with the same adapter name."
248+
)
192249

193250
if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
194251
raise ValueError(
@@ -236,9 +293,14 @@ def load_adapter(
236293
**adapter_kwargs,
237294
)
238295
peft_config.inference_mode = not is_trainable
239-
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
240-
# Create and add fresh new adapters into the model.
241-
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
296+
297+
if peft_config.peft_type != PeftType.LORA:
298+
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
299+
300+
if not hotswap:
301+
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
302+
# Create and add fresh new adapters into the model, unless the weights are hotswapped
303+
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
242304

243305
if not self._hf_peft_config_loaded:
244306
self._hf_peft_config_loaded = True
@@ -261,12 +323,47 @@ def load_adapter(
261323
# Early exit of the loop
262324
if n_replace > 0:
263325
break
326+
327+
# For hotswapping, we need the adapter name to be present in the state dict keys
328+
if hotswap:
329+
if key.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
330+
new_key = new_key[: -len(".weight")] + f".{adapter_name}.weight"
331+
elif key.endswith("lora_B.bias"): # lora_bias=True option
332+
new_key = new_key[: -len(".bias")] + f".{adapter_name}.bias"
264333
processed_adapter_state_dict[new_key] = value
265334

266335
# Load state dict
267-
incompatible_keys = set_peft_model_state_dict(
268-
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
269-
)
336+
if not hotswap:
337+
incompatible_keys = set_peft_model_state_dict(
338+
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
339+
)
340+
341+
if self._prepare_peft_hotswap_kwargs is not None:
342+
# For hotswapping of compiled models or adapters with different ranks.
343+
# If the user called enable_peft_hotswap, we need to ensure it is called:
344+
# - after the first adapter was loaded
345+
# - before the model is compiled and the 2nd adapter is being hotswapped in
346+
# Therefore, it needs to be called here
347+
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
348+
349+
prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs)
350+
# We only want to call prepare_model_for_compiled_hotswap once
351+
self._prepare_peft_hotswap_kwargs = None
352+
else:
353+
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
354+
355+
check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config)
356+
try:
357+
hotswap_adapter_from_state_dict(
358+
model=self,
359+
state_dict=processed_adapter_state_dict,
360+
adapter_name=adapter_name,
361+
config=peft_config,
362+
)
363+
except Exception as e:
364+
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
365+
raise
366+
incompatible_keys = None
270367

271368
if incompatible_keys is not None:
272369
err_msg = ""
@@ -308,6 +405,47 @@ def load_adapter(
308405
offload_index=offload_index,
309406
)
310407

408+
def enable_peft_hotswap(
409+
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
410+
) -> None:
411+
"""Enables the possibility to hotswap PEFT adapters with different ranks, or, if the model is compiled, without
412+
triggering recompilation.
413+
414+
Right now, hotswapping is only supported for LoRA.
415+
416+
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
417+
the loaded adapters differ. If the ranks are all identical and the model is not compiled, hotswapping works
418+
without calling this method first.
419+
420+
Args:
421+
target_rank (`int`, *optional*, defaults to `128`):
422+
The highest rank among all the adapters that will be loaded.
423+
check_compiled (`str`, *optional*, defaults to `"error"`):
424+
How to handle the case when the model is already compiled, which should generally be avoided. The
425+
options are:
426+
- "error" (default): raise an error
427+
- "warn": issue a warning
428+
- "ignore": do nothing
429+
"""
430+
min_version_hotswap = "0.15.0"
431+
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap):
432+
raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.")
433+
434+
if getattr(self, "peft_config", {}):
435+
if check_compiled == "error":
436+
raise RuntimeError("Call `enable_peft_hotswap` before loading the first adapter.")
437+
elif check_compiled == "warn":
438+
logger.warning(
439+
"It is recommended to call `enable_peft_hotswap` before loading the first adapter to avoid recompilation."
440+
)
441+
elif check_compiled != "ignore":
442+
raise ValueError(
443+
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
444+
)
445+
446+
self._hotswap_enabled = True
447+
self._prepare_peft_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
448+
311449
def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None:
312450
r"""
313451
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
@@ -343,6 +481,7 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non
343481
# Retrieve the name or path of the model, one could also use self.config._name_or_path
344482
# but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100
345483
adapter_config.base_model_name_or_path = self.__dict__.get("name_or_path", None)
484+
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
346485
inject_adapter_in_model(adapter_config, self, adapter_name)
347486

348487
self.set_adapter(adapter_name)

0 commit comments

Comments
 (0)