Skip to content

Commit 046cbec

Browse files
adds perf fixture
1 parent 467e132 commit 046cbec

File tree

1 file changed

+57
-44
lines changed

1 file changed

+57
-44
lines changed

tripy/tests/performance/test_perf.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,58 +21,72 @@
2121
import tripy as tp
2222

2323

24+
def perf_fixture(dtypes):
25+
def perf_fixture_impl(func):
26+
@pytest.fixture(params=dtypes, scope="session")
27+
def wrapped(request):
28+
tripy_module, torch_module, input_infos = func(request.param, helper.TORCH_DTYPES[request.param])
29+
30+
torch_state_dict = {key: torch.from_dlpack(value) for key, value in tripy_module.state_dict().items()}
31+
torch_module.load_state_dict(torch_state_dict)
32+
33+
compiler = tp.Compiler(tripy_module)
34+
tripy_compiled = compiler.compile(**input_infos)
35+
36+
inputs = {
37+
key: tp.iota(input_info.shape_bounds.opt, dtype=request.param)
38+
for key, input_info in input_infos.items()
39+
}
40+
for tensor in inputs.values():
41+
tensor.eval()
42+
43+
torch_compiled = torch.compile(torch_module)
44+
45+
return tripy_compiled, torch_compiled, inputs
46+
47+
return wrapped
48+
49+
return perf_fixture_impl
50+
51+
2452
# TODO: File issue for FP32:
25-
@pytest.fixture(params=[pytest.param(tp.float32, marks=pytest.mark.skip("Bug in MLIR-TRT")), tp.float16])
26-
def linear_block(request):
53+
@perf_fixture(dtypes=[pytest.param(tp.float32, marks=pytest.mark.skip("Bug in MLIR-TRT")), tp.float16])
54+
def linear_block(tripy_dtype, torch_dtype):
55+
2756
class LinearBlock(tp.Module):
2857
def __init__(self):
29-
self.layers = [tp.Linear(256, 256, bias=False, dtype=request.param) for _ in range(10)]
58+
self.layers = [tp.Linear(256, 256, bias=False, dtype=tripy_dtype) for _ in range(10)]
3059
for layer in self.layers:
3160
# Adjust the weights to prevent FP16 overflows.
32-
layer.weight = tp.Parameter((tp.iota((256, 256), dim=1, dtype=request.param) / 256.0) - 0.5)
61+
weight = torch.tile(
62+
torch.tensor([[-1, 1], [1, -1]], dtype=torch_dtype, device=torch.device("cuda")), (128, 128)
63+
)
64+
layer.weight = tp.Parameter(weight)
3365

3466
def __call__(self, input):
3567
for layer in self.layers:
3668
input = layer(input)
37-
print(torch.from_dlpack(input))
3869
return input
3970

4071
class TorchLinearBlock(torch.nn.Module):
4172
def __init__(self):
4273
super().__init__()
43-
dtype = helper.TORCH_DTYPES[request.param]
4474
self.layers = torch.nn.ModuleList(
45-
[torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=torch.device("cuda")) for _ in range(10)]
75+
[
76+
torch.nn.Linear(256, 256, bias=False, dtype=torch_dtype, device=torch.device("cuda"))
77+
for _ in range(10)
78+
]
4679
)
4780

4881
def forward(self, input):
4982
for layer in self.layers:
5083
input = layer(input)
51-
print(input)
5284
return input
5385

5486
tripy_block = LinearBlock()
5587
torch_block = TorchLinearBlock()
56-
57-
torch_state_dict = {key: torch.from_dlpack(value) for key, value in tripy_block.state_dict().items()}
58-
torch_block.load_state_dict(torch_state_dict)
59-
60-
input_infos = {"input": tp.InputInfo(shape=(1024, 256), dtype=request.param)}
61-
62-
# compiler = tp.Compiler(tripy_block)
63-
# tripy_compiled = compiler.compile(**input_infos)
64-
tripy_compiled = tripy_block
65-
66-
inputs = {
67-
key: tp.iota(input_info.shape_bounds.opt, dtype=request.param) / 100.0
68-
for key, input_info in input_infos.items()
69-
}
70-
for tensor in inputs.values():
71-
tensor.eval()
72-
73-
torch_compiled = torch.compile(torch_block)
74-
75-
return tripy_compiled, torch_compiled, inputs
88+
input_infos = {"input": tp.InputInfo(shape=(1024, 256), dtype=tripy_dtype)}
89+
return tripy_block, torch_block, input_infos
7690

7791

7892
def test_perf_regression(linear_block, benchmark):
@@ -84,30 +98,29 @@ def test_perf_regression(linear_block, benchmark):
8498
def test_perf_comparative(linear_block):
8599
compiled_tripy_module, compiled_torch_module, inputs = linear_block
86100

87-
# TODO: Change to 100:
88-
NUM_ITERS = 1
101+
def time_func(func, kwargs, warm_up_runs=2, iterations=100):
102+
for _ in range(warm_up_runs):
103+
func(**kwargs)
89104

90-
# TODO: Add warm-up runs, factor out into function.
91-
start = time.perf_counter()
92-
for _ in range(NUM_ITERS):
93-
tripy_out = compiled_tripy_module(**inputs)
94-
end = time.perf_counter()
105+
start = time.perf_counter()
106+
for _ in range(iterations):
107+
out = func(**kwargs)
108+
end = time.perf_counter()
95109

96-
tripy_time = end - start
110+
return out, end - start
97111

98-
start = time.perf_counter()
99-
for _ in range(NUM_ITERS):
100-
torch_out = compiled_torch_module(**{key: torch.from_dlpack(value) for key, value in inputs.items()})
101-
end = time.perf_counter()
112+
tripy_out, tripy_time = time_func(compiled_tripy_module, inputs)
102113

103-
torch_time = end - start
114+
# TODO: Figure out how to time torch more accurately:
115+
torch_out, torch_time = time_func(
116+
compiled_torch_module, {key: torch.from_dlpack(value) for key, value in inputs.items()}
117+
)
104118

105119
# If the outputs don't match, then we're either not comparing apples-to-apples
106120
# or there is an accuracy bug somewhere - either way we want to catch it here.
107-
# TODO: Adjust tolerance per test?
108-
# TODO: File accuracy bug? Check if delta is within expected FP16 error - maybe check CUDA vs. torch CPU.
109-
assert torch.allclose(torch_out, torch.from_dlpack(tripy_out), atol=0.01)
121+
assert torch.allclose(torch_out, torch.from_dlpack(tripy_out))
110122

123+
# TODO: Make this threshold adjustable
111124
# Check that Tripy inference is at least 5% faster
112125
print(f"Tripy was {torch_time / float(tripy_time)}x faster than Torch")
113126
assert (tripy_time * 1.05) < torch_time

0 commit comments

Comments
 (0)