Skip to content

Commit 022fc47

Browse files
committed
Let TF-GNN choose between keras or tf_keras consistently with TF 2.15:
both provide Keras 2.15, but it matters which one is used, because they have separate class hierarchies and global registries. Along the way, refactor the nested case distinctions of tf_internal.py into a clear list of supported older TF/Keras versions. PiperOrigin-RevId: 604619976
1 parent b7a9027 commit 022fc47

File tree

1 file changed

+69
-46
lines changed

1 file changed

+69
-46
lines changed

tensorflow_gnn/graph/tf_internal.py

Lines changed: 69 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
TODO(b/188399175): Use the public ExtensionType API instead.
1818
"""
1919

20+
import os
21+
22+
##
23+
## Part 1: TensorFlow symbols
24+
##
25+
2026
# The following imports work in all supported versions of TF.
2127
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top,g-bad-import-order
2228
from tensorflow.python.framework import composite_tensor
@@ -32,39 +38,6 @@
3238
except ImportError:
3339
type_spec_registry = None # Not available before TF 2.12.
3440

35-
# NOTE: See ../__init__.py for an up-front check of supported Keras versions.
36-
37-
try:
38-
try:
39-
# Get Keras v2 from the separate tf_keras package.
40-
# In OSS, it exists for TF2.14+. It may become required for TF2.16+.
41-
from tf_keras.src.engine import keras_tensor # pytype: disable=import-error
42-
from tf_keras.src.layers import core as core_layers # pytype: disable=import-error
43-
import tf_keras.src.backend as keras_backend # pytype: disable=import-error
44-
except ImportError:
45-
# Get Keras v2 from the keras package.
46-
# In OSS, this is possible for TF2.15 and older.
47-
import keras # pytype: disable=import-error
48-
if not keras.__version__.startswith('2.'):
49-
raise ImportError(
50-
'tensorflow_gnn requires tf_keras to be installed or keras version <'
51-
f' 3. Instead got keras version {keras.__version__}.'
52-
) from None # A Keras version mismatch is different to lacking tf_keras.
53-
import keras # pytype: disable=import-error
54-
if hasattr(keras, 'src'): # As of TF/Keras 2.13.
55-
from keras.src.engine import keras_tensor # pytype: disable=import-error
56-
from keras.src.layers import core as core_layers # pytype: disable=import-error
57-
import keras.src.backend as keras_backend # pytype: disable=import-error
58-
else:
59-
from keras.engine import keras_tensor # pytype: disable=import-error
60-
from keras.layers import core as core_layers # pytype: disable=import-error
61-
import keras.backend as keras_backend # pytype: disable=import-error
62-
except ImportError:
63-
# Internal
64-
keras_tensor = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access
65-
core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access
66-
keras_backend = tf._keras_internal.backend # pylint: disable=protected-access
67-
6841
CompositeTensor = composite_tensor.CompositeTensor
6942
BatchableTypeSpec = type_spec.BatchableTypeSpec
7043
type_spec_register = (
@@ -79,28 +52,78 @@
7952
type_spec_registry.lookup if type_spec_registry is not None
8053
else type_spec.lookup)
8154

82-
try:
83-
# These types are semi-public as of TF/Keras 2.13.
84-
# Whenever possible, get them the official way.
55+
OpDispatcher = tf.__internal__.dispatch.OpDispatcher
56+
57+
58+
##
59+
## Part 2: Keras symbols, compatible with `tf.keras.*`
60+
##
61+
62+
# pytype: disable=import-error
63+
64+
if tf.__version__.startswith("2.12."):
65+
# tf.keras is keras 2.12, which does not yet have the `src` subdirectory.
66+
from keras import backend as keras_backend
67+
from keras.engine import keras_tensor as kt
68+
from keras.layers import core as core_layers
69+
# In 2.12, these symbols are not exposed yet under tf.keras.__internal__.
70+
KerasTensor = kt.KerasTensor
71+
RaggedKerasTensor = kt.RaggedKerasTensor
72+
73+
elif tf.__version__.startswith("2.13.") or tf.__version__.startswith("2.14."):
8574
KerasTensor = tf.keras.__internal__.KerasTensor
8675
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
87-
except AttributeError:
88-
KerasTensor = keras_tensor.KerasTensor
89-
RaggedKerasTensor = keras_tensor.RaggedKerasTensor
90-
# These KerasTensor helpers are still private in TF/Keras 2.13.
91-
register_keras_tensor_specialization = (
92-
keras_tensor.register_keras_tensor_specialization)
93-
delegate_property = core_layers._delegate_property # pylint: disable=protected-access
94-
delegate_method = core_layers._delegate_method # pylint: disable=protected-access
76+
# tf.keras is keras.
77+
# For TF 2.14, there also exists a tf_keras package, but TF does not use it.
78+
from keras.src import backend as keras_backend
79+
from keras.src.engine import keras_tensor as kt
80+
from keras.src.layers import core as core_layers
9581

96-
OpDispatcher = tf.__internal__.dispatch.OpDispatcher
82+
elif tf.__version__.startswith("2.15."):
83+
KerasTensor = tf.keras.__internal__.KerasTensor
84+
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
85+
# OSS TensorFlow 2.15 can choose between keras 2.15 and tf_keras 2.15
86+
# BUT THESE ARE DIFFERENT PACKAGES WITH SEPARATE GLOBAL REGISTRIES
87+
# so it is essential that we pick the right one by replicating the logic from
88+
# https://github.com/tensorflow/tensorflow/blob/r2.15/tensorflow/python/util/lazy_loader.py#L96
89+
if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"):
90+
from tf_keras.src import backend as keras_backend
91+
from tf_keras.src.layers import core as core_layers
92+
from tf_keras.src.engine import keras_tensor as kt
93+
else:
94+
from keras.src import backend as keras_backend
95+
from keras.src.layers import core as core_layers
96+
from keras.src.engine import keras_tensor as kt
9797

98+
elif hasattr(tf, "_keras_internal"): # Special case: internal.
99+
KerasTensor = tf.keras.__internal__.KerasTensor
100+
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
101+
kt = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access
102+
core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access
103+
keras_backend = tf._keras_internal.backend # pylint: disable=protected-access
104+
105+
else: # TF2.16 and onwards.
106+
# ../__init__.py has already checked that tf.keras has version 2, not 3,
107+
# which implies that tf.keras is tf_keras, and we do not second-guess
108+
# the selection logic.
109+
KerasTensor = tf.keras.__internal__.KerasTensor
110+
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
111+
from tf_keras.src import backend as keras_backend
112+
from tf_keras.src.layers import core as core_layers
113+
from tf_keras.src.engine import keras_tensor as kt
114+
115+
# pytype: enable=import-error
116+
117+
register_keras_tensor_specialization = kt.register_keras_tensor_specialization
118+
delegate_property = core_layers._delegate_property # pylint: disable=protected-access
119+
delegate_method = core_layers._delegate_method # pylint: disable=protected-access
98120
unique_keras_object_name = keras_backend.unique_object_name
99121

100122
# Delete imports, in their order above.
101123
del composite_tensor
102124
del type_spec
103125
del tf
104126
del type_spec_registry
105-
del keras_tensor
127+
del keras_backend
106128
del core_layers
129+
del kt

0 commit comments

Comments
 (0)