Skip to content

Fix Discretization layer graph mode bug #21514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

divyashreepathihalli
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli commented Jul 25, 2025

fixes #21468

Root cause : in init method

the line sets compute_dtype="int64" for "int" output_mode, which causes the model graph to cast float inputs to int64 before they reach the Discretization layer.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug in the Discretization preprocessing layer that affected its behavior in graph execution environments. The changes ensure that the layer's internal and output data types are correctly handled, leading to consistent and accurate results regardless of whether the model is run in eager or graph mode. A new dedicated test case has been added to validate this fix.

Highlights

  • Fix Discretization Layer Graph Mode: This PR addresses a bug in the Discretization layer that caused incorrect behavior or dtype issues when used within TensorFlow's graph mode (e.g., when building a Keras Model or using model.predict).
  • Standardized Internal Dtype: The internal dtype for the Discretization layer's computations is now explicitly set to float32 in the constructor. This ensures consistency and compatibility with backend numerical operations like digitize.
  • Corrected Output Specification Dtype: The compute_output_spec method has been updated to correctly determine the output tensor's dtype. If output_mode is 'int', the output dtype will be int64; otherwise, it will use the layer's compute_dtype (which is now float32). This is crucial for proper graph compilation and type inference.
  • Added Graph Mode Test Coverage: A new test case, test_discretization_eager_vs_graph, has been introduced. This test specifically verifies the Discretization layer's output consistency when run in eager mode, via a Keras Model call, and through model.predict, ensuring the fix correctly addresses graph mode discrepancies.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a graph mode bug in the Discretization layer. The changes correctly adjust the layer's compute_dtype to handle float inputs and fix the output specification for symbolic tensors in graph mode.

My review identifies a critical issue where the eager execution path remains inconsistent with the graph mode behavior, which could lead to unexpected dtype outputs. I've provided a suggestion to align them.

Additionally, I've recommended improving the new test case by adding assertions to make it self-verifying, which is a standard best practice for unit tests.

@codecov-commenter
Copy link

codecov-commenter commented Jul 25, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 77.86%. Comparing base (8bf6a58) to head (949bf6d).

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21514      +/-   ##
==========================================
- Coverage   82.72%   77.86%   -4.86%     
==========================================
  Files         567      567              
  Lines       56245    56247       +2     
  Branches     8790     8790              
==========================================
- Hits        46527    43798    -2729     
- Misses       7561    10374    +2813     
+ Partials     2157     2075      -82     
Flag Coverage Δ
keras 77.70% <100.00%> (-4.83%) ⬇️
keras-jax ?
keras-numpy 58.42% <100.00%> (+<0.01%) ⬆️
keras-openvino 34.57% <0.00%> (-0.01%) ⬇️
keras-tensorflow 64.35% <100.00%> (+<0.01%) ⬆️
keras-torch 63.98% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@fchollet
Copy link
Collaborator

Thanks for the PR -- please take a look at the test failure:

FAILED keras/src/layers/preprocessing/discretization_test.py::DiscretizationTest::test_discretization_basics - AssertionError: expected output dtype int64, got float16

@@ -96,8 +96,7 @@ def __init__(
name=None,
):
if dtype is None:
dtype = "int64" if output_mode == "int" else backend.floatx()

dtype = "float32"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should just remove the whole if dtype is None block, the base layer class will handle None properly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that! It results in the unit tests failing

FAILED keras/src/layers/preprocessing/discretization_test.py::DiscretizationTest::test_discretization_basics - AssertionError: expected output dtype int64, got int32:
- int32
+ int64

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it fails the same way with or without on JAX.

The thing is, this class has this:

    @property
    def input_dtype(self):
        return backend.floatx()

So I don't understand why the inputs were cast to ints.

Maybe you can try to override @property ... compute_dtype?

@divyashreepathihalli
Copy link
Collaborator Author

Pytorch GPU error is unrelated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Discretization layer returns inconsistent result in graph mode.
6 participants