Skip to content

Commit e485838

Browse files
committed
Attempt to fix, standardize, and stabilize the data unit tests
1 parent 23a1dca commit e485838

File tree

7 files changed

+138
-130
lines changed

7 files changed

+138
-130
lines changed

tests/unit_tests/data/test_bin_reader.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
from types import ModuleType, SimpleNamespace
77
from typing import Any, Dict
88

9-
import nltk
10-
import pytest
11-
129
try:
1310
import boto3
1411
import botocore.exceptions as exceptions
1512
except ModuleNotFoundError:
13+
# Create mock msc module
1614
boto3 = ModuleType("boto3")
17-
sys.modules[boto3.__name__] = boto3
15+
16+
# Create mock types submodule
1817
exceptions = ModuleType("botocore.exceptions")
18+
19+
# Register the mock module in sys.modules
20+
sys.modules[boto3.__name__] = boto3
1921
sys.modules[exceptions.__name__] = exceptions
2022

2123
try:
@@ -43,6 +45,8 @@ def __init__(self, offset: int, size: int):
4345
sys.modules[msc.__name__] = msc
4446
sys.modules[types_module.__name__] = types_module
4547

48+
import torch
49+
4650
from megatron.core.datasets.indexed_dataset import (
4751
IndexedDataset,
4852
ObjectStorageConfig,
@@ -58,9 +62,11 @@ def __init__(self, offset: int, size: int):
5862
gpt2_merge,
5963
gpt2_vocab,
6064
)
65+
from tests.unit_tests.test_utilities import Utils
66+
6167

6268
##
63-
# Overload client from boto3
69+
# Mock boto3
6470
##
6571

6672

@@ -72,7 +78,8 @@ def __init__(self, *args: Any) -> None:
7278

7379
def download_file(self, Bucket: str, Key: str, Filename: str) -> None:
7480
os.makedirs(os.path.dirname(Filename), exist_ok=True)
75-
os.system(f"cp {os.path.join('/', Bucket, Key)} {Filename}")
81+
remote_path = os.path.join("/", Bucket, Key)
82+
os.system(f"cp {remote_path} {Filename}")
7683
assert os.path.exists(Filename)
7784

7885
def upload_file(self, Filename: str, Bucket: str, Key: str) -> None:
@@ -104,27 +111,28 @@ def close(self) -> None:
104111

105112

106113
##
107-
# Overload ClientError from botocore.exceptions
114+
# Mock botocore.exceptions
108115
##
109116

110117

111118
class _LocalClientError(Exception):
112-
""" "Local test client error"""
119+
"""Local test client error"""
113120

114121
pass
115122

116123

117124
setattr(exceptions, "ClientError", _LocalClientError)
118125

119126
##
120-
# Mock multistorageclient module
127+
# Mock msc.open, msc.download_file, msc.resolve_storage_client
121128
##
122129

123130

124131
def _msc_download_file(remote_path, local_path):
125-
remote_path = remote_path.removeprefix(MSC_PREFIX + "default")
132+
remote_path = os.path.join("/", remote_path.removeprefix(MSC_PREFIX))
126133
os.makedirs(os.path.dirname(local_path), exist_ok=True)
127134
os.system(f"cp {remote_path} {local_path}")
135+
assert os.path.exists(local_path)
128136

129137

130138
def _msc_resolve_storage_client(path):
@@ -134,28 +142,29 @@ def read(self, path, byte_range):
134142
f.seek(byte_range.offset)
135143
return f.read(byte_range.size)
136144

137-
return StorageClient(), path.removeprefix(MSC_PREFIX + "default")
145+
return StorageClient(), os.path.join("/", path.removeprefix(MSC_PREFIX))
138146

139147

140148
setattr(msc, "open", open)
141149
setattr(msc, "download_file", _msc_download_file)
142150
setattr(msc, "resolve_storage_client", _msc_resolve_storage_client)
143151

144152

145-
@pytest.mark.flaky
146-
@pytest.mark.flaky_in_dev
147153
def test_bin_reader():
148-
with tempfile.TemporaryDirectory() as temp_dir:
149-
# set the default nltk data path
150-
os.environ["NLTK_DATA"] = os.path.join(temp_dir, "nltk_data")
151-
nltk.data.path.append(os.environ["NLTK_DATA"])
154+
if torch.distributed.is_available():
155+
Utils.initialize_distributed()
156+
if torch.distributed.get_rank() != 0:
157+
return
152158

159+
with tempfile.TemporaryDirectory() as temp_dir:
153160
path_to_raws = os.path.join(temp_dir, "sample_raws")
154161
path_to_data = os.path.join(temp_dir, "sample_data")
155-
path_to_object_storage_cache = os.path.join(temp_dir, "object_storage_cache")
162+
path_to_object_storage_cache_msc = os.path.join(temp_dir, "object_storage_cache_msc")
163+
path_to_object_storage_cache_s3 = os.path.join(temp_dir, "object_storage_cache_s3")
156164
os.mkdir(path_to_raws)
157165
os.mkdir(path_to_data)
158-
os.mkdir(path_to_object_storage_cache)
166+
os.mkdir(path_to_object_storage_cache_msc)
167+
os.mkdir(path_to_object_storage_cache_s3)
159168

160169
# create the dummy resources
161170
dummy_jsonl(path_to_raws)
@@ -195,27 +204,26 @@ def test_bin_reader():
195204
assert isinstance(indexed_dataset_mmap.bin_reader, _MMapBinReader)
196205

197206
indexed_dataset_msc = IndexedDataset(
198-
MSC_PREFIX + "default" + prefix, # use the default profile to access the filesystem
207+
MSC_PREFIX + prefix.lstrip("/"),
199208
multimodal=False,
200209
mmap=False,
201210
object_storage_config=ObjectStorageConfig(
202-
path_to_idx_cache=path_to_object_storage_cache
211+
path_to_idx_cache=path_to_object_storage_cache_msc
203212
),
204213
)
205214
assert isinstance(indexed_dataset_msc.bin_reader, _MultiStorageClientBinReader)
206215
assert len(indexed_dataset_msc) == len(indexed_dataset_file)
207216
assert len(indexed_dataset_msc) == len(indexed_dataset_mmap)
208217

209218
indexed_dataset_s3 = IndexedDataset(
210-
S3_PREFIX + prefix,
219+
S3_PREFIX + prefix.lstrip("/"),
211220
multimodal=False,
212221
mmap=False,
213222
object_storage_config=ObjectStorageConfig(
214-
path_to_idx_cache=path_to_object_storage_cache
223+
path_to_idx_cache=path_to_object_storage_cache_s3
215224
),
216225
)
217226
assert isinstance(indexed_dataset_s3.bin_reader, _S3BinReader)
218-
219227
assert len(indexed_dataset_s3) == len(indexed_dataset_file)
220228
assert len(indexed_dataset_s3) == len(indexed_dataset_mmap)
221229

@@ -226,6 +234,7 @@ def test_bin_reader():
226234
for idx in indices:
227235
assert (indexed_dataset_s3[idx] == indexed_dataset_file[idx]).all()
228236
assert (indexed_dataset_s3[idx] == indexed_dataset_mmap[idx]).all()
237+
assert (indexed_dataset_s3[idx] == indexed_dataset_msc[idx]).all()
229238

230239

231240
if __name__ == "__main__":

tests/unit_tests/data/test_gpt_dataset.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ def test_mock_gpt_dataset():
3434
if torch.distributed.get_rank() == 0:
3535
compile_helpers()
3636
torch.distributed.barrier()
37+
build_on_rank = lambda: torch.distributed.get_rank() == 0
3738
else:
3839
compile_helpers()
40+
build_on_rank = lambda: True
3941

4042
tokenizer = _NullTokenizer(vocab_size=_MOCK_VOCAB_SIZE)
4143

@@ -51,26 +53,27 @@ def test_mock_gpt_dataset():
5153
)
5254

5355
datasets = BlendedMegatronDatasetBuilder(
54-
MockGPTDataset, [100, 100, 100], lambda: True, config
56+
MockGPTDataset, [100, 100, 100], build_on_rank, config
5557
).build()
5658

57-
N = 10
59+
if build_on_rank():
60+
N = 10
5861

59-
# Check iso-index variance by split
60-
subsets = [sample_N(dataset, N, randomize=False) for dataset in datasets]
61-
assert not numpy.allclose(subsets[0], subsets[1])
62-
assert not numpy.allclose(subsets[0], subsets[2])
63-
assert not numpy.allclose(subsets[1], subsets[2])
62+
# Check iso-index variance by split
63+
subsets = [sample_N(dataset, N, randomize=False) for dataset in datasets]
64+
assert not numpy.allclose(subsets[0], subsets[1])
65+
assert not numpy.allclose(subsets[0], subsets[2])
66+
assert not numpy.allclose(subsets[1], subsets[2])
6467

65-
# Check iso-split / iso-index identity
66-
subset_1A = sample_N(datasets[0], N, randomize=False)
67-
subset_1B = sample_N(datasets[0], N, randomize=False)
68-
assert numpy.allclose(subset_1A, subset_1B)
68+
# Check iso-split / iso-index identity
69+
subset_1A = sample_N(datasets[0], N, randomize=False)
70+
subset_1B = sample_N(datasets[0], N, randomize=False)
71+
assert numpy.allclose(subset_1A, subset_1B)
6972

70-
# Check iso-split variance by index
71-
subset_1A = sample_N(datasets[0], N, randomize=True)
72-
subset_1B = sample_N(datasets[0], N, randomize=True)
73-
assert not numpy.allclose(subset_1A, subset_1B)
73+
# Check iso-split variance by index
74+
subset_1A = sample_N(datasets[0], N, randomize=True)
75+
subset_1B = sample_N(datasets[0], N, randomize=True)
76+
assert not numpy.allclose(subset_1A, subset_1B)
7477

7578
config = GPTDatasetConfig(
7679
random_seed=1234,
@@ -86,29 +89,30 @@ def test_mock_gpt_dataset():
8689
)
8790

8891
datasets = BlendedMegatronDatasetBuilder(
89-
MockGPTDataset, [0, None, 0], lambda: True, config
92+
MockGPTDataset, [0, None, 0], build_on_rank, config
9093
).build()
9194

92-
sample = datasets[1][datasets[1].shuffle_index.argmax()]
93-
argmax = sample['labels'].shape[0] - torch.flip(sample['labels'], [0]).argmax() - 1
95+
if build_on_rank():
96+
sample = datasets[1][datasets[1].shuffle_index.argmax()]
97+
argmax = sample['labels'].shape[0] - torch.flip(sample['labels'], [0]).argmax() - 1
9498

95-
# Test add_extra_token_to_sequence
96-
assert sample['tokens'][argmax] != tokenizer.eod
97-
assert sample['labels'][argmax] == tokenizer.eod
99+
# Test add_extra_token_to_sequence
100+
assert sample['tokens'][argmax] != tokenizer.eod
101+
assert sample['labels'][argmax] == tokenizer.eod
98102

99-
# Test eod_mask_loss, drop_last_partial_validation_sequence
100-
assert argmax < sample['labels'].shape[0] - 1
101-
assert torch.all(sample['labels'][argmax + 1 :] == 0)
102-
assert not torch.any(
103-
sample['loss_mask'][
104-
torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0)
105-
]
106-
)
103+
# Test eod_mask_loss, drop_last_partial_validation_sequence
104+
assert argmax < sample['labels'].shape[0] - 1
105+
assert torch.all(sample['labels'][argmax + 1 :] == 0)
106+
assert not torch.any(
107+
sample['loss_mask'][
108+
torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0)
109+
]
110+
)
107111

108-
sample = datasets[1][None]
112+
sample = datasets[1][None]
109113

110-
# Check handling of None index
111-
assert not torch.any(sample['loss_mask'])
114+
# Check handling of None index
115+
assert not torch.any(sample['loss_mask'])
112116

113117

114118
if __name__ == "__main__":

0 commit comments

Comments
 (0)