Skip to content

Commit 061bcad

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Skip load metadata tensor
Summary: X-link: facebookresearch/FBGEMM#1856 The metadata tensor is newly added for kvzch table. Some old checkpoints may not have this fqn. Directly load old checkpoint can cause fqn missing error. This diff try to skip init metadata tensor at load checkpoint func. Metadata tensor is not used in training, so it is okay to skip load. It will be created during saving checkpoint. Differential Revision: D81811024
1 parent 1b1e2b3 commit 061bcad

File tree

3 files changed

+52
-36
lines changed

3 files changed

+52
-36
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,7 +1957,7 @@ def __init__(
19571957
List[ShardedTensor],
19581958
List[ShardedTensor],
19591959
List[ShardedTensor],
1960-
List[ShardedTensor],
1960+
Optional[List[ShardedTensor]],
19611961
]
19621962
] = None
19631963

@@ -2126,26 +2126,31 @@ def _init_sharded_split_embedding_weights(
21262126
self._table_name_to_weight_count_per_rank,
21272127
use_param_size_as_rows=True,
21282128
)
2129-
metadata_sharded_t_list = create_virtual_sharded_tensors(
2130-
emb_table_config_copy,
2131-
metadata_list, # pyre-ignore [6]
2132-
self._pg,
2133-
prefix,
2134-
self._table_name_to_weight_count_per_rank,
2135-
)
2129+
metadata_sharded_t_list = None
2130+
if metadata_list is not None:
2131+
metadata_sharded_t_list = create_virtual_sharded_tensors(
2132+
emb_table_config_copy,
2133+
metadata_list,
2134+
self._pg,
2135+
prefix,
2136+
self._table_name_to_weight_count_per_rank,
2137+
)
21362138

21372139
assert (
21382140
len(pmt_list)
21392141
== len(weight_ids_list) # pyre-ignore
21402142
== len(bucket_cnt_list) # pyre-ignore
2141-
== len(metadata_list) # pyre-ignore
21422143
)
21432144
assert (
21442145
len(pmt_sharded_t_list)
21452146
== len(weight_id_sharded_t_list)
21462147
== len(bucket_cnt_sharded_t_list)
2147-
== len(metadata_sharded_t_list)
21482148
)
2149+
if metadata_list is not None:
2150+
assert metadata_sharded_t_list is not None
2151+
assert len(pmt_list) == len(metadata_list)
2152+
assert len(pmt_sharded_t_list) == len(metadata_sharded_t_list)
2153+
21492154
self._split_weights_res = (
21502155
pmt_sharded_t_list,
21512156
weight_id_sharded_t_list,
@@ -2181,10 +2186,13 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
21812186
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
21822187
table_config = self._config.embedding_tables[table_idx]
21832188
key = append_prefix(prefix, f"{table_config.name}")
2189+
metadata_sharded_t = None
2190+
if metadata_sharded_t_list is not None:
2191+
metadata_sharded_t = metadata_sharded_t_list[table_idx]
21842192

21852193
yield key, pmt_sharded_t, weight_id_sharded_t_list[
21862194
table_idx
2187-
], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t_list[table_idx]
2195+
], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t
21882196

21892197
def flush(self) -> None:
21902198
"""
@@ -2849,7 +2857,7 @@ def __init__(
28492857
List[ShardedTensor],
28502858
List[ShardedTensor],
28512859
List[ShardedTensor],
2852-
List[ShardedTensor],
2860+
Optional[List[ShardedTensor]],
28532861
]
28542862
] = None
28552863

@@ -3018,26 +3026,31 @@ def _init_sharded_split_embedding_weights(
30183026
self._table_name_to_weight_count_per_rank,
30193027
use_param_size_as_rows=True,
30203028
)
3021-
metadata_sharded_t_list = create_virtual_sharded_tensors(
3022-
emb_table_config_copy,
3023-
metadata_list, # pyre-ignore [6]
3024-
self._pg,
3025-
prefix,
3026-
self._table_name_to_weight_count_per_rank,
3027-
)
3029+
metadata_sharded_t_list = None
3030+
if metadata_list is not None:
3031+
metadata_sharded_t_list = create_virtual_sharded_tensors(
3032+
emb_table_config_copy,
3033+
metadata_list,
3034+
self._pg,
3035+
prefix,
3036+
self._table_name_to_weight_count_per_rank,
3037+
)
30283038

30293039
assert (
30303040
len(pmt_list)
30313041
== len(weight_ids_list) # pyre-ignore
30323042
== len(bucket_cnt_list) # pyre-ignore
3033-
== len(metadata_list) # pyre-ignore
30343043
)
30353044
assert (
30363045
len(pmt_sharded_t_list)
30373046
== len(weight_id_sharded_t_list)
30383047
== len(bucket_cnt_sharded_t_list)
3039-
== len(metadata_sharded_t_list)
30403048
)
3049+
if metadata_list is not None:
3050+
assert metadata_sharded_t_list is not None
3051+
assert len(pmt_list) == len(metadata_list)
3052+
assert len(pmt_sharded_t_list) == len(metadata_sharded_t_list)
3053+
30413054
self._split_weights_res = (
30423055
pmt_sharded_t_list,
30433056
weight_id_sharded_t_list,
@@ -3073,10 +3086,13 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
30733086
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
30743087
table_config = self._config.embedding_tables[table_idx]
30753088
key = append_prefix(prefix, f"{table_config.name}")
3089+
metadata_sharded_t = None
3090+
if metadata_sharded_t_list is not None:
3091+
metadata_sharded_t = metadata_sharded_t_list[table_idx]
30763092

30773093
yield key, pmt_sharded_t, weight_id_sharded_t_list[
30783094
table_idx
3079-
], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t_list[table_idx]
3095+
], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t
30803096

30813097
def flush(self) -> None:
30823098
"""

torchrec/distributed/embedding.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,6 @@ def post_state_dict_hook(
10671067
assert (
10681068
weight_ids_sharded_t is not None
10691069
and id_cnt_per_bucket_sharded_t is not None
1070-
and metadata_sharded_t is not None
10711070
)
10721071
# The logic here assumes there is only one shard per table on any particular rank
10731072
# if there are cases each rank has >1 shards, we need to update here accordingly
@@ -1121,12 +1120,13 @@ def update_destination(
11211120
destination,
11221121
virtual_table_sharded_t_map[table_name][1],
11231122
)
1124-
update_destination(
1125-
table_name,
1126-
"metadata",
1127-
destination,
1128-
virtual_table_sharded_t_map[table_name][2],
1129-
)
1123+
if virtual_table_sharded_t_map[table_name][2] is not None:
1124+
update_destination(
1125+
table_name,
1126+
"metadata",
1127+
destination,
1128+
virtual_table_sharded_t_map[table_name][2],
1129+
)
11301130

11311131
def _post_load_state_dict_hook(
11321132
module: "ShardedEmbeddingCollection",

torchrec/distributed/embeddingbag.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,6 @@ def post_state_dict_hook(
12131213
assert (
12141214
weight_ids_sharded_t is not None
12151215
and id_cnt_per_bucket_sharded_t is not None
1216-
and metadata_sharded_t is not None
12171216
)
12181217
# The logic here assumes there is only one shard per table on any particular rank
12191218
# if there are cases each rank has >1 shards, we need to update here accordingly
@@ -1267,12 +1266,13 @@ def update_destination(
12671266
destination,
12681267
virtual_table_sharded_t_map[table_name][1],
12691268
)
1270-
update_destination(
1271-
table_name,
1272-
"metadata",
1273-
destination,
1274-
virtual_table_sharded_t_map[table_name][2],
1275-
)
1269+
if virtual_table_sharded_t_map[table_name][2] is not None:
1270+
update_destination(
1271+
table_name,
1272+
"metadata",
1273+
destination,
1274+
virtual_table_sharded_t_map[table_name][2],
1275+
)
12761276

12771277
def _post_load_state_dict_hook(
12781278
module: "ShardedEmbeddingBagCollection",

0 commit comments

Comments
 (0)