Skip to content

Commit c3b6096

Browse files
lukebaumanncopybara-github
authored andcommitted
Introduce jax shim to support multiple versions of JAX
PiperOrigin-RevId: 789952304
1 parent 240224c commit c3b6096

File tree

4 files changed

+56
-8
lines changed

4 files changed

+56
-8
lines changed

pathwaysutils/jax/__init__.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Pathways JAX abstractions.
15+
16+
This introduces an abstrction layer some JAX APIs that have changed over
17+
`pathwaysutils`'s compatibility window.
18+
"""
19+
20+
from typing import Any
21+
22+
try:
23+
# jax>=0.7.0
24+
from jax.extend import backend # pylint: disable=g-import-not-at-top
25+
26+
register_backend_cache = backend.register_backend_cache
27+
28+
del backend
29+
except AttributeError:
30+
# jax<0.7.0
31+
from jax._src import util # pylint: disable=g-import-not-at-top
32+
33+
def register_backend_cache(cache: Any, name: str): # pylint: disable=unused-argument
34+
return util.cache_clearing_funs.add(cache.cache_clear)
35+
36+
del util
37+
38+
try:
39+
# jax>0.7.0
40+
from jax.extend import backend # pylint: disable=g-import-not-at-top
41+
42+
ifrt_proxy = backend.ifrt_proxy
43+
del backend
44+
except AttributeError:
45+
# jax<=0.7.0
46+
from jax.lib import xla_extension # pylint: disable=g-import-not-at-top
47+
48+
ifrt_proxy = xla_extension.ifrt_proxy
49+
del xla_extension

pathwaysutils/lru_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import functools
1717
from typing import Any, Callable
1818

19-
from jax._src import util
19+
from pathwaysutils import jax as pw_jax
2020

2121

2222
def lru_cache(
@@ -38,7 +38,7 @@ def wrap(f):
3838

3939
wrapper.cache_clear = cached.cache_clear
4040
wrapper.cache_info = cached.cache_info
41-
util.cache_clearing_funs.add(wrapper.cache_clear)
41+
pw_jax.register_backend_cache(wrapper, "Pathways LRU cache")
4242
return wrapper
4343

4444
return wrap

pathwaysutils/proxy_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515

1616
import jax
1717
from jax.extend import backend
18-
from jax.lib.xla_extension import ifrt_proxy
18+
from pathwaysutils import jax as pw_jax
1919

2020

2121
def register_backend_factory():
2222
backend.register_backend_factory(
2323
"proxy",
24-
lambda: ifrt_proxy.get_client(
24+
lambda: pw_jax.ifrt_proxy.get_client(
2525
jax.config.read("jax_backend_target"),
26-
ifrt_proxy.ClientConnectionOptions(),
26+
pw_jax.ifrt_proxy.ClientConnectionOptions(),
2727
),
2828
priority=-1,
2929
)

pathwaysutils/test/proxy_backend_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
import jax
1919
from jax.extend import backend
20-
from jax.lib.xla_extension import ifrt_proxy
20+
from pathwaysutils import jax as pw_jax
2121
from pathwaysutils import proxy_backend
2222

23-
2423
from absl.testing import absltest
2524

2625

@@ -39,7 +38,7 @@ def test_no_proxy_backend_registration_raises_error(self):
3938
def test_proxy_backend_registration(self):
4039
self.enter_context(
4140
mock.patch.object(
42-
ifrt_proxy,
41+
pw_jax.ifrt_proxy,
4342
"get_client",
4443
return_value=mock.MagicMock(),
4544
)

0 commit comments

Comments
 (0)