Skip to content

Commit 74a05ae

Browse files
committed
adds Context Parallel README
Signed-off-by: Jonathan Mitchell <[email protected]>
1 parent d32ed9e commit 74a05ae

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

bionemo-recipes/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ The biological AI community is actively prototyping model architectures and need
2929
| Directory | Description | FSDP | BF16 | FP8<sup>[1]</sup> | THD | FP8 + THD | MXFP8<sup>[2]</sup> | NVFP4<sup>[2]</sup> | CP |
3030
| ------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------- | ------------ | ---- | ----------------- | --- | --------- | ------------------- | ------------------- | --- |
3131
| `models/amplify`,<br> [available on Hugging Face](https://huggingface.co/nvidia/AMPLIFY_350M) | TE accelerated protein BERT, [Amgen](https://www.biorxiv.org/content/10.1101/2024.09.23.614603v1) ||||| 🚧 ||||
32-
| `models/esm2`,<br> [available on Hugging Face](https://huggingface.co/nvidia/esm2_t48_15B_UR50D) | TE accelerated protein BERT, [Meta](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1) ||||||| 🚧 | 🚧 |
32+
| `models/esm2`,<br> [available on Hugging Face](https://huggingface.co/nvidia/esm2_t48_15B_UR50D) | TE accelerated protein BERT, [Meta](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1) ||||||| 🚧 | |
3333
| `recipes/`<br>`codonfm_ptl_te` | Recipe for [CodonFM](https://research.nvidia.com/labs/dbr/assets/data/manuscripts/nv-codonfm-preprint.pdf)'s Encodon using TE | 🚧 || 🚧 || 🚧 | 🚧 | 🚧 | 🚧 |
3434
| `recipes/`<br>`esm2_accelerate_te` | Recipe for `esm2/amplify` TE + HF Accelerate | 🚧 ||| 🚧 || 🚧 | 🚧 | 🚧 |
3535
| `recipes/`<br>`esm2_native_te` | Recipe for `esm2/amplify` + native PyTorch | mFSDP, FSDP2 |||||| 🚧 | 🚧 |
@@ -48,7 +48,7 @@ Abbreviations:
4848
- MXFP8<sup>[2]</sup>: [Multi Scale 8-bit floating point](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html), as compact as FP8 but with better numerical precision.
4949
- NVFP4<sup>[2]</sup>: [NVIDIA 4-bit floating point](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Beyond-FP8---training-with-NVFP4), faster than FP8, retaining accuracy via multi-scale.
5050
- THD: **T**otal **H**eads **D**imension, also known as ["sequence packing"](https://docs.nvidia.com/nemo-framework/user-guide/24.07/nemotoolkit/features/optimizations/sequence_packing.html#sequence-packing-for-sft-peft). A way to construct a batch with sequences of different length so there are no pads, therefore no compute is wasted on computing attention for padding tokens. This is in contrast to **B**atch **S**equence **H**ead **D**imension (BSHD) format, which uses pads to create a rectangular batch.
51-
- CP: Context parallel, also known as sequence parallel. A way to distribute the memory required to process long sequences across multiple GPUs.
51+
- CP: Context parallel, also known as sequence parallel. A way to distribute the memory required to process long sequences across multiple GPUs. For more information please see [context parallel](./recipes/context_parallel.md)
5252

5353
\[1\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 9.0 and above (Hopper+) <br/>
5454
\[2\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and 10.3 (Blackwell), 12.0 support pending <br/>
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Context Parallelism
2+
3+
## What is it
4+
5+
When training transformer-based models, context is everything. It's what tells the model: look at this token in relation to all these other tokens. The model's ability to establish context is paramount to many LLM tasks because it "grounds" the model and tells it to look at this thing in relation to its context.
6+
7+
But what happens when I can't make that context window any bigger? Enter Context Parallelism (CP). CP is used to parallelize context across multiple GPUs, such that sequences can be sharded and split up so that no single GPU has to hold the entire context in memory, they can share the load.
8+
9+
In short, Context Parallelism distributes sequences across devices. It's one of the "Ds" in what's known as 5D parallelism (Tensor Parallel, Pipeline Parallel, Data Parallel, Expert Parallel, Context Parallel).
10+
11+
CP acts very similarly to Data Parallelism, in that the activations for input tokens are distributed across devices. The key difference is that in CP, the activations across multiple devices are part of the same input sequence, whereas in Data Parallelism these split activations are for different sequences.
12+
13+
## How does it work?
14+
15+
The core idea behind CP is to partition the data into various chunks, with each chunk assigned to a different device (GPU in this case). During each forward pass, each device computes attention locally on a chunk while coordinating with other devices to access key-value pairs needed for the full attention computation.
16+
17+
## What does the data generation part look like?
18+
19+
In BioNeMo, we've created some abstractions to partition the data for you. There exists a [CPAwareDataloader](esm2_native_te/dataset.py) that will shard the CP data for you and send it to each device. This dataloader operates on Sequence Packed (THD) data [link](https://docs.nvidia.com/nemo-framework/user-guide/24.12/nemotoolkit/features/optimizations/sequence_packing.html). This `CPAwareDataloader` will take as arguments your CP group and local CP rank. This dataloader wrapper will call its underlying dataloader to generate a unique piece of data and then shard those unique sequences across your CP groups. This is beneficial because you won't need to maintain a deterministic data pipeline because unique data is only being generated across the non CP groups, and it is replicated across the CP groups. More details below.
20+
21+
Alternatively, one could utilize any DataLoader such as the canonical [PyTorch DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), however, you would have to ensure that your dataset is synchronized across CP ranks. In some cases, if you have a non-deterministic data pipeline, even if you attempt to get the same data from a dataloader it may be different due to non-deterministic preprocessing stages such as masking. For more information on preserving determinism in your datasets, please see [MegatronLMDataModule](https://nvidia.github.io/bionemo-framework/main/about/background/megatron_datasets/).
22+
23+
### Context Parallelism Sharding Example
24+
25+
**Original packed sequences (2 seqs):**
26+
27+
```
28+
┌─────────────────────────┐
29+
│ 1, 2, 3 | 5, 6 │
30+
└─────────────────────────┘
31+
```
32+
33+
**Pad to divisibility:**
34+
35+
```
36+
┌────────────────────────────────────┐
37+
│ 1, 2, 3, <pad> | 5, 6, <pad>, <pad> │
38+
└────────────────────────────────────┘
39+
```
40+
41+
**Distributed across CP ranks:**
42+
43+
```
44+
CP0: [1, <pad> | 5, <pad>]
45+
CP1: [2, 3 | 6, <pad>]
46+
```
47+
48+
In the example above, imagine that we have 2 CP groups (CP0 and CP1). The `CPAwareDataloader` takes as an argument an `UnderlyingDataloader` which generates the unique sequences `1, 2, 3` and `5, 6`. In CP we need to pad these sequences so that they are divisible by `cp_size*2` to enable chunking for [Ring Attention](https://arxiv.org/abs/2310.01889). In this case with `cp_size=2` we need to make each sequence divisible by 4.
49+
50+
After we've padded the sequences, we distribute the shards across the CP ranks (CP0 and CP1). We can see that each CP rank takes a slice from the first and second sequence. For CP0 it takes the first and last token of both sequences while CP1 takes the middle tokens.
51+
52+
After ring attention, the activations will also be sharded across those CP groups so no device has to hold all of them!
53+
54+
### Resources
55+
56+
For more information related to Context Parallelism, please see our recipes:
57+
58+
- [esm2/train_ddp_cp.py](esm2_native_te/train_ddp_cp.py)

0 commit comments

Comments
 (0)