Skip to content

Commit 8ebb42f

Browse files
committed
Update to TF1.1 TL1.4.5
1 parent 27d83c3 commit 8ebb42f

28 files changed

+437
-332
lines changed

main.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def main(_):
4343
tl.files.exists_or_mkdir(FLAGS.sample_dir)
4444

4545
z_dim = 100
46-
4746
with tf.device("/gpu:0"):
4847
##========================= DEFINE MODEL ===========================##
4948
z = tf.placeholder(tf.float32, [FLAGS.batch_size, z_dim], name='z_noise')
@@ -94,15 +93,14 @@ def main(_):
9493
net_d_name = os.path.join(save_dir, 'net_d.npz')
9594

9695
data_files = glob(os.path.join("./data", FLAGS.dataset, "*.jpg"))
97-
# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
98-
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
96+
97+
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
9998

10099
##========================= TRAIN MODELS ================================##
101100
iter_counter = 0
102101
for epoch in range(FLAGS.epoch):
103102
## shuffle data
104103
shuffle(data_files)
105-
print("[*] Dataset shuffled!")
106104

107105
## update sample files based on shuffled data
108106
sample_files = data_files[0:FLAGS.sample_size]
@@ -119,46 +117,28 @@ def main(_):
119117
# more image augmentation functions in http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html
120118
batch = [get_image(batch_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for batch_file in batch_files]
121119
batch_images = np.array(batch).astype(np.float32)
122-
# batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
123-
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
120+
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32) # batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
124121
start_time = time.time()
125122
# updates the discriminator
126123
errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images })
127124
# updates the generator, run generator twice to make sure that d_loss does not go to zero (difference from paper)
128125
for _ in range(2):
129126
errG, _ = sess.run([g_loss, g_optim], feed_dict={z: batch_z})
130127
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
131-
% (epoch, FLAGS.epoch, idx, batch_idxs,
132-
time.time() - start_time, errD, errG))
133-
sys.stdout.flush()
128+
% (epoch, FLAGS.epoch, idx, batch_idxs, time.time() - start_time, errD, errG))
134129

135130
iter_counter += 1
136131
if np.mod(iter_counter, FLAGS.sample_step) == 0:
137132
# generate and visualize generated images
138133
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images})
139-
save_images(img, [8, 8],
140-
'./{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
134+
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
141135
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))
142-
sys.stdout.flush()
143136

144137
if np.mod(iter_counter, FLAGS.save_step) == 0:
145138
# save current network parameters
146139
print("[*] Saving checkpoints...")
147-
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images})
148-
model_dir = "%s_%s_%s" % (FLAGS.dataset, FLAGS.batch_size, FLAGS.output_size)
149-
save_dir = os.path.join(FLAGS.checkpoint_dir, model_dir)
150-
if not os.path.exists(save_dir):
151-
os.makedirs(save_dir)
152-
# the latest version location
153-
net_g_name = os.path.join(save_dir, 'net_g.npz')
154-
net_d_name = os.path.join(save_dir, 'net_d.npz')
155-
# # this version is for future re-check and visualization analysis
156-
# net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
157-
# net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
158-
# tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
159-
# tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
160-
# tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
161-
# tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
140+
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
141+
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
162142
print("[*] Saving checkpoints SUCCESS!")
163143

164144
if __name__ == '__main__':

model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ def generator_simplified_api(inputs, is_train=True, reuse=False):
1212
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
1313
c_dim = FLAGS.c_dim # n_color 3
1414
batch_size = FLAGS.batch_size # 64
15-
1615
w_init = tf.random_normal_initializer(stddev=0.02)
1716
gamma_init = tf.random_normal_initializer(1., 0.02)
18-
1917
with tf.variable_scope("generator", reuse=reuse):
2018
tl.layers.set_name_reuse(reuse)
2119

@@ -47,15 +45,12 @@ def generator_simplified_api(inputs, is_train=True, reuse=False):
4745
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
4846
return net_h4, logits
4947

50-
5148
def discriminator_simplified_api(inputs, is_train=True, reuse=False):
5249
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
5350
c_dim = FLAGS.c_dim # n_color 3
5451
batch_size = FLAGS.batch_size # 64
55-
5652
w_init = tf.random_normal_initializer(stddev=0.02)
5753
gamma_init = tf.random_normal_initializer(1., 0.02)
58-
5954
with tf.variable_scope("discriminator", reuse=reuse):
6055
tl.layers.set_name_reuse(reuse)
6156

tensorlayer/__init__.py

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from . import rein
2626

2727

28-
__version__ = "1.4.2"
28+
__version__ = "1.4.5"
2929

3030
global_flag = {}
3131
global_dict = {}
-914 Bytes
Binary file not shown.
-2.94 KB
Binary file not shown.
-19.1 KB
Binary file not shown.
-28.1 KB
Binary file not shown.
-9.11 KB
Binary file not shown.
-169 KB
Binary file not shown.
-31 KB
Binary file not shown.

0 commit comments

Comments
 (0)