|
16 | 16 |
|
17 | 17 | import pytest |
18 | 18 | import torch |
19 | | -from tests import helper |
20 | 19 |
|
21 | | -import tripy as tp |
22 | | - |
23 | | - |
24 | | -def perf_fixture(dtypes): |
25 | | - def perf_fixture_impl(func): |
26 | | - @pytest.fixture(params=dtypes, scope="session") |
27 | | - def wrapped(request): |
28 | | - tripy_module, torch_module, input_infos = func(request.param, helper.TORCH_DTYPES[request.param]) |
29 | | - |
30 | | - torch_state_dict = {key: torch.from_dlpack(value) for key, value in tripy_module.state_dict().items()} |
31 | | - torch_module.load_state_dict(torch_state_dict) |
32 | | - |
33 | | - compiler = tp.Compiler(tripy_module) |
34 | | - tripy_compiled = compiler.compile(**input_infos) |
35 | | - |
36 | | - inputs = { |
37 | | - key: tp.iota(input_info.shape_bounds.opt, dtype=request.param) |
38 | | - for key, input_info in input_infos.items() |
39 | | - } |
40 | | - for tensor in inputs.values(): |
41 | | - tensor.eval() |
42 | | - |
43 | | - torch_compiled = torch.compile(torch_module) |
| 20 | +# Need to import cases in order to populate PERF_CASES and load pytest fixtures |
| 21 | +from tests.performance.cases import * |
| 22 | +from tests.performance.conftest import PERF_CASES |
44 | 23 |
|
45 | | - return tripy_compiled, torch_compiled, inputs |
46 | | - |
47 | | - return wrapped |
48 | | - |
49 | | - return perf_fixture_impl |
50 | | - |
51 | | - |
52 | | -# TODO: File issue for FP32: |
53 | | -@perf_fixture(dtypes=[pytest.param(tp.float32, marks=pytest.mark.skip("Bug in MLIR-TRT")), tp.float16]) |
54 | | -def linear_block(tripy_dtype, torch_dtype): |
| 24 | +import tripy as tp |
55 | 25 |
|
56 | | - class LinearBlock(tp.Module): |
57 | | - def __init__(self): |
58 | | - self.layers = [tp.Linear(256, 256, bias=False, dtype=tripy_dtype) for _ in range(10)] |
59 | | - for layer in self.layers: |
60 | | - # Adjust the weights to prevent FP16 overflows. |
61 | | - weight = torch.tile( |
62 | | - torch.tensor([[-1, 1], [1, -1]], dtype=torch_dtype, device=torch.device("cuda")), (128, 128) |
63 | | - ) |
64 | | - layer.weight = tp.Parameter(weight) |
65 | 26 |
|
66 | | - def __call__(self, input): |
67 | | - for layer in self.layers: |
68 | | - input = layer(input) |
69 | | - return input |
| 27 | +@pytest.mark.parametrize("perf_case", PERF_CASES) |
| 28 | +def test_perf_regression(perf_case, benchmark): |
| 29 | + compiled_tripy_module, _, inputs = perf_case |
70 | 30 |
|
71 | | - class TorchLinearBlock(torch.nn.Module): |
72 | | - def __init__(self): |
73 | | - super().__init__() |
74 | | - self.layers = torch.nn.ModuleList( |
75 | | - [ |
76 | | - torch.nn.Linear(256, 256, bias=False, dtype=torch_dtype, device=torch.device("cuda")) |
77 | | - for _ in range(10) |
78 | | - ] |
79 | | - ) |
| 31 | + benchmark(compiled_tripy_module, **inputs) |
80 | 32 |
|
81 | | - def forward(self, input): |
82 | | - for layer in self.layers: |
83 | | - input = layer(input) |
84 | | - return input |
85 | 33 |
|
86 | | - tripy_block = LinearBlock() |
87 | | - torch_block = TorchLinearBlock() |
88 | | - input_infos = {"input": tp.InputInfo(shape=(1024, 256), dtype=tripy_dtype)} |
89 | | - return tripy_block, torch_block, input_infos |
| 34 | +@pytest.mark.parametrize("perf_case", PERF_CASES) |
| 35 | +def test_perf_comparative(perf_case): |
| 36 | + compiled_tripy_module, compiled_torch_module, inputs = perf_case |
90 | 37 |
|
| 38 | + WARM_UP_RUNS = 2 |
| 39 | + ITERATIONS = 100 |
91 | 40 |
|
92 | | -def test_perf_regression(linear_block, benchmark): |
93 | | - compiled_tripy_module, _, inputs = linear_block |
| 41 | + # Time Tripy |
| 42 | + stream = tp.default_stream() |
94 | 43 |
|
95 | | - benchmark(compiled_tripy_module, **inputs) |
| 44 | + for _ in range(WARM_UP_RUNS): |
| 45 | + compiled_tripy_module(**inputs) |
| 46 | + stream.synchronize() |
96 | 47 |
|
| 48 | + start = time.perf_counter() |
| 49 | + for _ in range(ITERATIONS): |
| 50 | + tripy_out = compiled_tripy_module(**inputs) |
| 51 | + stream.synchronize() |
| 52 | + end = time.perf_counter() |
97 | 53 |
|
98 | | -def test_perf_comparative(linear_block): |
99 | | - compiled_tripy_module, compiled_torch_module, inputs = linear_block |
| 54 | + # Torch will report time in ms: |
| 55 | + tripy_time = (end - start) * 1000 |
100 | 56 |
|
101 | | - def time_func(func, kwargs, warm_up_runs=2, iterations=100): |
102 | | - for _ in range(warm_up_runs): |
103 | | - func(**kwargs) |
| 57 | + # Time Torch |
| 58 | + torch_inputs = {key: torch.from_dlpack(value).to(device="cuda") for key, value in inputs.items()} |
104 | 59 |
|
105 | | - start = time.perf_counter() |
106 | | - for _ in range(iterations): |
107 | | - out = func(**kwargs) |
108 | | - end = time.perf_counter() |
| 60 | + with torch.no_grad(): |
| 61 | + for _ in range(WARM_UP_RUNS): |
| 62 | + compiled_torch_module(**torch_inputs) |
| 63 | + torch.cuda.synchronize() |
109 | 64 |
|
110 | | - return out, end - start |
| 65 | + start = torch.cuda.Event(enable_timing=True) |
| 66 | + end = torch.cuda.Event(enable_timing=True) |
111 | 67 |
|
112 | | - tripy_out, tripy_time = time_func(compiled_tripy_module, inputs) |
| 68 | + start.record() |
| 69 | + for _ in range(ITERATIONS): |
| 70 | + torch_out = compiled_torch_module(**torch_inputs) |
| 71 | + end.record() |
| 72 | + torch.cuda.synchronize() |
113 | 73 |
|
114 | | - # TODO: Figure out how to time torch more accurately: |
115 | | - torch_out, torch_time = time_func( |
116 | | - compiled_torch_module, {key: torch.from_dlpack(value) for key, value in inputs.items()} |
117 | | - ) |
| 74 | + torch_time = start.elapsed_time(end) |
118 | 75 |
|
119 | 76 | # If the outputs don't match, then we're either not comparing apples-to-apples |
120 | 77 | # or there is an accuracy bug somewhere - either way we want to catch it here. |
|
0 commit comments