Skip to content

Commit ea42e5f

Browse files
committed
Attempt to fix, standardize, and stabilize the data unit tests
1 parent 23a1dca commit ea42e5f

File tree

7 files changed

+142
-131
lines changed

7 files changed

+142
-131
lines changed

tests/unit_tests/data/test_bin_reader.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
13
import os
24
import random
35
import sys
@@ -6,16 +8,18 @@
68
from types import ModuleType, SimpleNamespace
79
from typing import Any, Dict
810

9-
import nltk
10-
import pytest
11-
1211
try:
1312
import boto3
1413
import botocore.exceptions as exceptions
1514
except ModuleNotFoundError:
15+
# Create mock msc module
1616
boto3 = ModuleType("boto3")
17-
sys.modules[boto3.__name__] = boto3
17+
18+
# Create mock types submodule
1819
exceptions = ModuleType("botocore.exceptions")
20+
21+
# Register the mock module in sys.modules
22+
sys.modules[boto3.__name__] = boto3
1923
sys.modules[exceptions.__name__] = exceptions
2024

2125
try:
@@ -43,6 +47,8 @@ def __init__(self, offset: int, size: int):
4347
sys.modules[msc.__name__] = msc
4448
sys.modules[types_module.__name__] = types_module
4549

50+
import torch
51+
4652
from megatron.core.datasets.indexed_dataset import (
4753
IndexedDataset,
4854
ObjectStorageConfig,
@@ -58,9 +64,11 @@ def __init__(self, offset: int, size: int):
5864
gpt2_merge,
5965
gpt2_vocab,
6066
)
67+
from tests.unit_tests.test_utilities import Utils
68+
6169

6270
##
63-
# Overload client from boto3
71+
# Mock boto3
6472
##
6573

6674

@@ -72,7 +80,8 @@ def __init__(self, *args: Any) -> None:
7280

7381
def download_file(self, Bucket: str, Key: str, Filename: str) -> None:
7482
os.makedirs(os.path.dirname(Filename), exist_ok=True)
75-
os.system(f"cp {os.path.join('/', Bucket, Key)} {Filename}")
83+
remote_path = os.path.join("/", Bucket, Key)
84+
os.system(f"cp {remote_path} {Filename}")
7685
assert os.path.exists(Filename)
7786

7887
def upload_file(self, Filename: str, Bucket: str, Key: str) -> None:
@@ -104,27 +113,28 @@ def close(self) -> None:
104113

105114

106115
##
107-
# Overload ClientError from botocore.exceptions
116+
# Mock botocore.exceptions
108117
##
109118

110119

111120
class _LocalClientError(Exception):
112-
""" "Local test client error"""
121+
"""Local test client error"""
113122

114123
pass
115124

116125

117126
setattr(exceptions, "ClientError", _LocalClientError)
118127

119128
##
120-
# Mock multistorageclient module
129+
# Mock msc.open, msc.download_file, msc.resolve_storage_client
121130
##
122131

123132

124133
def _msc_download_file(remote_path, local_path):
125-
remote_path = remote_path.removeprefix(MSC_PREFIX + "default")
134+
remote_path = os.path.join("/", remote_path.removeprefix(MSC_PREFIX))
126135
os.makedirs(os.path.dirname(local_path), exist_ok=True)
127136
os.system(f"cp {remote_path} {local_path}")
137+
assert os.path.exists(local_path)
128138

129139

130140
def _msc_resolve_storage_client(path):
@@ -134,28 +144,29 @@ def read(self, path, byte_range):
134144
f.seek(byte_range.offset)
135145
return f.read(byte_range.size)
136146

137-
return StorageClient(), path.removeprefix(MSC_PREFIX + "default")
147+
return StorageClient(), os.path.join("/", path.removeprefix(MSC_PREFIX))
138148

139149

140150
setattr(msc, "open", open)
141151
setattr(msc, "download_file", _msc_download_file)
142152
setattr(msc, "resolve_storage_client", _msc_resolve_storage_client)
143153

144154

145-
@pytest.mark.flaky
146-
@pytest.mark.flaky_in_dev
147155
def test_bin_reader():
148-
with tempfile.TemporaryDirectory() as temp_dir:
149-
# set the default nltk data path
150-
os.environ["NLTK_DATA"] = os.path.join(temp_dir, "nltk_data")
151-
nltk.data.path.append(os.environ["NLTK_DATA"])
156+
if torch.distributed.is_available():
157+
Utils.initialize_distributed()
158+
if torch.distributed.get_rank() != 0:
159+
return
152160

161+
with tempfile.TemporaryDirectory() as temp_dir:
153162
path_to_raws = os.path.join(temp_dir, "sample_raws")
154163
path_to_data = os.path.join(temp_dir, "sample_data")
155-
path_to_object_storage_cache = os.path.join(temp_dir, "object_storage_cache")
164+
path_to_object_storage_cache_msc = os.path.join(temp_dir, "object_storage_cache_msc")
165+
path_to_object_storage_cache_s3 = os.path.join(temp_dir, "object_storage_cache_s3")
156166
os.mkdir(path_to_raws)
157167
os.mkdir(path_to_data)
158-
os.mkdir(path_to_object_storage_cache)
168+
os.mkdir(path_to_object_storage_cache_msc)
169+
os.mkdir(path_to_object_storage_cache_s3)
159170

160171
# create the dummy resources
161172
dummy_jsonl(path_to_raws)
@@ -195,27 +206,26 @@ def test_bin_reader():
195206
assert isinstance(indexed_dataset_mmap.bin_reader, _MMapBinReader)
196207

197208
indexed_dataset_msc = IndexedDataset(
198-
MSC_PREFIX + "default" + prefix, # use the default profile to access the filesystem
209+
MSC_PREFIX + prefix.lstrip("/"),
199210
multimodal=False,
200211
mmap=False,
201212
object_storage_config=ObjectStorageConfig(
202-
path_to_idx_cache=path_to_object_storage_cache
213+
path_to_idx_cache=path_to_object_storage_cache_msc
203214
),
204215
)
205216
assert isinstance(indexed_dataset_msc.bin_reader, _MultiStorageClientBinReader)
206217
assert len(indexed_dataset_msc) == len(indexed_dataset_file)
207218
assert len(indexed_dataset_msc) == len(indexed_dataset_mmap)
208219

209220
indexed_dataset_s3 = IndexedDataset(
210-
S3_PREFIX + prefix,
221+
S3_PREFIX + prefix.lstrip("/"),
211222
multimodal=False,
212223
mmap=False,
213224
object_storage_config=ObjectStorageConfig(
214-
path_to_idx_cache=path_to_object_storage_cache
225+
path_to_idx_cache=path_to_object_storage_cache_s3
215226
),
216227
)
217228
assert isinstance(indexed_dataset_s3.bin_reader, _S3BinReader)
218-
219229
assert len(indexed_dataset_s3) == len(indexed_dataset_file)
220230
assert len(indexed_dataset_s3) == len(indexed_dataset_mmap)
221231

@@ -226,6 +236,7 @@ def test_bin_reader():
226236
for idx in indices:
227237
assert (indexed_dataset_s3[idx] == indexed_dataset_file[idx]).all()
228238
assert (indexed_dataset_s3[idx] == indexed_dataset_mmap[idx]).all()
239+
assert (indexed_dataset_s3[idx] == indexed_dataset_msc[idx]).all()
229240

230241

231242
if __name__ == "__main__":

tests/unit_tests/data/test_gpt_dataset.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ def test_mock_gpt_dataset():
3434
if torch.distributed.get_rank() == 0:
3535
compile_helpers()
3636
torch.distributed.barrier()
37+
build_on_rank = lambda: torch.distributed.get_rank() == 0
3738
else:
3839
compile_helpers()
40+
build_on_rank = lambda: True
3941

4042
tokenizer = _NullTokenizer(vocab_size=_MOCK_VOCAB_SIZE)
4143

@@ -51,26 +53,27 @@ def test_mock_gpt_dataset():
5153
)
5254

5355
datasets = BlendedMegatronDatasetBuilder(
54-
MockGPTDataset, [100, 100, 100], lambda: True, config
56+
MockGPTDataset, [100, 100, 100], build_on_rank, config
5557
).build()
5658

57-
N = 10
59+
if build_on_rank():
60+
N = 10
5861

59-
# Check iso-index variance by split
60-
subsets = [sample_N(dataset, N, randomize=False) for dataset in datasets]
61-
assert not numpy.allclose(subsets[0], subsets[1])
62-
assert not numpy.allclose(subsets[0], subsets[2])
63-
assert not numpy.allclose(subsets[1], subsets[2])
62+
# Check iso-index variance by split
63+
subsets = [sample_N(dataset, N, randomize=False) for dataset in datasets]
64+
assert not numpy.allclose(subsets[0], subsets[1])
65+
assert not numpy.allclose(subsets[0], subsets[2])
66+
assert not numpy.allclose(subsets[1], subsets[2])
6467

65-
# Check iso-split / iso-index identity
66-
subset_1A = sample_N(datasets[0], N, randomize=False)
67-
subset_1B = sample_N(datasets[0], N, randomize=False)
68-
assert numpy.allclose(subset_1A, subset_1B)
68+
# Check iso-split / iso-index identity
69+
subset_1A = sample_N(datasets[0], N, randomize=False)
70+
subset_1B = sample_N(datasets[0], N, randomize=False)
71+
assert numpy.allclose(subset_1A, subset_1B)
6972

70-
# Check iso-split variance by index
71-
subset_1A = sample_N(datasets[0], N, randomize=True)
72-
subset_1B = sample_N(datasets[0], N, randomize=True)
73-
assert not numpy.allclose(subset_1A, subset_1B)
73+
# Check iso-split variance by index
74+
subset_1A = sample_N(datasets[0], N, randomize=True)
75+
subset_1B = sample_N(datasets[0], N, randomize=True)
76+
assert not numpy.allclose(subset_1A, subset_1B)
7477

7578
config = GPTDatasetConfig(
7679
random_seed=1234,
@@ -86,29 +89,30 @@ def test_mock_gpt_dataset():
8689
)
8790

8891
datasets = BlendedMegatronDatasetBuilder(
89-
MockGPTDataset, [0, None, 0], lambda: True, config
92+
MockGPTDataset, [0, None, 0], build_on_rank, config
9093
).build()
9194

92-
sample = datasets[1][datasets[1].shuffle_index.argmax()]
93-
argmax = sample['labels'].shape[0] - torch.flip(sample['labels'], [0]).argmax() - 1
95+
if build_on_rank():
96+
sample = datasets[1][datasets[1].shuffle_index.argmax()]
97+
argmax = sample['labels'].shape[0] - torch.flip(sample['labels'], [0]).argmax() - 1
9498

95-
# Test add_extra_token_to_sequence
96-
assert sample['tokens'][argmax] != tokenizer.eod
97-
assert sample['labels'][argmax] == tokenizer.eod
99+
# Test add_extra_token_to_sequence
100+
assert sample['tokens'][argmax] != tokenizer.eod
101+
assert sample['labels'][argmax] == tokenizer.eod
98102

99-
# Test eod_mask_loss, drop_last_partial_validation_sequence
100-
assert argmax < sample['labels'].shape[0] - 1
101-
assert torch.all(sample['labels'][argmax + 1 :] == 0)
102-
assert not torch.any(
103-
sample['loss_mask'][
104-
torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0)
105-
]
106-
)
103+
# Test eod_mask_loss, drop_last_partial_validation_sequence
104+
assert argmax < sample['labels'].shape[0] - 1
105+
assert torch.all(sample['labels'][argmax + 1 :] == 0)
106+
assert not torch.any(
107+
sample['loss_mask'][
108+
torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0)
109+
]
110+
)
107111

108-
sample = datasets[1][None]
112+
sample = datasets[1][None]
109113

110-
# Check handling of None index
111-
assert not torch.any(sample['loss_mask'])
114+
# Check handling of None index
115+
assert not torch.any(sample['loss_mask'])
112116

113117

114118
if __name__ == "__main__":

0 commit comments

Comments
 (0)