-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathTrainSeparate.py
More file actions
44 lines (37 loc) · 1.49 KB
/
TrainSeparate.py
File metadata and controls
44 lines (37 loc) · 1.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import Storage as ST
import Model as MD
import tensorflow as tf
import numpy as np
if __name__ == '__main__':
path_list = [['./TrainingSamples/Ges_0', 0],
['./TrainingSamples/Ges_1', 1],
['./TrainingSamples/Ges_2', 2],
['./TrainingSamples/Ges_3', 3],
['./TrainingSamples/Ges_3-A', 6],
['./TrainingSamples/Ges_3-B', 7],
['./TrainingSamples/Ges_4', 4],
['./TrainingSamples/Ges_5', 5]
]
file_list = ST.enum_samples(path_list)
train_op, input_ph, label_ph, model = MD.train_operation(10, 192, 192, 3, 8)
correct = 0.0
total = 0.0
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
for x in range(0, 2400):
images, labels = ST.pick_some(file_list, 10)
if (x+1) % 20 == 0:
rlt = sess.run(model, feed_dict={input_ph: images})
stat = np.argmax(rlt, axis=1) - np.argmax(labels, axis=1)
stat = 10 - np.count_nonzero(stat)
correct += stat
total += 10
rate = 100 * correct / total
print(x+1, rate)
# print(np.argmax(rlt, axis=1))
# print(np.argmax(labels, axis=1))
else:
sess.run(train_op, feed_dict={input_ph: images, label_ph: labels})
# print(x)
MD.save_params('params_separate.bin', sess)