Skip to content

Conversation

@jomitchellnv
Copy link
Collaborator

@jomitchellnv jomitchellnv commented Nov 13, 2025

Description

Add context based parallelism to ESM2 through the addition of a CPAware Dataloader as well as several edits to the testing model file.

Usage

train_dataloader, dataset_or_sampler = create_cp_dataloader(dist_config, cp_world_size=torch.distributed.get_world_size(group=cp_group), cp_group=cp_group, cp_rank=cp_rank, **args.dataset)

There's also a bunch of other changes as well that need to happen to run this.

Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.

  • ciflow:skip - Skip all CI tests for this PR
  • ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
  • ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
  • ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
  • ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.

Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.

For more details, see CONTRIBUTING

Note

By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

Jonathan Mitchell added 6 commits November 12, 2025 11:59
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 13, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Jonathan Mitchell <[email protected]>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files named utils.py 😆 !!

but more seriously, this looks like a mix of test data utilities (those should go in the tests folder) and actual useable code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can put the get_batch_on_this_cp_rank function inside dataset.py then? Since its going to take ~2 months to get it from TE after I push it: NVIDIA/TransformerEngine#2387

Jonathan Mitchell added 3 commits November 14, 2025 11:30
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
@jomitchellnv jomitchellnv changed the title [DRAFT] Jm/context parallel esm2 Adds THD + CP for ESM2 Nov 14, 2025
Jonathan Mitchell added 2 commits November 14, 2025 13:39
x
Signed-off-by: Jonathan Mitchell <[email protected]>
x
Signed-off-by: Jonathan Mitchell <[email protected]>
"""A dataloader that is aware of context parallelism."""
def __init__(self, dataloader: StatefulDataLoader,
cp_group: torch.distributed.ProcessGroup,
cp_rank: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this you could probably get from torch.distributed right? rather than asking for it here?

Copy link
Collaborator Author

@jomitchellnv jomitchellnv Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cp_rank comes from the device_mesh which isn't available here

Comment on lines 279 to 294
combined_batch = []
for cp_rank in range(self.num_cp_ranks):
input_ids_sharded, labels_sharded = get_batch_on_this_cp_rank(
cu_seqlens_padded=batch["cu_seq_lens_q_padded"],
input_ids_padded=batch["input_ids"],
labels_padded=batch["labels"],
cp_group=self.cp_group,
qvk_format="thd",
cp_rank=cp_rank,
)
batch_shard = dict(batch)
batch_shard["input_ids"] = input_ids_sharded
batch_shard["labels"] = labels_sharded
combined_batch.append(batch_shard)
else:
combined_batch = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we wanted to do this as a dataset.map call, right? otherwise this wont be done as part of the dataloader's prefetch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh yea -- also I didn't use a generator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a real dataloader tho -- its a wrapper class. In order to use map wouldn't that need to be a legit dataloader

Comment on lines 227 to 239
if self.config.use_cp:
hidden_states = layer_module(
hidden_states,
attention_mask,
rotary_pos_emb=te_rope_emb,
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
pad_between_seqs=kwargs.get("pad_between_seqs", None),
max_seqlen_q=kwargs.get("max_length_q", None),
max_seqlen_kv=kwargs.get("max_length_k", None),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this have to be a separate block? we can just pass those None values through in the non-CP case, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 611 to 615

n_masked_per_seq = torch.nested.nested_tensor_from_jagged(
is_masked, offsets=kwargs["cu_seq_lens_q"]
).sum(1)
is_masked, offsets=kwargs["cu_seq_lens_q"]
).sum(1)
mask_ratio_observed = n_masked_per_seq.float() / src_lengths
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 557 to 572
def test_sanity_convergence_ddp_cp(tmp_path, recipe_path):
"""Test that the main function can be invoked wrapping the model in DDP."""

# Run the training script with Hydra configuration overrides
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
config_name="L0_sanity_cp",
overrides=[
f"+wandb_init_args.dir={tmp_path}",
f"checkpoint.ckpt_dir={tmp_path}",
f"cp_size=2",
],
)

final_loss = main_ddp(sanity_config)
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait this doesn't make any sense -- wouldn't we need two GPUs for this convergence test to work? shouldn't this just hang on a single device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry this was WIP I haven't written this part yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all i did was change the functoin name lawl

batch['pad_between_seqs'] = True
return batch

def _get_data_scatter_sharded(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe like, send_data_to_cp_ranks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Jonathan Mitchell added 3 commits November 14, 2025 14:15
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
buffer_size: int = 10_000,
use_stateful_dataloader: bool = False,
mlm_probability: float = 0.15,
pad_sequences_to_be_divisible_by: int | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you know the cp_world_size, can't we initialize this to the correct value for folks?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could -- but if you also want to do FP8 + CP then this would need to be higher right? Since CP=2, Divisibility_factor=4, but you would need Divisibility_factor=16 for MXFP8 right? I can set it, but also make it togggleable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's currently just set from the config

group=self.cp_group,
group_src=0,
)
torch.distributed.barrier(group=self.cp_group) # TODO(@jomitchell): Might not need this since its sync.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i also don't think this is the right call for an async op, i'd remove

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed



class CPAwareDataloader:
"""A dataloader that is aware of context parallelism."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, just a quick summary of the main steps here -- This class handles synchronizing a single dataloader across multiple CP ranks. it materializes a dataloader instance on CP rank 0, which is responsible for splitting its inputs into sub-batches for each CP rank. It then uses torch.distributed.scatter to send the data to all cp ranks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines 228 to 237
hidden_states = layer_module(
hidden_states,
attention_mask,
rotary_pos_emb=te_rope_emb,
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
pad_between_seqs=kwargs.get("pad_between_seqs", None),
# TODO(@jomitchell): Add `max_seqlen_q` and `max_seqlen_kv` by finding the largest padded sequence length. torch.diff(cu_seqlens_q_padded).max().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this not also work for non-cp cases?

Comment on lines 218 to 220
te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q_padded"][-1])
else:
te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q"][-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id just check for a cu_seq_lens_q_padded and use cu_seq_lens_q of its not there

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

micro_batch_size: Optional[int] = None,
max_seq_length: Optional[int] = None,
padded_vocab_size: Optional[int] = 64,
use_cp: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure you need this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context_parallel.py

else:
assert micro_batch_size is None, "Only one of micro_batch_size or token_micro_batch_size can be provided."
assert token_micro_batch_size >= max_seq_length, "token_micro_batch_size must be greater than max_seq_length."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For context parallelism, we need each sequence...
if pad_sequences_to_be_divisible_by is None:
pad_sequences_to_be_divisible_by = 2 * cp_world_size

batch["labels"] = labels_padded.unsqueeze(0)
batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32)
batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pop the max_seq_lens stuff here, rather than in model.forward()

if self.config.use_cp:
te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q_padded"][-1])
else:
te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q"][-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_seq_len = kwargs["cu_seq_lens_q_padded"][-1] if "cu_seq_lens_q_padded" in kwargs else ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

)

# Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D".
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, token_dropout=False, use_cp=True, dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, token_dropout=False, use_cp=True, dtype=torch.bfloat16)
config = AutoConfig.from_pretrained(
args.model_tag,
trust_remote_code=True,
token_dropout=False, # Token dropout isn't supported with CP, since it requires reduction over the entire sequence.
dtype=torch.bfloat16
)

@pstjohn
Copy link
Collaborator

pstjohn commented Nov 17, 2025

for models/esm2:

  • update collator, add test that the _padded keys are set correctly

test_thd.py:

def test_thd_vs_padded_thd_equivalence(input_data_thd):
    input_data_padded_thd = ...

    outputs_thd = model(**input_data_thd)
    outputs_padded_thd = model(**input_data_padded_thd)
    
    torch.testing.assert_close(...)

test_cp.py:

@requires_multi_gpu
def test_grads_are_equal():
    cmd = "torchrun --standalone --nproc-per-node 2 {__file__} { ?? }"

...

if __name__ == "__main__":

    # argparse?

    data = "some mock input protein sequence"

    # run the model with no CP, get gradients
    model()

    # run the model with CP = 2, get gradients
    set the model layers to use cp
    use the collator, copy the scatter code over.

    # compare gradients and logits where you have them 

Jonathan Mitchell and others added 18 commits November 17, 2025 14:07
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants