Skip to content

Conversation

@MalyalaKarthik66
Copy link
Contributor

@MalyalaKarthik66 MalyalaKarthik66 commented Nov 4, 2025

fix #21813

Add adaptive pooling support across major backends

This PR implements adaptive average and max pooling for 1D, 2D, and 3D across the JAX, NumPy, TensorFlow, and PyTorch backends.

  • For PyTorch, native adaptive pooling ops are used.

  • For JAX and TensorFlow, adaptive pooling is implemented using an efficient n-dimensional two-pool gather algorithm, eliminating multiple for-loops and providing robust performance on CPU, GPU, and TPU.

  • For NumPy, adaptive pooling is implemented using the same n-dimensional two-pool gather algorithm with NumPy stride tricks, providing an efficient pure-NumPy implementation for CPU.

  • All corresponding unit tests for JAX, NumPy, TensorFlow, and PyTorch adaptive pooling pass successfully.

  • Verified in real training model tests — both TensorFlow and PyTorch pass on GPU and CPU environments.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances Keras 3 by adding adaptive average and max pooling layers for 2D spatial data. These new layers allow users to specify a target output size, with the pooling kernel and stride automatically adjusted, providing greater flexibility in network architectures, particularly for tasks requiring fixed-size feature maps regardless of input dimensions. The implementation prioritizes the JAX backend while ensuring seamless integration with other Keras backends.

Highlights

  • New Adaptive Pooling Layers: Introduced AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers to Keras 3, allowing for a fixed output size regardless of input dimensions.
  • JAX Backend Implementation: Provided a PyTorch-compatible implementation for these adaptive pooling operations specifically for the JAX backend.
  • Unified Ops API: Exposed new keras.ops.adaptive_avg_pool and keras.ops.adaptive_max_pool functions for backend-agnostic usage.
  • Comprehensive Testing: Included extensive unit tests, numerical parity checks against PyTorch, and support for both channels_first and channels_last data formats.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers, along with their corresponding backend operations. The changes include the layer definitions, JAX backend implementations, ops API, and comprehensive tests. The layer APIs and tests are well-designed. However, the JAX backend implementation has significant performance issues due to the use of Python loops, which are not JIT-compatible. There are also opportunities to improve code quality by removing dead code and reducing duplication. My review provides specific feedback on these points.

Comment on lines 1515 to 1533
for i in range(out_h):
for j in range(out_w):
# Calculate pooling region for this output position
start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32)
end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32)
start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32)
end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32)

# Extract region and apply average pooling
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
# Average over spatial dimensions (axis 1, 2)
pooled = jnp.mean(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
# Average over spatial dimensions (axis 2, 3)
pooled = jnp.mean(region, axis=(2, 3))

result_list.append(pooled)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of adaptive pooling uses Python for loops to iterate over output positions. This is an anti-pattern in JAX as it prevents JIT compilation and leads to very poor performance, especially for larger inputs or output sizes. The computation should be expressed using JAX's vectorized operations or JIT-compatible loops like lax.fori_loop to achieve good performance. A fully vectorized einsum-based approach for average pooling, or a lax.fori_loop over output pixels for both pooling types, would be significantly more performant. This comment also applies to the adaptive_max_pool implementation.

Comment on lines 1469 to 1478
def _adaptive_pool_start_index(output_idx, output_size, input_size):
"""Calculate start index for adaptive pooling (PyTorch compatible)."""
return jnp.floor((output_idx * input_size) / output_size).astype(jnp.int32)


def _adaptive_pool_end_index(output_idx, output_size, input_size):
"""Calculate end index for adaptive pooling (PyTorch compatible)."""
return jnp.ceil(((output_idx + 1) * input_size) / output_size).astype(
jnp.int32
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper functions _adaptive_pool_start_index and _adaptive_pool_end_index are defined but not used. This dead code should be removed to improve code clarity.

Comment on lines 1481 to 1618
def adaptive_avg_pool(
inputs, output_size, data_format="channels_last", name=None
):
"""
Adaptive average pooling for JAX backend (PyTorch-compatible).
"""
# Convert output_size to tuple
spatial_dims = inputs.ndim - 2
if isinstance(output_size, int):
output_size = (output_size,) * spatial_dims
else:
output_size = tuple(output_size)

# Get spatial shape
if data_format == "channels_last":
batch_size = inputs.shape[0]
channels = inputs.shape[-1]
spatial_shape = inputs.shape[1:-1]
else: # channels_first
batch_size = inputs.shape[0]
channels = inputs.shape[1]
spatial_shape = inputs.shape[2:]

if len(output_size) != 2:
raise NotImplementedError(
"Only 2D adaptive pooling is currently supported"
)

out_h, out_w = output_size
in_h, in_w = spatial_shape

# Build output by iterating over output positions
result_list = []

for i in range(out_h):
for j in range(out_w):
# Calculate pooling region for this output position
start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32)
end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32)
start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32)
end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32)

# Extract region and apply average pooling
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
# Average over spatial dimensions (axis 1, 2)
pooled = jnp.mean(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
# Average over spatial dimensions (axis 2, 3)
pooled = jnp.mean(region, axis=(2, 3))

result_list.append(pooled)

# Stack results: (out_h*out_w, batch, channels)
output = jnp.stack(result_list, axis=0)

# Reshape and transpose to correct output shape
if data_format == "channels_last":
# (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 0, 1, 3))
else: # channels_first
# (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 3, 0, 1))

return output


def adaptive_max_pool(
inputs, output_size, data_format="channels_last", name=None
):
"""
Adaptive max pooling for JAX backend (PyTorch-compatible).
"""
# Convert output_size to tuple
spatial_dims = inputs.ndim - 2
if isinstance(output_size, int):
output_size = (output_size,) * spatial_dims
else:
output_size = tuple(output_size)

# Get spatial shape
if data_format == "channels_last":
batch_size = inputs.shape[0]
channels = inputs.shape[-1]
spatial_shape = inputs.shape[1:-1]
else: # channels_first
batch_size = inputs.shape[0]
channels = inputs.shape[1]
spatial_shape = inputs.shape[2:]

if len(output_size) != 2:
raise NotImplementedError(
"Only 2D adaptive pooling is currently supported"
)

out_h, out_w = output_size
in_h, in_w = spatial_shape

# Build output by iterating over output positions
result_list = []

for i in range(out_h):
for j in range(out_w):
# Calculate pooling region for this output position
start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32)
end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32)
start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32)
end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32)

# Extract region and apply max pooling
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
# Max over spatial dimensions (axis 1, 2)
pooled = jnp.max(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
# Max over spatial dimensions (axis 2, 3)
pooled = jnp.max(region, axis=(2, 3))

result_list.append(pooled)

# Stack results: (out_h*out_w, batch, channels)
output = jnp.stack(result_list, axis=0)

# Reshape and transpose to correct output shape
if data_format == "channels_last":
# (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 0, 1, 3))
else: # channels_first
# (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 3, 0, 1))

return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The functions adaptive_avg_pool and adaptive_max_pool are nearly identical, with the only difference being the pooling operation (jnp.mean vs jnp.max). This code duplication can be avoided by creating a generic _adaptive_pool helper function that takes the pooling function as an argument. This would improve maintainability and reduce redundancy.

For example:

def _adaptive_pool(inputs, output_size, data_format, pool_op):
    # ... common setup code ...
    for i in range(out_h):
        for j in range(out_w):
            # ... common region calculation ...
            if data_format == "channels_last":
                region = inputs[:, start_h:end_h, start_w:end_w, :]
                pooled = pool_op(region, axis=(1, 2))
            else:  # channels_first
                region = inputs[:, :, start_h:end_h, start_w:end_w]
                pooled = pool_op(region, axis=(2, 3))
            result_list.append(pooled)
    # ... common reshape and transpose code ...
    return output

def adaptive_avg_pool(inputs, output_size, data_format="channels_last", name=None):
    # ...
    return _adaptive_pool(inputs, output_size, data_format, jnp.mean)

def adaptive_max_pool(inputs, output_size, data_format="channels_last", name=None):
    # ...
    return _adaptive_pool(inputs, output_size, data_format, jnp.max)

Note that this refactoring suggestion still contains the performance issue mentioned in another comment. The primary goal here is to illustrate how to reduce code duplication.

@codecov-commenter
Copy link

codecov-commenter commented Nov 4, 2025

Codecov Report

❌ Patch coverage is 82.96623% with 116 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.59%. Comparing base (a2897fa) to head (5fb2523).

Files with missing lines Patch % Lines
keras/src/backend/tensorflow/nn.py 85.18% 14 Missing and 14 partials ⚠️
keras/src/ops/nn.py 34.88% 24 Missing and 4 partials ⚠️
keras/src/backend/jax/nn.py 90.24% 8 Missing and 8 partials ⚠️
keras/src/backend/torch/nn.py 73.33% 6 Missing and 6 partials ⚠️
keras/src/backend/numpy/nn.py 92.06% 5 Missing and 5 partials ⚠️
...s/src/layers/pooling/adaptive_average_pooling1d.py 53.84% 5 Missing and 1 partial ⚠️
keras/src/layers/pooling/adaptive_max_pooling1d.py 53.84% 5 Missing and 1 partial ⚠️
...s/src/layers/pooling/adaptive_average_pooling2d.py 81.81% 1 Missing and 1 partial ⚠️
...s/src/layers/pooling/adaptive_average_pooling3d.py 81.81% 1 Missing and 1 partial ⚠️
keras/src/layers/pooling/adaptive_max_pooling2d.py 81.81% 1 Missing and 1 partial ⚠️
... and 2 more
Additional details and impacted files
@@           Coverage Diff            @@
##           master   #21820    +/-   ##
========================================
  Coverage   82.59%   82.59%            
========================================
  Files         580      587     +7     
  Lines       60226    60907   +681     
  Branches     9444     9549   +105     
========================================
+ Hits        49742    50307   +565     
- Misses       8050     8122    +72     
- Partials     2434     2478    +44     
Flag Coverage Δ
keras 82.42% <82.96%> (+<0.01%) ⬆️
keras-jax 61.66% <38.91%> (-0.26%) ⬇️
keras-numpy 56.89% <35.68%> (-0.24%) ⬇️
keras-openvino 35.40% <12.92%> (-0.26%) ⬇️
keras-tensorflow 63.85% <41.11%> (-0.26%) ⬇️
keras-torch 62.58% <23.05%> (-0.45%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@MalyalaKarthik66 MalyalaKarthik66 changed the title Add AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers Add adaptive pooling (1D, 2D, 3D) support across JAX, TensorFlow, and PyTorch backends Nov 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds adaptive pooling support for JAX, TensorFlow, and PyTorch backends, which is a great addition. The implementation for JAX and TensorFlow uses a custom "Two-Pool Gather" algorithm, while the PyTorch implementation leverages native operations. The code is well-structured and includes corresponding unit tests.

My review focuses on improving maintainability by reducing code duplication in the backend implementations, ensuring user-facing elements like docstrings and error messages are clear and accurate, and maintaining code style consistency. I've provided several suggestions to address these points.

Comment on lines 1508 to 1509
n, l, c = inputs.shape
out_l = output_size[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable names n, l, c, and out_l are quite short. According to the Keras API design guidelines, it's preferred to use fully spelled-out names to improve readability, with a few common exceptions like dim and num.1 Consider using more descriptive names like batch_size, length, channels, and output_length. This comment also applies to the other adaptive pooling functions in this file.

For example:
n, l, c = inputs.shape -> batch_size, length, channels = inputs.shape
out_l = output_size[0] -> output_length = output_size[0]

Style Guide References

Footnotes

  1. The style guide recommends using fully spelled-out names for variables and arguments to improve clarity, e.g., attention_scores instead of attn_scores. Short names are acceptable only for very common terms like dim or num.

Comment on lines 1499 to 1831
# ---------- 1D Adaptive Pooling ----------
def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid"
)
small_pool_l = small_pool_l / small_l

big_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid"
)
big_pool_l = big_pool_l / big_l

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid"
)
big_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid"
)

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


# ---------- 2D Adaptive Pooling ----------
def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


# ---------- 3D Adaptive Pooling ----------
def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_d = small_pool_d / small_d

big_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_d = big_pool_d / big_d

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs,
-jnp.inf,
lax.max,
(1, small_d, 1, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_d = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, small_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, big_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, small_w, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, big_w, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


# ---------- Dispatcher ----------
def adaptive_avg_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive average pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_avg_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_avg_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_avg_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_avg_pool supports 1D, 2D, or 3D inputs only."
)


def adaptive_max_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive max pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_max_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_max_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_max_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_max_pool supports 1D, 2D, or 3D inputs only."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementations for adaptive_avg_pool{1,2,3}d and adaptive_max_pool{1,2,3}d are very similar, leading to significant code duplication. To improve maintainability, consider refactoring this code.

Here are a couple of suggestions:

  1. Create a helper function for each dimension (e.g., _adaptive_pool1d) that takes the pooling type ('avg' or 'max') as an argument. This would halve the number of functions.
  2. A more advanced refactoring would be to create a single generic n-dimensional pooling function that iterates over the spatial dimensions. This would further consolidate the logic for 1D, 2D, and 3D pooling into one place.

Comment on lines 138 to 141
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX or Torch backend."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message is missing TensorFlow as a supported backend for adaptive pooling. Please update the message to include it for accuracy.

Suggested change
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX or Torch backend."
)
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX, TensorFlow, or Torch backend."
)

Comment on lines 158 to 161
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX or Torch backend."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message is missing TensorFlow as a supported backend for adaptive pooling. Please update the message to include it for accuracy.

Suggested change
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX or Torch backend."
)
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX, TensorFlow, or Torch backend."
)

Comment on lines +301 to +361
static_shape = inputs.shape.as_list()
l_static = static_shape[1]
out_l = output_size[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable names l_static and out_l are quite short. According to the Keras API design guidelines, it's preferred to use fully spelled-out names to improve readability.1 Consider using more descriptive names like static_length and output_length. This comment also applies to the other adaptive pooling functions in this file.

For example:
l_static = static_shape[1] -> static_length = static_shape[1]
out_l = output_size[0] -> output_length = output_size[0]

Style Guide References

Footnotes

  1. The style guide recommends using fully spelled-out names for variables and arguments to improve clarity, e.g., attention_scores instead of attn_scores. Short names are acceptable only for very common terms like dim or num.

Comment on lines 74 to 79
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received output_size={} of type {}".format(
output_size, type(output_size)
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other new files in this PR, please use an f-string for this error message.

            raise TypeError(
                f"`output_size` must be an integer or tuple of 3 integers. "
                f"Received: output_size={output_size} of type "
                f"{type(output_size)}"
            )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 please fix

Comment on lines 53 to 65
if not isinstance(output_size, int):
raise TypeError(
"`output_size` must be an integer. Received output_size={} "
"of type {}".format(output_size, type(output_size))
)
self.output_size = output_size
self.data_format = data_format or config.image_data_format()

if self.data_format not in {"channels_first", "channels_last"}:
raise ValueError(
"Invalid data_format: {}. Must be either 'channels_first' "
"or 'channels_last'.".format(self.data_format)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other new files in this PR, please use f-strings for these error messages instead of .format().

Suggested change
if not isinstance(output_size, int):
raise TypeError(
"`output_size` must be an integer. Received output_size={} "
"of type {}".format(output_size, type(output_size))
)
self.output_size = output_size
self.data_format = data_format or config.image_data_format()
if self.data_format not in {"channels_first", "channels_last"}:
raise ValueError(
"Invalid data_format: {}. Must be either 'channels_first' "
"or 'channels_last'.".format(self.data_format)
)
if not isinstance(output_size, int):
raise TypeError(
f"`output_size` must be an integer. Received: output_size={output_size} "
f"of type {type(output_size)}"
)
self.output_size = output_size
self.data_format = data_format or config.image_data_format()
if self.data_format not in {"channels_first", "channels_last"}:
raise ValueError(
f"Invalid data_format: {self.data_format}. Must be either 'channels_first' "
f"or 'channels_last'."
)

Comment on lines 66 to 84
if len(output_size) != 3:
raise ValueError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received: {}".format(output_size)
)
self.output_size = tuple(output_size)
else:
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received: {} of type {}".format(output_size, type(output_size))
)

self.data_format = data_format or config.image_data_format()

if self.data_format not in {"channels_first", "channels_last"}:
raise ValueError(
"Invalid data_format: {}. Must be either 'channels_first' or "
"'channels_last'.".format(self.data_format)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other new files in this PR, please use f-strings for these error messages instead of .format().

Suggested change
if len(output_size) != 3:
raise ValueError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received: {}".format(output_size)
)
self.output_size = tuple(output_size)
else:
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received: {} of type {}".format(output_size, type(output_size))
)
self.data_format = data_format or config.image_data_format()
if self.data_format not in {"channels_first", "channels_last"}:
raise ValueError(
"Invalid data_format: {}. Must be either 'channels_first' or "
"'channels_last'.".format(self.data_format)
)
if len(output_size) != 3:
raise ValueError(
f"`output_size` must be an integer or tuple of 3 integers. "
f"Received: {output_size}"
)
self.output_size = tuple(output_size)
else:
raise TypeError(
f"`output_size` must be an integer or tuple of 3 integers. "
f"Received: output_size={output_size} of type {type(output_size)}"
)
self.data_format = data_format or config.image_data_format()
if self.data_format not in {"channels_first", "channels_last"}:
raise ValueError(
f"Invalid data_format: {self.data_format}. Must be either 'channels_first' or "
f"'channels_last'."
)

Comment on lines +1172 to +1235
"""Adaptive max pooling operation.
Applies an adaptive max pooling operation that automatically computes the
kernel size and stride to pool the input to the specified `output_size`.
This operation is useful when you want a fixed output size regardless of
input size, commonly used in models like ResNet for global feature
extraction.
Args:
inputs: Tensor of rank 4. Input tensor of shape:
- If `data_format="channels_last"`:
`(batch_size, height, width, channels)`.
- If `data_format="channels_first"`:
`(batch_size, channels, height, width)`.
output_size: Integer or tuple/list of 2 integers, specifying the target
output spatial dimensions `(output_height, output_width)`. If a
single
integer is provided, the same value is used for both dimensions.
data_format: string, either `"channels_last"` or `"channels_first"`.
Defaults to the value found in your Keras config file at
`~/.keras/keras.json`. If never set, defaults to `"channels_last"`.
Returns:
A tensor of rank 4 representing the adaptive max pooled result.
Example:
>>> x = np.random.rand(2, 64, 64, 3)
>>> y = keras.ops.adaptive_max_pool(x, output_size=(32, 32))
>>> y.shape
(2, 32, 32, 3)
>>> # Works with any input size
>>> x = np.random.rand(2, 100, 80, 3)
>>> y = keras.ops.adaptive_max_pool(x, output_size=7)
>>> y.shape
(2, 7, 7, 3)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring describes the 2D case, but this function is a dispatcher for 1D, 2D, and 3D pooling. Please update the docstring to be more general and include examples for other dimensions to avoid confusion for users.1

    """Adaptive max pooling operation for 1D, 2D, and 3D data.

    This operation is useful when you want a fixed output size regardless of
    input size.

    Args:
        inputs: Input tensor. Must be 3D, 4D, or 5D.
        output_size: An integer or a tuple of integers, specifying the output
            spatial dimensions.
        data_format: string, either `"channels_last"` or `"channels_first"`.
            Defaults to the value found in your Keras config file at
            `~/.keras/keras.json`. If never set, defaults to `"channels_last"`.

    Returns:
        A tensor representing the adaptive max pooled result.

    Example:

    **2D Example**

    >>> x = np.random.rand(2, 64, 64, 3)
    >>> y = keras.ops.adaptive_max_pool(x, output_size=(32, 32))
    >>> y.shape
    (2, 32, 32, 3)

    **3D Example**

    >>> x = np.random.rand(2, 32, 32, 32, 3)
    >>> y = keras.ops.adaptive_max_pool(x, output_size=(16, 16, 16))
    >>> y.shape
    (2, 16, 16, 16, 3)
    """

Style Guide References

Footnotes

  1. Docstrings should be comprehensive and show examples for common use cases and key features to guide the user effectively.

Comment on lines 1319 to 1412
"""Adaptive average pooling operation.
Applies an adaptive average pooling operation that automatically
computes the
kernel size and stride to pool the input to the specified `output_size`.
This operation is useful when you want a fixed output size regardless of
input size, commonly used in models like ResNet for global feature
extraction.
Args:
inputs: Tensor of rank 4. Input tensor of shape:
- If `data_format="channels_last"`:
`(batch_size, height, width, channels)`.
- If `data_format="channels_first"`:
`(batch_size, channels, height, width)`.
output_size: Integer or tuple/list of 2 integers, specifying the target
output spatial dimensions `(output_height, output_width)`. If a
single
integer is provided, the same value is used for both dimensions.
data_format: string, either `"channels_last"` or `"channels_first"`.
Defaults to the value found in your Keras config file at
`~/.keras/keras.json`. If never set, defaults to `"channels_last"`.
Returns:
A tensor of rank 4 representing the adaptive average pooled result.
Example:
>>> x = np.random.rand(2, 64, 64, 3)
>>> y = keras.ops.adaptive_avg_pool(x, output_size=(32, 32))
>>> y.shape
(2, 32, 32, 3)
>>> # Works with any input size
>>> x = np.random.rand(2, 100, 80, 3)
>>> y = keras.ops.adaptive_avg_pool(x, output_size=7)
>>> y.shape
(2, 7, 7, 3)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring describes the 2D case, but this function is a dispatcher for 1D, 2D, and 3D pooling. Please update the docstring to be more general and include examples for other dimensions to avoid confusion for users.1

    """Adaptive average pooling operation for 1D, 2D, and 3D data.

    This operation is useful when you want a fixed output size regardless of
    input size.

    Args:
        inputs: Input tensor. Must be 3D, 4D, or 5D.
        output_size: An integer or a tuple of integers, specifying the output
            spatial dimensions.
        data_format: string, either `"channels_last"` or `"channels_first"`.
            Defaults to the value found in your Keras config file at
            `~/.keras/keras.json`. If never set, defaults to `"channels_last"`.

    Returns:
        A tensor representing the adaptive average pooled result.

    Example:

    **2D Example**

    >>> x = np.random.rand(2, 64, 64, 3)
    >>> y = keras.ops.adaptive_avg_pool(x, output_size=(32, 32))
    >>> y.shape
    (2, 32, 32, 3)

    **3D Example**

    >>> x = np.random.rand(2, 32, 32, 32, 3)
    >>> y = keras.ops.adaptive_avg_pool(x, output_size=(16, 16, 16))
    >>> y.shape
    (2, 16, 16, 16, 3)
    """

Style Guide References

Footnotes

  1. Docstrings should be comprehensive and show examples for common use cases and key features to guide the user effectively.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces adaptive pooling (1D, 2D, and 3D) for JAX, TensorFlow, and PyTorch backends, along with the corresponding Keras layers and tests. The implementation for PyTorch leverages native operations, while for JAX and TensorFlow, a custom "Two-Pool Gather" algorithm is used. The changes are comprehensive and well-tested. My main feedback is to refactor the JAX and TensorFlow implementations to reduce significant code duplication, which will improve maintainability. I've also noted some minor issues with error messages and docstrings.

Comment on lines 1500 to 1831
def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid"
)
small_pool_l = small_pool_l / small_l

big_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid"
)
big_pool_l = big_pool_l / big_l

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid"
)
big_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid"
)

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


# ---------- 2D Adaptive Pooling ----------
def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


# ---------- 3D Adaptive Pooling ----------
def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_d = small_pool_d / small_d

big_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_d = big_pool_d / big_d

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs,
-jnp.inf,
lax.max,
(1, small_d, 1, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_d = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, small_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, big_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, small_w, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, big_w, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


# ---------- Dispatcher ----------
def adaptive_avg_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive average pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_avg_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_avg_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_avg_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_avg_pool supports 1D, 2D, or 3D inputs only."
)


def adaptive_max_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive max pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_max_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_max_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_max_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_max_pool supports 1D, 2D, or 3D inputs only."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementations for 1D, 2D, and 3D adaptive pooling for both avg and max operations contain a significant amount of duplicated code. This makes the code harder to read and maintain.

Consider refactoring this by creating a generalized helper function. This function could handle the pooling logic for a single dimension and could be parameterized for average vs. max pooling.

For example, you could have a helper:
_adaptive_pool_1d_single_dim(inputs, axis, output_dim, reduce_fn, init_val, normalize=False)

Then, the 2D and 3D functions can be implemented by composing this helper function for each spatial dimension. This would greatly reduce code duplication and improve maintainability.

Comment on lines 138 to 140
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX or Torch backend."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message should be updated to include TensorFlow as a supported backend for adaptive pooling.

        "Adaptive pooling not implemented for OpenVINO. "
        "Use JAX, TensorFlow or Torch backend."

Comment on lines 158 to 160
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX or Torch backend."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message should be updated to include TensorFlow as a supported backend for adaptive pooling.

        "Adaptive pooling not implemented for OpenVINO. "
        "Use JAX, TensorFlow or Torch backend."

Comment on lines 74 to 79
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received output_size={} of type {}".format(
output_size, type(output_size)
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other new layer files in this PR, please use an f-string for this error message.

Suggested change
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received output_size={} of type {}".format(
output_size, type(output_size)
)
)
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
f"Received: output_size={output_size} of type "
f"{type(output_size)}"
)

Comment on lines 53 to 57
if not isinstance(output_size, int):
raise TypeError(
"`output_size` must be an integer. Received output_size={} "
"of type {}".format(output_size, type(output_size))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other new layer files in this PR, please use an f-string for this error message.

            raise TypeError(
                f"`output_size` must be an integer. Received: output_size={output_size} "
                f"of type {type(output_size)}"
            )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines 67 to 70
raise ValueError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received: {}".format(output_size)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency, please use an f-string for this error message. Also, consider raising a TypeError instead of a ValueError here, as the check is on the length of the output_size tuple, which relates to its structure/type in this context.

Suggested change
raise ValueError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received: {}".format(output_size)
)
raise ValueError(
"`output_size` must be an integer or tuple of 3 integers. "
f"Received: output_size={output_size}"
)

Comment on lines 73 to 76
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received: {} of type {}".format(output_size, type(output_size))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other new layer files in this PR, please use an f-string for this error message.

Suggested change
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received: {} of type {}".format(output_size, type(output_size))
)
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
f"Received: output_size={output_size} of type {type(output_size)}"
)

Comment on lines +1172 to +1215
"""Adaptive max pooling operation.
Applies an adaptive max pooling operation that automatically computes the
kernel size and stride to pool the input to the specified `output_size`.
This operation is useful when you want a fixed output size regardless of
input size, commonly used in models like ResNet for global feature
extraction.
Args:
inputs: Tensor of rank 4. Input tensor of shape:
- If `data_format="channels_last"`:
`(batch_size, height, width, channels)`.
- If `data_format="channels_first"`:
`(batch_size, channels, height, width)`.
output_size: Integer or tuple/list of 2 integers, specifying the target
output spatial dimensions `(output_height, output_width)`. If a
single
integer is provided, the same value is used for both dimensions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for adaptive_max_pool is not entirely accurate. It states that the input is a 'Tensor of rank 4' and output_size is for 2D inputs. However, this function supports 1D, 2D, and 3D inputs (ranks 3, 4, and 5). Please update the docstring to reflect this, for example:

    Args:
        inputs: Tensor of rank 3, 4, or 5.
        output_size: Integer or tuple/list of 1, 2, or 3 integers, specifying
            the target output spatial dimensions. If a single integer is
            provided, the same value is used for all spatial dimensions.

Comment on lines 1319 to 1392
"""Adaptive average pooling operation.
Applies an adaptive average pooling operation that automatically
computes the
kernel size and stride to pool the input to the specified `output_size`.
This operation is useful when you want a fixed output size regardless of
input size, commonly used in models like ResNet for global feature
extraction.
Args:
inputs: Tensor of rank 4. Input tensor of shape:
- If `data_format="channels_last"`:
`(batch_size, height, width, channels)`.
- If `data_format="channels_first"`:
`(batch_size, channels, height, width)`.
output_size: Integer or tuple/list of 2 integers, specifying the target
output spatial dimensions `(output_height, output_width)`. If a
single
integer is provided, the same value is used for both dimensions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for adaptive_avg_pool is not entirely accurate. It states that the input is a 'Tensor of rank 4' and output_size is for 2D inputs. However, this function supports 1D, 2D, and 3D inputs (ranks 3, 4, and 5). Please update the docstring to reflect this, for example:

    Args:
        inputs: Tensor of rank 3, 4, or 5.
        output_size: Integer or tuple/list of 1, 2, or 3 integers, specifying
            the target output spatial dimensions. If a single integer is
            provided, the same value is used for all spatial dimensions.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this PR! I can see that a lot of work went into this.

Please run ./shell/api_gen.sh to generate the __init__.py files.

Or more generally, use pre-commit hooks: https://github.com/keras-team/keras/blob/master/CONTRIBUTING.md#generating-public-api-and-formatting-the-code

Comment on lines 28 to 29
from keras.src.backend.jax.nn import adaptive_avg_pool
from keras.src.backend.jax.nn import adaptive_max_pool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this file.

Comment on lines 1469 to 1473
def get_static_window_sizes(input_dim, output_dim):
"""Calculate small and big window sizes for adaptive pooling."""
small_window = math.ceil(input_dim / output_dim)
big_window = small_window + 1
return small_window, big_window
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some version of this is done in each backend. Can we share the code between backends? It can go in keras/src/ops/operation_utils.py, you should name it something more specific like compute_adaptive_pooling_window_sizes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hertschuh
I initially moved this into ops/operation_utils, but that introduced a circular import since ops/__init__.py imports backend symbols at the top level. Would it be OK if I moved the utility into backend/common/backend_utils.py and import it from the backends instead?

return small_window, big_window


def compute_static_gather_indices(input_dim, output_size, big_window):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename _compute_static_gather_indices to distinguish this from other ops.

return gather_indices.astype(jnp.int32)


# ---------- 1D Adaptive Pooling ----------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove those # ---------- comments to stay consistent with the style (same for all the other ones)



# ---------- 1D Adaptive Pooling ----------
def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename _adaptive_avg_pool1d since it's not a real op, to make it obvious it is not directly an op implementation (same for the other ones in this file and the other ones in the other backends).

Comment on lines 1242 to 1255
def adaptive_max_pool(inputs, output_size, data_format=None):
"""Adaptive max pooling - Numpy backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for Numpy. "
"Use JAX, Torch or Tensorflow backend."
)


def adaptive_avg_pool(inputs, output_size, data_format=None):
"""Adaptive average pooling - Numpy backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for Numpy. "
"Use JAX, Torch or Tensorflow backend."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way we can have a NumPy implementation?

If not, can we plug the JAX implementation? (like we did for convolutions).



@keras_export("keras.layers.AdaptiveAveragePooling3D")
class AdaptiveAveragePooling3D(Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be beneficial to have a super class BaseAdaptiveAveragePooling. Most of the code is shared between 1d/2d/3d. Look at convolutions to see the pattern.

But basically:

  • call would be in the super class only
  • get_config would be in the super class only
  • compute_output_shape could probably be generalized to be moved to the super class
  • __init__ would be both in the super class and sub-classes. The super class __init__ would take output_size and set it as a field and would take data_format, validate it and set it as a field. The sub class __init__ would validate output_size and normalize it to a tuple before passing it to super().__init__.

Same for AdaptiveMaxPooling and AdaptivePooling.

Comment on lines 74 to 79
raise TypeError(
"`output_size` must be an integer or tuple of 3 integers. "
"Received output_size={} of type {}".format(
output_size, type(output_size)
)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 please fix

import numpy as np
import pytest

from keras.src import backend as K
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do keras.src import backend and then use backend.foo.

We no longer use K.

Comment on lines 22 to 27
try:
import torch

TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not great because JAX GPU/TPU tests and TF GPU tests don't install Torch on purpose, so the testing would only be done on CPU.

Can we instead hardcode some small input tensors and expected outputs?

@MalyalaKarthik66
Copy link
Contributor Author

@hertschuh
Thanks for the review! I’ll make all the suggested changes.

@MalyalaKarthik66
Copy link
Contributor Author

@hertschuh
I’ve updated the PR with all the requested changes. I introduced shared base classes for adaptive pooling and refactored the 1D/2D/3D layers accordingly, and also implemented the NumPy backend version.

I initially moved compute_adaptive_pooling_window_sizes into ops/operation_utils, but that created a circular import since ops/__init__ imports backend symbols. I’ve moved the utility into backend/common/backend_utils.py instead and import it from each backend.

Please let me know if anything else needs to be adjusted — happy to update further.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This is really close.

A couple comments that impact several files:

rename avg -> average (e.g. adaptive_average_pool). While there is a precedent for shortening "maximum" to "max", there isn't for "average" (e.g. the current average_pool op).

In each one of the nn.py files, can you reorder things like this for consistency:

  • existing max_pool op
  • existing average_pool op
  • _compute_adaptive_pooling_gather_indices
  • other private functions 1d / 2d / 3d
  • adaptive_average_pool
  • adaptive_max_pool
  • other ops

"""Adaptive max pooling - OpenVINO backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX or Torch backend."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove that part, just keep the part that says it's not implemented on OpenVino

"""Adaptive average pooling - OpenVINO backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for OpenVINO. "
"Use JAX or Torch backend."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment.

return outputs


def compute_static_gather_indices(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename _compute_static_gather_indices like the other ones.

Comment on lines 53 to 57
if not isinstance(output_size, int):
raise TypeError(
"`output_size` must be an integer. Received output_size={} "
"of type {}".format(output_size, type(output_size))
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

"""
if data_format is None:
data_format = config.image_data_format()
return backend.nn.adaptive_avg_pool(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need a class called AdaptiveAveragePool. Look at AveragePool above for an example.

Then here, add:

 if any_symbolic_tensors((inputs,)):
        return AdaptiveAveragePool(output_size, data_format).symbolic_call(inputs)

"""
if data_format is None:
data_format = config.image_data_format()
return backend.nn.adaptive_max_pool(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need a class called AdaptiveMaxPool. Look at AveragePool below for an example.

Then here, add:

 if any_symbolic_tensors((inputs,)):
        return AdaptiveMaxPool(output_size, data_format).symbolic_call(inputs)

Comment on lines 50 to 54
if not isinstance(output_size, int):
raise TypeError(
f"`output_size` must be an integer. "
f"Received: {output_size} of type {type(output_size)}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, also allow a tuple of size 1 and mention it in the error message.

Comment on lines 50 to 54
if not isinstance(output_size, int):
raise TypeError(
"`output_size` must be an integer. Received output_size={} "
"of type {}".format(output_size, type(output_size))
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, also allow a tuple of size 1 and mention it in the error message.

@MalyalaKarthik66
Copy link
Contributor Author

@hertschuh
Thanks! I've updated the PR with all the requested changes.

@hertschuh
Copy link
Collaborator

Great! This is ready to go except:

  • can you rebase (one file need merging)?
  • can you run the code formatter?

@MalyalaKarthik66
Copy link
Contributor Author

Great! This is ready to go except:

  • can you rebase (one file need merging)?
  • can you run the code formatter?

@hertschuh
Thanks! I’ve rebased onto the latest upstream master, resolved the OpenVINO conflict, and run the formatter. Please let me know if anything else is needed.

@hertschuh
Copy link
Collaborator

@MalyalaKarthik66

Can you run ./shell/api_gen.sh to reformat the code?

@MalyalaKarthik66
Copy link
Contributor Author

MalyalaKarthik66 commented Dec 14, 2025

@MalyalaKarthik66

Can you run ./shell/api_gen.sh to reformat the code?

@hertschuh
I set up Python 3.12 and installed Keras in editable mode, but ./shell/api_gen.sh keeps failing on Windows with ModuleNotFoundError: No module named "keras\src" and a WinError 32 lock on tmp_build_dir. Could you please run the API generator on Linux/macOS and commit the keras/api updates?

it keeps failing with the following error:

$ ./shell/api_gen.sh
Generating api directory with public APIs...
Traceback (most recent call last):
  File "C:\Users\karth\keras\api_gen.py", line 159, in build
    namex.generate_api_files(
  File "C:\Users\karth\keras\venv\Lib\site-packages\namex\generate.py", line 76, in generate_api_files
    mod = importlib.import_module(entry_point, package=".")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\karth\AppData\Local\Programs\Python\Python312\Lib\importlib\__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1387, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1310, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "<frozen importlib._bootstrap>", line 1387, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1324, in _find_and_load_unlocked
ModuleNotFoundError: No module named 'keras\\src'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\karth\keras\api_gen.py", line 187, in <module>
    build()
  File "C:\Users\karth\keras\api_gen.py", line 183, in build
    shutil.rmtree(build_dir)
  File "C:\Users\karth\AppData\Local\Programs\Python\Python312\Lib\shutil.py", line 781, in rmtree
    return _rmtree_unsafe(path, onexc)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\karth\AppData\Local\Programs\Python\Python312\Lib\shutil.py", line 639, in _rmtree_unsafe
    onexc(os.rmdir, path, err)
  File "C:\Users\karth\AppData\Local\Programs\Python\Python312\Lib\shutil.py", line 637, in _rmtree_unsafe
    os.rmdir(path)
PermissionError: [WinError 32] The process cannot access the file because it is being used by another process: 'C:\\Users\\karth\\keras\\tmp_build_dir'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Add AdaptivePooling - Avg/Max

4 participants