Skip to content

Adding support for tracking optimizers states in Model Delta Tracker. #3143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs)
self.post_lookup_tracker_fn(self, features, embs)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type
Expand Down
51 changes: 48 additions & 3 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
from abc import ABC
from collections import OrderedDict
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -206,6 +206,10 @@ def __init__(
)

self.grouped_configs = grouped_configs
# Model tracker function to tracker optimizer state
self.optim_state_tracker_fn: Optional[
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
] = None

def _create_embedding_kernel(
self,
Expand Down Expand Up @@ -305,7 +309,13 @@ def forward(
self._feature_splits,
)
for emb_op, features in zip(self._emb_modules, features_by_group):
embeddings.append(emb_op(features).view(-1))
lookup = emb_op(features).view(-1)
embeddings.append(lookup)

# Model tracker optimizer state function, will only be set called
# when model tracker is configured to track optimizer state
if self.optim_state_tracker_fn is not None:
self.optim_state_tracker_fn(emb_op, features, lookup)

return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor)

Expand Down Expand Up @@ -409,6 +419,19 @@ def purge(self) -> None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
emb_module.purge()

def register_optim_state_tracker_fn(
self,
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
) -> None:
"""
Model tracker function to tracker optimizer state

Args:
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.

"""
self.optim_state_tracker_fn = record_fn


class CommOpGradientScaling(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -481,6 +504,10 @@ def __init__(
if scale_weight_gradients and get_gradient_division()
else 1
)
# Model tracker function to tracker optimizer state
self.optim_state_tracker_fn: Optional[
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
] = None

def _create_embedding_kernel(
self,
Expand Down Expand Up @@ -608,7 +635,12 @@ def forward(
features._weights, self._scale_gradient_factor
)

embeddings.append(emb_op(features))
lookup = emb_op(features)
embeddings.append(lookup)
# Model tracker optimizer state function, will only be set called
# when model tracker is configured to track optimizer state
if self.optim_state_tracker_fn is not None:
self.optim_state_tracker_fn(emb_op, features, lookup)

if features.variable_stride_per_key() and len(self._emb_modules) > 1:
stride_per_rank_per_key = list(
Expand Down Expand Up @@ -738,6 +770,19 @@ def purge(self) -> None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
emb_module.purge()

def register_optim_state_tracker_fn(
self,
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
) -> None:
"""
Model tracker function to tracker optimizer state

Args:
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.

"""
self.optim_state_tracker_fn = record_fn


class MetaInferGroupedEmbeddingsLookup(
BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn
Expand Down
6 changes: 3 additions & 3 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []
self.post_lookup_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor], None]
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
] = None
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None

Expand Down Expand Up @@ -426,14 +426,14 @@ def train(self, mode: bool = True): # pyre-ignore[3]

def register_post_lookup_tracker_fn(
self,
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
) -> None:
"""
Register a function to be called after lookup is done. This is used for
tracking the lookup results and optimizer states.

Args:
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.

"""
if self.post_lookup_tracker_fn is not None:
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs)
self.post_lookup_tracker_fn(self, features, embs)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST,
Expand Down
82 changes: 75 additions & 7 deletions torchrec/distributed/model_tracker/model_delta_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
import torch

from torch import nn

from torchrec.distributed.embedding import ShardedEmbeddingCollection
from torchrec.distributed.embedding_lookup import (
GroupedEmbeddingsLookup,
GroupedPooledEmbeddingsLookup,
)
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
from torchrec.distributed.model_tracker.delta_store import DeltaStore
from torchrec.distributed.model_tracker.types import (
Expand All @@ -27,9 +32,16 @@
# Only IDs are tracked, no additional state is stored.
TrackingMode.ID_ONLY: EmbdUpdateMode.NONE,
# TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that
# the earliest embedding values are stored since the last checkpoint or snapshot.
# This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk().
# the earliest embedding values are stored since the last checkpoint
# or snapshot. This mode is used for computing topk delta rows, which
# is currently achieved by running (new_emb - old_emb).norm().topk().
TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST,
# TrackingMode.MOMENTUM utilizes EmbdUpdateMode.LAST to ensure that
# the most recent momentum values—capturing the accumulated gradient
# direction and magnitude—are stored since the last batch.
# This mode supports approximate top-k delta-row selection, can be
# obtained by running momentum.norm().topk().
TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST,
}

# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
Expand Down Expand Up @@ -141,7 +153,9 @@ def trigger_compaction(self) -> None:
# Update the current compact index to the end index to avoid duplicate compaction.
self.curr_compact_index = end_idx

def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
def record_lookup(
self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor
) -> None:
"""
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.

Expand All @@ -152,6 +166,7 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
(in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode).

Args:
emb_module (nn.Module): The embedding module in which the lookup was performed.
kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record.
states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt.
"""
Expand All @@ -162,7 +177,9 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
# In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch.
elif self._mode == TrackingMode.EMBEDDING:
self.record_embeddings(kjt, states)

# In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
elif self._mode == TrackingMode.MOMENTUM_LAST:
self.record_momentum(emb_module, kjt)
else:
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")

Expand Down Expand Up @@ -228,6 +245,39 @@ def record_embeddings(
states=torch.cat(per_table_emb[table_fqn]),
)

def record_momentum(
self,
emb_module: nn.Module,
kjt: KeyedJaggedTensor,
) -> None:
# FIXME: this is the momentum from last iteration, use momentum from current iter
# for correctness.
# pyre-ignore Undefined attribute [16]:
momentum = emb_module._emb_module.momentum1_dev
# FIXME: support multiple tables per group, information can be extracted from
# module._config (i.e., GroupedEmbeddingConfig)
# pyre-ignore Undefined attribute [16]:
states = momentum.view(-1, emb_module._config.embedding_dims()[0])[
kjt.values()
].norm(dim=1)

offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
torch.tensor(kjt.length_per_key(), dtype=torch.int64)
)
assert (
kjt.values().numel() == states.numel()
), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} "

for i, key in enumerate(kjt.keys()):
fqn = self.feature_to_fqn[key]
per_key_states = states[offsets[i] : offsets[i + 1]]
self.store.append(
batch_idx=self.curr_batch_idx,
table_fqn=fqn,
ids=kjt[key].values(),
states=per_key_states,
)

def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
"""
Return a dictionary of hit local IDs for each sparse feature. Ids are
Expand Down Expand Up @@ -380,13 +430,31 @@ def _clean_fqn_fn(self, fqn: str) -> str:
def _validate_and_init_tracker_fns(self) -> None:
"To validate the mode is supported for the given module"
for module in self.tracked_modules.values():
# EMBEDDING mode is only supported for ShardedEmbeddingCollection
assert not (
isinstance(module, ShardedEmbeddingBagCollection)
and self._mode == TrackingMode.EMBEDDING
), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings."
# register post lookup function
# pyre-ignore[29]
module.register_post_lookup_tracker_fn(self.record_lookup)

if (
self._mode == TrackingMode.ID_ONLY
or self._mode == TrackingMode.EMBEDDING
):
# register post lookup function
# pyre-ignore[29]
module.register_post_lookup_tracker_fn(self.record_lookup)
elif self._mode == TrackingMode.MOMENTUM_LAST:
# pyre-ignore[29]:
for lookup in module._lookups:
assert isinstance(
lookup,
(GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup),
)
lookup.register_optim_state_tracker_fn(self.record_lookup)
else:
raise NotImplementedError(
f"Tracking mode {self._mode} is not supported"
)
# register auto compaction function at odist
if self._auto_compact:
# pyre-ignore[29]
Expand Down
Loading
Loading