diff --git a/bionemo-recipes/models/llama3/create_tokenizer.py b/bionemo-recipes/models/llama3/create_tokenizer.py new file mode 100644 index 000000000..a2d32bde7 --- /dev/null +++ b/bionemo-recipes/models/llama3/create_tokenizer.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 + +"""Script to create the HuggingFace PreTrainedTokenizerFast for nucleotide sequences. + +This script creates a tokenizer that: +1. Maps each character to its ord() value (ASCII encoding) +2. Uses special tokens with NeMo convention (EOS=0, PAD=1, BOS=2, UNK=3) +3. Works with AutoTokenizer.from_pretrained() + +Run this script to regenerate the tokenizer files if needed. +""" + +import logging +import os + +from tokenizers import Tokenizer, processors +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Split +from transformers import PreTrainedTokenizerFast + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def create_nucleotide_tokenizer( + eos_id: int = 0, + pad_id: int = 1, + bos_id: int = 2, + unk_id: int = 3, +) -> PreTrainedTokenizerFast: + """Create a PreTrainedTokenizerFast for nucleotide sequences. + + Uses special token IDs for causal language modeling: + - BOS = 2 (beginning of sequence) + - EOS = 0 (end of sequence) + - PAD = 1 (padding) + - UNK = 3 (unknown) + + Args: + eos_id: End-of-sequence token ID (default: 0) + pad_id: Padding token ID (default: 1) + bos_id: Beginning-of-sequence token ID (default: 2) + unk_id: Unknown token ID (default: 3) + + Returns: + PreTrainedTokenizerFast ready to use and save + """ + # Define special tokens + special_tokens = { + "": bos_id, + "": eos_id, + "": pad_id, + "": unk_id, + } + + # Build vocab: Map each ASCII character to its ord() value + # IMPORTANT: Exclude reserved IDs for special tokens + reserved_ids = set(special_tokens.values()) + vocab = {chr(i): i for i in range(256) if i not in reserved_ids} + vocab = {**vocab, **special_tokens} + + # Create Rust tokenizer backend with WordLevel model + tokenizer = Tokenizer(WordLevel(vocab, unk_token="")) + + # Configure pre-tokenizer: Split into individual characters + tokenizer.pre_tokenizer = Split(pattern="", behavior="isolated") + + # Configure post-processor: Add BOS/EOS tokens automatically + tokenizer.post_processor = processors.TemplateProcessing( + single=" $A ", + pair=" $A $B ", + special_tokens=[ + ("", bos_id), + ("", eos_id), + ], + ) + + # Wrap in HuggingFace PreTrainedTokenizerFast + hf_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + unk_token="", + pad_token="", + eos_token="", + bos_token="", + ) + + return hf_tokenizer + + +def main(): + """Create and save the nucleotide tokenizer.""" + logger.info("Creating nucleotide tokenizer") + + # Create tokenizer with default settings (BOS=2, EOS=0, PAD=1, UNK=3) + tokenizer = create_nucleotide_tokenizer() + + logger.info(f"Vocab size: {tokenizer.vocab_size}") + logger.info( + f"Special tokens: BOS={tokenizer.bos_token_id}, EOS={tokenizer.eos_token_id}, PAD={tokenizer.pad_token_id}, UNK={tokenizer.unk_token_id}" + ) + + # Save to default location + save_path = os.path.join(os.path.dirname(__file__), "nucleotide_fast_tokenizer") + tokenizer.save_pretrained(save_path) + logger.info(f"Tokenizer saved to: {save_path}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/special_tokens_map.json b/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/special_tokens_map.json new file mode 100644 index 000000000..a1e19488e --- /dev/null +++ b/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/special_tokens_map.json @@ -0,0 +1,6 @@ +{ + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "" +} diff --git a/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/tokenizer.json b/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/tokenizer.json new file mode 100644 index 000000000..5dbce19e6 --- /dev/null +++ b/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/tokenizer.json @@ -0,0 +1,396 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Split", + "pattern": { + "String": "" + }, + "behavior": "Isolated", + "invert": false + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "special_tokens": { + "": { + "id": "", + "ids": [ + 2 + ], + "tokens": [ + "" + ] + }, + "": { + "id": "", + "ids": [ + 0 + ], + "tokens": [ + "" + ] + } + } + }, + "decoder": null, + "model": { + "type": "WordLevel", + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3, + "\u0004": 4, + "\u0005": 5, + "\u0006": 6, + "\u0007": 7, + "\b": 8, + "\t": 9, + "\n": 10, + "\u000b": 11, + "\f": 12, + "\r": 13, + "\u000e": 14, + "\u000f": 15, + "\u0010": 16, + "\u0011": 17, + "\u0012": 18, + "\u0013": 19, + "\u0014": 20, + "\u0015": 21, + "\u0016": 22, + "\u0017": 23, + "\u0018": 24, + "\u0019": 25, + "\u001a": 26, + "\u001b": 27, + "\u001c": 28, + "\u001d": 29, + "\u001e": 30, + "\u001f": 31, + " ": 32, + "!": 33, + "\"": 34, + "#": 35, + "$": 36, + "%": 37, + "&": 38, + "'": 39, + "(": 40, + ")": 41, + "*": 42, + "+": 43, + ",": 44, + "-": 45, + ".": 46, + "/": 47, + "0": 48, + "1": 49, + "2": 50, + "3": 51, + "4": 52, + "5": 53, + "6": 54, + "7": 55, + "8": 56, + "9": 57, + ":": 58, + ";": 59, + "<": 60, + "=": 61, + ">": 62, + "?": 63, + "@": 64, + "A": 65, + "B": 66, + "C": 67, + "D": 68, + "E": 69, + "F": 70, + "G": 71, + "H": 72, + "I": 73, + "J": 74, + "K": 75, + "L": 76, + "M": 77, + "N": 78, + "O": 79, + "P": 80, + "Q": 81, + "R": 82, + "S": 83, + "T": 84, + "U": 85, + "V": 86, + "W": 87, + "X": 88, + "Y": 89, + "Z": 90, + "[": 91, + "\\": 92, + "]": 93, + "^": 94, + "_": 95, + "`": 96, + "a": 97, + "b": 98, + "c": 99, + "d": 100, + "e": 101, + "f": 102, + "g": 103, + "h": 104, + "i": 105, + "j": 106, + "k": 107, + "l": 108, + "m": 109, + "n": 110, + "o": 111, + "p": 112, + "q": 113, + "r": 114, + "s": 115, + "t": 116, + "u": 117, + "v": 118, + "w": 119, + "x": 120, + "y": 121, + "z": 122, + "{": 123, + "|": 124, + "}": 125, + "~": 126, + "": 127, + "€": 128, + "": 129, + "‚": 130, + "ƒ": 131, + "„": 132, + "…": 133, + "†": 134, + "‡": 135, + "ˆ": 136, + "‰": 137, + "Š": 138, + "‹": 139, + "Œ": 140, + "": 141, + "Ž": 142, + "": 143, + "": 144, + "‘": 145, + "’": 146, + "“": 147, + "”": 148, + "•": 149, + "–": 150, + "—": 151, + "˜": 152, + "™": 153, + "š": 154, + "›": 155, + "œ": 156, + "": 157, + "ž": 158, + "Ÿ": 159, + " ": 160, + "¡": 161, + "¢": 162, + "£": 163, + "¤": 164, + "¥": 165, + "¦": 166, + "§": 167, + "¨": 168, + "©": 169, + "ª": 170, + "«": 171, + "¬": 172, + "­": 173, + "®": 174, + "¯": 175, + "°": 176, + "±": 177, + "²": 178, + "³": 179, + "´": 180, + "µ": 181, + "¶": 182, + "·": 183, + "¸": 184, + "¹": 185, + "º": 186, + "»": 187, + "¼": 188, + "½": 189, + "¾": 190, + "¿": 191, + "À": 192, + "Á": 193, + "Â": 194, + "Ã": 195, + "Ä": 196, + "Å": 197, + "Æ": 198, + "Ç": 199, + "È": 200, + "É": 201, + "Ê": 202, + "Ë": 203, + "Ì": 204, + "Í": 205, + "Î": 206, + "Ï": 207, + "Ð": 208, + "Ñ": 209, + "Ò": 210, + "Ó": 211, + "Ô": 212, + "Õ": 213, + "Ö": 214, + "×": 215, + "Ø": 216, + "Ù": 217, + "Ú": 218, + "Û": 219, + "Ü": 220, + "Ý": 221, + "Þ": 222, + "ß": 223, + "à": 224, + "á": 225, + "â": 226, + "ã": 227, + "ä": 228, + "å": 229, + "æ": 230, + "ç": 231, + "è": 232, + "é": 233, + "ê": 234, + "ë": 235, + "ì": 236, + "í": 237, + "î": 238, + "ï": 239, + "ð": 240, + "ñ": 241, + "ò": 242, + "ó": 243, + "ô": 244, + "õ": 245, + "ö": 246, + "÷": 247, + "ø": 248, + "ù": 249, + "ú": 250, + "û": 251, + "ü": 252, + "ý": 253, + "þ": 254, + "ÿ": 255 + }, + "unk_token": "" + } +} diff --git a/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/tokenizer_config.json b/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/tokenizer_config.json new file mode 100644 index 000000000..5e189bec3 --- /dev/null +++ b/bionemo-recipes/models/llama3/nucleotide_fast_tokenizer/tokenizer_config.json @@ -0,0 +1,44 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "tokenizer_class": "PreTrainedTokenizerFast", + "unk_token": "" +} diff --git a/bionemo-recipes/models/llama3/tests/test_tokenizer.py b/bionemo-recipes/models/llama3/tests/test_tokenizer.py new file mode 100644 index 000000000..99b28a2fd --- /dev/null +++ b/bionemo-recipes/models/llama3/tests/test_tokenizer.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 + +""" +Unit tests for ASCII nucleotide tokenizer. +""" + +from pathlib import Path + +import pytest +import torch +from transformers import AutoTokenizer + + +@pytest.fixture(scope="session") +def tokenizer(): + """Load the ASCII nucleotide tokenizer.""" + tokenizer_path = Path(__file__).parent.parent / "nucleotide_fast_tokenizer" + return AutoTokenizer.from_pretrained(str(tokenizer_path)) + + +def test_tokenizer_special_token_ids(tokenizer): + """Test that the tokenizer's special token IDs are correct (match NeMo)""" + assert tokenizer.eos_token_id == 0 + assert tokenizer.pad_token_id == 1 + assert tokenizer.bos_token_id == 2 + assert tokenizer.unk_token_id == 3 + + +def test_tokenizer_encode_simple_sequences(tokenizer): + """Test encoding a simple repeated character sequences.""" + sequence = "AAAA" + encoded = tokenizer.encode(sequence, add_special_tokens=True) + + # Expected: BOS + AAAA + EOS = [2, 65, 65, 65, 65, 0] + expected = [2, 65, 65, 65, 65, 0] + assert encoded == expected + + sequence = "C" + encoded = tokenizer.encode(sequence, add_special_tokens=True) + + # Expected: BOS + C + EOS = [2, 67, 0] + expected = [2, 67, 0] + assert encoded == expected + + sequence = "G" * 20 + encoded = tokenizer.encode(sequence, add_special_tokens=True) + expected = [2] + [71] * 20 + [0] + assert encoded == expected + + +def test_tokenizer_encode_without_special_tokens(tokenizer): + """Test encoding without BOS/EOS tokens.""" + sequence = "TTTT" + encoded = tokenizer.encode(sequence, add_special_tokens=False) + + # Expected: just the Ts (T=84) + expected = [84, 84, 84, 84] + assert encoded == expected + + +def test_tokenizer_roundtrip_encode_decode(tokenizer): + """Test that encoding and decoding produces the original sequence.""" + sequence = "ATCGATCG" + encoded = tokenizer.encode(sequence, add_special_tokens=True) + decoded = tokenizer.decode(encoded, skip_special_tokens=True) + + # Decoded may have spaces between tokens, so compare without spaces + assert sequence == decoded.replace(" ", "") + + +def test_tokenizer_nucleotide_mappings(tokenizer): + """Test each nucleotide maps to its ASCII value.""" + # A=65, T=84, C=67, G=71 + assert tokenizer.encode("A", add_special_tokens=False) == [65] + assert tokenizer.encode("T", add_special_tokens=False) == [84] + assert tokenizer.encode("C", add_special_tokens=False) == [67] + assert tokenizer.encode("G", add_special_tokens=False) == [71] + + +def test_tokenizer_padding_to_longest(tokenizer): + """Test padding pads to longest sequence in batch.""" + batch = tokenizer(["AAAA", "TTTTTTTT"], padding=True, add_special_tokens=True, return_tensors="pt") + + # AAAA → [2, 65, 65, 65, 65, 0] = 6 tokens + # TTTTTTTT → [2, 84, 84, 84, 84, 84, 84, 84, 84, 0] = 10 tokens + # Should pad to 10 + assert batch["input_ids"].shape == torch.Size([2, 10]) + + # First sequence should have padding (PAD=1) + assert batch["input_ids"][0, 6].item() == 1 # First padding position + assert batch["input_ids"][0, 9].item() == 1 # Last padding position + + # Attention mask: 1 for real tokens, 0 for padding + assert batch["attention_mask"][0, 5].item() == 1 # Last real token + assert batch["attention_mask"][0, 6].item() == 0 # First padding + + +def test_tokenizer_attention_mask_correct(tokenizer): + """Test attention mask is 1 for real tokens, 0 for padding.""" + batch = tokenizer(["GG", "GGGGGG"], padding=True, add_special_tokens=True, return_tensors="pt") + + # GG → 4 tokens (BOS + GG + EOS) + # GGGGGG → 8 tokens (BOS + GGGGGG + EOS) + # Padded to 8 tokens + + # First sequence: 4 real + 4 padding + expected_mask_0 = [1, 1, 1, 1, 0, 0, 0, 0] + assert batch["attention_mask"][0].tolist() == expected_mask_0 + + # Second sequence: all real + expected_mask_1 = [1, 1, 1, 1, 1, 1, 1, 1] + assert batch["attention_mask"][1].tolist() == expected_mask_1 + + +def test_tokenizer_mixed_nucleotides(tokenizer): + """Test all standard nucleotides encode correctly.""" + sequence = "ATCGGTC" + encoded = tokenizer.encode(sequence, add_special_tokens=False) + + # A=65, T=84, C=67, G=71 + # ATCGGTC = A, T, C, G, G, T, C + expected = [65, 84, 67, 71, 71, 84, 67] + assert encoded == expected + + +def test_tokenizer_special_nucleotides(tokenizer): + """Test that sequences with ambiguity tokens (N, R, Y) encodes correctly.""" + sequence = "AANNNRY" + encoded = tokenizer.encode(sequence, add_special_tokens=False) + + # A=65, N=78, R=82, Y=89 + expected = [65, 65, 78, 78, 78, 82, 89] + assert encoded == expected + + +def test_10kbp_sequence_creates_expected_window_count(tokenizer): + """Test 10kbp sequence creates correct number of windows with seq_length=1000, stride=800. + + Verifies windowing math: 10000bp with seq_length=1000, stride=800. + """ + sequence = "A" * 10000 # 10kbp + + result = tokenizer( + sequence, + max_length=1000, + stride=800, # 800 token overlap + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + + # Hardcoded expectation based on input data: + # 10000bp with 1000 token windows and 800 token stride + # Step forward = 1000 - 800 = 200 tokens per window + assert len(result["input_ids"]) == 47 + + +def test_overlapping_windows_creates_more_samples(tokenizer): + """Test overlapping stride creates more windows than less overlapping.""" + sequence = "ATCG" * 2500 # 10kbp + + result_more_overlap = tokenizer( + sequence, + max_length=1000, + stride=800, # 200 token step (80% overlap) + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + + result_less_overlap = tokenizer( + sequence, + max_length=1000, + stride=500, # 500 token step (50% overlap) + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + + # Hardcoded expectations + assert len(result_more_overlap["input_ids"]) == 47 # With more overlap (smaller step) + assert len(result_less_overlap["input_ids"]) == 20 # With less overlap (larger step) + assert len(result_more_overlap["input_ids"]) > len(result_less_overlap["input_ids"]) + + +def test_production_window_length_creates_expected_samples(tokenizer): + """Test production settings (8192 window, 200 overlap) create correct number of windows.""" + sequence = "A" * 50000 # 50kbp sequence + + result = tokenizer( + sequence, + max_length=8192, + stride=200, # 200 token overlap + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + + # Hardcoded expectation with production settings: + # 50000bp with 8192 window and 200 stride (overlap) + # Step forward = 8192 - 200 = 7992 tokens per window + assert len(result["input_ids"]) == 7 + + +def test_short_sequences_dont_overflow(tokenizer): + """Test that short sequences (< max_length) don't create overflow windows.""" + sequence = "ATCG" * 100 # 400bp + + result = tokenizer( + sequence, + max_length=1000, + stride=800, + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + + # Sequence is shorter than max_length, should only create 1 window + assert len(result["input_ids"]) == 1 + # Length should be 400bp + BOS + EOS = 402 tokens + assert len(result["input_ids"][0]) == 402 + + +def test_bos_eos_in_overlapping_windows(tokenizer): + """Test that BOS/EOS tokens are added to every overlapping window. + + Verifies that when using return_overflowing_tokens with add_special_tokens=True, + each window gets its own BOS and EOS tokens, treating each as an independent sequence. + This matches the behavior needed for causal language modeling training. + """ + # Use a short genomic sequence that will produce exactly 2 overlapping windows + # With max_length=7 and stride=4, sequence of 8bp should give 2 windows + sequence = "ATCGATCG" # 8bp + + result = tokenizer( + sequence, + max_length=7, # BOS + 5 content + EOS = 7 tokens total + stride=4, # Overlap of 4 tokens between windows + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + + # Should produce exactly 2 windows + num_windows = len(result["input_ids"]) + assert num_windows >= 2, f"Should produce at least 2 overlapping windows, got {num_windows}" + + first_window = result["input_ids"][0] + second_window = result["input_ids"][1] + + # Verify both windows have BOS at start and EOS at end + assert first_window[0] == tokenizer.bos_token_id + assert first_window[-1] == tokenizer.eos_token_id + assert second_window[0] == tokenizer.bos_token_id + assert second_window[-1] == tokenizer.eos_token_id + + # Verify windows are actually overlapping by checking they share some content + first_content = set(first_window[1:-1]) + second_content = set(second_window[1:-1]) + assert len(first_content & second_content) > 0 diff --git a/bionemo-recipes/recipes/llama3_native_te/dataset.py b/bionemo-recipes/recipes/llama3_native_te/dataset.py new file mode 100644 index 000000000..cd61a6b55 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/dataset.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import datasets +import datasets.distributed +from distributed_config import DistributedConfig +from torch.utils.data import DataLoader, DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer +from transformers.data.data_collator import DataCollatorForLanguageModeling + + +logger = logging.getLogger(__name__) + + +def create_tokenized_dataset( + distributed_config: DistributedConfig, + tokenizer_path: str, + load_dataset_kwargs: dict, + max_seq_length: int = 8192, + stride: int = 200, + buffer_size: int = 500_000, + use_lazy_tokenization: bool = True, + sequence_column: str = "sequence", +): + """Create a tokenized dataset with windowing. + + Args: + distributed_config: The distributed configuration. + tokenizer_path: Path to the nucleotide tokenizer directory. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + max_seq_length: The maximum length of sequences (window size). + stride: The stride for windowing (overlap = stride tokens). + buffer_size: The buffer size for shuffle. + use_lazy_tokenization: Whether to use datasets.set_transform for tokenization. + sequence_column: Name of the column containing genomic sequences (default: "sequence"). + + Returns: + Tuple of (tokenized_dataset, tokenizer). + """ + logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}") + dataset = datasets.load_dataset(**load_dataset_kwargs) + logger.info(f"Loaded dataset: {dataset}") + + if isinstance(dataset, datasets.IterableDataset): + dataset = datasets.distributed.split_dataset_by_node( + dataset, + rank=distributed_config.rank, + world_size=distributed_config.world_size, + ) + dataset = dataset.shuffle(seed=42, buffer_size=buffer_size) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + def tokenize_with_windowing(examples): + """Tokenize nucleotide sequences with windowing (one-to-many mapping).""" + # Tokenize with windowing using return_overflowing_tokens + result = tokenizer( + examples[sequence_column], + max_length=max_seq_length, + stride=stride, + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + return result + + if isinstance(dataset, datasets.Dataset) and use_lazy_tokenization: + # Using dataset.map on a non-streaming dataset will automatically perform and cache the transform + tokenized_dataset = dataset.with_transform(tokenize_with_windowing) + else: + tokenized_dataset = dataset.map( + tokenize_with_windowing, + batched=True, + remove_columns=dataset.column_names, + ) + + return tokenized_dataset, tokenizer + + +def create_bshd_dataloader( + distributed_config: DistributedConfig, + tokenizer_path: str, + load_dataset_kwargs: dict, + micro_batch_size: int, + num_workers: int = 0, + max_seq_length: int = 8192, + stride: int = 200, + seed: int = 42, + buffer_size: int = 500_000, + use_lazy_tokenization: bool = True, + use_stateful_dataloader: bool = False, + sequence_column: str = "sequence", +): + """Create a BSHD dataloader for genomic sequences using CLM (causal language modeling). + + Args: + distributed_config: The distributed configuration. + tokenizer_path: Path to the nucleotide tokenizer directory. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + micro_batch_size: The batch size per device. + num_workers: The number of workers to use for the dataloader. + max_seq_length: The maximum length of sequences (window size). + stride: The stride for windowing (overlap = stride tokens). + seed: The seed to use for the distributed sampler and data collator. + buffer_size: The buffer size for shuffle. + use_lazy_tokenization: Whether to use datasets.set_transform for tokenization. + use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state. + sequence_column: Name of the column containing genomic sequences (default: "sequence"). + + Returns: + A tuple of (dataloader, dataset_or_sampler). + """ + tokenized_dataset, tokenizer = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=max_seq_length, + stride=stride, + buffer_size=buffer_size, + use_lazy_tokenization=use_lazy_tokenization, + sequence_column=sequence_column, + ) + + if isinstance(tokenized_dataset, datasets.IterableDataset): + sampler = None + else: + sampler = DistributedSampler( + tokenized_dataset, + rank=distributed_config.rank, + num_replicas=distributed_config.world_size, + seed=seed, + ) + + # Use DataCollatorForLanguageModeling with mlm=False for CLM + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, # Causal language modeling (no masking) + ) + + # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again. + dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader + train_dataloader = dataloader_class( + tokenized_dataset, + sampler=sampler, + batch_size=micro_batch_size, + collate_fn=data_collator, + num_workers=num_workers, + pin_memory=True if not use_stateful_dataloader else False, + persistent_workers=num_workers > 0, + ) + + return train_dataloader, tokenized_dataset if sampler is None else sampler diff --git a/bionemo-recipes/recipes/llama3_native_te/distributed_config.py b/bionemo-recipes/recipes/llama3_native_te/distributed_config.py new file mode 100644 index 000000000..271a5ffcf --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/distributed_config.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from dataclasses import dataclass, field + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class DistributedConfig: + """Class to track distributed ranks and handle basic distributed training setup. + + If torch distributed environment variables are not set, we set them to default values for single-process training. + + Attributes: + rank: The rank of the process. + local_rank: The local rank of the process. + world_size: The total number of processes. + """ + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """This is the global rank 0 process, to be used for wandb logging, etc.""" + return self.rank == 0 diff --git a/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/special_tokens_map.json b/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/special_tokens_map.json new file mode 100644 index 000000000..a1e19488e --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/special_tokens_map.json @@ -0,0 +1,6 @@ +{ + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "" +} diff --git a/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/tokenizer.json b/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/tokenizer.json new file mode 100644 index 000000000..5dbce19e6 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/tokenizer.json @@ -0,0 +1,396 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Split", + "pattern": { + "String": "" + }, + "behavior": "Isolated", + "invert": false + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "special_tokens": { + "": { + "id": "", + "ids": [ + 2 + ], + "tokens": [ + "" + ] + }, + "": { + "id": "", + "ids": [ + 0 + ], + "tokens": [ + "" + ] + } + } + }, + "decoder": null, + "model": { + "type": "WordLevel", + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3, + "\u0004": 4, + "\u0005": 5, + "\u0006": 6, + "\u0007": 7, + "\b": 8, + "\t": 9, + "\n": 10, + "\u000b": 11, + "\f": 12, + "\r": 13, + "\u000e": 14, + "\u000f": 15, + "\u0010": 16, + "\u0011": 17, + "\u0012": 18, + "\u0013": 19, + "\u0014": 20, + "\u0015": 21, + "\u0016": 22, + "\u0017": 23, + "\u0018": 24, + "\u0019": 25, + "\u001a": 26, + "\u001b": 27, + "\u001c": 28, + "\u001d": 29, + "\u001e": 30, + "\u001f": 31, + " ": 32, + "!": 33, + "\"": 34, + "#": 35, + "$": 36, + "%": 37, + "&": 38, + "'": 39, + "(": 40, + ")": 41, + "*": 42, + "+": 43, + ",": 44, + "-": 45, + ".": 46, + "/": 47, + "0": 48, + "1": 49, + "2": 50, + "3": 51, + "4": 52, + "5": 53, + "6": 54, + "7": 55, + "8": 56, + "9": 57, + ":": 58, + ";": 59, + "<": 60, + "=": 61, + ">": 62, + "?": 63, + "@": 64, + "A": 65, + "B": 66, + "C": 67, + "D": 68, + "E": 69, + "F": 70, + "G": 71, + "H": 72, + "I": 73, + "J": 74, + "K": 75, + "L": 76, + "M": 77, + "N": 78, + "O": 79, + "P": 80, + "Q": 81, + "R": 82, + "S": 83, + "T": 84, + "U": 85, + "V": 86, + "W": 87, + "X": 88, + "Y": 89, + "Z": 90, + "[": 91, + "\\": 92, + "]": 93, + "^": 94, + "_": 95, + "`": 96, + "a": 97, + "b": 98, + "c": 99, + "d": 100, + "e": 101, + "f": 102, + "g": 103, + "h": 104, + "i": 105, + "j": 106, + "k": 107, + "l": 108, + "m": 109, + "n": 110, + "o": 111, + "p": 112, + "q": 113, + "r": 114, + "s": 115, + "t": 116, + "u": 117, + "v": 118, + "w": 119, + "x": 120, + "y": 121, + "z": 122, + "{": 123, + "|": 124, + "}": 125, + "~": 126, + "": 127, + "€": 128, + "": 129, + "‚": 130, + "ƒ": 131, + "„": 132, + "…": 133, + "†": 134, + "‡": 135, + "ˆ": 136, + "‰": 137, + "Š": 138, + "‹": 139, + "Œ": 140, + "": 141, + "Ž": 142, + "": 143, + "": 144, + "‘": 145, + "’": 146, + "“": 147, + "”": 148, + "•": 149, + "–": 150, + "—": 151, + "˜": 152, + "™": 153, + "š": 154, + "›": 155, + "œ": 156, + "": 157, + "ž": 158, + "Ÿ": 159, + " ": 160, + "¡": 161, + "¢": 162, + "£": 163, + "¤": 164, + "¥": 165, + "¦": 166, + "§": 167, + "¨": 168, + "©": 169, + "ª": 170, + "«": 171, + "¬": 172, + "­": 173, + "®": 174, + "¯": 175, + "°": 176, + "±": 177, + "²": 178, + "³": 179, + "´": 180, + "µ": 181, + "¶": 182, + "·": 183, + "¸": 184, + "¹": 185, + "º": 186, + "»": 187, + "¼": 188, + "½": 189, + "¾": 190, + "¿": 191, + "À": 192, + "Á": 193, + "Â": 194, + "Ã": 195, + "Ä": 196, + "Å": 197, + "Æ": 198, + "Ç": 199, + "È": 200, + "É": 201, + "Ê": 202, + "Ë": 203, + "Ì": 204, + "Í": 205, + "Î": 206, + "Ï": 207, + "Ð": 208, + "Ñ": 209, + "Ò": 210, + "Ó": 211, + "Ô": 212, + "Õ": 213, + "Ö": 214, + "×": 215, + "Ø": 216, + "Ù": 217, + "Ú": 218, + "Û": 219, + "Ü": 220, + "Ý": 221, + "Þ": 222, + "ß": 223, + "à": 224, + "á": 225, + "â": 226, + "ã": 227, + "ä": 228, + "å": 229, + "æ": 230, + "ç": 231, + "è": 232, + "é": 233, + "ê": 234, + "ë": 235, + "ì": 236, + "í": 237, + "î": 238, + "ï": 239, + "ð": 240, + "ñ": 241, + "ò": 242, + "ó": 243, + "ô": 244, + "õ": 245, + "ö": 246, + "÷": 247, + "ø": 248, + "ù": 249, + "ú": 250, + "û": 251, + "ü": 252, + "ý": 253, + "þ": 254, + "ÿ": 255 + }, + "unk_token": "" + } +} diff --git a/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/tokenizer_config.json b/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/tokenizer_config.json new file mode 100644 index 000000000..5e189bec3 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/example_checkpoint/tokenizer_config.json @@ -0,0 +1,44 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "tokenizer_class": "PreTrainedTokenizerFast", + "unk_token": "" +} diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt new file mode 100644 index 000000000..794de698a --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt @@ -0,0 +1,10 @@ +datasets +hydra-core +torch +torchao!=0.14.0 +torchdata +torchmetrics +tqdm +transformer_engine[pytorch] +transformers +wandb diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py new file mode 100644 index 000000000..c811180c3 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path +from unittest import mock + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +import torch +from torch.distributed.device_mesh import _mesh_resources, init_device_mesh + + +sys.path.append(Path(__file__).parent.parent.as_posix()) +sys.path.append(Path(__file__).parent.as_posix()) + +from distributed_config import DistributedConfig + + +@pytest.fixture +def recipe_path() -> Path: + """Return the root directory of the recipe.""" + return Path(__file__).parent.parent + + +@pytest.fixture(scope="session") +def mock_genomic_parquet(tmp_path_factory) -> Path: + """Create a mock genomic sequences parquet file for testing. + + This fixture creates a small parquet file with synthetic genomic sequences + that can be used for training tests without relying on external data files. + + Returns: + Path to the generated parquet file + """ + tmp_dir = tmp_path_factory.mktemp("data") + parquet_path = tmp_dir / "test_genomic_sequences.parquet" + + # Create mock genomic sequences with simple repeating patterns + # These are easy for the model to overfit to, which is perfect for sanity tests + sequences = [ + "ATCG" * 300, # 1200 bp - simple ATCG repeat + "AAAA" * 250 + "TTTT" * 250, # 2000 bp - alternating A and T blocks + "GCGC" * 200, # 800 bp - GC repeat + "ACGT" * 400, # 1600 bp - all 4 nucleotides + "TGCA" * 350, # 1400 bp - reverse pattern + ] + + # Create parquet table with 'sequence' column + table = pa.table( + { + "sequence": sequences, + } + ) + + pq.write_table(table, parquet_path) + return parquet_path + + +@pytest.fixture(scope="session", autouse=True) +def device_mesh(): + """Create a re-usable device mesh for testing. + + This is a "auto-use", session-scope fixture so that a single device mesh is created and used in all tests. + + Megatron-FSDP throws issues when re-creating the torch device mesh in the same process, starting in the 25.09 NGC + pytorch container release. To work around this, we create a re-usable device mesh that use in all single-process + tests. + """ + # Initialize the distributed configuration, including creating the distributed process group. + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + device_mesh = init_device_mesh("cuda", mesh_shape=(1, 1), mesh_dim_names=("dp", "tp")) + + # Mock these torch.distributed functions so that we re-use the same device mesh, and don't re-create or destroy the + # global process group. + with ( + mock.patch("torch.distributed.device_mesh.init_device_mesh", return_value=device_mesh), + mock.patch("torch.distributed.init_process_group", return_value=None), + mock.patch("torch.distributed.destroy_process_group", return_value=None), + ): + yield + + # At the end of all tests, destroy the process group and clear the device mesh resources. + torch.distributed.destroy_process_group() + _mesh_resources.mesh_stack.clear() + _mesh_resources.child_to_root_mapping.clear() + _mesh_resources.root_to_flatten_mapping.clear() + _mesh_resources.flatten_name_to_root_dims.clear() + _mesh_resources.mesh_dim_group_options.clear() + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py new file mode 100644 index 000000000..31df7b321 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py @@ -0,0 +1,414 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +import torch +from dataset import create_bshd_dataloader, create_tokenized_dataset +from distributed_config import DistributedConfig + + +@pytest.fixture +def tokenizer_path(recipe_path): + """Get the path to the nucleotide tokenizer.""" + return str(recipe_path / "example_checkpoint") + + +@pytest.fixture +def simple_parquet(tmp_path): + """Create a simple Parquet file with multiple genomic sequences for testing batching.""" + parquet_path = tmp_path / "genomic_sequences.parquet" + + # Create multiple sequences of varying lengths for better batching tests + sequences = [ + "A" * 1000, + "T" * 1200, + "C" * 800, + "G" * 1500, + "ATCG" * 300, + ] + + table = pa.table( + { + "sequence": sequences, + } + ) + + pq.write_table(table, parquet_path) + return str(parquet_path) + + +def test_dataset_loads_and_tokenizes_sequence(tokenizer_path, tmp_path): + """Test that dataset loads and tokenizes a sequence correctly with exact token verification. + + Uses single sequence so shuffling doesn't affect test (similar to SQLite test approach). + Pattern: expected_sequence = [nucleotide_id] * seqlen + """ + # Create a Parquet file with a single T sequence of known length + parquet_path = tmp_path / "genomic_sequences.parquet" + sequence = "T" * 10 # Small, predictable sequence + table = pa.table({"sequence": [sequence]}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + tokenized_dataset, tokenizer = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=20, # Large enough to fit the sequence + stride=10, + buffer_size=10_000, + use_lazy_tokenization=False, # Eager to get predictable dataset + ) + + # Only 1 sequence → 1 window → dataset[0] is predictable regardless of shuffle + sample = tokenized_dataset[0] + assert "input_ids" in sample + + # Get nucleotides (remove BOS and EOS) + tokens = sample["input_ids"] + nucleotides = tokens[1:-1] + + # Verify exact expected sequence + bos = 2 + eos = 0 + t = 84 # ASCII value of 'T' + + expected_sequence = [t] * 10 # All Ts + received_sequence = nucleotides + + assert tokens[0] == bos, f"First token should be BOS (2), got {tokens[0]}" + assert tokens[-1] == eos, f"Last token should be EOS (0), got {tokens[-1]}" + assert received_sequence == expected_sequence, f"Expected {expected_sequence}, got {received_sequence}" + + +def test_dataloader_returns_expected_batch(tokenizer_path, tmp_path): + """Test dataloader returns exact expected batch with known input. + + Creates minimal test data with exactly one sequence to get deterministic output. + Verifies exact token values match expected hardcoded batch. + """ + # Create minimal test parquet with exactly 1 sequence + parquet_path = tmp_path / "single_sequence.parquet" + sequence = "A" * 5 # 5 As + table = pa.table({"sequence": [sequence]}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=1, # Just one sample per batch + num_workers=0, + max_seq_length=10, # Large enough for 5bp sequence + stride=5, + use_lazy_tokenization=False, # Eager for deterministic behavior + ) + + returned_batch = next(iter(dataloader)) + + # Hardcode expected batch (1 sequence, deterministic output) + # seq: 5bp of As -> BOS + 5 As + EOS + bos = 2 + eos = 0 + a = 65 # ASCII value of 'A' + + expected_input_ids = torch.tensor([[bos, a, a, a, a, a, eos]], dtype=torch.long) + expected_labels = torch.tensor([[bos, a, a, a, a, a, eos]], dtype=torch.long) # CLM: labels = input_ids + expected_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1]], dtype=torch.long) # All real tokens + + assert torch.equal(returned_batch["input_ids"], expected_input_ids), ( + f"Expected input_ids {expected_input_ids}, got {returned_batch['input_ids']}" + ) + assert torch.equal(returned_batch["labels"], expected_labels), ( + f"Expected labels {expected_labels}, got {returned_batch['labels']}" + ) + assert torch.equal(returned_batch["attention_mask"], expected_attention_mask), ( + f"Expected attention_mask {expected_attention_mask}, got {returned_batch['attention_mask']}" + ) + + +def test_attention_mask_aligns_with_labels(tokenizer_path, simple_parquet): + """Test attention_mask correctly identifies real vs padded positions in labels. + + Where attention_mask=1: labels should contain real token IDs (matching input_ids) + Where attention_mask=0: labels should contain ignore_index value (-100) + """ + # HuggingFace's DataCollatorForLanguageModeling uses -100 as ignore_index by default + ignore_pad_token = -100 + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": simple_parquet, + "split": "train", + } + + # Use a moderate window size to ensure we get padding in batches + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=2, + num_workers=0, + max_seq_length=500, + stride=100, + ) + + batch = next(iter(dataloader)) + + # Check first sequence in batch + attention_mask = batch["attention_mask"][0] + labels = batch["labels"][0] + input_ids = batch["input_ids"][0] + + # Where attention_mask=1, labels should equal input_ids (real tokens) + real_positions = attention_mask == 1 + real_labels = labels[real_positions] + real_input_ids = input_ids[real_positions] + + # For CLM (Causal Language Modeling), labels should match input_ids at real positions + assert torch.all(real_labels == real_input_ids), "Labels should match input_ids at real token positions" + + # Verify specific token positions contain expected values + assert real_labels[0].item() == 2, "First token should be BOS (2)" + assert real_labels[-1].item() == 0, "Last real token should be EOS (0)" + # Middle tokens should be nucleotides (A=65, T=84, C=67, G=71) + if len(real_labels) > 2: + middle_token = real_labels[1].item() + assert middle_token in [65, 84, 67, 71], f"Nucleotide tokens should be A/T/C/G, got {middle_token}" + + # Ensure NO real position has the ignore padding value + assert torch.all(real_labels != ignore_pad_token), "Real tokens should not have IGNORE_PAD_TOKEN" + + # Where attention_mask=0, labels should be IGNORE_PAD_TOKEN (-100) + padded_positions = attention_mask == 0 + if padded_positions.any(): + padded_labels = labels[padded_positions] + assert torch.all(padded_labels == ignore_pad_token), ( + f"Padded positions should have IGNORE_PAD_TOKEN (-100), got {padded_labels.unique()}" + ) + + +def test_windowing_in_dataset_creates_multiple_samples(tokenizer_path, tmp_path): + """Test that the dataset's windowing creates expected number of samples.""" + # Create a 3kbp sequence + parquet_path = tmp_path / "genomic_sequences.parquet" + sequence = "A" * 3000 + table = pa.table({"sequence": [sequence]}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + tokenized_dataset, tokenizer = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=1000, + stride=800, # 800 token overlap, so 200 token step + buffer_size=10_000, + use_lazy_tokenization=False, # Use eager tokenization to expand windows + ) + + # Count samples + num_samples = len(tokenized_dataset) + + # With 3000bp sequence, max_length=1000, stride=800 (800 overlap, 200 step) + # Formula: ceil((3000+2 - 1000) / 200) + 1 = ceil(2002/200) + 1 = 11 + 1 = 12 windows + assert num_samples == 12, f"Expected exactly 12 windows, got {num_samples}" + + +def test_lazy_tokenization_returns_batch(tokenizer_path, simple_parquet): + """Test that lazy tokenization works and returns valid batches.""" + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": simple_parquet, + "split": "train", + "streaming": False, + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=2, + num_workers=0, + max_seq_length=500, + stride=100, + use_lazy_tokenization=True, + ) + + # Get a batch + batch = next(iter(dataloader)) + + # Verify batch is not None and has correct structure + assert batch is not None + assert "input_ids" in batch + assert "attention_mask" in batch + assert "labels" in batch + assert isinstance(batch["input_ids"], torch.Tensor) + + # With lazy tokenization and windowing, batch size can vary due to on-the-fly window expansion + # Just verify we get at least one sample (lazy tokenization + windowing makes exact count unpredictable) + assert batch["input_ids"].shape[0] >= 1, f"Expected at least 1 sample in batch, got {batch['input_ids'].shape[0]}" + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_multiple_sequences_batch_correctly(tokenizer_path, simple_parquet, streaming): + """Test that multiple sequences batch together correctly in both streaming and non-streaming modes. + + This test catches bugs that only appear with multi-row datasets vs single-row: + - Batching/collation works with multiple sequences + - Sequences in batch are different (not duplicated) + - Padding aligns correctly across multiple sequences + - All sequences are processed across batches + - Works in both streaming=True and streaming=False modes + """ + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": simple_parquet, + "split": "train", + "streaming": streaming, + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=2, + num_workers=0, + max_seq_length=500, + stride=100, + buffer_size=10_000, # Only used for streaming + use_lazy_tokenization=False, + ) + + # Get first batch + batch = next(iter(dataloader)) + + # KEY TEST 1: Verify batch contains MULTIPLE sequences (not just 1) + assert batch["input_ids"].shape[0] == 2, f"Batch should contain 2 sequences, got {batch['input_ids'].shape[0]}" + + # KEY TEST 2: Verify sequences in batch are DIFFERENT (catch duplication bugs) + seq1 = batch["input_ids"][0] + seq2 = batch["input_ids"][1] + assert not torch.equal(seq1, seq2), "Sequences in batch should be different, not duplicates" + + # KEY TEST 3: Verify padding aligns across all tensors in batch + batch_size, seq_length = batch["input_ids"].shape + assert batch["attention_mask"].shape == (batch_size, seq_length) + assert batch["labels"].shape == (batch_size, seq_length) + + # KEY TEST 4: Verify all sequences are processed (multiple batches produced) + # With 5 sequences from simple_parquet (800-1500bp) and max_seq_length=500, + # windowing will create ~11+ windows total. With batch_size=2, expect ~5-6 batches. + # We already consumed 1 batch, so should have at least 4 remaining batches. + all_batches = list(dataloader) + total_batches = len(all_batches) + 1 # +1 for first batch already consumed + assert len(all_batches) >= 4, ( + f"Expected at least 4 remaining batches (5 total), got {len(all_batches)} remaining ({total_batches} total)" + ) + + # KEY TEST 5: Verify subsequent batches also valid (not just first batch) + if len(all_batches) > 0: + second_batch = all_batches[0] + # Check structure is consistent across batches + assert "input_ids" in second_batch + assert "attention_mask" in second_batch + assert "labels" in second_batch + # Verify it also has multiple sequences (could be different count due to windowing) + assert second_batch["input_ids"].shape[0] >= 1, ( + f"Second batch should have at least 1 sequence, got {second_batch['input_ids'].shape[0]}" + ) + # Verify tensors align + batch_size_2, seq_length_2 = second_batch["input_ids"].shape + assert second_batch["attention_mask"].shape == (batch_size_2, seq_length_2) + assert second_batch["labels"].shape == (batch_size_2, seq_length_2) + + +def test_batching_produces_correct_batch_size(tokenizer_path, tmp_path): + """Test that batching combines multiple sequences correctly with exact batch counts. + + Creates 5 short sequences (no windowing) with micro_batch_size=2. + Should produce exactly 3 batches with shapes: [2, 2, 1]. + """ + # Create 5 sequences that won't trigger windowing (all very short) + parquet_path = tmp_path / "five_sequences.parquet" + sequences = [ + "A" * 10, # Seq 1 + "T" * 15, # Seq 2 + "C" * 12, # Seq 3 + "G" * 8, # Seq 4 + "ATCG" * 3, # Seq 5 (12bp) + ] + table = pa.table({"sequence": sequences}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=2, + num_workers=0, + max_seq_length=50, # Large enough - no windowing + stride=10, + use_lazy_tokenization=False, # Use eager to ensure predictable batching + ) + + # Collect all batches + batches = list(dataloader) + + # With 5 sequences and batch_size=2, expect exactly 3 batches: [2, 2, 1] + assert len(batches) == 3, f"Expected exactly 3 batches from 5 sequences, got {len(batches)}" + + # Check each batch has correct shape + assert batches[0]["input_ids"].shape[0] == 2, "Batch 0 should have 2 sequences" + assert batches[1]["input_ids"].shape[0] == 2, "Batch 1 should have 2 sequences" + assert batches[2]["input_ids"].shape[0] == 1, "Batch 2 should have 1 sequence (remainder)"