Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions Config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model_path = 'E:/hugging_face_model_all/bert-base-uncased'

data_path = 'Amazon Review/train.json'

save_model_path = 'model_files/'

epochs = 2

train_num_task = 500

test_num_task = 5

k_support = 80

k_query = 20

num_labels = 2

log_path = 'logs/log.txt'

meta_epoch = 10

outer_batch_size = 2

inner_batch_size = 12

outer_update_lr = 5e-5

inner_update_lr = 5e-5

inner_update_step = 10

inner_update_step_eval = 40
2 changes: 1 addition & 1 deletion Interactive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@
" fast_model.eval()\n",
" with torch.no_grad():\n",
" query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))\n",
" query_batch = iter(query_dataloader).next()\n",
" query_batch = iter(query_dataloader).__next__()\n",
" query_batch = tuple(t.to(self.device) for t in query_batch)\n",
" q_input_ids, q_attention_mask, q_segment_ids, q_label_id = query_batch\n",
" q_outputs = fast_model(q_input_ids, q_attention_mask, q_segment_ids, labels = q_label_id)\n",
Expand Down
29 changes: 29 additions & 0 deletions datas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pandas as pd
import Config as cg
import json
from random import shuffle
from collections import Counter


def load_data():

data = json.load(open(cg.data_path))

return data

def res_data():
data = load_data()

mention_domain = [r['domain'] for r in data]
counter = Counter(mention_domain)
sorted_items_by_value = sorted(counter.items(), key=lambda item: item[1], reverse=True)

# 按照domains的数量排序 取后三
low_resource_domains = []
for i in range(0,3):
low_resource_domains.append(sorted_items_by_value[-3:][i][0])

train_examples = [r for r in data if r['domain'] not in low_resource_domains]
test_examples = [r for r in data if r['domain'] in low_resource_domains]

return train_examples, test_examples
110 changes: 110 additions & 0 deletions logs/log.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
epoch:0Step:0 acc:0.925
Test0.89
epoch:0Step:0 acc:0.825
Test0.9
epoch:0Step:1 acc:0.725
epoch:0Step:2 acc:0.875
epoch:0Step:3 acc:0.875
epoch:0Step:4 acc:0.875
epoch:0Step:5 acc:0.825
epoch:0Step:6 acc:0.95
epoch:0Step:7 acc:0.9
epoch:0Step:8 acc:0.8999999999999999
epoch:0Step:9 acc:0.925
epoch:0Step:10 acc:0.8500000000000001
epoch:0Step:11 acc:0.975
epoch:0Step:12 acc:0.8999999999999999
epoch:0Step:13 acc:0.85
epoch:0Step:14 acc:0.85
epoch:0Step:15 acc:0.975
epoch:0Step:16 acc:0.8500000000000001
epoch:0Step:17 acc:0.975
epoch:0Step:18 acc:0.925
epoch:0Step:19 acc:0.85
epoch:0Step:20 acc:0.875
Test0.9
epoch:0Step:21 acc:0.825
epoch:0Step:22 acc:0.925
epoch:0Step:23 acc:0.975
epoch:0Step:24 acc:0.95
epoch:0Step:25 acc:0.825
epoch:0Step:26 acc:1.0
epoch:0Step:27 acc:0.875
epoch:0Step:28 acc:0.875
epoch:0Step:29 acc:0.8999999999999999
epoch:0Step:30 acc:0.8999999999999999
epoch:0Step:31 acc:0.95
epoch:0Step:32 acc:0.95
epoch:0Step:33 acc:1.0
epoch:0Step:34 acc:0.8500000000000001
epoch:0Step:35 acc:0.975
epoch:0Step:36 acc:0.9
epoch:0Step:37 acc:0.975
epoch:0Step:38 acc:0.95
epoch:0Step:39 acc:0.8
epoch:0Step:40 acc:0.95
Test0.93
epoch:0Step:41 acc:0.925
epoch:0Step:42 acc:0.975
epoch:0Step:43 acc:0.9
epoch:0Step:44 acc:0.95
epoch:0Step:45 acc:0.925
epoch:0Step:46 acc:0.9
epoch:0Step:47 acc:0.925
epoch:0Step:48 acc:0.9
epoch:0Step:49 acc:0.925
epoch:0Step:50 acc:0.95
epoch:0Step:51 acc:0.925
epoch:0Step:52 acc:0.975
epoch:0Step:53 acc:0.975
epoch:0Step:54 acc:0.925
epoch:0Step:55 acc:0.925
epoch:0Step:56 acc:0.875
epoch:0Step:57 acc:0.975
epoch:0Step:58 acc:0.875
epoch:0Step:59 acc:0.95
epoch:0Step:60 acc:0.8999999999999999
Test0.93
epoch:0Step:61 acc:0.875
epoch:0Step:62 acc:0.925
epoch:0Step:63 acc:0.875
epoch:0Step:64 acc:0.95
epoch:0Step:65 acc:0.925
epoch:0Step:66 acc:0.925
epoch:0Step:67 acc:0.825
epoch:0Step:68 acc:0.925
epoch:0Step:69 acc:0.95
epoch:0Step:70 acc:1.0
epoch:0Step:71 acc:0.95
epoch:0Step:72 acc:0.875
epoch:0Step:73 acc:0.975
epoch:0Step:74 acc:0.95
epoch:0Step:75 acc:0.975
epoch:0Step:76 acc:0.95
epoch:0Step:77 acc:0.925
epoch:0Step:78 acc:0.925
epoch:0Step:79 acc:0.875
epoch:0Step:80 acc:0.95
Test0.96
epoch:0Step:81 acc:0.925
epoch:0Step:82 acc:1.0
epoch:0Step:83 acc:0.875
epoch:0Step:84 acc:0.95
epoch:0Step:85 acc:1.0
epoch:0Step:86 acc:0.975
epoch:0Step:87 acc:0.925
epoch:0Step:88 acc:0.975
epoch:0Step:89 acc:1.0
epoch:0Step:90 acc:0.875
epoch:0Step:91 acc:0.95
epoch:0Step:92 acc:0.95
epoch:0Step:93 acc:0.975
epoch:0Step:94 acc:0.925
epoch:0Step:95 acc:1.0
epoch:0Step:96 acc:0.95
epoch:0Step:97 acc:0.975
epoch:0Step:98 acc:0.875
epoch:0Step:99 acc:0.95
epoch:0Step:100 acc:0.925
Test0.95
epoch:0Step:101 acc:0.925
140 changes: 44 additions & 96 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,38 @@
import json
from random import shuffle
from collections import Counter
import torch
from training import *
from datas import res_data
from meta import *
from meta_learner import *
from transformers import BertModel, BertTokenizer
import time
import logging
import argparse
import os
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from reptile import Learner
from task import MetaTask
import random
import numpy as np

def random_seed(value):
torch.backends.cudnn.deterministic=True
torch.manual_seed(value)
torch.cuda.manual_seed(value)
np.random.seed(value)
random.seed(value)

def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
idxs = list(range(0,len(taskset)))
if is_shuffle:
random.shuffle(idxs)
for i in range(0,len(idxs), batch_size):
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]

def main():

parser = argparse.ArgumentParser()

parser.add_argument("--data", default='dataset.json', type=str,
help="Path to dataset file")

parser.add_argument("--bert_model", default='bert-base-uncased', type=str,
help="Path to bert model")

parser.add_argument("--num_labels", default=2, type=int,
help="Number of class for classification")
import Config as cg

parser.add_argument("--epoch", default=5, type=int,
help="Number of outer interation")

parser.add_argument("--k_spt", default=80, type=int,
help="Number of support samples per task")

parser.add_argument("--k_qry", default=20, type=int,
help="Number of query samples per task")

parser.add_argument("--outer_batch_size", default=2, type=int,
help="Batch of task size")

parser.add_argument("--inner_batch_size", default=12, type=int,
help="Training batch size in inner iteration")

parser.add_argument("--outer_update_lr", default=5e-5, type=float,
help="Meta learning rate")

parser.add_argument("--inner_update_lr", default=5e-5, type=float,
help="Inner update learning rate")

parser.add_argument("--inner_update_step", default=10, type=int,
help="Number of interation in the inner loop during train time")
# tokenizer = BertTokenizer.from_pretrained(cg.model_path, do_lower_case = True)
# train = MetaTask(train_examples, num_task = 13, k_support=10, k_query=2, tokenizer = tokenizer)

parser.add_argument("--inner_update_step_eval", default=40, type=int,
help="Number of interation in the inner loop during test time")

parser.add_argument("--num_task_train", default=500, type=int,
help="Total number of meta tasks for training")

parser.add_argument("--num_task_test", default=3, type=int,
help="Total number of tasks for testing")

args = parser.parse_args()

reviews = json.load(open(args.data))
low_resource_domains = ["office_products", "automotive", "computer_&_video_games"]
def strat_training(train_examples,tokenizer,test):

train_examples = [r for r in reviews if r['domain'] not in low_resource_domains]
test_examples = [r for r in reviews if r['domain'] in low_resource_domains]
print(len(train_examples), len(test_examples))
args = TrainingArgs()

tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case = True)
learner = Learner(args)

test = MetaTask(test_examples, num_task = args.num_task_test, k_support=args.k_spt,
k_query=args.k_qry, tokenizer = tokenizer)

global_step = 0
for epoch in range(args.epoch):

train = MetaTask(train_examples, num_task = args.num_task_train, k_support=args.k_spt,
k_query=args.k_qry, tokenizer = tokenizer)
best_model_accuracy = 0.0

for epoch in range(args.meta_epoch):

train = MetaTask(train_examples, num_task = cg.train_num_task, k_support=cg.k_support, k_query=cg.k_query, tokenizer = tokenizer)
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = args.outer_batch_size)

for step, task_batch in enumerate(db):


f = open(cg.log_path, 'a')

acc = learner(task_batch)

print('Step:', step, '\ttraining Acc:', acc)


print('epoch:',epoch,'Step:', step, '\ttraining Acc:', acc)
f.write("epoch:"+str(epoch)+" Step:"+str(step)+" acc:"+str(acc) + '\n')

if global_step % 20 == 0:
random_seed(123)
print("\n-----------------Testing Mode-----------------\n")
Expand All @@ -115,10 +44,29 @@ def main():
acc_all_test.append(acc)

print('Step:', step, 'Test F1:', np.mean(acc_all_test))

f.write('Test' + str(np.mean(acc_all_test)) + '\n')

random_seed(int(time.time() % 10))

if acc > best_model_accuracy:
best_model_accuracy = acc
best_model_path = cg.save_model_path+'epoch'+str(epoch)+'_step'+str(step).pth'
# 保存模型
torch.save(learner.model, best_model_path)

global_step += 1

if __name__ == "__main__":
main()
f.close()

def run():

train_examples, test_examples = res_data()

tokenizer = BertTokenizer.from_pretrained(cg.model_path, do_lower_case = True)

test = MetaTask(test_examples, num_task = cg.test_num_task, k_support=cg.k_support, k_query=cg.k_query, tokenizer = tokenizer)

strat_training(train_examples,tokenizer,test)


if __name__ == '__main__':
run()

2 changes: 1 addition & 1 deletion maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, batch_tasks, training = True):
print("Inner Loss: ", np.mean(all_loss))

query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
query_batch = iter(query_dataloader).next()
query_batch = iter(query_dataloader).__next__()
query_batch = tuple(t.to(self.device) for t in query_batch)
q_input_ids, q_attention_mask, q_segment_ids, q_label_id = query_batch
q_outputs = fast_model(q_input_ids, q_attention_mask, q_segment_ids, labels = q_label_id)
Expand Down
Loading