-
Notifications
You must be signed in to change notification settings - Fork 2k
Issues with JAX 0.9.x and other latest libraries #8655
Description
I was trying to convert modules from the latest Flax/NNX v0.12.6 to tfjs and encountered some issues. The report below was written with claude. My other input is that ydf==0.16.1 is available now, and this may or may not be a way of resolving the issue with tensorflow_decision_forests.
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 26.3, M3 Air
- TensorFlow.js installed from (npm or script link): 0fc04d9
- TensorFlow.js version (use command below): 4.22.0 0fc04d9
- Browser version: N/A
- Tensorflow.js Converter Version: 0fc04d9
Describe the current behavior
The Python tensorflowjs converter cannot be installed or used on Python 3.13 with current versions of its dependencies (NumPy 2.x, JAX 0.9.x, TensorFlow 2.20). There are four distinct issues:
-
setup.pymissing modules:tf_module_mapperandnormalize_bias_addare imported at runtime but not listed inpy_modulesinsetup.py, so they are missing from built wheels. This causesImportError: cannot import name 'tf_module_mapper'on import. -
Hard imports of optional dependencies block import on Python 3.13:
tf_saved_model_conversion_v2.pyhas top-levelimport tensorflow_decision_forests(line 28) andimport tensorflow_hub as hub(line 54).tensorflow_decision_forestshas no Python 3.13 wheels, andtensorflow_hubis incompatible with TensorFlow 2.20+. Since these are only needed for specific conversion paths (TFDF saved models and TF Hub modules), they should be guarded withtry/except. -
requirements.txtlists unavailable dependencies:tensorflow-decision-forests>=1.9.0andtensorflow-hub>=0.16.1are listed as hard requirements, but neither is installable on Python 3.13. These should be moved to optional/extra dependencies. -
convert_jaxbroken with JAX >= 0.9:jax_conversion.pypassesenable_xla=Falsetojax2tf.convert(line 122), but JAX 0.9.x has deprecated and removed both theenable_xlaandnative_serializationparameters. JAX now always uses native serialization, which producesXlaCallModuleops that TF.js does not support. The only supported calling convention versions are 9 and 10, both of which emitXlaCallModule. This makesconvert_jaxcompletely non-functional with current JAX.
Describe the expected behavior
The package should install and convert_jax should work on Python 3.13 with current dependency versions.
Standalone code to reproduce the issue
uv venv --python 3.13
source .venv/bin/activateTo get past issues 1-3, apply the following patches before installing:
In tfjs/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py, replace import tensorflow_decision_forests (line 28) with:
try:
import tensorflow_decision_forests
except ImportError:
passAnd replace import tensorflow_hub as hub (line 54) with:
try:
import tensorflow_hub as hub
except (ImportError, AttributeError):
hub = NoneIn tfjs/tfjs-converter/python/setup.py, add the missing modules to py_modules:
'tensorflowjs.converters.normalize_bias_add',
'tensorflowjs.converters.tf_module_mapper',In tfjs/tfjs-converter/python/requirements.txt, remove tensorflow-decision-forests>=1.9.0 and tensorflow-hub>=0.16.1.
Then install:
uv pip install tensorflow tf-keras h5py jax jaxlib flax numpy six packaging
uv pip install --no-build-isolation /path/to/tfjs/tfjs-converter/pythonThen run:
import tensorflow as tf
import tensorflowjs as tfjs
import numpy as np
def linear(params, x):
return x @ params['w'] + params['b']
tfjs.converters.convert_jax(
apply_fn=linear,
params={'w': np.random.randn(4, 3).astype(np.float32),
'b': np.zeros(3, dtype=np.float32)},
input_signatures=[tf.TensorSpec([1, 4], tf.float32)],
model_dir='/tmp/tfjs_test',
)This produces:
ValueError: Unsupported Ops in the model before optimization
XlaCallModule
Other info / logs
Environment:
Python 3.13.12
numpy 2.4.3
tensorflow 2.20.0
jax 0.9.2
jaxlib 0.9.2
flax 0.12.6
Full traceback for issue 4 (XlaCallModule):
File "demo.py", line 8, in <module>
tfjs.converters.convert_jax(...)
File "tensorflowjs/converters/jax_conversion.py", line 148, in convert_jax
saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir, **tfjs_converter_params)
File "tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 988, in convert_tf_saved_model
_convert_tf_saved_model(...)
File "tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 863, in _convert_tf_saved_model
optimized_graph = optimize_graph(...)
File "tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 159, in optimize_graph
raise ValueError('Unsupported Ops in the model before optimization\n' + ', '.join(unsupported))
ValueError: Unsupported Ops in the model before optimization
XlaCallModule
JAX 0.9.x removed the non-XLA lowering path entirely. jax2tf.convert now ignores enable_xla=False (deprecated) and always produces XlaCallModule ops via native serialization. The jax_export_calling_convention_version config only accepts versions 9-10, both of which use native serialization. There is no configuration in current JAX that avoids XlaCallModule ops.