Skip to content

Commit 0fce551

Browse files
Adding Whole Word Masking
1 parent d66a146 commit 0fce551

File tree

2 files changed

+96
-19
lines changed

2 files changed

+96
-19
lines changed

README.md

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,56 @@
11
# BERT
22

3+
**\*\*\*\*\* New May 31st, 2019: Whole Word Masking Models \*\*\*\*\***
4+
5+
This is a release of several new models which were the result of an improvement
6+
the pre-processing code.
7+
8+
In the original pre-processing code, we randomly select WordPiece tokens to
9+
mask. For example:
10+
11+
`Input Text: the man jumped up , put his basket on phil ##am ##mon ' s head`
12+
`Original Masked Input: [MASK] man [MASK] up , put his [MASK] on phil
13+
[MASK] ##mon ' s head`
14+
15+
The new technique is called Whole Word Masking. In this case, we always mask
16+
*all* of the the tokens corresponding to a word at once. The overall masking
17+
rate remains the same.
18+
19+
`Whole Word Masked Input: the man [MASK] up , put his basket on [MASK] [MASK]
20+
[MASK] ' s head`
21+
22+
The training is identical -- we still predict each masked WordPiece token
23+
independently. The improvement comes from the fact that the original prediction
24+
task was too 'easy' for words that had been split into multiple WordPieces.
25+
26+
This can be enabled during data generation by passing the flag
27+
`--do_whole_word_mask=True` to `create_pretraining_data.py`.
28+
29+
Pre-trained models with Whole Word Masking are linked below. The data and
30+
training were otherwise identical, and the models have identical structure and
31+
vocab to the original models. We only include BERT-Large models. When using
32+
these models, please make it clear in the paper that you are using the Whole
33+
Word Masking variant of BERT-Large.
34+
35+
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip)**:
36+
24-layer, 1024-hidden, 16-heads, 340M parameters
37+
38+
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip)**:
39+
24-layer, 1024-hidden, 16-heads, 340M parameters
40+
41+
Model | SQUAD 1.1 F1/EM | Multi NLI Accuracy
42+
---------------------------------------- | :-------------: | :----------------:
43+
BERT-Large, Uncased (Original) | 91.0/84.3 | 86.05
44+
BERT-Large, Uncased (Whole Word Masking) | 92.8/86.7 | 87.07
45+
BERT-Large, Cased (Original) | 91.5/84.8 | 86.09
46+
BERT-Large, Cased (Whole Word Masking) | 92.9/86.7 | 86.46
47+
348
**\*\*\*\*\* New February 7th, 2019: TfHub Module \*\*\*\*\***
449

550
BERT has been uploaded to [TensorFlow Hub](https://tfhub.dev). See
6-
`run_classifier_with_tfhub.py` for an example of how to use the TF Hub module,
7-
or run an example in the browser on [Colab](https://colab.sandbox.google.com/github/google-research/bert/blob/master/predicting_movie_reviews_with_bert_on_tf_hub.ipynb).
51+
`run_classifier_with_tfhub.py` for an example of how to use the TF Hub module,
52+
or run an example in the browser on
53+
[Colab](https://colab.sandbox.google.com/github/google-research/bert/blob/master/predicting_movie_reviews_with_bert_on_tf_hub.ipynb).
854

955
**\*\*\*\*\* New November 23rd, 2018: Un-normalized multilingual model + Thai +
1056
Mongolian \*\*\*\*\***
@@ -225,6 +271,10 @@ using your own script.)**
225271

226272
The links to the models are here (right-click, 'Save link as...' on the name):
227273

274+
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip)**:
275+
24-layer, 1024-hidden, 16-heads, 340M parameters
276+
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip)**:
277+
24-layer, 1024-hidden, 16-heads, 340M parameters
228278
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip)**:
229279
12-layer, 768-hidden, 12-heads, 110M parameters
230280
* **[`BERT-Large, Uncased`](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip)**:

create_pretraining_data.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
"Whether to lower case the input text. Should be True for uncased "
4343
"models and False for cased models.")
4444

45+
flags.DEFINE_bool(
46+
"do_whole_word_mask", False,
47+
"Whether to use whole word masking rather than per-WordPiece masking.")
48+
4549
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
4650

4751
flags.DEFINE_integer("max_predictions_per_seq", 20,
@@ -343,7 +347,20 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
343347
for (i, token) in enumerate(tokens):
344348
if token == "[CLS]" or token == "[SEP]":
345349
continue
346-
cand_indexes.append(i)
350+
# Whole Word Masking means that if we mask all of the wordpieces
351+
# corresponding to an original word. When a word has been split into
352+
# WordPieces, the first token does not have any marker and any subsequence
353+
# tokens are prefixed with ##. So whenever we see the ## token, we
354+
# append it to the previous set of word indexes.
355+
#
356+
# Note that Whole Word Masking does *not* change the training code
357+
# at all -- we still predict each WordPiece independently, softmaxed
358+
# over the entire vocabulary.
359+
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
360+
token.startswith("##")):
361+
cand_indexes[-1].append(i)
362+
else:
363+
cand_indexes.append([i])
347364

348365
rng.shuffle(cand_indexes)
349366

@@ -354,29 +371,39 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
354371

355372
masked_lms = []
356373
covered_indexes = set()
357-
for index in cand_indexes:
374+
for index_set in cand_indexes:
358375
if len(masked_lms) >= num_to_predict:
359376
break
360-
if index in covered_indexes:
377+
# If adding a whole-word mask would exceed the maximum number of
378+
# predictions, then just skip this candidate.
379+
if len(masked_lms) + len(index_set) > num_to_predict:
380+
continue
381+
is_any_index_covered = False
382+
for index in index_set:
383+
if index in covered_indexes:
384+
is_any_index_covered = True
385+
break
386+
if is_any_index_covered:
361387
continue
362-
covered_indexes.add(index)
388+
for index in index_set:
389+
covered_indexes.add(index)
363390

364-
masked_token = None
365-
# 80% of the time, replace with [MASK]
366-
if rng.random() < 0.8:
367-
masked_token = "[MASK]"
368-
else:
369-
# 10% of the time, keep original
370-
if rng.random() < 0.5:
371-
masked_token = tokens[index]
372-
# 10% of the time, replace with random word
391+
masked_token = None
392+
# 80% of the time, replace with [MASK]
393+
if rng.random() < 0.8:
394+
masked_token = "[MASK]"
373395
else:
374-
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
375-
376-
output_tokens[index] = masked_token
396+
# 10% of the time, keep original
397+
if rng.random() < 0.5:
398+
masked_token = tokens[index]
399+
# 10% of the time, replace with random word
400+
else:
401+
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
377402

378-
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
403+
output_tokens[index] = masked_token
379404

405+
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
406+
assert len(masked_lms) <= num_to_predict
380407
masked_lms = sorted(masked_lms, key=lambda x: x.index)
381408

382409
masked_lm_positions = []

0 commit comments

Comments
 (0)