diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 362240d88..e6b9a13d2 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1515,7 +1515,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(self, features, embs) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index fce86dc19..86eece133 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -10,7 +10,7 @@ import logging from abc import ABC from collections import OrderedDict -from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -206,6 +206,10 @@ def __init__( ) self.grouped_configs = grouped_configs + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + ] = None def _create_embedding_kernel( self, @@ -305,7 +309,13 @@ def forward( self._feature_splits, ) for emb_op, features in zip(self._emb_modules, features_by_group): - embeddings.append(emb_op(features).view(-1)) + lookup = emb_op(features).view(-1) + embeddings.append(lookup) + + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(emb_op, features, lookup) return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor) @@ -409,6 +419,19 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class CommOpGradientScaling(torch.autograd.Function): @staticmethod @@ -481,6 +504,10 @@ def __init__( if scale_weight_gradients and get_gradient_division() else 1 ) + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + ] = None def _create_embedding_kernel( self, @@ -608,7 +635,12 @@ def forward( features._weights, self._scale_gradient_factor ) - embeddings.append(emb_op(features)) + lookup = emb_op(features) + embeddings.append(lookup) + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(emb_op, features, lookup) if features.variable_stride_per_key() and len(self._emb_modules) > 1: stride_per_rank_per_key = list( @@ -738,6 +770,19 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class MetaInferGroupedEmbeddingsLookup( BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 20f0a4c88..eee2cba27 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -373,7 +373,7 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor], None] + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -426,14 +426,14 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, - record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], ) -> None: """ Register a function to be called after lookup is done. This is used for tracking the lookup results and optimizer states. Args: - record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. """ if self.post_lookup_tracker_fn is not None: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 2beaf3aef..af63585b4 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1459,7 +1459,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(self, features, embs) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 905bf7648..4dabdb198 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -8,12 +8,23 @@ # pyre-strict import logging as logger from collections import Counter, OrderedDict -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Tuple import torch +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from torch import nn +from torchrec.distributed.batched_embedding_kernel import BatchedFusedEmbedding + from torchrec.distributed.embedding import ShardedEmbeddingCollection +from torchrec.distributed.embedding_lookup import ( + BatchedFusedEmbeddingBag, + GroupedEmbeddingsLookup, + GroupedPooledEmbeddingsLookup, +) from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.model_tracker.delta_store import DeltaStore from torchrec.distributed.model_tracker.types import ( @@ -21,15 +32,32 @@ EmbdUpdateMode, TrackingMode, ) +from torchrec.distributed.utils import none_throws + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = { # Only IDs are tracked, no additional state is stored. TrackingMode.ID_ONLY: EmbdUpdateMode.NONE, # TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that - # the earliest embedding values are stored since the last checkpoint or snapshot. - # This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk(). + # the earliest embedding values are stored since the last checkpoint + # or snapshot. This mode is used for computing topk delta rows, which + # is currently achieved by running (new_emb - old_emb).norm().topk(). TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST, + # TrackingMode.MOMENTUM utilizes EmbdUpdateMode.LAST to ensure that + # the most recent momentum values—capturing the accumulated gradient + # direction and magnitude—are stored since the last batch. + # This mode supports approximate top-k delta-row selection, can be + # obtained by running momentum.norm().topk(). + TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST, + # MOMENTUM_DIFF keeps a running sum of the square of the gradients per row. + # Within each publishing interval, we track the starting value of this running + # sum on all used rows and then do a lookup when ``get_delta`` is called to query + # the latest sum. Then we can compute the delta of the two values and return them + # together with the row ids. + TrackingMode.MOMENTUM_DIFF: EmbdUpdateMode.FIRST, + # The same as MOMENTUM_DIFF. Adding for backward compatibility. + TrackingMode.ROWWISE_ADAGRAD: EmbdUpdateMode.FIRST, } # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. @@ -87,6 +115,7 @@ def __init__( # from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection self.tracked_modules: Dict[str, nn.Module] = {} + self.table_to_fqn: Dict[str, str] = {} self.feature_to_fqn: Dict[str, str] = {} # Generate the mapping from FQN to feature names. self.fqn_to_feature_names() @@ -141,7 +170,9 @@ def trigger_compaction(self) -> None: # Update the current compact index to the end index to avoid duplicate compaction. self.curr_compact_index = end_idx - def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: + def record_lookup( + self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor + ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -152,6 +183,7 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: (in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode). Args: + emb_module (nn.Module): The embedding module in which the lookup was performed. kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record. states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt. """ @@ -162,7 +194,14 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: # In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch. elif self._mode == TrackingMode.EMBEDDING: self.record_embeddings(kjt, states) - + # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch. + elif self._mode == TrackingMode.MOMENTUM_LAST: + self.record_momentum(emb_module, kjt) + elif ( + self._mode == TrackingMode.MOMENTUM_DIFF + or self._mode == TrackingMode.ROWWISE_ADAGRAD + ): + self.record_rowwise_optim_state(emb_module, kjt) else: raise NotImplementedError(f"Tracking mode {self._mode} is not supported") @@ -228,6 +267,93 @@ def record_embeddings( states=torch.cat(per_table_emb[table_fqn]), ) + def record_momentum( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + # FIXME: this is the momentum from last iteration, use momentum from current iter + # for correctness. + # pyre-ignore Undefined attribute [16]: + momentum = emb_module._emb_module.momentum1_dev + # FIXME: support multiple tables per group, information can be extracted from + # module._config (i.e., GroupedEmbeddingConfig) + # pyre-ignore Undefined attribute [16]: + states = momentum.view(-1, emb_module._config.embedding_dims()[0])[ + kjt.values() + ].norm(dim=1) + + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + table_fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + + def record_rowwise_optim_state( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + opt_states: List[List[torch.Tensor]] = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `split_optimizer_states`. + emb_module._emb_module.split_optimizer_states() + ) + proxy: torch.Tensor = torch.cat([state[0] for state in opt_states]) + states = proxy[kjt.values()] + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + table_fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + + def get_latest(self) -> Dict[str, torch.Tensor]: + ret: Dict[str, torch.Tensor] = {} + for module in self.tracked_modules.values(): + # pyre-fixme[29]: + for lookup in module._lookups: + for embs_module in lookup._emb_modules: + assert isinstance( + embs_module, (BatchedFusedEmbeddingBag, BatchedFusedEmbedding) + ), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but {type(embs_module)} found" + tbe = embs_module._emb_module + + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + table_names = [t.name for t in embs_module._config.embedding_tables] + opt_states = tbe.split_optimizer_states() + assert len(table_names) == len(opt_states) + + for i, table_name in enumerate(table_names): + emb_fqn = self.table_to_fqn[table_name] + table_state = opt_states[i][0] + assert ( + emb_fqn not in ret + ), f"a table with {emb_fqn} already exists" + ret[emb_fqn] = table_state + + return ret + def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: """ Return a dictionary of hit local IDs for each sparse feature. Ids are @@ -239,7 +365,13 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso per_table_delta_rows = self.get_delta(consumer) return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()} - def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: + def get_delta( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, DeltaRows]: """ Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN. @@ -264,6 +396,65 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: self.per_consumer_batch_idx[consumer] = index_end if self._delete_on_read: self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) + + if self._mode in (TrackingMode.MOMENTUM_DIFF, TrackingMode.ROWWISE_ADAGRAD): + square_sum_map = self.get_latest() + for fqn, rows in tracker_rows.items(): + assert ( + fqn in square_sum_map + ), f"{fqn} not found in {square_sum_map.keys()}" + # compute delta sum_t(g^2) for t in [t1, t2] through + # sum_t2(g^2) - sum_t1(g^2) + # pyre-fixme[58]: `-` is not supported for operand types `Tensor` + # and `Optional[Tensor]`. + rows.states = square_sum_map[fqn][rows.ids] - rows.states + + if rows.states is not None: + default_k = rows.states.size(-1) + top_k = ( + int(top_percentage * default_k) + if top_percentage is not None + else default_k + ) + + if ( + per_table_percentage is not None + and per_table_percentage.get(fqn) is not None + ): + per_table_k = int(per_table_percentage[fqn][0] * default_k) + policy = per_table_percentage[fqn][1] + + if policy == "MIN": + top_k = min(top_k, per_table_k) + elif policy == "MAX": + top_k = max(top_k, per_table_k) + elif policy == "OVERRIDE": + top_k = per_table_k + else: + logger.warning( + f"Unknown policy {policy}, will keep using original top_k {top_k}" + ) + + logger.info(f"get_unique {fqn=} {top_k=} {default_k=}") + + if top_k >= default_k: + continue + + if sorted_by_indices: + sorted_indices, _ = torch.sort( + torch.topk( + none_throws(rows.states), top_k, sorted=False + ).indices, + stable=False, + ) + rows.ids = rows.ids[sorted_indices] + rows.states = none_throws(rows.states)[sorted_indices] + else: + rows.states, indices = torch.topk( + none_throws(rows.states), top_k, sorted=False + ) + rows.ids = rows.ids[indices] + return tracker_rows def get_tracked_modules(self) -> Dict[str, nn.Module]: @@ -280,7 +471,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: return self._fqn_to_feature_map table_to_feature_names: Dict[str, List[str]] = OrderedDict() - table_to_fqn: Dict[str, str] = OrderedDict() for fqn, named_module in self._model.named_modules(): split_fqn = fqn.split(".") # Skipping partial FQNs present in fqns_to_skip @@ -306,13 +496,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: # will incorrectly match fqn with all the table names that have the same prefix if table_name in split_fqn: embedding_fqn = self._clean_fqn_fn(fqn) - if table_name in table_to_fqn: + if table_name in self.table_to_fqn: # Sanity check for validating that we don't have more then one table mapping to same fqn. logger.warning( - f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" ) - table_to_fqn[table_name] = embedding_fqn - logger.info(f"Table to fqn: {table_to_fqn}") + self.table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {self.table_to_fqn}") flatten_names = [ name for names in table_to_feature_names.values() for name in names ] @@ -325,15 +515,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() for table_name in table_to_feature_names: - if table_name not in table_to_fqn: + if table_name not in self.table_to_fqn: # This is likely unexpected, where we can't locate the FQN associated with this table. logger.warning( - f"Table {table_name} not found in {table_to_fqn}, skipping" + f"Table {table_name} not found in {self.table_to_fqn}, skipping" ) continue - fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ - table_name - ] + fqn_to_feature_names[self.table_to_fqn[table_name]] = ( + table_to_feature_names[table_name] + ) self._fqn_to_feature_map = fqn_to_feature_names return fqn_to_feature_names @@ -380,13 +570,49 @@ def _clean_fqn_fn(self, fqn: str) -> str: def _validate_and_init_tracker_fns(self) -> None: "To validate the mode is supported for the given module" for module in self.tracked_modules.values(): + # EMBEDDING mode is only supported for ShardedEmbeddingCollection assert not ( isinstance(module, ShardedEmbeddingBagCollection) and self._mode == TrackingMode.EMBEDDING ), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings." - # register post lookup function - # pyre-ignore[29] - module.register_post_lookup_tracker_fn(self.record_lookup) + + if ( + self._mode == TrackingMode.ID_ONLY + or self._mode == TrackingMode.EMBEDDING + ): + # register post lookup function + # pyre-ignore[29] + module.register_post_lookup_tracker_fn(self.record_lookup) + elif self._mode == TrackingMode.MOMENTUM_LAST: + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) + elif ( + self._mode == TrackingMode.ROWWISE_ADAGRAD + or self._mode == TrackingMode.MOMENTUM_DIFF + ): + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) and all( + # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD + # pyre-ignore[16]: + emb._emb_module.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD + # pyre-ignore[16]: + or emb._emb_module.optimizer == OptimType.PARTIAL_ROWWISE_ADAM + for emb in lookup._emb_modules + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) + else: + raise NotImplementedError( + f"Tracking mode {self._mode} is not supported" + ) # register auto compaction function at odist if self._auto_compact: # pyre-ignore[29] diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index a92a9b286..9362357fe 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -13,6 +13,9 @@ import torch import torchrec from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from parameterized import parameterized from torch import nn @@ -1463,6 +1466,196 @@ def test_multiple_consumers( output_params=output_params, ) + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_momentum( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_momentum, + world_size=self.world_size, + test_params=test_params, + ) + + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_DIFF, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.ROWWISE_ADAGRAD, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_rowwise_adagrad( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_rowwise_adagrad, + world_size=self.world_size, + test_params=test_params, + ) + def _test_fqn_to_feature_names( rank: int, @@ -1859,3 +2052,124 @@ def _test_multiple_consumer( and returned.allclose(expected_ids), f"{i=}, {table_fqn=}, mismatch {returned=} vs {expected_ids=}", ) + + +def _test_duplication_with_momentum( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + """ + Test momentum tracking functionality in model delta tracker. + + Validates that the tracker correctly captures and stores momentum values from + optimizer states when using TrackingMode.MOMENTUM_LAST mode. + """ + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + ) + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + dt_model_opt.step() + baseline_opt.step() + + delta_rows = dt.get_delta() + for table_fqn in table_fqns_list: + ids = delta_rows[table_fqn].ids + states = none_throws(delta_rows[table_fqn].states) + + unittest.TestCase().assertTrue(states is not None) + unittest.TestCase().assertTrue(ids.numel() == states.numel()) + unittest.TestCase().assertTrue(bool((states != 0).all().item())) + + +def _test_duplication_with_rowwise_adagrad( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + optimizer_type=OptimType.EXACT_ROWWISE_ADAGRAD, + ) + + # read momemtum directly from the table + tbe: SplitTableBatchedEmbeddingBagsCodegen = ( + ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `ec`. + dt_model._dmp_wrapped_module.module.ec._lookups[0] + ._emb_modules[0] + .emb_module + ) + if test_params.embedding_config_type == EmbeddingConfig + else ( + dt_model._dmp_wrapped_module.module.ebc._lookups[0] # pyre-ignore + ._emb_modules[0] + .emb_module + ) + ) + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + start_momentums = tbe.split_optimizer_states()[0][0].detach().clone() + + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + + dt_model_opt.step() + baseline_opt.step() + + end_momentums = tbe.split_optimizer_states()[0][0].detach().clone() + + delta_rows = dt.get_delta() + table_fqn = table_fqns_list[0] + + ids = delta_rows[table_fqn].ids + tracked_momentum = none_throws(delta_rows[table_fqn].states) + unittest.TestCase().assertTrue(tracked_momentum is not None) + unittest.TestCase().assertTrue(ids.numel() == tracked_momentum.numel()) + unittest.TestCase().assertTrue(bool((tracked_momentum != 0).all().item())) + + expected_momentum = end_momentums[ids] - start_momentums[ids] + unittest.TestCase().assertTrue(tracked_momentum.allclose(expected_momentum)) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index cec95af91..43a1b9223 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -41,13 +41,21 @@ class TrackingMode(Enum): Tracking mode for ``ModelDeltaTracker``. Enums: - ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. - EMBEDDING: Tracks both row IDs and their corresponding embedding values, - enabling precise top-k result calculations. However, this option comes with increased memory usage. + ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. + EMBEDDING: Tracks both row IDs and their corresponding embedding values, + enabling precise top-k result calculations. However, this option comes + with increased memory usage. + MOMENTUM_LAST: Tracks both row IDs and their corresponding momentum values. This mode + supports approximate top-k delta-row selection. + MOMENTUM_DIFF: Tracks both row IDs and their corresponding momentum difference values. + ROWWISE_ADAGRAD: Tracks both row IDs and their corresponding rowwise adagrad states. """ ID_ONLY = "id_only" EMBEDDING = "embedding" + MOMENTUM_LAST = "momentum_last" + MOMENTUM_DIFF = "momentum_diff" + ROWWISE_ADAGRAD = "rowwise_adagrad" class EmbdUpdateMode(Enum):