-
Notifications
You must be signed in to change notification settings - Fork 101
Adds THD + CP for ESM2 #1320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Adds THD + CP for ESM2 #1320
Conversation
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files named utils.py 😆 !!
but more seriously, this looks like a mix of test data utilities (those should go in the tests folder) and actual useable code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can put the get_batch_on_this_cp_rank function inside dataset.py then? Since its going to take ~2 months to get it from TE after I push it: NVIDIA/TransformerEngine#2387
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
| """A dataloader that is aware of context parallelism.""" | ||
| def __init__(self, dataloader: StatefulDataLoader, | ||
| cp_group: torch.distributed.ProcessGroup, | ||
| cp_rank: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this you could probably get from torch.distributed right? rather than asking for it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cp_rank comes from the device_mesh which isn't available here
| combined_batch = [] | ||
| for cp_rank in range(self.num_cp_ranks): | ||
| input_ids_sharded, labels_sharded = get_batch_on_this_cp_rank( | ||
| cu_seqlens_padded=batch["cu_seq_lens_q_padded"], | ||
| input_ids_padded=batch["input_ids"], | ||
| labels_padded=batch["labels"], | ||
| cp_group=self.cp_group, | ||
| qvk_format="thd", | ||
| cp_rank=cp_rank, | ||
| ) | ||
| batch_shard = dict(batch) | ||
| batch_shard["input_ids"] = input_ids_sharded | ||
| batch_shard["labels"] = labels_sharded | ||
| combined_batch.append(batch_shard) | ||
| else: | ||
| combined_batch = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we wanted to do this as a dataset.map call, right? otherwise this wont be done as part of the dataloader's prefetch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh yea -- also I didn't use a generator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't a real dataloader tho -- its a wrapper class. In order to use map wouldn't that need to be a legit dataloader
| if self.config.use_cp: | ||
| hidden_states = layer_module( | ||
| hidden_states, | ||
| attention_mask, | ||
| rotary_pos_emb=te_rope_emb, | ||
| cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), | ||
| cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), | ||
| cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), | ||
| cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), | ||
| pad_between_seqs=kwargs.get("pad_between_seqs", None), | ||
| max_seqlen_q=kwargs.get("max_length_q", None), | ||
| max_seqlen_kv=kwargs.get("max_length_k", None), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this have to be a separate block? we can just pass those None values through in the non-CP case, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
|
||
| n_masked_per_seq = torch.nested.nested_tensor_from_jagged( | ||
| is_masked, offsets=kwargs["cu_seq_lens_q"] | ||
| ).sum(1) | ||
| is_masked, offsets=kwargs["cu_seq_lens_q"] | ||
| ).sum(1) | ||
| mask_ratio_observed = n_masked_per_seq.float() / src_lengths |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| def test_sanity_convergence_ddp_cp(tmp_path, recipe_path): | ||
| """Test that the main function can be invoked wrapping the model in DDP.""" | ||
|
|
||
| # Run the training script with Hydra configuration overrides | ||
| with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): | ||
| sanity_config = compose( | ||
| config_name="L0_sanity_cp", | ||
| overrides=[ | ||
| f"+wandb_init_args.dir={tmp_path}", | ||
| f"checkpoint.ckpt_dir={tmp_path}", | ||
| f"cp_size=2", | ||
| ], | ||
| ) | ||
|
|
||
| final_loss = main_ddp(sanity_config) | ||
| assert final_loss < 3.0, f"Final loss {final_loss} is too high" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait this doesn't make any sense -- wouldn't we need two GPUs for this convergence test to work? shouldn't this just hang on a single device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry this was WIP I haven't written this part yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all i did was change the functoin name lawl
| batch['pad_between_seqs'] = True | ||
| return batch | ||
|
|
||
| def _get_data_scatter_sharded(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe like, send_data_to_cp_ranks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
| buffer_size: int = 10_000, | ||
| use_stateful_dataloader: bool = False, | ||
| mlm_probability: float = 0.15, | ||
| pad_sequences_to_be_divisible_by: int | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since you know the cp_world_size, can't we initialize this to the correct value for folks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could -- but if you also want to do FP8 + CP then this would need to be higher right? Since CP=2, Divisibility_factor=4, but you would need Divisibility_factor=16 for MXFP8 right? I can set it, but also make it togggleable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's currently just set from the config
| group=self.cp_group, | ||
| group_src=0, | ||
| ) | ||
| torch.distributed.barrier(group=self.cp_group) # TODO(@jomitchell): Might not need this since its sync. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i also don't think this is the right call for an async op, i'd remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
|
|
||
|
|
||
| class CPAwareDataloader: | ||
| """A dataloader that is aware of context parallelism.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, just a quick summary of the main steps here -- This class handles synchronizing a single dataloader across multiple CP ranks. it materializes a dataloader instance on CP rank 0, which is responsible for splitting its inputs into sub-batches for each CP rank. It then uses torch.distributed.scatter to send the data to all cp ranks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
| hidden_states = layer_module( | ||
| hidden_states, | ||
| attention_mask, | ||
| rotary_pos_emb=te_rope_emb, | ||
| cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), | ||
| cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), | ||
| cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), | ||
| cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), | ||
| pad_between_seqs=kwargs.get("pad_between_seqs", None), | ||
| # TODO(@jomitchell): Add `max_seqlen_q` and `max_seqlen_kv` by finding the largest padded sequence length. torch.diff(cu_seqlens_q_padded).max().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this not also work for non-cp cases?
| te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q_padded"][-1]) | ||
| else: | ||
| te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q"][-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
id just check for a cu_seq_lens_q_padded and use cu_seq_lens_q of its not there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
| micro_batch_size: Optional[int] = None, | ||
| max_seq_length: Optional[int] = None, | ||
| padded_vocab_size: Optional[int] = 64, | ||
| use_cp: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure you need this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
context_parallel.py
| else: | ||
| assert micro_batch_size is None, "Only one of micro_batch_size or token_micro_batch_size can be provided." | ||
| assert token_micro_batch_size >= max_seq_length, "token_micro_batch_size must be greater than max_seq_length." | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # For context parallelism, we need each sequence... | |
| if pad_sequences_to_be_divisible_by is None: | |
| pad_sequences_to_be_divisible_by = 2 * cp_world_size |
| batch["labels"] = labels_padded.unsqueeze(0) | ||
| batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) | ||
| batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pop the max_seq_lens stuff here, rather than in model.forward()
| if self.config.use_cp: | ||
| te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q_padded"][-1]) | ||
| else: | ||
| te_rope_emb = self.rotary_embeddings(max_seq_len=kwargs["cu_seq_lens_q"][-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_seq_len = kwargs["cu_seq_lens_q_padded"][-1] if "cu_seq_lens_q_padded" in kwargs else ...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
| ) | ||
|
|
||
| # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". | ||
| config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, token_dropout=False, use_cp=True, dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, token_dropout=False, use_cp=True, dtype=torch.bfloat16) | |
| config = AutoConfig.from_pretrained( | |
| args.model_tag, | |
| trust_remote_code=True, | |
| token_dropout=False, # Token dropout isn't supported with CP, since it requires reduction over the entire sequence. | |
| dtype=torch.bfloat16 | |
| ) |
|
for models/esm2:
def test_thd_vs_padded_thd_equivalence(input_data_thd):
input_data_padded_thd = ...
outputs_thd = model(**input_data_thd)
outputs_padded_thd = model(**input_data_padded_thd)
torch.testing.assert_close(...)
@requires_multi_gpu
def test_grads_are_equal():
cmd = "torchrun --standalone --nproc-per-node 2 {__file__} { ?? }"
...
if __name__ == "__main__":
# argparse?
data = "some mock input protein sequence"
# run the model with no CP, get gradients
model()
# run the model with CP = 2, get gradients
set the model layers to use cp
use the collator, copy the scatter code over.
# compare gradients and logits where you have them |
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Description
Add context based parallelism to ESM2 through the addition of a CPAware Dataloader as well as several edits to the testing model file.
Usage
There's also a bunch of other changes as well that need to happen to run this.
Type of changes
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.
Unit tests marked as
@pytest.mark.multi_gpuor@pytest.mark.distributedare not run in the PR pipeline.For more details, see CONTRIBUTING
Note
By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
/ok to testcomment on the pull request to trigger CI. This will need to be done for each new commit.Pre-submit Checklist