Skip to content

Commit 0877643

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Test commit (#3084)
Summary: Pull Request resolved: #3084 Rollback Plan: Differential Revision: D76457454
1 parent c8495ec commit 0877643

File tree

2 files changed

+264
-7
lines changed

2 files changed

+264
-7
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
from typing import Dict, List, Optional, Union
9+
import logging as logger
10+
from collections import Counter, OrderedDict
11+
from typing import Dict, Iterable, List, Optional, Union
1012

1113
import torch
1214

@@ -49,6 +51,8 @@ class ModelDeltaTracker:
4951
call.
5052
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
5153
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
54+
fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
55+
5256
"""
5357

5458
DEFAULT_CONSUMER: str = "default"
@@ -59,11 +63,15 @@ def __init__(
5963
consumers: Optional[List[str]] = None,
6064
delete_on_read: bool = True,
6165
mode: TrackingMode = TrackingMode.ID_ONLY,
66+
fqns_to_skip: Iterable[str] = (),
6267
) -> None:
6368
self._model = model
6469
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
6570
self._delete_on_read = delete_on_read
6671
self._mode = mode
72+
self._fqn_to_feature_map: Dict[str, List[str]] = {}
73+
self._fqns_to_skip: Iterable[str] = fqns_to_skip
74+
self.fqn_to_feature_names()
6775
pass
6876

6977
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
@@ -85,14 +93,70 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
8593
"""
8694
return {}
8795

88-
def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
96+
def fqn_to_feature_names(self) -> Dict[str, List[str]]:
8997
"""
90-
Returns a mapping from FQN to feature names for a given module.
91-
92-
Args:
93-
module (nn.Module): the module to retrieve feature names for.
98+
Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model.
9499
"""
95-
return {}
100+
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
101+
return self._fqn_to_feature_map
102+
103+
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
104+
table_to_fqn: Dict[str, str] = OrderedDict()
105+
for fqn, named_module in self._model.named_modules():
106+
split_fqn = fqn.split(".")
107+
# Skipping partial FQNs present in fqns_to_skip
108+
# TODO: Validate if we need to support more complex patterns for skipping fqns
109+
should_skip = False
110+
for fqn_to_skip in self._fqns_to_skip:
111+
if fqn_to_skip in split_fqn:
112+
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
113+
should_skip = True
114+
break
115+
if should_skip:
116+
continue
117+
118+
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119+
if isinstance(named_module, SUPPORTED_MODULES):
120+
for table_name, config in named_module._table_name_to_config.items():
121+
logger.info(
122+
f"Found {table_name} for {fqn} with features {config.feature_names}"
123+
)
124+
table_to_feature_names[table_name] = config.feature_names
125+
for table_name in table_to_feature_names:
126+
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127+
# will incorrectly match fqn with all the table names that have the same prefix
128+
if table_name in split_fqn:
129+
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
130+
if table_name in table_to_fqn:
131+
# Sanity check for validating that we don't have more then one table mapping to same fqn.
132+
logger.warning(
133+
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
134+
)
135+
table_to_fqn[table_name] = embedding_fqn
136+
logger.info(f"Table to fqn: {table_to_fqn}")
137+
flatten_names = [
138+
name for names in table_to_feature_names.values() for name in names
139+
]
140+
# TODO: Validate if there is a better way to handle duplicate feature names.
141+
# Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case.
142+
if len(set(flatten_names)) != len(flatten_names):
143+
counts = Counter(flatten_names)
144+
duplicates = [item for item, count in counts.items() if count > 1]
145+
logger.warning(f"duplicate feature names found: {duplicates}")
146+
147+
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
148+
for table_name in table_to_feature_names:
149+
if table_name not in table_to_fqn:
150+
# This is likely unexpected, where we can't locate the FQN associated with this table.
151+
logger.warning(
152+
f"Table {table_name} not found in {table_to_fqn}, skipping"
153+
)
154+
continue
155+
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
156+
table_name
157+
]
158+
self._fqn_to_feature_map = fqn_to_feature_names
159+
return fqn_to_feature_names
96160

97161
def clear(self, consumer: Optional[str] = None) -> None:
98162
"""
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
#!/usr/bin/env python3
11+
from dataclasses import dataclass
12+
from typing import cast, Dict, Iterable, List, Optional, Union
13+
14+
import torch
15+
16+
from torch import nn
17+
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
18+
from torchrec.distributed.planner import ParameterConstraints
19+
from torchrec.distributed.types import ShardingType
20+
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
21+
from torchrec.modules.embedding_modules import (
22+
EmbeddingBagCollection,
23+
EmbeddingCollection,
24+
)
25+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
26+
27+
28+
@dataclass
29+
class EmbeddingTableProps:
30+
"""
31+
Properties of an embedding table.
32+
33+
Args:
34+
embedding_table_config: Config of the embedding table of Union(EmbeddingConfig or EmbeddingBagConfig)
35+
sharding (ShardingType): sharding type of the table
36+
weight_type (WeightedType): weight
37+
"""
38+
39+
embedding_table_config: Union[EmbeddingConfig, EmbeddingBagConfig]
40+
sharding: ShardingType
41+
is_weighted: bool = False
42+
43+
44+
class TestECModel(nn.Module):
45+
"""
46+
Test model with EmbeddingCollection and Linear layers.
47+
48+
Args:
49+
tables (List[EmbeddingConfig]): list of embedding tables
50+
device (Optional[torch.device]): device on which buffers will be initialized
51+
52+
Example:
53+
TestECModel(tables=[EmbeddingConfig(...)])
54+
"""
55+
56+
def __init__(
57+
self, tables: List[EmbeddingConfig], device: Optional[torch.device] = None
58+
) -> None:
59+
super().__init__()
60+
self.ec: EmbeddingCollection = EmbeddingCollection(
61+
tables=tables,
62+
device=device if device else torch.device("meta"),
63+
)
64+
65+
embedding_dim = tables[0].embedding_dim
66+
67+
self.seq: nn.Sequential = nn.Sequential(
68+
*[nn.Linear(embedding_dim, embedding_dim) for _ in range(3)]
69+
)
70+
71+
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
72+
"""
73+
Forward pass of the TestECModel.
74+
75+
Args:
76+
features (KeyedJaggedTensor): Input features for the model.
77+
78+
Returns:
79+
torch.Tensor: Output tensor after processing through the model.
80+
"""
81+
82+
lookup_result = self.ec(features)
83+
return self.seq(torch.cat([jt.values() for _, jt in lookup_result.items()]))
84+
85+
86+
class TestEBCModel(nn.Module):
87+
"""
88+
Test model with EmbeddingBagCollection and Linear layers.
89+
90+
Args:
91+
tables (List[EmbeddingBagConfig]): list of embedding tables
92+
device (Optional[torch.device]): device on which buffers will be initialized
93+
94+
Example:
95+
TestEBCModel(tables=[EmbeddingBagConfig(...)])
96+
"""
97+
98+
def __init__(
99+
self, tables: List[EmbeddingBagConfig], device: Optional[torch.device] = None
100+
) -> None:
101+
super().__init__()
102+
self.ebc: EmbeddingBagCollection
103+
self.ebc = EmbeddingBagCollection(
104+
tables=tables,
105+
device=device if device else torch.device("meta"),
106+
)
107+
108+
embedding_dim = tables[0].embedding_dim
109+
self.seq: nn.Sequential = nn.Sequential(
110+
*[nn.Linear(embedding_dim, embedding_dim) for _ in range(3)]
111+
)
112+
113+
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
114+
"""
115+
Forward pass of the TestEBCModel.
116+
117+
Args:
118+
features (KeyedJaggedTensor): Input features for the model.
119+
120+
Returns:
121+
torch.Tensor: Output tensor after processing through the model.
122+
"""
123+
124+
lookup_result = self.ebc(features).to_dict()
125+
return self.seq(torch.cat(tuple(lookup_result.values())))
126+
127+
128+
def create_ec_model(
129+
tables: Iterable[EmbeddingTableProps],
130+
device: Optional[torch.device] = None,
131+
) -> nn.Module:
132+
"""
133+
Create an EmbeddingCollection model with the given tables.
134+
135+
Args:
136+
tables (List[EmbeddingTableProps]): list of embedding tables
137+
device (Optional[torch.device]): device on which buffers will be initialized
138+
139+
Returns:
140+
nn.Module: EmbeddingCollection model
141+
"""
142+
return TestECModel(
143+
tables=[
144+
cast(EmbeddingConfig, table.embedding_table_config) for table in tables
145+
],
146+
device=device,
147+
)
148+
149+
150+
def create_ebc_model(
151+
tables: Iterable[EmbeddingTableProps],
152+
device: Optional[torch.device] = None,
153+
) -> nn.Module:
154+
"""
155+
Create an EmbeddinBagCollection model with the given tables.
156+
157+
Args:
158+
tables (List[EmbeddingTableProps]): list of embedding tables
159+
device (Optional[torch.device]): device on which buffers will be initialized
160+
161+
Returns:
162+
nn.Module: EmbeddingCollection model
163+
"""
164+
return TestEBCModel(
165+
tables=[
166+
cast(EmbeddingBagConfig, table.embedding_table_config) for table in tables
167+
],
168+
device=device,
169+
)
170+
171+
172+
def generate_planner_constraints(
173+
tables: Iterable[EmbeddingTableProps],
174+
) -> dict[str, ParameterConstraints]:
175+
"""
176+
Generate planner constraints for the given tables.
177+
178+
Args:
179+
tables (List[EmbeddingTableProps]): list of embedding tables
180+
181+
Returns:
182+
Dict[str, ParameterConstraints]: planner constraints
183+
"""
184+
constraints: Dict[str, ParameterConstraints] = {}
185+
for table in tables:
186+
sharding_types = [table.sharding.value]
187+
constraints[table.embedding_table_config.name] = ParameterConstraints(
188+
sharding_types=sharding_types,
189+
compute_kernels=[EmbeddingComputeKernel.FUSED.value],
190+
feature_names=table.embedding_table_config.feature_names,
191+
pooling_factors=[1.0],
192+
)
193+
return constraints

0 commit comments

Comments
 (0)