You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source-fabric/advanced/compile.rst
+107-1Lines changed: 107 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -115,9 +115,115 @@ always exclude the first call to ``forward()`` from your measurements, since it
115
115
Compile median time: 0.0185 seconds
116
116
Speedup: 1.4x
117
117
118
-
119
118
----
120
119
120
+
**********************************************
121
+
Apply torch.compile with ModelParallelStrategy
122
+
**********************************************
123
+
124
+
:func:`torch.compile` can also be invoked as part of the `parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`.
125
+
126
+
This is particularly handy when :func:`torch.compile` is used in combination with the `torch.distributed.tensor` API.
127
+
128
+
Here is an example:
129
+
130
+
.. code-block:: python
131
+
132
+
import lightning as L
133
+
import torch
134
+
import torch.nn as nn
135
+
import torch.nn.functional as F
136
+
from lightning.pytorch.demos import Transformer
137
+
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy
138
+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
139
+
from torch.distributed.device_mesh import DeviceMesh
For a full example, see our `FP8 Distributed Transformer example <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/fp8_distributed_transformer>`_.
Copy file name to clipboardExpand all lines: docs/source-pytorch/advanced/compile.rst
+118-2Lines changed: 118 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -138,6 +138,122 @@ always exclude the first call to ``forward()``/``*_step()`` from your measuremen
138
138
139
139
----
140
140
141
+
**************************************
142
+
Apply torch.compile in configure_model
143
+
**************************************
144
+
145
+
:func:`torch.compile` can also be invoked as part of the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook.
146
+
147
+
This is particularly handy when :func:`torch.compile` is used in combination with :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`.
148
+
149
+
Here is an example:
150
+
151
+
.. code-block:: python
152
+
153
+
import lightning as L
154
+
import torch
155
+
import torch.nn as nn
156
+
import torch.nn.functional as F
157
+
from lightning.pytorch.demos import Transformer
158
+
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
159
+
from torch.distributed.device_mesh import DeviceMesh
160
+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
For a full example, see our `FP8 Distributed Transformer example <https://github.com/Lightning-AI/lightning/blob/master/examples/pytorch/fp8_distributed_transformer>`_.
255
+
256
+
----
141
257
142
258
******************
143
259
Avoid graph breaks
@@ -253,8 +369,8 @@ Limitations
253
369
254
370
There are a few limitations you should be aware of when using ``torch.compile`` **in conjunction with the Trainer**:
255
371
256
-
* The Trainer currently does not reapply ``torch.compile`` over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment.
257
-
This limitation will be lifted in the future.
372
+
* The Trainer currently does not reapply ``torch.compile`` over :class:`~lightning.pytorch.strategies.DDPStrategy` and :class:`~lightning.pytorch.strategies.FSDPStrategy`, meaning distributed operations can't benefit from speed ups at the moment.
373
+
This limitation can be avoided by using :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`, as described in `Apply torch.compile in configure_model`_ above.
258
374
259
375
* In some cases, using ``self.log()`` in your LightningModule will cause compilation errors.
260
376
Until addressed, you can work around these issues by applying ``torch.compile`` to the submodule(s) of your LightningModule rather than to the entire LightningModule at once.
This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs.
4
+
5
+
### Training Large Models and Memory Requirements
6
+
7
+
One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP).
8
+
9
+
An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time).
10
+
11
+
Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40).
12
+
13
+
The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination:
14
+
15
+
-[torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats)
16
+
-[dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch)
17
+
18
+
Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations.
19
+
20
+
### Vanilla Transformer Example
21
+
22
+
This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`.
23
+
24
+
Specifically, we employ the `ModelParallelStrategy`, and use the `configure_model` hook to distribute the model using the PyTorch DTensor API.
25
+
In the same hook we also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2).
26
+
27
+
The resulting code follows the PyTorch API closely, while also taking advantage of the rest of PyTorch Lightning.
28
+
29
+
To execute the code directly just run:
30
+
31
+
```bash
32
+
python train.py
33
+
```
34
+
35
+
### A Note on torch.compile
36
+
37
+
Note that PyTorch Lightning also supports calling `torch.compile` on a `LightningModule` and passing it to the `Trainer`.
38
+
39
+
While this works for simple cases, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` at the end of the `configure_model` hook, as shown in this example.
0 commit comments