Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import sys
import warnings
import time
import json

warnings.filterwarnings('ignore')

Expand Down Expand Up @@ -206,6 +208,11 @@ def _parse_args():
type=float,
default=5.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--perf_save_file",
type=str,
default=None,
help="The file to save performance metrics to.")

args = parser.parse_args()

Expand Down Expand Up @@ -233,6 +240,9 @@ def generate(args):
device = local_rank
_init_logging(rank)

# Start timing total execution
total_start_time = time.time()

if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
logging.info(
Expand Down Expand Up @@ -330,6 +340,7 @@ def generate(args):

logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
start_time = time.time()
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
Expand All @@ -340,6 +351,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
generation_time = time.time() - start_time

elif "i2v" in args.task:
if args.prompt is None:
Expand Down Expand Up @@ -386,6 +398,7 @@ def generate(args):
)

logging.info("Generating video ...")
start_time = time.time()
video = wan_i2v.generate(
args.prompt,
img,
Expand All @@ -397,6 +410,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
generation_time = time.time() - start_time
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
Expand Down Expand Up @@ -444,6 +458,7 @@ def generate(args):
)

logging.info("Generating video ...")
start_time = time.time()
video = wan_flf2v.generate(
args.prompt,
first_frame,
Expand All @@ -457,6 +472,7 @@ def generate(args):
seed=args.base_seed,
offload_model=args.offload_model
)
generation_time = time.time() - start_time

if rank == 0:
if args.save_file is None:
Expand All @@ -483,6 +499,29 @@ def generate(args):
nrow=1,
normalize=True,
value_range=(-1, 1))

# Save performance metrics if requested
if args.perf_save_file is not None:
# Calculate total time including model loading
total_time = time.time() - total_start_time

perf_data = {
"task": args.task,
"size": args.size,
"frame_num": args.frame_num,
"sample_steps": args.sample_steps,
"ulysses_size": args.ulysses_size,
"ring_size": args.ring_size,
"generation_time_seconds": generation_time,
"total_time_seconds": total_time
}

with open(args.perf_save_file, 'w') as f:
json.dump(perf_data, f, indent=2)
logging.info(f"Saved performance metrics to {args.perf_save_file}")
logging.info(f"Generation time: {generation_time:.2f} seconds")
logging.info(f"Total time (including model loading): {total_time:.2f} seconds")

logging.info("Finished.")


Expand Down