3131from transformers .models .esm .modeling_esm import EsmForMaskedLM # noqa: F401
3232
3333from checkpoint import load_checkpoint_fsdp2 , save_checkpoint_fsdp2 , save_final_model_fsdp2 , should_save_checkpoint
34- from dataset import create_bshd_dataloader , create_cp_dataloader , create_thd_dataloader
34+ from dataset import create_cp_dataloader
3535from distributed_config import DistributedConfig
3636from perf_logger import PerfLogger
3737from scheduler import get_linear_schedule_with_warmup
@@ -65,8 +65,6 @@ def main(args: DictConfig) -> float | None: # noqa: C901
6565 # Calculate DDP size (number of data parallel replicas)
6666 ddp_size = dist_config .world_size // args .cp_size
6767
68-
69-
7068 # Create a device mesh for DDP and CP.
7169 # The mesh is organized as [CP_dimension, DDP_dimension] where:
7270 # - DDP dimension: number of data parallel replicas (world_size // cp_size)
@@ -97,7 +95,9 @@ def main(args: DictConfig) -> float | None: # noqa: C901
9795 )
9896
9997 # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D".
100- config = AutoConfig .from_pretrained (args .model_tag , trust_remote_code = True , token_dropout = False , dtype = torch .bfloat16 )
98+ config = AutoConfig .from_pretrained (
99+ args .model_tag , trust_remote_code = True , token_dropout = False , dtype = torch .bfloat16
100+ )
101101 # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument.
102102 if args .use_sequence_packing :
103103 config .attn_input_format = "thd"
@@ -136,7 +136,6 @@ def main(args: DictConfig) -> float | None: # noqa: C901
136136 for module in model .modules ():
137137 if hasattr (module , "reset_parameters" ):
138138 module .reset_parameters ()
139-
140139
141140 # Context Parallelism requires THD Sequence Packing.
142141 assert args .use_sequence_packing , "Context Parallelism requires THD Sequence Packing."
@@ -148,7 +147,6 @@ def main(args: DictConfig) -> float | None: # noqa: C901
148147 cp_rank = cp_rank ,
149148 ** args .dataset ,
150149 )
151-
152150
153151 if args .use_torch_compile :
154152 # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
0 commit comments