Skip to content

Commit e24ad40

Browse files
adds initial comparative benchmark test
1 parent e1679a2 commit e24ad40

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

tripy/tests/performance/test_perf.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import time
16+
1517
import pytest
1618
import torch
19+
from tests import helper
1720

1821
import tripy as tp
19-
from tests import helper
2022

2123

2224
# TODO: File issue for FP32:
@@ -25,39 +27,52 @@ def linear_block(request):
2527
class LinearBlock(tp.Module):
2628
def __init__(self):
2729
self.layers = [tp.Linear(256, 256, bias=False, dtype=request.param) for _ in range(10)]
30+
for layer in self.layers:
31+
# Adjust the weights to prevent FP16 overflows.
32+
layer.weight = tp.Parameter((tp.iota((256, 256), dim=1, dtype=request.param) / 256.0) - 0.5)
2833

2934
def __call__(self, input):
3035
for layer in self.layers:
3136
input = layer(input)
37+
print(torch.from_dlpack(input))
3238
return input
3339

3440
class TorchLinearBlock(torch.nn.Module):
3541
def __init__(self):
3642
super().__init__()
37-
self.layers = [
38-
torch.nn.Linear(256, 256, bias=False, dtype=helper.TORCH_DTYPES[request.param]) for _ in range(10)
39-
]
43+
dtype = helper.TORCH_DTYPES[request.param]
44+
self.layers = torch.nn.ModuleList(
45+
[torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=torch.device("cuda")) for _ in range(10)]
46+
)
4047

4148
def forward(self, input):
4249
for layer in self.layers:
4350
input = layer(input)
51+
print(input)
4452
return input
4553

4654
tripy_block = LinearBlock()
4755
torch_block = TorchLinearBlock()
4856

49-
tripy_block.load_from_state_dict(state_dict={key: tp.Parameter(value) for key, value in torch_block.state_dict()})
57+
torch_state_dict = {key: torch.from_dlpack(value) for key, value in tripy_block.state_dict().items()}
58+
torch_block.load_state_dict(torch_state_dict)
5059

5160
input_infos = {"input": tp.InputInfo(shape=(1024, 256), dtype=request.param)}
5261

53-
compiler = tp.Compiler(tripy_block)
54-
tripy_compiled = compiler.compile(**input_infos)
62+
# compiler = tp.Compiler(tripy_block)
63+
# tripy_compiled = compiler.compile(**input_infos)
64+
tripy_compiled = tripy_block
5565

56-
inputs = {key: tp.iota(input_info.shape_bounds.opt, dtype=request.param) for key, input_info in input_infos.items()}
66+
inputs = {
67+
key: tp.iota(input_info.shape_bounds.opt, dtype=request.param) / 100.0
68+
for key, input_info in input_infos.items()
69+
}
5770
for tensor in inputs.values():
5871
tensor.eval()
5972

60-
return tripy_compiled, torch_block, inputs
73+
torch_compiled = torch.compile(torch_block)
74+
75+
return tripy_compiled, torch_compiled, inputs
6176

6277

6378
def test_perf_regression(linear_block, benchmark):
@@ -67,8 +82,32 @@ def test_perf_regression(linear_block, benchmark):
6782

6883

6984
def test_perf_comparative(linear_block):
70-
compiled_tripy_module, torch_module, inputs = linear_block
85+
compiled_tripy_module, compiled_torch_module, inputs = linear_block
86+
87+
# TODO: Change to 100:
88+
NUM_ITERS = 1
89+
90+
# TODO: Add warm-up runs, factor out into function.
91+
start = time.perf_counter()
92+
for _ in range(NUM_ITERS):
93+
tripy_out = compiled_tripy_module(**inputs)
94+
end = time.perf_counter()
95+
96+
tripy_time = end - start
97+
98+
start = time.perf_counter()
99+
for _ in range(NUM_ITERS):
100+
torch_out = compiled_torch_module(**{key: torch.from_dlpack(value) for key, value in inputs.items()})
101+
end = time.perf_counter()
102+
103+
torch_time = end - start
71104

72-
# TODO: Check accuracy - update fixture to make weights same
105+
# If the outputs don't match, then we're either not comparing apples-to-apples
106+
# or there is an accuracy bug somewhere - either way we want to catch it here.
107+
# TODO: Adjust tolerance per test?
108+
# TODO: File accuracy bug? Check if delta is within expected FP16 error - maybe check CUDA vs. torch CPU.
109+
assert torch.allclose(torch_out, torch.from_dlpack(tripy_out), atol=0.01)
73110

74-
# TODO: Compare perf after compiling? Maybe compile in fixture
111+
# Check that Tripy inference is at least 5% faster
112+
print(f"Tripy was {torch_time / float(tripy_time)}x faster than Torch")
113+
assert (tripy_time * 1.05) < torch_time

tripy/tripy/frontend/module/linear.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,8 @@ def __init__(
9898

9999
self.quant_dtype = quant_dtype
100100
self.weight_quant_dim = weight_quant_dim
101-
if quant_dtype is not None:
102-
self.weight_scale = None
103-
self.input_scale = None
101+
self.weight_scale = None
102+
self.input_scale = None
104103

105104
def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor":
106105
r"""

0 commit comments

Comments
 (0)