From 669e5d941488b79482739584b32769044d3c13bc Mon Sep 17 00:00:00 2001 From: Chenyu Zhang Date: Tue, 8 Jul 2025 23:15:59 -0700 Subject: [PATCH] kvzch use new operator in model publish (#3108) Summary: Publish change to enable KVEmbeddingInference when use_virtual_table is set to true Reviewed By: emlin Differential Revision: D75321284 --- .../distributed/quant_embedding_kernel.py | 72 +++++++++++-------- torchrec/quant/embedding_modules.py | 26 ++++--- .../quant/tests/test_embedding_modules.py | 24 +++++++ 3 files changed, 82 insertions(+), 40 deletions(-) diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 4e0dc31f3..18b5dc7f8 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -20,6 +20,7 @@ PoolingMode, rounded_row_size_in_bytes, ) +from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference from torchrec.distributed.batched_embedding_kernel import ( BaseBatchedEmbedding, BaseBatchedEmbeddingBag, @@ -237,6 +238,7 @@ def __init__( super().__init__(config, pg, device) managed: List[EmbeddingLocation] = [] + is_virtual_table: bool = False for table in config.embedding_tables: if device is not None and device.type == "cuda": managed.append( @@ -244,6 +246,8 @@ def __init__( ) else: managed.append(EmbeddingLocation.HOST) + if table.use_virtual_table: + is_virtual_table = True self._config: GroupedEmbeddingConfig = config self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params) self._is_weighted: Optional[bool] = config.is_weighted @@ -284,6 +288,8 @@ def __init__( if self.lengths_to_tbe: tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength + elif is_virtual_table: + tbe_clazz = KVEmbeddingInference else: tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen @@ -448,6 +454,7 @@ def __init__( super().__init__(config, pg, device) managed: List[EmbeddingLocation] = [] + is_virtual_table = False for table in config.embedding_tables: if device is not None and device.type == "cuda": managed.append( @@ -455,6 +462,8 @@ def __init__( ) else: managed.append(EmbeddingLocation.HOST) + if table.use_virtual_table: + is_virtual_table = True self._config: GroupedEmbeddingConfig = config self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params) self._quant_state_dict_split_scale_bias: bool = ( @@ -465,37 +474,40 @@ def __init__( ) # 16 for CUDA, 1 for others like CPU and MTIA. self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( - IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ + embedding_clazz = ( + KVEmbeddingInference + if is_virtual_table + else IntNBitTableBatchedEmbeddingBagsCodegen + ) + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz( + embedding_specs=[ + ( + table.name, + local_rows, ( - table.name, - local_rows, - ( - local_cols - if self._quant_state_dict_split_scale_bias - else table.embedding_dim - ), - data_type_to_sparse_type(table.data_type), - location, - ) - for local_rows, local_cols, table, location in zip( - self._local_rows, - self._local_cols, - config.embedding_tables, - managed, - ) - ], - device=device, - pooling_mode=PoolingMode.NONE, - feature_table_map=self._feature_table_map, - row_alignment=self._tbe_row_alignment, - uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue - feature_names_per_table=[ - table.feature_names for table in config.embedding_tables - ], - **(tbe_fused_params(fused_params) or {}), - ) + local_cols + if self._quant_state_dict_split_scale_bias + else table.embedding_dim + ), + data_type_to_sparse_type(table.data_type), + location, + ) + for local_rows, local_cols, table, location in zip( + self._local_rows, + self._local_cols, + config.embedding_tables, + managed, + ) + ], + device=device, + pooling_mode=PoolingMode.NONE, + feature_table_map=self._feature_table_map, + row_alignment=self._tbe_row_alignment, + uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + feature_names_per_table=[ + table.feature_names for table in config.embedding_tables + ], + **(tbe_fused_params(fused_params) or {}), ) if device is not None: self._emb_module.initialize_weights() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index bcd428a4e..3e979b34d 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -30,6 +30,7 @@ IntNBitTableBatchedEmbeddingBagsCodegen, PoolingMode, ) +from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference from torch import Tensor from torchrec.distributed.utils import none_throws from torchrec.modules.embedding_configs import ( @@ -357,7 +358,7 @@ def __init__( self._is_weighted = is_weighted self._embedding_bag_configs: List[EmbeddingBagConfig] = tables self._key_to_tables: Dict[ - Tuple[PoolingType, DataType, bool], List[EmbeddingBagConfig] + Tuple[PoolingType, bool], List[EmbeddingBagConfig] ] = defaultdict(list) self._feature_names: List[str] = [] self._feature_splits: List[int] = [] @@ -383,15 +384,13 @@ def __init__( key = (table.pooling, table.use_virtual_table) else: key = (table.pooling, False) - # pyre-ignore self._key_to_tables[key].append(table) location = ( EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE ) - for key, emb_configs in self._key_to_tables.items(): - pooling = key[0] + for (pooling, use_virtual_table), emb_configs in self._key_to_tables.items(): embedding_specs = [] weight_lists: Optional[ List[Tuple[torch.Tensor, Optional[torch.Tensor]]] @@ -420,7 +419,12 @@ def __init__( ) feature_table_map.extend([idx] * table.num_features()) - emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_clazz = ( + KVEmbeddingInference + if use_virtual_table + else IntNBitTableBatchedEmbeddingBagsCodegen + ) + emb_module = embedding_clazz( embedding_specs=embedding_specs, pooling_mode=pooling_type_to_pooling_mode(pooling), weight_lists=weight_lists, @@ -790,8 +794,7 @@ def __init__( # noqa C901 key = (table.data_type, False) self._key_to_tables[key].append(table) self._feature_splits: List[int] = [] - for key, emb_configs in self._key_to_tables.items(): - data_type = key[0] + for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items(): embedding_specs = [] weight_lists: Optional[ List[Tuple[torch.Tensor, Optional[torch.Tensor]]] @@ -816,10 +819,13 @@ def __init__( # noqa C901 table_name_to_quantized_weights[table.name] ) feature_table_map.extend([idx] * table.num_features()) - # move to here to make sure feature_names order is consistent with the embedding groups self._feature_names.extend(table.feature_names) - - emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_clazz = ( + KVEmbeddingInference + if use_virtual_table + else IntNBitTableBatchedEmbeddingBagsCodegen + ) + emb_module = embedding_clazz( embedding_specs=embedding_specs, pooling_mode=PoolingMode.NONE, weight_lists=weight_lists, diff --git a/torchrec/quant/tests/test_embedding_modules.py b/torchrec/quant/tests/test_embedding_modules.py index 4f52d89e3..4858901d8 100644 --- a/torchrec/quant/tests/test_embedding_modules.py +++ b/torchrec/quant/tests/test_embedding_modules.py @@ -7,6 +7,7 @@ # pyre-strict +import logging import unittest from dataclasses import replace from typing import Dict, List, Optional, Type @@ -44,6 +45,19 @@ KeyedTensor, ) +logger: logging.Logger = logging.getLogger(__name__) + + +def load_required_dram_kv_embedding_libraries() -> bool: + try: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference" + ) + return True + except Exception as e: + logger.error(f"Failed to load dram_kv_embedding libraries, skipping test: {e}") + return False + class EmbeddingBagCollectionTest(unittest.TestCase): def _asserting_same_embeddings( @@ -260,6 +274,11 @@ def test_multiple_features(self) -> None: ) self._test_ebc([eb1_config, eb2_config], features) + # pyre-ignore: Invalid decoration [56] + @unittest.skipIf( + not load_required_dram_kv_embedding_libraries(), + "Skip when required libraries are not available", + ) def test_multiple_kernels_per_ebc_table(self) -> None: class TestModule(torch.nn.Module): def __init__(self, m: torch.nn.Module) -> None: @@ -780,6 +799,11 @@ def __init__(self, m: torch.nn.Module) -> None: self.assertEqual(config.name, "t2") self.assertEqual(config.data_type, DataType.INT8) + # pyre-ignore: Invalid decoration [56] + @unittest.skipIf( + not load_required_dram_kv_embedding_libraries(), + "Skip when required libraries are not available", + ) def test_multiple_kernels_per_ec_table(self) -> None: class TestModule(torch.nn.Module): def __init__(self, m: torch.nn.Module) -> None: