Skip to content

Commit ce56fce

Browse files
Jonathan MitchellJonathan Mitchell
authored andcommitted
Context Parallel Squashed MR for recipes
- removed models dir Signed-off-by: Jonathan Mitchell <[email protected]>
1 parent e21115d commit ce56fce

File tree

9 files changed

+1457
-12
lines changed

9 files changed

+1457
-12
lines changed

bionemo-recipes/models/esm2/tests/test_cp.py

Lines changed: 387 additions & 0 deletions
Large diffs are not rendered by default.

bionemo-recipes/recipes/esm2_native_te/README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ bionemo-framework repository. You can download a zipped directory of this folder
1717

1818
| Model | BF16 | FP8<sup>[1]</sup> | THD Input Format | FP8 with THD Input Format | MXFP8<sup>[2]</sup> | Context Parallelism |
1919
| ----------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- |
20-
| [ESM-2](../../models/esm2/README.md) |||||| 🚧 |
20+
| [ESM-2](../../models/esm2/README.md) |||||| |
2121
| [AMPLIFY](../../models/amplify/README.md) ||| 🚧 ||| 🚧 |
2222

2323
✅: Supported <br/>
@@ -88,6 +88,27 @@ python train_fsdp2.py --config-name L0_sanity \
8888
use_sequence_packing=true
8989
```
9090

91+
### Context Parallelism
92+
We provide a training script [train_ddp_cp](./esm2_native_te/train_ddp_cp.py) and a sample config [L0_sanity_cp](./hydra_config/L0_sanity_cp.yaml) that uses context parallelism.
93+
94+
In the config the argument `--cp_size` allows the user to set the size of the context parallel distributed group. When paired with Distributed Data Parallelism (DDP), the number of context parallel groups will be determined by `world_size//cp_size`.
95+
96+
Thus, for example, if a user has 8 processes and sets `cp_size=2` they will have `2` CP groups and `4` DDP groups. During dataloading we make no assumptions about the data pipeline being deterministic or not. We simply unique data only for the DDP groups and select the relevant CP shards for the respective CP group.
97+
98+
For example, let's say that we have 2 DDP groups and 2 CP groups. Each DDP group will have a unique dataloader DP0 for DDP group 0
99+
and DP1 for DDP group 1. CP works by running something called ring attention, which expects tokens to live on each device in a particular layout. For this CP implementation we use something called [Dual Chunk Swapping](https://github.com/NVIDIA/TransformerEngine/blob/1df4a69f761672f633d40ea3605327087d1ea737/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py#L3714-L3770). If DP0 outputs sequence `1 2 3 4 5 6 7 8` and DP1 outputs `9 10 11 12 13 14 15 16` then when we run through the `CPAwareDataloader` defined in [datasets](./dataset.py), the dataloader will create CP shards from that DP group as follows:
100+
101+
```
102+
| DP0 | DP1 |
103+
CP0 | 1,2,7,8 | 9, 10, 15, 16 |
104+
CP1 | 3,4,5,6 | 11, 12, 13, 14|
105+
```
106+
You may notice these shards and wonder why they are the way they are. We did. The reason is that CP groups are sharded using slices. The full input sequence (such as `1 2 3 4 5 6 7`) is sliced into `2 * cp_size` groups. Then CP0 takes the first and last slice, while CP1 takes the middle slices, of each sequence.
107+
108+
In this example we only show one sequence but its important to note that slicing takes place on every sequence, so if a second sequence is also available, that will be sliced in the same manner. CP0 will take the first and last slice of every sequence, while CP1 will take the middle slices of each sequence.
109+
110+
111+
91112
### Comparing Against the HF Transformers Reference Implementation
92113

93114
To launch training with the ESM-2 model as implemented in HF Transformers, pass a `facebook/esm2` checkpoint as the

bionemo-recipes/recipes/esm2_native_te/collator.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020

2121
import logging
2222
from dataclasses import dataclass
23-
from typing import Any
23+
from typing import Any, Callable, Optional
2424

2525
import datasets
2626
import torch
2727
from transformers import DataCollatorForLanguageModeling, DefaultDataCollator, PreTrainedTokenizerBase
2828

29+
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp
30+
2931

3032
logger = logging.getLogger(__name__)
3133

@@ -38,6 +40,7 @@ class MLMDataCollatorWithFlattening:
3840
2. Then applying MLM masking to the flattened sequence
3941
3. Providing Flash Attention metadata (cu_seq_lens) for sequence boundary awareness
4042
4. Optionally padding the total sequence length to be divisible by a specified number
43+
5. Optionally, pad each sequence to be divisible by a specified number (if provided).
4144
4245
The result is a THD-format batch optimized for Flash Attention with sequence packing,
4346
eliminating the need for traditional attention masks while maintaining proper sequence
@@ -62,6 +65,9 @@ class MLMDataCollatorWithFlattening:
6265
seed (int | None): Random seed for reproducible masking. Defaults to None.
6366
pad_to_multiple_of (int | None): If set, pads the total sequence length to be divisible
6467
by this number by adding a mock sequence at the end. Defaults to None.
68+
pad_sequences_to_be_divisible_by (int | None): If set, pads each sequence to be divisible
69+
by this number by adding padding tokens and labels set to -100. Defaults to None.
70+
This is used by context parallelism.
6571
6672
Example:
6773
>>> from transformers import AutoTokenizer
@@ -111,6 +117,7 @@ def __init__(
111117
return_position_ids: bool = False,
112118
bshd_equivalent: bool = False,
113119
bshd_pad_to_multiple_of: int | None = None,
120+
pad_sequences_to_be_divisible_by: int | None = None,
114121
):
115122
"""Initialize the MLMDataCollatorWithFlattening.
116123
@@ -129,6 +136,9 @@ def __init__(
129136
collator, at the expense of additional computation time. Defaults to False.
130137
bshd_pad_to_multiple_of (int | None): For the bshd_equivalent mode, mimics padding that would be done by the
131138
BSHD collator. Defaults to None.
139+
pad_sequences_to_be_divisible_by (int | None): If set, pads each sequence to be divisible
140+
by this number by adding padding tokens and labels set to -100. Defaults to None.
141+
This is used by context parallelism.
132142
"""
133143
self.mlm_collator = DataCollatorForLanguageModeling(
134144
tokenizer=tokenizer,
@@ -145,6 +155,10 @@ def __init__(
145155
self.return_position_ids = return_position_ids
146156
self.bshd_equivalent = bshd_equivalent
147157
self.bshd_pad_to_multiple_of = bshd_pad_to_multiple_of
158+
self.pad_sequences_to_be_divisible_by = pad_sequences_to_be_divisible_by
159+
160+
if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None:
161+
raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together")
148162

149163
if bshd_pad_to_multiple_of is not None and not bshd_equivalent:
150164
raise ValueError("bshd_pad_to_multiple_of can only be used when bshd_equivalent is True")
@@ -227,6 +241,23 @@ def __call__(self, features, return_tensors=None):
227241
if self.pad_to_multiple_of is not None:
228242
batch = self._pad_batch_to_multiple_of(batch)
229243

244+
elif self.pad_sequences_to_be_divisible_by is not None:
245+
# import pdb; pdb.set_trace()
246+
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
247+
batch["input_ids"],
248+
batch["labels"],
249+
batch["cu_seq_lens_q"],
250+
self.pad_sequences_to_be_divisible_by,
251+
padding_token_id=int(
252+
self.mlm_collator.tokenizer.pad_token_id
253+
),
254+
padding_label_id=-100,
255+
)
256+
batch["input_ids"] = input_ids_padded.unsqueeze(0)
257+
batch["labels"] = labels_padded.unsqueeze(0)
258+
batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32)
259+
batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32)
260+
230261
return batch
231262

232263
def bshd_compatible_call(self, features, return_tensors=None):
@@ -269,6 +300,36 @@ def _pad_batch_to_multiple_of(self, batch):
269300
)
270301

271302

303+
class MLMDataCollatorWithFlatteningCPAware:
304+
"""A collator that is aware of context parallelism."""
305+
def __init__(self, collator: MLMDataCollatorWithFlattening, cp_world_size: int):
306+
self.collator = collator
307+
self.cp_world_size = cp_world_size
308+
309+
def __call__(self, features):
310+
batch = self.collator(features)
311+
312+
combined_batch = []
313+
for cp_rank in range(self.cp_world_size):
314+
input_ids_sharded, labels_sharded = split_batch_by_cp_rank(
315+
cu_seqlens_padded=batch["cu_seq_lens_q_padded"],
316+
input_ids_padded=batch["input_ids"],
317+
labels_padded=batch["labels"],
318+
qvk_format="thd",
319+
cp_rank=cp_rank,
320+
cp_world_size=self.cp_world_size,
321+
)
322+
batch_shard = dict(batch)
323+
batch_shard["input_ids"] = input_ids_sharded
324+
batch_shard["labels"] = labels_sharded
325+
# Now determine the max length of the sequence.
326+
seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1]
327+
batch_shard["max_length_q"] = int((seqlens_q.max().item() + 63) // 64 * 64) # TODO(@jomitchell): Not sure if I need this anymore.
328+
batch_shard["max_length_k"] = batch_shard["max_length_q"]
329+
combined_batch.append(batch_shard)
330+
331+
return combined_batch # [<cp_rank_0_shard>, <cp_rank_1_shard>, ..., <cp_rank_n_shard>]
332+
272333
@dataclass
273334
class DataCollatorWithFlattening(DefaultDataCollator):
274335
"""Data collator for sequence packing with flash attentions cu_seqlens-style attention.
@@ -441,3 +502,104 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token
441502
)
442503

443504
return batch
505+
506+
# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387
507+
# we can replace this with the one in TransformerEngine.
508+
def split_batch_by_cp_rank(
509+
cu_seqlens_padded: torch.Tensor,
510+
input_ids_padded: torch.Tensor,
511+
labels_padded: torch.Tensor,
512+
cp_group: torch.distributed.ProcessGroup = None,
513+
qvk_format: str = "thd",
514+
cp_rank: Optional[int] = None,
515+
cp_world_size: Optional[int] = None,
516+
):
517+
"""Slice batch input along sequence dimension into multiple chunks for THD format.
518+
This function is inteded for use in self attention. It will not work for cross attention because
519+
it does not handle the case where the sequence length of the query and key are different.
520+
Which are parallelized across GPUs in a context parallel group.
521+
This version works with variable-length sequences using cumulative sequence lengths.
522+
523+
Args:
524+
cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it
525+
were executing on that rank without querying `torch.distributed.get_rank`.
526+
"""
527+
if qvk_format not in ["thd", "bshd", "sbhd"]:
528+
raise ValueError(f"Unsupported qvk_format: {qvk_format}!")
529+
if qvk_format == "thd":
530+
# Get context parallel size and rank
531+
if cp_world_size > 1:
532+
if cp_rank is None:
533+
cp_rank = torch.distributed.get_rank(group=cp_group)
534+
elif not (0 <= cp_rank < cp_world_size):
535+
raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.")
536+
537+
# Calculate the chunk sizes for each sequence
538+
total_slices_of_any_sequence = 2 * cp_world_size
539+
slice_sizes = (
540+
cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]
541+
) // total_slices_of_any_sequence
542+
543+
# Process each tensor directly instead of using keys_to_change loop
544+
def process_tensor(val):
545+
if val is None:
546+
return val
547+
# Determine which dimension is the sequence dimension
548+
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
549+
if isinstance(cu_seqlens_padded[-1], torch.Tensor):
550+
seq_len_val = cu_seqlens_padded[-1].item()
551+
else:
552+
seq_len_val = cu_seqlens_padded[-1]
553+
554+
# Handle 1D tensors (like position_ids that don't have batch dimension)
555+
if val.ndim == 1:
556+
if val.shape[0] == seq_len_val:
557+
current_seq_dim = 0
558+
else:
559+
raise ValueError(
560+
"1D tensor shape doesn't match expected sequence length. Make sure the"
561+
" inputs are in THD format and padded correctly."
562+
)
563+
elif val.ndim >= 2:
564+
if val.shape[1] == seq_len_val:
565+
current_seq_dim = 1
566+
elif val.shape[0] == seq_len_val:
567+
current_seq_dim = 0
568+
else:
569+
raise ValueError(
570+
"Make sure the inputs are in THD format and padded correctly."
571+
)
572+
else:
573+
raise ValueError("Tensor must be at least 1D")
574+
575+
# On this particular rank, for each sequence, get two slices, one from the beginning
576+
# and one from the end.
577+
cp_rank_slices = []
578+
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
579+
# 1st segment
580+
cp_rank_slices.append(
581+
torch.arange(
582+
seq_start + (cp_rank * slice_size),
583+
seq_start + ((cp_rank + 1) * slice_size),
584+
device=val.device,
585+
)
586+
)
587+
588+
# 2nd segment
589+
cp_rank_slices.append(
590+
torch.arange(
591+
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
592+
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
593+
device=val.device,
594+
)
595+
)
596+
597+
return val.index_select(current_seq_dim, torch.cat(cp_rank_slices))
598+
599+
# Process each tensor directly
600+
input_ids_padded = process_tensor(input_ids_padded)
601+
labels_padded = process_tensor(labels_padded)
602+
else:
603+
raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!")
604+
605+
return input_ids_padded, labels_padded

0 commit comments

Comments
 (0)