Skip to content

kvzch use new operator in model publish #3108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,12 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]:
embedding_shard_metadata.append(table.local_metadata)
return embedding_shard_metadata

def is_using_virtual_table(self) -> bool:
return self.compute_kernel in [
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE,
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE,
]


F = TypeVar("F", bound=Multistreamable)
T = TypeVar("T")
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def create_sharding_infos_by_sharding_device_group(
getattr(config, "num_embeddings_post_pruning", None)
# TODO: Need to check if attribute exists for BC
),
use_virtual_table=config.use_virtual_table,
),
param_sharding=parameter_sharding,
param=param,
Expand Down
66 changes: 36 additions & 30 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PoolingMode,
rounded_row_size_in_bytes,
)
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
from torchrec.distributed.batched_embedding_kernel import (
BaseBatchedEmbedding,
BaseBatchedEmbeddingBag,
Expand Down Expand Up @@ -284,6 +285,8 @@ def __init__(

if self.lengths_to_tbe:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
elif config.is_using_virtual_table():
tbe_clazz = KVEmbeddingInference
else:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen

Expand Down Expand Up @@ -465,37 +468,40 @@ def __init__(
)
# 16 for CUDA, 1 for others like CPU and MTIA.
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
embedding_clazz = (
KVEmbeddingInference
if config.is_using_virtual_table()
else IntNBitTableBatchedEmbeddingBagsCodegen
)
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
embedding_specs=[
(
table.name,
local_rows,
(
table.name,
local_rows,
(
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(table.data_type),
location,
)
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
)
],
device=device,
pooling_mode=PoolingMode.NONE,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(table.data_type),
location,
)
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
)
],
device=device,
pooling_mode=PoolingMode.NONE,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
if device is not None:
self._emb_module.initialize_weights()
Expand Down
26 changes: 16 additions & 10 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
IntNBitTableBatchedEmbeddingBagsCodegen,
PoolingMode,
)
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
from torch import Tensor
from torchrec.distributed.utils import none_throws
from torchrec.modules.embedding_configs import (
Expand Down Expand Up @@ -357,7 +358,7 @@ def __init__(
self._is_weighted = is_weighted
self._embedding_bag_configs: List[EmbeddingBagConfig] = tables
self._key_to_tables: Dict[
Tuple[PoolingType, DataType, bool], List[EmbeddingBagConfig]
Tuple[PoolingType, bool], List[EmbeddingBagConfig]
] = defaultdict(list)
self._feature_names: List[str] = []
self._feature_splits: List[int] = []
Expand All @@ -383,15 +384,13 @@ def __init__(
key = (table.pooling, table.use_virtual_table)
else:
key = (table.pooling, False)
# pyre-ignore
self._key_to_tables[key].append(table)

location = (
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
)

for key, emb_configs in self._key_to_tables.items():
pooling = key[0]
for (pooling, use_virtual_table), emb_configs in self._key_to_tables.items():
embedding_specs = []
weight_lists: Optional[
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
Expand Down Expand Up @@ -420,7 +419,12 @@ def __init__(
)
feature_table_map.extend([idx] * table.num_features())

emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_clazz = (
KVEmbeddingInference
if use_virtual_table
else IntNBitTableBatchedEmbeddingBagsCodegen
)
emb_module = embedding_clazz(
embedding_specs=embedding_specs,
pooling_mode=pooling_type_to_pooling_mode(pooling),
weight_lists=weight_lists,
Expand Down Expand Up @@ -790,8 +794,7 @@ def __init__( # noqa C901
key = (table.data_type, False)
self._key_to_tables[key].append(table)
self._feature_splits: List[int] = []
for key, emb_configs in self._key_to_tables.items():
data_type = key[0]
for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items():
embedding_specs = []
weight_lists: Optional[
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
Expand All @@ -816,10 +819,13 @@ def __init__( # noqa C901
table_name_to_quantized_weights[table.name]
)
feature_table_map.extend([idx] * table.num_features())
# move to here to make sure feature_names order is consistent with the embedding groups
self._feature_names.extend(table.feature_names)

emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_clazz = (
KVEmbeddingInference
if use_virtual_table
else IntNBitTableBatchedEmbeddingBagsCodegen
)
emb_module = embedding_clazz(
embedding_specs=embedding_specs,
pooling_mode=PoolingMode.NONE,
weight_lists=weight_lists,
Expand Down
Loading