Skip to content

Commit 7c52d6f

Browse files
lukebaumanncopybara-github
authored andcommitted
Adding split_by_mesh_axis to experimental for use by a new experimental reshard.
This is only available for JAX 0.7.2 and above. PiperOrigin-RevId: 810549332
1 parent cfcc50d commit 7c52d6f

File tree

2 files changed

+223
-0
lines changed

2 files changed

+223
-0
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Experimental split by mesh axis API."""
15+
16+
from typing import Any, Sequence
17+
18+
import jax
19+
from pathwaysutils import jax as pw_jax
20+
from pathwaysutils import lru_cache
21+
22+
23+
@lru_cache.lru_cache(maxsize=16384)
24+
def _cached_named_sharding(
25+
mesh: jax.sharding.Mesh,
26+
spec: jax.sharding.PartitionSpec,
27+
memory_kind: str | None = None,
28+
):
29+
return jax.sharding.NamedSharding(mesh, spec, memory_kind=memory_kind)
30+
31+
32+
@lru_cache.lru_cache(maxsize=1024)
33+
def _get_per_mesh_shardings(
34+
meshes: tuple[jax.sharding.Mesh, ...],
35+
spec: jax.sharding.PartitionSpec,
36+
memory_kind: str | None = None,
37+
) -> Sequence[jax.sharding.NamedSharding]:
38+
"""Returns per-mesh shardings."""
39+
return [
40+
_cached_named_sharding(mesh, spec, memory_kind=memory_kind)
41+
for mesh in meshes
42+
]
43+
44+
45+
def split_by_mesh_axis(
46+
arrays: Any,
47+
mesh_axis: str,
48+
mesh_axis_indices_or_sections: int | Sequence[int] | None = None,
49+
*,
50+
donate: bool = False,
51+
) -> Sequence[Any]:
52+
"""Splits arrays by a mesh axis, and returns arrays on each split mesh.
53+
54+
Args:
55+
arrays: PyTree of JAX arrays with NamedSharding whose mesh is identical.
56+
mesh_axis: Mesh axis to split the arrays by.
57+
mesh_axis_indices_or_sections: If it is an integer, N, the mesh axis will be
58+
divided into N equal submeshes along `mesh_axis`. If it is a 1-D sequence,
59+
the entries indicate the boundary on the mesh axis along `mesh_axis`. For
60+
example, [2, 3] for splitting first mesh axis results in three output
61+
arrays (per each input array) on `mesh[:2], mesh[2:3], mesh[3:]`,
62+
respectively. If it is None, it will be the same as `N =
63+
mesh.axis_size[mesh.axis_names.index(mesh_axis)]`. Note: the sequence must
64+
be monotonoically increasing and should not contain the start or end
65+
boundaries.
66+
donate: Whether to donate input arrays. By default, input arrays are
67+
aliased.
68+
69+
Returns:
70+
A sequence of PyTrees whose structure is the same as `arrays`.
71+
Each element `i` has arrays with their shards filtered out to match
72+
mesh corresponding mesh constructed according to
73+
`mesh_axis_indices_or_sections`. An array's shape remains the same if the
74+
array is replicated along `mesh_axis`, or is shrunk by a split factor
75+
computed from `mesh_axis_indices_or_sections` if the array is partitioned
76+
along `mesh_axis`.
77+
"""
78+
flat_arrays, treedef = jax.tree.flatten(arrays)
79+
80+
if not flat_arrays:
81+
return arrays
82+
83+
sharding = flat_arrays[0].sharding
84+
if not isinstance(sharding, jax.sharding.NamedSharding):
85+
raise ValueError(f"Array must have a NamedSharding. Got {sharding=}")
86+
mesh = sharding.mesh
87+
mesh_axis_idx = mesh.axis_names.index(mesh_axis)
88+
sharded_dim_idxs = []
89+
for array in flat_arrays:
90+
sharding = array.sharding
91+
if not isinstance(sharding, jax.sharding.NamedSharding):
92+
raise ValueError(f"Array must have a NamedSharding. Got {sharding=}")
93+
if mesh != sharding.mesh:
94+
raise ValueError(
95+
f"Array sharding mesh must match, but got {mesh=}, {sharding.mesh=}"
96+
)
97+
if sharding._logical_device_ids is not None: # pylint: disable=protected-access
98+
raise ValueError(
99+
"Array sharding's _logical_device_ids must be None, but got"
100+
f" {sharding._logical_device_ids=}" # pylint: disable=protected-access
101+
)
102+
sharded_dim = -1
103+
for dim_idx, dim_spec in enumerate(sharding.spec):
104+
flat_dim_spec, _ = jax.tree.flatten(dim_spec)
105+
if mesh_axis in flat_dim_spec:
106+
sharded_dim = dim_idx
107+
break
108+
sharded_dim_idxs.append(sharded_dim)
109+
110+
# Transform mesh_axis_indices_or_sections into a list of axis boundaries,
111+
# with the last entry being the size of the mesh_axis.
112+
if mesh_axis_indices_or_sections is None:
113+
# If mesh_axis_indices_or_sections is None, the arrays will be divided
114+
# along the mesh_axis.
115+
mesh_axis_indices_or_sections = mesh.axis_sizes[mesh_axis_idx]
116+
if isinstance(mesh_axis_indices_or_sections, int):
117+
# Expand the mesh_axis_indices_or_sections to a list indicating the
118+
# boundaries of mesh axis.
119+
if mesh.axis_sizes[mesh_axis_idx] % mesh_axis_indices_or_sections != 0:
120+
raise ValueError(
121+
"The size of the `mesh_axis` must be divisible by"
122+
" `mesh_axis_indices_or_sections`. Got"
123+
f" {mesh.axis_sizes[mesh_axis_idx]} and"
124+
f" {mesh_axis_indices_or_sections=}"
125+
)
126+
axis_size = mesh.axis_sizes[mesh_axis_idx] // mesh_axis_indices_or_sections
127+
mesh_axis_sections = list(
128+
range(axis_size, mesh.axis_sizes[mesh_axis_idx] + 1, axis_size)
129+
)
130+
else:
131+
mesh_axis_sections = mesh_axis_indices_or_sections
132+
for i, boundary in enumerate(mesh_axis_sections):
133+
if boundary <= 0 or boundary >= mesh.axis_sizes[mesh_axis_idx]:
134+
raise ValueError(
135+
"Mesh axis sections values must be in range (0,"
136+
f" axis_size={mesh.axis_sizes[mesh_axis_idx]}) to avoid an empty"
137+
f" section, but got {mesh_axis_sections=}."
138+
)
139+
if i > 0 and mesh_axis_sections[i] <= mesh_axis_sections[i - 1]:
140+
raise ValueError(
141+
"Mesh axis sections must be monotonically increasing, but got"
142+
f" {mesh_axis_sections=}."
143+
)
144+
mesh_axis_sections += [mesh.axis_sizes[mesh_axis_idx]]
145+
146+
submeshes = []
147+
axis_boundary_start = 0
148+
slices = [slice(None)] * len(mesh.axis_sizes)
149+
for axis_boundary_end in mesh_axis_sections:
150+
slices[mesh_axis_idx] = slice(axis_boundary_start, axis_boundary_end)
151+
submeshes.append(
152+
jax.sharding.Mesh(mesh.devices[tuple(slices)], mesh.axis_names)
153+
)
154+
axis_boundary_start = axis_boundary_end
155+
156+
submeshes_tuple = tuple(submeshes)
157+
submesh_shardings = [
158+
_get_per_mesh_shardings(
159+
submeshes_tuple, x.sharding.spec, x.sharding.memory_kind
160+
)
161+
for x in flat_arrays
162+
]
163+
164+
flat_split_arrays = pw_jax.jaxlib_pathways._split_by_mesh_axis( # pylint: disable=protected-access
165+
arrays=flat_arrays,
166+
sharded_dim_idxs=sharded_dim_idxs,
167+
mesh_axis_sizes=mesh.axis_sizes,
168+
mesh_axis_idx=mesh_axis_idx,
169+
mesh_axis_sections=mesh_axis_sections,
170+
submesh_shardings=submesh_shardings,
171+
donate=donate,
172+
)
173+
174+
# Convert the flat arrays to a list of a PyTree per submesh.
175+
outer_treedef = jax.tree.structure(["*"] * len(flat_split_arrays))
176+
inner_treedef = jax.tree.structure(["*"] * len(submeshes))
177+
return [
178+
jax.tree.unflatten(treedef, flat_submesh_arrays)
179+
for flat_submesh_arrays in jax.tree.transpose(
180+
outer_treedef, inner_treedef, flat_split_arrays
181+
)
182+
]

pathwaysutils/jax/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,32 @@
1818
"""
1919

2020
from typing import Any
21+
import jax
22+
23+
24+
class _FakeJaxModule:
25+
"""A fake module that raises an ImportError when accessed.
26+
27+
This is used to provide a placeholder for JAX modules that are not available
28+
in older versions of JAX, raising a helpful error message if they are
29+
inadvertently used.
30+
"""
31+
32+
def __init__(self, name, version):
33+
self.__name__ = name
34+
self.version = version
35+
self.error_message = (
36+
f"Module {self.__name__} does not exist until JAX {self.version}. "
37+
f"The current version of JAX is {jax.__version__}. "
38+
"Using this modules results in this runtime error."
39+
)
40+
41+
def __getattr__(self, name):
42+
raise ImportError(self.error_message)
43+
44+
def __call__(self, *args, **kwargs):
45+
raise ImportError(self.error_message)
46+
2147

2248
try:
2349
# jax>=0.7.0
@@ -47,3 +73,18 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable
4773

4874
ifrt_proxy = xla_extension.ifrt_proxy
4975
del xla_extension
76+
77+
78+
try:
79+
# jax>=0.7.2
80+
from jax.jaxlib import _pathways # pylint: disable=g-import-not-at-top
81+
82+
jaxlib_pathways = _pathways
83+
del _pathways
84+
except (ModuleNotFoundError, AttributeError):
85+
# jax<0.7.2
86+
87+
jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.7.2")
88+
89+
90+
del _FakeJaxModule

0 commit comments

Comments
 (0)