Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion bionemo-recipes/recipes/llama3_native_te/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,33 @@ def tokenize_with_windowing(examples):
# Using dataset.map on a non-streaming dataset will automatically perform and cache the transform
tokenized_dataset = dataset.with_transform(tokenize_with_windowing)
else:
# WORKAROUND for OpenGenome2 inconsistent schema:
# OpenGenome2 has inconsistent schemas across shards - some have 'record' column, some don't.
# This causes dataset.column_names to be None for streaming IterableDataset.
#
# For IterableDataset with None column_names (OpenGenome2):
# - Must explicitly list columns to remove: [sequence_column, "record"]
# - IterableDataset.map() handles missing columns gracefully
#
# For regular Dataset (non-streaming, or streaming with consistent schema like ESM2):
# - Use dataset.column_names (which is available and accurate)
# - Dataset.map() raises error if column doesn't exist
#
# TODO: Remove this workaround once Arc Institute fixes OpenGenome2 schema consistency.
# When all shards have the same columns, dataset.column_names will work for both cases.
if isinstance(dataset, datasets.IterableDataset):
# Streaming dataset: column_names may be None due to inconsistent schema
columns_to_remove = [sequence_column, "record"]
else:
# Non-streaming dataset: use actual column names
columns_to_remove = dataset.column_names

logger.info(f"Applying dataset.map with columns to remove: {columns_to_remove}")

tokenized_dataset = dataset.map(
tokenize_with_windowing,
batched=True,
remove_columns=dataset.column_names,
remove_columns=columns_to_remove,
)

return tokenized_dataset, tokenizer
Expand Down
129 changes: 129 additions & 0 deletions bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,132 @@ def test_batching_produces_correct_batch_size(tokenizer_path, tmp_path):
assert batches[0]["input_ids"].shape[0] == 2, "Batch 0 should have 2 sequences"
assert batches[1]["input_ids"].shape[0] == 2, "Batch 1 should have 2 sequences"
assert batches[2]["input_ids"].shape[0] == 1, "Batch 2 should have 1 sequence (remainder)"


def test_streaming_dataset_removes_columns_correctly(tokenizer_path, tmp_path):
"""Test that streaming datasets properly remove input columns (text, record) during tokenization.

This is a regression test for the OpenGenome2-specific bug where dataset.column_names is None
for streaming datasets with inconsistent schemas across shards. This causes remove_columns to fail
and leaves raw text/record columns in the tokenized dataset.

OpenGenome2 has inconsistent schemas:
- Some shards: ["text", "record"]
- Some shards: ["text"] only
- Result: dataset.column_names = None (can't determine upfront)

Note: Regular datasets (like ESM2) don't have this issue because they have consistent schemas.

Reference: https://github.com/NVIDIA/bionemo-framework/commit/3c0aee6de065ef494389591ca9028e8301dc385a
"""
# Create a Parquet file with both 'text' (sequence column) and 'record' (metadata)
parquet_path = tmp_path / "genomic_with_metadata.parquet"
sequences = ["ATCGATCG" * 10, "GCTAGCTA" * 10]
records = ["chr1:1000-1080", "chr2:2000-2080"]

table = pa.table(
{
"text": sequences, # Using 'text' to match OpenGenome2 format
"record": records, # Metadata column that should be removed
}
)
pq.write_table(table, parquet_path)

distributed_config = DistributedConfig(rank=0, world_size=1)

# Load as streaming dataset
load_dataset_kwargs = {
"path": "parquet",
"data_files": str(parquet_path),
"split": "train",
"streaming": True, # This makes dataset.column_names = None
}

tokenized_dataset, tokenizer = create_tokenized_dataset(
distributed_config=distributed_config,
tokenizer_path=tokenizer_path,
load_dataset_kwargs=load_dataset_kwargs,
max_seq_length=100,
stride=10,
buffer_size=1000,
use_lazy_tokenization=False,
sequence_column="text", # Specify which column has sequences
)

# Get first sample from streaming dataset
sample = next(iter(tokenized_dataset))

# Verify that only tokenizer outputs remain (no raw text or record columns)
expected_keys = {"input_ids", "attention_mask", "token_type_ids", "overflow_to_sample_mapping"}
actual_keys = set(sample.keys())

assert actual_keys.issubset(expected_keys), (
f"Unexpected columns found in tokenized dataset. "
f"Expected only {expected_keys}, but got {actual_keys}. "
f"Columns 'text' and 'record' should have been removed."
)

# Specifically check that problematic columns are NOT present
assert "text" not in sample, "Column 'text' should have been removed during tokenization"
assert "record" not in sample, "Column 'record' should have been removed during tokenization"

# Verify tokenizer outputs are present and valid
assert "input_ids" in sample, "input_ids should be present"
assert isinstance(sample["input_ids"], list), "input_ids should be a list"
assert len(sample["input_ids"]) > 0, "input_ids should not be empty"


def test_streaming_dataset_handles_missing_record_column(tokenizer_path, tmp_path):
"""Test that remove_columns handles missing 'record' column gracefully (OpenGenome2 workaround).

OpenGenome2 has inconsistent schemas across shards:
- Some shards have 'record' column (metadata)
- Some shards don't have 'record' column

This test verifies that explicitly listing 'record' in columns_to_remove doesn't
cause errors when the column is absent. This is part of the OpenGenome2 workaround.

TODO: Remove this workaround once Arc Institute fixes OpenGenome2 schema consistency.

Reference: https://github.com/NVIDIA/bionemo-framework/commit/a41f306eda7605552ee736e3291c098f2623828a
"""
# Create a Parquet file with ONLY 'text' column (no 'record')
parquet_path = tmp_path / "genomic_no_record.parquet"
sequences = ["ATCGATCG" * 10, "GCTAGCTA" * 10]

table = pa.table(
{
"text": sequences, # Only text, no record column
}
)
pq.write_table(table, parquet_path)

distributed_config = DistributedConfig(rank=0, world_size=1)

# Load as streaming dataset
load_dataset_kwargs = {
"path": "parquet",
"data_files": str(parquet_path),
"split": "train",
"streaming": True,
}

# This should NOT raise an error even though 'record' is in columns_to_remove
tokenized_dataset, tokenizer = create_tokenized_dataset(
distributed_config=distributed_config,
tokenizer_path=tokenizer_path,
load_dataset_kwargs=load_dataset_kwargs,
max_seq_length=100,
stride=10,
buffer_size=1000,
use_lazy_tokenization=False,
sequence_column="text",
)

# Get first sample - should work without errors
sample = next(iter(tokenized_dataset))

# Verify only tokenizer outputs are present
assert "text" not in sample, "Column 'text' should have been removed"
assert "record" not in sample, "Column 'record' was never present, so shouldn't be in output"
assert "input_ids" in sample, "input_ids should be present"