Skip to content

Implementation of the GaussianNoise transform for uint8 inputs #9169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3978,7 +3978,7 @@ class TestGaussianNoise:
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_kernel(self, make_input):
def test_kernel_float(self, make_input):
check_kernel(
F.gaussian_noise,
make_input(dtype=torch.float32),
Expand All @@ -3990,9 +3990,28 @@ def test_kernel(self, make_input):
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_functional(self, make_input):
def test_kernel_uint8(self, make_input):
check_kernel(
F.gaussian_noise,
make_input(dtype=torch.uint8),
# This cannot pass because the noise on a batch in not per-image
check_batched_vs_unbatched=False,
)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_functional_float(self, make_input):
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_functional_uint8(self, make_input):
check_functional(F.gaussian_noise, make_input(dtype=torch.uint8))

@pytest.mark.parametrize(
("kernel", "input_type"),
[
Expand All @@ -4008,10 +4027,11 @@ def test_functional_signature(self, kernel, input_type):
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_transform(self, make_input):
def test_transform_float(self, make_input):
def adapter(_, input, __):
# This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
# Same for PIL images
# We have two different implementations for floats and uint8
# To test this implementation we'll convert the auto-generated uint8 tensors to float32
# We don't support other int dtypes nor pil images
for key, value in input.items():
if isinstance(value, torch.Tensor) and not value.is_floating_point():
input[key] = value.to(torch.float32)
Expand All @@ -4021,11 +4041,29 @@ def adapter(_, input, __):

check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_video],
)
def test_transform_uint8(self, make_input):
def adapter(_, input, __):
# We have two different implementations for floats and uint8
# To test this implementation we'll convert every tensor to uint8
# We don't support other int dtypes nor pil images
for key, value in input.items():
if isinstance(value, torch.Tensor) and not value.dtype != torch.uint8:
input[key] = value.to(torch.uint8)
if isinstance(value, PIL.Image.Image):
input[key] = F.pil_to_tensor(value).to(torch.uint8)
return input

check_transform(transforms.GaussianNoise(), make_input(dtype=torch.uint8), check_sample_input=adapter)

def test_bad_input(self):
with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."):
F.gaussian_noise(make_image_pil())
with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"):
F.gaussian_noise(make_image(dtype=torch.uint8))
with pytest.raises(ValueError, match="Input tensor is expected to be in uint8 or float dtype"):
F.gaussian_noise(make_image(dtype=torch.int32))
with pytest.raises(ValueError, match="sigma shouldn't be negative"):
F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1)

Expand Down
15 changes: 12 additions & 3 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,22 @@ class GaussianNoise(Transform):
Each image or frame in a batch will be transformed independently i.e. the
noise added to each image will be different.

The input tensor is also expected to be of float dtype in ``[0, 1]``.
This transform does not support PIL images.
The input tensor is also expected to be of float dtype in ``[0, 1]``,
or of ``uint8`` dtype in ``[0, 255]``. This transform does not support PIL
images.

Regardless of the dtype used, the parameters of the function use the same
scale, so a ``mean`` parameter of 0.5 will result in an average value
increase of 0.5 units for float images, and an average increase of 127.5
units for ``uint8`` images.

Args:
mean (float): Mean of the sampled normal distribution. Default is 0.
sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1.
clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True.
clip (bool, optional): Whether to clip the values after adding noise, be it to
``[0, 1]`` for floats or to ``[0, 255]`` for ``uint8``. Setting this parameter to
``False`` may cause unsigned integer overflows with uint8 inputs.
Default is True.
"""

def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None:
Expand Down
25 changes: 18 additions & 7 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,27 @@ def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, cl
@_register_kernel_internal(gaussian_noise, torch.Tensor)
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
if not image.is_floating_point():
raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}")
if sigma < 0:
raise ValueError(f"sigma shouldn't be negative. Got {sigma}")

noise = mean + torch.randn_like(image) * sigma
out = image + noise
if clip:
out = torch.clamp(out, 0, 1)
return out
if image.is_floating_point():
noise = mean + torch.randn_like(image) * sigma
out = image + noise
if clip:
out = torch.clamp(out, 0, 1)
return out

elif image.dtype == torch.uint8:
# Convert to intermediate dtype int16 to add to input more efficiently
noise = ((mean * 255) + torch.randn_like(image, dtype=torch.float32) * (sigma * 255)).to(torch.int16)
out = image + noise

if clip:
out = torch.clamp(out, 0, 255)
return out.to(torch.uint8)

else:
raise ValueError(f"Input tensor is expected to be in uint8 or float dtype, got dtype={image.dtype}")


@_register_kernel_internal(gaussian_noise, tv_tensors.Video)
Expand Down