-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
25 lines (22 loc) · 846 Bytes
/
train.py
File metadata and controls
25 lines (22 loc) · 846 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
import nn
logfilename = 'NN'
nn.set_log(logfilename)
from sklearn.datasets import fetch_mldata
nn.logger.info('Downloading MNIST dataset.')
mnist = fetch_mldata('MNIST original')
mnist.data = mnist.data.astype(np.float32)
mnist.data /= 255
mnist.target = mnist.target.astype(np.int32)
train_img, test_img = np.split(mnist.data, [60000])
train_targets, test_targets = np.split(mnist.target, [60000])
train_data = list(zip(train_img, train_targets))
test_data = list(zip(test_img, test_targets))
np.random.shuffle(train_data)
import nnh1
net = nnh1.NNH1(n_units=[784,100,10])
net.train(train_data, test_data, batch_size=10, test_step=1000, epoch=10, \
lr=0.01, lr_step=30000, lr_mult=0.1, wdecay=0.0005, momentum=0.9,\
drop_ph=1.0, disp_step=100)
net.save('model.npz')
net.plot_log(logfilename+'.log')