Skip to content

Issues with JAX 0.9.x and other latest libraries #8655

@DBraun

Description

@DBraun

I was trying to convert modules from the latest Flax/NNX v0.12.6 to tfjs and encountered some issues. The report below was written with claude. My other input is that ydf==0.16.1 is available now, and this may or may not be a way of resolving the issue with tensorflow_decision_forests.

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 26.3, M3 Air
  • TensorFlow.js installed from (npm or script link): 0fc04d9
  • TensorFlow.js version (use command below): 4.22.0 0fc04d9
  • Browser version: N/A
  • Tensorflow.js Converter Version: 0fc04d9

Describe the current behavior

The Python tensorflowjs converter cannot be installed or used on Python 3.13 with current versions of its dependencies (NumPy 2.x, JAX 0.9.x, TensorFlow 2.20). There are four distinct issues:

  1. setup.py missing modules: tf_module_mapper and normalize_bias_add are imported at runtime but not listed in py_modules in setup.py, so they are missing from built wheels. This causes ImportError: cannot import name 'tf_module_mapper' on import.

  2. Hard imports of optional dependencies block import on Python 3.13: tf_saved_model_conversion_v2.py has top-level import tensorflow_decision_forests (line 28) and import tensorflow_hub as hub (line 54). tensorflow_decision_forests has no Python 3.13 wheels, and tensorflow_hub is incompatible with TensorFlow 2.20+. Since these are only needed for specific conversion paths (TFDF saved models and TF Hub modules), they should be guarded with try/except.

  3. requirements.txt lists unavailable dependencies: tensorflow-decision-forests>=1.9.0 and tensorflow-hub>=0.16.1 are listed as hard requirements, but neither is installable on Python 3.13. These should be moved to optional/extra dependencies.

  4. convert_jax broken with JAX >= 0.9: jax_conversion.py passes enable_xla=False to jax2tf.convert (line 122), but JAX 0.9.x has deprecated and removed both the enable_xla and native_serialization parameters. JAX now always uses native serialization, which produces XlaCallModule ops that TF.js does not support. The only supported calling convention versions are 9 and 10, both of which emit XlaCallModule. This makes convert_jax completely non-functional with current JAX.

Describe the expected behavior

The package should install and convert_jax should work on Python 3.13 with current dependency versions.

Standalone code to reproduce the issue

uv venv --python 3.13
source .venv/bin/activate

To get past issues 1-3, apply the following patches before installing:

In tfjs/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py, replace import tensorflow_decision_forests (line 28) with:

try:
  import tensorflow_decision_forests
except ImportError:
  pass

And replace import tensorflow_hub as hub (line 54) with:

try:
  import tensorflow_hub as hub
except (ImportError, AttributeError):
  hub = None

In tfjs/tfjs-converter/python/setup.py, add the missing modules to py_modules:

'tensorflowjs.converters.normalize_bias_add',
'tensorflowjs.converters.tf_module_mapper',

In tfjs/tfjs-converter/python/requirements.txt, remove tensorflow-decision-forests>=1.9.0 and tensorflow-hub>=0.16.1.

Then install:

uv pip install tensorflow tf-keras h5py jax jaxlib flax numpy six packaging
uv pip install --no-build-isolation /path/to/tfjs/tfjs-converter/python

Then run:

import tensorflow as tf
import tensorflowjs as tfjs
import numpy as np

def linear(params, x):
    return x @ params['w'] + params['b']

tfjs.converters.convert_jax(
    apply_fn=linear,
    params={'w': np.random.randn(4, 3).astype(np.float32),
            'b': np.zeros(3, dtype=np.float32)},
    input_signatures=[tf.TensorSpec([1, 4], tf.float32)],
    model_dir='/tmp/tfjs_test',
)

This produces:

ValueError: Unsupported Ops in the model before optimization
XlaCallModule

Other info / logs

Environment:

Python 3.13.12
numpy 2.4.3
tensorflow 2.20.0
jax 0.9.2
jaxlib 0.9.2
flax 0.12.6

Full traceback for issue 4 (XlaCallModule):

  File "demo.py", line 8, in <module>
    tfjs.converters.convert_jax(...)
  File "tensorflowjs/converters/jax_conversion.py", line 148, in convert_jax
    saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir, **tfjs_converter_params)
  File "tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 988, in convert_tf_saved_model
    _convert_tf_saved_model(...)
  File "tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 863, in _convert_tf_saved_model
    optimized_graph = optimize_graph(...)
  File "tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 159, in optimize_graph
    raise ValueError('Unsupported Ops in the model before optimization\n' + ', '.join(unsupported))
ValueError: Unsupported Ops in the model before optimization
XlaCallModule

JAX 0.9.x removed the non-XLA lowering path entirely. jax2tf.convert now ignores enable_xla=False (deprecated) and always produces XlaCallModule ops via native serialization. The jax_export_calling_convention_version config only accepts versions 9-10, both of which use native serialization. There is no configuration in current JAX that avoids XlaCallModule ops.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions