Skip to content

Commit a95e7e1

Browse files
emlinfacebook-github-bot
authored andcommitted
support multiple kernels in QEC when virtual table type is enabled (#2984)
Summary: Pull Request resolved: #2984 - Enabled separate embedding group for virtual table - Fixed feature order if feature is grouped differently with definition order Reviewed By: kausv Differential Revision: D73059492 fbshipit-source-id: 655ae97fb4c6a8e14574438f74d06d7f6eb8319b
1 parent 8b75d3c commit a95e7e1

File tree

2 files changed

+141
-8
lines changed

2 files changed

+141
-8
lines changed

torchrec/quant/embedding_modules.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def __init__(
357357
self._is_weighted = is_weighted
358358
self._embedding_bag_configs: List[EmbeddingBagConfig] = tables
359359
self._key_to_tables: Dict[
360-
Tuple[PoolingType, DataType], List[EmbeddingBagConfig]
360+
Tuple[PoolingType, DataType, bool], List[EmbeddingBagConfig]
361361
] = defaultdict(list)
362362
self._feature_names: List[str] = []
363363
self._feature_splits: List[int] = []
@@ -379,14 +379,16 @@ def __init__(
379379
if table.name in table_names:
380380
raise ValueError(f"Duplicate table name {table.name}")
381381
table_names.add(table.name)
382+
key = (table.pooling, table.use_virtual_table)
382383
# pyre-ignore
383-
self._key_to_tables[table.pooling].append(table)
384+
self._key_to_tables[key].append(table)
384385

385386
location = (
386387
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
387388
)
388389

389-
for pooling, emb_configs in self._key_to_tables.items():
390+
for key, emb_configs in self._key_to_tables.items():
391+
pooling = key[0]
390392
embedding_specs = []
391393
weight_lists: Optional[
392394
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -756,7 +758,9 @@ def __init__( # noqa C901
756758
self._output_dtype = output_dtype
757759
self._device = device
758760
self.row_alignment = row_alignment
759-
self._key_to_tables: Dict[DataType, List[EmbeddingConfig]] = defaultdict(list)
761+
self._key_to_tables: Dict[Tuple[DataType, bool], List[EmbeddingConfig]] = (
762+
defaultdict(list)
763+
)
760764
self._feature_names: List[str] = []
761765
self._features_order: Optional[List[int]] = None
762766

@@ -778,12 +782,11 @@ def __init__( # noqa C901
778782
+ f" Violating case: {table.name}'s embedding_dim {table.embedding_dim} !="
779783
+ f" {self._embedding_dim}"
780784
)
781-
key = table.data_type
785+
key = (table.data_type, table.use_virtual_table)
782786
self._key_to_tables[key].append(table)
783-
self._feature_names.extend(table.feature_names)
784787
self._feature_splits: List[int] = []
785788
for key, emb_configs in self._key_to_tables.items():
786-
data_type = key
789+
data_type = key[0]
787790
embedding_specs = []
788791
weight_lists: Optional[
789792
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -808,6 +811,9 @@ def __init__( # noqa C901
808811
table_name_to_quantized_weights[table.name]
809812
)
810813
feature_table_map.extend([idx] * table.num_features())
814+
# move to here to make sure feature_names order is consistent with the embedding groups
815+
self._feature_names.extend(table.feature_names)
816+
811817
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
812818
embedding_specs=embedding_specs,
813819
pooling_mode=PoolingMode.NONE,
@@ -852,7 +858,9 @@ def __init__( # noqa C901
852858
"weight_qbias", qbias
853859
)
854860

855-
self._embedding_names_by_batched_tables: Dict[DataType, List[str]] = {
861+
self._embedding_names_by_batched_tables: Dict[
862+
Tuple[DataType, bool], List[str]
863+
] = {
856864
key: list(itertools.chain(*get_embedding_names_by_table(table)))
857865
for key, table in self._key_to_tables.items()
858866
}

torchrec/quant/tests/test_embedding_modules.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,44 @@ def test_multiple_features(self) -> None:
260260
)
261261
self._test_ebc([eb1_config, eb2_config], features)
262262

263+
def test_multiple_kernels_per_ebc_table(self) -> None:
264+
class TestModule(torch.nn.Module):
265+
def __init__(self, m: torch.nn.Module) -> None:
266+
super().__init__()
267+
self.m = m
268+
269+
eb1_config = EmbeddingBagConfig(
270+
name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1"]
271+
)
272+
eb2_config = EmbeddingBagConfig(
273+
name="t2",
274+
embedding_dim=16,
275+
num_embeddings=10,
276+
feature_names=["f2"],
277+
use_virtual_table=True,
278+
)
279+
eb3_config = EmbeddingBagConfig(
280+
name="t3", embedding_dim=16, num_embeddings=10, feature_names=["f3"]
281+
)
282+
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config, eb3_config])
283+
model = TestModule(ebc)
284+
qebc = trec_infer.modules.quantize_embeddings(
285+
model,
286+
dtype=torch.int8,
287+
inplace=True,
288+
per_table_weight_dtype={"t1": torch.float16},
289+
)
290+
self.assertTrue(isinstance(qebc.m, QuantEmbeddingBagCollection))
291+
# feature name should be consistent with the order of grouped embeddings
292+
self.assertEqual(qebc.m._feature_names, ["f1", "f3", "f2"])
293+
294+
features = KeyedJaggedTensor(
295+
keys=["f1", "f2", "f3"],
296+
values=torch.as_tensor([0, 1, 2]),
297+
lengths=torch.as_tensor([1, 1, 1]),
298+
)
299+
self._test_ebc([eb1_config, eb2_config, eb3_config], features)
300+
263301
# pyre-ignore
264302
@given(
265303
data_type=st.sampled_from(
@@ -742,6 +780,93 @@ def __init__(self, m: torch.nn.Module) -> None:
742780
self.assertEqual(config.name, "t2")
743781
self.assertEqual(config.data_type, DataType.INT8)
744782

783+
def test_multiple_kernels_per_ec_table(self) -> None:
784+
class TestModule(torch.nn.Module):
785+
def __init__(self, m: torch.nn.Module) -> None:
786+
super().__init__()
787+
self.m = m
788+
789+
eb1_config = EmbeddingConfig(
790+
name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1"]
791+
)
792+
eb2_config = EmbeddingConfig(
793+
name="t2",
794+
embedding_dim=16,
795+
num_embeddings=10,
796+
feature_names=["f2"],
797+
use_virtual_table=True,
798+
)
799+
eb3_config = EmbeddingConfig(
800+
name="t3",
801+
embedding_dim=16,
802+
num_embeddings=10,
803+
feature_names=["f3"],
804+
)
805+
ec = EmbeddingCollection(tables=[eb1_config, eb2_config, eb3_config])
806+
model = TestModule(ec)
807+
qconfig_spec_keys: List[Type[torch.nn.Module]] = [EmbeddingCollection]
808+
quant_mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = {
809+
EmbeddingCollection: QuantEmbeddingCollection
810+
}
811+
qec = trec_infer.modules.quantize_embeddings(
812+
model,
813+
dtype=torch.int8,
814+
additional_qconfig_spec_keys=qconfig_spec_keys,
815+
additional_mapping=quant_mapping,
816+
inplace=True,
817+
per_table_weight_dtype={
818+
"t1": torch.float16,
819+
"t2": torch.float16,
820+
"t3": torch.float16,
821+
},
822+
)
823+
self.assertTrue(isinstance(qec.m, QuantEmbeddingCollection))
824+
# feature name should be consistent with the order of grouped embeddings
825+
self.assertEqual(qec.m._feature_names, ["f1", "f3", "f2"])
826+
827+
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
828+
configs = model.m.embedding_configs()
829+
self.assertEqual(len(configs), 3)
830+
features = KeyedJaggedTensor(
831+
keys=["f1", "f2", "f3"],
832+
values=torch.as_tensor(
833+
[
834+
5,
835+
1,
836+
0,
837+
0,
838+
4,
839+
3,
840+
4,
841+
9,
842+
2,
843+
2,
844+
3,
845+
3,
846+
1,
847+
5,
848+
0,
849+
7,
850+
5,
851+
0,
852+
9,
853+
9,
854+
3,
855+
5,
856+
6,
857+
6,
858+
9,
859+
3,
860+
7,
861+
8,
862+
7,
863+
7,
864+
]
865+
),
866+
lengths=torch.as_tensor([9, 12, 9]),
867+
)
868+
self._test_ec(tables=[eb3_config, eb1_config, eb2_config], features=features)
869+
745870
def test_different_quantization_dtype_per_ebc_table(self) -> None:
746871
class TestModule(torch.nn.Module):
747872
def __init__(self, m: torch.nn.Module) -> None:

0 commit comments

Comments
 (0)