@@ -81,7 +81,7 @@ class ConvTestCase:
8181@pytest .mark .parametrize ("torch_dtype,tp_dtype" , DTYPES )
8282class TestConvolution :
8383 @pytest .mark .parametrize ("test_case" , test_cases_transpose_1d )
84- def test_transposed_convolution_1d (self , torch_dtype , tp_dtype , test_case ):
84+ def test_transposed_convolution_1d (self , torch_dtype , tp_dtype , test_case , eager_or_compiled ):
8585 if not test_case .torch_pad :
8686 test_case .torch_pad = 0
8787 if not test_case .stride :
@@ -129,14 +129,14 @@ def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
129129 conv_layer .bias = tp .cast (tp .Tensor (conv_layer_torch .bias .data ), tp_dtype )
130130
131131 expected = conv_layer_torch (input_torch ).to (torch_dtype )
132- output = conv_layer ( input )
132+ output = eager_or_compiled ( conv_layer , input )
133133
134- rtol_ = 1e -3
134+ rtol_ = 3e -3
135135 assert tp .allclose (output , tp .Tensor (expected ), rtol = rtol_ )
136136 assert output .shape == list (expected .shape )
137137
138138 @pytest .mark .parametrize ("test_case" , test_cases_transpose_2d )
139- def test_transposed_convolution_2d (self , torch_dtype , tp_dtype , test_case ):
139+ def test_transposed_convolution_2d (self , torch_dtype , tp_dtype , test_case , eager_or_compiled ):
140140 if not test_case .torch_pad :
141141 test_case .torch_pad = 0
142142 if not test_case .stride :
@@ -184,14 +184,14 @@ def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
184184 conv_layer .bias = tp .cast (tp .Tensor (conv_layer_torch .bias .data ), tp_dtype )
185185
186186 expected = conv_layer_torch (input_torch ).to (torch_dtype )
187- output = conv_layer ( input )
187+ output = eager_or_compiled ( conv_layer , input )
188188
189189 rtol_ = 1e-2
190190 assert tp .allclose (output , tp .Tensor (expected ), rtol = rtol_ )
191191 assert output .shape == list (expected .shape )
192192
193193 @pytest .mark .parametrize ("test_case" , test_cases_transpose_3d )
194- def test_transposed_convolution_3d (self , torch_dtype , tp_dtype , test_case ):
194+ def test_transposed_convolution_3d (self , torch_dtype , tp_dtype , test_case , eager_or_compiled ):
195195 if not test_case .torch_pad :
196196 test_case .torch_pad = 0
197197 if not test_case .stride :
@@ -239,12 +239,12 @@ def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
239239 conv_layer .bias = tp .cast (tp .Tensor (conv_layer_torch .bias .data ), tp_dtype )
240240
241241 expected = conv_layer_torch (input_torch ).to (torch_dtype )
242- output = conv_layer ( input )
242+ output = eager_or_compiled ( conv_layer , input )
243243 rtol_ = 1.3e-6 if tp_dtype == tp .float32 else 1.6e-3
244244 assert tp .allclose (output , tp .Tensor (expected ), rtol = rtol_ )
245245 assert output .shape == list (expected .shape )
246246
247- def test_transposed_equivalency (self , torch_dtype , tp_dtype ):
247+ def test_transposed_equivalency (self , torch_dtype , tp_dtype , eager_or_compiled ):
248248 input_torch = torch .arange (9 , dtype = torch .float32 , device = torch .device ("cuda" )).reshape (* (1 , 1 , 3 , 3 ))
249249 input = tp .cast (tp .Tensor (input_torch ), tp_dtype )
250250
@@ -277,8 +277,8 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):
277277
278278 expected = conv_layer_torch (input_torch ).to (torch_dtype )
279279 expected_transpose = conv_transpose_layer_torch (input_torch ).to (torch_dtype )
280- output = conv_layer ( input )
281- output_transpose = conv_transpose_layer ( input )
280+ output = eager_or_compiled ( conv_layer , input )
281+ output_transpose = eager_or_compiled ( conv_transpose_layer , input )
282282
283283 rtol_ = 2e-7 if tp_dtype == tp .float32 else 9e-4
284284 assert tp .allclose (output , tp .Tensor (expected ), rtol = rtol_ )
@@ -291,7 +291,7 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):
291291 assert list (expected .shape ) == list (expected_transpose .shape )
292292
293293 @pytest .mark .parametrize ("test_case" , test_cases_transpose_downscale )
294- def test_transposed_downscale (self , torch_dtype , tp_dtype , test_case ):
294+ def test_transposed_downscale (self , torch_dtype , tp_dtype , test_case , eager_or_compiled ):
295295 input_torch = torch .arange (9 , dtype = torch .float32 , device = torch .device ("cuda" )).reshape (* (1 , 1 , 3 , 3 ))
296296 input = tp .cast (tp .Tensor (input_torch ), tp_dtype )
297297
@@ -320,7 +320,7 @@ def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
320320 conv_layer .weight = tp .cast (tp .Tensor (conv_layer_torch .weight .data ), tp_dtype )
321321
322322 expected = conv_layer_torch (input_torch ).to (torch_dtype )
323- output = conv_layer ( input )
323+ output = eager_or_compiled ( conv_layer , input )
324324
325325 rtol_ = 1e-15 if tp_dtype == tp .float32 else 1e-10
326326 assert tp .allclose (output , tp .Tensor (expected ), rtol = rtol_ )
0 commit comments