Skip to content

Commit 0b2d0b4

Browse files
committed
Training scripts, tests, and config for llama3; very similar to ESM2 native te
Signed-off-by: savitha-eng <[email protected]>
1 parent 5c0316e commit 0b2d0b4

File tree

10 files changed

+1599
-0
lines changed

10 files changed

+1599
-0
lines changed

bionemo-recipes/recipes/llama3/checkpoint.py

Lines changed: 566 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# @package _global_
2+
3+
# Convergence test configuration for DDP with tiny Llama model (~10M params)
4+
# Tests that the model can overfit on a small 200MB dataset
5+
# Single GPU version
6+
7+
defaults:
8+
- defaults
9+
- _self_
10+
11+
# Use tiny Llama config for fast convergence testing
12+
model_tag: /workspaces/bionemo-framework/bionemo-recipes/recipes/llama3/tiny_llama_config
13+
14+
# Training steps - enough to see convergence on small dataset
15+
num_train_steps: 1000
16+
17+
# Dataset configuration - use 2MB subset
18+
dataset:
19+
tokenizer_path: /workspaces/bionemo-framework/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer
20+
micro_batch_size: 1 # Conservative for single GPU
21+
num_workers: 2
22+
max_seq_length: 8192 # Full Llama3 context length
23+
stride: 400 # 400bp overlap for 8K context
24+
buffer_size: 10_000 # Smaller buffer for faster iteration
25+
use_lazy_tokenization: true
26+
load_dataset_kwargs:
27+
path: "parquet"
28+
data_files: "/workspaces/bionemo-framework/data/genomic_sequences_2mb.parquet"
29+
split: "train"
30+
streaming: true # Use streaming to avoid loading entire dataset into memory
31+
32+
# Optimizer - higher LR for faster convergence on small model
33+
adamw_kwargs:
34+
lr: 5e-4 # Higher than default for faster convergence
35+
fused: true
36+
betas: [0.9, 0.98]
37+
eps: 1e-8
38+
weight_decay: 0.01
39+
40+
# Learning rate scheduler
41+
lr_scheduler_kwargs:
42+
num_warmup_steps: 100 # Quick warmup (10% of training)
43+
num_training_steps: 1000
44+
45+
# Checkpoint configuration - disabled for fast convergence testing
46+
checkpoint:
47+
ckpt_dir: null # No checkpoints
48+
save_final_model: false # Don't save final model
49+
resume_from_checkpoint: false # Start fresh for convergence test
50+
save_every_n_steps: null # No intermediate checkpoints
51+
52+
# Logging - frequent logging to track convergence
53+
logger:
54+
frequency: 10 # Log every 10 steps
55+
56+
# WandB configuration
57+
wandb_init_args:
58+
project: "llama3-genomic-convergence"
59+
name: "tiny-llama-ddp-convergence-test"
60+
mode: "online" # Online mode for real-time dashboard
61+
tags:
62+
- convergence-test
63+
- ddp
64+
- tiny-model
65+
- 10M-params
66+
- single-gpu
67+
- 8192-context
68+
69+
# Meta device and torch compile
70+
use_meta_device: false
71+
use_torch_compile: false # Disable for debugging
72+
73+
# FP8 configuration - disabled for convergence testing
74+
fp8_config:
75+
enabled: false
76+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
77+
fp8_format: "HYBRID"
78+
fp8_recipe_kwargs: {}
79+
fp8_model_init_kwargs:
80+
enabled: false
81+
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# @package _global_
2+
3+
# Convergence test configuration for FSDP2 with tiny Llama model (~10M params)
4+
# Tests that the model can overfit on a small 200MB dataset
5+
# Works with single GPU (no sharding) or multi-GPU (sharded)
6+
7+
defaults:
8+
- defaults
9+
- _self_
10+
11+
# Use tiny Llama config for fast convergence testing
12+
model_tag: /workspaces/bionemo-framework/bionemo-recipes/recipes/llama3/tiny_llama_config
13+
14+
# Training steps - enough to see convergence on small dataset
15+
num_train_steps: 1000
16+
17+
# Dataset configuration - use 2MB subset
18+
dataset:
19+
tokenizer_path: /workspaces/bionemo-framework/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer
20+
micro_batch_size: 1 # Conservative for single GPU
21+
num_workers: 2
22+
max_seq_length: 8192 # Full Llama3 context length
23+
stride: 400 # 400bp overlap for 8K context
24+
buffer_size: 10_000 # Smaller buffer for faster iteration
25+
use_lazy_tokenization: true
26+
load_dataset_kwargs:
27+
path: "parquet"
28+
data_files: "/workspaces/bionemo-framework/data/genomic_sequences_2mb.parquet"
29+
split: "train"
30+
streaming: true # Use streaming to avoid loading entire dataset into memory
31+
32+
# Optimizer - higher LR for faster convergence on small model
33+
adamw_kwargs:
34+
lr: 5e-4 # Higher than default for faster convergence
35+
fused: true
36+
betas: [0.9, 0.98]
37+
eps: 1e-8
38+
weight_decay: 0.01
39+
40+
# Learning rate scheduler
41+
lr_scheduler_kwargs:
42+
num_warmup_steps: 100 # Quick warmup (10% of training)
43+
num_training_steps: 1000
44+
45+
# Checkpoint configuration - disabled for fast convergence testing
46+
checkpoint:
47+
ckpt_dir: null # No checkpoints
48+
save_final_model: false # Don't save final model
49+
resume_from_checkpoint: false # Start fresh for convergence test
50+
save_every_n_steps: null # No intermediate checkpoints
51+
52+
# Logging - frequent logging to track convergence
53+
logger:
54+
frequency: 10 # Log every 10 steps
55+
56+
# WandB configuration
57+
wandb_init_args:
58+
project: "llama3-genomic-convergence"
59+
name: "tiny-llama-fsdp2-convergence-test"
60+
mode: "online" # Online mode for real-time dashboard
61+
tags:
62+
- convergence-test
63+
- fsdp2
64+
- tiny-model
65+
- 10M-params
66+
- single-node
67+
- 8192-context
68+
69+
# Meta device and torch compile
70+
use_meta_device: false
71+
use_torch_compile: false # Disable for debugging
72+
73+
# FP8 configuration - disabled for convergence testing
74+
fp8_config:
75+
enabled: false
76+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
77+
fp8_format: "HYBRID"
78+
fp8_recipe_kwargs: {}
79+
fp8_model_init_kwargs:
80+
enabled: false
81+
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
defaults:
2+
- defaults
3+
- _self_
4+
5+
# Training config
6+
model_tag: ./small_llama_config # Use small Llama config for testing (4 layers, 2048 hidden)
7+
num_train_steps: 250
8+
9+
# We want this on in CI/CD to validate that the script runs successfully with torch.compile.
10+
use_torch_compile: false # Disable for faster startup during testing
11+
12+
dataset:
13+
tokenizer_path: /workspaces/bionemo-framework/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer
14+
micro_batch_size: 1 # Small batch size for limited GPU memory
15+
num_workers: 1
16+
max_seq_length: 1024 # Smaller window for testing
17+
stride: 100 # Smaller stride for testing
18+
buffer_size: 10_000 # Smaller buffer for testing
19+
use_lazy_tokenization: true
20+
load_dataset_kwargs:
21+
path: "parquet"
22+
split: "train"
23+
data_files: "test_genomic_sequences.parquet" # Use local test file for now
24+
25+
26+
# WandB config
27+
wandb_init_args:
28+
name: "llama3_8B_genomic_sanity"
29+
mode: "offline"
30+
31+
# Learning rate scheduler config
32+
lr_scheduler_kwargs:
33+
num_warmup_steps: 10 # Shorter warmup for quick testing
34+
num_training_steps: 250 # Match num_train_steps
35+
36+
checkpoint:
37+
ckpt_dir: null
38+
resume_from_checkpoint: true
39+
save_every_n_steps: 50
40+
save_final_model: false
41+
42+
logger:
43+
frequency: 1
44+
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Training config
2+
model_tag: ??? # E.g., meta-llama/Meta-Llama-3-8B or a local path
3+
num_train_steps: ???
4+
5+
# TODO: Once BIONEMO-2583 and BIONEMO-2719 are fixed, enable this by default and simplify training scripts to remove the
6+
# meta-device conditional.
7+
use_meta_device: false
8+
9+
# Whether to wrap the model in torch.compile. Note, this is currently not supported with mfsdp (BIONEMO-2977).
10+
# We leave this off by default since we don't see much of a performance improvement with TE layers.
11+
use_torch_compile: false
12+
13+
dataset:
14+
tokenizer_path: /workspaces/bionemo-framework/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer
15+
micro_batch_size: ???
16+
num_workers: 1
17+
max_seq_length: 8192 # Window size for genomic sequences
18+
stride: 200 # Overlap for windowing
19+
buffer_size: 500_000 # Shuffle buffer size
20+
use_lazy_tokenization: true
21+
load_dataset_kwargs:
22+
path: "parquet"
23+
split: "train"
24+
streaming: True
25+
26+
# WandB config
27+
wandb_init_args:
28+
name: ???
29+
30+
# mFSDP config
31+
fully_shard_kwargs:
32+
zero_dp_strategy: "optim_grads_params"
33+
calculate_per_token_loss: false
34+
init_model_with_meta_device: ${use_meta_device}
35+
check_for_nan_in_grad: true
36+
grad_reduce_in_fp32: false
37+
preserve_fp32_weights: true
38+
overlap_grad_reduce: true
39+
overlap_param_gather: true
40+
sync_model_each_microbatch: true
41+
average_in_collective: false
42+
43+
# TransformerEngine FP8 config. See
44+
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more information on
45+
# supported formats.
46+
fp8_config:
47+
enabled: false
48+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
49+
fp8_format: "HYBRID"
50+
fp8_recipe_kwargs: {}
51+
fp8_model_init_kwargs:
52+
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.
53+
54+
# Optimizer config
55+
adamw_kwargs:
56+
lr: 4e-4
57+
fused: true
58+
betas: [0.9, 0.98]
59+
eps: 1e-8
60+
weight_decay: 0.01
61+
62+
# Learning rate scheduler config
63+
lr_scheduler_kwargs:
64+
num_warmup_steps: 2_000
65+
num_training_steps: 500_000
66+
67+
# Checkpoint config
68+
checkpoint:
69+
ckpt_dir: ???
70+
save_final_model: true
71+
resume_from_checkpoint: true
72+
save_every_n_steps: 50
73+
74+
logger:
75+
frequency: 100
76+
77+
78+

0 commit comments

Comments
 (0)