Skip to content

Commit 2030d20

Browse files
pmeierNicolasHug
andauthored
register tensor and PIL kernel the same way as datapoints (#7797)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 84db2ac commit 2030d20

File tree

9 files changed

+552
-687
lines changed

9 files changed

+552
-687
lines changed

test/test_transforms_v2_functional.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import math
33
import os
44
import re
5-
from unittest import mock
65

76
import numpy as np
87
import PIL.Image
@@ -25,7 +24,6 @@
2524
from torchvision.transforms.v2 import functional as F
2625
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
2726
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
28-
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
2927
from torchvision.transforms.v2.utils import is_simple_tensor
3028
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
3129
from transforms_v2_kernel_infos import KERNEL_INFOS
@@ -359,18 +357,6 @@ def test_scripted_smoke(self, info, args_kwargs, device):
359357
def test_scriptable(self, dispatcher):
360358
script(dispatcher)
361359

362-
@image_sample_inputs
363-
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
364-
(image_datapoint, *other_args), kwargs = args_kwargs.load()
365-
image_simple_tensor = torch.Tensor(image_datapoint)
366-
367-
kernel_info = info.kernel_infos[datapoints.Image]
368-
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
369-
370-
info.dispatcher(image_simple_tensor, *other_args, **kwargs)
371-
372-
spy.assert_called_once()
373-
374360
@image_sample_inputs
375361
def test_simple_tensor_output_type(self, info, args_kwargs):
376362
(image_datapoint, *other_args), kwargs = args_kwargs.load()
@@ -381,25 +367,6 @@ def test_simple_tensor_output_type(self, info, args_kwargs):
381367
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
382368
assert type(output) is torch.Tensor
383369

384-
@make_info_args_kwargs_parametrization(
385-
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
386-
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
387-
)
388-
def test_dispatch_pil(self, info, args_kwargs, spy_on):
389-
(image_datapoint, *other_args), kwargs = args_kwargs.load()
390-
391-
if image_datapoint.ndim > 3:
392-
pytest.skip("Input is batched")
393-
394-
image_pil = F.to_image_pil(image_datapoint)
395-
396-
pil_kernel_info = info.pil_kernel_info
397-
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
398-
399-
info.dispatcher(image_pil, *other_args, **kwargs)
400-
401-
spy.assert_called_once()
402-
403370
@make_info_args_kwargs_parametrization(
404371
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
405372
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
@@ -416,28 +383,6 @@ def test_pil_output_type(self, info, args_kwargs):
416383

417384
assert isinstance(output, PIL.Image.Image)
418385

419-
@make_info_args_kwargs_parametrization(
420-
DISPATCHER_INFOS,
421-
args_kwargs_fn=lambda info: info.sample_inputs(),
422-
)
423-
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
424-
(datapoint, *other_args), kwargs = args_kwargs.load()
425-
426-
input_type = type(datapoint)
427-
428-
wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]
429-
430-
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
431-
# proper kernel was wrapped
432-
if hasattr(wrapped_kernel, "__wrapped__"):
433-
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]
434-
435-
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
436-
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
437-
info.dispatcher(datapoint, *other_args, **kwargs)
438-
439-
spy.assert_called_once()
440-
441386
@make_info_args_kwargs_parametrization(
442387
DISPATCHER_INFOS,
443388
args_kwargs_fn=lambda info: info.sample_inputs(),
@@ -449,6 +394,9 @@ def test_datapoint_output_type(self, info, args_kwargs):
449394

450395
assert isinstance(output, type(datapoint))
451396

397+
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
398+
assert output.format == datapoint.format
399+
452400
@pytest.mark.parametrize(
453401
("dispatcher_info", "datapoint_type", "kernel_info"),
454402
[

test/test_transforms_v2_refactored.py

Lines changed: 117 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from torchvision.transforms._functional_tensor import _max_value as get_max_value
4040
from torchvision.transforms.functional import pil_modes_mapping
4141
from torchvision.transforms.v2 import functional as F
42-
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
42+
from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal
4343

4444

4545
@pytest.fixture(autouse=True)
@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
173173
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
174174

175175

176-
def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
177-
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
178-
preserved in doing so. For bounding boxes also checks that the format is preserved.
179-
"""
180-
input_type = type(input)
181-
182-
if isinstance(input, datapoints.Datapoint):
183-
wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type]
184-
185-
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
186-
# proper kernel was wrapped
187-
if hasattr(wrapped_kernel, "__wrapped__"):
188-
assert wrapped_kernel.__wrapped__ is kernel
189-
190-
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
191-
with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}):
192-
output = dispatcher(input, *args, **kwargs)
193-
194-
spy.assert_called_once()
195-
else:
196-
with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy:
197-
output = dispatcher(input, *args, **kwargs)
198-
199-
spy.assert_called_once()
200-
201-
assert isinstance(output, input_type)
202-
203-
if isinstance(input, datapoints.BoundingBoxes):
204-
assert output.format == input.format
205-
206-
207176
def check_dispatcher(
208177
dispatcher,
178+
# TODO: remove this parameter
209179
kernel,
210180
input,
211181
*args,
212182
check_scripted_smoke=True,
213-
check_dispatch=True,
214183
**kwargs,
215184
):
216185
unknown_input = object()
186+
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
187+
dispatcher(unknown_input, *args, **kwargs)
188+
217189
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
218-
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
219-
dispatcher(unknown_input, *args, **kwargs)
190+
output = dispatcher(input, *args, **kwargs)
220191

221192
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")
222193

194+
assert isinstance(output, type(input))
195+
196+
if isinstance(input, datapoints.BoundingBoxes):
197+
assert output.format == input.format
198+
223199
if check_scripted_smoke:
224200
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)
225201

226-
if check_dispatch:
227-
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs)
228-
229202

230203
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
231204
"""Checks if the signature of the dispatcher matches the kernel signature."""
@@ -412,18 +385,20 @@ def transform(bbox):
412385

413386

414387
@pytest.mark.parametrize(
415-
("dispatcher", "registered_datapoint_clss"),
388+
("dispatcher", "registered_input_types"),
416389
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
417390
)
418-
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss):
391+
def test_exhaustive_kernel_registration(dispatcher, registered_input_types):
419392
missing = {
393+
torch.Tensor,
394+
PIL.Image.Image,
420395
datapoints.Image,
421396
datapoints.BoundingBoxes,
422397
datapoints.Mask,
423398
datapoints.Video,
424-
} - registered_datapoint_clss
399+
} - registered_input_types
425400
if missing:
426-
names = sorted(f"datapoints.{cls.__name__}" for cls in missing)
401+
names = sorted(str(t) for t in missing)
427402
raise AssertionError(
428403
"\n".join(
429404
[
@@ -1753,11 +1728,6 @@ def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device,
17531728
F.to_dtype,
17541729
kernel,
17551730
make_input(dtype=input_dtype, device=device),
1756-
# TODO: we could leave check_dispatch to True but it currently fails
1757-
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints.
1758-
# We should be able to put this back if we change the dispatch
1759-
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733
1760-
check_dispatch=False,
17611731
dtype=output_dtype,
17621732
scale=scale,
17631733
)
@@ -2208,9 +2178,105 @@ def new_resize(dp, *args, **kwargs):
22082178
t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
22092179
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
22102180

2211-
def test_bad_disaptcher_name(self):
2212-
class CustomDatapoint(datapoints.Datapoint):
2181+
def test_errors(self):
2182+
with pytest.raises(ValueError, match="Could not find dispatcher with name"):
2183+
F.register_kernel("bad_name", datapoints.Image)
2184+
2185+
with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"):
2186+
F.register_kernel(datapoints.Image, F.resize)
2187+
2188+
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
2189+
F.register_kernel(F.resize, object)
2190+
2191+
with pytest.raises(ValueError, match="already has a kernel registered for type"):
2192+
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor)
2193+
2194+
2195+
class TestGetKernel:
2196+
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
2197+
# would also be fine
2198+
KERNELS = {
2199+
torch.Tensor: F.resize_image_tensor,
2200+
PIL.Image.Image: F.resize_image_pil,
2201+
datapoints.Image: F.resize_image_tensor,
2202+
datapoints.BoundingBoxes: F.resize_bounding_boxes,
2203+
datapoints.Mask: F.resize_mask,
2204+
datapoints.Video: F.resize_video,
2205+
}
2206+
2207+
def test_unsupported_types(self):
2208+
class MyTensor(torch.Tensor):
22132209
pass
22142210

2215-
with pytest.raises(ValueError, match="Could not find dispatcher with name"):
2216-
F.register_kernel("bad_name", CustomDatapoint)
2211+
class MyPILImage(PIL.Image.Image):
2212+
pass
2213+
2214+
for input_type in [str, int, object, MyTensor, MyPILImage]:
2215+
with pytest.raises(
2216+
TypeError,
2217+
match=(
2218+
"supports inputs of type torch.Tensor, PIL.Image.Image, "
2219+
"and subclasses of torchvision.datapoints.Datapoint"
2220+
),
2221+
):
2222+
_get_kernel(F.resize, input_type)
2223+
2224+
def test_exact_match(self):
2225+
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
2226+
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
2227+
# here, register the kernels without wrapper, and check the exact matching afterwards.
2228+
def resize_with_pure_kernels():
2229+
pass
2230+
2231+
for input_type, kernel in self.KERNELS.items():
2232+
_register_kernel_internal(resize_with_pure_kernels, input_type, datapoint_wrapper=False)(kernel)
2233+
2234+
assert _get_kernel(resize_with_pure_kernels, input_type) is kernel
2235+
2236+
def test_builtin_datapoint_subclass(self):
2237+
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
2238+
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
2239+
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
2240+
# to the kernel of the corresponding superclass
2241+
def resize_with_pure_kernels():
2242+
pass
2243+
2244+
class MyImage(datapoints.Image):
2245+
pass
2246+
2247+
class MyBoundingBoxes(datapoints.BoundingBoxes):
2248+
pass
2249+
2250+
class MyMask(datapoints.Mask):
2251+
pass
2252+
2253+
class MyVideo(datapoints.Video):
2254+
pass
2255+
2256+
for custom_datapoint_subclass in [
2257+
MyImage,
2258+
MyBoundingBoxes,
2259+
MyMask,
2260+
MyVideo,
2261+
]:
2262+
builtin_datapoint_class = custom_datapoint_subclass.__mro__[1]
2263+
builtin_datapoint_kernel = self.KERNELS[builtin_datapoint_class]
2264+
_register_kernel_internal(resize_with_pure_kernels, builtin_datapoint_class, datapoint_wrapper=False)(
2265+
builtin_datapoint_kernel
2266+
)
2267+
2268+
assert _get_kernel(resize_with_pure_kernels, custom_datapoint_subclass) is builtin_datapoint_kernel
2269+
2270+
def test_datapoint_subclass(self):
2271+
class MyDatapoint(datapoints.Datapoint):
2272+
pass
2273+
2274+
# Note that this will be an error in the future
2275+
assert _get_kernel(F.resize, MyDatapoint) is _noop
2276+
2277+
def resize_my_datapoint():
2278+
pass
2279+
2280+
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)
2281+
2282+
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint

torchvision/transforms/v2/functional/_augment.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
88
from torchvision.utils import _log_api_usage_once
99

10-
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
10+
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
1111

1212

1313
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
@@ -20,23 +20,16 @@ def erase(
2020
v: torch.Tensor,
2121
inplace: bool = False,
2222
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
23-
if not torch.jit.is_scripting():
24-
_log_api_usage_once(erase)
25-
26-
if torch.jit.is_scripting() or is_simple_tensor(inpt):
23+
if torch.jit.is_scripting():
2724
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
28-
elif isinstance(inpt, datapoints.Datapoint):
29-
kernel = _get_kernel(erase, type(inpt))
30-
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
31-
elif isinstance(inpt, PIL.Image.Image):
32-
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
33-
else:
34-
raise TypeError(
35-
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
36-
f"but got {type(inpt)} instead."
37-
)
25+
26+
_log_api_usage_once(erase)
27+
28+
kernel = _get_kernel(erase, type(inpt))
29+
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
3830

3931

32+
@_register_kernel_internal(erase, torch.Tensor)
4033
@_register_kernel_internal(erase, datapoints.Image)
4134
def erase_image_tensor(
4235
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
@@ -48,7 +41,7 @@ def erase_image_tensor(
4841
return image
4942

5043

51-
@torch.jit.unused
44+
@_register_kernel_internal(erase, PIL.Image.Image)
5245
def erase_image_pil(
5346
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
5447
) -> PIL.Image.Image:

0 commit comments

Comments
 (0)