Skip to content

Commit 862c632

Browse files
jomitchellnvJonathan MitchellJonathan MitchellJonathan Mitchell
authored
Context Parallel Squashed MR for models dir (#1337)
Adds context parallelism support to ESM2 models inside `bionemo-recipes/models/esm2` - Adds a comprehensive unit test that checks (1) Gradients (2) Loss and (3) Logits for NON-CP vs CP execution - Adds a unit test that confirms that padded THD and unpadded THD return the same logits after the models forward pass - Adds a Collator capable of computing CP_SHARDs for a given CP rank, which is amenable for downstream scattering to other GPUs. All in pre-fetch. <!-- Provide a detailed description of the changes in this PR --> #### Usage <!--- How does a user interact with the changed code --> ```python TODO: Add code snippet ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [X] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully --------- Signed-off-by: Jonathan Mitchell <[email protected]> Signed-off-by: Jonathan Mitchell <[email protected]> Signed-off-by: Jonathan Mitchell <[email protected]> Signed-off-by: Jonathan Mitchell <[email protected]> Co-authored-by: Jonathan Mitchell <[email protected]> Co-authored-by: Jonathan Mitchell <[email protected]> Co-authored-by: Jonathan Mitchell <[email protected]>
1 parent 7c5e7ad commit 862c632

File tree

10 files changed

+880
-8
lines changed

10 files changed

+880
-8
lines changed

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import datasets
2626
import torch
27+
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp
2728
from transformers import DataCollatorForLanguageModeling, DefaultDataCollator, PreTrainedTokenizerBase
2829

2930

@@ -38,6 +39,7 @@ class MLMDataCollatorWithFlattening:
3839
2. Then applying MLM masking to the flattened sequence
3940
3. Providing Flash Attention metadata (cu_seq_lens) for sequence boundary awareness
4041
4. Optionally padding the total sequence length to be divisible by a specified number
42+
5. Optionally, pad each sequence to be divisible by a specified number (if provided).
4143
4244
The result is a THD-format batch optimized for Flash Attention with sequence packing,
4345
eliminating the need for traditional attention masks while maintaining proper sequence
@@ -62,6 +64,9 @@ class MLMDataCollatorWithFlattening:
6264
seed (int | None): Random seed for reproducible masking. Defaults to None.
6365
pad_to_multiple_of (int | None): If set, pads the total sequence length to be divisible
6466
by this number by adding a mock sequence at the end. Defaults to None.
67+
pad_sequences_to_be_divisible_by (int | None): If set, pads each sequence to be divisible
68+
by this number by adding padding tokens and labels set to -100. Defaults to None.
69+
This is used by context parallelism.
6570
6671
Example:
6772
>>> from transformers import AutoTokenizer
@@ -111,6 +116,7 @@ def __init__(
111116
return_position_ids: bool = False,
112117
bshd_equivalent: bool = False,
113118
bshd_pad_to_multiple_of: int | None = None,
119+
pad_sequences_to_be_divisible_by: int | None = None,
114120
):
115121
"""Initialize the MLMDataCollatorWithFlattening.
116122
@@ -129,6 +135,9 @@ def __init__(
129135
collator, at the expense of additional computation time. Defaults to False.
130136
bshd_pad_to_multiple_of (int | None): For the bshd_equivalent mode, mimics padding that would be done by the
131137
BSHD collator. Defaults to None.
138+
pad_sequences_to_be_divisible_by (int | None): If set, pads each sequence to be divisible
139+
by this number by adding padding tokens and labels set to -100. Defaults to None.
140+
This is used by context parallelism.
132141
"""
133142
self.mlm_collator = DataCollatorForLanguageModeling(
134143
tokenizer=tokenizer,
@@ -145,6 +154,10 @@ def __init__(
145154
self.return_position_ids = return_position_ids
146155
self.bshd_equivalent = bshd_equivalent
147156
self.bshd_pad_to_multiple_of = bshd_pad_to_multiple_of
157+
self.pad_sequences_to_be_divisible_by = pad_sequences_to_be_divisible_by
158+
159+
if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None:
160+
raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together")
148161

149162
if bshd_pad_to_multiple_of is not None and not bshd_equivalent:
150163
raise ValueError("bshd_pad_to_multiple_of can only be used when bshd_equivalent is True")
@@ -227,6 +240,20 @@ def __call__(self, features, return_tensors=None):
227240
if self.pad_to_multiple_of is not None:
228241
batch = self._pad_batch_to_multiple_of(batch)
229242

243+
elif self.pad_sequences_to_be_divisible_by is not None:
244+
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
245+
batch["input_ids"],
246+
batch["labels"],
247+
batch["cu_seq_lens_q"],
248+
self.pad_sequences_to_be_divisible_by,
249+
padding_token_id=int(self.mlm_collator.tokenizer.pad_token_id),
250+
padding_label_id=-100,
251+
)
252+
batch["input_ids"] = input_ids_padded.unsqueeze(0)
253+
batch["labels"] = labels_padded.unsqueeze(0)
254+
batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32)
255+
batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32)
256+
230257
return batch
231258

232259
def bshd_compatible_call(self, features, return_tensors=None):
@@ -269,6 +296,53 @@ def _pad_batch_to_multiple_of(self, batch):
269296
)
270297

271298

299+
class MLMDataCollatorWithFlatteningCPAware:
300+
"""A collator that is aware of context parallelism."""
301+
302+
def __init__(self, collator: MLMDataCollatorWithFlattening, cp_world_size: int):
303+
"""Initialize the MLMDataCollatorWithFlatteningCPAware.
304+
305+
Args:
306+
collator: The collator to use for masking tokens.
307+
cp_world_size: The size of the context parallelism group.
308+
"""
309+
self.collator = collator
310+
self.cp_world_size = cp_world_size
311+
312+
def __call__(self, features) -> list[dict[str, Any]]:
313+
"""Process batches of data and create shards for each context parallelism rank.
314+
315+
Args:
316+
features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'.
317+
318+
Returns:
319+
A list of dictionaries, each containing a shard of the batch for a given context parallelism rank.
320+
"""
321+
batch = self.collator(features)
322+
323+
combined_batch = []
324+
for cp_rank in range(self.cp_world_size):
325+
input_ids_sharded, labels_sharded = split_batch_by_cp_rank(
326+
cu_seqlens_padded=batch["cu_seq_lens_q_padded"],
327+
input_ids_padded=batch["input_ids"],
328+
labels_padded=batch["labels"],
329+
qvk_format="thd",
330+
cp_rank=cp_rank,
331+
cp_world_size=self.cp_world_size,
332+
)
333+
batch_shard = dict(batch)
334+
batch_shard["input_ids"] = input_ids_sharded
335+
batch_shard["labels"] = labels_sharded
336+
# Now determine the max length of the sequence.
337+
seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1]
338+
batch_shard["max_length_q"] = int((seqlens_q.max().item() + 63) // 64 * 64)
339+
batch_shard["max_length_k"] = batch_shard["max_length_q"]
340+
batch_shard["pad_between_seqs"] = True # TODO(@jomitchell): Double check this on recipe MR.
341+
combined_batch.append(batch_shard)
342+
343+
return combined_batch
344+
345+
272346
@dataclass
273347
class DataCollatorWithFlattening(DefaultDataCollator):
274348
"""Data collator for sequence packing with flash attentions cu_seqlens-style attention.
@@ -441,3 +515,108 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token
441515
)
442516

443517
return batch
518+
519+
520+
# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387
521+
# we can replace this with the one in TransformerEngine.
522+
def split_batch_by_cp_rank(
523+
cu_seqlens_padded: torch.Tensor,
524+
input_ids_padded: torch.Tensor,
525+
labels_padded: torch.Tensor,
526+
cp_group: torch.distributed.ProcessGroup = None,
527+
qvk_format: str = "thd",
528+
cp_rank: int | None = None,
529+
cp_world_size: int | None = None,
530+
):
531+
"""Slice batch input along sequence dimension into multiple chunks for THD format.
532+
533+
This function is inteded for use in self attention. It will not work for cross attention because
534+
it does not handle the case where the sequence length of the query and key are different.
535+
Which are parallelized across GPUs in a context parallel group.
536+
This version works with variable-length sequences using cumulative sequence lengths.
537+
538+
Args:
539+
cu_seqlens_padded: Cumulative sequence length.
540+
input_ids_padded: Input IDs.
541+
labels_padded: Labels.
542+
cp_group: Context parallel group.
543+
qvk_format: Format of the input data.
544+
cp_world_size: The size of the context parallelism group. If provided, the function will use this value to determine the rank.
545+
cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it
546+
were executing on that rank without querying `torch.distributed.get_rank`.
547+
"""
548+
if qvk_format not in ["thd", "bshd", "sbhd"]:
549+
raise ValueError(f"Unsupported qvk_format: {qvk_format}!")
550+
if qvk_format == "thd":
551+
# Get context parallel size and rank
552+
if cp_world_size > 1:
553+
if cp_rank is None:
554+
cp_rank = torch.distributed.get_rank(group=cp_group)
555+
elif not (0 <= cp_rank < cp_world_size):
556+
raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.")
557+
558+
# Calculate the chunk sizes for each sequence
559+
total_slices_of_any_sequence = 2 * cp_world_size
560+
slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence
561+
562+
# Process each tensor directly instead of using keys_to_change loop
563+
def process_tensor(val):
564+
if val is None:
565+
return val
566+
# Determine which dimension is the sequence dimension
567+
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
568+
if isinstance(cu_seqlens_padded[-1], torch.Tensor):
569+
seq_len_val = cu_seqlens_padded[-1].item()
570+
else:
571+
seq_len_val = cu_seqlens_padded[-1]
572+
573+
# Handle 1D tensors (like position_ids that don't have batch dimension)
574+
if val.ndim == 1:
575+
if val.shape[0] == seq_len_val:
576+
current_seq_dim = 0
577+
else:
578+
raise ValueError(
579+
"1D tensor shape doesn't match expected sequence length. Make sure the"
580+
" inputs are in THD format and padded correctly."
581+
)
582+
elif val.ndim >= 2:
583+
if val.shape[1] == seq_len_val:
584+
current_seq_dim = 1
585+
elif val.shape[0] == seq_len_val:
586+
current_seq_dim = 0
587+
else:
588+
raise ValueError("Make sure the inputs are in THD format and padded correctly.")
589+
else:
590+
raise ValueError("Tensor must be at least 1D")
591+
592+
# On this particular rank, for each sequence, get two slices, one from the beginning
593+
# and one from the end.
594+
cp_rank_slices = []
595+
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
596+
# 1st segment
597+
cp_rank_slices.append(
598+
torch.arange(
599+
seq_start + (cp_rank * slice_size),
600+
seq_start + ((cp_rank + 1) * slice_size),
601+
device=val.device,
602+
)
603+
)
604+
605+
# 2nd segment
606+
cp_rank_slices.append(
607+
torch.arange(
608+
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
609+
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
610+
device=val.device,
611+
)
612+
)
613+
614+
return val.index_select(current_seq_dim, torch.cat(cp_rank_slices))
615+
616+
# Process each tensor directly
617+
input_ids_padded = process_tensor(input_ids_padded)
618+
labels_padded = process_tensor(labels_padded)
619+
else:
620+
raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!")
621+
622+
return input_ids_padded, labels_padded

bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def forward(
180180
**kwargs: Additional arguments, see TransformersKwargs for more details.
181181
"""
182182
all_hidden_states: tuple[torch.Tensor, ...] = ()
183-
184183
has_thd_input = [
185184
x is not None
186185
for x in [
@@ -213,7 +212,11 @@ def forward(
213212
if self.config.attn_input_format == "bshd":
214213
te_rope_emb = self.rotary_embeddings(max_seq_len=hidden_states.shape[1])
215214
elif self.config.attn_input_format == "thd":
216-
te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q"][-1])
215+
te_rope_emb = self.rotary_embeddings(
216+
max_seq_len=kwargs["cu_seq_lens_q_padded"][-1]
217+
if "cu_seq_lens_q_padded" in kwargs
218+
else kwargs["cu_seq_lens_q"][-1]
219+
)
217220
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
218221

219222
for layer_module in self.layers:
@@ -226,8 +229,11 @@ def forward(
226229
rotary_pos_emb=te_rope_emb,
227230
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
228231
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
232+
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
233+
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
229234
max_seqlen_q=kwargs.get("max_length_q", None),
230235
max_seqlen_kv=kwargs.get("max_length_k", None),
236+
pad_between_seqs=kwargs.get("pad_between_seqs", None),
231237
)
232238

233239
hidden_states = self.emb_layer_norm_after(hidden_states)

bionemo-recipes/models/esm2/tests/test_collator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,25 @@ def test_mlm_data_collator_with_flattening_bshd_equivalent(tokenizer, test_prote
367367
)
368368

369369

370+
def test_mlm_data_collator_with_flattening_pad_sequences_to_be_divisible_by(tokenizer, test_proteins):
371+
"""Test MLMDataCollatorWithFlattening with pad_sequences_to_be_divisible_by."""
372+
collator = MLMDataCollatorWithFlattening(
373+
tokenizer=tokenizer,
374+
mlm_probability=0.15,
375+
pad_sequences_to_be_divisible_by=16,
376+
)
377+
features = [tokenizer(protein) for protein in test_proteins]
378+
batch = collator(features)
379+
assert batch["input_ids"].numel() % 16 == 0, (
380+
f"Expected {batch['input_ids'].numel()} tokens to be divisible by 16, got {batch['input_ids'].numel()}"
381+
)
382+
assert batch["input_ids"].shape == (1, batch["input_ids"].numel()), (
383+
f"Expected shape (1, {batch['input_ids'].numel()}), got {batch['input_ids'].shape}"
384+
)
385+
assert (batch["input_ids"][:, -1] == tokenizer.pad_token_id).all()
386+
assert (batch["labels"][:, -1] == -100).all()
387+
388+
370389
def test_token_packing_dataset():
371390
"""Test that the token packing dataset works."""
372391

0 commit comments

Comments
 (0)