Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions tests/unit_tests/data/test_bin_reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import os
import random
import sys
Expand All @@ -6,16 +8,18 @@
from types import ModuleType, SimpleNamespace
from typing import Any, Dict

import nltk
import pytest

try:
import boto3
import botocore.exceptions as exceptions
except ModuleNotFoundError:
# Create mock msc module
boto3 = ModuleType("boto3")
sys.modules[boto3.__name__] = boto3

# Create mock types submodule
exceptions = ModuleType("botocore.exceptions")

# Register the mock module in sys.modules
sys.modules[boto3.__name__] = boto3
sys.modules[exceptions.__name__] = exceptions

try:
Expand Down Expand Up @@ -43,6 +47,8 @@ def __init__(self, offset: int, size: int):
sys.modules[msc.__name__] = msc
sys.modules[types_module.__name__] = types_module

import torch

from megatron.core.datasets.indexed_dataset import (
IndexedDataset,
ObjectStorageConfig,
Expand All @@ -58,9 +64,11 @@ def __init__(self, offset: int, size: int):
gpt2_merge,
gpt2_vocab,
)
from tests.unit_tests.test_utilities import Utils


##
# Overload client from boto3
# Mock boto3
##


Expand All @@ -72,7 +80,8 @@ def __init__(self, *args: Any) -> None:

def download_file(self, Bucket: str, Key: str, Filename: str) -> None:
os.makedirs(os.path.dirname(Filename), exist_ok=True)
os.system(f"cp {os.path.join('/', Bucket, Key)} {Filename}")
remote_path = os.path.join("/", Bucket, Key)
os.system(f"cp {remote_path} {Filename}")
assert os.path.exists(Filename)

def upload_file(self, Filename: str, Bucket: str, Key: str) -> None:
Expand Down Expand Up @@ -104,27 +113,28 @@ def close(self) -> None:


##
# Overload ClientError from botocore.exceptions
# Mock botocore.exceptions
##


class _LocalClientError(Exception):
""" "Local test client error"""
"""Local test client error"""

pass


setattr(exceptions, "ClientError", _LocalClientError)

##
# Mock multistorageclient module
# Mock msc.open, msc.download_file, msc.resolve_storage_client
##


def _msc_download_file(remote_path, local_path):
remote_path = remote_path.removeprefix(MSC_PREFIX + "default")
remote_path = os.path.join("/", remote_path.removeprefix(MSC_PREFIX))
os.makedirs(os.path.dirname(local_path), exist_ok=True)
os.system(f"cp {remote_path} {local_path}")
assert os.path.exists(local_path)


def _msc_resolve_storage_client(path):
Expand All @@ -134,28 +144,29 @@ def read(self, path, byte_range):
f.seek(byte_range.offset)
return f.read(byte_range.size)

return StorageClient(), path.removeprefix(MSC_PREFIX + "default")
return StorageClient(), os.path.join("/", path.removeprefix(MSC_PREFIX))


setattr(msc, "open", open)
setattr(msc, "download_file", _msc_download_file)
setattr(msc, "resolve_storage_client", _msc_resolve_storage_client)


@pytest.mark.flaky
@pytest.mark.flaky_in_dev
def test_bin_reader():
with tempfile.TemporaryDirectory() as temp_dir:
# set the default nltk data path
os.environ["NLTK_DATA"] = os.path.join(temp_dir, "nltk_data")
nltk.data.path.append(os.environ["NLTK_DATA"])
if torch.distributed.is_available():
Utils.initialize_distributed()
if torch.distributed.get_rank() != 0:
return

with tempfile.TemporaryDirectory() as temp_dir:
path_to_raws = os.path.join(temp_dir, "sample_raws")
path_to_data = os.path.join(temp_dir, "sample_data")
path_to_object_storage_cache = os.path.join(temp_dir, "object_storage_cache")
path_to_object_storage_cache_msc = os.path.join(temp_dir, "object_storage_cache_msc")
path_to_object_storage_cache_s3 = os.path.join(temp_dir, "object_storage_cache_s3")
os.mkdir(path_to_raws)
os.mkdir(path_to_data)
os.mkdir(path_to_object_storage_cache)
os.mkdir(path_to_object_storage_cache_msc)
os.mkdir(path_to_object_storage_cache_s3)

# create the dummy resources
dummy_jsonl(path_to_raws)
Expand Down Expand Up @@ -195,27 +206,26 @@ def test_bin_reader():
assert isinstance(indexed_dataset_mmap.bin_reader, _MMapBinReader)

indexed_dataset_msc = IndexedDataset(
MSC_PREFIX + "default" + prefix, # use the default profile to access the filesystem
MSC_PREFIX + prefix.lstrip("/"),
multimodal=False,
mmap=False,
object_storage_config=ObjectStorageConfig(
path_to_idx_cache=path_to_object_storage_cache
path_to_idx_cache=path_to_object_storage_cache_msc
),
)
assert isinstance(indexed_dataset_msc.bin_reader, _MultiStorageClientBinReader)
assert len(indexed_dataset_msc) == len(indexed_dataset_file)
assert len(indexed_dataset_msc) == len(indexed_dataset_mmap)

indexed_dataset_s3 = IndexedDataset(
S3_PREFIX + prefix,
S3_PREFIX + prefix.lstrip("/"),
multimodal=False,
mmap=False,
object_storage_config=ObjectStorageConfig(
path_to_idx_cache=path_to_object_storage_cache
path_to_idx_cache=path_to_object_storage_cache_s3
),
)
assert isinstance(indexed_dataset_s3.bin_reader, _S3BinReader)

assert len(indexed_dataset_s3) == len(indexed_dataset_file)
assert len(indexed_dataset_s3) == len(indexed_dataset_mmap)

Expand All @@ -226,6 +236,7 @@ def test_bin_reader():
for idx in indices:
assert (indexed_dataset_s3[idx] == indexed_dataset_file[idx]).all()
assert (indexed_dataset_s3[idx] == indexed_dataset_mmap[idx]).all()
assert (indexed_dataset_s3[idx] == indexed_dataset_msc[idx]).all()


if __name__ == "__main__":
Expand Down
68 changes: 36 additions & 32 deletions tests/unit_tests/data/test_gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def test_mock_gpt_dataset():
if torch.distributed.get_rank() == 0:
compile_helpers()
torch.distributed.barrier()
build_on_rank = lambda: torch.distributed.get_rank() == 0
else:
compile_helpers()
build_on_rank = lambda: True

tokenizer = _NullTokenizer(vocab_size=_MOCK_VOCAB_SIZE)

Expand All @@ -51,26 +53,27 @@ def test_mock_gpt_dataset():
)

datasets = BlendedMegatronDatasetBuilder(
MockGPTDataset, [100, 100, 100], lambda: True, config
MockGPTDataset, [100, 100, 100], build_on_rank, config
).build()

N = 10
if build_on_rank():
N = 10

# Check iso-index variance by split
subsets = [sample_N(dataset, N, randomize=False) for dataset in datasets]
assert not numpy.allclose(subsets[0], subsets[1])
assert not numpy.allclose(subsets[0], subsets[2])
assert not numpy.allclose(subsets[1], subsets[2])
# Check iso-index variance by split
subsets = [sample_N(dataset, N, randomize=False) for dataset in datasets]
assert not numpy.allclose(subsets[0], subsets[1])
assert not numpy.allclose(subsets[0], subsets[2])
assert not numpy.allclose(subsets[1], subsets[2])

# Check iso-split / iso-index identity
subset_1A = sample_N(datasets[0], N, randomize=False)
subset_1B = sample_N(datasets[0], N, randomize=False)
assert numpy.allclose(subset_1A, subset_1B)
# Check iso-split / iso-index identity
subset_1A = sample_N(datasets[0], N, randomize=False)
subset_1B = sample_N(datasets[0], N, randomize=False)
assert numpy.allclose(subset_1A, subset_1B)

# Check iso-split variance by index
subset_1A = sample_N(datasets[0], N, randomize=True)
subset_1B = sample_N(datasets[0], N, randomize=True)
assert not numpy.allclose(subset_1A, subset_1B)
# Check iso-split variance by index
subset_1A = sample_N(datasets[0], N, randomize=True)
subset_1B = sample_N(datasets[0], N, randomize=True)
assert not numpy.allclose(subset_1A, subset_1B)

config = GPTDatasetConfig(
random_seed=1234,
Expand All @@ -86,29 +89,30 @@ def test_mock_gpt_dataset():
)

datasets = BlendedMegatronDatasetBuilder(
MockGPTDataset, [0, None, 0], lambda: True, config
MockGPTDataset, [0, None, 0], build_on_rank, config
).build()

sample = datasets[1][datasets[1].shuffle_index.argmax()]
argmax = sample['labels'].shape[0] - torch.flip(sample['labels'], [0]).argmax() - 1
if build_on_rank():
sample = datasets[1][datasets[1].shuffle_index.argmax()]
argmax = sample['labels'].shape[0] - torch.flip(sample['labels'], [0]).argmax() - 1

# Test add_extra_token_to_sequence
assert sample['tokens'][argmax] != tokenizer.eod
assert sample['labels'][argmax] == tokenizer.eod
# Test add_extra_token_to_sequence
assert sample['tokens'][argmax] != tokenizer.eod
assert sample['labels'][argmax] == tokenizer.eod

# Test eod_mask_loss, drop_last_partial_validation_sequence
assert argmax < sample['labels'].shape[0] - 1
assert torch.all(sample['labels'][argmax + 1 :] == 0)
assert not torch.any(
sample['loss_mask'][
torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0)
]
)
# Test eod_mask_loss, drop_last_partial_validation_sequence
assert argmax < sample['labels'].shape[0] - 1
assert torch.all(sample['labels'][argmax + 1 :] == 0)
assert not torch.any(
sample['loss_mask'][
torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0)
]
)

sample = datasets[1][None]
sample = datasets[1][None]

# Check handling of None index
assert not torch.any(sample['loss_mask'])
# Check handling of None index
assert not torch.any(sample['loss_mask'])


if __name__ == "__main__":
Expand Down
Loading