Skip to content

Commit 6b2d3e6

Browse files
committed
Add verbose flag, update tests and clean up example
1 parent 0d8861c commit 6b2d3e6

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed
461 KB
Loading
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import argparse
19+
import os
20+
import numpy as np
21+
from PIL import Image
22+
from argparse import Namespace
23+
from skimage.metrics import structural_similarity
24+
25+
from example import tripy_diffusion
26+
27+
28+
def load_reference_image(image_path, verbose=False):
29+
"""Load reference image from file path."""
30+
if not os.path.exists(image_path):
31+
raise FileNotFoundError(f"Reference image not found: {image_path}")
32+
33+
if verbose:
34+
print(f"[I] Loading reference image from {image_path}")
35+
return Image.open(image_path)
36+
37+
38+
def compare_images(tripy_img, reference_img, threshold=0.80):
39+
"""Compare two images using structural similarity index."""
40+
# Convert both images to grayscale numpy arrays for comparison
41+
tripy_array = np.array(tripy_img.convert("L"))
42+
reference_array = np.array(reference_img.convert("L"))
43+
44+
# Ensure both images have the same dimensions
45+
if tripy_array.shape != reference_array.shape:
46+
print(f"[W] Image shape mismatch: tripy {tripy_array.shape} vs reference {reference_array.shape}")
47+
# Resize reference to match tripy output
48+
reference_img_resized = reference_img.resize(tripy_img.size, Image.Resampling.LANCZOS)
49+
reference_array = np.array(reference_img_resized.convert("L"))
50+
51+
# Calculate structural similarity
52+
ssim = structural_similarity(tripy_array, reference_array)
53+
54+
if ssim >= threshold:
55+
print(f"[I] Passed: Images are similar (SSIM >= {threshold})")
56+
return True
57+
else:
58+
print(f"[I] Failed: Images are not similar enough (SSIM < {threshold})")
59+
return False
60+
61+
62+
def main():
63+
parser = argparse.ArgumentParser(
64+
description="Compare tripy diffusion output with a reference image",
65+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
66+
)
67+
68+
# Reference image argument
69+
parser.add_argument(
70+
"--reference",
71+
type=str,
72+
default="assets/torch_ref_fp16_fuji_steps50_seed420.png",
73+
help="Path to reference image file to compare against",
74+
)
75+
76+
# Diffusion parameters (matching example.py)
77+
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps in diffusion")
78+
parser.add_argument(
79+
"--prompt",
80+
type=str,
81+
default="a beautiful photograph of Mt. Fuji during cherry blossom",
82+
help="Phrase to render",
83+
)
84+
parser.add_argument("--fp16", action="store_true", help="Cast the weights to float16")
85+
parser.add_argument("--seed", type=int, help="Set the random latent seed")
86+
parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength")
87+
parser.add_argument(
88+
"--hf-token", type=str, default="", help="HuggingFace API access token for downloading model checkpoints"
89+
)
90+
parser.add_argument("--engine-dir", type=str, default="engines", help="Output directory for TensorRT engines")
91+
92+
# Comparison parameters
93+
parser.add_argument("--threshold", type=float, default=0.80, help="SSIM threshold for considering images similar")
94+
parser.add_argument("--save-output", type=str, default=None, help="Save the tripy output image to this path")
95+
parser.add_argument(
96+
"--verbose", action="store_true", default=False, help="Enable verbose output with timing and progress bars"
97+
)
98+
99+
args = parser.parse_args()
100+
101+
# Load reference image
102+
try:
103+
reference_img = load_reference_image(args.reference)
104+
except FileNotFoundError as e:
105+
print(f"[E] {e}")
106+
return 1
107+
108+
# Create args namespace for tripy_diffusion
109+
tripy_args = Namespace(
110+
steps=args.steps,
111+
prompt=args.prompt,
112+
out=args.save_output,
113+
fp16=args.fp16,
114+
seed=args.seed,
115+
guidance=args.guidance,
116+
torch_inference=False,
117+
hf_token=args.hf_token,
118+
engine_dir=args.engine_dir,
119+
verbose=args.verbose,
120+
)
121+
122+
# Run tripy diffusion
123+
if args.verbose:
124+
print(f"[I] Running tripy diffusion with parameters:")
125+
print(f" Prompt: {args.prompt}")
126+
print(f" Steps: {args.steps}")
127+
print(f" FP16: {args.fp16}")
128+
print(f" Seed: {args.seed}")
129+
print(f" Guidance: {args.guidance}")
130+
131+
try:
132+
tripy_img, times = tripy_diffusion(tripy_args)
133+
except Exception as e:
134+
print(f"[E] Error running tripy diffusion: {e}")
135+
return 1
136+
137+
# Compare images
138+
is_similar = compare_images(tripy_img, reference_img, args.threshold)
139+
140+
# Save output if requested
141+
if args.save_output:
142+
if not os.path.isdir(os.path.dirname(args.save_output)):
143+
os.makedirs(os.path.dirname(args.save_output), exist_ok=True)
144+
tripy_img.save(args.save_output)
145+
print(f"[I] Saved tripy output to {args.save_output}")
146+
147+
# Return appropriate exit code
148+
return 0 if is_similar else 1
149+
150+
151+
if __name__ == "__main__":
152+
exit(main())

0 commit comments

Comments
 (0)