2121import tripy as tp
2222
2323
24+ def perf_fixture (dtypes ):
25+ def perf_fixture_impl (func ):
26+ @pytest .fixture (params = dtypes , scope = "session" )
27+ def wrapped (request ):
28+ tripy_module , torch_module , input_infos = func (request .param , helper .TORCH_DTYPES [request .param ])
29+
30+ torch_state_dict = {key : torch .from_dlpack (value ) for key , value in tripy_module .state_dict ().items ()}
31+ torch_module .load_state_dict (torch_state_dict )
32+
33+ compiler = tp .Compiler (tripy_module )
34+ tripy_compiled = compiler .compile (** input_infos )
35+
36+ inputs = {
37+ key : tp .iota (input_info .shape_bounds .opt , dtype = request .param )
38+ for key , input_info in input_infos .items ()
39+ }
40+ for tensor in inputs .values ():
41+ tensor .eval ()
42+
43+ torch_compiled = torch .compile (torch_module )
44+
45+ return tripy_compiled , torch_compiled , inputs
46+
47+ return wrapped
48+
49+ return perf_fixture_impl
50+
51+
2452# TODO: File issue for FP32:
25- @pytest .fixture (params = [pytest .param (tp .float32 , marks = pytest .mark .skip ("Bug in MLIR-TRT" )), tp .float16 ])
26- def linear_block (request ):
53+ @perf_fixture (dtypes = [pytest .param (tp .float32 , marks = pytest .mark .skip ("Bug in MLIR-TRT" )), tp .float16 ])
54+ def linear_block (tripy_dtype , torch_dtype ):
55+
2756 class LinearBlock (tp .Module ):
2857 def __init__ (self ):
29- self .layers = [tp .Linear (256 , 256 , bias = False , dtype = request . param ) for _ in range (10 )]
58+ self .layers = [tp .Linear (256 , 256 , bias = False , dtype = tripy_dtype ) for _ in range (10 )]
3059 for layer in self .layers :
3160 # Adjust the weights to prevent FP16 overflows.
32- layer .weight = tp .Parameter ((tp .iota ((256 , 256 ), dim = 1 , dtype = request .param ) / 256.0 ) - 0.5 )
61+ weight = torch .tile (
62+ torch .tensor ([[- 1 , 1 ], [1 , - 1 ]], dtype = torch_dtype , device = torch .device ("cuda" )), (128 , 128 )
63+ )
64+ layer .weight = tp .Parameter (weight )
3365
3466 def __call__ (self , input ):
3567 for layer in self .layers :
3668 input = layer (input )
37- print (torch .from_dlpack (input ))
3869 return input
3970
4071 class TorchLinearBlock (torch .nn .Module ):
4172 def __init__ (self ):
4273 super ().__init__ ()
43- dtype = helper .TORCH_DTYPES [request .param ]
4474 self .layers = torch .nn .ModuleList (
45- [torch .nn .Linear (256 , 256 , bias = False , dtype = dtype , device = torch .device ("cuda" )) for _ in range (10 )]
75+ [
76+ torch .nn .Linear (256 , 256 , bias = False , dtype = torch_dtype , device = torch .device ("cuda" ))
77+ for _ in range (10 )
78+ ]
4679 )
4780
4881 def forward (self , input ):
4982 for layer in self .layers :
5083 input = layer (input )
51- print (input )
5284 return input
5385
5486 tripy_block = LinearBlock ()
5587 torch_block = TorchLinearBlock ()
56-
57- torch_state_dict = {key : torch .from_dlpack (value ) for key , value in tripy_block .state_dict ().items ()}
58- torch_block .load_state_dict (torch_state_dict )
59-
60- input_infos = {"input" : tp .InputInfo (shape = (1024 , 256 ), dtype = request .param )}
61-
62- # compiler = tp.Compiler(tripy_block)
63- # tripy_compiled = compiler.compile(**input_infos)
64- tripy_compiled = tripy_block
65-
66- inputs = {
67- key : tp .iota (input_info .shape_bounds .opt , dtype = request .param ) / 100.0
68- for key , input_info in input_infos .items ()
69- }
70- for tensor in inputs .values ():
71- tensor .eval ()
72-
73- torch_compiled = torch .compile (torch_block )
74-
75- return tripy_compiled , torch_compiled , inputs
88+ input_infos = {"input" : tp .InputInfo (shape = (1024 , 256 ), dtype = tripy_dtype )}
89+ return tripy_block , torch_block , input_infos
7690
7791
7892def test_perf_regression (linear_block , benchmark ):
@@ -84,30 +98,29 @@ def test_perf_regression(linear_block, benchmark):
8498def test_perf_comparative (linear_block ):
8599 compiled_tripy_module , compiled_torch_module , inputs = linear_block
86100
87- # TODO: Change to 100:
88- NUM_ITERS = 1
101+ def time_func (func , kwargs , warm_up_runs = 2 , iterations = 100 ):
102+ for _ in range (warm_up_runs ):
103+ func (** kwargs )
89104
90- # TODO: Add warm-up runs, factor out into function.
91- start = time .perf_counter ()
92- for _ in range (NUM_ITERS ):
93- tripy_out = compiled_tripy_module (** inputs )
94- end = time .perf_counter ()
105+ start = time .perf_counter ()
106+ for _ in range (iterations ):
107+ out = func (** kwargs )
108+ end = time .perf_counter ()
95109
96- tripy_time = end - start
110+ return out , end - start
97111
98- start = time .perf_counter ()
99- for _ in range (NUM_ITERS ):
100- torch_out = compiled_torch_module (** {key : torch .from_dlpack (value ) for key , value in inputs .items ()})
101- end = time .perf_counter ()
112+ tripy_out , tripy_time = time_func (compiled_tripy_module , inputs )
102113
103- torch_time = end - start
114+ # TODO: Figure out how to time torch more accurately:
115+ torch_out , torch_time = time_func (
116+ compiled_torch_module , {key : torch .from_dlpack (value ) for key , value in inputs .items ()}
117+ )
104118
105119 # If the outputs don't match, then we're either not comparing apples-to-apples
106120 # or there is an accuracy bug somewhere - either way we want to catch it here.
107- # TODO: Adjust tolerance per test?
108- # TODO: File accuracy bug? Check if delta is within expected FP16 error - maybe check CUDA vs. torch CPU.
109- assert torch .allclose (torch_out , torch .from_dlpack (tripy_out ), atol = 0.01 )
121+ assert torch .allclose (torch_out , torch .from_dlpack (tripy_out ))
110122
123+ # TODO: Make this threshold adjustable
111124 # Check that Tripy inference is at least 5% faster
112125 print (f"Tripy was { torch_time / float (tripy_time )} x faster than Torch" )
113126 assert (tripy_time * 1.05 ) < torch_time
0 commit comments