1414# limitations under the License.
1515import time
1616from textwrap import dedent
17+ from typing import Callable
1718
1819import pytest
1920import torch
2526import tripy as tp
2627
2728
29+ def run_timed_trials (thunk : Callable [[], None ], warm_up_runs = 10 , iterations = 1000 ):
30+ """
31+ Returns the average time measured for calls to the thunk (the function intended to be timed)
32+ in microseconds. First performs the specified number of untimed warm-ups.
33+ """
34+
35+ for _ in range (warm_up_runs ):
36+ thunk ()
37+
38+ start = time .perf_counter_ns ()
39+ for _ in range (iterations ):
40+ thunk ()
41+ end = time .perf_counter_ns ()
42+ return (end - start ) / (iterations * 1000.0 )
43+
44+
2845@pytest .mark .parametrize ("perf_case" , PERF_CASES )
2946def test_perf_regression (perf_case , benchmark ):
3047 compiled_tripy_module , _ , inputs , _ = perf_case
@@ -115,15 +132,10 @@ def func({arg_str}):
115132 for input in inputs :
116133 input .eval ()
117134
118- for _ in range (warm_up_runs ):
119- compiled_one_io (* inputs )
120-
121- start = time .perf_counter_ns ()
122- for _ in range (iterations ):
123- compiled_one_io (* inputs )
124- end = time .perf_counter_ns ()
135+ def measure_thunk ():
136+ return compiled_one_io (* inputs )
125137
126- return ( end - start ) / ( iterations * 1000.0 )
138+ return run_timed_trials ( measure_thunk , warm_up_runs = warm_up_runs , iterations = iterations )
127139
128140 assert measure_overhead (1 ) < 60.0
129141
@@ -137,3 +149,13 @@ def func({arg_str}):
137149 # Ensure all deltas are within a few microseconds of each other
138150 average_delta = sum (deltas ) / float (len (deltas ))
139151 assert all (abs (delta - average_delta ) < 10 for delta in deltas )
152+
153+
154+ def test_tripy_param_update (benchmark ):
155+ m = tp .Module ()
156+ m .param = tp .Parameter ([1 , 2 , 3 , 4 ])
157+
158+ def measure_thunk ():
159+ m .param = tp .Parameter ([5 , 6 , 7 , 8 ])
160+
161+ benchmark (measure_thunk )
0 commit comments