14
14
"""Experimental resharding API for elastic device sets."""
15
15
16
16
import base64
17
+ import collections
18
+ import itertools
17
19
import json
18
20
from typing import Any , Dict , Sequence
19
21
@@ -92,9 +94,7 @@ def _get_resharding_plan(
92
94
donate : bool ,
93
95
) -> ReshardingPlanWrapper :
94
96
"""Returns a resharding plan for the given sharding task."""
95
- return ReshardingPlanWrapper (
96
- avals , old_shardings , new_shardings , donate
97
- )
97
+ return ReshardingPlanWrapper (avals , old_shardings , new_shardings , donate )
98
98
99
99
100
100
_get_resharding_plan_cached = lru_cache .lru_cache ()(_get_resharding_plan )
@@ -116,12 +116,11 @@ def reshard(
116
116
(must be a tree prefix of `x`), representing the device(s) and sharding to
117
117
which `x` should be sharded to. The result will be committed to the
118
118
device(s) of the sharding.
119
- donate: If `True`, donate all input arrays, which may reduce the
120
- amount memory needed for resharding. Buffers donated to resharding should
121
- not be reused.
122
- may_alias: If `True`, may alias the input array with the output array.
123
- May reduce the amount of memory needed for resharding. Not used at the
124
- moment.
119
+ donate: If `True`, donate all input arrays, which may reduce the amount
120
+ memory needed for resharding. Buffers donated to resharding should not be
121
+ reused.
122
+ may_alias: If `True`, may alias the input array with the output array. May
123
+ reduce the amount of memory needed for resharding. Not used at the moment.
125
124
cache_resharding_plans: If `True`, uses a resharding plan cache to avoid
126
125
recreating plans for the same resharding operation. May improve
127
126
performance for use cases where the same resharding operation is done many
@@ -137,43 +136,59 @@ def reshard(
137
136
"reshard sharding" , tree_def , sharding
138
137
)
139
138
140
- jax_arrays = []
141
- jax_array_dst_shardings = []
142
- non_jax_arrays = []
143
- non_jax_array_dst_shardings = []
144
- for arr , dst_sharding in zip (flat_x , flat_sharding ):
139
+ # We must split the arrays into two groups:
140
+ # 1. jax.Array
141
+ # 2. non jax.Array
142
+ # For jax.Array, we will use the ifrt client to get the resharding plan and
143
+ # execute it.
144
+ # These arrays must be further split into groups based on the device set of
145
+ # the sharding, since plugin programs only supports execution on the same
146
+ # device set.
147
+ # For non jax.Array, we will use jax.device_put to put the array to the
148
+ # destination devices.
149
+ jax_arrays = collections .defaultdict (
150
+ lambda : {"arrays" : [], "indices" : [], "dst_shardings" : []}
151
+ )
152
+ non_jax_arrays = {"arrays" : [], "indices" : [], "dst_shardings" : []}
153
+ for index , (arr , dst_sharding ) in enumerate (zip (flat_x , flat_sharding )):
145
154
if not isinstance (dst_sharding , jax .sharding .Sharding ):
146
155
raise ValueError ("`sharding` must contain only `jax.sharding.Sharding`" )
147
156
if isinstance (arr , jax .Array ):
148
- jax_arrays .append (arr )
149
- jax_array_dst_shardings .append (dst_sharding )
157
+ key = frozenset (arr .sharding .device_set )
158
+ jax_arrays [key ]["arrays" ].append (arr )
159
+ jax_arrays [key ]["indices" ].append (index )
160
+ jax_arrays [key ]["dst_shardings" ].append (dst_sharding )
150
161
else :
151
- non_jax_arrays .append (arr )
152
- non_jax_array_dst_shardings .append (dst_sharding )
153
-
154
- if non_jax_arrays :
155
- non_jax_arrays = jax .device_put (non_jax_arrays , non_jax_array_dst_shardings )
162
+ non_jax_arrays ["arrays" ].append (arr )
163
+ non_jax_arrays ["indices" ].append (index )
164
+ non_jax_arrays ["dst_shardings" ].append (dst_sharding )
165
+
166
+ if non_jax_arrays ["arrays" ]:
167
+ non_jax_arrays ["arrays" ] = jax .device_put (
168
+ non_jax_arrays ["arrays" ],
169
+ non_jax_arrays ["dst_shardings" ],
170
+ donate = donate ,
171
+ may_alias = may_alias ,
172
+ )
156
173
157
- if jax_arrays :
174
+ for array_info in jax_arrays . values () :
158
175
get_resharding_plan_func = (
159
176
_get_resharding_plan_cached
160
177
if cache_resharding_plans
161
178
else _get_resharding_plan
162
179
)
163
- jax_arrays = get_resharding_plan_func (
164
- tuple (arr .aval for arr in jax_arrays ),
165
- tuple (arr .sharding for arr in jax_arrays ),
166
- tuple (jax_array_dst_shardings ),
180
+ array_info [ "arrays" ] = get_resharding_plan_func (
181
+ tuple (arr .aval for arr in array_info [ "arrays" ] ),
182
+ tuple (arr .sharding for arr in array_info [ "arrays" ] ),
183
+ tuple (array_info [ "dst_shardings" ] ),
167
184
donate ,
168
- ).execute (tuple (jax_arrays ))
185
+ ).execute (tuple (array_info [ "arrays" ] ))
169
186
170
- result = []
171
- jax_iter = iter (jax_arrays )
172
- non_jax_iter = iter (non_jax_arrays )
187
+ result = [None ] * len (flat_x )
188
+ for arr , idx in zip (non_jax_arrays ["arrays" ], non_jax_arrays ["indices" ]):
189
+ result [idx ] = arr
190
+ for array_info in jax_arrays .values ():
191
+ for arr , idx in zip (array_info ["arrays" ], array_info ["indices" ]):
192
+ result [idx ] = arr
173
193
174
- for arr in flat_x :
175
- if isinstance (arr , jax .Array ):
176
- result .append (next (jax_iter ))
177
- else :
178
- result .append (next (non_jax_iter ))
179
194
return jax .tree .unflatten (tree_def , result )
0 commit comments