-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix Discretization layer graph mode bug #21514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix Discretization layer graph mode bug #21514
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request resolves a critical bug in the Discretization
preprocessing layer that affected its behavior in graph execution environments. The changes ensure that the layer's internal and output data types are correctly handled, leading to consistent and accurate results regardless of whether the model is run in eager or graph mode. A new dedicated test case has been added to validate this fix.
Highlights
- Fix Discretization Layer Graph Mode: This PR addresses a bug in the
Discretization
layer that caused incorrect behavior ordtype
issues when used within TensorFlow's graph mode (e.g., when building a KerasModel
or usingmodel.predict
). - Standardized Internal Dtype: The internal
dtype
for theDiscretization
layer's computations is now explicitly set tofloat32
in the constructor. This ensures consistency and compatibility with backend numerical operations likedigitize
. - Corrected Output Specification Dtype: The
compute_output_spec
method has been updated to correctly determine the output tensor'sdtype
. Ifoutput_mode
is 'int', the outputdtype
will beint64
; otherwise, it will use the layer'scompute_dtype
(which is nowfloat32
). This is crucial for proper graph compilation and type inference. - Added Graph Mode Test Coverage: A new test case,
test_discretization_eager_vs_graph
, has been introduced. This test specifically verifies theDiscretization
layer's output consistency when run in eager mode, via a KerasModel
call, and throughmodel.predict
, ensuring the fix correctly addresses graph mode discrepancies.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix a graph mode bug in the Discretization
layer. The changes correctly adjust the layer's compute_dtype
to handle float inputs and fix the output specification for symbolic tensors in graph mode.
My review identifies a critical issue where the eager execution path remains inconsistent with the graph mode behavior, which could lead to unexpected dtype outputs. I've provided a suggestion to align them.
Additionally, I've recommended improving the new test case by adding assertions to make it self-verifying, which is a standard best practice for unit tests.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21514 +/- ##
==========================================
- Coverage 82.72% 77.86% -4.86%
==========================================
Files 567 567
Lines 56245 56247 +2
Branches 8790 8790
==========================================
- Hits 46527 43798 -2729
- Misses 7561 10374 +2813
+ Partials 2157 2075 -82
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Thanks for the PR -- please take a look at the test failure:
|
@@ -96,8 +96,7 @@ def __init__( | |||
name=None, | |||
): | |||
if dtype is None: | |||
dtype = "int64" if output_mode == "int" else backend.floatx() | |||
|
|||
dtype = "float32" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should just remove the whole if dtype is None
block, the base layer class will handle None
properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried that! It results in the unit tests failing
FAILED keras/src/layers/preprocessing/discretization_test.py::DiscretizationTest::test_discretization_basics - AssertionError: expected output dtype int64, got int32:
- int32
+ int64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it fails the same way with or without on JAX.
The thing is, this class has this:
@property
def input_dtype(self):
return backend.floatx()
So I don't understand why the inputs were cast to ints.
Maybe you can try to override @property ... compute_dtype
?
Pytorch GPU error is unrelated |
fixes #21468
Root cause : in init method
the line sets compute_dtype="int64" for "int" output_mode, which causes the model graph to cast float inputs to int64 before they reach the Discretization layer.