Skip to content

Commit 6bfe76e

Browse files
committed
add flash-attn
Signed-off-by: Peter St. John <[email protected]>
1 parent 01a36d3 commit 6bfe76e

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed
Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
1-
# An example, minimal Dockerfile to install dependencies in a fresh python environment with CUDA support.
1+
# An example, minimal Dockerfile to install dependencies in a fresh python environment with CUDA support. This image
2+
# ends up with two copies of CUDA libraries; the first is the one installed by the base image, and the second is brought
3+
# in when we pip install torch.
24

35
FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04
46

57
ENV UV_LINK_MODE=copy
68
SHELL ["/bin/bash", "-c"]
79

8-
RUN mkdir -p /workspace && chown -R ubuntu:ubuntu /workspace
9-
USER ubuntu
10+
# Install torch, transformer-engine, and flash-attn
1011
RUN --mount=type=cache,target=/root/.cache/uv \
1112
--mount=type=cache,target=/root/.cache/pip \
12-
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
1313
--mount=from=ghcr.io/astral-sh/uv,source=/uv,target=/bin/uv \
1414
<<EOF
1515
uv venv --python 3.12 --seed /workspace/.venv
1616
source /workspace/.venv/bin/activate
1717
uv pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu130
18-
uv pip install wheel packaging
18+
uv pip install wheel packaging psutil
19+
pip install --no-build-isolation flash-attn
1920
pip install --no-build-isolation transformer-engine[pytorch]==2.9.0
20-
uv pip install -r /requirements.txt
2121
EOF
2222

23+
# Install recipe-specific dependencies
24+
RUN --mount=type=cache,target=/root/.cache/uv \
25+
--mount=type=cache,target=/root/.cache/pip \
26+
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
27+
--mount=from=ghcr.io/astral-sh/uv,source=/uv,target=/bin/uv \
28+
uv pip install -r /requirements.txt
29+
30+
USER ubuntu
31+
RUN chown -R ubuntu:ubuntu /workspace
32+
2333
ENV PATH="/workspace/.venv/bin:$PATH"
2434
WORKDIR /workspace/bionemo

bionemo-recipes/recipes/esm2_native_te/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ CUDA 13.0):
5050
uv venv --python 3.12 --seed /workspace/.venv
5151
source /workspace/.venv/bin/activate
5252
uv pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu130
53-
uv pip install wheel packaging
53+
uv pip install wheel packaging psutil
54+
pip install --no-build-isolation flash-attn
5455
pip install --no-build-isolation transformer-engine[pytorch]==2.9.0
5556
uv pip install -r /requirements.txt
5657
```

bionemo-recipes/recipes/esm2_native_te/tests/test_distributed_checkpointing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def test_checkpoint_save_and_load_single_process_ddp(recipe_path, tmp_path):
7373
overrides=[
7474
f"checkpoint.ckpt_dir={temp_dir}",
7575
f"+wandb_init_args.dir={tmp_path}",
76+
f"hydra.run.dir={tmp_path}",
7677
"num_train_steps=10",
7778
"checkpoint.save_every_n_steps=5",
7879
"checkpoint.resume_from_checkpoint=false", # Start fresh
@@ -121,6 +122,7 @@ def test_checkpoint_save_and_load_single_process_ddp(recipe_path, tmp_path):
121122
overrides=[
122123
f"checkpoint.ckpt_dir={temp_dir}",
123124
f"+wandb_init_args.dir={tmp_path}",
125+
f"hydra.run.dir={tmp_path}",
124126
"num_train_steps=15",
125127
"checkpoint.save_every_n_steps=5",
126128
"checkpoint.resume_from_checkpoint=true", # Resume from checkpoint
@@ -205,6 +207,7 @@ def test_checkpoint_save_and_load_two_processes_ddp(recipe_path, tmp_path):
205207
"checkpoint.save_every_n_steps=5",
206208
"checkpoint.resume_from_checkpoint=false", # Start fresh
207209
"dataset.use_stateful_dataloader=true",
210+
f"hydra.run.dir={tmp_path}",
208211
]
209212

210213
result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env)
@@ -268,6 +271,7 @@ def test_checkpoint_save_and_load_two_processes_ddp(recipe_path, tmp_path):
268271
"checkpoint.save_every_n_steps=5",
269272
"checkpoint.resume_from_checkpoint=true", # Resume from checkpoint
270273
"dataset.use_stateful_dataloader=true",
274+
f"hydra.run.dir={tmp_path}",
271275
]
272276

273277
result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env)
@@ -346,6 +350,7 @@ def test_checkpoint_save_and_load_single_process_mfsdp(recipe_path, tmp_path):
346350
overrides=[
347351
f"checkpoint.ckpt_dir={temp_dir}",
348352
f"+wandb_init_args.dir={tmp_path}",
353+
f"hydra.run.dir={tmp_path}",
349354
"num_train_steps=10",
350355
"checkpoint.save_every_n_steps=5",
351356
"checkpoint.resume_from_checkpoint=false", # Start fresh
@@ -390,6 +395,7 @@ def test_checkpoint_save_and_load_single_process_mfsdp(recipe_path, tmp_path):
390395
overrides=[
391396
f"checkpoint.ckpt_dir={temp_dir}",
392397
f"+wandb_init_args.dir={tmp_path}",
398+
f"hydra.run.dir={tmp_path}",
393399
"num_train_steps=15",
394400
"checkpoint.save_every_n_steps=5",
395401
"checkpoint.resume_from_checkpoint=true", # Resume from checkpoint
@@ -457,6 +463,7 @@ def test_checkpoint_save_and_load_two_processes_mfsdp(recipe_path, tmp_path):
457463
"checkpoint.save_every_n_steps=5",
458464
"checkpoint.resume_from_checkpoint=false", # Start fresh
459465
"dataset.use_stateful_dataloader=true",
466+
f"hydra.run.dir={tmp_path}",
460467
]
461468

462469
result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env)
@@ -503,6 +510,7 @@ def test_checkpoint_save_and_load_two_processes_mfsdp(recipe_path, tmp_path):
503510
"checkpoint.save_every_n_steps=5",
504511
"checkpoint.resume_from_checkpoint=true", # Resume from checkpoint
505512
"dataset.use_stateful_dataloader=true",
513+
f"hydra.run.dir={tmp_path}",
506514
]
507515

508516
result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env)
@@ -559,6 +567,7 @@ def test_checkpoint_save_and_load_single_process_fsdp2(recipe_path, tmp_path):
559567
overrides=[
560568
f"checkpoint.ckpt_dir={temp_dir}",
561569
f"+wandb_init_args.dir={tmp_path}",
570+
f"hydra.run.dir={tmp_path}",
562571
"num_train_steps=10",
563572
"checkpoint.save_every_n_steps=5",
564573
"checkpoint.resume_from_checkpoint=false", # Start fresh
@@ -668,6 +677,7 @@ def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path):
668677
"num_train_steps=10",
669678
"checkpoint.save_every_n_steps=5",
670679
"dataset.use_stateful_dataloader=true",
680+
f"hydra.run.dir={tmp_path}",
671681
]
672682

673683
result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env)
@@ -714,6 +724,7 @@ def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path):
714724
"checkpoint.save_every_n_steps=5",
715725
"checkpoint.resume_from_checkpoint=true", # Resume from checkpoint
716726
"dataset.use_stateful_dataloader=true",
727+
f"hydra.run.dir={tmp_path}",
717728
]
718729

719730
result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env)
@@ -797,6 +808,7 @@ def test_final_model_save_mfsdp(recipe_path, tmp_path):
797808
overrides=[
798809
f"checkpoint.ckpt_dir={temp_dir}",
799810
f"+wandb_init_args.dir={tmp_path}",
811+
f"hydra.run.dir={tmp_path}",
800812
"num_train_steps=3",
801813
"checkpoint.save_final_model=true",
802814
],
@@ -831,6 +843,7 @@ def test_final_model_save_fsdp2(recipe_path, tmp_path):
831843
overrides=[
832844
f"checkpoint.ckpt_dir={temp_dir}",
833845
f"+wandb_init_args.dir={tmp_path}",
846+
f"hydra.run.dir={tmp_path}",
834847
"checkpoint.save_final_model=true",
835848
"num_train_steps=3",
836849
],
@@ -874,6 +887,7 @@ def test_scheduler_resume_single_gpu(recipe_path, tmp_path):
874887
overrides=[
875888
f"checkpoint.ckpt_dir={temp_dir}",
876889
f"+wandb_init_args.dir={tmp_path}",
890+
f"hydra.run.dir={tmp_path}",
877891
"num_train_steps=10",
878892
"checkpoint.save_every_n_steps=5",
879893
"checkpoint.resume_from_checkpoint=false", # Start fresh, don't look for checkpoints
@@ -891,6 +905,7 @@ def test_scheduler_resume_single_gpu(recipe_path, tmp_path):
891905
overrides=[
892906
f"checkpoint.ckpt_dir={temp_dir}",
893907
f"+wandb_init_args.dir={tmp_path}",
908+
f"hydra.run.dir={tmp_path}",
894909
"num_train_steps=15",
895910
"checkpoint.save_every_n_steps=5",
896911
"checkpoint.resume_from_checkpoint=true", # Resume from checkpoint
@@ -951,6 +966,7 @@ def test_scheduler_resume_two_gpu(recipe_path, tmp_path):
951966
"checkpoint.resume_from_checkpoint=false", # Start fresh, don't look for checkpoints
952967
"lr_scheduler_kwargs.num_warmup_steps=20",
953968
"lr_scheduler_kwargs.num_training_steps=100",
969+
f"hydra.run.dir={tmp_path}",
954970
]
955971

956972
result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env)
@@ -974,6 +990,7 @@ def test_scheduler_resume_two_gpu(recipe_path, tmp_path):
974990
"checkpoint.resume_from_checkpoint=true", # Resume from checkpoint
975991
"lr_scheduler_kwargs.num_warmup_steps=20",
976992
"lr_scheduler_kwargs.num_training_steps=100",
993+
f"hydra.run.dir={tmp_path}",
977994
]
978995

979996
result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env)

0 commit comments

Comments
 (0)