Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 143 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch # usort:skip
from torch import nn, Tensor # usort:skip
from torch.autograd.profiler import record_function # usort:skip

# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
Expand Down Expand Up @@ -626,6 +627,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
lxu_cache_locations_list: List[Tensor]
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]
prefetched_info: List[Tuple[Tensor, Tensor, Optional[Tensor]]]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
Expand Down Expand Up @@ -690,6 +692,8 @@ def __init__( # noqa C901
embedding_table_index_type: torch.dtype = torch.int64,
embedding_table_offset_type: torch.dtype = torch.int64,
embedding_shard_info: Optional[List[Tuple[int, int, int, int]]] = None,
enable_raw_embedding_streaming: bool = False,
res_params: Optional[RESParams] = None,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
self.uuid = str(uuid.uuid4())
Expand All @@ -700,6 +704,7 @@ def __init__( # noqa C901
)

self.logging_table_name: str = self.get_table_name_for_logging(table_names)
self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
self.pooling_mode = pooling_mode
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE

Expand Down Expand Up @@ -1460,6 +1465,30 @@ def __init__( # noqa C901
)
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type

self.prefetched_info: List[Tuple[Tensor, Tensor, Optional[Tensor]]] = (
torch.jit.annotate(List[Tuple[Tensor, Tensor, Optional[Tensor]]], [])
)
if self.enable_raw_embedding_streaming:
self.res_params: RESParams = res_params or RESParams()
self.res_params.table_sizes = [0] + list(accumulate(rows))
res_port_from_env = os.getenv("LOCAL_RES_PORT")
self.res_params.res_server_port = (
int(res_port_from_env) if res_port_from_env else 0
)
# pyre-fixme[4]: Attribute must be annotated.
self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
self.uuid,
self.enable_raw_embedding_streaming,
self.res_params.res_store_shards,
self.res_params.res_server_port,
self.res_params.table_names,
self.res_params.table_offsets,
self.res_params.table_sizes,
)
logging.info(
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
)

@torch.jit.ignore
def log(self, msg: str) -> None:
"""
Expand Down Expand Up @@ -1979,8 +2008,13 @@ def forward( # noqa: C901
# In forward, we don't enable multi-pass prefetch as we want the process
# to be as fast as possible and memory usage doesn't matter (will be recycled
# by dense fwd/bwd)
# TODO: Properly pass in the hash_zch_identities
self._prefetch(
indices, offsets, vbe_metadata, multipass_prefetch_config=None
indices,
offsets,
vbe_metadata,
multipass_prefetch_config=None,
hash_zch_identities=None,
)

if len(self.timesteps_prefetched) > 0:
Expand Down Expand Up @@ -2503,6 +2537,7 @@ def _prefetch(
offsets: Tensor,
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
hash_zch_identities: Optional[Tensor] = None,
) -> None:
if not is_torchdynamo_compiling():
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
Expand All @@ -2521,7 +2556,13 @@ def _prefetch(
self.local_uvm_cache_stats.zero_()
self._report_io_size_count("prefetch_input", indices)

# streaming before updating the cache
self.raw_embedding_stream()

final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
linear_cache_indices_merged = torch.zeros(
0, dtype=indices.dtype, device=indices.device
)
for (
partial_indices,
partial_lxu_cache_locations,
Expand All @@ -2537,6 +2578,9 @@ def _prefetch(
vbe_metadata.max_B if vbe_metadata is not None else -1,
base_offset,
)
linear_cache_indices_merged = torch.cat(
[linear_cache_indices_merged, linear_cache_indices]
)

if (
self.record_cache_metrics.record_cache_miss_counter
Expand Down Expand Up @@ -2617,6 +2661,8 @@ def _prefetch(
if self.should_log():
self.print_uvm_cache_stats(use_local_cache=False)

self._store_prefetched_tensors(linear_cache_indices_merged, hash_zch_identities)

def should_log(self) -> bool:
"""Determines if we should log for this step, using exponentially decreasing frequency.

Expand Down Expand Up @@ -3829,6 +3875,102 @@ def _debug_print_input_stats_factory_null(
return _debug_print_input_stats_factory_impl
return _debug_print_input_stats_factory_null

@torch.jit.ignore
def raw_embedding_stream(self) -> None:
if not self.enable_raw_embedding_streaming:
return None
# when pipelining is enabled
# prefetch in iter i happens before the backward sparse in iter i - 1
# so embeddings for iter i - 1's changed ids are not updated.
# so we can only fetch the indices from the iter i - 2
# when pipelining is disabled
# prefetch in iter i happens before forward iter i
# so we can get the iter i - 1's changed ids safely.
target_prev_iter = 1
if self.prefetch_pipeline:
target_prev_iter = 2
if not len(self.prefetched_info) > (target_prev_iter - 1):
return None
with record_function(
"## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
):
(updated_indices, updated_count, updated_identities) = (
self.prefetched_info.pop(0)
)
updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
updated_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
gather_cache_stats=False, # not collecting cache stats
num_uniq_cache_indices=updated_count,
)
updated_weights = torch.empty(
[updated_indices.size()[0], self.max_D_cache],
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
dtype=self.lxu_cache_weights.dtype,
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
device=self.lxu_cache_weights.device,
)
torch.ops.fbgemm.masked_index_select(
updated_weights,
updated_locations,
self.lxu_cache_weights,
updated_count,
)
# stream weights
self._raw_embedding_streamer.stream(
updated_indices.to(device=torch.device("cpu")),
updated_weights.to(device=torch.device("cpu")),
(
updated_identities.to(device=torch.device("cpu"))
if updated_identities is not None
else None
),
updated_count.to(device=torch.device("cpu")),
False, # require_tensor_copy
False, # blocking_tensor_copy
)

@torch.jit.ignore
def _store_prefetched_tensors(
self,
linear_cache_indices_merged: torch.Tensor,
hash_zch_identities: Optional[torch.Tensor],
) -> None:
"""
NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional.
This function stores the prefetched tensors for the raw embedding streaming.
"""
if self.enable_raw_embedding_streaming:
with record_function(
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
):
(
linear_unique_indices,
linear_unique_indices_length,
_,
) = torch.ops.fbgemm.get_unique_indices(
linear_cache_indices_merged,
self.total_cache_hash_size,
compute_count=False,
)
linear_unique_indices = linear_unique_indices.narrow(
0, 0, linear_unique_indices_length[0]
)
self.prefetched_info.append(
(
linear_unique_indices,
linear_unique_indices_length,
(
hash_zch_identities.index_select(
dim=0, index=linear_unique_indices
).to(device=torch.device("cpu"))
if hash_zch_identities is not None
else None
),
)
)


class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ namespace fbgemm_gpu {
struct StreamQueueItem {
at::Tensor indices;
at::Tensor weights;
std::optional<at::Tensor> identities;
at::Tensor count;
StreamQueueItem(
at::Tensor src_indices,
at::Tensor src_weights,
std::optional<at::Tensor> src_identities,
at::Tensor src_count) {
indices = std::move(src_indices);
weights = std::move(src_weights);
identities = std::move(src_identities);
count = std::move(src_count);
}
};
Expand Down Expand Up @@ -67,21 +70,24 @@ class RawEmbeddingStreamer : public torch::jit::CustomClassHolder {
void stream(
const at::Tensor& indices,
const at::Tensor& weights,
std::optional<at::Tensor> identities,
const at::Tensor& count,
bool require_tensor_copy,
bool blocking_tensor_copy = true);

#ifdef FBGEMM_FBCODE
folly::coro::Task<void> tensor_stream(
const at::Tensor& indices,
const at::Tensor& weights);
const at::Tensor& weights,
std::optional<at::Tensor> identities);
/*
* Copy the indices, weights and count tensors and enqueue them for
* asynchronous stream.
*/
void copy_and_enqueue_stream_tensors(
const at::Tensor& indices,
const at::Tensor& weights,
std::optional<at::Tensor> identities,
const at::Tensor& count);

/*
Expand Down
Loading
Loading