1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+ import time
16+
1517import pytest
1618import torch
19+ from tests import helper
1720
1821import tripy as tp
19- from tests import helper
2022
2123
2224# TODO: File issue for FP32:
@@ -25,39 +27,52 @@ def linear_block(request):
2527 class LinearBlock (tp .Module ):
2628 def __init__ (self ):
2729 self .layers = [tp .Linear (256 , 256 , bias = False , dtype = request .param ) for _ in range (10 )]
30+ for layer in self .layers :
31+ # Adjust the weights to prevent FP16 overflows.
32+ layer .weight = tp .Parameter ((tp .iota ((256 , 256 ), dim = 1 , dtype = request .param ) / 256.0 ) - 0.5 )
2833
2934 def __call__ (self , input ):
3035 for layer in self .layers :
3136 input = layer (input )
37+ print (torch .from_dlpack (input ))
3238 return input
3339
3440 class TorchLinearBlock (torch .nn .Module ):
3541 def __init__ (self ):
3642 super ().__init__ ()
37- self .layers = [
38- torch .nn .Linear (256 , 256 , bias = False , dtype = helper .TORCH_DTYPES [request .param ]) for _ in range (10 )
39- ]
43+ dtype = helper .TORCH_DTYPES [request .param ]
44+ self .layers = torch .nn .ModuleList (
45+ [torch .nn .Linear (256 , 256 , bias = False , dtype = dtype , device = torch .device ("cuda" )) for _ in range (10 )]
46+ )
4047
4148 def forward (self , input ):
4249 for layer in self .layers :
4350 input = layer (input )
51+ print (input )
4452 return input
4553
4654 tripy_block = LinearBlock ()
4755 torch_block = TorchLinearBlock ()
4856
49- tripy_block .load_from_state_dict (state_dict = {key : tp .Parameter (value ) for key , value in torch_block .state_dict ()})
57+ torch_state_dict = {key : torch .from_dlpack (value ) for key , value in tripy_block .state_dict ().items ()}
58+ torch_block .load_state_dict (torch_state_dict )
5059
5160 input_infos = {"input" : tp .InputInfo (shape = (1024 , 256 ), dtype = request .param )}
5261
53- compiler = tp .Compiler (tripy_block )
54- tripy_compiled = compiler .compile (** input_infos )
62+ # compiler = tp.Compiler(tripy_block)
63+ # tripy_compiled = compiler.compile(**input_infos)
64+ tripy_compiled = tripy_block
5565
56- inputs = {key : tp .iota (input_info .shape_bounds .opt , dtype = request .param ) for key , input_info in input_infos .items ()}
66+ inputs = {
67+ key : tp .iota (input_info .shape_bounds .opt , dtype = request .param ) / 100.0
68+ for key , input_info in input_infos .items ()
69+ }
5770 for tensor in inputs .values ():
5871 tensor .eval ()
5972
60- return tripy_compiled , torch_block , inputs
73+ torch_compiled = torch .compile (torch_block )
74+
75+ return tripy_compiled , torch_compiled , inputs
6176
6277
6378def test_perf_regression (linear_block , benchmark ):
@@ -67,8 +82,32 @@ def test_perf_regression(linear_block, benchmark):
6782
6883
6984def test_perf_comparative (linear_block ):
70- compiled_tripy_module , torch_module , inputs = linear_block
85+ compiled_tripy_module , compiled_torch_module , inputs = linear_block
86+
87+ # TODO: Change to 100:
88+ NUM_ITERS = 1
89+
90+ # TODO: Add warm-up runs, factor out into function.
91+ start = time .perf_counter ()
92+ for _ in range (NUM_ITERS ):
93+ tripy_out = compiled_tripy_module (** inputs )
94+ end = time .perf_counter ()
95+
96+ tripy_time = end - start
97+
98+ start = time .perf_counter ()
99+ for _ in range (NUM_ITERS ):
100+ torch_out = compiled_torch_module (** {key : torch .from_dlpack (value ) for key , value in inputs .items ()})
101+ end = time .perf_counter ()
102+
103+ torch_time = end - start
71104
72- # TODO: Check accuracy - update fixture to make weights same
105+ # If the outputs don't match, then we're either not comparing apples-to-apples
106+ # or there is an accuracy bug somewhere - either way we want to catch it here.
107+ # TODO: Adjust tolerance per test?
108+ # TODO: File accuracy bug? Check if delta is within expected FP16 error - maybe check CUDA vs. torch CPU.
109+ assert torch .allclose (torch_out , torch .from_dlpack (tripy_out ), atol = 0.01 )
73110
74- # TODO: Compare perf after compiling? Maybe compile in fixture
111+ # Check that Tripy inference is at least 5% faster
112+ print (f"Tripy was { torch_time / float (tripy_time )} x faster than Torch" )
113+ assert (tripy_time * 1.05 ) < torch_time
0 commit comments