Skip to content

Commit 1885423

Browse files
committed
Update planner to use consistent hashing
Differential Revision: D76303748
1 parent be4e6d7 commit 1885423

File tree

6 files changed

+244
-35
lines changed

6 files changed

+244
-35
lines changed

torchrec/distributed/planner/enumerators.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import copy
1011
import logging
1112
from typing import Dict, List, Optional, Set, Tuple, Union
1213

@@ -102,6 +103,11 @@ def __init__(
102103
EmbeddingStorageEstimator(topology=topology, constraints=constraints),
103104
]
104105

106+
# Initializing caching for enumerate
107+
self._last_stored_search_space: Optional[List[ShardingOption]] = None
108+
self._last_stored_module: Optional[nn.Module] = None
109+
self._last_stored_sharders: Optional[List[ModuleSharder[nn.Module]]] = None
110+
105111
def enumerate(
106112
self,
107113
module: nn.Module,
@@ -118,6 +124,12 @@ def enumerate(
118124
List[ShardingOption]: valid sharding options with values populated.
119125
"""
120126

127+
if (
128+
self._last_stored_module == module
129+
and self._last_stored_sharders == sharders
130+
):
131+
return copy.deepcopy(self._last_stored_search_space) # pyre-ignore
132+
121133
self._sharder_map = {
122134
sharder_name(sharder.module_type): sharder for sharder in sharders
123135
}
@@ -230,8 +242,20 @@ def enumerate(
230242

231243
self.populate_estimates(sharding_options)
232244

245+
self._last_stored_module = module
246+
self._last_stored_sharders = sharders
247+
248+
# Caching the search space with a copy of sharding options, to avoid unexpected modifications to list
249+
self._last_stored_search_space = copy.deepcopy(sharding_options)
233250
return sharding_options
234251

252+
@property
253+
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
254+
# NOTE: This is the last search space stored by enumerate(...), do not use
255+
# this field in place of actually calling enumerate(...) as it will varie for each
256+
# module/sharders passed in.
257+
return self._last_stored_search_space
258+
235259
def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
236260
for estimator in self._estimators:
237261
estimator.estimate(sharding_options, self._sharder_map)

torchrec/distributed/planner/planners.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from torchrec.distributed.planner.types import (
4141
Enumerator,
42+
hash_planner_context_inputs,
4243
ParameterConstraints,
4344
Partitioner,
4445
PerfModel,
@@ -280,25 +281,21 @@ def collective_plan(
280281
sharders,
281282
)
282283

283-
def hash_planner_context_inputs(self) -> str:
284+
def hash_planner_context_inputs(self) -> int:
284285
"""
285286
Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats.
286287
These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context.
287288
288289
Returns:
289290
Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints.
290291
"""
291-
hashable_list = [
292+
return hash_planner_context_inputs(
292293
self._topology,
293294
self._batch_size,
294295
self._enumerator,
295296
self._storage_reservation,
296-
frozenset(self._constraints.items()) if self._constraints else None,
297-
]
298-
serialized_list = str(hashable_list).encode("utf-8")
299-
hash_object = hashlib.sha256(serialized_list)
300-
hash_digest = hash_object.hexdigest()
301-
return hash_digest
297+
self._constraints,
298+
)
302299

303300
def plan(
304301
self,

torchrec/distributed/planner/storage_reservations.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class FixedPercentageStorageReservation(StorageReservation):
163163
def __init__(self, percentage: float) -> None:
164164
assert percentage >= 0 and percentage <= 1
165165
self._percentage: float = percentage
166+
self._last_reserved_topology: Optional[Topology] = None
166167

167168
def reserve(
168169
self,
@@ -174,8 +175,14 @@ def reserve(
174175
) -> Topology:
175176
reserved_topology = copy.deepcopy(topology)
176177
_reserve_storage_percentage(reserved_topology, self._percentage)
178+
self._last_reserved_topology = reserved_topology
177179
return reserved_topology
178180

181+
@property
182+
def last_reserved_topology(self) -> Optional[Topology]:
183+
"Returns a copy of the cached value of the most recent output from the reserve() method."
184+
return copy.deepcopy(self._last_reserved_topology)
185+
179186

180187
class HeuristicalStorageReservation(StorageReservation):
181188
"""
@@ -206,6 +213,7 @@ def __init__(
206213

207214
self._dense_storage: Optional[Storage] = None
208215
self._kjt_storage: Optional[Storage] = None
216+
self._last_reserved_topology: Optional[Topology] = None
209217

210218
def reserve(
211219
self,
@@ -215,6 +223,7 @@ def reserve(
215223
sharders: List[ModuleSharder[nn.Module]],
216224
constraints: Optional[Dict[str, ParameterConstraints]] = None,
217225
) -> Topology:
226+
# TODO: enable proper caching of topology values through _last_reserved_topology
218227
reserved_topology = copy.deepcopy(topology)
219228

220229
batch_inputs, shardable_modules = _get_batch_inputs_and_shardable_parameters(
@@ -262,8 +271,14 @@ def reserve(
262271
message=negative_storage_solution,
263272
)
264273

274+
self._last_reserved_topology = copy.deepcopy(reserved_topology)
265275
return reserved_topology
266276

277+
@property
278+
def last_reserved_topology(self) -> Optional[Topology]:
279+
"Cached value of the most recent output from the reserve() method."
280+
return self._last_reserved_topology
281+
267282

268283
class InferenceStorageReservation(StorageReservation):
269284
"""
@@ -291,6 +306,7 @@ def __init__(
291306

292307
self._dense_storage: Optional[Storage] = None
293308
self._kjt_storage: Optional[Storage] = None
309+
self._last_reserved_topology: Optional[Topology] = None
294310

295311
def reserve(
296312
self,
@@ -324,4 +340,9 @@ def reserve(
324340
multiplier=1,
325341
)
326342

343+
self._last_reserved_topology = copy.deepcopy(reserved_topology)
344+
327345
return reserved_topology
346+
347+
def last_reserved_topology(self) -> Optional[Topology]:
348+
return copy.deepcopy(self._last_reserved_topology)

torchrec/distributed/planner/tests/test_planners.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414
from torch import nn
15-
from torchrec import EmbeddingConfig
15+
from torchrec import EmbeddingBagCollection, EmbeddingConfig
1616
from torchrec.distributed.embedding import EmbeddingCollectionSharder
1717
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1818
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
@@ -306,6 +306,22 @@ def test_passing_info_through_constraints(self) -> None:
306306
class TestEmbeddingShardingHashPlannerContextInputs(unittest.TestCase):
307307

308308
def setUp(self) -> None:
309+
eb_config = EmbeddingBagConfig(
310+
name="table_0",
311+
embedding_dim=160,
312+
num_embeddings=10000,
313+
feature_names=["f1"],
314+
data_type=DataType.FP16,
315+
)
316+
module = EmbeddingBagCollection(
317+
tables=[eb_config],
318+
is_weighted=False,
319+
device=torch.device(
320+
"meta"
321+
), # Using meta device for now since only getting search space
322+
)
323+
sharders = [EmbeddingBagCollectionSharder()]
324+
309325
self.topology = Topology(
310326
local_world_size=8,
311327
world_size=1,
@@ -315,10 +331,20 @@ def setUp(self) -> None:
315331
self.enumerator = EmbeddingEnumerator(
316332
topology=self.topology, batch_size=self.batch_size
317333
)
334+
self.enumerator.enumerate(module, sharders) # pyre-ignore
335+
318336
self.storage_reservation = HeuristicalStorageReservation(percentage=0.15)
319337
self.perf_model = NoopPerfModel(topology=self.topology)
320338
self.constraints = {"table1": ParameterConstraints()}
321339

340+
self.storage_reservation.reserve(
341+
topology=self.topology,
342+
batch_size=self.batch_size,
343+
module=module,
344+
sharders=sharders, # pyre-ignore
345+
constraints=self.constraints,
346+
)
347+
322348
def test_hash_equality(self) -> None:
323349
planner1 = EmbeddingShardingPlanner(
324350
topology=self.topology,

torchrec/distributed/planner/tests/test_types.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,30 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import cast
11+
from typing import cast, Dict, Optional
1212
from unittest.mock import MagicMock
1313

1414
import torch
15+
from torch import multiprocessing
1516
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
17+
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
18+
from torchrec.distributed.planner import EmbeddingShardingPlanner
19+
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
20+
from torchrec.distributed.planner.perf_models import NoopPerfModel
21+
from torchrec.distributed.planner.storage_reservations import (
22+
HeuristicalStorageReservation,
23+
)
1624

1725
from torchrec.distributed.planner.types import (
1826
ParameterConstraints,
1927
Shard,
2028
ShardingOption,
2129
Topology,
2230
)
31+
from torchrec.distributed.test_utils.multi_process import (
32+
MultiProcessContext,
33+
MultiProcessTestBase,
34+
)
2335
from torchrec.distributed.types import (
2436
BoundsCheckMode,
2537
CacheAlgorithm,
@@ -348,3 +360,75 @@ def test_hash_inequality(self) -> None:
348360
self.assertNotEqual(
349361
hash(pc1), hash(pc2), "Hashes should be different for different instances"
350362
)
363+
364+
365+
def _test_hashing_consistency(
366+
rank: int,
367+
world_size: int,
368+
backend: str,
369+
return_hash_dict: Dict[str, int],
370+
local_size: Optional[int] = None,
371+
) -> None:
372+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
373+
topology = Topology(
374+
local_world_size=8,
375+
world_size=1,
376+
compute_device="cuda",
377+
)
378+
batch_size = 128
379+
enumerator = EmbeddingEnumerator(topology=topology, batch_size=batch_size)
380+
eb_config = EmbeddingBagConfig(
381+
name="table_0",
382+
embedding_dim=160,
383+
num_embeddings=10000,
384+
feature_names=["f1"],
385+
data_type=DataType.FP16,
386+
)
387+
module = EmbeddingBagCollection(
388+
tables=[eb_config],
389+
is_weighted=False,
390+
device=torch.device(
391+
"meta"
392+
), # Using meta device for now since only getting search space
393+
)
394+
sharders = [EmbeddingBagCollectionSharder()]
395+
enumerator.enumerate(module, sharders) # pyre-ignore
396+
storage_reservation = HeuristicalStorageReservation(percentage=0.15)
397+
constraints = {"table1": ParameterConstraints()}
398+
399+
storage_reservation.reserve(
400+
topology=topology,
401+
batch_size=batch_size,
402+
module=module,
403+
sharders=sharders, # pyre-ignore
404+
constraints=constraints,
405+
)
406+
perf_model = NoopPerfModel(topology=topology)
407+
408+
planner1 = EmbeddingShardingPlanner(
409+
topology=topology,
410+
batch_size=batch_size,
411+
enumerator=enumerator,
412+
storage_reservation=storage_reservation,
413+
performance_model=perf_model,
414+
constraints=constraints,
415+
)
416+
417+
h = planner1.hash_planner_context_inputs()
418+
return_hash_dict[str(rank)] = h
419+
420+
421+
class TestConsistentHashingBetweenProcesses(MultiProcessTestBase):
422+
423+
def test_hash_consistency(self) -> None:
424+
# planner
425+
world_size = 2
426+
return_hash_dict = multiprocessing.Manager().dict()
427+
self._run_multi_process_test(
428+
callable=_test_hashing_consistency,
429+
world_size=world_size,
430+
backend="nccl" if torch.cuda.is_available() else "gloo",
431+
return_hash_dict=return_hash_dict,
432+
)
433+
hashes = return_hash_dict.values()
434+
assert hashes[0] == hashes[1], "hash values are different."

0 commit comments

Comments
 (0)