diff --git a/chapter_natural-language-processing-pretraining/bert.md b/chapter_natural-language-processing-pretraining/bert.md index 007e32e024..e46979384b 100644 --- a/chapter_natural-language-processing-pretraining/bert.md +++ b/chapter_natural-language-processing-pretraining/bert.md @@ -196,6 +196,7 @@ class BERTEncoder(nn.Block): def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len=1000, **kwargs): super(BERTEncoder, self).__init__(**kwargs) + self.num_hiddens = num_hiddens self.token_embedding = nn.Embedding(vocab_size, num_hiddens) self.segment_embedding = nn.Embedding(2, num_hiddens) self.blks = nn.Sequential() @@ -210,7 +211,9 @@ class BERTEncoder(nn.Block): def forward(self, tokens, segments, valid_lens): # Shape of `X` remains unchanged in the following code snippet: # (batch size, max sequence length, `num_hiddens`) - X = self.token_embedding(tokens) + self.segment_embedding(segments) + # the embedding values are multiplied by the square root of the embedding dimension + # to rescale before they are summed up + X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments) X = X + self.pos_embedding.data(ctx=X.ctx)[:, :X.shape[1], :] for blk in self.blks: X = blk(X, valid_lens) @@ -225,6 +228,7 @@ class BERTEncoder(nn.Module): def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len=1000, **kwargs): super(BERTEncoder, self).__init__(**kwargs) + self.num_hiddens = num_hiddens self.token_embedding = nn.Embedding(vocab_size, num_hiddens) self.segment_embedding = nn.Embedding(2, num_hiddens) self.blks = nn.Sequential() @@ -239,7 +243,9 @@ class BERTEncoder(nn.Module): def forward(self, tokens, segments, valid_lens): # Shape of `X` remains unchanged in the following code snippet: # (batch size, max sequence length, `num_hiddens`) - X = self.token_embedding(tokens) + self.segment_embedding(segments) + # the embedding values are multiplied by the square root of the embedding dimension + # to rescale before they are summed up + X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments) X = X + self.pos_embedding[:, :X.shape[1], :] for blk in self.blks: X = blk(X, valid_lens)