Skip to content

Fix Discretization layer graph mode bug #21514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
13 changes: 9 additions & 4 deletions keras/src/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def __init__(
name=None,
):
if dtype is None:
dtype = "int64" if output_mode == "int" else backend.floatx()

dtype = "float32"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should just remove the whole if dtype is None block, the base layer class will handle None properly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that! It results in the unit tests failing

FAILED keras/src/layers/preprocessing/discretization_test.py::DiscretizationTest::test_discretization_basics - AssertionError: expected output dtype int64, got int32:
- int32
+ int64

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it fails the same way with or without on JAX.

The thing is, this class has this:

    @property
    def input_dtype(self):
        return backend.floatx()

So I don't understand why the inputs were cast to ints.

Maybe you can try to override @property ... compute_dtype?

super().__init__(name=name, dtype=dtype)

if sparse and not backend.SUPPORTS_SPARSE_TENSORS:
Expand Down Expand Up @@ -213,7 +212,10 @@ def reset_state(self):
self.summary = np.array([[], []], dtype="float32")

def compute_output_spec(self, inputs):
return backend.KerasTensor(shape=inputs.shape, dtype=self.compute_dtype)
output_dtype = (
"int64" if self.output_mode == "int" else self.compute_dtype
)
return backend.KerasTensor(shape=inputs.shape, dtype=output_dtype)

def load_own_variables(self, store):
if len(store) == 1:
Expand All @@ -230,11 +232,14 @@ def call(self, inputs):
)

indices = self.backend.numpy.digitize(inputs, self.bin_boundaries)
output_dtype = (
"int64" if self.output_mode == "int" else self.compute_dtype
)
return numerical_utils.encode_categorical_inputs(
indices,
output_mode=self.output_mode,
depth=len(self.bin_boundaries) + 1,
dtype=self.compute_dtype,
dtype=output_dtype,
sparse=self.sparse,
backend_module=self.backend,
)
Expand Down
Loading