Skip to content
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bca505b
context parallel start
Nov 10, 2025
d6d1a8d
good early signals of CP training
Nov 12, 2025
dd741bc
notes
Nov 12, 2025
39a771e
async on
Nov 13, 2025
187ca33
scatter object list implementation
Nov 13, 2025
5270370
double check
Nov 13, 2025
4ba1bc1
adds unit test for cp with no padding
Nov 14, 2025
9909e26
adds test for CP shard slicing
Nov 14, 2025
493a0e9
adds second unit test
Nov 14, 2025
f65b491
cleanup
Nov 14, 2025
1915662
x
Nov 14, 2025
877b98c
x
Nov 14, 2025
b4d422c
fix in review
Nov 14, 2025
1a58228
adds TODO for max seqlen in CP mode
Nov 14, 2025
c5b7f06
Updates CPDataset to do work in collation
Nov 14, 2025
2824357
fixes unit tests
Nov 14, 2025
942d9b0
linting
Nov 14, 2025
5b845c9
adds docs for CP
Nov 17, 2025
12e23fd
moves test
jomitchellnv Nov 17, 2025
f8baf99
removes use_cp from modeling file
Nov 17, 2025
90f2ef8
adds link to docs readme for dualchunk
Nov 17, 2025
507e2b0
adds docs to dataloader
Nov 17, 2025
6059811
adds docs to dataloader
Nov 17, 2025
faedff2
moves utils to context parallel file
Nov 17, 2025
4882dd7
utils -> context parallel file
Nov 17, 2025
58be150
adds max seqlen using correct algo
Nov 18, 2025
f40647c
fixes comment in model file
Nov 18, 2025
7286243
copies collator and modeling file to models
Nov 18, 2025
f52cea3
adds thd padded sequence equivalence test
Nov 19, 2025
805f1ee
adds test cp
Nov 19, 2025
c3ccbf9
test cp code has logit match
Nov 19, 2025
6391db7
cleanup cp test
Nov 19, 2025
c0b07da
loss value checks added
Nov 19, 2025
77d3789
linting
Nov 19, 2025
4fb3904
adds gradient comparisons to context parallel test
Nov 20, 2025
d022290
cleanup code
Nov 20, 2025
c5c74ab
linting models side
Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion bionemo-recipes/recipes/esm2_native_te/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ bionemo-framework repository. You can download a zipped directory of this folder

| Model | BF16 | FP8<sup>[1]</sup> | THD Input Format | FP8 with THD Input Format | MXFP8<sup>[2]</sup> | Context Parallelism |
| ----------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- |
| [ESM-2](../../models/esm2/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
| [ESM-2](../../models/esm2/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [AMPLIFY](../../models/amplify/README.md) | ✅ | ❌ | 🚧 | ❌ | ❌ | 🚧 |

✅: Supported <br/>
Expand Down Expand Up @@ -88,6 +88,27 @@ python train_fsdp2.py --config-name L0_sanity \
use_sequence_packing=true
```

### Context Parallelism
We provide a training script [train_ddp_cp](./esm2_native_te/train_ddp_cp.py) and a sample config [L0_sanity_cp](./hydra_config/L0_sanity_cp.yaml) that uses context parallelism.

In the config the argument `--cp_size` allows the user to set the size of the context parallel distributed group. When paired with Distributed Data Parallelism (DDP), the number of context parallel groups will be determined by `world_size//cp_size`.

Thus, for example, if a user has 8 processes and sets `cp_size=2` they will have `2` CP groups and `4` DDP groups. During dataloading we make no assumptions about the data pipeline being deterministic or not. We simply unique data only for the DDP groups and select the relevant CP shards for the respective CP group.

For example, let's say that we have 2 DDP groups and 2 CP groups. Each DDP group will have a unique dataloader DP0 for DDP group 0
and DP1 for DDP group 1. CP works by running something called ring attention, which expects tokens to live on each device in a particular layout. For this CP implementation we use something called Dual Chunk slicing. If DP0 outputs sequence `1 2 3 4 5 6 7 8` and DP1 outputs `9 10 11 12 13 14 15 16` then when we run through the `CPAwareDataloader` defined in [datasets](./dataset.py), the dataloader will create CP shards from that DP group as follows:

```
| DP0 | DP1 |
CP0 | 1,2,7,8 | 9, 10, 15, 16 |
CP1 | 3,4,5,6 | 11, 12, 13, 14|
```
You may notice these shards and wonder why they are the way they are. We did. The reason is that CP groups are sharded using slices. The full input sequence (such as `1 2 3 4 5 6 7`) is sliced into `2 * cp_size` groups. Then CP0 takes the first and last slice, while CP1 takes the middle slices, of each sequence.

In this example we only show one sequence but its important to note that slicing takes place on every sequence, so if a second sequence is also available, that will be sliced in the same manner. CP0 will take the first and last slice of every sequence, while CP1 will take the middle slices of each sequence.



### Comparing Against the HF Transformers Reference Implementation

To launch training with the ESM-2 model as implemented in HF Transformers, pass a `facebook/esm2` checkpoint as the
Expand Down
60 changes: 59 additions & 1 deletion bionemo-recipes/recipes/esm2_native_te/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@

import logging
from dataclasses import dataclass
from typing import Any
from typing import Any, Callable

import datasets
import torch
from transformers import DataCollatorForLanguageModeling, DefaultDataCollator, PreTrainedTokenizerBase
from utils import split_batch_by_cp_rank

from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp


logger = logging.getLogger(__name__)
Expand All @@ -38,6 +41,7 @@ class MLMDataCollatorWithFlattening:
2. Then applying MLM masking to the flattened sequence
3. Providing Flash Attention metadata (cu_seq_lens) for sequence boundary awareness
4. Optionally padding the total sequence length to be divisible by a specified number
5. Optionally, pad each sequence to be divisible by a specified number (if provided).

The result is a THD-format batch optimized for Flash Attention with sequence packing,
eliminating the need for traditional attention masks while maintaining proper sequence
Expand All @@ -62,6 +66,9 @@ class MLMDataCollatorWithFlattening:
seed (int | None): Random seed for reproducible masking. Defaults to None.
pad_to_multiple_of (int | None): If set, pads the total sequence length to be divisible
by this number by adding a mock sequence at the end. Defaults to None.
pad_sequences_to_be_divisible_by (int | None): If set, pads each sequence to be divisible
by this number by adding padding tokens and labels set to -100. Defaults to None.
This is used by context parallelism.

Example:
>>> from transformers import AutoTokenizer
Expand Down Expand Up @@ -111,6 +118,7 @@ def __init__(
return_position_ids: bool = False,
bshd_equivalent: bool = False,
bshd_pad_to_multiple_of: int | None = None,
pad_sequences_to_be_divisible_by: int | None = None,
):
"""Initialize the MLMDataCollatorWithFlattening.

Expand All @@ -129,6 +137,9 @@ def __init__(
collator, at the expense of additional computation time. Defaults to False.
bshd_pad_to_multiple_of (int | None): For the bshd_equivalent mode, mimics padding that would be done by the
BSHD collator. Defaults to None.
pad_sequences_to_be_divisible_by (int | None): If set, pads each sequence to be divisible
by this number by adding padding tokens and labels set to -100. Defaults to None.
This is used by context parallelism.
"""
self.mlm_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
Expand All @@ -145,6 +156,10 @@ def __init__(
self.return_position_ids = return_position_ids
self.bshd_equivalent = bshd_equivalent
self.bshd_pad_to_multiple_of = bshd_pad_to_multiple_of
self.pad_sequences_to_be_divisible_by = pad_sequences_to_be_divisible_by

if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None:
raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together")

if bshd_pad_to_multiple_of is not None and not bshd_equivalent:
raise ValueError("bshd_pad_to_multiple_of can only be used when bshd_equivalent is True")
Expand Down Expand Up @@ -227,6 +242,23 @@ def __call__(self, features, return_tensors=None):
if self.pad_to_multiple_of is not None:
batch = self._pad_batch_to_multiple_of(batch)

elif self.pad_sequences_to_be_divisible_by is not None:
# import pdb; pdb.set_trace()
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
batch["input_ids"],
batch["labels"],
batch["cu_seq_lens_q"],
self.pad_sequences_to_be_divisible_by,
padding_token_id=int(
self.mlm_collator.tokenizer.pad_token_id
),
padding_label_id=-100,
)
batch["input_ids"] = input_ids_padded.unsqueeze(0)
batch["labels"] = labels_padded.unsqueeze(0)
batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32)
batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pop the max_seq_lens stuff here, rather than in model.forward()

return batch

def bshd_compatible_call(self, features, return_tensors=None):
Expand Down Expand Up @@ -269,6 +301,32 @@ def _pad_batch_to_multiple_of(self, batch):
)


class MLMDataCollatorWithFlatteningCPAware:
"""A collator that is aware of context parallelism."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we get a bit more exposition in the docstring here about what this is doing and why?

e.g., in context parallelism we split the input sequences along the sequence dimension. this example uses torch.distributed primitives to load and split inputs on a single CP rank to avoid the requirement of creating synchronized distributed dataloaders. it returns a list of batches with length equal to the number of context parallel ranks, where batch[i] is the batch intended for cp_rank i

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea -- i put alot of that in the docs but i can put it here as well

def __init__(self, collator: MLMDataCollatorWithFlattening, cp_world_size: int):
self.collator = collator
self.cp_world_size = cp_world_size

def __call__(self, features):
batch = self.collator(features)

combined_batch = []
for cp_rank in range(self.cp_world_size):
input_ids_sharded, labels_sharded = split_batch_by_cp_rank(
cu_seqlens_padded=batch["cu_seq_lens_q_padded"],
input_ids_padded=batch["input_ids"],
labels_padded=batch["labels"],
qvk_format="thd",
cp_rank=cp_rank,
cp_world_size=self.cp_world_size,
)
batch_shard = dict(batch)
batch_shard["input_ids"] = input_ids_sharded
batch_shard["labels"] = labels_sharded
combined_batch.append(batch_shard)

return combined_batch # [<cp_rank_0_shard>, <cp_rank_1_shard>, ..., <cp_rank_n_shard>]

@dataclass
class DataCollatorWithFlattening(DefaultDataCollator):
"""Data collator for sequence packing with flash attentions cu_seqlens-style attention.
Expand Down
168 changes: 161 additions & 7 deletions bionemo-recipes/recipes/esm2_native_te/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@

import logging

from typing import Optional
import torch
import datasets
import datasets.distributed
from torch.utils.data import DataLoader, DistributedSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorForLanguageModeling

from collator import MLMDataCollatorWithFlattening, TokenPackingDataset
from collator import MLMDataCollatorWithFlattening, TokenPackingDataset, MLMDataCollatorWithFlatteningCPAware
from distributed_config import DistributedConfig


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -143,7 +144,6 @@ def create_bshd_dataloader(
collate_fn=data_collator,
num_workers=num_workers,
pin_memory=True if not use_stateful_dataloader else False,
persistent_workers=True,
)

return train_dataloader, tokenized_dataset if sampler is None else sampler
Expand All @@ -161,6 +161,7 @@ def create_thd_dataloader(
buffer_size: int = 10_000,
use_stateful_dataloader: bool = False,
mlm_probability: float = 0.15,
pad_sequences_to_be_divisible_by: int | None = None,
):
"""Create a dataloader that packs up to the maximum number of tokens per batch.

Expand All @@ -178,7 +179,8 @@ def create_thd_dataloader(
buffer_size: The buffer size to use for the distributed sampler.
use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state.
mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking.
**kwargs: Unused, here to enable kwargs to match the signature of create_bshd_dataloader.
pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value.
This is useful for context parallelism. Defaults to None.

Returns:
A dataloader that can be used for training.
Expand All @@ -203,8 +205,9 @@ def create_thd_dataloader(
data_collator = MLMDataCollatorWithFlattening(
tokenizer=tokenizer,
mlm_probability=mlm_probability,
pad_to_multiple_of=token_micro_batch_size,
pad_to_multiple_of=token_micro_batch_size if pad_sequences_to_be_divisible_by is None else None,
seed=seed,
pad_sequences_to_be_divisible_by=pad_sequences_to_be_divisible_by,
)

# TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again.
Expand All @@ -215,7 +218,158 @@ def create_thd_dataloader(
collate_fn=data_collator,
num_workers=num_workers,
pin_memory=True if not use_stateful_dataloader else False,
persistent_workers=True,
)

return train_dataloader, tokenized_dataset


def create_cp_dataloader(
distributed_config: DistributedConfig,
tokenizer_name: str,
load_dataset_kwargs: dict,
micro_batch_size: int | None = None,
token_micro_batch_size: int | None = None,
num_workers: int = 1,
max_seq_length: int = 1024,
seed: int = 42,
buffer_size: int = 10_000,
use_stateful_dataloader: bool = False,
mlm_probability: float = 0.15,
pad_sequences_to_be_divisible_by: int | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you know the cp_world_size, can't we initialize this to the correct value for folks?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could -- but if you also want to do FP8 + CP then this would need to be higher right? Since CP=2, Divisibility_factor=4, but you would need Divisibility_factor=16 for MXFP8 right? I can set it, but also make it togggleable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's currently just set from the config

cp_world_size: int = 1,
cp_group: torch.distributed.ProcessGroup = None,
cp_rank: int = 0,
):
"""Create a dataloader that packs up to the maximum number of tokens per batch. This dataload is also
amenable toward context parallelism. It produces batches of data on CP rank 0, creates shards from that data for all other
CP ranks, and then scatters the shards to the other CP ranks.


Args:
distributed_config: The distributed configuration.
tokenizer_name: The name of the tokenizer to pull from the HuggingFace Hub.
load_dataset_kwargs: Keyword arguments to pass to `load_dataset` for the train dataset.
micro_batch_size: The batch size (number of sequences) per device. This will set the token_micro_batch_size to
micro_batch_size * max_seq_length. Defaults to None.
token_micro_batch_size: The maximum number of tokens per batch. If None, the micro_batch_size * max_seq_length
will be used. Defaults to None.
num_workers: The number of workers to use for the dataloader. For iterable datasets, this should be 1.
max_seq_length: The maximum length of the protein sequences.
seed: The seed to use for the distributed sampler and data collator.
buffer_size: The buffer size to use for the distributed sampler.
use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state.
mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking.
pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value.
This is useful for context parallelism. Defaults to None.
cp_world_size: The size of the context parallel group.
cp_group: The context parallel group.
cp_rank: The rank of the current context parallel process.
Returns:
A CPAwareDataloader that can be used for training.
"""
tokenized_dataset, tokenizer = create_tokenized_dataset(
distributed_config=distributed_config,
tokenizer_name=tokenizer_name,
load_dataset_kwargs=load_dataset_kwargs,
max_seq_length=max_seq_length,
buffer_size=buffer_size,
)

assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset."
if token_micro_batch_size is None:
assert micro_batch_size is not None, "Only one of micro_batch_size or token_micro_batch_size can be provided."
token_micro_batch_size = micro_batch_size * max_seq_length
else:
assert micro_batch_size is None, "Only one of micro_batch_size or token_micro_batch_size can be provided."
assert token_micro_batch_size >= max_seq_length, "token_micro_batch_size must be greater than max_seq_length."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For context parallelism, we need each sequence...
if pad_sequences_to_be_divisible_by is None:
pad_sequences_to_be_divisible_by = 2 * cp_world_size

# For THD, we pad out to the maximum number of tokens per batch for consistent array shapes.
data_collator = MLMDataCollatorWithFlattening(
tokenizer=tokenizer,
mlm_probability=mlm_probability,
pad_to_multiple_of=token_micro_batch_size if pad_sequences_to_be_divisible_by is None else None,
seed=seed,
pad_sequences_to_be_divisible_by=pad_sequences_to_be_divisible_by,
)

data_collator = MLMDataCollatorWithFlatteningCPAware(
collator=data_collator,
cp_world_size=cp_world_size,
)

# TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again.
dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader
train_dataloader = dataloader_class(
TokenPackingDataset(tokenized_dataset, max_tokens_per_batch=token_micro_batch_size),
batch_size=None, # The TokenPackingDataset will handle the batching.
collate_fn=data_collator,
num_workers=num_workers,
pin_memory=True if not use_stateful_dataloader else False,
)

return CPAwareDataloader(train_dataloader, cp_group, cp_rank), tokenized_dataset


class CPAwareDataloader:
"""A dataloader that is aware of context parallelism."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, just a quick summary of the main steps here -- This class handles synchronizing a single dataloader across multiple CP ranks. it materializes a dataloader instance on CP rank 0, which is responsible for splitting its inputs into sub-batches for each CP rank. It then uses torch.distributed.scatter to send the data to all cp ranks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

def __init__(self, dataloader: StatefulDataLoader,
cp_group: torch.distributed.ProcessGroup,
cp_rank: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this you could probably get from torch.distributed right? rather than asking for it here?

Copy link
Collaborator Author

@jomitchellnv jomitchellnv Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cp_rank comes from the device_mesh which isn't available here

):
"""Initialize the CPAwareDataloader."""
self.dataloader = dataloader
self.cp_rank = cp_rank
self.cp_group = cp_group
self.num_cp_ranks = cp_group.size()
self._iterator = None

def __iter__(self):
"""Make the dataloader iterable."""
self._iterator = iter(self.dataloader) # < --- collator output.
return self

def __next__(self):
"""Get the batch from the dataloader for the current CP rank."""
batch = self._send_data_to_cp_ranks()
batch['pad_between_seqs'] = True
return batch

def _send_data_to_cp_ranks(self):
"""
This function will get the batch from the dataloader on CP rank 0, and then determine
the shards for all the different CP group members.
combined_batch = [<cp_rank_0_shard>, <cp_rank_1_shard>, ..., <cp_rank_n_shard>]
Then it will scatter the shards to the different CP group members.
The shards are then combined into a single batch and returned to the caller
for the current CP rank.

Scalability:
Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they do not
grow linearly with CP size.

Args:
None

Returns:
batch: The batch for the current CP rank.

"""
if self.cp_rank == 0:
# Get data once, then make copies for each rank.
if self._iterator is None:
self._iterator = iter(self.dataloader)
combined_batch = next(self._iterator)

else:
combined_batch = None

scatter_object_output_list = [None]
# Note: This does not provide an async_op handle. Thus its blocking.
torch.distributed.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=combined_batch,
group=self.cp_group,
group_src=0,
)
torch.distributed.barrier(group=self.cp_group) # TODO(@jomitchell): Might not need this since its sync.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i also don't think this is the right call for an async op, i'd remove

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

return scatter_object_output_list[0]

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"mask_token_id": 32,
"max_position_embeddings": 1026,
"max_seq_length": null,
"use_cp": false,
"micro_batch_size": null,
"model_type": "nv_esm",
"num_attention_heads": 20,
Expand Down
Loading
Loading