Skip to content

Commit 8be7f0e

Browse files
committed
TD-BERT Implementation
TD-BERT implementation from songyouwei#147
1 parent 1584032 commit 8be7f0e

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

data_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def __init__(self, fname, tokenizer):
205205
'aspect_boundary': aspect_boundary,
206206
'dependency_graph': dependency_graph,
207207
'polarity': polarity,
208+
'left_context_len': left_context_len,
209+
'aspect_len': aspect_len,
208210
}
209211

210212
all_data.append(data)

models/td_bert.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# -*- coding: utf-8 -*-
2+
# file: td_bert.py
3+
# author: xiangpan <[email protected]>
4+
# Copyright (C) 2020. All Rights Reserved.
5+
import torch
6+
import torch.nn as nn
7+
from layers.attention import Attention
8+
9+
10+
class TD_BERT(nn.Module):
11+
def __init__(self, bert, opt):
12+
super(TD_BERT, self).__init__()
13+
self.bert = bert
14+
self.dropout = nn.Dropout(opt.dropout)
15+
self.opt = opt
16+
self.dense = nn.Linear(opt.bert_dim, opt.polarities_dim)
17+
18+
def forward(self, inputs):
19+
text_bert_indices, bert_segments_ids, left_context_len, aspect_len = (
20+
inputs[0],
21+
inputs[1],
22+
inputs[2],
23+
inputs[3],
24+
)
25+
26+
encoded_layers, cls_output = self.bert(
27+
text_bert_indices, bert_segments_ids
28+
)
29+
30+
31+
pooled_list = []
32+
for i in range(0, encoded_layers.shape[0]): # batch_size i th batch
33+
encoded_layers_i = encoded_layers[i]
34+
left_context_len_i = left_context_len[i]
35+
aspect_len_i = aspect_len[i]
36+
e_list = []
37+
if (left_context_len_i + 1) == (left_context_len_i + 1 + aspect_len_i):
38+
e_list.append(encoded_layers_i[0])
39+
for j in range(left_context_len_i + 1, left_context_len_i + 1 + aspect_len_i):
40+
e_list.append(encoded_layers_i[j])
41+
e = torch.stack(e_list, 0)
42+
embed = torch.stack([e], 0)
43+
pooled = nn.functional.max_pool2d(embed, (embed.size(1), 1)).squeeze(1)
44+
pooled_list.append(pooled)
45+
pooled_output = torch.cat(pooled_list)
46+
pooled_output = self.dropout(pooled_output)
47+
48+
logits = self.dense(pooled_output)
49+
50+
return logits

train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from models.bert_spc import BERT_SPC
2727
from models.albert_spc import ALBERT_SPC
2828
from models.roberta_spc import ROBERTA_SPC
29+
from models.td_bert import TD_BERT
2930

3031
logger = logging.getLogger()
3132
logger.setLevel(logging.INFO)
@@ -249,6 +250,8 @@ def main():
249250
'lcf_bert': LCF_BERT,
250251
'albert_spc': ALBERT_SPC,
251252
'roberta_spc': ROBERTA_SPC,
253+
'td_bert': TD_BERT,
254+
252255
# default hyper-parameters for LCF-BERT model is as follws:
253256
# lr: 2e-5
254257
# l2: 1e-5
@@ -298,6 +301,7 @@ def main():
298301
'aen_bert': ['text_bert_indices', 'aspect_bert_indices'],
299302
'aen_bert': ['text_bert_indices', 'aspect_bert_indices'],
300303
'lcf_bert': ['concat_bert_indices', 'concat_segments_indices', 'text_bert_indices', 'aspect_bert_indices'],
304+
'td_bert': ['text_bert_indices', 'bert_segments_ids','left_context_len','aspect_len'],
301305
}
302306
initializers = {
303307
'xavier_uniform_': torch.nn.init.xavier_uniform_,

0 commit comments

Comments
 (0)