Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions chapter_natural-language-processing-pretraining/bert.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class BERTEncoder(nn.Block):
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout, max_len=1000, **kwargs):
super(BERTEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
self.segment_embedding = nn.Embedding(2, num_hiddens)
self.blks = nn.Sequential()
Expand All @@ -210,7 +211,9 @@ class BERTEncoder(nn.Block):
def forward(self, tokens, segments, valid_lens):
# Shape of `X` remains unchanged in the following code snippet:
# (batch size, max sequence length, `num_hiddens`)
X = self.token_embedding(tokens) + self.segment_embedding(segments)
# the embedding values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments)
X = X + self.pos_embedding.data(ctx=X.ctx)[:, :X.shape[1], :]
for blk in self.blks:
X = blk(X, valid_lens)
Expand All @@ -225,6 +228,7 @@ class BERTEncoder(nn.Module):
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout, max_len=1000, **kwargs):
super(BERTEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
self.segment_embedding = nn.Embedding(2, num_hiddens)
self.blks = nn.Sequential()
Expand All @@ -239,7 +243,9 @@ class BERTEncoder(nn.Module):
def forward(self, tokens, segments, valid_lens):
# Shape of `X` remains unchanged in the following code snippet:
# (batch size, max sequence length, `num_hiddens`)
X = self.token_embedding(tokens) + self.segment_embedding(segments)
# the embedding values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments)
X = X + self.pos_embedding[:, :X.shape[1], :]
for blk in self.blks:
X = blk(X, valid_lens)
Expand Down
Loading