Skip to content

Conversation

@savitha-eng
Copy link
Collaborator

Description

This PR adds a nucleotide tokenizer for the llama3 model, following HuggingFace's PreTrainedTokenizerFast pattern with Rust backend for performance. It follow's NeMo special token conventions (i.e., for EOS, BOS, PAD etc.).

Key features:

  • ASCII-based tokenizer that maps nucleotides (A, T, C, G, N, R, Y) to their ASCII values
  • NeMo-style special tokens: EOS=0, PAD=1, BOS=2, UNK=3
  • Pre-built tokenizer files in nucleotide_fast_tokenizer/ directory
  • create_tokenizer.py script for reproducibility
  • Comprehensive unit tests covering tokenization, padding, and attention masks

This tokenizer will be used in genomic data loading pipelines for training llama3 models on nucleotide sequences.

Usage

from transformers import AutoTokenizer

# Load the nucleotide tokenizer
tokenizer = AutoTokenizer.from_pretrained("bionemo-recipes/models/llama3/nucleotide_fast_tokenizer")

# Tokenize a nucleotide sequence
sequence = "ATCGATCG"
encoded = tokenizer.encode(sequence, add_special_tokens=True)
# Returns: [2, 65, 84, 67, 71, 65, 84, 67, 71, 0]
# [BOS, A, T, C, G, A, T, C, G, EOS]

# Batch tokenization with padding
batch = tokenizer(["AAAA", "TTTTTTTT"], padding=True, return_tensors="pt")
# Returns: {'input_ids': tensor([[2, 65, 65, 65, 65, 0, 1, 1, 1, 1],
#                                  [2, 84, 84, 84, 84, 84, 84, 84, 84, 0]]),
#           'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
#                                      [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.

  • ciflow:skip - Skip all CI tests for this PR
  • ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
  • ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
  • ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
  • ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.

Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.

For more details, see CONTRIBUTING

Note

By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 7, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@savitha-eng savitha-eng marked this pull request as ready for review November 12, 2025 08:41
@savitha-eng savitha-eng requested a review from pstjohn November 12, 2025 22:42
@savitha-eng savitha-eng force-pushed the savitha/llama3-recipes-dataloader-add-tokenizer branch from f9ee9f8 to e88bf5c Compare November 14, 2025 19:42
@pstjohn
Copy link
Collaborator

pstjohn commented Nov 17, 2025

/ok to test bc7100f

@pstjohn pstjohn enabled auto-merge November 17, 2025 14:36
…1318)

### Description

This PR adds a genomic dataset module for training Llama3 models on
nucleotide sequences, following the ESM2 native TE pattern (much of the
code is similar to the ESM2 dataset).

**Key features:**
- **Streaming Parquet datasets**: Efficient loading of large genomic
datasets using HuggingFace `datasets` library with `streaming=True` to
avoid loading entire datasets into memory
- **Windowing/Strided sampling**: Automatic creation of overlapping
windows from long genomic sequences using the tokenizer's built-in
`return_overflowing_tokens=True` parameter
- **Shuffle buffer**: Large shuffle buffer (500K samples by default) for
better randomization during streaming
- **Distributed training support**: Built-in support for multi-GPU
training with `split_dataset_by_node`
- **Causal LM collation**: Uses `DataCollatorForLanguageModeling` with
`mlm=False` for next-token prediction

**Implementation details:**
- `create_tokenized_dataset()`: Loads Parquet data, handles dataset
splits, applies tokenization with windowing
- `create_bshd_dataloader()`: Creates PyTorch DataLoader with
appropriate sampler and collator
- Supports both streaming and non-streaming modes
- Supports both lazy and eager tokenization

**Testing:**
- 8 dataset tests covering windowing, streaming, lazy tokenization, and
batch structure
- Mock data fixtures in `conftest.py` for CI/CD compatibility
- Note that the tests focus on testing the single node
behavior/functionality of the dataloader, distributed dataset tests
following the esm2 pattern are a TODO

#### Usage

```python
from dataset import create_bshd_dataloader
from distributed_config import DistributedConfig

# Configure distributed training (defaults to single GPU if env vars not set)
distributed_config = DistributedConfig()

# Create dataloader for genomic sequences
dataloader, sampler = create_bshd_dataloader(
    distributed_config=distributed_config,
    tokenizer_path="/path/to/nucleotide_fast_tokenizer",
    load_dataset_kwargs={
        "path": "parquet",
        "data_files": "/path/to/genomic_sequences.parquet",
        "split": "train",
        "streaming": True,  # Memory-efficient streaming
    },
    micro_batch_size=4,
    max_seq_length=8192,    # Window size
    stride=200,              # Overlap between windows (200 tokens)
    buffer_size=500_000,     # Shuffle buffer size
)

# Train
for batch in dataloader:
    # batch contains: input_ids, attention_mask, labels
    # labels = input_ids (for causal LM, DataCollator handles shifting)
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
```

**Key parameters:**
- `max_seq_length`: Window size for genomic sequences (e.g., 8192 for
full Llama3 context)
- `stride`: Overlap between consecutive windows in tokens (e.g., 200 for
200bp overlap)
- `streaming`: Set to `True` for large datasets to avoid loading
everything into memory
- `buffer_size`: Shuffle buffer size for streaming mode (larger = better
randomization)

### Type of changes

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Refactor
- [ ] Documentation update
- [ ] Other (please describe):

### CI Pipeline Configuration

Configure CI behavior by applying the relevant labels. By default, only
basic unit tests are run.

-
[ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip)
- Skip all CI tests for this PR
-
[ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks)
- Run Jupyter notebooks execution tests for bionemo2
-
[ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow)
- Run slow single GPU integration tests marked as @pytest.mark.slow for
bionemo2
-
[ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all)
- Run all tests (unit tests, slow tests, and notebooks) for bionemo2.
This label can be used to enforce running tests for all bionemo2.
-
[ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes)
- Run tests for all recipes (under bionemo-recipes). This label can be
used to enforce running tests for all recipes.

Unit tests marked as `@pytest.mark.multi_gpu` or
`@pytest.mark.distributed` are not run in the PR pipeline.

For more details, see [CONTRIBUTING](CONTRIBUTING.md)

> [!NOTE]
> By default, only basic unit tests are run. Add appropriate labels to
enable an additional test coverage.

#### Authorizing CI Runs

We use
[copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation)
to manage authorization of CI
runs on NVIDIA's compute resources.

- If a pull request is opened by a trusted user and contains only
trusted changes, the pull request's code will
automatically be copied to a pull-request/ prefixed branch in the source
repository (e.g. pull-request/123)
- If a pull request is opened by an untrusted user or contains untrusted
changes, an NVIDIA org member must leave an
`/ok to test` comment on the pull request to trigger CI. This will need
to be done for each new commit.

### Pre-submit Checklist

- [x] I have tested these changes locally
- [x] I have updated the documentation accordingly
- [x] I have added/updated tests as needed
- [x] All existing tests pass successfully

---------

Signed-off-by: savitha-eng <[email protected]>
pull bot pushed a commit to mahdi-shafiei/bionemo-framework that referenced this pull request Nov 17, 2025
merges NVIDIA#1318 and NVIDIA#1314 to main to start the llama3 recipe, fixes a few
pre-commit lints

---------

Signed-off-by: Peter St. John <[email protected]>
Co-authored-by: savitha-eng <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants