From 41354fb8e1e0037d624b9c3fd0e315206600a4d4 Mon Sep 17 00:00:00 2001 From: Matthew Murphy <15095986+murphymatt@users.noreply.github.com> Date: Tue, 24 Jun 2025 16:20:10 +0000 Subject: [PATCH] cache offsets / lengths on updates --- torchrec/modules/object_pool_lookups.py | 78 +++++++------------ .../modules/tests/test_kjt_pool_lookup.py | 58 ++++++++++++-- 2 files changed, 76 insertions(+), 60 deletions(-) diff --git a/torchrec/modules/object_pool_lookups.py b/torchrec/modules/object_pool_lookups.py index b30358f19..5b3e284e0 100644 --- a/torchrec/modules/object_pool_lookups.py +++ b/torchrec/modules/object_pool_lookups.py @@ -149,9 +149,11 @@ def forward(self, ids: torch.Tensor) -> JaggedTensor: """ return self.lookup(ids) - @abc.abstractmethod def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: - pass + yield "values", self._tbe_state + yield "key_lengths", self._key_lengths + if self._is_weighted: + yield "weights", self._tbe_weights_state class TensorJaggedIndexSelectLookup(KeyedJaggedTensorPoolLookup): @@ -249,12 +251,14 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor: ) return JaggedTensor( - values=values, weights=weights, lengths=key_lengths_for_ids.flatten() + values=values, + weights=weights, + lengths=key_lengths_for_ids.flatten(), ) def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: - with record_function("## TensorPool update ##"): + with record_function("## KJTPool update ##"): key_lengths = ( # pyre-ignore values.lengths() @@ -262,7 +266,6 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: .sum(axis=1) ) key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths) - padded_values = torch.ops.fbgemm.jagged_to_padded_dense( values.values(), [key_offsets], @@ -290,12 +293,6 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: self._jagged_lengths = lengths self._jagged_offsets = offsets - def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: - yield "values", self._values - yield "key_lengths", self._key_lengths - if self._is_weighted: - yield "weights", self._weights - class UVMCachingInt64Lookup(KeyedJaggedTensorPoolLookup): def __init__( @@ -378,29 +375,18 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor: kjt_dense_values = output_int_upper | output_int_lower key_lengths_for_ids = self._key_lengths[ids] - lengths_sum = key_lengths_for_ids.sum(dim=1) - - padded_lengths = self._bit_dims_t - lengths_sum - # TODO: pre-compute this on class init - jagged_lengths = torch.stack( - [ - lengths_sum, - padded_lengths, - ], - dim=1, - ).flatten() - - lookup_indices = torch.arange(0, ids.shape[0] * 2, 2, device=self._device) - output_lengths = jagged_lengths[lookup_indices] + lookup_indices = 2 * ids + lengths = self._jagged_lengths[lookup_indices] + offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum(lengths) values = jagged_index_select_with_empty( kjt_dense_values.flatten().unsqueeze(-1), lookup_indices, - torch.ops.fbgemm.asynchronous_inclusive_cumsum(jagged_lengths), - torch.ops.fbgemm.asynchronous_inclusive_cumsum(output_lengths), + self._jagged_offsets, + offsets, ) return JaggedTensor( - values=values.flatten(), + values=values, lengths=key_lengths_for_ids.flatten(), ) @@ -435,10 +421,9 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: .to(self._key_lengths.dtype) ) - def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: - yield "values_upper_and_lower_bits", self._tbe_state - if self._is_weighted: - yield "weights", self._tbe_weights_state + lengths, offsets = self._infer_jagged_lengths_inclusive_offsets() + self._jagged_lengths = lengths + self._jagged_offsets = offsets class UVMCachingInt32Lookup(KeyedJaggedTensorPoolLookup): @@ -520,24 +505,14 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor: kjt_dense_values = output.view(torch.int32) key_lengths_for_ids = self._key_lengths[ids] - lengths_sum = key_lengths_for_ids.sum(dim=1) - - padded_lengths = self._bit_dims_t - lengths_sum - jagged_lengths = torch.stack( - [ - lengths_sum, - padded_lengths, - ], - dim=1, - ).flatten() - - lookup_ids = 2 * torch.arange(ids.shape[0], device=self._device) - output_lengths = jagged_lengths[lookup_ids] + lookup_indices = 2 * ids + lengths = self._jagged_lengths[lookup_indices] + offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum(lengths) values = jagged_index_select_with_empty( kjt_dense_values.flatten().unsqueeze(-1), - lookup_ids, - torch.ops.fbgemm.asynchronous_inclusive_cumsum(jagged_lengths), - torch.ops.fbgemm.asynchronous_inclusive_cumsum(output_lengths), + lookup_indices, + self._jagged_offsets, + offsets ) return JaggedTensor( @@ -569,10 +544,9 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: .to(self._key_lengths.dtype) ) - def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: - yield "values", self._tbe_state - if self._is_weighted: - yield "weights", self._tbe_weights_state + lengths, offsets = self._infer_jagged_lengths_inclusive_offsets() + self._jagged_lengths = lengths + self._jagged_offsets = offsets class TensorPoolLookup(abc.ABC, torch.nn.Module): diff --git a/torchrec/modules/tests/test_kjt_pool_lookup.py b/torchrec/modules/tests/test_kjt_pool_lookup.py index 3b9617f3e..61fe31653 100644 --- a/torchrec/modules/tests/test_kjt_pool_lookup.py +++ b/torchrec/modules/tests/test_kjt_pool_lookup.py @@ -9,20 +9,23 @@ import unittest import torch -from torchrec.modules.object_pool_lookups import UVMCachingInt64Lookup +from torchrec.modules.object_pool_lookups import ( + UVMCachingInt32Lookup, + UVMCachingInt64Lookup, +) from torchrec.sparse.jagged_tensor import JaggedTensor class KeyedJaggedTensorPoolLookupTest(unittest.TestCase): # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( - torch.cuda.device_count() <= 1, - "Not enough GPUs, this test requires at least two GPUs", + not torch.cuda.is_available(), + "This test requires a GPU to run", ) def test_uvm_caching_int64_lookup( self, ) -> None: - device = torch.device("cuda:0") + device = torch.device("cuda") pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} lookup = UVMCachingInt64Lookup( @@ -43,10 +46,7 @@ def test_uvm_caching_int64_lookup( lookup.update( ids=ids, - values=JaggedTensor( - jt_values, - lengths=jt_lengths, - ), + values=JaggedTensor(jt_values, lengths=jt_lengths), ) torch.testing.assert_close(lookup.lookup(ids).values(), jt_values) @@ -63,3 +63,45 @@ def test_uvm_caching_int64_lookup( torch.testing.assert_close( lookup.lookup(ids).values(), jt_values + INT64_VALUE_SHIFT ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "This test requires a GPU to run", + ) + def test_uvm_caching_int32_lookup( + self, + ) -> None: + device = torch.device("cuda") + + pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4} + lookup = UVMCachingInt32Lookup( + pool_size=pool_size, + feature_max_lengths=feature_max_lengths, + is_weighted=False, + device=device, + ) + ids = torch.tensor([0, 1, 2, 3], device=device) + jt_values = torch.tensor( + [1, 3, 3, 2, 2, 4, 11, 13, 13, 13, 12, 12, 14, 14, 14, 14], + dtype=torch.int32, + device=device, + ) + jt_lengths = torch.tensor( + [1, 2, 2, 1, 1, 3, 2, 4], dtype=torch.int, device=device + ) + + lookup.update( + ids=ids, + values=JaggedTensor(jt_values, lengths=jt_lengths), + ) + + torch.testing.assert_close(lookup.lookup(ids).values(), jt_values) + + lookup.update( + ids=ids, + values=JaggedTensor(jt_values, lengths=jt_lengths), + ) + + torch.testing.assert_close(lookup.lookup(ids).values(), jt_values) +