Skip to content

Commit 02f66fc

Browse files
committed
Merge branch 'tripy-resnet50' of github.com:NVIDIA/TensorRT-Incubator into tripy-resnet50
2 parents b6bd4f4 + 22e1612 commit 02f66fc

36 files changed

+348
-200
lines changed

tripy/docs/packages.html

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
<body>
1111
<h1>Package Index</h1>
12+
<a
13+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/tripy-v0.0.5/tripy-0.0.5-py3-none-any.whl">tripy-0.0.5-py3-none-any.whl</a><br>
14+
15+
<a
16+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/tripy-v0.0.4/tripy-0.0.4-py3-none-any.whl">tripy-0.0.4-py3-none-any.whl</a><br>
17+
1218
<a
1319
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/tripy-v0.0.3/tripy-0.0.3-py3-none-any.whl">tripy-0.0.3-py3-none-any.whl</a><br>
1420

@@ -102,6 +108,26 @@ <h1>Package Index</h1>
102108
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.36/mlir_tensorrt_runtime-0.1.36+cuda12.trt102-cp312-cp312-linux_x86_64.whl">mlir_tensorrt_runtime-0.1.36+cuda12.trt102-cp312-cp312-linux_x86_64.whl</a><br>
103109
<a
104110
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.36/mlir_tensorrt_runtime-0.1.36+cuda12.trt102-cp39-cp39-linux_x86_64.whl">mlir_tensorrt_runtime-0.1.36+cuda12.trt102-cp39-cp39-linux_x86_64.whl</a><br>
105-
</body>
111+
112+
113+
<a
114+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.37/mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp310-cp310-linux_x86_64.whl">mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp310-cp310-linux_x86_64.whl</a><br>
115+
<a
116+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.37/mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp311-cp311-linux_x86_64.whl">mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp311-cp311-linux_x86_64.whl</a><br>
117+
<a
118+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.37/mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp312-cp312-linux_x86_64.whl">mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp312-cp312-linux_x86_64.whl</a><br>
119+
<a
120+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.37/mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp39-cp39-linux_x86_64.whl">mlir_tensorrt_compiler-0.1.37+cuda12.trt102-cp39-cp39-linux_x86_64.whl</a><br>
121+
<a
122+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.37/mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp310-cp310-linux_x86_64.whl">mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp310-cp310-linux_x86_64.whl</a><br>
123+
<a
124+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.37/mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp311-cp311-linux_x86_64.whl">mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp311-cp311-linux_x86_64.whl</a><br>
125+
<a
126+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.37/mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp312-cp312-linux_x86_64.whl">mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp312-cp312-linux_x86_64.whl</a><br>
127+
<a
128+
href="https://github.com/NVIDIA/TensorRT-Incubator/releases/download/mlir-tensorrt-v0.1.37/mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp39-cp39-linux_x86_64.whl">mlir_tensorrt_runtime-0.1.37+cuda12.trt102-cp39-cp39-linux_x86_64.whl</a><br>
129+
130+
131+
</body>
106132

107133
</html>

tripy/pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
[project]
22
name = "tripy"
3-
version = "0.0.3"
3+
version = "0.0.5"
44
authors = [{name = "NVIDIA", email="[email protected]"}]
55
description = "Tripy: A Python Programming Model For TensorRT"
66
readme = "README.md"
77
requires-python = ">= 3.9"
88
license = {text = "Apache 2.0"}
99
dependencies = [
1010
"tensorrt~=10.0",
11-
"mlir-tensorrt-compiler==0.1.36+cuda12.trt102",
12-
"mlir-tensorrt-runtime==0.1.36+cuda12.trt102",
11+
"mlir-tensorrt-compiler==0.1.37+cuda12.trt102",
12+
"mlir-tensorrt-runtime==0.1.37+cuda12.trt102",
1313
"colored==2.2.3",
1414
]
1515

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
20+
import tripy as tp
21+
22+
23+
@pytest.fixture(params=["compile", "eager"])
24+
def eager_or_compiled(request):
25+
def wrapper(func, *args, **kwargs):
26+
def get_input_info(x: tp.Tensor):
27+
return tp.InputInfo(list(map(int, x.shape)), dtype=x.dtype)
28+
29+
if request.param == "eager":
30+
return func(*args, **kwargs)
31+
32+
assert request.param == "compile"
33+
34+
compile_args = []
35+
for arg in args:
36+
# We don't want to feed DimensionSize as a dynamic input to the compiler (https://github.com/NVIDIA/TensorRT-Incubator/issues/65).
37+
if isinstance(arg, tp.Tensor) and not isinstance(arg, tp.DimensionSize):
38+
compile_args.append(get_input_info(arg))
39+
else:
40+
compile_args.append(arg)
41+
compile_args = tuple(compile_args)
42+
43+
compile_kwargs = dict(
44+
(
45+
k,
46+
((get_input_info(v) if isinstance(v, tp.Tensor) and not isinstance(v, tp.DimensionSize) else v)),
47+
)
48+
for k, v in kwargs.items()
49+
)
50+
51+
compiled_func = tp.compile(func, args=compile_args, kwargs=compile_kwargs)
52+
53+
tensor_args = tuple(x for x in args if isinstance(x, tp.Tensor) and not isinstance(x, tp.DimensionSize))
54+
55+
tensor_kwargs = {
56+
k: v for k, v in kwargs.items() if isinstance(v, tp.Tensor) and not isinstance(v, tp.DimensionSize)
57+
}
58+
59+
return compiled_func(*tensor_args, **tensor_kwargs)
60+
61+
return wrapper

tripy/tests/integration/test_batchnorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TestBatchNorm:
2626

2727
@pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES)
2828
@pytest.mark.parametrize("input_shape", [(2, 2, 2, 2)])
29-
def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape):
29+
def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, eager_or_compiled):
3030
eps = 1e-5
3131
num_features = input_shape[1] # Number of channels in the input tensor
3232
batchnorm = torch.nn.BatchNorm2d(num_features=num_features, eps=eps, dtype=torch_dtype)
@@ -45,7 +45,7 @@ def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape):
4545
input = torch.randn(input_shape, dtype=torch_dtype).to("cuda")
4646
tp_input = tp.Tensor(input, dtype=tp_dtype)
4747

48-
output = tp_batchnorm(tp_input)
48+
output = eager_or_compiled(tp_batchnorm, tp_input)
4949

5050
batchnorm.to("cuda").eval()
5151
with torch.no_grad():

tripy/tests/integration/test_cast.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,54 +30,53 @@ class TestCast:
3030
[
3131
(np.int32, np.float32),
3232
(np.float32, np.int32),
33-
(np.int64, np.float32),
34-
(np.float32, np.int64),
35-
(np.int64, np.int32),
36-
(np.int64, np.int8),
3733
(np.int32, np.int8),
3834
(np.float32, np.int8),
39-
(np.int8, np.int64),
4035
(np.int8, np.int32),
4136
(np.int8, np.float32),
4237
# important to test conversion into bool because default StableHLO semantics
4338
# are simply to truncate to i1, which is not desirable
4439
(np.float32, bool),
4540
(np.int32, bool),
46-
(np.int64, bool),
4741
# requires a dequantization first
4842
# TODO(#219): Dequantize fails with dynamic shapes
4943
# (np.int8, bool),
5044
],
5145
)
52-
def test_cast(self, input_dtype, target_dtype):
46+
def test_cast(self, input_dtype, target_dtype, eager_or_compiled):
5347
tp_input_dtype = NUMPY_TO_TRIPY[input_dtype]
5448
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]
5549

5650
# TODO(#222): Integer casts with negative numbers fail in many cases
5751
input_tensor = tp.Tensor([0, 1, 2], dtype=tp_input_dtype)
5852
np_input = cp.from_dlpack(input_tensor).get()
59-
output = tp.cast(input_tensor, tp_target_dtype)
53+
output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype)
6054

6155
assert np.array_equal(cp.from_dlpack(output).get(), np_input.astype(target_dtype))
6256

6357
# these dtypes don't have analogues in numpy
6458
@pytest.mark.parametrize("source_dtype", [pytest.param(tp.float8, marks=skip_if_older_than_sm89), tp.int4])
65-
def test_cast_quantized_dtypes_into_bool(self, source_dtype):
59+
def test_cast_quantized_dtypes_into_bool(self, source_dtype, eager_or_compiled):
6660
# TODO(#223): Using an odd size leads to a strange crash, so can't just use [-1.0, 0.0, 1.0]
6761
input_tensor = tp.Tensor([-1.0, 0.0, 0.0, 1.0], dtype=tp.float32)
68-
q = tp.quantize(input_tensor, scale=1.0, dtype=source_dtype)
69-
output = tp.cast(q, tp.bool)
62+
63+
def func(input):
64+
q = tp.quantize(input, scale=1.0, dtype=source_dtype)
65+
output = tp.cast(q, tp.bool)
66+
return output
67+
68+
output = eager_or_compiled(func, input_tensor)
7069
assert cp.from_dlpack(output).get().tolist() == [True, False, False, True]
7170

72-
@pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int64, np.int8])
73-
def test_cast_from_bool(self, target_dtype):
71+
@pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int8])
72+
def test_cast_from_bool(self, target_dtype, eager_or_compiled):
7473
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]
7574

7675
# in principle, it is not important what *specific* values we convert to,
7776
# so long as false is mapped to 0 and true to nonzero
7877
input_tensor = tp.Tensor([False, True], dtype=tp.bool)
7978
np_input = cp.from_dlpack(input_tensor).get()
80-
output = tp.cast(input_tensor, tp_target_dtype)
79+
output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype)
8180

8281
tp_compare_to_zero = cp.from_dlpack(output).get() == 0
8382
np_compare_to_zero = np_input.astype(target_dtype) == 0

tripy/tests/integration/test_concatenate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ class TestConcatenate:
3333
([(2, 3, 4)], 0),
3434
],
3535
)
36-
def test_concat(self, tensor_shapes, dim):
36+
def test_concat(self, tensor_shapes, dim, eager_or_compiled):
3737
tensors = [tp.ones(shape) for shape in tensor_shapes]
38-
out = tp.concatenate(tensors, dim=dim)
38+
out = eager_or_compiled(tp.concatenate, tensors, dim=dim)
3939
assert np.array_equal(
4040
cp.from_dlpack(out).get(), np.concatenate([np.ones(shape) for shape in tensor_shapes], axis=dim)
4141
)
@@ -44,8 +44,8 @@ def test_concat(self, tensor_shapes, dim):
4444
"tensor_shapes, dim",
4545
[([(2, 3, 4), (2, 4, 4)], 0), ([(4, 5, 6), (4, 1, 6)], -1)],
4646
)
47-
def test_negative_concat(self, tensor_shapes, dim):
47+
def test_negative_concat(self, tensor_shapes, dim, eager_or_compiled):
4848
tensors = [tp.ones(shape) for shape in tensor_shapes]
4949
with helper.raises(tp.TripyException, match=f"not compatible at non-concat index"):
50-
out = tp.concatenate(tensors, dim=dim)
50+
out = eager_or_compiled(tp.concatenate, tensors, dim=dim)
5151
print(out)

tripy/tests/integration/test_conv.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ConvTestCase:
7575
@pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES)
7676
class TestConvolution:
7777
@pytest.mark.parametrize("test_case", test_cases_1d)
78-
def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
78+
def test_convolution_1d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
7979
if not test_case.torch_pad:
8080
test_case.torch_pad = 0
8181
if not test_case.stride:
@@ -122,7 +122,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
122122
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)
123123

124124
expected = conv_layer_torch(input_torch).to(torch_dtype)
125-
output = conv_layer(input)
125+
output = eager_or_compiled(conv_layer, input)
126126

127127
# FP32 kernel seems to lose some precision, and FP16 needs to be run in FP32 on torch
128128
rtol_ = 4e-5 if tp_dtype == tp.float32 else 1e-3
@@ -131,7 +131,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
131131
assert list(output_torch.shape) == list(expected.shape)
132132

133133
@pytest.mark.parametrize("test_case", test_cases_2d)
134-
def test_convolution_2d(self, torch_dtype, tp_dtype, test_case):
134+
def test_convolution_2d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
135135
if not test_case.torch_pad:
136136
test_case.torch_pad = 0
137137
if not test_case.stride:
@@ -178,15 +178,15 @@ def test_convolution_2d(self, torch_dtype, tp_dtype, test_case):
178178
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)
179179

180180
expected = conv_layer_torch(input_torch).to(torch_dtype)
181-
output = conv_layer(input)
181+
output = eager_or_compiled(conv_layer, input)
182182

183183
rtol_ = 2e-7 if tp_dtype == tp.float32 else 1.5e-3
184184
output_torch = torch.from_dlpack(output)
185185
assert torch.allclose(output_torch, expected, rtol=rtol_)
186186
assert list(output_torch.shape) == list(expected.shape)
187187

188188
@pytest.mark.parametrize("test_case", test_cases_3d)
189-
def test_convolution_3d(self, torch_dtype, tp_dtype, test_case):
189+
def test_convolution_3d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
190190
pytest.skip("TODO (#260): Fix accuracy bugs in 3D conv")
191191
if not test_case.torch_pad:
192192
test_case.torch_pad = 0
@@ -245,14 +245,14 @@ def test_convolution_3d(self, torch_dtype, tp_dtype, test_case):
245245
return
246246

247247
expected = conv_layer_torch(input_torch).to(torch_dtype)
248-
output = conv_layer(input)
248+
output = eager_or_compiled(conv_layer, input)
249249

250250
rtol_ = 2e-4 if tp_dtype == tp.float32 else 1.4e-3 # 3d conv has greater accumulation error
251251
output_torch = torch.from_dlpack(output)
252252
assert torch.allclose(output_torch, expected, rtol=rtol_)
253253
assert list(output_torch.shape) == list(expected.shape)
254254

255-
def test_uneven_padding(self, torch_dtype, tp_dtype):
255+
def test_uneven_padding(self, torch_dtype, tp_dtype, eager_or_compiled):
256256
input_torch = torch.arange(200, dtype=torch.float32, device=torch.device("cuda")).reshape(*(2, 4, 5, 5))
257257
input = tp.cast(tp.Tensor(input_torch), tp_dtype)
258258

@@ -282,7 +282,7 @@ def test_uneven_padding(self, torch_dtype, tp_dtype):
282282

283283
input_torch = torch_pad(input_torch)
284284
expected = conv_layer_torch(input_torch).to(torch_dtype)
285-
output = conv_layer(input)
285+
output = eager_or_compiled(conv_layer, input)
286286

287287
rtol_ = 2e-7 if tp_dtype == tp.float32 else 2e-3
288288
output_torch = torch.from_dlpack(output)

tripy/tests/integration/test_conv_transpose.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class ConvTestCase:
8181
@pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES)
8282
class TestConvolution:
8383
@pytest.mark.parametrize("test_case", test_cases_transpose_1d)
84-
def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
84+
def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
8585
if not test_case.torch_pad:
8686
test_case.torch_pad = 0
8787
if not test_case.stride:
@@ -129,14 +129,14 @@ def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
129129
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)
130130

131131
expected = conv_layer_torch(input_torch).to(torch_dtype)
132-
output = conv_layer(input)
132+
output = eager_or_compiled(conv_layer, input)
133133

134-
rtol_ = 1e-3
134+
rtol_ = 3e-3
135135
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
136136
assert output.shape == list(expected.shape)
137137

138138
@pytest.mark.parametrize("test_case", test_cases_transpose_2d)
139-
def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
139+
def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
140140
if not test_case.torch_pad:
141141
test_case.torch_pad = 0
142142
if not test_case.stride:
@@ -184,14 +184,14 @@ def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
184184
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)
185185

186186
expected = conv_layer_torch(input_torch).to(torch_dtype)
187-
output = conv_layer(input)
187+
output = eager_or_compiled(conv_layer, input)
188188

189189
rtol_ = 1e-2
190190
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
191191
assert output.shape == list(expected.shape)
192192

193193
@pytest.mark.parametrize("test_case", test_cases_transpose_3d)
194-
def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
194+
def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
195195
if not test_case.torch_pad:
196196
test_case.torch_pad = 0
197197
if not test_case.stride:
@@ -239,12 +239,12 @@ def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
239239
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)
240240

241241
expected = conv_layer_torch(input_torch).to(torch_dtype)
242-
output = conv_layer(input)
242+
output = eager_or_compiled(conv_layer, input)
243243
rtol_ = 1.3e-6 if tp_dtype == tp.float32 else 1.6e-3
244244
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
245245
assert output.shape == list(expected.shape)
246246

247-
def test_transposed_equivalency(self, torch_dtype, tp_dtype):
247+
def test_transposed_equivalency(self, torch_dtype, tp_dtype, eager_or_compiled):
248248
input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3))
249249
input = tp.cast(tp.Tensor(input_torch), tp_dtype)
250250

@@ -277,8 +277,8 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):
277277

278278
expected = conv_layer_torch(input_torch).to(torch_dtype)
279279
expected_transpose = conv_transpose_layer_torch(input_torch).to(torch_dtype)
280-
output = conv_layer(input)
281-
output_transpose = conv_transpose_layer(input)
280+
output = eager_or_compiled(conv_layer, input)
281+
output_transpose = eager_or_compiled(conv_transpose_layer, input)
282282

283283
rtol_ = 2e-7 if tp_dtype == tp.float32 else 9e-4
284284
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
@@ -291,7 +291,7 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):
291291
assert list(expected.shape) == list(expected_transpose.shape)
292292

293293
@pytest.mark.parametrize("test_case", test_cases_transpose_downscale)
294-
def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
294+
def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
295295
input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3))
296296
input = tp.cast(tp.Tensor(input_torch), tp_dtype)
297297

@@ -320,7 +320,7 @@ def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
320320
conv_layer.weight = tp.cast(tp.Tensor(conv_layer_torch.weight.data), tp_dtype)
321321

322322
expected = conv_layer_torch(input_torch).to(torch_dtype)
323-
output = conv_layer(input)
323+
output = eager_or_compiled(conv_layer, input)
324324

325325
rtol_ = 1e-15 if tp_dtype == tp.float32 else 1e-10
326326
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)

0 commit comments

Comments
 (0)