55import sys
66import time
77from dataclasses import dataclass
8- from typing import Dict , Tuple
8+ from typing import Dict , Optional , Tuple
99
1010import numpy
1111import torch
1212
1313from megatron .core .datasets .blended_megatron_dataset_config import BlendedMegatronDatasetConfig
1414from megatron .core .datasets .indexed_dataset import IndexedDataset
15- from megatron .core .datasets .megatron_dataset import MegatronDataset , MockDataset
15+ from megatron .core .datasets .megatron_dataset import LowLevelDataset , MegatronDataset , MockDataset
1616from megatron .core .datasets .utils import Split , log_single_rank
1717
1818logger = logging .getLogger (__name__ )
@@ -29,6 +29,8 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
2929
3030 eod_mask_loss (bool): Option to enable the EOD mask loss
3131
32+ create_attention_mask (bool): Option to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
33+
3234 vocab_size (int): Size of vocabulary
3335
3436 """
@@ -39,6 +41,8 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
3941
4042 eod_mask_loss : bool = None
4143
44+ create_attention_mask : bool = True
45+
4246 vocab_size : int = sys .maxsize
4347
4448 def __post_init__ (self ) -> None :
@@ -57,6 +61,29 @@ class MockGPTDataset(MockDataset):
5761 """The mock GPT dataset
5862 """
5963
64+ def __init__ (
65+ self ,
66+ dataset : Optional [LowLevelDataset ],
67+ dataset_path : Optional [str ],
68+ indices : Optional [numpy .ndarray ],
69+ num_samples : int ,
70+ index_split : Split ,
71+ config : BlendedMegatronDatasetConfig ,
72+ ) -> None :
73+ super ().__init__ (dataset , dataset_path , indices , num_samples , index_split , config )
74+
75+ self .masks_and_position_ids_are_cacheable = not any (
76+ [
77+ self .config .reset_position_ids ,
78+ self .config .reset_attention_mask ,
79+ self .config .eod_mask_loss ,
80+ ]
81+ )
82+ self .masks_and_position_ids_are_cached = False
83+ self .cached_attention_mask = None
84+ self .cached_loss_mask = None
85+ self .cached_position_ids = None
86+
6087 def __getitem__ (self , idx : int ) -> Dict [str , torch .Tensor ]:
6188 """Return a sequence_length + 1 token sequence consisting of the following:
6289 - (1) S, the RNG length-sentinel in the range [0, sequence_length)
@@ -89,21 +116,43 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
89116 labels = text [1 :].contiguous ()
90117 tokens = text [:- 1 ].contiguous ()
91118
92- attention_mask , loss_mask , position_ids = _get_ltor_masks_and_position_ids (
93- tokens ,
94- eod ,
95- self .config .reset_position_ids ,
96- self .config .reset_attention_mask ,
97- self .config .eod_mask_loss ,
98- )
99-
100- return {
101- "tokens" : tokens ,
102- "labels" : labels ,
103- "attention_mask" : attention_mask ,
104- "loss_mask" : loss_mask ,
105- "position_ids" : position_ids ,
106- }
119+ if (
120+ not self .masks_and_position_ids_are_cacheable
121+ or not self .masks_and_position_ids_are_cached
122+ ):
123+ attention_mask , loss_mask , position_ids = _get_ltor_masks_and_position_ids (
124+ tokens ,
125+ eod ,
126+ self .config .reset_position_ids ,
127+ self .config .reset_attention_mask ,
128+ self .config .eod_mask_loss ,
129+ self .config .create_attention_mask ,
130+ )
131+ if self .masks_and_position_ids_are_cacheable :
132+ self .cached_attention_mask = attention_mask
133+ self .cached_loss_mask = loss_mask
134+ self .cached_position_ids = position_ids
135+ self .masks_and_position_ids_are_cached = True
136+ else :
137+ attention_mask = self .cached_attention_mask
138+ loss_mask = self .cached_loss_mask
139+ position_ids = self .cached_position_ids
140+
141+ if self .config .create_attention_mask :
142+ return {
143+ "tokens" : tokens ,
144+ "labels" : labels ,
145+ "attention_mask" : attention_mask ,
146+ "loss_mask" : loss_mask ,
147+ "position_ids" : position_ids ,
148+ }
149+ else :
150+ return {
151+ "tokens" : tokens ,
152+ "labels" : labels ,
153+ "loss_mask" : loss_mask ,
154+ "position_ids" : position_ids ,
155+ }
107156
108157
109158class GPTDataset (MegatronDataset ):
@@ -138,6 +187,18 @@ def __init__(
138187
139188 self .vocab_size = config .vocab_size
140189
190+ self .masks_and_position_ids_are_cacheable = not any (
191+ [
192+ self .config .reset_position_ids ,
193+ self .config .reset_attention_mask ,
194+ self .config .eod_mask_loss ,
195+ ]
196+ )
197+ self .masks_and_position_ids_are_cached = False
198+ self .cached_attention_mask = None
199+ self .cached_loss_mask = None
200+ self .cached_position_ids = None
201+
141202 def _finalize (self ) -> None :
142203 """Abstract method implementation
143204
@@ -205,21 +266,43 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
205266 tokens >= self .vocab_size
206267 ), "An input token is out of bounds of the tokenizer vocabulary"
207268
208- attention_mask , loss_mask , position_ids = _get_ltor_masks_and_position_ids (
209- tokens ,
210- self .config .tokenizer .eod ,
211- self .config .reset_position_ids ,
212- self .config .reset_attention_mask ,
213- self .config .eod_mask_loss ,
214- )
215-
216- return {
217- "tokens" : tokens ,
218- "labels" : labels ,
219- "attention_mask" : attention_mask ,
220- "loss_mask" : loss_mask ,
221- "position_ids" : position_ids ,
222- }
269+ if (
270+ not self .masks_and_position_ids_are_cacheable
271+ or not self .masks_and_position_ids_are_cached
272+ ):
273+ attention_mask , loss_mask , position_ids = _get_ltor_masks_and_position_ids (
274+ tokens ,
275+ self .config .tokenizer .eod ,
276+ self .config .reset_position_ids ,
277+ self .config .reset_attention_mask ,
278+ self .config .eod_mask_loss ,
279+ self .config .create_attention_mask ,
280+ )
281+ if self .masks_and_position_ids_are_cacheable :
282+ self .cached_attention_mask = attention_mask
283+ self .cached_loss_mask = loss_mask
284+ self .cached_position_ids = position_ids
285+ self .masks_and_position_ids_are_cached = True
286+ else :
287+ attention_mask = self .cached_attention_mask
288+ loss_mask = self .cached_loss_mask
289+ position_ids = self .cached_position_ids
290+
291+ if self .config .create_attention_mask :
292+ return {
293+ "tokens" : tokens ,
294+ "labels" : labels ,
295+ "attention_mask" : attention_mask ,
296+ "loss_mask" : loss_mask ,
297+ "position_ids" : position_ids ,
298+ }
299+ else :
300+ return {
301+ "tokens" : tokens ,
302+ "labels" : labels ,
303+ "loss_mask" : loss_mask ,
304+ "position_ids" : position_ids ,
305+ }
223306
224307 def _query_document_sample_shuffle_indices (
225308 self , idx : int
@@ -575,6 +658,7 @@ def _get_ltor_masks_and_position_ids(
575658 reset_position_ids : bool ,
576659 reset_attention_mask : bool ,
577660 eod_mask_loss : bool ,
661+ create_attention_mask : bool ,
578662):
579663 """Build masks and position id for left to right model.
580664
@@ -589,6 +673,8 @@ def _get_ltor_masks_and_position_ids(
589673
590674 eod_mask_loss (bool): Switch to enable the EOD mask loss
591675
676+ create_attention_mask (bool): Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
677+
592678 Returns:
593679 torch.Tensor: Attention mask needed to be used for Attention
594680
@@ -598,9 +684,12 @@ def _get_ltor_masks_and_position_ids(
598684 """
599685 seq_length = data .numel ()
600686
601- attention_mask = torch .tril (torch .ones ((seq_length , seq_length ), device = data .device )).unsqueeze (
602- 0
603- )
687+ if create_attention_mask :
688+ attention_mask = torch .tril (
689+ torch .ones ((seq_length , seq_length ), device = data .device )
690+ ).unsqueeze (0 )
691+ else :
692+ attention_mask = None
604693
605694 # Loss mask.
606695 loss_mask = torch .ones (seq_length , dtype = torch .float , device = data .device )
@@ -625,14 +714,15 @@ def _get_ltor_masks_and_position_ids(
625714 for j in range (eod_index .numel ()):
626715 i = eod_index [j ]
627716 # Mask attention loss.
628- if reset_attention_mask :
717+ if reset_attention_mask and attention_mask is not None :
629718 attention_mask [0 , (i + 1 ) :, : (i + 1 )] = 0
630719 # Reset positions.
631720 if reset_position_ids :
632721 position_ids [(i + 1 ) :] -= i + 1 - prev_index
633722 prev_index = i + 1
634723
635- # Convert attention mask to binary:
636- attention_mask = attention_mask < 0.5
724+ if attention_mask is not None :
725+ # Convert attention mask to binary:
726+ attention_mask = attention_mask < 0.5
637727
638728 return attention_mask , loss_mask , position_ids
0 commit comments