-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval_stgan.py
More file actions
99 lines (69 loc) · 3.07 KB
/
eval_stgan.py
File metadata and controls
99 lines (69 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
eval_stgan.py — Evaluation / inference script for the ST-GAN model.
Composites glasses from the pre-built glasses dataset onto a test face image
and saves the per-stage composite results as PNG files.
Usage
-----
.. code-block:: bash
python eval_stgan.py \\
--group 0 --model STGAN --warpN 1 \\
--loadGP 0_STGAN \\
--loadImage path/to/face.png
Outputs
-------
For each glasses sample ``b`` in the batch::
eval_{loadGP}/image_g{b}_input.png — composite at stage 0 (no warp)
eval_{loadGP}/image_g{b}_output.png — composite at final warp stage
"""
import sys
import time
import numpy as np
# Add src/ to path so `stgan` package is importable
sys.path.insert(0, "src")
from stgan import utils
def main() -> None:
"""Entry point: build the evaluation graph and run inference."""
print(utils.toYellow("======================================================="))
print(utils.toYellow("eval_stgan.py — ST-GAN with homography"))
print(utils.toYellow("======================================================="))
import tensorflow as tf
from stgan import data, graph, options, warp
opt = options.set(training=False)
print(utils.toMagenta("building graph..."))
tf.reset_default_graph()
with tf.device(opt.GPUdevice):
imageBG = tf.placeholder(tf.float32, shape=[opt.batchSize, opt.H, opt.W, 3])
imageFG = tf.placeholder(tf.float32, shape=[opt.batchSize, opt.H, opt.W, 4])
PH = [imageBG, imageFG]
pPertFG = opt.pertFG * tf.random_normal([opt.batchSize, opt.warpDim])
geometric = graph.geometric_multires
imageFGwarpAll, _, _ = geometric(opt, imageBG, imageFG, pPertFG)
imageCompAll = [
graph.composite(opt, imageBG, imageFGwarpAll[stage])
for stage in range(opt.warpN + 1)
]
varsGP = [v for v in tf.global_variables() if "geometric" in v.name]
# Load glasses dataset
print(utils.toMagenta("loading test data..."))
glasses = np.load("dataset/glasses.npy")
saver_GP = tf.train.Saver(var_list=varsGP)
print(utils.toYellow("======= EVALUATION START ======="))
t_start = time.time()
tfConfig = tf.ConfigProto(allow_soft_placement=True)
tfConfig.gpu_options.allow_growth = True
with tf.Session(config=tfConfig) as sess:
sess.run(tf.global_variables_initializer())
utils.restoreModel(opt, sess, saver_GP, opt.loadGP, "GP")
print(utils.toMagenta("start evaluation..."))
out_dir = f"eval_{opt.loadGP}"
utils.mkdir(out_dir)
test_image = utils.imread(opt.loadImage)
batch = data.makeBatchEval(opt, test_image, glasses, PH)
ic0, icf = sess.run([imageCompAll[0], imageCompAll[-1]], feed_dict=batch)
for b in range(opt.batchSize):
utils.imsave(f"{out_dir}/image_g{b}_input.png", ic0[b])
utils.imsave(f"{out_dir}/image_g{b}_output.png", icf[b])
print(utils.toGreen(f"Saved {opt.batchSize} composites to {out_dir}/"))
print(utils.toYellow("======= EVALUATION DONE ======="))
if __name__ == "__main__":
main()