|
39 | 39 | from torchvision.transforms._functional_tensor import _max_value as get_max_value
|
40 | 40 | from torchvision.transforms.functional import pil_modes_mapping
|
41 | 41 | from torchvision.transforms.v2 import functional as F
|
42 |
| -from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY |
| 42 | +from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal |
43 | 43 |
|
44 | 44 |
|
45 | 45 | @pytest.fixture(autouse=True)
|
@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
|
173 | 173 | dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
|
174 | 174 |
|
175 | 175 |
|
176 |
| -def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): |
177 |
| - """Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is |
178 |
| - preserved in doing so. For bounding boxes also checks that the format is preserved. |
179 |
| - """ |
180 |
| - input_type = type(input) |
181 |
| - |
182 |
| - if isinstance(input, datapoints.Datapoint): |
183 |
| - wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type] |
184 |
| - |
185 |
| - # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the |
186 |
| - # proper kernel was wrapped |
187 |
| - if hasattr(wrapped_kernel, "__wrapped__"): |
188 |
| - assert wrapped_kernel.__wrapped__ is kernel |
189 |
| - |
190 |
| - spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__) |
191 |
| - with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}): |
192 |
| - output = dispatcher(input, *args, **kwargs) |
193 |
| - |
194 |
| - spy.assert_called_once() |
195 |
| - else: |
196 |
| - with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy: |
197 |
| - output = dispatcher(input, *args, **kwargs) |
198 |
| - |
199 |
| - spy.assert_called_once() |
200 |
| - |
201 |
| - assert isinstance(output, input_type) |
202 |
| - |
203 |
| - if isinstance(input, datapoints.BoundingBoxes): |
204 |
| - assert output.format == input.format |
205 |
| - |
206 |
| - |
207 | 176 | def check_dispatcher(
|
208 | 177 | dispatcher,
|
| 178 | + # TODO: remove this parameter |
209 | 179 | kernel,
|
210 | 180 | input,
|
211 | 181 | *args,
|
212 | 182 | check_scripted_smoke=True,
|
213 |
| - check_dispatch=True, |
214 | 183 | **kwargs,
|
215 | 184 | ):
|
216 | 185 | unknown_input = object()
|
| 186 | + with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): |
| 187 | + dispatcher(unknown_input, *args, **kwargs) |
| 188 | + |
217 | 189 | with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
|
218 |
| - with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): |
219 |
| - dispatcher(unknown_input, *args, **kwargs) |
| 190 | + output = dispatcher(input, *args, **kwargs) |
220 | 191 |
|
221 | 192 | spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")
|
222 | 193 |
|
| 194 | + assert isinstance(output, type(input)) |
| 195 | + |
| 196 | + if isinstance(input, datapoints.BoundingBoxes): |
| 197 | + assert output.format == input.format |
| 198 | + |
223 | 199 | if check_scripted_smoke:
|
224 | 200 | _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)
|
225 | 201 |
|
226 |
| - if check_dispatch: |
227 |
| - _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs) |
228 |
| - |
229 | 202 |
|
230 | 203 | def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
|
231 | 204 | """Checks if the signature of the dispatcher matches the kernel signature."""
|
@@ -412,18 +385,20 @@ def transform(bbox):
|
412 | 385 |
|
413 | 386 |
|
414 | 387 | @pytest.mark.parametrize(
|
415 |
| - ("dispatcher", "registered_datapoint_clss"), |
| 388 | + ("dispatcher", "registered_input_types"), |
416 | 389 | [(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
|
417 | 390 | )
|
418 |
| -def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): |
| 391 | +def test_exhaustive_kernel_registration(dispatcher, registered_input_types): |
419 | 392 | missing = {
|
| 393 | + torch.Tensor, |
| 394 | + PIL.Image.Image, |
420 | 395 | datapoints.Image,
|
421 | 396 | datapoints.BoundingBoxes,
|
422 | 397 | datapoints.Mask,
|
423 | 398 | datapoints.Video,
|
424 |
| - } - registered_datapoint_clss |
| 399 | + } - registered_input_types |
425 | 400 | if missing:
|
426 |
| - names = sorted(f"datapoints.{cls.__name__}" for cls in missing) |
| 401 | + names = sorted(str(t) for t in missing) |
427 | 402 | raise AssertionError(
|
428 | 403 | "\n".join(
|
429 | 404 | [
|
@@ -1753,11 +1728,6 @@ def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device,
|
1753 | 1728 | F.to_dtype,
|
1754 | 1729 | kernel,
|
1755 | 1730 | make_input(dtype=input_dtype, device=device),
|
1756 |
| - # TODO: we could leave check_dispatch to True but it currently fails |
1757 |
| - # in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints. |
1758 |
| - # We should be able to put this back if we change the dispatch |
1759 |
| - # mechanism e.g. via https://github.com/pytorch/vision/pull/7733 |
1760 |
| - check_dispatch=False, |
1761 | 1731 | dtype=output_dtype,
|
1762 | 1732 | scale=scale,
|
1763 | 1733 | )
|
@@ -2208,9 +2178,105 @@ def new_resize(dp, *args, **kwargs):
|
2208 | 2178 | t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
|
2209 | 2179 | t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
|
2210 | 2180 |
|
2211 |
| - def test_bad_disaptcher_name(self): |
2212 |
| - class CustomDatapoint(datapoints.Datapoint): |
| 2181 | + def test_errors(self): |
| 2182 | + with pytest.raises(ValueError, match="Could not find dispatcher with name"): |
| 2183 | + F.register_kernel("bad_name", datapoints.Image) |
| 2184 | + |
| 2185 | + with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"): |
| 2186 | + F.register_kernel(datapoints.Image, F.resize) |
| 2187 | + |
| 2188 | + with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): |
| 2189 | + F.register_kernel(F.resize, object) |
| 2190 | + |
| 2191 | + with pytest.raises(ValueError, match="already has a kernel registered for type"): |
| 2192 | + F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) |
| 2193 | + |
| 2194 | + |
| 2195 | +class TestGetKernel: |
| 2196 | + # We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination |
| 2197 | + # would also be fine |
| 2198 | + KERNELS = { |
| 2199 | + torch.Tensor: F.resize_image_tensor, |
| 2200 | + PIL.Image.Image: F.resize_image_pil, |
| 2201 | + datapoints.Image: F.resize_image_tensor, |
| 2202 | + datapoints.BoundingBoxes: F.resize_bounding_boxes, |
| 2203 | + datapoints.Mask: F.resize_mask, |
| 2204 | + datapoints.Video: F.resize_video, |
| 2205 | + } |
| 2206 | + |
| 2207 | + def test_unsupported_types(self): |
| 2208 | + class MyTensor(torch.Tensor): |
2213 | 2209 | pass
|
2214 | 2210 |
|
2215 |
| - with pytest.raises(ValueError, match="Could not find dispatcher with name"): |
2216 |
| - F.register_kernel("bad_name", CustomDatapoint) |
| 2211 | + class MyPILImage(PIL.Image.Image): |
| 2212 | + pass |
| 2213 | + |
| 2214 | + for input_type in [str, int, object, MyTensor, MyPILImage]: |
| 2215 | + with pytest.raises( |
| 2216 | + TypeError, |
| 2217 | + match=( |
| 2218 | + "supports inputs of type torch.Tensor, PIL.Image.Image, " |
| 2219 | + "and subclasses of torchvision.datapoints.Datapoint" |
| 2220 | + ), |
| 2221 | + ): |
| 2222 | + _get_kernel(F.resize, input_type) |
| 2223 | + |
| 2224 | + def test_exact_match(self): |
| 2225 | + # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the |
| 2226 | + # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher |
| 2227 | + # here, register the kernels without wrapper, and check the exact matching afterwards. |
| 2228 | + def resize_with_pure_kernels(): |
| 2229 | + pass |
| 2230 | + |
| 2231 | + for input_type, kernel in self.KERNELS.items(): |
| 2232 | + _register_kernel_internal(resize_with_pure_kernels, input_type, datapoint_wrapper=False)(kernel) |
| 2233 | + |
| 2234 | + assert _get_kernel(resize_with_pure_kernels, input_type) is kernel |
| 2235 | + |
| 2236 | + def test_builtin_datapoint_subclass(self): |
| 2237 | + # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the |
| 2238 | + # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher |
| 2239 | + # here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched |
| 2240 | + # to the kernel of the corresponding superclass |
| 2241 | + def resize_with_pure_kernels(): |
| 2242 | + pass |
| 2243 | + |
| 2244 | + class MyImage(datapoints.Image): |
| 2245 | + pass |
| 2246 | + |
| 2247 | + class MyBoundingBoxes(datapoints.BoundingBoxes): |
| 2248 | + pass |
| 2249 | + |
| 2250 | + class MyMask(datapoints.Mask): |
| 2251 | + pass |
| 2252 | + |
| 2253 | + class MyVideo(datapoints.Video): |
| 2254 | + pass |
| 2255 | + |
| 2256 | + for custom_datapoint_subclass in [ |
| 2257 | + MyImage, |
| 2258 | + MyBoundingBoxes, |
| 2259 | + MyMask, |
| 2260 | + MyVideo, |
| 2261 | + ]: |
| 2262 | + builtin_datapoint_class = custom_datapoint_subclass.__mro__[1] |
| 2263 | + builtin_datapoint_kernel = self.KERNELS[builtin_datapoint_class] |
| 2264 | + _register_kernel_internal(resize_with_pure_kernels, builtin_datapoint_class, datapoint_wrapper=False)( |
| 2265 | + builtin_datapoint_kernel |
| 2266 | + ) |
| 2267 | + |
| 2268 | + assert _get_kernel(resize_with_pure_kernels, custom_datapoint_subclass) is builtin_datapoint_kernel |
| 2269 | + |
| 2270 | + def test_datapoint_subclass(self): |
| 2271 | + class MyDatapoint(datapoints.Datapoint): |
| 2272 | + pass |
| 2273 | + |
| 2274 | + # Note that this will be an error in the future |
| 2275 | + assert _get_kernel(F.resize, MyDatapoint) is _noop |
| 2276 | + |
| 2277 | + def resize_my_datapoint(): |
| 2278 | + pass |
| 2279 | + |
| 2280 | + _register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint) |
| 2281 | + |
| 2282 | + assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint |
0 commit comments