|
17 | 17 | TODO(b/188399175): Use the public ExtensionType API instead. |
18 | 18 | """ |
19 | 19 |
|
| 20 | +import os |
| 21 | + |
| 22 | +## |
| 23 | +## Part 1: TensorFlow symbols |
| 24 | +## |
| 25 | + |
20 | 26 | # The following imports work in all supported versions of TF. |
21 | 27 | # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top,g-bad-import-order |
22 | 28 | from tensorflow.python.framework import composite_tensor |
|
32 | 38 | except ImportError: |
33 | 39 | type_spec_registry = None # Not available before TF 2.12. |
34 | 40 |
|
35 | | -# NOTE: See ../__init__.py for an up-front check of supported Keras versions. |
36 | | - |
37 | | -try: |
38 | | - try: |
39 | | - # Get Keras v2 from the separate tf_keras package. |
40 | | - # In OSS, it exists for TF2.14+. It may become required for TF2.16+. |
41 | | - from tf_keras.src.engine import keras_tensor # pytype: disable=import-error |
42 | | - from tf_keras.src.layers import core as core_layers # pytype: disable=import-error |
43 | | - import tf_keras.src.backend as keras_backend # pytype: disable=import-error |
44 | | - except ImportError: |
45 | | - # Get Keras v2 from the keras package. |
46 | | - # In OSS, this is possible for TF2.15 and older. |
47 | | - import keras # pytype: disable=import-error |
48 | | - if not keras.__version__.startswith('2.'): |
49 | | - raise ImportError( |
50 | | - 'tensorflow_gnn requires tf_keras to be installed or keras version <' |
51 | | - f' 3. Instead got keras version {keras.__version__}.' |
52 | | - ) from None # A Keras version mismatch is different to lacking tf_keras. |
53 | | - import keras # pytype: disable=import-error |
54 | | - if hasattr(keras, 'src'): # As of TF/Keras 2.13. |
55 | | - from keras.src.engine import keras_tensor # pytype: disable=import-error |
56 | | - from keras.src.layers import core as core_layers # pytype: disable=import-error |
57 | | - import keras.src.backend as keras_backend # pytype: disable=import-error |
58 | | - else: |
59 | | - from keras.engine import keras_tensor # pytype: disable=import-error |
60 | | - from keras.layers import core as core_layers # pytype: disable=import-error |
61 | | - import keras.backend as keras_backend # pytype: disable=import-error |
62 | | -except ImportError: |
63 | | - # Internal |
64 | | - keras_tensor = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access |
65 | | - core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access |
66 | | - keras_backend = tf._keras_internal.backend # pylint: disable=protected-access |
67 | | - |
68 | 41 | CompositeTensor = composite_tensor.CompositeTensor |
69 | 42 | BatchableTypeSpec = type_spec.BatchableTypeSpec |
70 | 43 | type_spec_register = ( |
|
79 | 52 | type_spec_registry.lookup if type_spec_registry is not None |
80 | 53 | else type_spec.lookup) |
81 | 54 |
|
82 | | -try: |
83 | | - # These types are semi-public as of TF/Keras 2.13. |
84 | | - # Whenever possible, get them the official way. |
| 55 | +OpDispatcher = tf.__internal__.dispatch.OpDispatcher |
| 56 | + |
| 57 | + |
| 58 | +## |
| 59 | +## Part 2: Keras symbols, compatible with `tf.keras.*` |
| 60 | +## |
| 61 | + |
| 62 | +# pytype: disable=import-error |
| 63 | + |
| 64 | +if tf.__version__.startswith("2.12."): |
| 65 | + # tf.keras is keras 2.12, which does not yet have the `src` subdirectory. |
| 66 | + from keras import backend as keras_backend |
| 67 | + from keras.engine import keras_tensor as kt |
| 68 | + from keras.layers import core as core_layers |
| 69 | + # In 2.12, these symbols are not exposed yet under tf.keras.__internal__. |
| 70 | + KerasTensor = kt.KerasTensor |
| 71 | + RaggedKerasTensor = kt.RaggedKerasTensor |
| 72 | + |
| 73 | +elif tf.__version__.startswith("2.13.") or tf.__version__.startswith("2.14."): |
85 | 74 | KerasTensor = tf.keras.__internal__.KerasTensor |
86 | 75 | RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor |
87 | | -except AttributeError: |
88 | | - KerasTensor = keras_tensor.KerasTensor |
89 | | - RaggedKerasTensor = keras_tensor.RaggedKerasTensor |
90 | | -# These KerasTensor helpers are still private in TF/Keras 2.13. |
91 | | -register_keras_tensor_specialization = ( |
92 | | - keras_tensor.register_keras_tensor_specialization) |
93 | | -delegate_property = core_layers._delegate_property # pylint: disable=protected-access |
94 | | -delegate_method = core_layers._delegate_method # pylint: disable=protected-access |
| 76 | + # tf.keras is keras. |
| 77 | + # For TF 2.14, there also exists a tf_keras package, but TF does not use it. |
| 78 | + from keras.src import backend as keras_backend |
| 79 | + from keras.src.engine import keras_tensor as kt |
| 80 | + from keras.src.layers import core as core_layers |
95 | 81 |
|
96 | | -OpDispatcher = tf.__internal__.dispatch.OpDispatcher |
| 82 | +elif tf.__version__.startswith("2.15."): |
| 83 | + KerasTensor = tf.keras.__internal__.KerasTensor |
| 84 | + RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor |
| 85 | + # OSS TensorFlow 2.15 can choose between keras 2.15 and tf_keras 2.15 |
| 86 | + # BUT THESE ARE DIFFERENT PACKAGES WITH SEPARATE GLOBAL REGISTRIES |
| 87 | + # so it is essential that we pick the right one by replicating the logic from |
| 88 | + # https://github.com/tensorflow/tensorflow/blob/r2.15/tensorflow/python/util/lazy_loader.py#L96 |
| 89 | + if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"): |
| 90 | + from tf_keras.src import backend as keras_backend |
| 91 | + from tf_keras.src.layers import core as core_layers |
| 92 | + from tf_keras.src.engine import keras_tensor as kt |
| 93 | + else: |
| 94 | + from keras.src import backend as keras_backend |
| 95 | + from keras.src.layers import core as core_layers |
| 96 | + from keras.src.engine import keras_tensor as kt |
97 | 97 |
|
| 98 | +elif hasattr(tf, "_keras_internal"): # Special case: internal. |
| 99 | + KerasTensor = tf.keras.__internal__.KerasTensor |
| 100 | + RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor |
| 101 | + kt = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access |
| 102 | + core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access |
| 103 | + keras_backend = tf._keras_internal.backend # pylint: disable=protected-access |
| 104 | + |
| 105 | +else: # TF2.16 and onwards. |
| 106 | + # ../__init__.py has already checked that tf.keras has version 2, not 3, |
| 107 | + # which implies that tf.keras is tf_keras, and we do not second-guess |
| 108 | + # the selection logic. |
| 109 | + KerasTensor = tf.keras.__internal__.KerasTensor |
| 110 | + RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor |
| 111 | + from tf_keras.src import backend as keras_backend |
| 112 | + from tf_keras.src.layers import core as core_layers |
| 113 | + from tf_keras.src.engine import keras_tensor as kt |
| 114 | + |
| 115 | +# pytype: enable=import-error |
| 116 | + |
| 117 | +register_keras_tensor_specialization = kt.register_keras_tensor_specialization |
| 118 | +delegate_property = core_layers._delegate_property # pylint: disable=protected-access |
| 119 | +delegate_method = core_layers._delegate_method # pylint: disable=protected-access |
98 | 120 | unique_keras_object_name = keras_backend.unique_object_name |
99 | 121 |
|
100 | 122 | # Delete imports, in their order above. |
101 | 123 | del composite_tensor |
102 | 124 | del type_spec |
103 | 125 | del tf |
104 | 126 | del type_spec_registry |
105 | | -del keras_tensor |
| 127 | +del keras_backend |
106 | 128 | del core_layers |
| 129 | +del kt |
0 commit comments