Skip to content

Commit db3a3f7

Browse files
committed
Merge branch 'boxiangw/mlperf-option-add-one-extra-token' into 'main'
[MLPerf] GPT dataset features: drop last partial validation sequence, drop extra token, return sample with 1s loss mask, mock dataset testing See merge request ADLR/megatron-lm!1223
2 parents 2297178 + c90aa16 commit db3a3f7

27 files changed

+543
-414
lines changed

examples/run_simple_mcore_train_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def get_train_data_iterator():
4949
config = GPTDatasetConfig(
5050
random_seed = 0,
5151
sequence_length = 64,
52-
blend=[],
53-
mock=True,
52+
blend=None,
5453
reset_position_ids=False,
5554
reset_attention_mask=False,
5655
eod_mask_loss=False,

megatron/core/QuickStart.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,9 @@ from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset
8686
8787
def get_train_data_iterator():
8888
config = GPTDatasetConfig(
89-
random_seed = 0,
90-
sequence_length = 64,
91-
blend=[],
92-
mock=True,
89+
random_seed=0,
90+
sequence_length=64,
91+
blend=None,
9392
reset_position_ids=False,
9493
reset_attention_mask=False,
9594
eod_mask_loss=False,

megatron/core/datasets/bert_dataset.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class BERTMaskedWordPieceDataset(MaskedWordPieceDataset):
3838
3939
indexed_indices (numpy.ndarray): The set of the documents indices to expose
4040
41-
num_samples (int): The number of samples to draw from the indexed dataset
41+
num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
4242
4343
index_split (Split): The indexed_indices Split
4444
@@ -50,17 +50,14 @@ def __init__(
5050
indexed_dataset: IndexedDataset,
5151
dataset_path: str,
5252
indexed_indices: numpy.ndarray,
53-
num_samples: int,
53+
num_samples: Optional[int],
5454
index_split: Split,
5555
config: BERTMaskedWordPieceDatasetConfig,
5656
) -> None:
5757
super().__init__(
5858
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
5959
)
6060

61-
def _finalize(self) -> None:
62-
"""Abstract method implementation
63-
"""
6461
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
6562
# Account for the single <cls> and two <sep> token ids
6663
self.sample_index = self._build_sample_index(

megatron/core/datasets/blended_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
166166
log_single_rank(
167167
logger,
168168
logging.WARNING,
169-
"Unable to save the blending indexes because path_to_cache is None",
169+
f"Unable to save the {type(self).__name__} indexes because path_to_cache is None",
170170
)
171171

172172
t_end = time.time()

megatron/core/datasets/blended_megatron_dataset_builder.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from megatron.core.datasets.blended_dataset import BlendedDataset
1111
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
12-
from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset, MockDataset
12+
from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset
1313
from megatron.core.datasets.utils import Split, log_single_rank, normalize
1414
from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank
1515

@@ -51,13 +51,11 @@ def __init__(
5151

5252
log_single_rank(
5353
logger,
54-
logging.WARNING,
54+
logging.INFO,
5555
f"Building dataset splits with cls={cls.__name__}, sizes={self.sizes}, and config={self.config}",
5656
)
5757

58-
if self.config.mock:
59-
assert issubclass(self.cls, MockDataset)
60-
else:
58+
if not self.config.mock:
6159
for split in Split:
6260
size_is_none = self.sizes[split.value] is None
6361
if self.config.blend_per_split is None:
@@ -151,7 +149,13 @@ def _build_blended_dataset_splits(self,) -> List[Optional[TopLevelDataset]]:
151149
# Return fake "mock" datasets
152150
##
153151
if self.config.mock:
154-
return self._build_megatron_dataset_splits(None, None, self.sizes)
152+
split = self.config.split_matrix
153+
try:
154+
return self._build_megatron_dataset_splits(None, split, self.sizes)
155+
except Exception as error:
156+
raise Exception(
157+
f"{self.cls.__name__} failed to build as a mock data generator"
158+
) from error
155159

156160
##
157161
# All splits come from the same distribution
@@ -282,7 +286,7 @@ def _build_megatron_dataset_splits(
282286
"""Build each MidLevelDataset split from a single LowLevelDataset
283287
284288
Args:
285-
dataset_path (Optional[str]): The path on disk which defines the underlying LowLevelDataset, e.g. the .bin and .idx file prefix when self.cls is of type IndexedMegatronDataset or None when self.cls is of type MockDataset
289+
dataset_path (Optional[str]): The path on disk which defines the underlying LowLevelDataset, or None for mock dataset classes
286290
287291
split (List[Tuple[float, float]]): The dataset split matrix
288292
@@ -292,33 +296,23 @@ def _build_megatron_dataset_splits(
292296
List[Optional[MidLevelDataset]]: The MidLevelDataset (or None) per split
293297
"""
294298
# Build the low level dataset
295-
if issubclass(self.cls, MockDataset):
296-
low_level_dataset = None
297-
elif issubclass(self.cls, MegatronDataset):
298-
low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config)
299-
else:
300-
raise NotImplementedError
299+
low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config)
301300

302301
# Build the split indices for the low level dataset
303-
if low_level_dataset is not None:
304-
num_elements = self.cls.numel_low_level_dataset(low_level_dataset)
305-
split_indices = []
306-
for i, _ in enumerate(Split):
307-
if split[i] is not None:
308-
beg = int(round(split[i][0] * float(num_elements)))
309-
end = int(round(split[i][1] * float(num_elements)))
310-
split_indices.append(
311-
numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32)
312-
)
313-
else:
314-
split_indices.append(None)
315-
else:
316-
split_indices = [None for _ in Split]
302+
num_elements = self.cls.numel_low_level_dataset(low_level_dataset)
303+
split_indices = []
304+
for i, _ in enumerate(Split):
305+
if split[i] is not None:
306+
beg = int(round(split[i][0] * float(num_elements)))
307+
end = int(round(split[i][1] * float(num_elements)))
308+
split_indices.append(numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32))
309+
else:
310+
split_indices.append(None)
317311

318312
# Build the mid level dataset
319313
mid_level_datasets = []
320314
for i, _split in enumerate(Split):
321-
if not self.config.mock and split[i] is None:
315+
if split[i] is None:
322316
mid_level_datasets.append(None)
323317
else:
324318
mid_level_datasets.append(

megatron/core/datasets/blended_megatron_dataset_config.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from dataclasses import dataclass, field
77
from typing import List, Optional, Tuple
88

9-
import torch
10-
119
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
1210
from megatron.core.datasets.utils import Split, log_single_rank, normalize
1311

@@ -53,46 +51,51 @@ class BlendedMegatronDatasetConfig:
5351
mmap_bin_files: bool = True
5452
"""Whether to mmap the .bin files or use file pointers."""
5553

56-
mock: bool = False
57-
"""Whether to bypass real data loading and validation in favor of mock data generation."""
54+
mock: bool = field(init=False, default=False)
55+
"""Whether to bypass real data loading and validation in favor of mock data generation.
56+
Created automatically from 'blend' and 'blend_per_split'. Not to be passed in to the
57+
constructor.
58+
"""
5859

5960
tokenizer: Optional[MegatronTokenizer] = None
6061
"""The MegatronTokenizer instance or None. Required for datasets which do online tokenization."""
6162

6263
def __post_init__(self) -> None:
6364
"""Do asserts and set fields post init
6465
"""
65-
log_single_rank(logger, logging.INFO, f"mock = {self.mock}")
66-
67-
if not self.mock:
68-
if self.blend_per_split is not None and any(self.blend_per_split):
69-
assert self.blend is None, "blend and blend_per_split are incompatible"
70-
assert self.split is None, "split and blend_per_split are incompatible"
71-
assert len(self.blend_per_split) == len(
72-
Split
73-
), f"blend_per_split must contain {len(Split)} blends"
74-
for split in Split:
75-
if self.blend_per_split[split.value] is None:
76-
log_single_rank(
77-
logger, logging.INFO, f"blend not provided for {split.name} split"
78-
)
79-
else:
80-
assert self.blend_per_split[split.value][1] is None or len(
81-
self.blend_per_split[split.value][0]
82-
) == len(
83-
self.blend_per_split[split.value][1]
84-
), "blend per split prefixes and weights must be equal in number"
85-
else:
86-
assert (
87-
self.blend is not None
88-
), "one of either blend or blend_per_split must be provided"
89-
assert self.split is not None, "both blend and split must be provided"
66+
if self.blend_per_split is not None and any(self.blend_per_split):
67+
assert self.blend is None, "blend and blend_per_split are incompatible"
68+
assert self.split is None, "split and blend_per_split are incompatible"
69+
assert len(self.blend_per_split) == len(
70+
Split
71+
), f"blend_per_split must contain {len(Split)} blends"
72+
for split in Split:
73+
if self.blend_per_split[split.value] is None:
74+
log_single_rank(
75+
logger, logging.INFO, f"blend not provided for {split.name} split"
76+
)
77+
else:
78+
assert self.blend_per_split[split.value][1] is None or len(
79+
self.blend_per_split[split.value][0]
80+
) == len(
81+
self.blend_per_split[split.value][1]
82+
), "blend per split prefixes and weights must be equal in number"
83+
else:
84+
assert self.split is not None, "split must be provided in absence of blend_per_split"
85+
split_vector = parse_and_normalize_split(self.split)
86+
self.split_matrix = convert_split_vector_to_split_matrix(split_vector)
87+
log_single_rank(logger, logging.INFO, f"Let split_matrix = {self.split_matrix}")
88+
if self.blend is not None:
9089
assert self.blend[1] is None or len(self.blend[0]) == len(
9190
self.blend[1]
9291
), "blend prefixes and weights must be equal in number"
93-
split_vector = parse_and_normalize_split(self.split)
94-
self.split_matrix = convert_split_vector_to_split_matrix(split_vector)
95-
log_single_rank(logger, logging.INFO, f"Let split_matrix = {self.split_matrix}")
92+
else:
93+
self.mock = True
94+
log_single_rank(
95+
logger,
96+
logging.INFO,
97+
f"Let mock = True, as both blend and blend_per_split are None",
98+
)
9699

97100

98101
def parse_and_normalize_split(split: str) -> List[float]:

0 commit comments

Comments
 (0)