diff --git a/test.py b/test.py index bb4884b6b..35ad4d7ba 100644 --- a/test.py +++ b/test.py @@ -8,6 +8,9 @@ import gc import os import unittest +import argparse +import sys +import time # Add time module import import torch from torchbenchmark import ( @@ -26,6 +29,15 @@ # unresponsive for 5 minutes the parent will presume it dead / incapacitated.) TIMEOUT = int(os.getenv("TIMEOUT", 300)) # Seconds +# Add argument parser +parser = argparse.ArgumentParser(description='Run benchmark tests', add_help=False) +parser.add_argument('-t', '--iterations', type=int, default=300, + help='Number of iterations to run inference (default: 300)') +# Parse only known arguments to avoid interfering with unittest +args, unknown = parser.parse_known_args() + +# Store iterations in a global variable +ITERATIONS = args.iterations class TestBenchmark(unittest.TestCase): def setUp(self): @@ -55,6 +67,8 @@ def _create_example_model_instance(task: ModelTask, device: str): def _load_test(path, device): model_name = os.path.basename(path) + print(f"Loading test for model {model_name} on {device}") + def _skip_cuda_memory_check_p(metadata): if device != "cuda": @@ -94,12 +108,32 @@ def train_fn(self): skip=_skip_cuda_memory_check_p(metadata), assert_equal=self.assertEqual ): try: + # Measure model initialization time + init_start_time = time.time() task.make_model_instance( test="train", device=device, batch_size=batch_size ) - task.invoke() - task.check_details_train(device=device, md=metadata) + init_time = time.time() - init_start_time + print(f"\nModel initialization time: {init_time:.2f} seconds") + + # Measure training time + train_start_time = time.time() + # Run training for specified number of iterations + for _ in range(ITERATIONS): + task.invoke() + task.check_details_train(device=device, md=metadata) + train_time = time.time() - train_start_time + print(f"Training time: {train_time:.2f} seconds") + + # Measure cleanup time + cleanup_start_time = time.time() task.del_model_instance() + cleanup_time = time.time() - cleanup_start_time + print(f"Cleanup time: {cleanup_time:.2f} seconds") + + # Print total time + total_time = init_time + train_time + cleanup_time + print(f"Total time: {total_time:.2f} seconds") except NotImplementedError as e: self.skipTest( f'Method train on {device} is not implemented because "{e}", skipping...' @@ -117,13 +151,33 @@ def eval_fn(self): skip=_skip_cuda_memory_check_p(metadata), assert_equal=self.assertEqual ): try: + # Measure model initialization time + init_start_time = time.time() task.make_model_instance( test="eval", device=device, batch_size=batch_size ) - task.invoke() - task.check_details_eval(device=device, md=metadata) - task.check_eval_output() + init_time = time.time() - init_start_time + print(f"\nModel initialization time: {init_time:.2f} seconds") + + # Measure evaluation time + eval_start_time = time.time() + # Run inference for specified number of iterations + for _ in range(ITERATIONS): + task.invoke() + task.check_details_eval(device=device, md=metadata) + task.check_eval_output() + eval_time = time.time() - eval_start_time + print(f"Evaluation time: {eval_time:.2f} seconds") + + # Measure cleanup time + cleanup_start_time = time.time() task.del_model_instance() + cleanup_time = time.time() - cleanup_start_time + print(f"Cleanup time: {cleanup_time:.2f} seconds") + + # Print total time + total_time = init_time + eval_time + cleanup_time + print(f"Total time: {total_time:.2f} seconds") except NotImplementedError as e: self.skipTest( f'Method eval on {device} is not implemented because "{e}", skipping...' @@ -187,4 +241,6 @@ def _load_tests(): _load_tests() if __name__ == "__main__": - unittest.main() + # Pass unknown arguments to unittest + sys.argv[1:] = unknown + unittest.main() \ No newline at end of file