Skip to content

Conversation

@MalyalaKarthik66
Copy link
Contributor

Add keras.layers.RandomResizedCrop layer that:

  • Samples random crop area from scale=(0.08, 1.0) and aspect ratio from ratio=(0.75, 1.33)
  • Performs center crop in inference mode (preserves target aspect ratio)
  • Supports all backends (TF/JAX/Torch/NumPy) via backend.image.resize
  • Includes bounding box and segmentation mask transformations
  • Has 14 comprehensive tests passing on all backends
  • Gracefully handles OpenVINO (NotImplementedError until resize support added)

Closes #21822

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new RandomResizedCrop layer to Keras, a crucial data augmentation technique for image-based models. The layer intelligently samples random crop areas and aspect ratios during training, while ensuring a consistent center crop during inference. It's designed with multi-backend compatibility and extends its transformation capabilities to bounding boxes and segmentation masks, making it a versatile tool for various computer vision tasks.

Highlights

  • New Layer Introduction: A new keras.layers.RandomResizedCrop layer has been added, providing a powerful image augmentation technique for computer vision models.
  • Random Sampling Logic: The layer samples a random crop area from scale=(0.08, 1.0) and an aspect ratio from ratio=(0.75, 1.33) during training.
  • Inference Behavior: During inference (training=False), the layer performs a deterministic center crop that preserves the target aspect ratio, followed by resizing to the specified height and width.
  • Multi-Backend Support: The implementation supports all Keras backends (TensorFlow, JAX, PyTorch, NumPy) by utilizing backend.image.resize.
  • Comprehensive Data Transformation: Beyond images, the layer also includes logic for transforming associated bounding boxes and segmentation masks, ensuring consistency across different data types.
  • Extensive Testing: The new layer is accompanied by 14 comprehensive tests that pass on all supported backends, covering various scenarios and edge cases.
  • OpenVINO Compatibility: The layer gracefully handles the OpenVINO backend by raising a NotImplementedError until backend.image.resize support is available for OpenVINO.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the RandomResizedCrop layer, a valuable addition for image augmentation. The implementation is well-structured and includes comprehensive tests covering various scenarios and backends.

I've identified a few areas for improvement:

  • The class docstring is missing a usage example and shape information, which is recommended by the Keras API design guidelines.
  • There's a potential bug when an integer seed is passed directly to get_random_transformation, which could lead to non-random behavior.
  • The resizing logic for images and segmentation masks has some code duplication that can be refactored for better maintainability.

Overall, this is a great contribution. Addressing these points will make the new layer even more robust and user-friendly.

Comment on lines +85 to +114
def get_random_transformation(self, data, training=True, seed=None):
"""Returns a crop transformation `(h_start, w_start, crop_h, crop_w)`.

The same crop parameters are applied to all images in a batch,
which matches the behavior of other preprocessing layers.
"""
if isinstance(data, dict):
images = data.get("images", None)
input_shape = backend.shape(images)
else:
input_shape = backend.shape(data)

input_height = ops.cast(input_shape[self.height_axis], "float32")
input_width = ops.cast(input_shape[self.width_axis], "float32")

if training:
h_start, w_start, crop_h, crop_w = self._get_random_crop_params(
input_height, input_width, seed
)
else:
h_start, w_start, crop_h, crop_w = self._get_center_crop_params(
input_height, input_width
)

return (
ops.cast(h_start, "int32"),
ops.cast(w_start, "int32"),
ops.cast(crop_h, "int32"),
ops.cast(crop_w, "int32"),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If an integer seed is passed directly to get_random_transformation, all subsequent calls to backend.random.uniform within _get_random_crop_params will produce the same random number. This is because an integer seed is stateless.

To ensure different random numbers are generated for scale, ratio, and crop position, you should create a stateful SeedGenerator from the integer seed if one is provided.

    def get_random_transformation(self, data, training=True, seed=None):
        """Returns a crop transformation `(h_start, w_start, crop_h, crop_w)`.

        The same crop parameters are applied to all images in a batch,
        which matches the behavior of other preprocessing layers.
        """
        if isinstance(seed, int):
            seed = SeedGenerator(seed)
        if isinstance(data, dict):
            images = data.get("images", None)
            input_shape = backend.shape(images)
        else:
            input_shape = backend.shape(data)

        input_height = ops.cast(input_shape[self.height_axis], "float32")
        input_width = ops.cast(input_shape[self.width_axis], "float32")

        if training:
            h_start, w_start, crop_h, crop_w = self._get_random_crop_params(
                input_height, input_width, seed
            )
        else:
            h_start, w_start, crop_h, crop_w = self._get_center_crop_params(
                input_height, input_width
            )

        return (
            ops.cast(h_start, "int32"),
            ops.cast(w_start, "int32"),
            ops.cast(crop_h, "int32"),
            ops.cast(crop_w, "int32"),
        )

Comment on lines +12 to +45
"""Randomly crops and resizes images to a target size.

This layer:
1. Samples a random relative area from `scale`.
2. Samples a random aspect ratio from `ratio`.
3. Derives a crop window (height, width) from these values.
4. Crops the image and resizes the crop to `(height, width)`.

Args:
height: Integer. Target height of the output image.
width: Integer. Target width of the output image.
scale: Tuple of two floats `(min_scale, max_scale)`. The
sampled relative area (crop_area / image_area) will lie
in this range. Default `(0.08, 1.0)`.
ratio: Tuple of two floats `(min_ratio, max_ratio)`. Aspect
ratio (width / height) of the crop is sampled from this
interval in log-space. Default `(0.75, 1.33)`.
interpolation: String. Interpolation mode used in the resize
step, e.g. `"bilinear"`. Default `"bilinear"`.
seed: Optional integer. Random seed.
data_format: Optional string, `"channels_last"` or
`"channels_first"`. Follows global image data format by
default.
name: Optional string name.
**kwargs: Additional layer keyword arguments.

Notes:
* On inference (`training=False`), the layer performs a
deterministic center crop that preserves the target
aspect ratio, followed by resize to `(height, width)`.
* On the OpenVINO backend, `backend.image.resize` is not
implemented. In this case, the layer raises a
`NotImplementedError` at runtime.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring is missing a code example, as well as Input shape and Output shape sections, which are recommended by the Keras API design guidelines. Adding these will improve usability and documentation quality.

Please add a simple example demonstrating the layer's usage, and document the expected input and output shapes.

    """Randomly crops and resizes images to a target size.

    This layer:
      1. Samples a random relative area from `scale`.
      2. Samples a random aspect ratio from `ratio`.
      3. Derives a crop window (height, width) from these values.
      4. Crops the image and resizes the crop to `(height, width)`.

    Args:
        height: Integer. Target height of the output image.
        width: Integer. Target width of the output image.
        scale: Tuple of two floats `(min_scale, max_scale)`. The
            sampled relative area (crop_area / image_area) will lie
            in this range. Default `(0.08, 1.0)`.
        ratio: Tuple of two floats `(min_ratio, max_ratio)`. Aspect
            ratio (width / height) of the crop is sampled from this
            interval in log-space. Default `(0.75, 1.33)`.
        interpolation: String. Interpolation mode used in the resize
            step, e.g. `"bilinear"`. Default `"bilinear"`.
        seed: Optional integer. Random seed.
        data_format: Optional string, `"channels_last"` or
            `"channels_first"`. Follows global image data format by
            default.
        name: Optional string name.
        **kwargs: Additional layer keyword arguments.

    Input shape:
        A 3D (unbatched) or 4D (batched) tensor, with shape:
        `(..., height, width, channels)` if `data_format="channels_last"`,
        or `(..., channels, height, width)` if `data_format="channels_first"`.

    Output shape:
        A 3D (unbatched) or 4D (batched) tensor, with shape:
        `(..., self.height, self.width, channels)` if
        `data_format="channels_last"`, or
        `(..., channels, self.height, self.width)` if
        `data_format="channels_first"`.

    Example:
    ```python
    input_shape = (2, 300, 300, 3)
    input_data = ops.random.uniform(shape=input_shape)
    layer = keras.layers.RandomResizedCrop(height=224, width=224)
    output_data = layer(input_data)
    print(output_data.shape)
    # (2, 224, 224, 3)
    ```

    Notes:
        * On inference (`training=False`), the layer performs a
          deterministic center crop that preserves the target
          aspect ratio, followed by resize to `(height, width)`.
        * On the OpenVINO backend, `backend.image.resize` is not
          implemented. In this case, the layer raises a
          `NotImplementedError` at runtime.
    """
References
  1. The Keras API design guidelines recommend that all docstrings include code examples (line 145) and follow a specific structure, including Input/Output shape sections for layers (line 157). (link)

Comment on lines +205 to +267
def _resize_images(self, images):
"""Resize images to `(height, width)` using backend API.

For OpenVINO, this raises NotImplementedError because the
backend does not yet provide `backend.image.resize`.
"""
if backend.backend() == "openvino":
raise NotImplementedError(
"`RandomResizedCrop` is not yet supported on the "
"OpenVINO backend because `backend.image.resize` is "
"not implemented there. Please use `RandomCrop` or "
"switch to a different backend until resize support "
"is added."
)

return backend.image.resize(
images,
size=(self.height, self.width),
interpolation=self.interpolation,
antialias=False,
crop_to_aspect_ratio=False,
pad_to_aspect_ratio=False,
fill_mode="constant",
fill_value=0.0,
data_format=self.data_format,
)

def transform_images(self, images, transformation=None, training=True):
"""Apply random resized crop to a batch of images."""
if transformation is None:
transformation = self.get_random_transformation(
images, training=training
)
h_start, w_start, crop_h, crop_w = transformation

images = self._slice_images(images, h_start, w_start, crop_h, crop_w)
images = self._resize_images(images)
return images

def transform_segmentation_masks(
self, masks, transformation, training=True
):
"""Apply the same crop + resize to segmentation masks."""
h_start, w_start, crop_h, crop_w = transformation
masks = self._slice_images(masks, h_start, w_start, crop_h, crop_w)

if backend.backend() == "openvino":
raise NotImplementedError(
"Segmentation mask resizing for `RandomResizedCrop` is "
"not yet supported on the OpenVINO backend."
)

return backend.image.resize(
masks,
size=(self.height, self.width),
interpolation="nearest",
antialias=False,
crop_to_aspect_ratio=False,
pad_to_aspect_ratio=False,
fill_mode="constant",
fill_value=0.0,
data_format=self.data_format,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's some code duplication in _resize_images and transform_segmentation_masks for handling the OpenVINO backend and calling backend.image.resize. Also, the error message for OpenVINO in transform_segmentation_masks is less informative than the one in _resize_images.

You can refactor this into a single private helper method _resize to remove duplication and ensure consistent, helpful error messages. This improves maintainability.

    def _resize(self, images, interpolation):
        """Resize images to `(height, width)` using backend API.

        For OpenVINO, this raises NotImplementedError because the
        backend does not yet provide `backend.image.resize`.
        """
        if backend.backend() == "openvino":
            raise NotImplementedError(
                "`RandomResizedCrop` is not yet supported on the "
                "OpenVINO backend because `backend.image.resize` is "
                "not implemented there. Please use `RandomCrop` or "
                "switch to a different backend until resize support "
                "is added."
            )

        return backend.image.resize(
            images,
            size=(self.height, self.width),
            interpolation=interpolation,
            antialias=False,
            crop_to_aspect_ratio=False,
            pad_to_aspect_ratio=False,
            fill_mode="constant",
            fill_value=0.0,
            data_format=self.data_format,
        )

    def _resize_images(self, images):
        """Resize images to `(height, width)` using backend API."""
        return self._resize(images, self.interpolation)

    def transform_images(self, images, transformation=None, training=True):
        """Apply random resized crop to a batch of images."""
        if transformation is None:
            transformation = self.get_random_transformation(
                images, training=training
            )
        h_start, w_start, crop_h, crop_w = transformation

        images = self._slice_images(images, h_start, w_start, crop_h, crop_w)
        images = self._resize_images(images)
        return images

    def transform_segmentation_masks(
        self, masks, transformation, training=True
    ):
        """Apply the same crop + resize to segmentation masks."""
        h_start, w_start, crop_h, crop_w = transformation
        masks = self._slice_images(masks, h_start, w_start, crop_h, crop_w)
        return self._resize(masks, "nearest")

@codecov-commenter
Copy link

codecov-commenter commented Dec 10, 2025

Codecov Report

❌ Patch coverage is 65.27778% with 50 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.89%. Comparing base (46813a3) to head (641d67e).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
...cessing/image_preprocessing/random_resized_crop.py 65.03% 44 Missing and 6 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21917      +/-   ##
==========================================
+ Coverage   76.30%   76.89%   +0.59%     
==========================================
  Files         580      581       +1     
  Lines       60031    60175     +144     
  Branches     9433     9443      +10     
==========================================
+ Hits        45805    46271     +466     
+ Misses      11750    11549     -201     
+ Partials     2476     2355     -121     
Flag Coverage Δ
keras 76.79% <65.27%> (+0.62%) ⬆️
keras-jax 62.13% <65.27%> (+<0.01%) ⬆️
keras-numpy 57.34% <65.27%> (+0.02%) ⬆️
keras-openvino 34.25% <13.88%> (-0.05%) ⬇️
keras-tensorflow 64.30% <65.27%> (?)
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add RandomResizedCrop

3 participants