You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
LoRA hotswapping has been available in PEFT since 0.15.0. There is
already a diffusers
integration (huggingface/diffusers#9453), but
the transformers integration was still missing this feature. This PR
remedies this.
Hotswapping allows to swap different LoRA adapters in-place instead of
loading multiple adapters and switchint between them. Not only can this
be advantageous to safe memory and potentially for quicker loading, the
biggest advantage is that if the model is compiled, we can hotswap
without triggering recompilation (loading a separate adapter would
require recompilation).
There are some caveats to using this feature, most notably that only
LoRA is supported. This was fine for diffusers, as it only works with
LoRA, but the transformers integration works with other PEFT methods
too. However, LoRA should be by far the most common method, so this
should be fine for now. This and other caveats have been documented.
To make the usage more intuitive, hotswap is now auto-enabled after
calling model.enable_peft_hotswap(). For this, we detect if
enable_peft_hotswap() was called *and* if the adapter being loaded
is *not* the first adapter (because the first adapter cannot be
hotswapped, it needs to be loaded normally).
Copy file name to clipboardExpand all lines: docs/source/en/peft.md
+46Lines changed: 46 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -151,3 +151,49 @@ model.enable_adapters()
151
151
# disable all adapters
152
152
model.disable_adapters()
153
153
```
154
+
155
+
## Hotswapping adapters
156
+
157
+
A common use case when serving multiple adapters is to load one adapter first, generate output, load another adapter, generate more outputs, load another adapter, etc. This can be inefficient, since each time a new adapter is loaded, new memory is reserved; moreover, if the model is compiled with `torch.compile`, it needs to be re-compiled each time a new adapter is used. When switching frequently, the compilation time may never be amortized.
158
+
159
+
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and, in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter. Note that other PEFT methods are not supported yet, only LoRA.
160
+
161
+
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter (`"default"` is the default adapter name) to be swapped.
For compiled models, it is often necessary to call [`~integrations.peft.PeftAdapterMixin.enable_peft_hotswap`] to avoid recompilation. Call this method _before_ loading the first adapter, while `torch.compile` should be called _after_ loading the first adapter.
175
+
176
+
```python
177
+
model = AutoModel.from_pretrained(...)
178
+
max_rank =...# the highest rank among all LoRAs that you want to load
179
+
# call *before* compiling and loading the LoRA adapter
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
191
+
192
+
By default, hotswapping is disabled and requires you to pass `hotswap=True` to `load_adapter`. However, if you called `enable_peft_hotswap` first, hotswapping will be enabled by default. If you want to avoid using it, you need to pass `hotswap=False`.
193
+
194
+
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
195
+
196
+
> [!Tip]
197
+
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [PEFT](https://github.com/huggingface/peft/issues) with a reproducible example.
198
+
199
+
For an example of how the use of `torch.compile` in combination with hotswapping can improve runtime, check out [this blogpost](https://huggingface.co/blog/lora-fast). Although that example uses Diffusers, similar improvements can be expected here.
# Retrieve the name or path of the model, one could also use self.config._name_or_path
344
482
# but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100
0 commit comments