You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We provide a training script [train_ddp_cp](./esm2_native_te/train_ddp_cp.py) and a sample config [L0_sanity_cp](./hydra_config/L0_sanity_cp.yaml) that uses context parallelism.
93
+
94
+
In the config the argument `--cp_size` allows the user to set the size of the context parallel distributed group. When paired with Distributed Data Parallelism (DDP), the number of context parallel groups will be determined by `world_size//cp_size`.
95
+
96
+
Thus, for example, if a user has 8 processes and sets `cp_size=2` they will have `2` CP groups and `4` DDP groups. During dataloading we make no assumptions about the data pipeline being deterministic or not. We simply unique data only for the DDP groups and select the relevant CP shards for the respective CP group.
97
+
98
+
For example, let's say that we have 2 DDP groups and 2 CP groups. Each DDP group will have a unique dataloader DP0 for DDP group 0
99
+
and DP1 for DDP group 1. CP works by running something called ring attention, which expects tokens to live on each device in a particular layout. For this CP implementation we use something called [Dual Chunk Swapping](https://github.com/NVIDIA/TransformerEngine/blob/1df4a69f761672f633d40ea3605327087d1ea737/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py#L3714-L3770). If DP0 outputs sequence `1 2 3 4 5 6 7 8` and DP1 outputs `9 10 11 12 13 14 15 16` then when we run through the `CPAwareDataloader` defined in [datasets](./dataset.py), the dataloader will create CP shards from that DP group as follows:
100
+
101
+
```
102
+
| DP0 | DP1 |
103
+
CP0 | 1,2,7,8 | 9, 10, 15, 16 |
104
+
CP1 | 3,4,5,6 | 11, 12, 13, 14|
105
+
```
106
+
You may notice these shards and wonder why they are the way they are. We did. The reason is that CP groups are sharded using slices. The full input sequence (such as `1 2 3 4 5 6 7`) is sliced into `2 * cp_size` groups. Then CP0 takes the first and last slice, while CP1 takes the middle slices, of each sequence.
107
+
108
+
In this example we only show one sequence but its important to note that slicing takes place on every sequence, so if a second sequence is also available, that will be sliced in the same manner. CP0 will take the first and last slice of every sequence, while CP1 will take the middle slices of each sequence.
109
+
110
+
111
+
91
112
### Comparing Against the HF Transformers Reference Implementation
92
113
93
114
To launch training with the ESM-2 model as implemented in HF Transformers, pass a `facebook/esm2` checkpoint as the
0 commit comments