Skip to content

Exclude Special Tokens from Masking in mlm_masking #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/sequence_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def mlm_masking(
mask_token: int,
pad_token: int = -1,
ignore_index: int = -100,
special_tokens: Optional[Any] = None,
np_rng=np.random.default_rng(),
) -> tuple[np.ndarray, np.ndarray]:
"""
Expand All @@ -312,28 +313,41 @@ def mlm_masking(
mask_prob (float): probability of initially masking a token, in the first "wave" of masking
mask_token (int): token to use for masking
ignore_index (int): the token indicating that position should be ignored during training. We call it `ignore_index` to conform to the API of the cross entropy loss function.
special_tokens (Optional[Any]): a list or array of token IDs that should be excluded from masking.

Returns:
tuple[np.array,np.array]: (masked_seq, labels)
masked_seq: the input seq with some tokens replaced by `mask_token`
labels: the original input seq with non-masked tokens replaced by `ignore_index`
"""
# Create labels
labels = np.where(seq == pad_token, ignore_index, seq)

# Create a single mask
# Create a boolean mask indicating tokens eligible for masking.
# Exclude special tokens and pad tokens from being considered.
eligible_mask = np.ones_like(seq, dtype=bool)
if special_tokens is not None:
# Exclude tokens that are in the special_tokens list.
eligible_mask &= ~np.isin(seq, special_tokens)
if pad_token is not None:
# Exclude pad tokens.
eligible_mask &= seq != pad_token

# Generate a random array and set non-eligible positions to 1.0 so they never meet the mask_prob threshold.
rand = np_rng.random(seq.shape)
rand = np.where(eligible_mask, rand, 1.0)

# Select masking candidates with the specified mask_prob probability (applied only to eligible tokens).
mask_candidates = rand < mask_prob

# Partition the probability space appropriately using a single mask
# Generate another random array to partition the mask candidates into 80/10/10 groups.
branch_rand = np_rng.random(seq.shape)
# 80% of the time, we mask the token
mask_mask = rand < mask_prob * 0.8
mask_mask = mask_candidates & (branch_rand < 0.8)
# 10% of the time, we replace the token with a random token
random_mask = (rand >= mask_prob * 0.8) & (rand < mask_prob * 0.9)
random_mask = mask_candidates & (branch_rand >= 0.8) & (branch_rand < 0.9)
# 10% of the time, we keep the token the same
keep_mask = (rand >= mask_prob * 0.9) & (rand < mask_prob)
keep_mask = mask_candidates & (branch_rand >= 0.9)

# We only compute loss over the tokens marked for masking
labels = np.where(mask_mask | random_mask | keep_mask, labels, ignore_index)
labels = np.where(mask_candidates, seq, ignore_index)

# Apply masking
seq = np.where(mask_mask, mask_token, seq)
Expand Down