@@ -43,7 +43,6 @@ def main(_):
43
43
tl .files .exists_or_mkdir (FLAGS .sample_dir )
44
44
45
45
z_dim = 100
46
-
47
46
with tf .device ("/gpu:0" ):
48
47
##========================= DEFINE MODEL ===========================##
49
48
z = tf .placeholder (tf .float32 , [FLAGS .batch_size , z_dim ], name = 'z_noise' )
@@ -94,15 +93,14 @@ def main(_):
94
93
net_d_name = os .path .join (save_dir , 'net_d.npz' )
95
94
96
95
data_files = glob (os .path .join ("./data" , FLAGS .dataset , "*.jpg" ))
97
- # sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
98
- sample_seed = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )
96
+
97
+ sample_seed = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
99
98
100
99
##========================= TRAIN MODELS ================================##
101
100
iter_counter = 0
102
101
for epoch in range (FLAGS .epoch ):
103
102
## shuffle data
104
103
shuffle (data_files )
105
- print ("[*] Dataset shuffled!" )
106
104
107
105
## update sample files based on shuffled data
108
106
sample_files = data_files [0 :FLAGS .sample_size ]
@@ -119,46 +117,28 @@ def main(_):
119
117
# more image augmentation functions in http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html
120
118
batch = [get_image (batch_file , FLAGS .image_size , is_crop = FLAGS .is_crop , resize_w = FLAGS .output_size , is_grayscale = 0 ) for batch_file in batch_files ]
121
119
batch_images = np .array (batch ).astype (np .float32 )
122
- # batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
123
- batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )
120
+ batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 ) # batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
124
121
start_time = time .time ()
125
122
# updates the discriminator
126
123
errD , _ = sess .run ([d_loss , d_optim ], feed_dict = {z : batch_z , real_images : batch_images })
127
124
# updates the generator, run generator twice to make sure that d_loss does not go to zero (difference from paper)
128
125
for _ in range (2 ):
129
126
errG , _ = sess .run ([g_loss , g_optim ], feed_dict = {z : batch_z })
130
127
print ("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
131
- % (epoch , FLAGS .epoch , idx , batch_idxs ,
132
- time .time () - start_time , errD , errG ))
133
- sys .stdout .flush ()
128
+ % (epoch , FLAGS .epoch , idx , batch_idxs , time .time () - start_time , errD , errG ))
134
129
135
130
iter_counter += 1
136
131
if np .mod (iter_counter , FLAGS .sample_step ) == 0 :
137
132
# generate and visualize generated images
138
133
img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
139
- save_images (img , [8 , 8 ],
140
- './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , idx ))
134
+ tl .visualize .save_images (img , [8 , 8 ], './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , idx ))
141
135
print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD , errG ))
142
- sys .stdout .flush ()
143
136
144
137
if np .mod (iter_counter , FLAGS .save_step ) == 0 :
145
138
# save current network parameters
146
139
print ("[*] Saving checkpoints..." )
147
- img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
148
- model_dir = "%s_%s_%s" % (FLAGS .dataset , FLAGS .batch_size , FLAGS .output_size )
149
- save_dir = os .path .join (FLAGS .checkpoint_dir , model_dir )
150
- if not os .path .exists (save_dir ):
151
- os .makedirs (save_dir )
152
- # the latest version location
153
- net_g_name = os .path .join (save_dir , 'net_g.npz' )
154
- net_d_name = os .path .join (save_dir , 'net_d.npz' )
155
- # # this version is for future re-check and visualization analysis
156
- # net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
157
- # net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
158
- # tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
159
- # tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
160
- # tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
161
- # tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
140
+ tl .files .save_npz (net_g .all_params , name = net_g_name , sess = sess )
141
+ tl .files .save_npz (net_d .all_params , name = net_d_name , sess = sess )
162
142
print ("[*] Saving checkpoints SUCCESS!" )
163
143
164
144
if __name__ == '__main__' :
0 commit comments