Skip to content

Commit 9de386d

Browse files
committed
Merge branch 'xren/dataset_fix' into 'main'
skip unnecessary attention mask generation See merge request ADLR/megatron-lm!1259
2 parents 6835eb7 + e7f376c commit 9de386d

File tree

6 files changed

+143
-41
lines changed

6 files changed

+143
-41
lines changed

megatron/core/datasets/gpt_dataset.py

Lines changed: 128 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import sys
66
import time
77
from dataclasses import dataclass
8-
from typing import Dict, Tuple
8+
from typing import Dict, Optional, Tuple
99

1010
import numpy
1111
import torch
1212

1313
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
1414
from megatron.core.datasets.indexed_dataset import IndexedDataset
15-
from megatron.core.datasets.megatron_dataset import MegatronDataset, MockDataset
15+
from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset, MockDataset
1616
from megatron.core.datasets.utils import Split, log_single_rank
1717

1818
logger = logging.getLogger(__name__)
@@ -29,6 +29,8 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
2929
3030
eod_mask_loss (bool): Option to enable the EOD mask loss
3131
32+
create_attention_mask (bool): Option to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
33+
3234
vocab_size (int): Size of vocabulary
3335
3436
"""
@@ -39,6 +41,8 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
3941

4042
eod_mask_loss: bool = None
4143

44+
create_attention_mask: bool = True
45+
4246
vocab_size: int = sys.maxsize
4347

4448
def __post_init__(self) -> None:
@@ -57,6 +61,29 @@ class MockGPTDataset(MockDataset):
5761
"""The mock GPT dataset
5862
"""
5963

64+
def __init__(
65+
self,
66+
dataset: Optional[LowLevelDataset],
67+
dataset_path: Optional[str],
68+
indices: Optional[numpy.ndarray],
69+
num_samples: int,
70+
index_split: Split,
71+
config: BlendedMegatronDatasetConfig,
72+
) -> None:
73+
super().__init__(dataset, dataset_path, indices, num_samples, index_split, config)
74+
75+
self.masks_and_position_ids_are_cacheable = not any(
76+
[
77+
self.config.reset_position_ids,
78+
self.config.reset_attention_mask,
79+
self.config.eod_mask_loss,
80+
]
81+
)
82+
self.masks_and_position_ids_are_cached = False
83+
self.cached_attention_mask = None
84+
self.cached_loss_mask = None
85+
self.cached_position_ids = None
86+
6087
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
6188
"""Return a sequence_length + 1 token sequence consisting of the following:
6289
- (1) S, the RNG length-sentinel in the range [0, sequence_length)
@@ -89,21 +116,43 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
89116
labels = text[1:].contiguous()
90117
tokens = text[:-1].contiguous()
91118

92-
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
93-
tokens,
94-
eod,
95-
self.config.reset_position_ids,
96-
self.config.reset_attention_mask,
97-
self.config.eod_mask_loss,
98-
)
99-
100-
return {
101-
"tokens": tokens,
102-
"labels": labels,
103-
"attention_mask": attention_mask,
104-
"loss_mask": loss_mask,
105-
"position_ids": position_ids,
106-
}
119+
if (
120+
not self.masks_and_position_ids_are_cacheable
121+
or not self.masks_and_position_ids_are_cached
122+
):
123+
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
124+
tokens,
125+
eod,
126+
self.config.reset_position_ids,
127+
self.config.reset_attention_mask,
128+
self.config.eod_mask_loss,
129+
self.config.create_attention_mask,
130+
)
131+
if self.masks_and_position_ids_are_cacheable:
132+
self.cached_attention_mask = attention_mask
133+
self.cached_loss_mask = loss_mask
134+
self.cached_position_ids = position_ids
135+
self.masks_and_position_ids_are_cached = True
136+
else:
137+
attention_mask = self.cached_attention_mask
138+
loss_mask = self.cached_loss_mask
139+
position_ids = self.cached_position_ids
140+
141+
if self.config.create_attention_mask:
142+
return {
143+
"tokens": tokens,
144+
"labels": labels,
145+
"attention_mask": attention_mask,
146+
"loss_mask": loss_mask,
147+
"position_ids": position_ids,
148+
}
149+
else:
150+
return {
151+
"tokens": tokens,
152+
"labels": labels,
153+
"loss_mask": loss_mask,
154+
"position_ids": position_ids,
155+
}
107156

108157

109158
class GPTDataset(MegatronDataset):
@@ -138,6 +187,18 @@ def __init__(
138187

139188
self.vocab_size = config.vocab_size
140189

190+
self.masks_and_position_ids_are_cacheable = not any(
191+
[
192+
self.config.reset_position_ids,
193+
self.config.reset_attention_mask,
194+
self.config.eod_mask_loss,
195+
]
196+
)
197+
self.masks_and_position_ids_are_cached = False
198+
self.cached_attention_mask = None
199+
self.cached_loss_mask = None
200+
self.cached_position_ids = None
201+
141202
def _finalize(self) -> None:
142203
"""Abstract method implementation
143204
@@ -205,21 +266,43 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
205266
tokens >= self.vocab_size
206267
), "An input token is out of bounds of the tokenizer vocabulary"
207268

208-
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
209-
tokens,
210-
self.config.tokenizer.eod,
211-
self.config.reset_position_ids,
212-
self.config.reset_attention_mask,
213-
self.config.eod_mask_loss,
214-
)
215-
216-
return {
217-
"tokens": tokens,
218-
"labels": labels,
219-
"attention_mask": attention_mask,
220-
"loss_mask": loss_mask,
221-
"position_ids": position_ids,
222-
}
269+
if (
270+
not self.masks_and_position_ids_are_cacheable
271+
or not self.masks_and_position_ids_are_cached
272+
):
273+
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
274+
tokens,
275+
self.config.tokenizer.eod,
276+
self.config.reset_position_ids,
277+
self.config.reset_attention_mask,
278+
self.config.eod_mask_loss,
279+
self.config.create_attention_mask,
280+
)
281+
if self.masks_and_position_ids_are_cacheable:
282+
self.cached_attention_mask = attention_mask
283+
self.cached_loss_mask = loss_mask
284+
self.cached_position_ids = position_ids
285+
self.masks_and_position_ids_are_cached = True
286+
else:
287+
attention_mask = self.cached_attention_mask
288+
loss_mask = self.cached_loss_mask
289+
position_ids = self.cached_position_ids
290+
291+
if self.config.create_attention_mask:
292+
return {
293+
"tokens": tokens,
294+
"labels": labels,
295+
"attention_mask": attention_mask,
296+
"loss_mask": loss_mask,
297+
"position_ids": position_ids,
298+
}
299+
else:
300+
return {
301+
"tokens": tokens,
302+
"labels": labels,
303+
"loss_mask": loss_mask,
304+
"position_ids": position_ids,
305+
}
223306

224307
def _query_document_sample_shuffle_indices(
225308
self, idx: int
@@ -575,6 +658,7 @@ def _get_ltor_masks_and_position_ids(
575658
reset_position_ids: bool,
576659
reset_attention_mask: bool,
577660
eod_mask_loss: bool,
661+
create_attention_mask: bool,
578662
):
579663
"""Build masks and position id for left to right model.
580664
@@ -589,6 +673,8 @@ def _get_ltor_masks_and_position_ids(
589673
590674
eod_mask_loss (bool): Switch to enable the EOD mask loss
591675
676+
create_attention_mask (bool): Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
677+
592678
Returns:
593679
torch.Tensor: Attention mask needed to be used for Attention
594680
@@ -598,9 +684,12 @@ def _get_ltor_masks_and_position_ids(
598684
"""
599685
seq_length = data.numel()
600686

601-
attention_mask = torch.tril(torch.ones((seq_length, seq_length), device=data.device)).unsqueeze(
602-
0
603-
)
687+
if create_attention_mask:
688+
attention_mask = torch.tril(
689+
torch.ones((seq_length, seq_length), device=data.device)
690+
).unsqueeze(0)
691+
else:
692+
attention_mask = None
604693

605694
# Loss mask.
606695
loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device)
@@ -625,14 +714,15 @@ def _get_ltor_masks_and_position_ids(
625714
for j in range(eod_index.numel()):
626715
i = eod_index[j]
627716
# Mask attention loss.
628-
if reset_attention_mask:
717+
if reset_attention_mask and attention_mask is not None:
629718
attention_mask[0, (i + 1) :, : (i + 1)] = 0
630719
# Reset positions.
631720
if reset_position_ids:
632721
position_ids[(i + 1) :] -= i + 1 - prev_index
633722
prev_index = i + 1
634723

635-
# Convert attention mask to binary:
636-
attention_mask = attention_mask < 0.5
724+
if attention_mask is not None:
725+
# Convert attention mask to binary:
726+
attention_mask = attention_mask < 0.5
637727

638728
return attention_mask, loss_mask, position_ids

megatron/training/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,9 @@ def _add_data_args(parser):
14021402
'end-of-document token.')
14031403
group.add_argument('--eod-mask-loss', action='store_true',
14041404
help='Mask loss for the end of document tokens.')
1405+
group.add_argument('--no-create-attention-mask-in-dataloader', action='store_false',
1406+
help='If set, do not create attention_masks in dataloader.',
1407+
dest='create_attention_mask_in_dataloader')
14051408

14061409
return parser
14071410

megatron/training/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ def get_batch_on_this_tp_rank(data_iterator):
278278
args = get_args()
279279

280280
def _broadcast(item):
281-
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
281+
if item is not None:
282+
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
282283

283284
if mpu.get_tensor_model_parallel_rank() == 0:
284285

@@ -291,7 +292,7 @@ def _broadcast(item):
291292
'tokens': data["tokens"].cuda(non_blocking = True),
292293
'labels': data["labels"].cuda(non_blocking = True),
293294
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
294-
'attention_mask': data["attention_mask"].cuda(non_blocking = True),
295+
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
295296
'position_ids': data["position_ids"].cuda(non_blocking = True)
296297
}
297298

@@ -317,7 +318,12 @@ def _broadcast(item):
317318
tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
318319
labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
319320
loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
320-
attention_mask=torch.empty((args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device())
321+
if args.create_attention_mask_in_dataloader:
322+
attention_mask=torch.empty(
323+
(args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
324+
)
325+
else:
326+
attention_mask=None
321327
position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
322328

323329
if args.pipeline_model_parallel_size == 1:

pretrain_gpt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def core_gpt_dataset_config_from_args(args):
179179
reset_position_ids=args.reset_position_ids,
180180
reset_attention_mask=args.reset_attention_mask,
181181
eod_mask_loss=args.eod_mask_loss,
182+
create_attention_mask=args.create_attention_mask_in_dataloader,
182183
vocab_size=get_tokenizer().vocab_size,
183184
)
184185

tests/functional_tests/jet_recipes/MR-gpt.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ spec:
5656
products:
5757
# MCore
5858
- {tp_size: [2], pp_size: [2]}
59+
- {tp_size: [2], pp_size: [2], extra_args: ["--no-create-attention-mask-in-dataloader"], args_meta: ["no_create_attention_mask_in_dataloader"]}
5960
- {tp_size: [2], pp_size: [2], extra_args: ["--no-mmap-bin-files"], args_meta: ["no_mmap_bin_files"]}
6061
- {tp_size: [1], pp_size: [4], vp_size: [1]}
6162
- {tp_size: [4], pp_size: [1], extra_args: ["--qk-layernorm --test-mode"]}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.92392, 10.93645, 10.89657, 10.86919, 10.74782, 10.658, 10.15864, 10.24906, 10.15088, 9.83933]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1735.0, 1861.0, 2111.0, 1844.0, 1762.0, 1858.0, 1554.0, 2031.0, 2309.0, 2225.0]}, "iteration_timing_avg": 0.15396205882352942}

0 commit comments

Comments
 (0)