From 46fdcd0c29fa09789bde6d197c20b1617f66c9b3 Mon Sep 17 00:00:00 2001 From: zwy Date: Thu, 9 Jan 2020 18:45:00 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=8A=A0=E5=85=A5batch?= =?UTF-8?q?=5Fsize?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- captcha_train.py | 2 +- my_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/captcha_train.py b/captcha_train.py index dc290bf3a..8eede5398 100644 --- a/captcha_train.py +++ b/captcha_train.py @@ -18,7 +18,7 @@ def main(): optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) # Train the Model - train_dataloader = my_dataset.get_train_data_loader() + train_dataloader = my_dataset.get_train_data_loader(batch_size=batch_size) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_dataloader): images = Variable(images) diff --git a/my_dataset.py b/my_dataset.py index 3ce66b577..3d59bb092 100755 --- a/my_dataset.py +++ b/my_dataset.py @@ -30,10 +30,10 @@ def __getitem__(self, idx): transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) -def get_train_data_loader(): +def get_train_data_loader(batch_size=64): dataset = mydataset(captcha_setting.TRAIN_DATASET_PATH, transform=transform) - return DataLoader(dataset, batch_size=64, shuffle=True) + return DataLoader(dataset, batch_size=batch_size, shuffle=True) def get_test_data_loader(): dataset = mydataset(captcha_setting.TEST_DATASET_PATH, transform=transform) From 31e4603db36a1c01e541a99f897fd12bad0257cc Mon Sep 17 00:00:00 2001 From: zwy Date: Thu, 9 Jan 2020 18:45:09 +0800 Subject: [PATCH 2/2] fine_tuning --- fine_tuning.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 fine_tuning.py diff --git a/fine_tuning.py b/fine_tuning.py new file mode 100644 index 000000000..8856f6a39 --- /dev/null +++ b/fine_tuning.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import my_dataset +from captcha_cnn_model import CNN + +learning_rate = 0.001 +batch_size = 10 +num_epochs = 15 + + +def fine_tuning(): + cnn = CNN() + cnn.eval() + cnn.load_state_dict(torch.load('model.pkl')) + print("load cnn net.") + criterion = nn.MultiLabelSoftMarginLoss() + optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) + train_data_loader = my_dataset.get_train_data_loader(batch_size) + + for epoch in range(num_epochs): + for i, (images, labels) in enumerate(train_data_loader): + images = Variable(images) + labels = Variable(labels.float()) + predict_labels = cnn(images) + # print(predict_labels.type) + # print(labels.type) + loss = criterion(predict_labels, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if (i + 1) % 10 == 0: + print("epoch:", epoch, "step:", i, "loss:", loss.item()) + if (i + 1) % 100 == 0: + torch.save(cnn.state_dict(), "./model.pkl") # current is model.pkl + print("save model") + print("epoch:", epoch, "step:", i, "loss:", loss.item()) + torch.save(cnn.state_dict(), "./model.pkl") # current is model.pkl + print("save last model") + +if __name__ == '__main__': + fine_tuning() +