Skip to content

Commit 6dda227

Browse files
lukebaumanncopybara-github
authored andcommitted
Update pathways.experimental.reshard so that PyTrees with arrays that have different device sets can be resharded.
We group arrays in the pytree by device sets and issue a reshard call for each group. This is needed because plugin executable arguments must be on the same set of devices. PiperOrigin-RevId: 811039183
1 parent 7c52d6f commit 6dda227

File tree

1 file changed

+50
-35
lines changed

1 file changed

+50
-35
lines changed

pathwaysutils/experimental/reshard.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"""Experimental resharding API for elastic device sets."""
1515

1616
import base64
17+
import collections
18+
import itertools
1719
import json
1820
from typing import Any, Dict, Sequence
1921

@@ -92,9 +94,7 @@ def _get_resharding_plan(
9294
donate: bool,
9395
) -> ReshardingPlanWrapper:
9496
"""Returns a resharding plan for the given sharding task."""
95-
return ReshardingPlanWrapper(
96-
avals, old_shardings, new_shardings, donate
97-
)
97+
return ReshardingPlanWrapper(avals, old_shardings, new_shardings, donate)
9898

9999

100100
_get_resharding_plan_cached = lru_cache.lru_cache()(_get_resharding_plan)
@@ -116,12 +116,11 @@ def reshard(
116116
(must be a tree prefix of `x`), representing the device(s) and sharding to
117117
which `x` should be sharded to. The result will be committed to the
118118
device(s) of the sharding.
119-
donate: If `True`, donate all input arrays, which may reduce the
120-
amount memory needed for resharding. Buffers donated to resharding should
121-
not be reused.
122-
may_alias: If `True`, may alias the input array with the output array.
123-
May reduce the amount of memory needed for resharding. Not used at the
124-
moment.
119+
donate: If `True`, donate all input arrays, which may reduce the amount
120+
memory needed for resharding. Buffers donated to resharding should not be
121+
reused.
122+
may_alias: If `True`, may alias the input array with the output array. May
123+
reduce the amount of memory needed for resharding. Not used at the moment.
125124
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
126125
recreating plans for the same resharding operation. May improve
127126
performance for use cases where the same resharding operation is done many
@@ -137,43 +136,59 @@ def reshard(
137136
"reshard sharding", tree_def, sharding
138137
)
139138

140-
jax_arrays = []
141-
jax_array_dst_shardings = []
142-
non_jax_arrays = []
143-
non_jax_array_dst_shardings = []
144-
for arr, dst_sharding in zip(flat_x, flat_sharding):
139+
# We must split the arrays into two groups:
140+
# 1. jax.Array
141+
# 2. non jax.Array
142+
# For jax.Array, we will use the ifrt client to get the resharding plan and
143+
# execute it.
144+
# These arrays must be further split into groups based on the device set of
145+
# the sharding, since plugin programs only supports execution on the same
146+
# device set.
147+
# For non jax.Array, we will use jax.device_put to put the array to the
148+
# destination devices.
149+
jax_arrays = collections.defaultdict(
150+
lambda: {"arrays": [], "indices": [], "dst_shardings": []}
151+
)
152+
non_jax_arrays = {"arrays": [], "indices": [], "dst_shardings": []}
153+
for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)):
145154
if not isinstance(dst_sharding, jax.sharding.Sharding):
146155
raise ValueError("`sharding` must contain only `jax.sharding.Sharding`")
147156
if isinstance(arr, jax.Array):
148-
jax_arrays.append(arr)
149-
jax_array_dst_shardings.append(dst_sharding)
157+
key = frozenset(arr.sharding.device_set)
158+
jax_arrays[key]["arrays"].append(arr)
159+
jax_arrays[key]["indices"].append(index)
160+
jax_arrays[key]["dst_shardings"].append(dst_sharding)
150161
else:
151-
non_jax_arrays.append(arr)
152-
non_jax_array_dst_shardings.append(dst_sharding)
153-
154-
if non_jax_arrays:
155-
non_jax_arrays = jax.device_put(non_jax_arrays, non_jax_array_dst_shardings)
162+
non_jax_arrays["arrays"].append(arr)
163+
non_jax_arrays["indices"].append(index)
164+
non_jax_arrays["dst_shardings"].append(dst_sharding)
165+
166+
if non_jax_arrays["arrays"]:
167+
non_jax_arrays["arrays"] = jax.device_put(
168+
non_jax_arrays["arrays"],
169+
non_jax_arrays["dst_shardings"],
170+
donate=donate,
171+
may_alias=may_alias,
172+
)
156173

157-
if jax_arrays:
174+
for array_info in jax_arrays.values():
158175
get_resharding_plan_func = (
159176
_get_resharding_plan_cached
160177
if cache_resharding_plans
161178
else _get_resharding_plan
162179
)
163-
jax_arrays = get_resharding_plan_func(
164-
tuple(arr.aval for arr in jax_arrays),
165-
tuple(arr.sharding for arr in jax_arrays),
166-
tuple(jax_array_dst_shardings),
180+
array_info["arrays"] = get_resharding_plan_func(
181+
tuple(arr.aval for arr in array_info["arrays"]),
182+
tuple(arr.sharding for arr in array_info["arrays"]),
183+
tuple(array_info["dst_shardings"]),
167184
donate,
168-
).execute(tuple(jax_arrays))
185+
).execute(tuple(array_info["arrays"]))
169186

170-
result = []
171-
jax_iter = iter(jax_arrays)
172-
non_jax_iter = iter(non_jax_arrays)
187+
result = [None] * len(flat_x)
188+
for arr, idx in zip(non_jax_arrays["arrays"], non_jax_arrays["indices"]):
189+
result[idx] = arr
190+
for array_info in jax_arrays.values():
191+
for arr, idx in zip(array_info["arrays"], array_info["indices"]):
192+
result[idx] = arr
173193

174-
for arr in flat_x:
175-
if isinstance(arr, jax.Array):
176-
result.append(next(jax_iter))
177-
else:
178-
result.append(next(non_jax_iter))
179194
return jax.tree.unflatten(tree_def, result)

0 commit comments

Comments
 (0)