diff --git a/examples/zch/Readme.md b/examples/zch/Readme.md new file mode 100644 index 000000000..371034bb0 --- /dev/null +++ b/examples/zch/Readme.md @@ -0,0 +1,74 @@ +# Managed Collision Hash Example + +This example demonstrates the usage of managed collision hash feature in TorchRec, which is designed to efficiently handle hash collisions in embedding tables. We include two implementations of the feature: sorted managed collision Hash (MCH) and MPZCH (Multi-Probe Zero Collision Hash). + +## Folder Structure + +``` +managed_collision_hash/ +├── Readme.md # This documentation file +├── __init__.py # Python package marker +├── main.py # Main script to run the benchmark +└── sparse_arch.py # Implementation of the sparse architecture with managed collision +└── zero_collision_hash_tutorial.ipynb # Jupyter notebook for the motivation of zero collision hash and the use of zero collision hash modules in TorchRec +``` + +### Introduction of MPZCH + +Multi-probe Zero Collision Hash (MPZCH) is a technique that can be used to reduce the collision rate for embedding table lookups. For the concept of hash collision and why we need to manage the collision, please refer to the [zero collision hash tutorial](zero_collision_hash_tutorial.ipynb). + +A MPZCH module contains two essential tables: the identity table and the metadata table. +The identity table is used to record the mapping from input hash value to the remapped ID. The value in each identity table slot is an input hash value, and that hash value's remmaped ID is the index of the slot. +The metadata table share the same length as the identity table. The time when a hash value is inserted into a identity table slot is recorded in the same-indexed slot of the metadata table. + +Specifically, MPZCH include the following two steps: +1. **First Probe**: Check if there are available or evictable slots in its identity table. +2. **Second Probe**: Check if the slot for indexed with the input hash value is occupied. If not, directly insert the input hash value into that slot. Otherwise, perform a linear probe to find the next available slot. If all the slots are occupied, find the next evictable slot whose value has stayed in the table for a time longer than a threshold, and replace the expired hash value with the input one. + +The use of MPZCH module `HashZchManagedCollisionModule` are introduced with detailed comments in the [sparse_arch.py](sparse_arch.py) file. + +The module can be configured to use different eviction policies and parameters. + +The detailed function calls are shown in the diagram below +![MPZCH Module Data Flow](docs/mpzch_module_dataflow.png) + +#### Relationship among Important Parameters + +The `HashZchManagedCollisionModule` module has three important parameters for initialization +- `num_embeddings`: the number of embeddings in the embedding table +- `num_buckets`: the number of buckets in the hash table + +The `num_buckets` is used as the minimal sharding unit for the embedding table. Because we are doing linear probe in MPZCH, when resharding the embedding table, we want to avoid separate the remapped index of an input feature ID and its hash value to different ranks. So we make sure they are in the same bucket, and move the whole bucket during resharding. + +## Usage +We also prepare a profiling example of an Sparse Arch implemented with different ZCH techniques. +To run the profiling example with sorted ZCH: + +```bash +python main.py +``` + +To run the profiling example with MPZCH: + +```bash +python main.py --use_mpzch +``` + +You can also specify the `batch_size`, `num_iters`, and `num_embeddings_per_table`: +```bash +python main.py --use_mpzch --batch_size 8 --num_iters 100 --num_embeddings_per_table 1000 +``` + +The example allows you to compare the QPS of embedding operations with sorted ZCH and MPZCH. On our server with A100 GPU, the initial QPS benchmark results with `batch_size=8`, `num_iters=100`, and `num_embeddings_per_table=1000` is presented in the table below: + +| ZCH module | QPS | +| --- | --- | +| sorted ZCH | 1371.6942797862002 | +| MPZCH | 2750.4449443587414 | + +And with `batch_size=1024`, `num_iters=1000`, and `num_embeddings_per_table=1000` is + +| ZCH module | QPS | +| --- | --- | +| sorted ZCH | 263827.54955056956 | +| MPZCH | 551306.9687760604 | diff --git a/examples/zch/__init__.py b/examples/zch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/zch/docs/mpzch_module_dataflow.png b/examples/zch/docs/mpzch_module_dataflow.png new file mode 100644 index 000000000..8ff4ba9e4 Binary files /dev/null and b/examples/zch/docs/mpzch_module_dataflow.png differ diff --git a/examples/zch/main.py b/examples/zch/main.py new file mode 100644 index 000000000..18b114b3b --- /dev/null +++ b/examples/zch/main.py @@ -0,0 +1,131 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import argparse +import time + +import torch + +from torchrec import EmbeddingConfig, KeyedJaggedTensor +from torchrec.distributed.benchmark.benchmark_utils import get_inputs +from tqdm import tqdm + +from .sparse_arch import SparseArch + + +def main(args: argparse.Namespace) -> None: + """ + This function tests the performance of a Sparse module with or without the MPZCH feature. + Arguments: + use_mpzch: bool, whether to enable MPZCH or not + Prints: + duration: time for a forward pass of the Sparse module with or without MPZCH enabled + collision_rate: the collision rate of the MPZCH feature + """ + print(f"Is use MPZCH: {args.use_mpzch}") + + # check available devices + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + # device = torch.device("cpu") + + print(f"Using device: {device}") + + # create an embedding configuration + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=args.num_embeddings_per_table, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=args.num_embeddings_per_table, + ), + ] + + # generate kjt input list + input_kjt_list = [] + for _ in range(args.num_iters): + input_kjt_single = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + # pick a set of 24 random numbers from 0 to args.num_embeddings_per_table + values=torch.LongTensor( + list( + torch.randint( + 0, args.num_embeddings_per_table, (3 * args.batch_size,) + ) + ) + ), + lengths=torch.LongTensor([1] * args.batch_size + [2] * args.batch_size), + weights=None, + ) + input_kjt_single = input_kjt_single.to(device) + input_kjt_list.append(input_kjt_single) + + num_requests = args.num_iters * args.batch_size + + # make the model + model = SparseArch( + tables=embedding_config, + device=device, + return_remapped=True, + use_mpzch=args.use_mpzch, + buckets=1, + ) + + # do the forward pass + if device.type == "cuda": + torch.cuda.synchronize() + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + + # record the start time + starter.record() + for it_idx in tqdm(range(args.num_iters)): + # ec_out, remapped_ids_out = model(input_kjt_single) + input_kjt = input_kjt_list[it_idx].to(device) + ec_out, remapped_ids_out = model(input_kjt) + # record the end time + ender.record() + # wait for the end time to be recorded + torch.cuda.synchronize() + duration = starter.elapsed_time(ender) / 1000.0 # convert to seconds + else: + # in cpu mode, MPZCH can only run in inference mode, so we profile the model in eval mode + model.eval() + if args.use_mpzch: + # when using MPZCH modules, we need to manually set the modules to be in inference mode + # pyre-ignore + model._mc_ec._managed_collision_collection._managed_collision_modules[ + "table_0" + ].reset_inference_mode() + # pyre-ignore + model._mc_ec._managed_collision_collection._managed_collision_modules[ + "table_1" + ].reset_inference_mode() + + start_time = time.time() + for it_idx in tqdm(range(args.num_iters)): + input_kjt = input_kjt_list[it_idx].to(device) + ec_out, remapped_ids_out = model(input_kjt) + end_time = time.time() + duration = end_time - start_time + # get qps + qps = num_requests / duration + print(f"qps: {qps}") + # print the duration + print(f"duration: {duration} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--use_mpzch", action="store_true", default=False) + parser.add_argument("--num_iters", type=int, default=100) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--num_embeddings_per_table", type=int, default=1000) + args: argparse.Namespace = parser.parse_args() + main(args) diff --git a/examples/zch/sparse_arch.py b/examples/zch/sparse_arch.py new file mode 100644 index 000000000..b8be4abaa --- /dev/null +++ b/examples/zch/sparse_arch.py @@ -0,0 +1,137 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from torchrec import ( + EmbeddingCollection, + EmbeddingConfig, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) + +# For MPZCH +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) + +# For MPZCH +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection + +# For original MC +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) + +""" +Class SparseArch +An example of SparseArch with 2 tables, each with 2 features. +It looks up the corresponding embedding for incoming KeyedJaggedTensors with 2 features +and returns the corresponding embeddings. + +Parameters: + tables(List[EmbeddingConfig]): List of EmbeddingConfig that defines the embedding table + device(torch.device): device on which the embedding table should be placed + buckets(int): number of buckets for each table + input_hash_size(int): input hash size for each table + return_remapped(bool): whether to return remapped features, if so, the return will be + a tuple of (Embedding(KeyedTensor), Remapped_ID(KeyedJaggedTensor)), otherwise, the return will be + a tuple of (Embedding(KeyedTensor), None) + is_inference(bool): whether to use inference mode. In inference mode, the module will not update the embedding table + use_mpzch(bool): whether to use MPZCH or not. If true, the module will use MPZCH managed collision module, + otherwise, it will use original MC managed collision module +""" + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + buckets: int = 4, + input_hash_size: int = 4000, + return_remapped: bool = False, + is_inference: bool = False, + use_mpzch: bool = False, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + mc_modules = {} + + if ( + use_mpzch + ): # if using the MPZCH module, we create a HashZchManagedCollisionModule for each table + mc_modules["table_0"] = HashZchManagedCollisionModule( + is_inference=is_inference, + zch_size=( + tables[0].num_embeddings + ), # the zch size, that is, the size of local embedding table, should be the same as the size of the embedding table + input_hash_size=input_hash_size, # the input hash size, that is, the size of the input id space + device=device, # the device on which the embedding table should be placed + total_num_buckets=buckets, # the number of buckets, the detailed explanation of the use of buckets can be found in the readme file + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # the eviction policy name, in this example use the single ttl eviction policy, which assume an id is evictable if it has been in the table longer than the ttl (time to live) + eviction_config=HashZchEvictionConfig( # Here we need to specify for each feature, what is the ttl, that is, how long an id can stay in the table before it is evictable + features=[ + "feature_0" + ], # because we only have one feature "feature_0" in this table, so we only need to specify the ttl for this feature + single_ttl=1, # The unit of ttl is hour. Let's set the ttl to be default to 1, which means an id is evictable if it has been in the table for more than one hour. + ), + ) + mc_modules["table_1"] = HashZchManagedCollisionModule( + is_inference=is_inference, + zch_size=(tables[1].num_embeddings), + device=device, + input_hash_size=input_hash_size, + total_num_buckets=buckets, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature_1"], + single_ttl=1, + ), + ) + else: # if not using the MPZCH module, we create a MCHManagedCollisionModule for each table + mc_modules["table_0"] = MCHManagedCollisionModule( + zch_size=(tables[0].num_embeddings), + input_hash_size=input_hash_size, + device=device, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + mc_modules["table_1"] = MCHManagedCollisionModule( + zch_size=(tables[1].num_embeddings), + device=device, + input_hash_size=input_hash_size, + eviction_interval=1, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + self._mc_ec: ManagedCollisionEmbeddingCollection = ( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) + ) + + def forward( + self, kjt: KeyedJaggedTensor + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + return self._mc_ec(kjt) diff --git a/examples/zch/zero_collision_hash_tutorial.ipynb b/examples/zch/zero_collision_hash_tutorial.ipynb new file mode 100644 index 000000000..547601897 --- /dev/null +++ b/examples/zch/zero_collision_hash_tutorial.ipynb @@ -0,0 +1,511 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Zero-collision Hash Tutorial\n", + "This example notebook goes through the following topics:\n", + "- Why do we need zero-collision hash?\n", + "- How to use the zero-collision module in TorchRec?\n", + "\n", + "## Pre-requisite\n", + "Before dive into the details, let's import all the necessary packages first. This needs you to [have the latest `torchrec` library installed](https://docs.pytorch.org/torchrec/setup-torchrec.html#installation)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1181435817001907, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0611 161033.883 _utils_internal.py:282] NCCL_DEBUG env var is set to None\n", + "I0611 161033.885 _utils_internal.py:291] NCCL_DEBUG is WARN from /etc/nccl.conf\n", + "I0611 161039.736 pyper_torch_elastic_logging_utils.py:234] initialized PyperTorchElasticEventHandler\n" + ] + } + ], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torchrec import (\n", + " EmbeddingCollection,\n", + " EmbeddingConfig,\n", + " JaggedTensor,\n", + " KeyedJaggedTensor,\n", + " KeyedTensor,\n", + ")\n", + "\n", + "from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection\n", + "\n", + "from torchrec.modules.mc_modules import (\n", + " DistanceLFU_EvictionPolicy,\n", + " ManagedCollisionCollection,\n", + " MCHManagedCollisionModule,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hash and Zero Collision Hash\n", + "In this section, we present the motivation that\n", + "- Why do we need to perform hash on incoming features?\n", + "- Why do we need to implement zero-collision hash?\n", + "\n", + "Let's first take a look in the question that why do we need to perform hashing for sparse feature inputs in the recommendation model? \n", + "We firstly create an embedding table of 1000 embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define the number of embeddings\n", + "num_embeddings = 1000\n", + "table_config = EmbeddingConfig(\n", + " name=\"t1\",\n", + " embedding_dim=16,\n", + " num_embeddings=1000,\n", + " feature_names=[\"f1\"],\n", + ")\n", + "ec = EmbeddingCollection(tables=[table_config])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Usually, for each input sparse feature ID, we regard it as the index of the embedding in the embedding table, and fetch the embedding at the corresponding slot in the embedding table. However, while embedding tables is fixed when instantiating the models, the number of sparse features, such as tags of videos, can keep growing. After a while, the ID of a sparse feature can be larger the size of our embedding table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feature_id = num_embeddings + 1\n", + "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n", + " keys=[\"f1\"],\n", + " values=torch.tensor([feature_id]),\n", + " lengths=torch.tensor([1]),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At that point, the query will lead to an `index out of range` error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1225052112737471, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Query the embedding table of size 1000 with sparse feature ID tensor([1001])\n", + "This query throws an IndexError: index out of range in self\n" + ] + } + ], + "source": [ + "try:\n", + " feature_embedding = ec(input_kjt)\n", + "except IndexError as e:\n", + " print(f\"Query the embedding table of size {num_embeddings} with sparse feature ID {input_kjt['f1'].values()}\")\n", + " print(f\"This query throws an IndexError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To avoid this error from happening, we hash the sparse feature ID to a value within the range of the embedding table size, and use the hashed value as the feature ID to query the embedding table. \n", + "\n", + "For the purpose of demonstration, we use Python's built-in hash function to hash an integer (which will not change the value) and remap it to the range of `[0, num_embeddings)` by taking the modulo of `num_embeddings`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def remap(input_jt_value: int, num_embeddings: int):\n", + " input_hash = hash(input_jt_value)\n", + " return input_hash % num_embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can query the embedding table with the remapped id without error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 990950286247121, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Query the embedding table of size 1000 with remapped sparse feature ID 1 from original ID 1001\n", + "This query does not throw an IndexError, and returns the embedding of the remapped ID: {'f1': }\n" + ] + } + ], + "source": [ + "remapped_id = remap(feature_id, num_embeddings)\n", + "remapped_kjt = KeyedJaggedTensor.from_lengths_sync(\n", + " keys=[\"f1\"],\n", + " values=torch.tensor([remapped_id]),\n", + " lengths=torch.tensor([1]),\n", + ")\n", + "feature_embedding = ec(remapped_kjt)\n", + "print(f\"Query the embedding table of size {num_embeddings} with remapped sparse feature ID {remapped_id} from original ID {feature_id}\")\n", + "print(f\"This query does not throw an IndexError, and returns the embedding of the remapped ID: {feature_embedding}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After answering the first question: __Why do we need to perform hash on incoming features?__, now we can answer the second question: __Why do we need to implement zero-collision hash?__\n", + "\n", + "Because we are casting a larger range of values into a small range, there will be some values being remapped to the same index. For example, using our `remap` function, it will give the same remapped id for feature `num_embeddings + 1` and `1`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1024965419837378, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "feature ID 1 is remapped to ID 1\n", + "feature ID 1001 is remapped to ID 1\n", + "Check if remapped feature ID 1 and 1 are the same: True\n" + ] + } + ], + "source": [ + "feature_id_1 = 1\n", + "feature_id_2 = num_embeddings + 1\n", + "remapped_feature_id_1 = remap(feature_id_1, num_embeddings)\n", + "remapped_feature_id_2 = remap(feature_id_2, num_embeddings)\n", + "print(f\"feature ID {feature_id_1} is remapped to ID {remapped_feature_id_1}\")\n", + "print(f\"feature ID {feature_id_2} is remapped to ID {remapped_feature_id_2}\")\n", + "print(f\"Check if remapped feature ID {remapped_feature_id_1} and {remapped_feature_id_2} are the same: {remapped_feature_id_1 == remapped_feature_id_2}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, two totally different features can share the same embedding. The situation when two feature IDs share the same remapped ID is called a **hash collision**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 923188026604331, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Embedding of feature ID 1 is tensor([ 0.0232, 0.0075, 0.0281, -0.0195, -0.0301, 0.0033, 0.0303, 0.0294,\n", + " 0.0301, -0.0287, -0.0130, -0.0194, 0.0263, 0.0287, 0.0261, -0.0080],\n", + " grad_fn=)\n", + "Embedding of feature ID 1 is tensor([ 0.0232, 0.0075, 0.0281, -0.0195, -0.0301, 0.0033, 0.0303, 0.0294,\n", + " 0.0301, -0.0287, -0.0130, -0.0194, 0.0263, 0.0287, 0.0261, -0.0080],\n", + " grad_fn=)\n", + "Check if the embeddings of feature ID 1 and 1 are the same: True\n" + ] + } + ], + "source": [ + "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n", + " keys=[\"f1\"],\n", + " values=torch.tensor([remapped_feature_id_1, remapped_feature_id_2]),\n", + " lengths=torch.tensor([1, 1]),\n", + ")\n", + "feature_embeddings = ec(input_kjt)\n", + "feature_id_1_embedding = feature_embeddings[\"f1\"].values()[0]\n", + "feature_id_2_embedding = feature_embeddings[\"f1\"].values()[1]\n", + "print(f\"Embedding of feature ID {remapped_feature_id_1} is {feature_id_1_embedding}\")\n", + "print(f\"Embedding of feature ID {remapped_feature_id_2} is {feature_id_2_embedding}\")\n", + "print(f\"Check if the embeddings of feature ID {remapped_feature_id_1} and {remapped_feature_id_2} are the same: {torch.equal(feature_id_1_embedding, feature_id_2_embedding)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Making two different (and potentially totally irrelavant) features share the same embedding will cause inaccurate recommendations.\n", + "Lukily, for many sparse features, though their range can be larger than the the embedding table size, their IDs are sparsely located on the range.\n", + "In some other cases, the embedding table may only receive frequent queries for a subset of the features.\n", + "So we can design some __managed collision hash__ modules to avoid the hash collision from happening." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TorchRec Zero Collision Hash Modules\n", + "\n", + "TorchRec implements managed collision hash strategies such as *sorted zero collision hash* and *multi-probe zero collision hash (MPZCH)*.\n", + "\n", + "They help hash and remap the feature IDs to embedding table indices with (near-)zero collisions.\n", + "\n", + "In the following content we will use the MPZCH module as an example for how to use the zero-collision modules in TorchRec. The name of the MPZCH module is `HashZchManagedCollisionModule`.\n", + "\n", + "Let's assume we have two tables: `table_0` and `table_1`, each with embeddings for `feature_0` and `feature_1`, respectively." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define the table sizes\n", + "num_embeddings_table_0 = 1000\n", + "num_embeddings_table_1 = 2000\n", + "\n", + "# create table configs\n", + "table_0_config = EmbeddingConfig(\n", + " name=\"table_0\",\n", + " embedding_dim=16,\n", + " num_embeddings=num_embeddings_table_0,\n", + " feature_names=[\"feature_0\"],\n", + ")\n", + "\n", + "table_1_config = EmbeddingConfig(\n", + " name=\"table_1\",\n", + " embedding_dim=16,\n", + " num_embeddings=num_embeddings_table_1,\n", + " feature_names=[\"feature_1\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before turning the table configurations into embedding table collection, we instantiate our managed collision modules.\n", + "\n", + "The managed collision modules for a collection of embedding tables are intended to format as a dictionary with `{table_name: mc_module_for_the_table}`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mc_modules = {}\n", + "\n", + "# Instantiate the module, we provide detailed comments on\n", + "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #\n", + "input_hash_size = 10000\n", + "mc_modules[\"table_0\"] = MCHManagedCollisionModule(\n", + " zch_size=(table_0_config.num_embeddings),\n", + " input_hash_size=input_hash_size,\n", + " device=device,\n", + " eviction_interval=2,\n", + " eviction_policy=DistanceLFU_EvictionPolicy(),\n", + " )\n", + "mc_modules[\"table_1\"] = MCHManagedCollisionModule(\n", + " zch_size=(table_1_config.num_embeddings),\n", + " device=device,\n", + " input_hash_size=input_hash_size,\n", + " eviction_interval=1,\n", + " eviction_policy=DistanceLFU_EvictionPolicy(),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For embedding tables with managed collision modules, TorchRec uses a wrapper module `ManagedCollisionEmbeddingCollection` that contains both the embedding table collections and the managed collision modules. Users only need to pass their table configurations and" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mc_ec = ManagedCollisionEmbeddingCollection = (\n", + " ManagedCollisionEmbeddingCollection(\n", + " EmbeddingCollection(\n", + " tables=[\n", + " table_0_config,\n", + " table_1_config\n", + " ],\n", + " device=device,\n", + " ),\n", + " ManagedCollisionCollection(\n", + " managed_collision_modules=mc_modules,\n", + " embedding_configs=[\n", + " table_0_config,\n", + " table_1_config\n", + " ],\n", + " ),\n", + " return_remapped_features=True, # whether to return the remapped feature IDs\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `ManagedCollisionEmbeddingCollection` module will perform remapping and table look-up for the input. Users only need to pass the keyyed jagged tensor queries into the module." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1363945501556497, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "feature name: feature_0, feature jt: JaggedTensor({\n", + " [[1000], [10001]]\n", + "})\n", + "\n", + "feature jt values: tensor([ 1000, 10001])\n", + "feature name: feature_1, feature jt: JaggedTensor({\n", + " [[2000], [20001]]\n", + "})\n", + "\n", + "feature jt values: tensor([ 2000, 20001])\n" + ] + } + ], + "source": [ + "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n", + " keys=[\"feature_0\", \"feature_1\"],\n", + " values=torch.tensor([1000, 10001, 2000, 20001]),\n", + " lengths=torch.tensor([1, 1, 1, 1]),\n", + ")\n", + "for feature_name, feature_jt in input_kjt.to_dict().items():\n", + " print(f\"feature name: {feature_name}, feature jt: {feature_jt}\")\n", + " print(f\"feature jt values: {feature_jt.values()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1555795345806679, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "feature name: feature_0, feature embedding: JaggedTensor({\n", + " [[[0.022659072652459145, 0.0053002419881522655, -0.025007368996739388, -0.013145492412149906, -0.031139537692070007, -0.01486812811344862, -0.01133741531521082, 0.0027838051319122314, 0.026786740869283676, -0.010626785457134247, 0.01148480549454689, 0.02036162279546261, 0.013492186553776264, -0.024412740021944046, 0.01599711738526821, -0.02390478551387787]], [[-0.029269251972436905, 0.01744556427001953, 0.024260954931378365, 0.029459983110427856, -0.026435773819684982, -0.0034603318199515343, -0.007642757147550583, -0.02111411839723587, 0.027316255494952202, 0.015309474430978298, 0.03137263283133507, 0.01699884422123432, 0.02302604913711548, -0.015266639180481434, -0.019045181572437286, 0.006964980624616146]]]\n", + "})\n", + "\n", + "feature name: feature_1, feature embedding: JaggedTensor({\n", + " [[[0.009506281465291977, 0.012826820835471153, -0.0017535268561914563, -0.0009170559933409095, -0.014913717284798622, 0.0040654330514371395, -0.011355634778738022, 0.008443576283752918, 0.0007347835344262421, -0.00907053705304861, 0.010160156525671482, 0.016830360516905785, 0.002154064131900668, -0.010799579322338104, -0.0197420883923769, -0.0025849745143204927]], [[-0.020103629678487778, 0.01041398011147976, -0.01699216105043888, 0.01291638519614935, 0.018798867240548134, 0.01616138033568859, -0.020600538700819016, -0.017695769667625427, 0.0100017711520195, -0.010470695793628693, -0.018935278058052063, -0.011798662133514881, -0.014235826209187508, -0.01985463872551918, 0.009744714014232159, -0.004050525836646557]]]\n", + "})\n", + "\n", + "feature name: feature_0, feature jt values: tensor([997, 998], device='cuda:0')\n", + "feature name: feature_1, feature jt values: tensor([1997, 1998], device='cuda:0')\n" + ] + } + ], + "source": [ + "output_embeddings, remapped_ids = mc_ec(input_kjt.to(device))\n", + "# show output embeddings\n", + "for feature_name, feature_embedding in output_embeddings.items():\n", + " print(f\"feature name: {feature_name}, feature embedding: {feature_embedding}\")\n", + "# show remapped ids\n", + "for feature_name, feature_jt in remapped_ids.to_dict().items():\n", + " print(f\"feature name: {feature_name}, feature jt values: {feature_jt.values()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have a basic example of how to use the managed collision modules in TorchRec. \n", + "\n", + "We also provide a profiling example to compare the efficiency of sorted ZCH and MPZCH modules. Check the [Readme](Readme.md) file for more details." + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "ee3845a2-a85b-4a8e-8c42-9ce4690c9956", + "isAdHoc": false, + "kernelspec": { + "display_name": "torchrec", + "language": "python", + "name": "bento_kernel_torchrec" + }, + "language_info": { + "name": "plaintext" + }, + "orig_nbformat": 4 + } +} diff --git a/torchrec/distributed/benchmark/benchmark_zch_dlrmv2.py b/torchrec/distributed/benchmark/benchmark_zch_dlrmv2.py new file mode 100644 index 000000000..2cd41501c --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_zch_dlrmv2.py @@ -0,0 +1,936 @@ +import argparse +import csv +import json +import multiprocessing +import os +import sys +import time + +from typing import cast, Dict, Iterator, List, Optional + +import numpy as np + +import torch +import torch.nn as nn +import torchmetrics # @manual=fbsource//third-party/pypi/torchmetrics:torchmetrics +from pyre_extensions import none_throws +from torch import distributed as dist +from torch.utils.data import DataLoader + +from torchrec.datasets.criteo import ( + CAT_FEATURE_COUNT, + DAYS, + DEFAULT_CAT_NAMES, + DEFAULT_INT_NAMES, + InMemoryBinaryCriteoIterDataPipe, +) +from torchrec.distributed.comm import get_local_size +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder + +from torchrec.distributed.model_parallel import ( + DistributedModelParallel, + get_default_sharders, +) +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology + +from torchrec.distributed.planner.storage_reservations import ( + HeuristicalStorageReservation, +) + +from torchrec.distributed.types import ModuleSharder +from torchrec.models.dlrm import DLRM, DLRMTrain +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.mc_adapter import McEmbeddingBagCollectionAdapter + +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) + +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward + +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.test_utils import get_free_port +from tqdm import tqdm + +from .benchmark_zch_utils import BenchmarkMCProbe + +from .data.dlrm_dataloader import get_dataloader + + +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="torchrec dlrm example trainer") + parser.add_argument( + "--epochs", + type=int, + default=1, + help="number of epochs to train", + ) + parser.add_argument( + "--batch_size", + type=int, + default=4096, + help="batch size to use for training", + ) + parser.add_argument( + "--drop_last_training_batch", + dest="drop_last_training_batch", + action="store_true", + help="Drop the last non-full training batch", + ) + parser.add_argument( + "--test_batch_size", + type=int, + default=None, + help="batch size to use for validation and testing", + ) + parser.add_argument( + "--limit_train_batches", + type=int, + default=None, + help="number of train batches", + ) + parser.add_argument( + "--limit_val_batches", + type=int, + default=None, + help="number of validation batches", + ) + parser.add_argument( + "--limit_test_batches", + type=int, + default=None, + help="number of test batches", + ) + parser.add_argument( + "--dataset_name", + type=str, + choices=["criteo_1t", "criteo_kaggle"], + default="criteo_kaggle", + help="dataset for experiment, current support criteo_1tb, criteo_kaggle", + ) + parser.add_argument( + "--num_embeddings", # ratio of feature ids to embedding table size # 3 axis: x-bath_idx; y-collisions; zembedding table sizes + type=int, + default=100_000, + help="max_ind_size. The number of embeddings in each embedding table. Defaults" + " to 100_000 if num_embeddings_per_feature is not supplied.", + ) + parser.add_argument( + "--num_embeddings_per_feature", + type=str, + default=None, + help="Comma separated max_ind_size per sparse feature. The number of embeddings" + " in each embedding table. 26 values are expected for the Criteo dataset.", + ) + parser.add_argument( + "--dense_arch_layer_sizes", + type=str, + default="512,256,64", + help="Comma separated layer sizes for dense arch.", + ) + parser.add_argument( + "--over_arch_layer_sizes", + type=str, + default="512,512,256,1", + help="Comma separated layer sizes for over arch.", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=64, + help="Size of each embedding.", + ) + parser.add_argument( + "--interaction_branch1_layer_sizes", + type=str, + default="2048,2048", + help="Comma separated layer sizes for interaction branch1 (only on dlrm with projection).", + ) + parser.add_argument( + "--interaction_branch2_layer_sizes", + type=str, + default="2048,2048", + help="Comma separated layer sizes for interaction branch2 (only on dlrm with projection).", + ) + parser.add_argument( + "--dcn_num_layers", + type=int, + default=3, + help="Number of DCN layers in interaction layer (only on dlrm with DCN).", + ) + parser.add_argument( + "--dcn_low_rank_dim", + type=int, + default=512, + help="Low rank dimension for DCN in interaction layer (only on dlrm with DCN).", + ) + parser.add_argument( + "--undersampling_rate", + type=float, + help="Desired proportion of zero-labeled samples to retain (i.e. undersampling zero-labeled rows)." + " Ex. 0.3 indicates only 30pct of the rows with label 0 will be kept." + " All rows with label 1 will be kept. Value should be between 0 and 1." + " When not supplied, no undersampling occurs.", + ) + parser.add_argument( + "--seed", + type=int, + help="Random seed for reproducibility.", + default=0, + ) + parser.add_argument( + "--pin_memory", + dest="pin_memory", + action="store_true", + help="Use pinned memory when loading data.", + ) + parser.add_argument( + "--mmap_mode", + dest="mmap_mode", + action="store_true", + help="--mmap_mode mmaps the dataset." + " That is, the dataset is kept on disk but is accessed as if it were in memory." + " --mmap_mode is intended mostly for faster debugging. Use --mmap_mode to bypass" + " preloading the dataset when preloading takes too long or when there is " + " insufficient memory available to load the full dataset.", + ) + parser.add_argument( + "--in_memory_binary_criteo_path", + type=str, + default=None, + help="Directory path containing the Criteo dataset npy files.", + ) + parser.add_argument( + "--synthetic_multi_hot_criteo_path", + type=str, + default=None, + help="Directory path containing the MLPerf v2 synthetic multi-hot dataset npz files.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=15.0, + help="Learning rate.", + ) + parser.add_argument( + "--eps", + type=float, + default=1e-8, + help="Epsilon for Adagrad optimizer.", + ) + parser.add_argument( + "--shuffle_batches", + dest="shuffle_batches", + action="store_true", + help="Shuffle each batch during training.", + ) + parser.add_argument( + "--shuffle_training_set", + dest="shuffle_training_set", + action="store_true", + help="Shuffle the training set in memory. This will override mmap_mode", + ) + parser.add_argument( + "--validation_freq_within_epoch", + type=int, + default=None, + help="Frequency at which validation will be run within an epoch.", + ) + parser.set_defaults( + pin_memory=None, + mmap_mode=None, + drop_last=None, + shuffle_batches=None, + shuffle_training_set=None, + ) + parser.add_argument( + "--adagrad", + dest="adagrad", + action="store_true", + help="Flag to determine if adagrad optimizer should be used.", + ) + parser.add_argument( + "--collect_multi_hot_freqs_stats", + dest="collect_multi_hot_freqs_stats", + action="store_true", + help="Flag to determine whether to collect stats on freq of embedding access.", + ) + parser.add_argument( + "--multi_hot_sizes", + type=str, + default=None, + help="Comma separated multihot size per sparse feature. 26 values are expected for the Criteo dataset.", + ) + parser.add_argument( + "--multi_hot_distribution_type", + type=str, + choices=["uniform", "pareto"], + default=None, + help="Multi-hot distribution options.", + ) + parser.add_argument("--lr_warmup_steps", type=int, default=0) + parser.add_argument("--lr_decay_start", type=int, default=0) + parser.add_argument("--lr_decay_steps", type=int, default=0) + parser.add_argument( + "--print_lr", + action="store_true", + help="Print learning rate every iteration.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help="Enable TensorFloat-32 mode for matrix multiplications on A100 (or newer) GPUs.", + ) + parser.add_argument( + "--print_sharding_plan", + action="store_true", + help="Print the sharding plan used for each embedding table.", + ) + parser.add_argument( + "--input_hash_size", + type=int, + default=100_000, + help="Input feature value range", + ) + parser.add_argument( + "--num_buckets", + type=int, + default=4, + help="Number of buckets for identity table", + ) + parser.add_argument( + "--profiling_result_folder", + type=str, + default="profiling_result", + help="Folder to save profiling results", + ) + parser.add_argument( + "--use_zch", + action="store_true", + help="If use zch or not", + ) + return parser.parse_args(argv) + + +def hash_kjt( + sparse_features: KeyedJaggedTensor, num_embeddings: int +) -> KeyedJaggedTensor: + """ + convert the values in the input sparse_features to hashed ones in the range of [0, num_embeddings) + """ + hashed_feature_values_dict = {} # {feature_name: hashed_feature_values} + for feature_name, feature_values_jt in sparse_features.to_dict().items(): + hashed_feature_values_dict[feature_name] = [] + for feature_value in feature_values_jt.values(): + feature_value = feature_value.unsqueeze(0) # convert to [1, ] + feature_value = feature_value.to(torch.uint64) # convert to uint64 + hashed_feature_value = torch.ops.fbgemm.murmur_hash3(feature_value, 0, 0) + # convert to int64 + hashed_feature_value = hashed_feature_value.to( + torch.int64 + ) # convert to int64 + # convert to [0, num_embeddings) + hashed_feature_value = ( + hashed_feature_value % num_embeddings + ) # convert to [0, num_embeddings) + # convert to [1, ] + hashed_feature_value = hashed_feature_value.unsqueeze(0) # convert to [1, ] + hashed_feature_values_dict[feature_name].append(hashed_feature_value) + hashed_feature_values_dict[feature_name] = JaggedTensor.from_dense( + hashed_feature_values_dict[feature_name] + ) + # convert to [batch_size, ] + hashed_feature_kjt = KeyedJaggedTensor.from_jt_dict(hashed_feature_values_dict) + return hashed_feature_kjt + + +# def hash_kjt( +# sparse_features: KeyedJaggedTensor, num_embeddings: int +# ) -> KeyedJaggedTensor: +# """ +# convert the values in the input sparse_features to hashed ones in the range of [0, num_embeddings) +# """ +# hashed_feature_values_dict = {} # {feature_name: hashed_feature_values} +# for feature_name, feature_values_jt in sparse_features.to_dict().items(): +# hashed_feature_values_dict[feature_name] = [] +# feature_values = feature_values_jt.values() +# feature_value = feature_values.to(torch.uint64) # convert to uint64 +# hashed_feature_value = torch.ops.fbgemm.murmur_hash3(feature_value, 0, 0) +# # convert to int64 +# hashed_feature_value = hashed_feature_value.to(torch.int64) # convert to int64 +# # convert to [0, num_embeddings) +# hashed_feature_value = ( +# hashed_feature_value % num_embeddings +# ) # convert to [0, num_embeddings) +# # convert to [1, ] +# # hashed_feature_value = hashed_feature_value.unsqueeze(0) # convert to [1, ] +# # hashed_feature_values_dict[feature_name].append(hashed_feature_value) +# hashed_feature_values_dict[feature_name] = JaggedTensor.from_dense( +# hashed_feature_value +# ) +# # convert to [batch_size, ] +# hashed_feature_kjt = KeyedJaggedTensor.from_jt_dict(hashed_feature_values_dict) +# return hashed_feature_kjt + + +def main(rank: int, args: argparse.Namespace, queue: multiprocessing.Queue) -> None: + # seed everything for reproducibility + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # convert input hash size to num_embeddings if not using zch + if not args.use_zch: + args.input_hash_size = args.num_embeddings + + # setup environment + os.environ["RANK"] = str(rank) + if torch.cuda.is_available(): + device: torch.device = torch.device(f"cuda:{rank}") + backend = "nccl" + torch.cuda.set_device(device) + else: + device: torch.device = torch.device("cpu") + backend = "gloo" + dist.init_process_group(backend=backend, init_method="env://") + + # TEST FOR DATASET HASH + # train_dataloader = get_dataloader(args, backend, "train") + # for batch in train_dataloader: + # batch = batch.to(device) + # print("before hash", batch.sparse_features.to_dict()["cat_0"].values()[:5]) + # batch.sparse_features = hash_kjt(batch.sparse_features, args.num_embeddings) + # print("after hash", batch.sparse_features.to_dict()["cat_0"].values()[:5]) + # break + + # exit(0) + + # END TEST FOR DATASET HASH + + # get dataset + train_dataloader = get_dataloader(args, backend, "train") + val_dataloader = get_dataloader(args, backend, "val") + test_dataloader = get_dataloader(args, backend, "test") + + # create embedding configs + ebc_configs = [ + EmbeddingBagConfig( + name=f"t_{feature_name}", + embedding_dim=args.embedding_dim, + num_embeddings=( + none_throws(args.num_embeddings_per_feature)[feature_idx] + if args.num_embeddings is None + else args.num_embeddings + ), + feature_names=[feature_name], + ) + for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) + ] + + # get managed collision embedding bag collection + if args.use_zch: + ebc = ( + McEmbeddingBagCollectionAdapter( # TODO: add switch for other ZCH or no ZCH + tables=ebc_configs, + input_hash_size=args.input_hash_size, + device=torch.device("meta"), + world_size=get_local_size(), + use_mpzch=True, + mpzch_num_buckets=args.num_buckets, + ) + ) + else: + ebc = EmbeddingBagCollection(tables=ebc_configs, device=torch.device("meta")) + + # create model + dlrm_model = DLRM( + embedding_bag_collection=ebc, + dense_in_features=len(DEFAULT_INT_NAMES), + dense_arch_layer_sizes=[int(x) for x in args.dense_arch_layer_sizes.split(",")], + over_arch_layer_sizes=[int(x) for x in args.over_arch_layer_sizes.split(",")], + dense_device=device, + ) + + print(dlrm_model) + train_model = DLRMTrain(dlrm_model) + + # apply optimizer to the model + embedding_optimizer = torch.optim.Adagrad if args.adagrad else torch.optim.SGD + optimizer_kwargs = {"lr": args.learning_rate} + if args.adagrad: + optimizer_kwargs["eps"] = args.eps + apply_optimizer_in_backward( + embedding_optimizer, + train_model.model.sparse_arch.embedding_bag_collection.parameters(), + optimizer_kwargs, + ) + + # shard the model + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=get_local_size(), + world_size=dist.get_world_size(), + compute_device=device.type, + ), + batch_size=args.batch_size, + # If experience OOM, increase the percentage. see + # https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation + storage_reservation=HeuristicalStorageReservation(percentage=0.05), + ) + + sharders = get_default_sharders() + sharders.append(cast(ModuleSharder[nn.Module], ManagedCollisionCollectionSharder())) + + plan = planner.collective_plan(train_model, sharders, dist.GroupMember.WORLD) + + model = DistributedModelParallel( + module=train_model, + device=device, + plan=plan, + ) + + collision_remapping_tensor_dict = ( + {} + ) # feature_name: collision_tensor filled with all -1s with num_embedding size at the beginning, used only for non-zch case + benchmark_probe = None + if args.use_zch: + benchmark_probe = BenchmarkMCProbe( + mcec=model._dmp_wrapped_module.module.model.sparse_arch.embedding_bag_collection.mc_embedding_bag_collection._managed_collision_collection._managed_collision_modules, + mc_method="mpzch", + rank=rank, + ) + + interval_num_batches_show_qps = 50 + + total_time_in_training = 0 + total_num_queries_in_training = 0 + + # train the model + for epoch_idx in range(args.epochs): + model.train() + starter_list = [] + ender_list = [] + num_queries_per_batch_list = [] + pbar = tqdm(train_dataloader, desc=f"Epoch {epoch_idx}") + for batch_idx, batch in enumerate(pbar): + batch = batch.to(device) + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( + enable_timing=True + ) + if args.use_zch: + benchmark_probe.record_mcec_state(stage="before_fwd") + # forward pass + starter.record() + loss, outputs = model(batch) + ender.record() + loss.backward() + # do statistics + num_queries_per_batch = len(batch.labels) + starter_list.append(starter) + ender_list.append(ender) + num_queries_per_batch_list.append(num_queries_per_batch) + if args.use_zch: + benchmark_probe.record_mcec_state(stage="after_fwd") + # update zch statistics + benchmark_probe.update() + # push the zch stats to the queue + msg_content = { + "epoch_idx": epoch_idx, + "batch_idx": batch_idx, + "rank": rank, + "mch_stats": benchmark_probe.get_mch_stats(), + } + queue.put( + ("mch_stats", msg_content), + ) + if batch_idx % interval_num_batches_show_qps == 0 or batch_idx == len( + train_dataloader + ): + if batch_idx == 0: + # skip the first batch since it is not a full batch + continue + # synchronize all the threads to get the exact number of batches + torch.cuda.synchronize() + # calculate the qps + # NOTE: why do this qps calculation every interval_num_batches_show_qps batches? + # because performing this calculation needs to synchronize all the ranks by calling torch.cuda.synchronize() + # and this is a heavy operation (takes several milliseconds). So we only do this calculation every + # interval_num_batches_show_qps batches to reduce the overhead. + ## get per batch time list by calculating the time difference between the start and end events of each batch + per_batch_time_list = [] + for i in range(interval_num_batches_show_qps): + per_batch_time_list.append( + starter_list[i].elapsed_time(ender_list[i]) / 1000 + ) # convert to seconds by dividing by 1000 + ## calculate the total time in the interval + total_time_in_interval = sum(per_batch_time_list) + ## calculate the total number of queries in the interval + total_num_queries_in_interval = sum(num_queries_per_batch_list) + ## fabricate the message and total_num_queries_in_interval to the queue + interval_start_batch_idx = ( + batch_idx - interval_num_batches_show_qps + ) # the start batch index of the interval + interval_end_batch_idx = ( + batch_idx # the end batch index of the interval + ) + ## fabricate the message content + msg_content = { + "epoch_idx": epoch_idx, + "rank": rank, + "interval_start_batch_idx": interval_start_batch_idx, + "interval_end_batch_idx": interval_end_batch_idx, + "per_batch_time_list": per_batch_time_list, + "per_batch_num_queries_list": num_queries_per_batch_list, + } + ## put the message into the queue + queue.put(("duration_and_num_queries", msg_content)) + qps_per_interval = ( + total_num_queries_in_interval / total_time_in_interval + ) + total_time_in_training += total_time_in_interval + total_num_queries_in_training += total_num_queries_in_interval + pbar.set_postfix( + { + "QPS": qps_per_interval, + } + ) + pbar.update(interval_num_batches_show_qps) + # reset the lists + starter_list = [] + ender_list = [] + num_queries_per_batch_list = [] + # if batch_idx > 2: + # time.sleep(5) + # queue.put(("finished", {"rank": rank})) + # print("finished") + # exit(0) + # after each epoch, do validation + eval_result_dict = evaluation(model, val_dataloader, device) + # print the evaluation result + print(f"Epoch {epoch_idx} validation result: {eval_result_dict}") + # send the evaluation result to the queue + msg_content = { + "epoch_idx": epoch_idx, + "rank": rank, + "eval_result_dict": eval_result_dict, + } + queue.put(("eval_result", msg_content)) + + time.sleep(30) + queue.put(("finished", {"rank": rank})) + print("finished") + return + + # print("Total time in training: ", total_time_in_training) + # print("Total number of queries in training: ", total_num_queries_in_training) + # print("Average QPS: ", total_num_queries_in_training / total_time_in_training) + + +def evaluation(model: DLRMTrain, data_loader: DataLoader, device: torch.device): + """ + Evaluate the model on the given data loader. + """ + model.eval() + auroc = torchmetrics.AUROC(task="multiclass", num_classes=2).to(device) + log_loss_list = [] + label_val_sums = 0 + num_labels = 0 + for batch in tqdm(data_loader): + batch = batch.to(device) + with torch.no_grad(): + loss, outputs = model(batch) + loss_val, logits, labels = outputs + preds = torch.sigmoid(logits) + preds_reshaped = torch.stack((1 - preds, preds), dim=1) + auroc.update(preds_reshaped, labels) + # calculate log loss + batch_log_loss_list = -( + (1 + labels) / 2 * torch.log(preds) + + (1 - labels) / 2 * torch.log(1 - preds) + ) + log_loss_list.extend(batch_log_loss_list.tolist()) + label_val_sums += labels.sum().item() + num_labels += labels.shape[0] + auroc_result = auroc.compute().item() + # calculate ne as mean(log_loss_list) / log_loss(avg_label) + avg_label = label_val_sums / num_labels + avg_label = torch.tensor(avg_label).to(device) + avg_label_log_loss = -( + avg_label * torch.log(avg_label) + (1 - avg_label) * torch.log(1 - avg_label) + ) + ne = torch.mean(torch.tensor(log_loss_list)).item() / avg_label_log_loss.item() + print(f"AUROC: {auroc_result}, NE: {ne}") + eval_result_dict = { + "auroc": auroc_result, + "ne": ne, + } + return eval_result_dict + + +def statistic(args: argparse.Namespace, queue: multiprocessing.Queue): + """ + The process to perform statistic calculations + """ + mch_buffer = ( + {} + ) # {epcoh_idx:{end_batch_idx: {rank: data_dict}}} where data dict is {metric_name: metric_value} + num_processed_batches = 0 # counter of the number of processed batches + world_size = int(os.environ["WORLD_SIZE"]) # world size + finished_counter = 0 # counter of the number of finished processes + + # create a profiling result folder + os.makedirs(args.profiling_result_folder, exist_ok=True) + # create a csv file to save the zch_metrics + if args.use_zch: + zch_metrics_file_path = os.path.join( + args.profiling_result_folder, "zch_metrics.csv" + ) + with open(zch_metrics_file_path, "w") as f: + writer = csv.writer(f) + writer.writerow( + [ + "epoch_idx", + "batch_idx", + "feature_name", + "hit_cnt", + "total_cnt", + "insert_cnt", + "collision_cnt", + "hit_rate", + "insert_rate", + "collision_rate", + "rank_total_cnt", + ] + ) + # create a csv file to save the qps_metrics + qps_metrics_file_path = os.path.join( + args.profiling_result_folder, "qps_metrics.csv" + ) + with open(qps_metrics_file_path, "w") as f: + writer = csv.writer(f) + writer.writerow( + [ + "epoch_idx", + "batch_idx", + "rank", + "num_queries", + "duration", + "qps", + ] + ) + # create a csv file to save the eval_metrics + eval_metrics_file_path = os.path.join( + args.profiling_result_folder, "eval_metrics.csv" + ) + with open(eval_metrics_file_path, "w") as f: + writer = csv.writer(f) + writer.writerow( + [ + "epoch_idx", + "rank", + "auroc", + "ne", + ] + ) + + while finished_counter < world_size: + try: + # get the data from the queue + msg_type, msg_content = queue.get( + timeout=0.5 + ) # data are put into the queue im the form of (msg_type, epoch_idx, batch_idx, rank, rank_data_dict) + except Exception as e: + # if the queue is empty, check if all the processes have finished + # if finished_counter >= world_size: + # print(f"All processes have finished. {finished_counter} / {world_size}") + # break + # else: + # continue # keep waiting for the queue to be filled + # if queue is empty, check if all the processes have finished + if finished_counter >= world_size: + print(f"All processes have finished. {finished_counter} / {world_size}") + break + else: + continue # keep waiting for the queue to be filled + # when getting the data, check if the data is from the last batch + if ( + msg_type == "finished" + ): # if the message type is "finished", the process has finished + rank = msg_content["rank"] + finished_counter += 1 + print(f"Process {rank} has finished. {finished_counter} / {world_size}") + continue + elif msg_type == "mch_stats": + if not args.use_zch: + continue + epoch_idx = msg_content["epoch_idx"] + batch_idx = msg_content["batch_idx"] + rank = msg_content["rank"] + rank_batch_mch_stats = msg_content["mch_stats"] + # other wise, aggregate the data into the buffer + if epoch_idx not in mch_buffer: + mch_buffer[epoch_idx] = {} + if batch_idx not in mch_buffer[epoch_idx]: + mch_buffer[epoch_idx][batch_idx] = {} + mch_buffer[epoch_idx][batch_idx][rank] = rank_batch_mch_stats + num_processed_batches += 1 + # check if we have all the data from all the ranks for a batch in an epoch + # if we have all the data, combine the data from all the ranks + if len(mch_buffer[epoch_idx][batch_idx]) == world_size: + # create a dictionary to store the statistics for each batch + batch_stats = ( + {} + ) # {feature_name: {hit_cnt: 0, total_cnt: 0, insert_cnt: 0, collision_cnt: 0}} + # combine the data from all the ranks + for mch_stats_rank_idx in mch_buffer[epoch_idx][batch_idx].keys(): + rank_batch_mch_stats = mch_buffer[epoch_idx][batch_idx][ + mch_stats_rank_idx + ] + # for each feature table in the mch stats information + for mch_stats_feature_name in rank_batch_mch_stats.keys(): + # create the dictionary for the feature table if not created + if mch_stats_feature_name not in batch_stats: + batch_stats[mch_stats_feature_name] = { + "hit_cnt": 0, + "total_cnt": 0, + "insert_cnt": 0, + "collision_cnt": 0, + "rank_total_cnt": {}, # dictionary of {rank_idx: num_quries_mapped_to_the_rank} + } + # aggregate the data from all the ranks + ## aggregate the hit count + batch_stats[mch_stats_feature_name][ + "hit_cnt" + ] += rank_batch_mch_stats[mch_stats_feature_name]["hit_cnt"] + ## aggregate the total count + batch_stats[mch_stats_feature_name][ + "total_cnt" + ] += rank_batch_mch_stats[mch_stats_feature_name]["total_cnt"] + ## aggregate the insert count + batch_stats[mch_stats_feature_name][ + "insert_cnt" + ] += rank_batch_mch_stats[mch_stats_feature_name]["insert_cnt"] + ## aggregate the collision count + batch_stats[mch_stats_feature_name][ + "collision_cnt" + ] += rank_batch_mch_stats[mch_stats_feature_name][ + "collision_cnt" + ] + ## for rank total count, get the data from the rank data dict + batch_stats[mch_stats_feature_name]["rank_total_cnt"][ + mch_stats_rank_idx + ] = rank_batch_mch_stats[mch_stats_feature_name][ + "rank_total_cnt" + ] + # clear the buffer for the batch + del mch_buffer[epoch_idx][batch_idx] + # save the zch statistics to a file + with open(zch_metrics_file_path, "a") as f: + writer = csv.writer(f) + for feature_name, stats in batch_stats.items(): + hit_rate = stats["hit_cnt"] / stats["total_cnt"] + insert_rate = stats["insert_cnt"] / stats["total_cnt"] + collision_rate = stats["collision_cnt"] / stats["total_cnt"] + rank_total_cnt = json.dumps(stats["rank_total_cnt"]) + writer.writerow( + [ + epoch_idx, + batch_idx, + feature_name, + stats["hit_cnt"], + stats["total_cnt"], + stats["insert_cnt"], + stats["collision_cnt"], + hit_rate, + insert_rate, + collision_rate, + rank_total_cnt, + ] + ) + elif msg_type == "duration_and_num_queries": + epoch_idx = msg_content["epoch_idx"] + rank = msg_content["rank"] + interval_start_batch_idx = msg_content["interval_start_batch_idx"] + interval_end_batch_idx = msg_content["interval_end_batch_idx"] + per_batch_time_list = msg_content["per_batch_time_list"] + per_batch_num_queries_list = msg_content["per_batch_num_queries_list"] + # save the qps statistics to a file + with open(qps_metrics_file_path, "a") as f: + writer = csv.writer(f) + for i in range(len(per_batch_time_list)): + writer.writerow( + [ + epoch_idx, + str(interval_end_batch_idx + i), + rank, + per_batch_num_queries_list[i], + per_batch_time_list[i], + ( + per_batch_num_queries_list[i] / per_batch_time_list[i] + if per_batch_time_list[i] > 0 + else 0 + ), + ] + ) + elif msg_type == "eval_result": + epoch_idx = msg_content["epoch_idx"] + rank = msg_content["rank"] + eval_result_dict = msg_content["eval_result_dict"] + # save the evaluation result to a file + with open(eval_metrics_file_path, "a") as f: + writer = csv.writer(f) + writer.writerow( + [ + epoch_idx, + rank, + eval_result_dict["auroc"], + eval_result_dict["ne"], + ] + ) + else: + # raise a warning if the message type is not recognized + print("Warning: Unknown message type") + continue + + +if __name__ == "__main__": + args = parse_args(sys.argv[1:]) + + # set environment variables + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + # set a multiprocessing context + ctx = multiprocessing.get_context("spawn") + # create a queue to communicate between processes + queue = ctx.Queue() + # create a process to perform statistic calculations + stat_process = ctx.Process( + target=statistic, args=(args, queue) + ) # create a process to perform statistic calculations + stat_process.start() + # create a process to perform benchmarking + train_processes = [] + for rank in range(int(os.environ["WORLD_SIZE"])): + p = ctx.Process( + target=main, + args=(rank, args, queue), + ) + p.start() + train_processes.append(p) + + # wait for the training processes to finish + for p in train_processes: + p.join() + # wait for the statistic process to finish + stat_process.join() diff --git a/torchrec/distributed/benchmark/benchmark_zch_utils.py b/torchrec/distributed/benchmark/benchmark_zch_utils.py new file mode 100644 index 000000000..7dedf7a36 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_zch_utils.py @@ -0,0 +1,240 @@ +import copy +import json +import os +from typing import Dict + +import numpy as np + +import torch +import torch.nn as nn +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class BenchmarkMCProbe(nn.Module): + def __init__( + self, + mcec: Dict[str, ManagedCollisionEmbeddingCollection], + mc_method: str, # method for managing collisions, one of ["zch", "mpzch"] + rank: int, # rank of the current model shard + log_file_folder: str = "benchmark_logs", # folder to store the logging file + ) -> None: + super().__init__() + # self._mcec is a pointer to the mcec object passed in + self._mcec = mcec + # record the mc_method + self._mc_method = mc_method + # initialize the logging file handler + os.makedirs(log_file_folder, exist_ok=True) + self._log_file_path = os.path.join(log_file_folder, f"rank_{rank}.json") + self._rank = rank # record the rank of the current model shard + # get the output_offsets of the mcec + self.per_table_output_offsets = ( + {} + ) # dict of {table_name [str]: output_offsets [torch.Tensor]} TODO: find out relationship between table_name and feature_name + if self._mc_method == "mpzch": + for table_name, mcec_module in self._mcec.items(): + self.per_table_output_offsets[table_name] = ( + mcec_module._output_global_offset_tensor + ) + # create a dictionary to store the state of mcec modules + self.mcec_state = {} + # create a dictionary to store the statistics of mch modules + self._mch_stats = ( + {} + ) # dictionary of {table_name [str]: {metric_name [str]: metric_value [int]}} + + # record mcec state to file + def record_mcec_state(self, stage: str) -> None: + """ + record the state of mcec modules to the log file + The recorded state is a dictionary of + {{stage: {table_name: {metric_name: state}}}} + It only covers for the current batch + + params: + stage (str): before_fwd, after_fwd + return: + None + """ + # check if the stage in the desired options + assert stage in ( + "before_fwd", + "after_fwd", + ), f"stage {stage} is not supported, valid options are before_fwd, after_fwd" + # create a dictionary to store the state of mcec modules + if stage not in self.mcec_state: + self.mcec_state[stage] = {} # dict of {table_name: {metric_name: state}} + # if the stage is before_fwd, only record the remapping_table + # save the mcec table state for each embedding table + self.mcec_state[stage][ + "table_state" + ] = {} # dict of {table_name: {"remapping_table": state}} + for table_name, mc_module in self._mcec.items(): + self.mcec_state[stage]["table_state"][table_name] = {} + # + if self._mc_method == "zch": + self.mcec_state[stage]["table_state"][table_name][ + "remapping_table" + ] = mc_module._mch_sorted_raw_ids + # save t + elif self._mc_method == "mpzch": + self.mcec_state[stage]["table_state"][table_name]["remapping_table"] = ( + mc_module._hash_zch_identities.clone() + .to_dense() + .squeeze() + .cpu() + .numpy() + .tolist() + ) + else: + raise NotImplementedError( + f"mc method {self._mc_method} is not supported yet" + ) + # for before_fwd, we only need to record the remapping_table + if stage == "before_fwd": + return + # for after_fwd, we need to record the feature values + # check if the "before_fwd" stage is recorded + assert ( + "before_fwd" in self.mcec_state + ), "before_fwd stage is not recorded, please call record_mcec_state before calling record_mcec_state after_fwd" + # create the dirctionary to store the mcec feature values before forward + self.mcec_state["before_fwd"]["feature_values"] = {} + # create the dirctionary to store the mcec feature values after forward + self.mcec_state[stage]["feature_values"] = {} # dict of {table_name: state} + # save the mcec feature values for each embedding table + for table_name, mc_module in self._mcec.items(): + # record the remapped feature values + if self._mc_method == "mpzch": # when using mpzch mc modules + # record the remapped feature values first + self.mcec_state[stage]["feature_values"][table_name] = ( + mc_module.table_name_on_device_remapped_ids_dict[table_name] + .cpu() + .numpy() + .tolist() + ) + # record the input feature values + self.mcec_state["before_fwd"]["feature_values"][table_name] = ( + mc_module.table_name_on_device_input_ids_dict[table_name] + .cpu() + .numpy() + .tolist() + ) + # check if the input feature values list is empty + if ( + len(self.mcec_state["before_fwd"]["feature_values"][table_name]) + == 0 + ): + # if the input feature values list is empty, make it a list of -2 with the same length as the remapped feature values + self.mcec_state["before_fwd"]["feature_values"][table_name] = [ + -2 + ] * len(self.mcec_state[stage]["feature_values"][table_name]) + else: # when using other zch mc modules # TODO: implement the feature value recording for zch + raise NotImplementedError( + f"zc method {self._mc_method} is not supported yet" + ) + return + + def get_mcec_state(self) -> Dict[str, Dict[str, Dict[str, Dict[str, int]]]]: + """ + get the state of mcec modules + the state is a dictionary of + {{stage: {table_name: {data_name: state}}}} + """ + return self.mcec_state + + def save_mcec_state(self) -> None: + """ + save the state of mcec modules to the log file + """ + with open(self._log_file_path, "w") as f: + json.dump(self.mcec_state, f, indent=4) + + def get_mch_stats(self) -> Dict[str, Dict[str, int]]: + """ + get the statistics of mch modules + the statistics is a dictionary of + {{table_name: {metric_name: metric_value}}} + """ + return self._mch_stats + + def update(self) -> None: + """ + Update the ZCH statistics for the current batch + Params: + None + Return: + None + Require: + self.mcec_state is not None and has recorded both "before_fwd" and "after_fwd" for a batch + Update: + self._mch_stats + """ + # create a dictionary to store the statistics for each batch + batch_stats = ( + {} + ) # table_name: {hit_cnt: 0, total_cnt: 0, insert_cnt: 0, collision_cnt: 0} + # calculate the statistics for each rank + # get the remapping id table before forward pass and the input feature values + rank_feature_value_before_fwd = self.mcec_state["before_fwd"]["feature_values"] + # get the remapping id table after forward pass and the remapped feature ids + rank_feature_value_after_fwd = self.mcec_state["after_fwd"]["feature_values"] + # for each feature table in the remapped information + for ( + feature_name, + remapped_feature_ids, + ) in rank_feature_value_after_fwd.items(): + # create a new diction for the feature table if not created + if feature_name not in batch_stats: + batch_stats[feature_name] = { + "hit_cnt": 0, + "total_cnt": 0, + "insert_cnt": 0, + "collision_cnt": 0, + "rank_total_cnt": 0, + } + # get the input faeture values + input_feature_values = np.array(rank_feature_value_before_fwd[feature_name]) + # get the values stored in the remapping table for each remapped feature id after forward pass + prev_remapped_values = np.array( + self.mcec_state["before_fwd"]["table_state"][f"{feature_name}"][ + "remapping_table" + ] + )[remapped_feature_ids] + # get the values stored in the remapping table for each remapped feature id before forward pass + after_remapped_values = np.array( + self.mcec_state["after_fwd"]["table_state"][f"{feature_name}"][ + "remapping_table" + ] + )[remapped_feature_ids] + # count the number of same values in prev_remapped_values and after_remapped_values + # hit count = number of remapped values that exist in the remapping table before forward pass + this_rank_hits_count = np.sum(prev_remapped_values == input_feature_values) + batch_stats[feature_name]["hit_cnt"] += int(this_rank_hits_count) + # count the number of insertions + ## insert count = the decreased number of empty slots in the remapping table + ## before and after forward pass + this_rank_insert_count = np.sum(prev_remapped_values == -1) - np.sum( + after_remapped_values == -1 + ) + batch_stats[feature_name]["insert_cnt"] += int(this_rank_insert_count) + # count the number of total values + ## total count = the number of remapped values in the remapping table after forward pass + this_rank_total_count = int(len(remapped_feature_ids)) + # count the number of values redirected to the rank + batch_stats[feature_name]["rank_total_cnt"] = this_rank_total_count + batch_stats[feature_name]["total_cnt"] += this_rank_total_count + # count the number of collisions + # collision count = total count - hit count - insert count + this_rank_collision_count = ( + this_rank_total_count - this_rank_hits_count - this_rank_insert_count + ) + batch_stats[feature_name]["collision_cnt"] += int(this_rank_collision_count) + self._mch_stats = batch_stats diff --git a/torchrec/distributed/benchmark/count_non_zch_collision.py b/torchrec/distributed/benchmark/count_non_zch_collision.py new file mode 100644 index 000000000..f70677279 --- /dev/null +++ b/torchrec/distributed/benchmark/count_non_zch_collision.py @@ -0,0 +1,189 @@ +import csv +import json +import multiprocessing +import os +import sys + +import numpy as np + +import torch +from benchmark_zch_dlrmv2 import parse_args +from data.dlrm_dataloader import get_dataloader +from torch import distributed as dist +from torchrec.test_utils import get_free_port +from tqdm import tqdm + + +def main(rank, args): + # seed everything for reproducibility + torch.manual_seed(args.seed) + np.random.seed(args.seed) + # setup environment + os.environ["RANK"] = str(rank) + if torch.cuda.is_available(): + device: torch.device = torch.device(f"cuda:{rank}") + backend = "nccl" + torch.cuda.set_device(device) + else: + device: torch.device = torch.device("cpu") + backend = "gloo" + dist.init_process_group(backend=backend, init_method="env://") + + train_dataloader = get_dataloader(args, backend, "train") + + # make folder to save the collision dict + os.makedirs(args.profiling_result_folder, exist_ok=True) + + # collision dict + collison_dict = {} # feature_name: {remapped_id: original_id} + collision_stat = {} # feature_name: {hit: 0, collision: 0, total: 0} + hash_value_lookup_table = ( + {} + ) # feature_name: {original_id: remapped_id} # used to look up the remapped id for the original id to save the time of hashing the original id again + remapping_tensor_dict = {} # feature_name: remapping_tensor + zch_metrics_file_path = os.path.join( + args.profiling_result_folder, "zch_metrics.csv" + ) + with open(zch_metrics_file_path, "w") as f: + writer = csv.writer(f) + writer.writerow( + [ + "epoch_idx", + "batch_idx", + "feature_name", + "hit_cnt", + "total_cnt", + "insert_cnt", + "collision_cnt", + "hit_rate", + "insert_rate", + "collision_rate", + "rank_idx", + ] + ) + + pbar = tqdm(train_dataloader, desc=f"Rank {rank}") + for batch_idx, batch in enumerate(pbar): + batch = batch.to(device) + for feature_name, feature_values_jt in batch.sparse_features.to_dict().items(): + if feature_name not in collison_dict: + collison_dict[feature_name] = {} + if feature_name not in collision_stat: + collision_stat[feature_name] = { + "hit_cnt": 0, + "collision_cnt": 0, + "total_cnt": 0, + "insert_cnt": 0, + } + if feature_name not in remapping_tensor_dict: + remapping_tensor_dict[feature_name] = ( + torch.zeros(args.num_embeddings, dtype=torch.int64) - 1 + ).to( + device + ) # create a tensor of size [num_embeddings] and initialize it with -1 + num_empty_slots_before_remapping = ( + torch.sum(remapping_tensor_dict[feature_name] == -1).cpu().item() + ) # count the number of empty slots in the remapping tensor + if feature_name not in hash_value_lookup_table: + hash_value_lookup_table[feature_name] = {} + # create progress bar of feature values + remapped_tensor_values = torch.zeros_like(feature_values_jt.values()) + input_feature_values = feature_values_jt.values() + for feature_value_idx in range(len(input_feature_values)): + feature_value = input_feature_values[feature_value_idx] + if feature_value.cpu().item() in hash_value_lookup_table[feature_name]: + hashed_feature_value = hash_value_lookup_table[feature_name][ + feature_value.cpu().item() + ] + else: + feature_value = feature_value.unsqueeze(0) # convert to [1, ] + feature_value = feature_value.to(torch.uint64) # convert to uint64 + hashed_feature_value = torch.ops.fbgemm.murmur_hash3( + feature_value, 0, 0 + ) + # convert to int64 + hashed_feature_value = hashed_feature_value.to( + torch.int64 + ) # convert to int64 + # convert to [0, num_embeddings) + hashed_feature_value = ( + (hashed_feature_value % args.num_embeddings).cpu().item() + ) # convert to [0, num_embeddings) + # save the hashed feature value to the lookup table + hash_value_lookup_table[feature_name][ + feature_value.cpu().item() + ] = hashed_feature_value + remapped_tensor_values[feature_value_idx] = hashed_feature_value + # check if the remapping_tensor_dict at remapped_value's indexed slot value is -1 + if remapping_tensor_dict[feature_name][hashed_feature_value] == -1: + # if the remapping_tensor_dict at remapped_value's indexed slot value is -1, update the remapping_tensor_dict at remapped_value's indexed slot value to feature_value + remapping_tensor_dict[feature_name][ + hashed_feature_value + ] = feature_value + # check if the hashed feature value is in the collision dict + num_empty_slots_after_remapping = ( + torch.sum(remapping_tensor_dict[feature_name] == -1).cpu().item() + ) # count the number of empty slots in the remapping tensor + insert_cnt = ( + num_empty_slots_before_remapping - num_empty_slots_after_remapping + ) + hit_cnt = ( + ( + torch.sum( + torch.eq( + input_feature_values, + remapping_tensor_dict[feature_name][remapped_tensor_values], + ) + ) + - insert_cnt + ) + .cpu() + .item() + ) + total_cnt = len(input_feature_values) + collision_cnt = total_cnt - hit_cnt - insert_cnt + collision_stat[feature_name]["hit_cnt"] = hit_cnt + collision_stat[feature_name]["collision_cnt"] = collision_cnt + collision_stat[feature_name]["total_cnt"] = total_cnt + collision_stat[feature_name]["insert_cnt"] = insert_cnt + + # save the collision stat + with open(zch_metrics_file_path, "a") as f: + writer = csv.writer(f) + for feature_name, stats in collision_stat.items(): + hit_rate = stats["hit_cnt"] / stats["total_cnt"] + insert_rate = stats["insert_cnt"] / stats["total_cnt"] + collision_rate = stats["collision_cnt"] / stats["total_cnt"] + writer.writerow( + [ + 0, + batch_idx, + feature_name, + stats["hit_cnt"], + stats["total_cnt"], + stats["insert_cnt"], + stats["collision_cnt"], + hit_rate, + insert_rate, + collision_rate, + rank, + ] + ) + + +if __name__ == "__main__": + args = parse_args(sys.argv[1:]) + # set environment variables + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + # set a multiprocessing context + ctx = multiprocessing.get_context("spawn") + # create a process to perform benchmarking + processes = [] + for rank in range(int(os.environ["WORLD_SIZE"])): + p = ctx.Process( + target=main, + args=(rank, args), + ) + p.start() + processes.append(p) diff --git a/torchrec/distributed/benchmark/data/__init__.py b/torchrec/distributed/benchmark/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/distributed/benchmark/data/dlrm_dataloader.py b/torchrec/distributed/benchmark/data/dlrm_dataloader.py new file mode 100644 index 000000000..dc7c8b39f --- /dev/null +++ b/torchrec/distributed/benchmark/data/dlrm_dataloader.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +from typing import List + +from torch import distributed as dist +from torch.utils.data import DataLoader +from torchrec.datasets.criteo import ( + CAT_FEATURE_COUNT, + DAYS, + DEFAULT_CAT_NAMES, + DEFAULT_INT_NAMES, + InMemoryBinaryCriteoIterDataPipe, +) +from torchrec.datasets.random import RandomRecDataset + +# OSS import +try: + # pyre-ignore[21] + # @manual=torchrec/distributed/benchmark/data:multi_hot_criteo + from data.multi_hot_criteo import MultiHotCriteoIterDataPipe + +except ImportError: + pass + +# internal import +try: + from .multi_hot_criteo import MultiHotCriteoIterDataPipe # noqa F811 +except ImportError: + pass + +STAGES = ["train", "val", "test"] + + +def _get_random_dataloader( + args: argparse.Namespace, + stage: str, +) -> DataLoader: + attr = f"limit_{stage}_batches" + num_batches = getattr(args, attr) + if stage in ["val", "test"] and args.test_batch_size is not None: + batch_size = args.test_batch_size + else: + batch_size = args.batch_size + return DataLoader( + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=batch_size, + hash_size=args.num_embeddings, + hash_sizes=( + args.num_embeddings_per_feature + if hasattr(args, "num_embeddings_per_feature") + else None + ), + manual_seed=getattr(args, "seed", None), + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + num_batches=num_batches, + ), + batch_size=None, + batch_sampler=None, + pin_memory=args.pin_memory, + num_workers=0, + ) + + +def _get_in_memory_dataloader( + args: argparse.Namespace, + stage: str, +) -> DataLoader: + if args.in_memory_binary_criteo_path is not None: + dir_path = args.in_memory_binary_criteo_path + sparse_part = "sparse.npy" + datapipe = InMemoryBinaryCriteoIterDataPipe + else: + dir_path = args.synthetic_multi_hot_criteo_path + sparse_part = "sparse_multi_hot.npz" + datapipe = MultiHotCriteoIterDataPipe + + if args.dataset_name == "criteo_kaggle": + # criteo_kaggle has no validation set, so use 2nd half of training set for now. + # Setting stage to "test" will get the 2nd half of the dataset. + # Setting root_name to "train" reads from the training set file. + (root_name, stage) = ( + ("train", "train") if stage == "train" else ("train", "test") + ) + stage_files: List[List[str]] = [ + [os.path.join(dir_path, f"{root_name}_dense.npy")], + [os.path.join(dir_path, f"{root_name}_{sparse_part}")], + [os.path.join(dir_path, f"{root_name}_labels.npy")], + ] + # criteo_1tb code path uses below two conditionals + elif stage == "train": + stage_files: List[List[str]] = [ + [os.path.join(dir_path, f"day_{i}_dense.npy") for i in range(DAYS - 1)], + [os.path.join(dir_path, f"day_{i}_{sparse_part}") for i in range(DAYS - 1)], + [os.path.join(dir_path, f"day_{i}_labels.npy") for i in range(DAYS - 1)], + ] + elif stage in ["val", "test"]: + stage_files: List[List[str]] = [ + [os.path.join(dir_path, f"day_{DAYS-1}_dense.npy")], + [os.path.join(dir_path, f"day_{DAYS-1}_{sparse_part}")], + [os.path.join(dir_path, f"day_{DAYS-1}_labels.npy")], + ] + if stage in ["val", "test"] and args.test_batch_size is not None: + batch_size = args.test_batch_size + else: + batch_size = args.batch_size + dataloader = DataLoader( + datapipe( + stage, + *stage_files, # pyre-ignore[6] + batch_size=batch_size, + rank=dist.get_rank(), + world_size=dist.get_world_size(), + drop_last=args.drop_last_training_batch if stage == "train" else False, + shuffle_batches=args.shuffle_batches, + shuffle_training_set=args.shuffle_training_set, + shuffle_training_set_random_seed=args.seed, + mmap_mode=args.mmap_mode, + hashes=( + [args.num_embeddings] * CAT_FEATURE_COUNT + if args.input_hash_size is None + else ([args.input_hash_size] * CAT_FEATURE_COUNT) + ), + ), + batch_size=None, + pin_memory=args.pin_memory, + collate_fn=lambda x: x, + ) + return dataloader + + +def get_dataloader(args: argparse.Namespace, backend: str, stage: str) -> DataLoader: + """ + Gets desired dataloader from dlrm_main command line options. Currently, this + function is able to return either a DataLoader wrapped around a RandomRecDataset or + a Dataloader wrapped around an InMemoryBinaryCriteoIterDataPipe. + + Args: + args (argparse.Namespace): Command line options supplied to dlrm_main.py's main + function. + backend (str): "nccl" or "gloo". + stage (str): "train", "val", or "test". + + Returns: + dataloader (DataLoader): PyTorch dataloader for the specified options. + + """ + stage = stage.lower() + if stage not in STAGES: + raise ValueError(f"Supplied stage was {stage}. Must be one of {STAGES}.") + + args.pin_memory = ( + (backend == "nccl") if not hasattr(args, "pin_memory") else args.pin_memory + ) + + if ( + args.in_memory_binary_criteo_path is None + and args.synthetic_multi_hot_criteo_path is None + ): + return _get_random_dataloader(args, stage) + else: + return _get_in_memory_dataloader(args, stage) diff --git a/torchrec/distributed/benchmark/data/multi_hot_criteo.py b/torchrec/distributed/benchmark/data/multi_hot_criteo.py new file mode 100644 index 000000000..e6126290c --- /dev/null +++ b/torchrec/distributed/benchmark/data/multi_hot_criteo.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import zipfile +from typing import Dict, Iterator, List, Optional + +import numpy as np +import torch +from iopath.common.file_io import PathManager, PathManagerFactory +from pyre_extensions import none_throws +from torch.utils.data import IterableDataset +from torchrec.datasets.criteo import CAT_FEATURE_COUNT, DEFAULT_CAT_NAMES +from torchrec.datasets.utils import Batch, PATH_MANAGER_KEY +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class MultiHotCriteoIterDataPipe(IterableDataset): + """ + Datapipe designed to operate over the MLPerf DLRM v2 synthetic multi-hot dataset. + This dataset can be created by following the steps in + torchrec_dlrm/scripts/materialize_synthetic_multihot_dataset.py. + Each rank reads only the data for the portion of the dataset it is responsible for. + + Args: + stage (str): "train", "val", or "test". + dense_paths (List[str]): List of path strings to dense npy files. + sparse_paths (List[str]): List of path strings to multi-hot sparse npz files. + labels_paths (List[str]): List of path strings to labels npy files. + batch_size (int): batch size. + rank (int): rank. + world_size (int): world size. + drop_last (Optional[bool]): Whether to drop the last batch if it is incomplete. + shuffle_batches (bool): Whether to shuffle batches + shuffle_training_set (bool): Whether to shuffle all samples in the dataset. + shuffle_training_set_random_seed (int): The random generator seed used when + shuffling the training set. + hashes (Optional[int]): List of max categorical feature value for each feature. + Length of this list should be CAT_FEATURE_COUNT. + path_manager_key (str): Path manager key used to load from different + filesystems. + + Example:: + + datapipe = MultiHotCriteoIterDataPipe( + dense_paths=["day_0_dense.npy"], + sparse_paths=["day_0_sparse_multi_hot.npz"], + labels_paths=["day_0_labels.npy"], + batch_size=1024, + rank=torch.distributed.get_rank(), + world_size=torch.distributed.get_world_size(), + ) + batch = next(iter(datapipe)) + """ + + def __init__( + self, + stage: str, + dense_paths: List[str], + sparse_paths: List[str], + labels_paths: List[str], + batch_size: int, + rank: int, + world_size: int, + drop_last: Optional[bool] = False, + shuffle_batches: bool = False, + shuffle_training_set: bool = False, + shuffle_training_set_random_seed: int = 0, + mmap_mode: bool = False, + hashes: Optional[List[int]] = None, + path_manager_key: str = PATH_MANAGER_KEY, + ) -> None: + self.stage = stage + self.dense_paths = dense_paths + self.sparse_paths = sparse_paths + self.labels_paths = labels_paths + self.batch_size = batch_size + self.rank = rank + self.world_size = world_size + self.drop_last = drop_last + self.shuffle_batches = shuffle_batches + self.shuffle_training_set = shuffle_training_set + np.random.seed(shuffle_training_set_random_seed) + self.mmap_mode = mmap_mode + # hashes are not used because they were already applied in the + # script that generates the multi-hot dataset. + self.hashes: np.ndarray = np.array(hashes).reshape((CAT_FEATURE_COUNT, 1)) + self.path_manager_key = path_manager_key + self.path_manager: PathManager = PathManagerFactory().get(path_manager_key) + + if shuffle_training_set and stage == "train": + # Currently not implemented for the materialized multi-hot dataset. + self._shuffle_and_load_data_for_rank() + else: + m = "r" if mmap_mode else None + self.dense_arrs: List[np.ndarray] = [ + np.load(f, mmap_mode=m) for f in self.dense_paths + ] + self.labels_arrs: List[np.ndarray] = [ + np.load(f, mmap_mode=m) for f in self.labels_paths + ] + self.sparse_arrs: List = [] + for sparse_path in self.sparse_paths: + multi_hot_ids_l = [] + for feat_id_num in range(CAT_FEATURE_COUNT): + multi_hot_ft_ids = self._load_from_npz( + sparse_path, f"{feat_id_num}.npy" + ) + multi_hot_ids_l.append(multi_hot_ft_ids) + self.sparse_arrs.append(multi_hot_ids_l) + len_d0 = len(self.dense_arrs[0]) + second_half_start_index = int(len_d0 // 2 + len_d0 % 2) + if stage == "val": + self.dense_arrs[0] = self.dense_arrs[0][:second_half_start_index, :] + self.labels_arrs[0] = self.labels_arrs[0][:second_half_start_index, :] + self.sparse_arrs[0] = [ + feats[:second_half_start_index, :] for feats in self.sparse_arrs[0] + ] + elif stage == "test": + self.dense_arrs[0] = self.dense_arrs[0][second_half_start_index:, :] + self.labels_arrs[0] = self.labels_arrs[0][second_half_start_index:, :] + self.sparse_arrs[0] = [ + feats[second_half_start_index:, :] for feats in self.sparse_arrs[0] + ] + # When mmap_mode is enabled, sparse features are hashed when + # samples are batched in def __iter__. Otherwise, the dataset has been + # preloaded with sparse features hashed in the preload stage, here: + # if not self.mmap_mode and self.hashes is not None: + # for k, _ in enumerate(self.sparse_arrs): + # self.sparse_arrs[k] = [ + # feat % hash + # for (feat, hash) in zip(self.sparse_arrs[k], self.hashes) + # ] + + self.num_rows_per_file: List[int] = list(map(len, self.dense_arrs)) + total_rows = sum(self.num_rows_per_file) + self.num_full_batches: int = ( + total_rows // batch_size // self.world_size * self.world_size + ) + self.last_batch_sizes: np.ndarray = np.array( + [0 for _ in range(self.world_size)] + ) + remainder = total_rows % (self.world_size * batch_size) + if not self.drop_last and 0 < remainder: + if remainder < self.world_size: + self.num_full_batches -= self.world_size + self.last_batch_sizes += batch_size + else: + self.last_batch_sizes += remainder // self.world_size + self.last_batch_sizes[: remainder % self.world_size] += 1 + + self.multi_hot_sizes: List[int] = [ + multi_hot_feat.shape[-1] for multi_hot_feat in self.sparse_arrs[0] + ] + + # These values are the same for the KeyedJaggedTensors in all batches, so they + # are computed once here. This avoids extra work from the KeyedJaggedTensor sync + # functions. + self.keys: List[str] = DEFAULT_CAT_NAMES + self.index_per_key: Dict[str, int] = { + key: i for (i, key) in enumerate(self.keys) + } + + def _load_from_npz(self, fname, npy_name): + # figure out offset of .npy in .npz + zf = zipfile.ZipFile(fname) + info = zf.NameToInfo[npy_name] + assert info.compress_type == 0 + zf.fp.seek(info.header_offset + len(info.FileHeader()) + 20) + # read .npy header + zf.open(npy_name, "r") + version = np.lib.format.read_magic(zf.fp) + shape, fortran_order, dtype = np.lib.format._read_array_header(zf.fp, version) + assert ( + dtype == "int32" + ), f"sparse multi-hot dtype is {dtype} but should be int32" + offset = zf.fp.tell() + # create memmap + return np.memmap( + zf.filename, + dtype=dtype, + shape=shape, + order="F" if fortran_order else "C", + mode="r", + offset=offset, + ) + + def _np_arrays_to_batch( + self, + dense: np.ndarray, + sparse: List[np.ndarray], + labels: np.ndarray, + ) -> Batch: + if self.shuffle_batches: + # Shuffle all 3 in unison + shuffler = np.random.permutation(len(dense)) + sparse = [multi_hot_ft[shuffler, :] for multi_hot_ft in sparse] + dense = dense[shuffler] + labels = labels[shuffler] + + batch_size = len(dense) + lengths = torch.ones((CAT_FEATURE_COUNT * batch_size), dtype=torch.int32) + for k, multi_hot_size in enumerate(self.multi_hot_sizes): + lengths[k * batch_size : (k + 1) * batch_size] = multi_hot_size + offsets = torch.cumsum(torch.concat((torch.tensor([0]), lengths)), dim=0) + length_per_key = [ + batch_size * multi_hot_size for multi_hot_size in self.multi_hot_sizes + ] + offset_per_key = torch.cumsum( + torch.concat((torch.tensor([0]), torch.tensor(length_per_key))), dim=0 + ) + values = torch.concat([torch.from_numpy(feat).flatten() for feat in sparse]) + return Batch( + dense_features=torch.from_numpy(dense.copy()), + sparse_features=KeyedJaggedTensor( + keys=self.keys, + values=values, + lengths=lengths, + offsets=offsets, + stride=batch_size, + length_per_key=length_per_key, + offset_per_key=offset_per_key.tolist(), + index_per_key=self.index_per_key, + ), + labels=torch.from_numpy(labels.reshape(-1).copy()), + ) + + def __iter__(self) -> Iterator[Batch]: + # Invariant: buffer never contains more than batch_size rows. + buffer: Optional[List[np.ndarray]] = None + + def append_to_buffer( + dense: np.ndarray, + sparse: List[np.ndarray], + labels: np.ndarray, + ) -> None: + nonlocal buffer + if buffer is None: + buffer = [dense, sparse, labels] + else: + buffer[0] = np.concatenate((buffer[0], dense)) + buffer[1] = [np.concatenate((b, s)) for b, s in zip(buffer[1], sparse)] + buffer[2] = np.concatenate((buffer[2], labels)) + + # Maintain a buffer that can contain up to batch_size rows. Fill buffer as + # much as possible on each iteration. Only return a new batch when batch_size + # rows are filled. + file_idx = 0 + row_idx = 0 + batch_idx = 0 + buffer_row_count = 0 + cur_batch_size = ( + self.batch_size if self.num_full_batches > 0 else self.last_batch_sizes[0] + ) + while ( + batch_idx + < self.num_full_batches + (self.last_batch_sizes[0] > 0) * self.world_size + ): + if buffer_row_count == cur_batch_size or file_idx == len(self.dense_arrs): + if batch_idx % self.world_size == self.rank: + yield self._np_arrays_to_batch(*none_throws(buffer)) + buffer = None + buffer_row_count = 0 + batch_idx += 1 + if 0 <= batch_idx - self.num_full_batches < self.world_size and ( + self.last_batch_sizes[0] > 0 + ): + cur_batch_size = self.last_batch_sizes[ + batch_idx - self.num_full_batches + ] + else: + rows_to_get = min( + cur_batch_size - buffer_row_count, + self.num_rows_per_file[file_idx] - row_idx, + ) + buffer_row_count += rows_to_get + slice_ = slice(row_idx, row_idx + rows_to_get) + + if batch_idx % self.world_size == self.rank: + dense_inputs = self.dense_arrs[file_idx][slice_, :] + sparse_inputs = [ + feats[slice_, :] for feats in self.sparse_arrs[file_idx] + ] + target_labels = self.labels_arrs[file_idx][slice_, :] + + # if self.mmap_mode and self.hashes is not None: + # sparse_inputs = [ + # feats % hash + # for (feats, hash) in zip(sparse_inputs, self.hashes) + # ] + + append_to_buffer( + dense_inputs, + sparse_inputs, + target_labels, + ) + row_idx += rows_to_get + + if row_idx >= self.num_rows_per_file[file_idx]: + file_idx += 1 + row_idx = 0 + + def __len__(self) -> int: + return self.num_full_batches // self.world_size + (self.last_batch_sizes[0] > 0) diff --git a/torchrec/distributed/hash_mc_embedding.py b/torchrec/distributed/hash_mc_embedding.py new file mode 100644 index 000000000..4171e1092 --- /dev/null +++ b/torchrec/distributed/hash_mc_embedding.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging as logger +from collections import defaultdict +from typing import Dict, List + +import torch +from torchrec.distributed.quant_state import WeightSpec +from torchrec.distributed.types import ShardingType +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule + + +def sharded_zchs_buffers_spec( + sharded_model: torch.nn.Module, +) -> Dict[str, WeightSpec]: + # OUTPUT: + # Example: + # "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities", [0, 0], [500, 1]) + # "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.1._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities", [500, 0], [1000, 1]) + + # 'main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities': WeightSpec(fqn='main_module.module.ec_in_task_arch_hash._ d_embedding_collection._managed_collision_collection.viewer_rid_duplicate._hash_zch_identities' + def _get_table_names( + sharded_module: torch.nn.Module, + ) -> List[str]: + table_names: List[str] = [] + for _, module in sharded_module.named_modules(): + type_name: str = type(module).__name__ + if "ShardedMCCRemapper" in type_name: + for table_name in module._tables: + if table_name not in table_names: + table_names.append(table_name) + return table_names + + def _get_unsharded_fqn_identities( + sharded_module: torch.nn.Module, + fqn: str, + table_name: str, + ) -> str: + for module_fqn, module in sharded_module.named_modules(): + type_name: str = type(module).__name__ + if "ManagedCollisionCollection" in type_name: + if table_name in module._table_to_features: + return f"{fqn}.{module_fqn}._managed_collision_modules.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}" + logger.info(f"did not find table {table_name} in module {fqn}") + return "" + + ret: Dict[str, WeightSpec] = defaultdict() + for module_fqn, module in sharded_model.named_modules(): + type_name: str = type(module).__name__ + if "ShardedQuantManagedCollisionEmbeddingCollection" in type_name: + sharding_type = ShardingType.ROW_WISE.value + table_name_to_unsharded_fqn_identities: Dict[str, str] = {} + for subfqn, submodule in module.named_modules(): + type_name: str = type(submodule).__name__ + if "ShardedMCCRemapper" in type_name: + for table_name in submodule.zchs.keys(): + # identities tensor has only one column + shard_offsets: List[int] = [ + submodule._shard_metadata[table_name][0], + 0, + ] + shard_sizes: List[int] = [ + submodule._shard_metadata[table_name][1], + 1, + ] + if table_name not in table_name_to_unsharded_fqn_identities: + table_name_to_unsharded_fqn_identities[table_name] = ( + _get_unsharded_fqn_identities( + module, module_fqn, table_name + ) + ) + unsharded_fqn_identities: str = ( + table_name_to_unsharded_fqn_identities[table_name] + ) + # subfqn contains the index of sharding, so no need to add it specifically here + sharded_fqn_identities: str = ( + f"{module_fqn}.{subfqn}.zchs.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}" + ) + ret[sharded_fqn_identities] = WeightSpec( + fqn=unsharded_fqn_identities, + shard_offsets=shard_offsets, + shard_sizes=shard_sizes, + sharding_type=sharding_type, + ) + return ret diff --git a/torchrec/distributed/tests/test_hash_zch_mc.py b/torchrec/distributed/tests/test_hash_zch_mc.py new file mode 100644 index 000000000..7cf9906d1 --- /dev/null +++ b/torchrec/distributed/tests/test_hash_zch_mc.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#!/usr/bin/env python3 + +# pyre-strict + +import copy +import multiprocessing +import unittest +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from pyre_extensions import none_throws +from torch import nn +from torchrec import ( + EmbeddingCollection, + EmbeddingConfig, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) +from torchrec.distributed import ModuleSharder, ShardingEnv +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder + +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + EmbeddingCollectionSharder, + ManagedCollisionEmbeddingCollectionSharder, + row_wise, +) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ShardingPlan +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_modules import ManagedCollisionCollection + +BASE_LEAF_MODULES = [ + "IntNBitTableBatchedEmbeddingBagsCodegen", + "HashZchManagedCollisionModule", +] + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + buckets: int, + return_remapped: bool = False, + input_hash_size: int = 4000, + is_inference: bool = False, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + mc_modules = {} + mc_modules["table_0"] = HashZchManagedCollisionModule( + is_inference=is_inference, + zch_size=(tables[0].num_embeddings), + input_hash_size=input_hash_size, + device=device, + total_num_buckets=buckets, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature_0"], + single_ttl=1, + ), + ) + + mc_modules["table_1"] = HashZchManagedCollisionModule( + is_inference=is_inference, + zch_size=(tables[1].num_embeddings), + device=device, + input_hash_size=input_hash_size, + total_num_buckets=buckets, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature_1"], + single_ttl=1, + ), + ) + + self._mc_ec: ManagedCollisionEmbeddingCollection = ( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) + ) + + def forward( + self, kjt: KeyedJaggedTensor + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + return self._mc_ec(kjt) + + +class TestHashZchMcEmbedding(MultiProcessTestBase): + # pyre-ignore + @unittest.skipIf(torch.cuda.device_count() <= 1, "Not enough GPUs, skipping") + def test_hash_zch_mc_ec(self) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + train_input_per_rank = [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + list(range(1000, 1025)), + ), + lengths=torch.LongTensor([1] * 8 + [2] * 8), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + list(range(25000, 25025)), + ), + lengths=torch.LongTensor([1] * 8 + [2] * 8), + weights=None, + ), + ] + train_state_dict = multiprocessing.Manager().dict() + + # Train Model with ZCH on GPU + self._run_multi_process_test( + callable=_train_model, + world_size=WORLD_SIZE, + tables=embedding_config, + num_buckets=2, + kjt_input_per_rank=train_input_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder( + EmbeddingCollectionSharder(), + ManagedCollisionCollectionSharder(), + ), + return_dict=train_state_dict, + backend="nccl", + ) + + +def _train_model( + tables: List[EmbeddingConfig], + num_buckets: int, + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + sharder: ModuleSharder[nn.Module], + backend: str, + return_dict: Dict[str, Any], + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + + train_model = SparseArch( + tables=tables, + device=torch.device("cuda"), + input_hash_size=0, + return_remapped=True, + buckets=num_buckets, + ) + train_sharding_plan = construct_module_sharding_plan( + train_model._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda", + sharder=sharder, + ) + print(f"train_sharding_plan: {train_sharding_plan}") + sharded_train_model = _shard_modules( + module=copy.deepcopy(train_model), + plan=ShardingPlan({"_mc_ec": train_sharding_plan}), + env=ShardingEnv.from_process_group(none_throws(ctx.pg)), + sharders=[sharder], + device=ctx.device, + ) + # train + sharded_train_model(kjt_input.to(ctx.device)) + + for ( + key, + value, + ) in ( + # pyre-ignore + sharded_train_model._mc_ec._managed_collision_collection._managed_collision_modules.state_dict().items() + ): + return_dict[f"mc_{key}_{rank}"] = value.cpu() + for ( + key, + value, + # pyre-ignore + ) in sharded_train_model._mc_ec._embedding_collection.state_dict().items(): + tensors = [] + for i in range(len(value.local_shards())): + tensors.append(value.local_shards()[i].tensor.cpu()) + return_dict[f"ec_{key}_{rank}"] = torch.cat(tensors, dim=0) diff --git a/torchrec/modules/hash_mc_evictions.py b/torchrec/modules/hash_mc_evictions.py new file mode 100644 index 000000000..415a4cd95 --- /dev/null +++ b/torchrec/modules/hash_mc_evictions.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import time +from dataclasses import dataclass +from enum import Enum, unique +from typing import List, Optional, Tuple + +import torch +from pyre_extensions import none_throws + +from torchrec.sparse.jagged_tensor import JaggedTensor + +logger: logging.Logger = logging.getLogger(__name__) + + +@unique +class HashZchEvictionPolicyName(Enum): + # eviction based on the time the ID is last seen during training, + # and a single TTL + SINGLE_TTL_EVICTION = "SINGLE_TTL_EVICTION" + # eviction based on the time the ID is last seen during training, + # and per-feature TTLs + PER_FEATURE_TTL_EVICTION = "PER_FEATURE_TTL_EVICTION" + # eviction based on least recently seen ID within the probe range + LRU_EVICTION = "LRU_EVICTION" + + +@torch.jit.script +@dataclass +class HashZchEvictionConfig: + features: List[str] + single_ttl: Optional[int] = None + per_feature_ttl: Optional[List[int]] = None + + +@torch.fx.wrap +def get_kernel_from_policy( + policy_name: Optional[HashZchEvictionPolicyName], +) -> int: + return ( + 1 + if policy_name is not None + and policy_name == HashZchEvictionPolicyName.LRU_EVICTION + else 0 + ) + + +class HashZchEvictionScorer: + def __init__(self, config: HashZchEvictionConfig) -> None: + self._config: HashZchEvictionConfig = config + + def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor: + return torch.empty(0, device=device) + + def gen_threshold(self) -> int: + return -1 + + +class HashZchSingleTtlScorer(HashZchEvictionScorer): + def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor: + assert ( + self._config.single_ttl is not None and self._config.single_ttl > 0 + ), "To use scorer HashZchSingleTtlScorer, a positive single_ttl is required." + + return torch.full_like( + feature.values(), + # pyre-ignore [58] + self._config.single_ttl + + int( + time.time() / 3600 + ), # add the current time to the single_ttl, this is the time whem each value is expired and becomes evictable + dtype=torch.int32, + device=device, + ) + + def gen_threshold(self) -> int: + return int(time.time() / 3600) + + +class HashZchPerFeatureTtlScorer(HashZchEvictionScorer): + def __init__(self, config: HashZchEvictionConfig) -> None: + super().__init__(config) + + assert self._config.per_feature_ttl is not None and len( + self._config.features + ) == len( + # pyre-ignore [6] + self._config.per_feature_ttl + ), "To use scorer HashZchPerFeatureTtlScorer, a 1:1 mapping between features and per_feature_ttl is required." + + self._per_feature_ttl = torch.IntTensor(self._config.per_feature_ttl) + + def gen_score(self, feature: JaggedTensor, device: torch.device) -> torch.Tensor: + feature_split = feature.weights() + assert feature_split.size(0) == self._per_feature_ttl.size(0) + + scores = self._per_feature_ttl.repeat_interleave(feature_split) + int( + time.time() / 3600 + ) + + return scores.to(device=device) + + def gen_threshold(self) -> int: + return int(time.time() / 3600) + + +@torch.fx.wrap +def get_eviction_scorer( + policy_name: str, config: HashZchEvictionConfig +) -> HashZchEvictionScorer: + if policy_name == HashZchEvictionPolicyName.SINGLE_TTL_EVICTION: + return HashZchSingleTtlScorer(config) + elif policy_name == HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION: + return HashZchPerFeatureTtlScorer(config) + elif policy_name == HashZchEvictionPolicyName.LRU_EVICTION: + return HashZchSingleTtlScorer(config) + else: + return HashZchEvictionScorer(config) + + +class HashZchThresholdEvictionModule(torch.nn.Module): + """ + This module manages the computation of eviction score for input IDs. Based on the selected + eviction policy, a scorer is initiated to generate a score for each ID. The kernel + will use this score to make eviction decisions. + + Args: + policy_name: an enum value that indicates the eviction policy to use. + config: a config that contains information needed to run the eviction policy. + + Example:: + module = HashZchThresholdEvictionModule(...) + score = module(feature) + """ + + _eviction_scorer: HashZchEvictionScorer + + def __init__( + self, + policy_name: HashZchEvictionPolicyName, + config: HashZchEvictionConfig, + ) -> None: + super().__init__() + + self._policy_name: HashZchEvictionPolicyName = policy_name + self._config: HashZchEvictionConfig = config + self._eviction_scorer = get_eviction_scorer( + policy_name=self._policy_name, + config=self._config, + ) + + logger.info( + f"HashZchThresholdEvictionModule: {self._policy_name=}, {self._config=}" + ) + + def forward( + self, feature: JaggedTensor, device: torch.device + ) -> Tuple[torch.Tensor, int]: + """ + Args: + feature: a jagged tensor that contains the input IDs, and their lengths and + weights (feature split). + device: device of the tensor. + + Returns: + a tensor that contains the eviction score for each ID, plus an eviction threshold. + """ + return ( + self._eviction_scorer.gen_score(feature, device), + self._eviction_scorer.gen_threshold(), + ) + + +class HashZchOptEvictionModule(torch.nn.Module): + """ + This module manages the eviction of IDs from the ZCH table based on the selected eviction policy. + Args: + policy_name: an enum value that indicates the eviction policy to use. + Example: + module = HashZchOptEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION) + """ + + def __init__( + self, + policy_name: HashZchEvictionPolicyName, + ) -> None: + super().__init__() + + self._policy_name: HashZchEvictionPolicyName = policy_name + + def forward(self, feature: JaggedTensor, device: torch.device) -> Tuple[None, int]: + """ + Does not apply to this Eviction Policy. Returns None and -1. + Args: + feature: No op + Returns: + None, -1 + """ + return None, -1 + + +@torch.fx.wrap +def get_eviction_module( + policy_name: HashZchEvictionPolicyName, config: Optional[HashZchEvictionConfig] +) -> torch.nn.Module: + if policy_name in ( + HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION, + HashZchEvictionPolicyName.LRU_EVICTION, + ): + return HashZchThresholdEvictionModule(policy_name, none_throws(config)) + else: + return HashZchOptEvictionModule(policy_name) + + +class HashZchEvictionModule(torch.nn.Module): + """ + This module manages the eviction of IDs from the ZCH table based on the selected eviction policy. + Args: + policy_name: an enum value that indicates the eviction policy to use. + device: device of the tensor. + config: an optional config required if threshold based eviction is selected. + Example: + module = HashZchEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION) + """ + + def __init__( + self, + policy_name: HashZchEvictionPolicyName, + device: torch.device, + config: Optional[HashZchEvictionConfig], + ) -> None: + super().__init__() + + self._policy_name: HashZchEvictionPolicyName = policy_name + self._device: torch.device = device + self._eviction_module: torch.nn.Module = get_eviction_module( + self._policy_name, config + ) + + logger.info(f"HashZchEvictionModule: {self._policy_name=}, {self._device=}") + + def forward(self, feature: JaggedTensor) -> Tuple[Optional[torch.Tensor], int]: + """ + Args: + feature: a jagged tensor that contains the input IDs, and their lengths and + weights (feature split). + + Returns: + For threshold eviction, a tensor that contains the eviction score for each ID, plus an eviction threshold. Otherwise None and -1. + """ + return self._eviction_module(feature, self._device) diff --git a/torchrec/modules/hash_mc_metrics.py b/torchrec/modules/hash_mc_metrics.py new file mode 100644 index 000000000..714cf8c2a --- /dev/null +++ b/torchrec/modules/hash_mc_metrics.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import time +from typing import Optional + +import torch + +from torchrec.modules.hash_mc_evictions import HashZchEvictionConfig + + +class ScalarLogger(torch.nn.Module): + """ + A logger to report various metrics related to multi-probe ZCH. + + Args: + name: name of the embedding table. + zch_size: size of the sharded embedding table. + frequency: frequency of reporting metrics. + start_bucket: start bucket of the rank. + + + Example:: + logger = ScalarLogger(...) + logger(run_type, identities) + """ + + STEPS_BUFFER: str = "_scalar_logger_steps" + SECONDS_IN_HOUR: int = 3600 + MAX_HOURS: int = 2**31 - 1 + + def __init__( + self, + name: str, + zch_size: int, + frequency: int, + start_bucket: int, + log_file_path: str = "", + ) -> None: + super().__init__() + + self.register_buffer( + ScalarLogger.STEPS_BUFFER, + torch.tensor(1, dtype=torch.int64), + persistent=False, + ) + + self._name: str = name + self._zch_size: int = zch_size + self._frequency: int = frequency + self._start_bucket: int = start_bucket + + self._dtype_checked: bool = False + self._total_cnt: int = 0 + self._hit_cnt: int = 0 + self._insert_cnt: int = 0 + self._collision_cnt: int = 0 + self._eviction_cnt: int = 0 + self._opt_in_cnt: int = 0 + self._sum_eviction_age: float = 0.0 + + self.logger: logging.Logger = logging.getLogger() + if ( + log_file_path != "" + ): # if a log file path is provided, create a file handler to output logs to the file + file_handler = logging.FileHandler( + log_file_path, mode="w" + ) # initialize file handler + self.logger.addHandler(file_handler) # add file handler to logger + + def should_report(self) -> bool: + # We only need to report metrics from rank0 (start_bucket = 0) + + return ( + self._start_bucket == 0 + and self._total_cnt > 0 + and + # pyre-fixme[29]: `Union[(self: TensorBase, other: Any) -> Tensor, Tensor, + # Module]` is not a function. + self._scalar_logger_steps % self._frequency == 0 + ) + + def update( + self, + identities_0: torch.Tensor, + identities_1: torch.Tensor, + values: torch.Tensor, + remapped_ids: torch.Tensor, + evicted_emb_indices: Optional[torch.Tensor], + metadata: Optional[torch.Tensor], + num_reserved_slots: int, + eviction_config: Optional[HashZchEvictionConfig] = None, + ) -> None: + if not self._dtype_checked: + assert ( + identities_0.dtype == values.dtype + ), "identity type and feature type must match for meaningful metrics collection." + self._dtype_checked = True + + remapped_identities_0 = torch.index_select(identities_0, 0, remapped_ids)[:, 0] + remapped_identities_1 = torch.index_select(identities_1, 0, remapped_ids)[:, 0] + empty_slot_cnt_before_process = remapped_identities_0 == -1 + empty_slot_cnt_after_process = remapped_identities_1 == -1 + insert_cnt = int(torch.sum(empty_slot_cnt_before_process).item()) - int( + torch.sum(empty_slot_cnt_after_process).item() + ) + + self._insert_cnt += insert_cnt + self._total_cnt += values.numel() + hits = torch.eq(remapped_identities_0, values) + hit_cnt = int(torch.sum(hits).item()) + self._hit_cnt += hit_cnt + self._collision_cnt += values.numel() - hit_cnt - insert_cnt + + opt_in_range = self._zch_size - num_reserved_slots + opt_in_ids = torch.lt(remapped_ids, opt_in_range) + self._opt_in_cnt += int(torch.sum(opt_in_ids).item()) + + if evicted_emb_indices is not None and evicted_emb_indices.numel() > 0: + deduped_evicted_indices = torch.unique(evicted_emb_indices) + self._eviction_cnt += deduped_evicted_indices.numel() + + assert ( + metadata is not None + ), "metadata cannot be None when evicted_emb_indices has values" + now_c = int(time.time()) + cur_hour = now_c / ScalarLogger.SECONDS_IN_HOUR % ScalarLogger.MAX_HOURS + if eviction_config is not None and eviction_config.single_ttl is not None: + self._sum_eviction_age += int( + torch.sum( + cur_hour + + eviction_config.single_ttl + - metadata[deduped_evicted_indices, 0] + ).item() + ) + + def forward( + self, + run_type: str, + identities: torch.Tensor, + ) -> None: + """ + Args: + run_type: type of the run (train, eval, etc). + identities: the identities tensor for metrics computation. + + Returns: + None + """ + + if self.should_report(): + hit_rate = round(self._hit_cnt / self._total_cnt, 3) + insert_rate = round(self._insert_cnt / self._total_cnt, 3) + collision_rate = round(self._collision_cnt / self._total_cnt, 3) + eviction_rate = round(self._eviction_cnt / self._total_cnt, 3) + total_unused_slots = int(torch.sum(identities[:, 0] == -1).item()) + table_usage_ratio = round( + (self._zch_size - total_unused_slots) / self._zch_size, 3 + ) + opt_in_rate = ( + round(self._opt_in_cnt / self._total_cnt, 3) + if self._total_cnt > 0 + else 0 + ) + avg_eviction_age = ( + round(self._sum_eviction_age / self._eviction_cnt, 3) + if self._eviction_cnt > 0 + else 0 + ) + + # log the metrics to console (if no log file path is provided) or to the file (if a log file path is provided) + self.logger.info( + f"{self._name=}, {run_type=}, " + f"{self._total_cnt=}, {self._hit_cnt=}, {hit_rate=}, " + f"{self._insert_cnt=}, {insert_rate=}, " + f"{self._collision_cnt=}, {collision_rate=}, " + f"{self._eviction_cnt=}, {eviction_rate=}, {avg_eviction_age=}, " + f"{self._opt_in_cnt=}, {opt_in_rate=}, " + f"{total_unused_slots=}, {table_usage_ratio=}" + ) + + # reset the counter after reporting + self._total_cnt = 0 + self._hit_cnt = 0 + self._insert_cnt = 0 + self._collision_cnt = 0 + self._eviction_cnt = 0 + self._opt_in_cnt = 0 + self._sum_eviction_age = 0.0 + + # pyre-ignore[16]: `ScalarLogger` has no attribute `_scalar_logger_steps`. + # pyre-ignore[29]: `Union[(self: TensorBase, other: Any) -> Tensor, Tensor, Module]` is not a function. + self._scalar_logger_steps += 1 diff --git a/torchrec/modules/hash_mc_modules.py b/torchrec/modules/hash_mc_modules.py new file mode 100644 index 000000000..81eb3138b --- /dev/null +++ b/torchrec/modules/hash_mc_modules.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import math +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import fbgemm_gpu # @manual=//deeplearning/fbgemm/fbgemm_gpu:fbgemm_gpu + +import torch + +from torchrec.modules.hash_mc_evictions import ( + get_kernel_from_policy, + HashZchEvictionConfig, + HashZchEvictionModule, + HashZchEvictionPolicyName, +) +from torchrec.modules.hash_mc_metrics import ScalarLogger +from torchrec.modules.mc_modules import ManagedCollisionModule +from torchrec.sparse.jagged_tensor import JaggedTensor + +logger: logging.Logger = logging.getLogger(__name__) + + +@torch.fx.wrap +def _tensor_may_to_device( + src: torch.Tensor, + device: Optional[torch.device] = None, +) -> Tuple[torch.Tensor, torch.device]: + src_device: torch.device = src.device + if device is None: + return (src, src_device) + + if device.type != "meta" and src_device != device: + return (src.to(device), src_device) + return (src, src_device) + + +class TrainInputMapper(torch.nn.Module): + """ + Module used to generate sizes and offsets information corresponding to + the train ranks for inference inputs. This is due to we currently merge + all identity tensors that are row-wise sharded across training ranks at + inference time. So we need to map the inputs to the chunk of identities + that the input would go at training time to generate appropriate indices. + + Args: + input_hash_size: the max size of input IDs + total_num_buckets: the total number of buckets across all ranks at training time + size_per_rank: the size of the identity tensor/embedding size per rank + train_rank_offsets: the offset of the embedding table indices per rank + inference_dispatch_div_train_world_size: the flag to control whether to divide input by + world_size https://fburl.com/code/c9x98073 + name: the name of the embedding table + + Example:: + mapper = TrainInputMapper(...) + mapper(values, output_offset) + """ + + def __init__( + self, + input_hash_size: int, + total_num_buckets: int, + size_per_rank: torch.Tensor, + train_rank_offsets: torch.Tensor, + inference_dispatch_div_train_world_size: bool = False, + name: Optional[str] = None, + ) -> None: + super().__init__() + + self._input_hash_size = input_hash_size + assert total_num_buckets > 0, f"{total_num_buckets=} must be positive" + self._buckets = total_num_buckets + self._inference_dispatch_div_train_world_size = ( + inference_dispatch_div_train_world_size + ) + self._name = name + self.register_buffer( + "_zch_size_per_training_rank", size_per_rank, persistent=False + ) + self.register_buffer( + "_train_rank_offsets", train_rank_offsets, persistent=False + ) + logger.info( + f"TrainInputMapper: {self._name=}, {self._input_hash_size=}, {self._zch_size_per_training_rank=}, " + f"{self._train_rank_offsets=}, {self._inference_dispatch_div_train_world_size=}" + ) + + # TODO: make a kernel + def _get_values_sizes_offsets( + self, x: torch.Tensor, output_offset: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + zch_size_per_training_rank, _ = _tensor_may_to_device( + self._zch_size_per_training_rank, x.device + ) + train_rank_offsets, _ = _tensor_may_to_device( + self._train_rank_offsets, x.device + ) + + # NOTE: This assumption has to be the same as TorchRec input_dist logic + # https://fburl.com/code/c9x98073. Do not use torch.where() for performance. + if self._input_hash_size == 0: + train_ranks = x % self._buckets + if self._inference_dispatch_div_train_world_size: + x = x // self._buckets + else: + blk_size = (self._input_hash_size // self._buckets) + ( + 0 if self._input_hash_size % self._buckets == 0 else 1 + ) + train_ranks = x // blk_size + if self._inference_dispatch_div_train_world_size: + x = x % blk_size + + local_sizes = zch_size_per_training_rank.index_select( + dim=0, index=train_ranks + ) # This line causes error where zch_size_per_training_rank = tensor([25000, 25000, 25000, 25000], device='cuda:1') and train_ranks = tensor([291, 34, 15], device='cuda:1'), leading to index error: index out of range + offsets = train_rank_offsets.index_select(dim=0, index=train_ranks) + if output_offset is not None: + offsets -= output_offset + + return (x, local_sizes, offsets) + + def forward( + self, + values: torch.Tensor, + output_offset: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Args: + values: ID values to compute bucket assignment and offset. + output_offset: global offset of the start bucket per rank, used to compute bucket offset within the rank. + + Returns: + A tuple of three tensors: + - values: transformed ID values, different from input value only if inference_dispatch_div_train_world_size is True. + - local_sizes: bucket sizes of the input values. + - offsets: in-rank bucket offsets of the input values. + """ + + values, local_sizes, offsets = self._get_values_sizes_offsets( + values, output_offset + ) + return (values, local_sizes, offsets) + + +@torch.fx.wrap +def _get_device(hash_zch_identities: torch.Tensor) -> torch.device: + return hash_zch_identities.device + + +class HashZchManagedCollisionModule(ManagedCollisionModule): + """ + Module to manage multi-probe ZCH (MPZCH), including lookup (remapping), eviction, metrics collection, and required auxiliary tensors. + + Args: + zch_size: local size of the embedding table + device: the compute device + total_num_buckets: logical shard within each rank for resharding purpose, note that + 1) zch_size must be a multiple of total_num_buckets, and 2) total_num_buckets must be a multiple of world size + max_probe: the number of times MPZCH kernel attempts to run linear search for lookup or insertion + input_hash_size: the max size of input IDs (default to 0) + output_segments: the index range of each bucket, which is computed before sharding and typically not provided by user + is_inference: the flag to indicate if the module is used in inference as opposed to train or eval + name: the name of the embedding table + tb_logging_frequency: the frequency of emitting metrics to TensorBoard, measured by the number of batches + eviction_policy_name: the specific policy to be used for eviction operations + eviction_config: the config associated with the selected eviction policy + inference_dispatch_div_train_world_size: the flag to control whether to divide input by + world_size https://fburl.com/code/c9x98073 + start_bucket: start bucket of the current rank, typically not provided by user + end_bucket: end bucket of the current rank, typically not provided by user + opt_in_prob: the probability of an ID to be opted in from a statistical aspect + percent_reserved_slots: percentage of slots to be reserved when opt-in is enabled, the value must be in [0, 100) + + Example:: + module = HashZchManagedCollisionModule(...) + module(features) + """ + + _evicted_indices: List[torch.Tensor] + + IDENTITY_BUFFER: str = "_hash_zch_identities" + METADATA_BUFFER: str = "_hash_zch_metadata" + + def __init__( + self, + zch_size: int, + device: torch.device, + total_num_buckets: int, + max_probe: int = 128, + input_hash_size: int = 0, + output_segments: Optional[List[int]] = None, + is_inference: bool = False, + name: Optional[str] = None, + tb_logging_frequency: int = 0, + eviction_policy_name: Optional[HashZchEvictionPolicyName] = None, + eviction_config: Optional[HashZchEvictionConfig] = None, + inference_dispatch_div_train_world_size: bool = False, + start_bucket: int = 0, + end_bucket: Optional[int] = None, + opt_in_prob: int = -1, + percent_reserved_slots: float = 0, + ) -> None: + if output_segments is None: + assert ( + zch_size % total_num_buckets == 0 + ), f"please pass output segments if not uniform buckets {zch_size=}, {total_num_buckets=}" + output_segments = [ + (zch_size // total_num_buckets) * bucket + for bucket in range(total_num_buckets + 1) + ] + + super().__init__( + device=device, + output_segments=output_segments, + skip_state_validation=True, # avoid peristent buffers for TGIF Puslishing + ) + + self._zch_size: int = zch_size + self._output_segments: List[int] = output_segments + self._start_bucket: int = start_bucket + self._end_bucket: int = ( + end_bucket if end_bucket is not None else total_num_buckets + ) + self._output_global_offset_tensor: Optional[torch.Tensor] = None + if output_segments[start_bucket] > 0: + self._output_global_offset_tensor = torch.tensor( + [output_segments[start_bucket]], + dtype=torch.int64, + device=device if device.type != "meta" else torch.device("cpu"), + ) + + self._device: torch.device = device + self._input_hash_size: int = input_hash_size + self._is_inference: bool = is_inference + self._name: Optional[str] = name + self._tb_logging_frequency: int = tb_logging_frequency + self._scalar_logger: Optional[ScalarLogger] = None + self._eviction_policy_name: Optional[HashZchEvictionPolicyName] = ( + eviction_policy_name + ) + self._eviction_config: Optional[HashZchEvictionConfig] = eviction_config + self._eviction_module: Optional[HashZchEvictionModule] = ( + HashZchEvictionModule( + policy_name=self._eviction_policy_name, + device=self._device, + config=self._eviction_config, + ) + if self._eviction_policy_name is not None and self.training + else None + ) + self._opt_in_prob: int = opt_in_prob + assert ( + percent_reserved_slots >= 0 and percent_reserved_slots < 100 + ), "percent_reserved_slots must be in [0, 100)" + self._percent_reserved_slots: float = percent_reserved_slots + if self._opt_in_prob > 0: + assert ( + self._percent_reserved_slots > 0 + ), "percent_reserved_slots must be positive when opt_in_prob is positive" + assert ( + self._eviction_policy_name is None + or self._eviction_policy_name != HashZchEvictionPolicyName.LRU_EVICTION + ), "LRU eviction is not compatible with opt-in at this time" + + if torch.jit.is_scripting() or self._is_inference or self._name is None: + self._tb_logging_frequency = 0 + + if self._tb_logging_frequency > 0 and self._device.type != "meta": + assert self._name is not None + self._scalar_logger = ScalarLogger( + name=self._name, + zch_size=self._zch_size, + frequency=self._tb_logging_frequency, + start_bucket=self._start_bucket, + ) + else: + logger.info( + f"ScalarLogger is disabled because {self._tb_logging_frequency=} and {self._device.type=}" + ) + + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + size=self._zch_size, + support_evict=self._eviction_module is not None, + device=self._device, + long_type=True, # deprecated, always True + ) + + self._hash_zch_identities = torch.nn.Parameter(identities, requires_grad=False) + self.register_buffer(HashZchManagedCollisionModule.METADATA_BUFFER, metadata) + + self._max_probe = max_probe + self._buckets = total_num_buckets + # Do not need to store in buffer since this is created and consumed + # at each step https://fburl.com/code/axzimmbx + self._evicted_indices = [] + + # do not pass device, so its initialized on default physical device ('meta' will result in silent failure) + size_per_rank = torch.diff( + torch.tensor(self._output_segments, dtype=torch.int64) + ) + + self.input_mapper: torch.nn.Module = TrainInputMapper( + input_hash_size=self._input_hash_size, + total_num_buckets=total_num_buckets, + size_per_rank=size_per_rank, + train_rank_offsets=torch.tensor( + torch.ops.fbgemm.asynchronous_exclusive_cumsum(size_per_rank) + ), + # be consistent with https://fburl.com/code/p4mj4mc1 + inference_dispatch_div_train_world_size=inference_dispatch_div_train_world_size, + name=self._name, + ) + + if self._is_inference is True: + self.reset_inference_mode() + + # create two dictionaries to store the input values and remapped ids on the current rank + # these values are used for calculating zch metrics like hit rate and collision rate + ## on-device remapped ids + self.table_name_on_device_remapped_ids_dict: Dict[str, torch.Tensor] = ( + {} + ) # {table_name: on_device_remapped_ids} + ## on-device input ids + self.table_name_on_device_input_ids_dict: Dict[str, torch.Tensor] = ( + {} + ) # {table_name: input JT values that maps to the current rank} + + logger.info( + f"HashZchManagedCollisionModule: {self._name=}, {self.device=}, " + f"{self._zch_size=}, {self._input_hash_size=}, {self._max_probe=}, " + f"{self._is_inference=}, {self._tb_logging_frequency=}, " + f"{self._eviction_policy_name=}, {self._eviction_config=}, " + f"{self._buckets=}, {self._start_bucket=}, {self._end_bucket=}, " + f"{self._output_global_offset_tensor=}, {self._output_segments=}, " + f"{inference_dispatch_div_train_world_size=}, " + f"{self._opt_in_prob=}, {self._percent_reserved_slots=}" + ) + + @property + def device(self) -> torch.device: + return _get_device(self._hash_zch_identities) + + def buckets(self) -> int: + return self._buckets + + # TODO: This is hacky as we are using parameters to go through publishing. + # Can remove once working out buffer solution. + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + yield from super().named_buffers(prefix, recurse, remove_duplicate) + key: str = HashZchManagedCollisionModule.IDENTITY_BUFFER + if prefix: + key = f"{prefix}.{key}" + yield (key, self._hash_zch_identities.data) + + def validate_state(self) -> None: + raise NotImplementedError() + + def reset_inference_mode( + self, + ) -> None: + logger.info("HashZchManagedCollisionModule resetting inference mode") + # not revertable + self.eval() + self._is_inference = True + self._hash_zch_metadata = None + self._evicted_indices = [] + self._eviction_policy_name = None + self._eviction_module = None + + def _load_state_dict_pre_hook( + module: "HashZchManagedCollisionModule", + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + logger.info("HashZchManagedCollisionModule loading state dict") + # We store the full identity in checkpoint and predictor, cut it at inference loading + if not self._is_inference: + return + if "_hash_zch_metadata" in state_dict: + del state_dict["_hash_zch_metadata"] + + self._register_load_state_dict_pre_hook( + _load_state_dict_pre_hook, with_module=True + ) + + def preprocess( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return features + + def evict(self) -> Optional[torch.Tensor]: + if len(self._evicted_indices) == 0: + return None + out = torch.unique(torch.cat(self._evicted_indices)) + self._evicted_indices = [] + return ( + out + self._output_global_offset_tensor + if self._output_global_offset_tensor + else out + ) + + def profile( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return features + + def get_reserved_slots_per_bucket(self) -> int: + if self._opt_in_prob == -1: + return -1 + + return math.floor( + self._zch_size + * self._percent_reserved_slots + / 100 + / (self._end_bucket - self._start_bucket) + ) + + def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + metadata: Optional[torch.Tensor] = self._hash_zch_metadata + readonly: bool = False + if self._output_global_offset_tensor is not None: + self._output_global_offset_tensor, _ = _tensor_may_to_device( + self._output_global_offset_tensor, self.device + ) + if not self.training: + readonly = True + metadata = None + + # _evicted_indices will be reset in evict(): https://fburl.com/code/r3fxcs1y + assert len(self._evicted_indices) == 0 + + # `torch.no_grad()` Annotatin prevents torchscripting `JaggedTensor` for some reason... + with torch.no_grad(): + remapped_features: Dict[str, JaggedTensor] = {} + identities_0 = ( + self._hash_zch_identities.data.clone() + if self._tb_logging_frequency > 0 + else None + ) + + for name, feature in features.items(): + values = feature.values() + input_metadata, eviction_threshold = ( + self._eviction_module(feature) + if self._eviction_module is not None + else (None, -1) + ) + + opt_in_rands = ( + (torch.rand_like(values, dtype=torch.float) * 100).to(torch.int32) + if self._opt_in_prob != -1 and self.training + else None + ) + + values, orig_device = _tensor_may_to_device(values, self.device) + values, local_sizes, offsets = self.input_mapper( + values=values, + output_offset=self._output_global_offset_tensor, + ) + + self.table_name_on_device_input_ids_dict[name] = values.clone() + + num_reserved_slots = self.get_reserved_slots_per_bucket() + remapped_ids, evictions = torch.ops.fbgemm.zero_collision_hash( + input=values, + identities=self._hash_zch_identities, + max_probe=self._max_probe, + circular_probe=True, + exp_hours=-1, # deprecated, always -1 + readonly=readonly, + local_sizes=local_sizes, + offsets=offsets, + metadata=metadata, + # Use self._is_inference to turn on writing to pinned + # CPU memory directly. But may not have perf benefit. + output_on_uvm=False, # self._is_inference, + disable_fallback=False, + _modulo_identity_DPRECATED=False, # deprecated, always False + input_metadata=input_metadata, + eviction_threshold=eviction_threshold, + eviction_policy=get_kernel_from_policy(self._eviction_policy_name), + opt_in_prob=self._opt_in_prob, + num_reserved_slots=num_reserved_slots, + opt_in_rands=opt_in_rands, + ) + + # record the on-device remapped ids + self.table_name_on_device_remapped_ids_dict[name] = remapped_ids.clone() + + if self._scalar_logger is not None: + assert identities_0 is not None + self._scalar_logger.update( + identities_0=identities_0, + identities_1=self._hash_zch_identities, + values=values, + remapped_ids=remapped_ids, + evicted_emb_indices=evictions, + metadata=metadata, + num_reserved_slots=num_reserved_slots, + eviction_config=self._eviction_config, + ) + + output_global_offset_tensor = self._output_global_offset_tensor + if output_global_offset_tensor is not None: + remapped_ids = remapped_ids + output_global_offset_tensor + + _append_eviction_indice(self._evicted_indices, evictions) + remapped_ids, _ = _tensor_may_to_device(remapped_ids, orig_device) + + remapped_features[name] = JaggedTensor( + values=remapped_ids, + lengths=feature.lengths(), + offsets=feature.offsets(), + weights=feature.weights_or_none(), + ) + + # if name == "t_cat_0": + # print("remapped_feature", remapped_ids) + + if self._scalar_logger is not None: + self._scalar_logger( + run_type="train" if self.training else "eval", + identities=self._hash_zch_identities.data, + ) + + return remapped_features + + def forward( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return self.remap(features) + + def output_size(self) -> int: + return self._zch_size + + def input_size(self) -> int: + return self._input_hash_size + + def open_slots(self) -> torch.Tensor: + return torch.tensor([0]) + + def rebuild_with_output_id_range( + self, + output_id_range: Tuple[int, int], + output_segments: Optional[List[int]] = None, + device: Optional[torch.device] = None, + ) -> "HashZchManagedCollisionModule": + # rebuild should use existing output_segments instead of the input one and should not + # recalculate since the output segments are calculated based on the original embedding + # table size, total bucket number, which might not be available for the rebuild caller + try: + start_idx = self._output_segments.index(output_id_range[0]) + end_idx = self._output_segments.index(output_id_range[1]) + except ValueError: + raise RuntimeError( + f"Attempting to shard HashZchManagedCollisionModule, but rank {device} does not align with bucket boundaries;" + + f" please check kwarg total_num_buckets={self._buckets} is a multiple of world size." + ) + new_zch_size = output_id_range[1] - output_id_range[0] + + return self.__class__( + zch_size=new_zch_size, + device=device or self.device, + max_probe=self._max_probe, + total_num_buckets=self._buckets, + input_hash_size=self._input_hash_size, + is_inference=self._is_inference, + start_bucket=start_idx, + end_bucket=end_idx, + output_segments=self._output_segments, + name=self._name, + tb_logging_frequency=self._tb_logging_frequency, + eviction_policy_name=self._eviction_policy_name, + eviction_config=self._eviction_config, + opt_in_prob=self._opt_in_prob, + percent_reserved_slots=self._percent_reserved_slots, + ) + + +@torch.fx.wrap +def _append_eviction_indice( + evicted_indices: List[torch.Tensor], + evictions: Optional[torch.Tensor], +) -> None: + if evictions is not None and evictions.numel() > 0: + evicted_indices.append(evictions) diff --git a/torchrec/modules/mc_adapter.py b/torchrec/modules/mc_adapter.py new file mode 100644 index 000000000..fd971e389 --- /dev/null +++ b/torchrec/modules/mc_adapter.py @@ -0,0 +1,188 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict +import sys +from typing import Dict, Iterator, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + +class McEmbeddingCollectionAdapter(nn.Module): + """ + Managed Collision Embedding Collection Adapter + The adapter to convert exiting EmbeddingCollection to Managed Collision Embedding Collection module + The adapter will use the original EmbeddingCollection table but will pass input + """ + + def __init__( + self, + embedding_collection: EmbeddingCollection, + input_hash_size: int, + device: torch.device, + eviction_interval: int = 2, + allow_in_place_embed_weight_update: bool = False, + ) -> None: + """ + INIT_DOC_STRING + """ + super().__init__() + # build dictionary for {table_name: table_config} + mc_modules = {} + for table_name, table_config in embedding_collection.embedding_configs(): + mc_modules[table_name] = MCHManagedCollisionModule( + zch_size=table_config.num_embeddings, + device=device, + input_hash_size=input_hash_size, + eviction_interval=eviction_interval, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + self.mc_embedding_collection = ManagedCollisionEmbeddingCollection( + embedding_collection=embedding_collection, + managed_collision_collection=ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=embedding_collection.embedding_configs(), + ), + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + return_remapped_features=True, # not return remapped features + ) + self.remapped_ids = None # to store remapped ids + + def forward(self, input: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + """ + Args: + input (KeyedJaggedTensor): KJT of form [F X B X L]. + Returns: + Dict[str, JaggedTensor]: dictionary of {'feature_name': JaggedTensor} + """ + mc_ec_out, remapped_ids = self.mc_embedding_collection(input) + self.remapped_ids = remapped_ids + return mc_ec_out[0] + + +class McEmbeddingBagCollectionAdapter(nn.Module): + """ + Managed Collision Embedding Collection Adapter + The adapter to convert exiting EmbeddingCollection to Managed Collision Embedding Collection module + The adapter will use the original EmbeddingCollection table but will pass input + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + input_hash_size: int, + device: torch.device, + world_size: int, + eviction_interval: int = 1, + allow_in_place_embed_weight_update: bool = False, + use_mpzch: bool = False, + mpzch_num_buckets: Optional[int] = None, + ) -> None: + """ + Initialize an EmbeddingBagCollectionAdapter. + Parameters: + tables (List[EmbeddingBagConfig]): List of EmbeddingBagConfig. Should be the same as the original EmbeddingBagCollection. + input_hash_size (int): the upper bound of input feature values + device (torch.device): the device to use + world_size (int): the world size + eviction_interval (int): the eviction interval, default to 1 hour + allow_in_place_embed_weight_update (bool): whether to allow in-place embedding weight update + use_mpzch (bool): whether to use MPZCH or not # TODO: change this to a str to support different zch + mpzch_num_buckets (Optional[int]): the number of buckets for MPZCH # TODO: change this to a config dict to support different zch configs + """ + # super().__init__(tables=tables, device=device) + super().__init__() + # create ebc from table configs + ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) + # build dictionary for {table_name: table_config} + mc_modules = {} + for table_config in ebc.embedding_bag_configs(): + table_name = table_config.name + if use_mpzch: + # if use MPZCH, create a HashZchManagedCollisionModule + mc_modules[table_name] = HashZchManagedCollisionModule( # MPZCH + is_inference=False, + zch_size=(table_config.num_embeddings), + input_hash_size=input_hash_size, + device=device, + total_num_buckets=( + mpzch_num_buckets if mpzch_num_buckets else world_size + ), # total_num_buckets if not passed, use world_size, WORLD_SIZE should be a factor of total_num_buckets + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # defaultly using single ttl eviction policy + eviction_config=HashZchEvictionConfig( + features=table_config.feature_names, + single_ttl=eviction_interval, + ), + ) + else: # if not use MPZCH, create a MCHManagedCollisionModule using the sort ZCH algorithm + mc_modules[table_name] = MCHManagedCollisionModule( # sort ZCH + zch_size=table_config.num_embeddings, + device=device, + input_hash_size=input_hash_size, + eviction_interval=eviction_interval, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) # NOTE: the benchmark for sort ZCH is not implemented yet + # TODO: add the pure hash module here + + # create the mcebc module with the mc modules and the original ebc + self.mc_embedding_bag_collection = ( + ManagedCollisionEmbeddingBagCollection( # ZCH or not + embedding_bag_collection=ebc, + managed_collision_collection=ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=ebc.embedding_bag_configs(), + ), + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + return_remapped_features=False, # not return remapped features + ) + ) + + def forward(self, input_kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + """ + Args: + input (KeyedJaggedTensor): KJT of form [F X B X L]. + Returns: + Dict[str, JaggedTensor]: dictionary of {'feature_name': JaggedTensor} + """ + mc_ebc_out, per_table_remapped_id = self.mc_embedding_bag_collection(input_kjt) + return mc_ebc_out + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + # only return the parameters of the original EmbeddingBagCollection, not _managed_collision_collection modules + return self.mc_embedding_bag_collection._embedding_module.parameters( + recurse=recurse + ) + + def embedding_bag_configs(self) -> List[EmbeddingConfig]: + """ + Returns: + Dict[str, EmbeddingConfig]: dictionary of {'feature_name': EmbeddingConfig} + """ + # pyre-ignore[16]: `ManagedCollisionEmbeddingBagCollection` has no attribute `_embedding_module` + return ( + self.mc_embedding_bag_collection._embedding_module.embedding_bag_configs() + ) + + def get_per_table_remapped_id(self) -> Dict[str, JaggedTensor]: + return self.per_table_remapped_id diff --git a/torchrec/modules/tests/test_hash_mc_evictions.py b/torchrec/modules/tests/test_hash_mc_evictions.py new file mode 100644 index 000000000..e62b0d819 --- /dev/null +++ b/torchrec/modules/tests/test_hash_mc_evictions.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import unittest +from unittest.mock import patch + +import torch +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchPerFeatureTtlScorer, + HashZchSingleTtlScorer, +) +from torchrec.sparse.jagged_tensor import JaggedTensor + + +class TestEvictionScorer(unittest.TestCase): + # pyre-ignore [56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_single_ttl_scorer(self) -> None: + scorer = HashZchSingleTtlScorer( + config=HashZchEvictionConfig(features=["f1"], single_ttl=24) + ) + + jt = JaggedTensor( + values=torch.arange(0, 5, dtype=torch.int64), + lengths=torch.tensor([2, 2, 1], dtype=torch.int64), + ) + + with patch("time.time") as mock_time: + mock_time.return_value = 36000000 # hour 10000 + score = scorer.gen_score(jt, device=torch.device("cuda")) + self.assertTrue( + torch.equal( + score, + torch.tensor([10024, 10024, 10024, 10024, 10024], device="cuda"), + ), + f"{torch.unique(score)=}", + ) + + # pyre-ignore [56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_per_feature_ttl_scorer(self) -> None: + scorer = HashZchPerFeatureTtlScorer( + config=HashZchEvictionConfig( + features=["f1", "f2"], per_feature_ttl=[24, 48] + ) + ) + + jt = JaggedTensor( + values=torch.arange(0, 5, dtype=torch.int64), + lengths=torch.tensor([2, 2, 1], dtype=torch.int64), + weights=torch.tensor([4, 1], dtype=torch.int64), + ) + + with patch("time.time") as mock_time: + mock_time.return_value = 36000000 # hour 10000 + score = scorer.gen_score(jt, device=torch.device("cuda")) + self.assertTrue( + torch.equal( + score, + torch.tensor([10024, 10024, 10024, 10024, 10048], device="cuda"), + ), + f"{torch.unique(score)=}", + ) diff --git a/torchrec/modules/tests/test_hash_mc_modules.py b/torchrec/modules/tests/test_hash_mc_modules.py new file mode 100644 index 000000000..113c05b5b --- /dev/null +++ b/torchrec/modules/tests/test_hash_mc_modules.py @@ -0,0 +1,650 @@ +#!/usr/bin/env python3 +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import unittest +from typing import cast + +import torch +from hypothesis import given, settings, strategies as st +from pyre_extensions import none_throws +from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.mc_modules import ( + ManagedCollisionCollection, + ManagedCollisionModule, +) +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor + + +class TestMCH(unittest.TestCase): + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_zch_hash_inference(self) -> None: + # prepare + m1 = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cuda"), + total_num_buckets=2, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + self.assertEqual(m1._hash_zch_identities.dtype, torch.int64) + in1 = { + "f": JaggedTensor( + values=torch.arange(0, 20, 2, dtype=torch.int64, device="cuda"), + lengths=torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ), + } + o1 = m1(in1)["f"].values() + self.assertTrue( + torch.equal(torch.unique(o1), torch.arange(0, 10, device="cuda")), + f"{torch.unique(o1)=}", + ) + + in2 = { + "f": JaggedTensor( + values=torch.arange(1, 20, 2, dtype=torch.int64, device="cuda"), + lengths=torch.tensor([8, 2], dtype=torch.int64, device="cuda"), + ), + } + o2 = m1(in2)["f"].values() + self.assertTrue( + torch.equal(torch.unique(o2), torch.arange(10, 20, device="cuda")), + f"{torch.unique(o2)=}", + ) + + for device_str in ["cpu", "cuda"]: + # Inference + m_infer = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device(device_str), + total_num_buckets=2, + ) + + m_infer.reset_inference_mode() + m_infer.to(device_str) + + self.assertTrue( + torch.equal( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. + none_throws(m_infer.input_mapper._zch_size_per_training_rank), + torch.tensor([10, 10], dtype=torch.int64, device=device_str), + ) + ) + self.assertTrue( + torch.equal( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. + none_throws(m_infer.input_mapper._train_rank_offsets), + torch.tensor([0, 10], dtype=torch.int64, device=device_str), + ) + ) + + m_infer._hash_zch_identities = torch.nn.Parameter( + m1._hash_zch_identities[:, :1], + requires_grad=False, + ) + in12 = { + "f": JaggedTensor( + values=torch.arange(0, 20, dtype=torch.int64, device=device_str), + lengths=torch.tensor( + [4, 6, 8, 2], dtype=torch.int64, device=device_str + ), + ), + } + m_infer = torch.jit.script(m_infer) + o_infer = m_infer(in12)["f"].values() + o12 = torch.stack([o1, o2], dim=1).view(-1).to(device_str) + self.assertTrue(torch.equal(o_infer, o12), f"{o_infer=} vs {o12=}") + + m3 = HashZchManagedCollisionModule( + zch_size=10, + device=torch.device("cuda"), + total_num_buckets=2, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + self.assertEqual(m3._hash_zch_identities.dtype, torch.int64) + in3 = { + "f": JaggedTensor( + values=torch.arange(10, 20, dtype=torch.int64, device="cuda"), + lengths=torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ), + } + o3 = m3(in3)["f"].values() + self.assertTrue( + torch.equal(torch.unique(o3), torch.arange(0, 10, device="cuda")), + f"{torch.unique(o3)=}", + ) + # validate that original ids are assigned to identities + self.assertTrue( + torch.equal( + torch.unique(m3._hash_zch_identities), + torch.arange(10, 20, device="cuda"), + ), + f"{torch.unique(m3._hash_zch_identities)=}", + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_scriptability(self) -> None: + zch_size = 10 + mc_modules = { + "t1": cast( + ManagedCollisionModule, + HashZchManagedCollisionModule( + zch_size=zch_size, + device=torch.device("cpu"), + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature"], + ), + total_num_buckets=2, + ), + ) + } + + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=embedding_configs, + ) + torch.jit.script(mcc_ec) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_scriptability_lru(self) -> None: + zch_size = 10 + mc_modules = { + "t1": cast( + ManagedCollisionModule, + HashZchManagedCollisionModule( + zch_size=zch_size, + device=torch.device("cpu"), + total_num_buckets=2, + eviction_policy_name=HashZchEvictionPolicyName.LRU_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature"], + single_ttl=12, + ), + ), + ) + } + + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=embedding_configs, + ) + torch.jit.script(mcc_ec) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPUs", + ) + # pyre-ignore [56] + @given(hash_size=st.sampled_from([0, 80]), keep_original_indices=st.booleans()) + @settings(max_examples=6, deadline=None) + def test_zch_hash_train_to_inf_block_bucketize( + self, hash_size: int, keep_original_indices: bool + ) -> None: + # rank 0 + world_size = 2 + kjt = KeyedJaggedTensor( + keys=["f"], + values=torch.cat( + [ + torch.arange(0, 20, 2, dtype=torch.int64, device="cuda"), + torch.arange(30, 60, 3, dtype=torch.int64, device="cuda"), + ] + ), + lengths=torch.cat( + [ + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ] + ), + ) + block_sizes = torch.tensor( + [(size + world_size - 1) // world_size for size in [hash_size]], + dtype=torch.int64, + device="cuda", + ) + + bucketized_kjt, _ = bucketize_kjt_before_all2all( + kjt, + num_buckets=world_size, + block_sizes=block_sizes, + keep_original_indices=keep_original_indices, + ) + in1, in2 = bucketized_kjt.split([len(kjt.keys())] * world_size) + in1 = in1.to_dict() + in2 = in2.to_dict() + m0 = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cuda"), + input_hash_size=hash_size, + total_num_buckets=2, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + m1 = m0.rebuild_with_output_id_range((0, 10)) + m2 = m0.rebuild_with_output_id_range((10, 20)) + + # simulate calls to each rank + o1 = m1(in1) + o2 = m2(in2) + + m0.reset_inference_mode() + full_zch_identities = torch.cat( + [ + m1.state_dict()["_hash_zch_identities"], + m2.state_dict()["_hash_zch_identities"], + ] + ) + state_dict = m0.state_dict() + state_dict["_hash_zch_identities"] = full_zch_identities + m0.load_state_dict(state_dict) + + # now pass in original kjt + inf_input = kjt.to_dict() + inf_output = m0(inf_input) + + torch.allclose( + inf_output["f"].values(), torch.cat([o1["f"].values(), o2["f"].values()]) + ) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPUs", + ) + # pyre-ignore [56] + @given(hash_size=st.sampled_from([0, 80])) + @settings(max_examples=5, deadline=None) + def test_zch_hash_train_rescales_two(self, hash_size: int) -> None: + keep_original_indices = False + # rank 0 + world_size = 2 + kjt = KeyedJaggedTensor( + keys=["f"], + values=torch.cat( + [ + torch.randint( + 0, + hash_size if hash_size > 0 else 1000, + (20,), + dtype=torch.int64, + device="cuda", + ), + ] + ), + lengths=torch.cat( + [ + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ] + ), + ) + block_sizes = torch.tensor( + [(size + world_size - 1) // world_size for size in [hash_size]], + dtype=torch.int64, + device="cuda", + ) + sub_block_sizes = torch.tensor( + [(size + 2 - 1) // 2 for size in [block_sizes[0]]], + dtype=torch.int64, + device="cuda", + ) + bucketized_kjt, _ = bucketize_kjt_before_all2all( + kjt, + num_buckets=world_size, + block_sizes=block_sizes, + keep_original_indices=keep_original_indices, + ) + in1, in2 = bucketized_kjt.split([len(kjt.keys())] * world_size) + + bucketized_in1, _ = bucketize_kjt_before_all2all( + in1, + num_buckets=2, + block_sizes=sub_block_sizes, + keep_original_indices=keep_original_indices, + ) + bucketized_in2, _ = bucketize_kjt_before_all2all( + in2, + num_buckets=2, + block_sizes=sub_block_sizes, + keep_original_indices=keep_original_indices, + ) + in1_1, in1_2 = bucketized_in1.split([len(kjt.keys())] * 2) + in2_1, in2_2 = bucketized_in2.split([len(kjt.keys())] * 2) + + in1_1, in1_2 = in1_1.to_dict(), in1_2.to_dict() + in2_1, in2_2 = in2_1.to_dict(), in2_2.to_dict() + + m0 = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cuda"), + input_hash_size=hash_size, + total_num_buckets=4, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + + m1_1 = m0.rebuild_with_output_id_range((0, 5)) + m1_2 = m0.rebuild_with_output_id_range((5, 10)) + m2_1 = m0.rebuild_with_output_id_range((10, 15)) + m2_2 = m0.rebuild_with_output_id_range((15, 20)) + + # simulate calls to each rank + o1_1 = m1_1(in1_1) + o1_2 = m1_2(in1_2) + o2_1 = m2_1(in2_1) + o2_2 = m2_2(in2_2) + + m0.reset_inference_mode() + + full_zch_identities = torch.cat( + [ + m1_1.state_dict()["_hash_zch_identities"], + m1_2.state_dict()["_hash_zch_identities"], + m2_1.state_dict()["_hash_zch_identities"], + m2_2.state_dict()["_hash_zch_identities"], + ] + ) + state_dict = m0.state_dict() + state_dict["_hash_zch_identities"] = full_zch_identities + m0.load_state_dict(state_dict) + + # now pass in original kjt + inf_input = kjt.to_dict() + inf_output = m0(inf_input) + torch.allclose( + inf_output["f"].values(), + torch.cat([x["f"].values() for x in [o1_1, o1_2, o2_1, o2_2]]), + ) + + @unittest.skipIf( + torch.cuda.device_count() < 1, + "Not enough GPUs, this test requires at least one GPUs", + ) + # pyre-ignore [56] + @given(hash_size=st.sampled_from([0, 80])) + @settings(max_examples=5, deadline=None) + def test_zch_hash_train_rescales_four(self, hash_size: int) -> None: + keep_original_indices = True + kjt = KeyedJaggedTensor( + keys=["f"], + values=torch.cat( + [ + torch.randint( + 0, + hash_size if hash_size > 0 else 1000, + (20,), + dtype=torch.int64, + device="cuda", + ), + ] + ), + lengths=torch.cat( + [ + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + torch.tensor([4, 6], dtype=torch.int64, device="cuda"), + ] + ), + ) + + # initialize mch with 8 buckets + m0 = HashZchManagedCollisionModule( + zch_size=40, + device=torch.device("cuda"), + input_hash_size=hash_size, + total_num_buckets=4, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=[], + single_ttl=10, + ), + ) + + # start with world_size = 4 + world_size = 4 + block_sizes = torch.tensor( + [(size + world_size - 1) // world_size for size in [hash_size]], + dtype=torch.int64, + device="cuda", + ) + + m1_1 = m0.rebuild_with_output_id_range((0, 10)) + m2_1 = m0.rebuild_with_output_id_range((10, 20)) + m3_1 = m0.rebuild_with_output_id_range((20, 30)) + m4_1 = m0.rebuild_with_output_id_range((30, 40)) + + # shard, now world size 2! + # start with world_size = 4 + if hash_size > 0: + world_size = 2 + block_sizes = torch.tensor( + [(size + world_size - 1) // world_size for size in [hash_size]], + dtype=torch.int64, + device="cuda", + ) + # simulate kjt call + bucketized_kjt, permute = bucketize_kjt_before_all2all( + kjt, + num_buckets=world_size, + block_sizes=block_sizes, + keep_original_indices=keep_original_indices, + output_permute=True, + ) + in1_2, in2_2 = bucketized_kjt.split([len(kjt.keys())] * world_size) + else: + bucketized_kjt, permute = bucketize_kjt_before_all2all( + kjt, + num_buckets=world_size, + block_sizes=block_sizes, + keep_original_indices=keep_original_indices, + output_permute=True, + ) + kjts = bucketized_kjt.split([len(kjt.keys())] * world_size) + # rebuild kjt + in1_2 = KeyedJaggedTensor( + keys=kjts[0].keys(), + values=torch.cat([kjts[0].values(), kjts[1].values()], dim=0), + lengths=torch.cat([kjts[0].lengths(), kjts[1].lengths()], dim=0), + ) + in2_2 = KeyedJaggedTensor( + keys=kjts[2].keys(), + values=torch.cat([kjts[2].values(), kjts[3].values()], dim=0), + lengths=torch.cat([kjts[2].lengths(), kjts[3].lengths()], dim=0), + ) + + m1_2 = m0.rebuild_with_output_id_range((0, 20)) + m2_2 = m0.rebuild_with_output_id_range((20, 40)) + m1_zch_identities = torch.cat( + [ + m1_1.state_dict()["_hash_zch_identities"], + m2_1.state_dict()["_hash_zch_identities"], + ] + ) + m1_zch_metadata = torch.cat( + [ + m1_1.state_dict()["_hash_zch_metadata"], + m2_1.state_dict()["_hash_zch_metadata"], + ] + ) + state_dict = m1_2.state_dict() + state_dict["_hash_zch_identities"] = m1_zch_identities + state_dict["_hash_zch_metadata"] = m1_zch_metadata + m1_2.load_state_dict(state_dict) + + m2_zch_identities = torch.cat( + [ + m3_1.state_dict()["_hash_zch_identities"], + m4_1.state_dict()["_hash_zch_identities"], + ] + ) + m2_zch_metadata = torch.cat( + [ + m3_1.state_dict()["_hash_zch_metadata"], + m4_1.state_dict()["_hash_zch_metadata"], + ] + ) + state_dict = m2_2.state_dict() + state_dict["_hash_zch_identities"] = m2_zch_identities + state_dict["_hash_zch_metadata"] = m2_zch_metadata + m2_2.load_state_dict(state_dict) + + _ = m1_2(in1_2.to_dict()) + _ = m2_2(in2_2.to_dict()) + + m0.reset_inference_mode() # just clears out training state + full_zch_identities = torch.cat( + [ + m1_2.state_dict()["_hash_zch_identities"], + m2_2.state_dict()["_hash_zch_identities"], + ] + ) + state_dict = m0.state_dict() + state_dict["_hash_zch_identities"] = full_zch_identities + m0.load_state_dict(state_dict) + + # now set all models to eval, and run kjt + m1_2.eval() + m2_2.eval() + assert m0.training is False + + inf_input = kjt.to_dict() + inf_output = m0(inf_input) + + o1_2 = m1_2(in1_2.to_dict()) + o2_2 = m2_2(in2_2.to_dict()) + self.assertTrue( + torch.allclose( + inf_output["f"].values(), + torch.index_select( + torch.cat([x["f"].values() for x in [o1_2, o2_2]]), + dim=0, + index=cast(torch.Tensor, permute), + ), + ) + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() < 1, + "This test requires CUDA device", + ) + def test_output_global_offset_tensor(self) -> None: + m = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cpu"), + total_num_buckets=4, + ) + self.assertIsNone(m._output_global_offset_tensor) + + bucket2 = m.rebuild_with_output_id_range((5, 10)) + self.assertIsNotNone(bucket2._output_global_offset_tensor) + self.assertTrue( + # pyre-ignore [6] + torch.equal(bucket2._output_global_offset_tensor, torch.tensor([5])) + ) + self.assertEqual(bucket2._start_bucket, 1) + + m.reset_inference_mode() + bucket3 = m.rebuild_with_output_id_range((10, 15)) + self.assertIsNotNone(bucket3._output_global_offset_tensor) + self.assertTrue( + # pyre-ignore [6] + torch.equal(bucket3._output_global_offset_tensor, torch.tensor([10])) + ) + self.assertEqual(bucket3._start_bucket, 2) + self.assertEqual( + # pyre-ignore [16] + bucket3._output_global_offset_tensor.device.type, + "cpu", + ) + + remapped_indices = bucket3.remap( + { + "test": JaggedTensor( + values=torch.tensor( + [6, 10, 14, 18, 22], dtype=torch.int64, device="cpu" + ), + lengths=torch.tensor([5], dtype=torch.int64, device="cpu"), + ) + } + ) + self.assertTrue( + torch.allclose( + remapped_indices["test"].values(), torch.tensor([14, 10, 10, 11, 10]) + ) + ) + + gpu_zch = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("cuda"), + total_num_buckets=4, + ) + bucket4 = gpu_zch.rebuild_with_output_id_range((15, 20)) + self.assertIsNotNone(bucket4._output_global_offset_tensor) + self.assertTrue(bucket4._output_global_offset_tensor.device.type == "cuda") + self.assertEqual( + bucket4._output_global_offset_tensor, torch.tensor([15], device="cuda") + ) + + meta_zch = HashZchManagedCollisionModule( + zch_size=20, + device=torch.device("meta"), + total_num_buckets=4, + ) + meta_zch.reset_inference_mode() + self.assertIsNone(meta_zch._output_global_offset_tensor) + bucket5 = meta_zch.rebuild_with_output_id_range((15, 20)) + self.assertIsNotNone(bucket5._output_global_offset_tensor) + self.assertTrue(bucket5._output_global_offset_tensor.device.type == "cpu") + self.assertEqual(bucket5._output_global_offset_tensor, torch.tensor([15]))