@@ -1957,7 +1957,7 @@ def __init__(
1957
1957
List [ShardedTensor ],
1958
1958
List [ShardedTensor ],
1959
1959
List [ShardedTensor ],
1960
- List [ShardedTensor ],
1960
+ Optional [ List [ShardedTensor ] ],
1961
1961
]
1962
1962
] = None
1963
1963
@@ -2126,26 +2126,31 @@ def _init_sharded_split_embedding_weights(
2126
2126
self ._table_name_to_weight_count_per_rank ,
2127
2127
use_param_size_as_rows = True ,
2128
2128
)
2129
- metadata_sharded_t_list = create_virtual_sharded_tensors (
2130
- emb_table_config_copy ,
2131
- metadata_list , # pyre-ignore [6]
2132
- self ._pg ,
2133
- prefix ,
2134
- self ._table_name_to_weight_count_per_rank ,
2135
- )
2129
+ metadata_sharded_t_list = None
2130
+ if metadata_list is not None :
2131
+ metadata_sharded_t_list = create_virtual_sharded_tensors (
2132
+ emb_table_config_copy ,
2133
+ metadata_list ,
2134
+ self ._pg ,
2135
+ prefix ,
2136
+ self ._table_name_to_weight_count_per_rank ,
2137
+ )
2136
2138
2137
2139
assert (
2138
2140
len (pmt_list )
2139
2141
== len (weight_ids_list ) # pyre-ignore
2140
2142
== len (bucket_cnt_list ) # pyre-ignore
2141
- == len (metadata_list ) # pyre-ignore
2142
2143
)
2143
2144
assert (
2144
2145
len (pmt_sharded_t_list )
2145
2146
== len (weight_id_sharded_t_list )
2146
2147
== len (bucket_cnt_sharded_t_list )
2147
- == len (metadata_sharded_t_list )
2148
2148
)
2149
+ if metadata_list is not None :
2150
+ assert metadata_sharded_t_list is not None
2151
+ assert len (pmt_list ) == len (metadata_list )
2152
+ assert len (pmt_sharded_t_list ) == len (metadata_sharded_t_list )
2153
+
2149
2154
self ._split_weights_res = (
2150
2155
pmt_sharded_t_list ,
2151
2156
weight_id_sharded_t_list ,
@@ -2181,10 +2186,13 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
2181
2186
for table_idx , pmt_sharded_t in enumerate (pmt_sharded_t_list ):
2182
2187
table_config = self ._config .embedding_tables [table_idx ]
2183
2188
key = append_prefix (prefix , f"{ table_config .name } " )
2189
+ metadata_sharded_t = None
2190
+ if metadata_sharded_t_list is not None :
2191
+ metadata_sharded_t = metadata_sharded_t_list [table_idx ]
2184
2192
2185
2193
yield key , pmt_sharded_t , weight_id_sharded_t_list [
2186
2194
table_idx
2187
- ], bucket_cnt_sharded_t_list [table_idx ], metadata_sharded_t_list [ table_idx ]
2195
+ ], bucket_cnt_sharded_t_list [table_idx ], metadata_sharded_t
2188
2196
2189
2197
def flush (self ) -> None :
2190
2198
"""
@@ -2849,7 +2857,7 @@ def __init__(
2849
2857
List [ShardedTensor ],
2850
2858
List [ShardedTensor ],
2851
2859
List [ShardedTensor ],
2852
- List [ShardedTensor ],
2860
+ Optional [ List [ShardedTensor ] ],
2853
2861
]
2854
2862
] = None
2855
2863
@@ -3018,26 +3026,31 @@ def _init_sharded_split_embedding_weights(
3018
3026
self ._table_name_to_weight_count_per_rank ,
3019
3027
use_param_size_as_rows = True ,
3020
3028
)
3021
- metadata_sharded_t_list = create_virtual_sharded_tensors (
3022
- emb_table_config_copy ,
3023
- metadata_list , # pyre-ignore [6]
3024
- self ._pg ,
3025
- prefix ,
3026
- self ._table_name_to_weight_count_per_rank ,
3027
- )
3029
+ metadata_sharded_t_list = None
3030
+ if metadata_list is not None :
3031
+ metadata_sharded_t_list = create_virtual_sharded_tensors (
3032
+ emb_table_config_copy ,
3033
+ metadata_list ,
3034
+ self ._pg ,
3035
+ prefix ,
3036
+ self ._table_name_to_weight_count_per_rank ,
3037
+ )
3028
3038
3029
3039
assert (
3030
3040
len (pmt_list )
3031
3041
== len (weight_ids_list ) # pyre-ignore
3032
3042
== len (bucket_cnt_list ) # pyre-ignore
3033
- == len (metadata_list ) # pyre-ignore
3034
3043
)
3035
3044
assert (
3036
3045
len (pmt_sharded_t_list )
3037
3046
== len (weight_id_sharded_t_list )
3038
3047
== len (bucket_cnt_sharded_t_list )
3039
- == len (metadata_sharded_t_list )
3040
3048
)
3049
+ if metadata_list is not None :
3050
+ assert metadata_sharded_t_list is not None
3051
+ assert len (pmt_list ) == len (metadata_list )
3052
+ assert len (pmt_sharded_t_list ) == len (metadata_sharded_t_list )
3053
+
3041
3054
self ._split_weights_res = (
3042
3055
pmt_sharded_t_list ,
3043
3056
weight_id_sharded_t_list ,
@@ -3073,10 +3086,13 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
3073
3086
for table_idx , pmt_sharded_t in enumerate (pmt_sharded_t_list ):
3074
3087
table_config = self ._config .embedding_tables [table_idx ]
3075
3088
key = append_prefix (prefix , f"{ table_config .name } " )
3089
+ metadata_sharded_t = None
3090
+ if metadata_sharded_t_list is not None :
3091
+ metadata_sharded_t = metadata_sharded_t_list [table_idx ]
3076
3092
3077
3093
yield key , pmt_sharded_t , weight_id_sharded_t_list [
3078
3094
table_idx
3079
- ], bucket_cnt_sharded_t_list [table_idx ], metadata_sharded_t_list [ table_idx ]
3095
+ ], bucket_cnt_sharded_t_list [table_idx ], metadata_sharded_t
3080
3096
3081
3097
def flush (self ) -> None :
3082
3098
"""
0 commit comments