Skip to content

Commit 0d0f700

Browse files
committed
linting
Signed-off-by: Jonathan Mitchell <[email protected]>
1 parent f964f17 commit 0d0f700

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

bionemo-recipes/recipes/esm2_native_te/tests/test_train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ def test_sanity_ddp_thd_token_packing_huggingface_model(tmp_path, recipe_path):
555555

556556
main_ddp(sanity_config)
557557

558+
558559
def test_sanity_fsdp2_cp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
559560
if torch.cuda.get_device_capability() == (12, 0):
560561
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
@@ -573,4 +574,4 @@ def test_sanity_fsdp2_cp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
573574
],
574575
)
575576

576-
main_fsdp2_cp(sanity_config)
577+
main_fsdp2_cp(sanity_config)

bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def test_multi_gpu_train_te_ddp_cp(tmp_path, recipe_path):
146146
recipe_path,
147147
)
148148

149+
149150
@requires_multi_gpu
150151
@requires_datacenter_hardware
151152
def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path):
@@ -161,4 +162,4 @@ def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path):
161162
"cp_size=2",
162163
],
163164
recipe_path,
164-
)
165+
)

bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401
3232

3333
from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint
34-
from dataset import create_bshd_dataloader, create_cp_dataloader, create_thd_dataloader
34+
from dataset import create_cp_dataloader
3535
from distributed_config import DistributedConfig
3636
from perf_logger import PerfLogger
3737
from scheduler import get_linear_schedule_with_warmup
@@ -65,8 +65,6 @@ def main(args: DictConfig) -> float | None: # noqa: C901
6565
# Calculate DDP size (number of data parallel replicas)
6666
ddp_size = dist_config.world_size // args.cp_size
6767

68-
69-
7068
# Create a device mesh for DDP and CP.
7169
# The mesh is organized as [CP_dimension, DDP_dimension] where:
7270
# - DDP dimension: number of data parallel replicas (world_size // cp_size)
@@ -97,7 +95,9 @@ def main(args: DictConfig) -> float | None: # noqa: C901
9795
)
9896

9997
# Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D".
100-
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, token_dropout=False, dtype=torch.bfloat16)
98+
config = AutoConfig.from_pretrained(
99+
args.model_tag, trust_remote_code=True, token_dropout=False, dtype=torch.bfloat16
100+
)
101101
# If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument.
102102
if args.use_sequence_packing:
103103
config.attn_input_format = "thd"
@@ -136,7 +136,6 @@ def main(args: DictConfig) -> float | None: # noqa: C901
136136
for module in model.modules():
137137
if hasattr(module, "reset_parameters"):
138138
module.reset_parameters()
139-
140139

141140
# Context Parallelism requires THD Sequence Packing.
142141
assert args.use_sequence_packing, "Context Parallelism requires THD Sequence Packing."
@@ -148,7 +147,6 @@ def main(args: DictConfig) -> float | None: # noqa: C901
148147
cp_rank=cp_rank,
149148
**args.dataset,
150149
)
151-
152150

153151
if args.use_torch_compile:
154152
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.

0 commit comments

Comments
 (0)