From 5216304188747eb6b0e52b542429b15986f51d23 Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 25 Mar 2025 18:04:19 +0900 Subject: [PATCH] Exclude special tokens from mlm_masking - Added a `special_tokens` parameter to allow specifying token IDs (e.g., [CLS], [SEP]) to be excluded from masking. - Updated the eligible mask logic to filter out both special tokens and pad tokens before applying masking. - Refactored the masking process to ensure that only eligible tokens are considered for the 80/10/10 masking scheme. --- src/sequence_packer.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/sequence_packer.py b/src/sequence_packer.py index e937865a..c450a9a1 100644 --- a/src/sequence_packer.py +++ b/src/sequence_packer.py @@ -287,6 +287,7 @@ def mlm_masking( mask_token: int, pad_token: int = -1, ignore_index: int = -100, + special_tokens: Optional[Any] = None, np_rng=np.random.default_rng(), ) -> tuple[np.ndarray, np.ndarray]: """ @@ -312,28 +313,41 @@ def mlm_masking( mask_prob (float): probability of initially masking a token, in the first "wave" of masking mask_token (int): token to use for masking ignore_index (int): the token indicating that position should be ignored during training. We call it `ignore_index` to conform to the API of the cross entropy loss function. + special_tokens (Optional[Any]): a list or array of token IDs that should be excluded from masking. Returns: tuple[np.array,np.array]: (masked_seq, labels) masked_seq: the input seq with some tokens replaced by `mask_token` labels: the original input seq with non-masked tokens replaced by `ignore_index` """ - # Create labels - labels = np.where(seq == pad_token, ignore_index, seq) - - # Create a single mask + # Create a boolean mask indicating tokens eligible for masking. + # Exclude special tokens and pad tokens from being considered. + eligible_mask = np.ones_like(seq, dtype=bool) + if special_tokens is not None: + # Exclude tokens that are in the special_tokens list. + eligible_mask &= ~np.isin(seq, special_tokens) + if pad_token is not None: + # Exclude pad tokens. + eligible_mask &= seq != pad_token + + # Generate a random array and set non-eligible positions to 1.0 so they never meet the mask_prob threshold. rand = np_rng.random(seq.shape) + rand = np.where(eligible_mask, rand, 1.0) + + # Select masking candidates with the specified mask_prob probability (applied only to eligible tokens). + mask_candidates = rand < mask_prob - # Partition the probability space appropriately using a single mask + # Generate another random array to partition the mask candidates into 80/10/10 groups. + branch_rand = np_rng.random(seq.shape) # 80% of the time, we mask the token - mask_mask = rand < mask_prob * 0.8 + mask_mask = mask_candidates & (branch_rand < 0.8) # 10% of the time, we replace the token with a random token - random_mask = (rand >= mask_prob * 0.8) & (rand < mask_prob * 0.9) + random_mask = mask_candidates & (branch_rand >= 0.8) & (branch_rand < 0.9) # 10% of the time, we keep the token the same - keep_mask = (rand >= mask_prob * 0.9) & (rand < mask_prob) + keep_mask = mask_candidates & (branch_rand >= 0.9) # We only compute loss over the tokens marked for masking - labels = np.where(mask_mask | random_mask | keep_mask, labels, ignore_index) + labels = np.where(mask_candidates, seq, ignore_index) # Apply masking seq = np.where(mask_mask, mask_token, seq)