-
Notifications
You must be signed in to change notification settings - Fork 103
Adds THD + CP for ESM2 #1320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds THD + CP for ESM2 #1320
Changes from 18 commits
bca505b
d6d1a8d
dd741bc
39a771e
187ca33
5270370
4ba1bc1
9909e26
493a0e9
f65b491
1915662
877b98c
b4d422c
1a58228
c5b7f06
2824357
942d9b0
5b845c9
12e23fd
f8baf99
90f2ef8
507e2b0
6059811
faedff2
4882dd7
58be150
f40647c
7286243
f52cea3
805f1ee
c3ccbf9
6391db7
c0b07da
77d3789
4fb3904
d022290
c5c74ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,11 +20,14 @@ | |
|
|
||
| import logging | ||
| from dataclasses import dataclass | ||
| from typing import Any | ||
| from typing import Any, Callable | ||
|
|
||
| import datasets | ||
| import torch | ||
| from transformers import DataCollatorForLanguageModeling, DefaultDataCollator, PreTrainedTokenizerBase | ||
| from utils import split_batch_by_cp_rank | ||
|
|
||
| from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -38,6 +41,7 @@ class MLMDataCollatorWithFlattening: | |
| 2. Then applying MLM masking to the flattened sequence | ||
| 3. Providing Flash Attention metadata (cu_seq_lens) for sequence boundary awareness | ||
| 4. Optionally padding the total sequence length to be divisible by a specified number | ||
| 5. Optionally, pad each sequence to be divisible by a specified number (if provided). | ||
|
|
||
| The result is a THD-format batch optimized for Flash Attention with sequence packing, | ||
| eliminating the need for traditional attention masks while maintaining proper sequence | ||
|
|
@@ -62,6 +66,9 @@ class MLMDataCollatorWithFlattening: | |
| seed (int | None): Random seed for reproducible masking. Defaults to None. | ||
| pad_to_multiple_of (int | None): If set, pads the total sequence length to be divisible | ||
| by this number by adding a mock sequence at the end. Defaults to None. | ||
| pad_sequences_to_be_divisible_by (int | None): If set, pads each sequence to be divisible | ||
| by this number by adding padding tokens and labels set to -100. Defaults to None. | ||
| This is used by context parallelism. | ||
|
|
||
| Example: | ||
| >>> from transformers import AutoTokenizer | ||
|
|
@@ -111,6 +118,7 @@ def __init__( | |
| return_position_ids: bool = False, | ||
| bshd_equivalent: bool = False, | ||
| bshd_pad_to_multiple_of: int | None = None, | ||
| pad_sequences_to_be_divisible_by: int | None = None, | ||
| ): | ||
| """Initialize the MLMDataCollatorWithFlattening. | ||
|
|
||
|
|
@@ -129,6 +137,9 @@ def __init__( | |
| collator, at the expense of additional computation time. Defaults to False. | ||
| bshd_pad_to_multiple_of (int | None): For the bshd_equivalent mode, mimics padding that would be done by the | ||
| BSHD collator. Defaults to None. | ||
| pad_sequences_to_be_divisible_by (int | None): If set, pads each sequence to be divisible | ||
| by this number by adding padding tokens and labels set to -100. Defaults to None. | ||
| This is used by context parallelism. | ||
| """ | ||
| self.mlm_collator = DataCollatorForLanguageModeling( | ||
| tokenizer=tokenizer, | ||
|
|
@@ -145,6 +156,10 @@ def __init__( | |
| self.return_position_ids = return_position_ids | ||
| self.bshd_equivalent = bshd_equivalent | ||
| self.bshd_pad_to_multiple_of = bshd_pad_to_multiple_of | ||
| self.pad_sequences_to_be_divisible_by = pad_sequences_to_be_divisible_by | ||
|
|
||
| if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None: | ||
| raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together") | ||
|
|
||
| if bshd_pad_to_multiple_of is not None and not bshd_equivalent: | ||
| raise ValueError("bshd_pad_to_multiple_of can only be used when bshd_equivalent is True") | ||
|
|
@@ -227,6 +242,23 @@ def __call__(self, features, return_tensors=None): | |
| if self.pad_to_multiple_of is not None: | ||
| batch = self._pad_batch_to_multiple_of(batch) | ||
|
|
||
| elif self.pad_sequences_to_be_divisible_by is not None: | ||
| # import pdb; pdb.set_trace() | ||
| input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( | ||
| batch["input_ids"], | ||
| batch["labels"], | ||
| batch["cu_seq_lens_q"], | ||
| self.pad_sequences_to_be_divisible_by, | ||
| padding_token_id=int( | ||
| self.mlm_collator.tokenizer.pad_token_id | ||
| ), | ||
| padding_label_id=-100, | ||
| ) | ||
| batch["input_ids"] = input_ids_padded.unsqueeze(0) | ||
| batch["labels"] = labels_padded.unsqueeze(0) | ||
| batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) | ||
| batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) | ||
|
|
||
| return batch | ||
|
|
||
| def bshd_compatible_call(self, features, return_tensors=None): | ||
|
|
@@ -269,6 +301,32 @@ def _pad_batch_to_multiple_of(self, batch): | |
| ) | ||
|
|
||
|
|
||
| class MLMDataCollatorWithFlatteningCPAware: | ||
| """A collator that is aware of context parallelism.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we get a bit more exposition in the docstring here about what this is doing and why? e.g., in context parallelism we split the input sequences along the sequence dimension. this example uses torch.distributed primitives to load and split inputs on a single CP rank to avoid the requirement of creating synchronized distributed dataloaders. it returns a list of batches with length equal to the number of context parallel ranks, where batch[i] is the batch intended for cp_rank i
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea -- i put alot of that in the docs but i can put it here as well |
||
| def __init__(self, collator: MLMDataCollatorWithFlattening, cp_world_size: int): | ||
| self.collator = collator | ||
| self.cp_world_size = cp_world_size | ||
|
|
||
| def __call__(self, features): | ||
| batch = self.collator(features) | ||
|
|
||
| combined_batch = [] | ||
| for cp_rank in range(self.cp_world_size): | ||
| input_ids_sharded, labels_sharded = split_batch_by_cp_rank( | ||
| cu_seqlens_padded=batch["cu_seq_lens_q_padded"], | ||
| input_ids_padded=batch["input_ids"], | ||
| labels_padded=batch["labels"], | ||
| qvk_format="thd", | ||
| cp_rank=cp_rank, | ||
| cp_world_size=self.cp_world_size, | ||
| ) | ||
| batch_shard = dict(batch) | ||
| batch_shard["input_ids"] = input_ids_sharded | ||
| batch_shard["labels"] = labels_sharded | ||
| combined_batch.append(batch_shard) | ||
|
|
||
| return combined_batch # [<cp_rank_0_shard>, <cp_rank_1_shard>, ..., <cp_rank_n_shard>] | ||
|
|
||
| @dataclass | ||
| class DataCollatorWithFlattening(DefaultDataCollator): | ||
| """Data collator for sequence packing with flash attentions cu_seqlens-style attention. | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,17 +15,18 @@ | |||||||||
|
|
||||||||||
| import logging | ||||||||||
|
|
||||||||||
| from typing import Optional | ||||||||||
| import torch | ||||||||||
| import datasets | ||||||||||
| import datasets.distributed | ||||||||||
| from torch.utils.data import DataLoader, DistributedSampler | ||||||||||
| from torchdata.stateful_dataloader import StatefulDataLoader | ||||||||||
| from transformers import AutoTokenizer | ||||||||||
| from transformers.data.data_collator import DataCollatorForLanguageModeling | ||||||||||
|
|
||||||||||
| from collator import MLMDataCollatorWithFlattening, TokenPackingDataset | ||||||||||
| from collator import MLMDataCollatorWithFlattening, TokenPackingDataset, MLMDataCollatorWithFlatteningCPAware | ||||||||||
| from distributed_config import DistributedConfig | ||||||||||
|
|
||||||||||
|
|
||||||||||
| logger = logging.getLogger(__name__) | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -143,7 +144,6 @@ def create_bshd_dataloader( | |||||||||
| collate_fn=data_collator, | ||||||||||
| num_workers=num_workers, | ||||||||||
| pin_memory=True if not use_stateful_dataloader else False, | ||||||||||
| persistent_workers=True, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| return train_dataloader, tokenized_dataset if sampler is None else sampler | ||||||||||
|
|
@@ -161,6 +161,7 @@ def create_thd_dataloader( | |||||||||
| buffer_size: int = 10_000, | ||||||||||
| use_stateful_dataloader: bool = False, | ||||||||||
| mlm_probability: float = 0.15, | ||||||||||
| pad_sequences_to_be_divisible_by: int | None = None, | ||||||||||
| ): | ||||||||||
| """Create a dataloader that packs up to the maximum number of tokens per batch. | ||||||||||
|
|
||||||||||
|
|
@@ -178,7 +179,8 @@ def create_thd_dataloader( | |||||||||
| buffer_size: The buffer size to use for the distributed sampler. | ||||||||||
| use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state. | ||||||||||
| mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking. | ||||||||||
| **kwargs: Unused, here to enable kwargs to match the signature of create_bshd_dataloader. | ||||||||||
| pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value. | ||||||||||
| This is useful for context parallelism. Defaults to None. | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| A dataloader that can be used for training. | ||||||||||
|
|
@@ -203,8 +205,9 @@ def create_thd_dataloader( | |||||||||
| data_collator = MLMDataCollatorWithFlattening( | ||||||||||
| tokenizer=tokenizer, | ||||||||||
| mlm_probability=mlm_probability, | ||||||||||
| pad_to_multiple_of=token_micro_batch_size, | ||||||||||
| pad_to_multiple_of=token_micro_batch_size if pad_sequences_to_be_divisible_by is None else None, | ||||||||||
| seed=seed, | ||||||||||
| pad_sequences_to_be_divisible_by=pad_sequences_to_be_divisible_by, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again. | ||||||||||
|
|
@@ -215,7 +218,158 @@ def create_thd_dataloader( | |||||||||
| collate_fn=data_collator, | ||||||||||
| num_workers=num_workers, | ||||||||||
| pin_memory=True if not use_stateful_dataloader else False, | ||||||||||
| persistent_workers=True, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| return train_dataloader, tokenized_dataset | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def create_cp_dataloader( | ||||||||||
| distributed_config: DistributedConfig, | ||||||||||
| tokenizer_name: str, | ||||||||||
| load_dataset_kwargs: dict, | ||||||||||
| micro_batch_size: int | None = None, | ||||||||||
| token_micro_batch_size: int | None = None, | ||||||||||
| num_workers: int = 1, | ||||||||||
| max_seq_length: int = 1024, | ||||||||||
| seed: int = 42, | ||||||||||
| buffer_size: int = 10_000, | ||||||||||
| use_stateful_dataloader: bool = False, | ||||||||||
| mlm_probability: float = 0.15, | ||||||||||
| pad_sequences_to_be_divisible_by: int | None = None, | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since you know the cp_world_size, can't we initialize this to the correct value for folks?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could -- but if you also want to do FP8 + CP then this would need to be higher right? Since CP=2, Divisibility_factor=4, but you would need Divisibility_factor=16 for MXFP8 right? I can set it, but also make it togggleable.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's currently just set from the config |
||||||||||
| cp_world_size: int = 1, | ||||||||||
| cp_group: torch.distributed.ProcessGroup = None, | ||||||||||
| cp_rank: int = 0, | ||||||||||
| ): | ||||||||||
| """Create a dataloader that packs up to the maximum number of tokens per batch. This dataload is also | ||||||||||
| amenable toward context parallelism. It produces batches of data on CP rank 0, creates shards from that data for all other | ||||||||||
| CP ranks, and then scatters the shards to the other CP ranks. | ||||||||||
|
|
||||||||||
|
|
||||||||||
| Args: | ||||||||||
| distributed_config: The distributed configuration. | ||||||||||
| tokenizer_name: The name of the tokenizer to pull from the HuggingFace Hub. | ||||||||||
| load_dataset_kwargs: Keyword arguments to pass to `load_dataset` for the train dataset. | ||||||||||
| micro_batch_size: The batch size (number of sequences) per device. This will set the token_micro_batch_size to | ||||||||||
| micro_batch_size * max_seq_length. Defaults to None. | ||||||||||
| token_micro_batch_size: The maximum number of tokens per batch. If None, the micro_batch_size * max_seq_length | ||||||||||
| will be used. Defaults to None. | ||||||||||
| num_workers: The number of workers to use for the dataloader. For iterable datasets, this should be 1. | ||||||||||
| max_seq_length: The maximum length of the protein sequences. | ||||||||||
| seed: The seed to use for the distributed sampler and data collator. | ||||||||||
| buffer_size: The buffer size to use for the distributed sampler. | ||||||||||
| use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state. | ||||||||||
| mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking. | ||||||||||
| pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value. | ||||||||||
| This is useful for context parallelism. Defaults to None. | ||||||||||
| cp_world_size: The size of the context parallel group. | ||||||||||
| cp_group: The context parallel group. | ||||||||||
| cp_rank: The rank of the current context parallel process. | ||||||||||
| Returns: | ||||||||||
| A CPAwareDataloader that can be used for training. | ||||||||||
| """ | ||||||||||
| tokenized_dataset, tokenizer = create_tokenized_dataset( | ||||||||||
| distributed_config=distributed_config, | ||||||||||
| tokenizer_name=tokenizer_name, | ||||||||||
| load_dataset_kwargs=load_dataset_kwargs, | ||||||||||
| max_seq_length=max_seq_length, | ||||||||||
| buffer_size=buffer_size, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset." | ||||||||||
| if token_micro_batch_size is None: | ||||||||||
| assert micro_batch_size is not None, "Only one of micro_batch_size or token_micro_batch_size can be provided." | ||||||||||
| token_micro_batch_size = micro_batch_size * max_seq_length | ||||||||||
| else: | ||||||||||
| assert micro_batch_size is None, "Only one of micro_batch_size or token_micro_batch_size can be provided." | ||||||||||
| assert token_micro_batch_size >= max_seq_length, "token_micro_batch_size must be greater than max_seq_length." | ||||||||||
|
|
||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| # For THD, we pad out to the maximum number of tokens per batch for consistent array shapes. | ||||||||||
| data_collator = MLMDataCollatorWithFlattening( | ||||||||||
| tokenizer=tokenizer, | ||||||||||
| mlm_probability=mlm_probability, | ||||||||||
| pad_to_multiple_of=token_micro_batch_size if pad_sequences_to_be_divisible_by is None else None, | ||||||||||
| seed=seed, | ||||||||||
| pad_sequences_to_be_divisible_by=pad_sequences_to_be_divisible_by, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| data_collator = MLMDataCollatorWithFlatteningCPAware( | ||||||||||
| collator=data_collator, | ||||||||||
| cp_world_size=cp_world_size, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again. | ||||||||||
| dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader | ||||||||||
| train_dataloader = dataloader_class( | ||||||||||
| TokenPackingDataset(tokenized_dataset, max_tokens_per_batch=token_micro_batch_size), | ||||||||||
| batch_size=None, # The TokenPackingDataset will handle the batching. | ||||||||||
| collate_fn=data_collator, | ||||||||||
| num_workers=num_workers, | ||||||||||
| pin_memory=True if not use_stateful_dataloader else False, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| return CPAwareDataloader(train_dataloader, cp_group, cp_rank), tokenized_dataset | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class CPAwareDataloader: | ||||||||||
| """A dataloader that is aware of context parallelism.""" | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, just a quick summary of the main steps here -- This class handles synchronizing a single dataloader across multiple CP ranks. it materializes a dataloader instance on CP rank 0, which is responsible for splitting its inputs into sub-batches for each CP rank. It then uses torch.distributed.scatter to send the data to all cp ranks
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||||||||||
| def __init__(self, dataloader: StatefulDataLoader, | ||||||||||
| cp_group: torch.distributed.ProcessGroup, | ||||||||||
| cp_rank: int, | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this you could probably get from
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
| ): | ||||||||||
| """Initialize the CPAwareDataloader.""" | ||||||||||
| self.dataloader = dataloader | ||||||||||
| self.cp_rank = cp_rank | ||||||||||
| self.cp_group = cp_group | ||||||||||
| self.num_cp_ranks = cp_group.size() | ||||||||||
| self._iterator = None | ||||||||||
|
|
||||||||||
| def __iter__(self): | ||||||||||
| """Make the dataloader iterable.""" | ||||||||||
| self._iterator = iter(self.dataloader) # < --- collator output. | ||||||||||
| return self | ||||||||||
|
|
||||||||||
| def __next__(self): | ||||||||||
| """Get the batch from the dataloader for the current CP rank.""" | ||||||||||
| batch = self._send_data_to_cp_ranks() | ||||||||||
| batch['pad_between_seqs'] = True | ||||||||||
| return batch | ||||||||||
|
|
||||||||||
| def _send_data_to_cp_ranks(self): | ||||||||||
| """ | ||||||||||
| This function will get the batch from the dataloader on CP rank 0, and then determine | ||||||||||
| the shards for all the different CP group members. | ||||||||||
| combined_batch = [<cp_rank_0_shard>, <cp_rank_1_shard>, ..., <cp_rank_n_shard>] | ||||||||||
| Then it will scatter the shards to the different CP group members. | ||||||||||
| The shards are then combined into a single batch and returned to the caller | ||||||||||
| for the current CP rank. | ||||||||||
|
|
||||||||||
| Scalability: | ||||||||||
| Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they do not | ||||||||||
| grow linearly with CP size. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| None | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| batch: The batch for the current CP rank. | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| if self.cp_rank == 0: | ||||||||||
| # Get data once, then make copies for each rank. | ||||||||||
| if self._iterator is None: | ||||||||||
| self._iterator = iter(self.dataloader) | ||||||||||
| combined_batch = next(self._iterator) | ||||||||||
|
|
||||||||||
| else: | ||||||||||
| combined_batch = None | ||||||||||
|
|
||||||||||
| scatter_object_output_list = [None] | ||||||||||
| # Note: This does not provide an async_op handle. Thus its blocking. | ||||||||||
| torch.distributed.scatter_object_list( | ||||||||||
| scatter_object_output_list=scatter_object_output_list, | ||||||||||
| scatter_object_input_list=combined_batch, | ||||||||||
| group=self.cp_group, | ||||||||||
| group_src=0, | ||||||||||
| ) | ||||||||||
| torch.distributed.barrier(group=self.cp_group) # TODO(@jomitchell): Might not need this since its sync. | ||||||||||
|
||||||||||
| return scatter_object_output_list[0] | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pop the max_seq_lens stuff here, rather than in model.forward()