Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
565 changes: 565 additions & 0 deletions bionemo-recipes/recipes/llama3_native_te/checkpoint.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Example Tiny Llama3 Checkpoint

This directory contains the model and tokenizer configuration for a tiny Llama3 model (~1M parameters) optimized for fast convergence testing on genomic sequences. This checkpoint is designed for quick sanity checks and convergence tests.

## Contents

- **config.json**: Model configuration for a tiny Llama3 model (4 layers, 384 hidden size)
- **tokenizer.json**: Fast tokenizer for nucleotide sequences (256 vocab size)
- **tokenizer_config.json**: Tokenizer configuration
- **special_tokens_map.json**: Special tokens mapping (EOS=0, PAD=1, BOS=2, UNK=3)

## Usage

Use this directory as the `model_tag` in your training configurations:

```yaml
# In your hydra config (e.g., L0_convergence configs)
model_tag: ./example_tiny_llama_checkpoint

dataset:
tokenizer_path: ./example_tiny_llama_checkpoint # Same directory for tokenizer
```

This eliminates the need for absolute paths and makes configurations portable across different environments.

## Model Parameters

- Layers: 4
- Hidden size: 384
- Attention heads: 6
- Intermediate size: 1536
- Vocabulary size: 256 (nucleotide tokenizer)
- Max position embeddings: 8192

Perfect for fast convergence testing where you want to verify the model can overfit on small datasets.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 2,
"eos_token_id": 0,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 384,
"initializer_range": 0.02,
"intermediate_size": 1536,
"max_position_embeddings": 8192,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 6,
"num_hidden_layers": 4,
"num_key_value_heads": 6,
"pad_token_id": 1,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"transformers_version": "4.57.1",
"use_cache": true,
"vocab_size": 256
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# @package _global_

# Convergence test configuration with tiny Llama model (~1M params)
# Tests that the model can overfit on a small dataset
# Works with both DDP and FSDP2, single GPU or multi-GPU

defaults:
- defaults
- _self_

# Use tiny Llama config for fast convergence testing
model_tag: ./example_checkpoint

# Training steps - enough to see convergence on small dataset
num_train_steps: 1000

# Dataset configuration - use small test dataset
dataset:
tokenizer_path: ./example_checkpoint # Tokenizer included in checkpoint directory
micro_batch_size: 1 # Conservative for single GPU
num_workers: 2
max_seq_length: 8192 # Full Llama3 context length
stride: 400 # 400bp overlap for 8K context
buffer_size: 10_000 # Smaller buffer for faster iteration
use_lazy_tokenization: true
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
load_dataset_kwargs:
path: "parquet"
data_files: "genomic_sequences_2mb.parquet" # 2MB convergence test data in recipe directory
split: "train"
streaming: true # Use streaming to avoid loading entire dataset into memory

# Optimizer - higher LR for faster convergence on small model
adamw_kwargs:
lr: 5e-4 # Higher than default for faster convergence
fused: true
betas: [0.9, 0.98]
eps: 1e-8
weight_decay: 0.01

# Learning rate scheduler
lr_scheduler_kwargs:
num_warmup_steps: 100 # Quick warmup (10% of training)
num_training_steps: 1000

# Checkpoint configuration - disabled for fast convergence testing
checkpoint:
ckpt_dir: null # No checkpoints
save_final_model: false # Don't save final model
resume_from_checkpoint: false # Start fresh for convergence test
save_every_n_steps: null # No intermediate checkpoints

# Logging - frequent logging to track convergence
logger:
frequency: 10 # Log every 10 steps

# WandB configuration
wandb_init_args:
project: "llama3-genomic-convergence"
name: "tiny-llama-convergence-test"
mode: "online" # Online mode for real-time dashboard
tags:
- convergence-test
- tiny-model
- 1M-params
- 8192-context

# Meta device and torch compile
use_meta_device: false
use_torch_compile: false # Disable for debugging

# FP8 configuration - disabled for convergence testing
fp8_config:
enabled: false
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
fp8_format: "HYBRID"
fp8_recipe_kwargs: {}
fp8_model_init_kwargs:
enabled: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
defaults:
- defaults
- _self_

# Training config
model_tag: ./example_checkpoint # Use tiny Llama config for testing (4 layers, 384 hidden, ~9.6M params)
num_train_steps: 250

# We want this on in CI/CD to validate that the script runs successfully with torch.compile.
use_torch_compile: false # Disable for faster startup during testing

dataset:
tokenizer_path: ./example_checkpoint # Tokenizer included in checkpoint directory
micro_batch_size: 1 # Small batch size for limited GPU memory
num_workers: 1
max_seq_length: 1024 # Smaller window for testing
stride: 100 # Smaller stride for testing
buffer_size: 10_000 # Smaller buffer for testing
use_lazy_tokenization: true
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
load_dataset_kwargs:
path: "parquet"
split: "train"
data_files: "test_genomic_sequences.parquet" # Use local test file in recipe directory


# WandB config
wandb_init_args:
name: "llama3_8B_genomic_sanity"
mode: "offline"
project: null # Set to null by default, override with +wandb_init_args.project=your-project

# Learning rate scheduler config
lr_scheduler_kwargs:
num_warmup_steps: 10 # Shorter warmup for quick testing
num_training_steps: 250 # Match num_train_steps

checkpoint:
ckpt_dir: null
resume_from_checkpoint: true
save_every_n_steps: 50
save_final_model: false

logger:
frequency: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Training config
model_tag: ??? # E.g., meta-llama/Meta-Llama-3-8B or a local path
num_train_steps: ???

# TODO: Once BIONEMO-2583 and BIONEMO-2719 are fixed, enable this by default and simplify training scripts to remove the
# meta-device conditional.
use_meta_device: false

# Whether to wrap the model in torch.compile. Note, this is currently not supported with mfsdp (BIONEMO-2977).
# We leave this off by default since we don't see much of a performance improvement with TE layers.
use_torch_compile: false

# Whether to use gradient checkpointing to trade compute for memory
use_gradient_checkpointing: false

dataset:
tokenizer_path: ./example_checkpoint # Set to the path of your tokenizer (e.g., ./example_checkpoint)
micro_batch_size: 8
num_workers: 1
max_seq_length: 8192 # Window size for genomic sequences
stride: 200 # Overlap for windowing
buffer_size: 500_000 # Shuffle buffer size
use_lazy_tokenization: true
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
load_dataset_kwargs:
path: "parquet"
split: "train"
streaming: True

# WandB config
wandb_init_args:
name: ???
project: null # Optional: set to your wandb project name

# mFSDP config
fully_shard_kwargs:
zero_dp_strategy: "optim_grads_params"
calculate_per_token_loss: false
init_model_with_meta_device: ${use_meta_device}
check_for_nan_in_grad: true
grad_reduce_in_fp32: false
preserve_fp32_weights: true
overlap_grad_reduce: true
overlap_param_gather: true
sync_model_each_microbatch: true
average_in_collective: false

# TransformerEngine FP8 config. See
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more information on
# supported formats.
fp8_config:
enabled: false
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
fp8_format: "HYBRID"
fp8_recipe_kwargs: {}
fp8_model_init_kwargs:
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.

# Optimizer config
adamw_kwargs:
lr: 4e-4
fused: true
betas: [0.9, 0.98]
eps: 1e-8
weight_decay: 0.01

# Learning rate scheduler config
lr_scheduler_kwargs:
num_warmup_steps: 2_000
num_training_steps: 500_000

# Checkpoint config
checkpoint:
ckpt_dir: ???
save_final_model: true
resume_from_checkpoint: true
save_every_n_steps: 50

logger:
frequency: 100
Loading
Loading