diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 362240d88..e6b9a13d2 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1515,7 +1515,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(self, features, embs) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index fce86dc19..86eece133 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -10,7 +10,7 @@ import logging from abc import ABC from collections import OrderedDict -from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -206,6 +206,10 @@ def __init__( ) self.grouped_configs = grouped_configs + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + ] = None def _create_embedding_kernel( self, @@ -305,7 +309,13 @@ def forward( self._feature_splits, ) for emb_op, features in zip(self._emb_modules, features_by_group): - embeddings.append(emb_op(features).view(-1)) + lookup = emb_op(features).view(-1) + embeddings.append(lookup) + + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(emb_op, features, lookup) return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor) @@ -409,6 +419,19 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class CommOpGradientScaling(torch.autograd.Function): @staticmethod @@ -481,6 +504,10 @@ def __init__( if scale_weight_gradients and get_gradient_division() else 1 ) + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + ] = None def _create_embedding_kernel( self, @@ -608,7 +635,12 @@ def forward( features._weights, self._scale_gradient_factor ) - embeddings.append(emb_op(features)) + lookup = emb_op(features) + embeddings.append(lookup) + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(emb_op, features, lookup) if features.variable_stride_per_key() and len(self._emb_modules) > 1: stride_per_rank_per_key = list( @@ -738,6 +770,19 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class MetaInferGroupedEmbeddingsLookup( BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 20f0a4c88..eee2cba27 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -373,7 +373,7 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor], None] + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -426,14 +426,14 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, - record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], ) -> None: """ Register a function to be called after lookup is done. This is used for tracking the lookup results and optimizer states. Args: - record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. """ if self.post_lookup_tracker_fn is not None: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 2beaf3aef..af63585b4 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1459,7 +1459,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(self, features, embs) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 905bf7648..29c854175 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -13,7 +13,12 @@ import torch from torch import nn + from torchrec.distributed.embedding import ShardedEmbeddingCollection +from torchrec.distributed.embedding_lookup import ( + GroupedEmbeddingsLookup, + GroupedPooledEmbeddingsLookup, +) from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.model_tracker.delta_store import DeltaStore from torchrec.distributed.model_tracker.types import ( @@ -27,9 +32,16 @@ # Only IDs are tracked, no additional state is stored. TrackingMode.ID_ONLY: EmbdUpdateMode.NONE, # TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that - # the earliest embedding values are stored since the last checkpoint or snapshot. - # This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk(). + # the earliest embedding values are stored since the last checkpoint + # or snapshot. This mode is used for computing topk delta rows, which + # is currently achieved by running (new_emb - old_emb).norm().topk(). TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST, + # TrackingMode.MOMENTUM utilizes EmbdUpdateMode.LAST to ensure that + # the most recent momentum values—capturing the accumulated gradient + # direction and magnitude—are stored since the last batch. + # This mode supports approximate top-k delta-row selection, can be + # obtained by running momentum.norm().topk(). + TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST, } # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. @@ -141,7 +153,9 @@ def trigger_compaction(self) -> None: # Update the current compact index to the end index to avoid duplicate compaction. self.curr_compact_index = end_idx - def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: + def record_lookup( + self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor + ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -152,6 +166,7 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: (in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode). Args: + emb_module (nn.Module): The embedding module in which the lookup was performed. kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record. states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt. """ @@ -162,7 +177,9 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: # In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch. elif self._mode == TrackingMode.EMBEDDING: self.record_embeddings(kjt, states) - + # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch. + elif self._mode == TrackingMode.MOMENTUM_LAST: + self.record_momentum(emb_module, kjt) else: raise NotImplementedError(f"Tracking mode {self._mode} is not supported") @@ -228,6 +245,39 @@ def record_embeddings( states=torch.cat(per_table_emb[table_fqn]), ) + def record_momentum( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + # FIXME: this is the momentum from last iteration, use momentum from current iter + # for correctness. + # pyre-ignore Undefined attribute [16]: + momentum = emb_module._emb_module.momentum1_dev + # FIXME: support multiple tables per group, information can be extracted from + # module._config (i.e., GroupedEmbeddingConfig) + # pyre-ignore Undefined attribute [16]: + states = momentum.view(-1, emb_module._config.embedding_dims()[0])[ + kjt.values() + ].norm(dim=1) + + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + table_fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: """ Return a dictionary of hit local IDs for each sparse feature. Ids are @@ -380,13 +430,31 @@ def _clean_fqn_fn(self, fqn: str) -> str: def _validate_and_init_tracker_fns(self) -> None: "To validate the mode is supported for the given module" for module in self.tracked_modules.values(): + # EMBEDDING mode is only supported for ShardedEmbeddingCollection assert not ( isinstance(module, ShardedEmbeddingBagCollection) and self._mode == TrackingMode.EMBEDDING ), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings." - # register post lookup function - # pyre-ignore[29] - module.register_post_lookup_tracker_fn(self.record_lookup) + + if ( + self._mode == TrackingMode.ID_ONLY + or self._mode == TrackingMode.EMBEDDING + ): + # register post lookup function + # pyre-ignore[29] + module.register_post_lookup_tracker_fn(self.record_lookup) + elif self._mode == TrackingMode.MOMENTUM_LAST: + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) + else: + raise NotImplementedError( + f"Tracking mode {self._mode} is not supported" + ) # register auto compaction function at odist if self._auto_compact: # pyre-ignore[29] diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index a92a9b286..c3f641b98 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -1463,6 +1463,101 @@ def test_multiple_consumers( output_params=output_params, ) + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_momentum( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_momentum, + world_size=self.world_size, + test_params=test_params, + ) + def _test_fqn_to_feature_names( rank: int, @@ -1859,3 +1954,52 @@ def _test_multiple_consumer( and returned.allclose(expected_ids), f"{i=}, {table_fqn=}, mismatch {returned=} vs {expected_ids=}", ) + + +def _test_duplication_with_momentum( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + """ + Test momentum tracking functionality in model delta tracker. + + Validates that the tracker correctly captures and stores momentum values from + optimizer states when using TrackingMode.MOMENTUM_LAST mode. + """ + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + ) + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + dt_model_opt.step() + baseline_opt.step() + + delta_rows = dt.get_delta() + for table_fqn in table_fqns_list: + ids = delta_rows[table_fqn].ids + states = none_throws(delta_rows[table_fqn].states) + + unittest.TestCase().assertTrue(states is not None) + unittest.TestCase().assertTrue(ids.numel() == states.numel()) + unittest.TestCase().assertTrue(bool((states != 0).all().item())) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index cec95af91..a5a56514c 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -41,13 +41,17 @@ class TrackingMode(Enum): Tracking mode for ``ModelDeltaTracker``. Enums: - ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. - EMBEDDING: Tracks both row IDs and their corresponding embedding values, - enabling precise top-k result calculations. However, this option comes with increased memory usage. + ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. + EMBEDDING: Tracks both row IDs and their corresponding embedding values, + enabling precise top-k result calculations. However, this option comes + with increased memory usage. + MOMENTUM_LAST: Tracks both row IDs and their corresponding momentum values. This mode + supports approximate top-k delta-row selection. """ ID_ONLY = "id_only" EMBEDDING = "embedding" + MOMENTUM_LAST = "momentum_last" class EmbdUpdateMode(Enum):