Skip to content

Commit f2db538

Browse files
fixes torch benchmarking, refactors infra
1 parent 663d3c1 commit f2db538

File tree

6 files changed

+205
-84
lines changed

6 files changed

+205
-84
lines changed

tripy/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ test = [
5656
"pytest-cov==4.1.0",
5757
"pytest-xdist==3.6.1",
5858
"pytest-benchmark==4.0.0",
59+
"pytest-lazy-fixture==0.6.3",
5960
# Triton is required for torch.compile
6061
"triton==3.0.0",
6162
"snakeviz==2.2.0",
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
__all__ = []
16+
17+
18+
# In order to make the pytest fixtures defined in this submodule visible, we
19+
# need to import them in the test using their function names. To do so, we can
20+
# export them via this file by making them local variables and adding them to `__all__`.
21+
#
22+
# Note that just importing the module is sufficient to update PERF_CASES, but does
23+
# not make the actual fixture function visible to pytest.
24+
def __discover_modules():
25+
import importlib
26+
import pkgutil
27+
28+
mods = [importlib.import_module("tests.performance.cases")]
29+
while mods:
30+
mod = mods.pop(0)
31+
32+
yield mod
33+
34+
if hasattr(mod, "__path__"):
35+
mods.extend(
36+
[
37+
importlib.import_module(f"{mod.__name__}.{submod.name}")
38+
for submod in pkgutil.iter_modules(mod.__path__)
39+
]
40+
)
41+
42+
43+
modules = list(__discover_modules())[1:]
44+
45+
# Discover and import all perf fixtures.
46+
from tests.performance.conftest import PERF_CASES
47+
48+
__perf_case_names = {case.name for case in PERF_CASES}
49+
50+
for mod in modules:
51+
for name, obj in mod.__dict__.items():
52+
if name in __perf_case_names:
53+
locals()[name] = obj
54+
__all__.append(name)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
import torch
17+
from tests.performance.conftest import perf_fixture
18+
19+
import tripy as tp
20+
21+
22+
# TODO: File issue for FP32:
23+
@perf_fixture(dtypes=[pytest.param(tp.float32, marks=pytest.mark.skip("Bug in MLIR-TRT")), tp.float16])
24+
def linear_block(tripy_dtype, torch_dtype):
25+
26+
class LinearBlock(tp.Module):
27+
def __init__(self):
28+
self.layers = [tp.Linear(256, 256, bias=False, dtype=tripy_dtype) for _ in range(10)]
29+
for layer in self.layers:
30+
# Adjust the weights to prevent FP16 overflows.
31+
weight = torch.tile(
32+
torch.tensor([[-1, 1], [1, -1]], dtype=torch_dtype, device=torch.device("cuda")), (128, 128)
33+
)
34+
layer.weight = tp.Parameter(weight)
35+
36+
def __call__(self, input):
37+
for layer in self.layers:
38+
input = layer(input)
39+
return input
40+
41+
class TorchLinearBlock(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.layers = torch.nn.ModuleList(
45+
[
46+
torch.nn.Linear(256, 256, bias=False, dtype=torch_dtype, device=torch.device("cuda"))
47+
for _ in range(10)
48+
]
49+
)
50+
51+
def forward(self, input):
52+
for layer in self.layers:
53+
input = layer(input)
54+
return input
55+
56+
tripy_block = LinearBlock()
57+
torch_block = TorchLinearBlock()
58+
input_infos = {"input": tp.InputInfo(shape=(1024, 256), dtype=tripy_dtype)}
59+
return tripy_block, torch_block, input_infos
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
import torch
17+
from tests import helper
18+
19+
import tripy as tp
20+
21+
PERF_CASES = []
22+
23+
24+
def perf_fixture(dtypes):
25+
def perf_fixture_impl(func):
26+
PERF_CASES.append(pytest.lazy_fixture(func.__qualname__))
27+
28+
@pytest.fixture(params=dtypes, scope="session")
29+
def wrapped(request):
30+
tripy_module, torch_module, input_infos = func(request.param, helper.TORCH_DTYPES[request.param])
31+
32+
torch_state_dict = {key: torch.from_dlpack(value) for key, value in tripy_module.state_dict().items()}
33+
torch_module.load_state_dict(torch_state_dict)
34+
35+
compiler = tp.Compiler(tripy_module)
36+
tripy_compiled = compiler.compile(**input_infos)
37+
38+
inputs = {
39+
key: tp.iota(input_info.shape_bounds.opt, dtype=request.param)
40+
for key, input_info in input_infos.items()
41+
}
42+
for tensor in inputs.values():
43+
tensor.eval()
44+
45+
torch_compiled = torch.compile(torch_module)
46+
47+
return tripy_compiled, torch_compiled, inputs
48+
49+
return wrapped
50+
51+
return perf_fixture_impl

tripy/tests/performance/test_perf.py

Lines changed: 39 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -16,105 +16,62 @@
1616

1717
import pytest
1818
import torch
19-
from tests import helper
2019

21-
import tripy as tp
22-
23-
24-
def perf_fixture(dtypes):
25-
def perf_fixture_impl(func):
26-
@pytest.fixture(params=dtypes, scope="session")
27-
def wrapped(request):
28-
tripy_module, torch_module, input_infos = func(request.param, helper.TORCH_DTYPES[request.param])
29-
30-
torch_state_dict = {key: torch.from_dlpack(value) for key, value in tripy_module.state_dict().items()}
31-
torch_module.load_state_dict(torch_state_dict)
32-
33-
compiler = tp.Compiler(tripy_module)
34-
tripy_compiled = compiler.compile(**input_infos)
35-
36-
inputs = {
37-
key: tp.iota(input_info.shape_bounds.opt, dtype=request.param)
38-
for key, input_info in input_infos.items()
39-
}
40-
for tensor in inputs.values():
41-
tensor.eval()
42-
43-
torch_compiled = torch.compile(torch_module)
20+
# Need to import cases in order to populate PERF_CASES and load pytest fixtures
21+
from tests.performance.cases import *
22+
from tests.performance.conftest import PERF_CASES
4423

45-
return tripy_compiled, torch_compiled, inputs
46-
47-
return wrapped
48-
49-
return perf_fixture_impl
50-
51-
52-
# TODO: File issue for FP32:
53-
@perf_fixture(dtypes=[pytest.param(tp.float32, marks=pytest.mark.skip("Bug in MLIR-TRT")), tp.float16])
54-
def linear_block(tripy_dtype, torch_dtype):
24+
import tripy as tp
5525

56-
class LinearBlock(tp.Module):
57-
def __init__(self):
58-
self.layers = [tp.Linear(256, 256, bias=False, dtype=tripy_dtype) for _ in range(10)]
59-
for layer in self.layers:
60-
# Adjust the weights to prevent FP16 overflows.
61-
weight = torch.tile(
62-
torch.tensor([[-1, 1], [1, -1]], dtype=torch_dtype, device=torch.device("cuda")), (128, 128)
63-
)
64-
layer.weight = tp.Parameter(weight)
6526

66-
def __call__(self, input):
67-
for layer in self.layers:
68-
input = layer(input)
69-
return input
27+
@pytest.mark.parametrize("perf_case", PERF_CASES)
28+
def test_perf_regression(perf_case, benchmark):
29+
compiled_tripy_module, _, inputs = perf_case
7030

71-
class TorchLinearBlock(torch.nn.Module):
72-
def __init__(self):
73-
super().__init__()
74-
self.layers = torch.nn.ModuleList(
75-
[
76-
torch.nn.Linear(256, 256, bias=False, dtype=torch_dtype, device=torch.device("cuda"))
77-
for _ in range(10)
78-
]
79-
)
31+
benchmark(compiled_tripy_module, **inputs)
8032

81-
def forward(self, input):
82-
for layer in self.layers:
83-
input = layer(input)
84-
return input
8533

86-
tripy_block = LinearBlock()
87-
torch_block = TorchLinearBlock()
88-
input_infos = {"input": tp.InputInfo(shape=(1024, 256), dtype=tripy_dtype)}
89-
return tripy_block, torch_block, input_infos
34+
@pytest.mark.parametrize("perf_case", PERF_CASES)
35+
def test_perf_comparative(perf_case):
36+
compiled_tripy_module, compiled_torch_module, inputs = perf_case
9037

38+
WARM_UP_RUNS = 2
39+
ITERATIONS = 100
9140

92-
def test_perf_regression(linear_block, benchmark):
93-
compiled_tripy_module, _, inputs = linear_block
41+
# Time Tripy
42+
stream = tp.default_stream()
9443

95-
benchmark(compiled_tripy_module, **inputs)
44+
for _ in range(WARM_UP_RUNS):
45+
compiled_tripy_module(**inputs)
46+
stream.synchronize()
9647

48+
start = time.perf_counter()
49+
for _ in range(ITERATIONS):
50+
tripy_out = compiled_tripy_module(**inputs)
51+
stream.synchronize()
52+
end = time.perf_counter()
9753

98-
def test_perf_comparative(linear_block):
99-
compiled_tripy_module, compiled_torch_module, inputs = linear_block
54+
# Torch will report time in ms:
55+
tripy_time = (end - start) * 1000
10056

101-
def time_func(func, kwargs, warm_up_runs=2, iterations=100):
102-
for _ in range(warm_up_runs):
103-
func(**kwargs)
57+
# Time Torch
58+
torch_inputs = {key: torch.from_dlpack(value).to(device="cuda") for key, value in inputs.items()}
10459

105-
start = time.perf_counter()
106-
for _ in range(iterations):
107-
out = func(**kwargs)
108-
end = time.perf_counter()
60+
with torch.no_grad():
61+
for _ in range(WARM_UP_RUNS):
62+
compiled_torch_module(**torch_inputs)
63+
torch.cuda.synchronize()
10964

110-
return out, end - start
65+
start = torch.cuda.Event(enable_timing=True)
66+
end = torch.cuda.Event(enable_timing=True)
11167

112-
tripy_out, tripy_time = time_func(compiled_tripy_module, inputs)
68+
start.record()
69+
for _ in range(ITERATIONS):
70+
torch_out = compiled_torch_module(**torch_inputs)
71+
end.record()
72+
torch.cuda.synchronize()
11373

114-
# TODO: Figure out how to time torch more accurately:
115-
torch_out, torch_time = time_func(
116-
compiled_torch_module, {key: torch.from_dlpack(value) for key, value in inputs.items()}
117-
)
74+
torch_time = start.elapsed_time(end)
11875

11976
# If the outputs don't match, then we're either not comparing apples-to-apples
12077
# or there is an accuracy bug somewhere - either way we want to catch it here.

tripy/tripy/backend/mlir/executor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tripy.backend.utils import TensorInfo
2525
from tripy.common import datatype, device
2626
from tripy.common.exception import raise_error
27-
from tripy.utils import log_time, make_tuple
27+
from tripy.utils import make_tuple
2828

2929

3030
class Executor:
@@ -146,7 +146,6 @@ def stream(self):
146146
def stream(self, stream):
147147
self._stream = stream
148148

149-
@log_time
150149
def execute(self, output_devices=List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
151150
from tripy.frontend.trace.ops import Storage
152151

0 commit comments

Comments
 (0)