Skip to content

cache offsets / lengths on updates #3135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 26 additions & 52 deletions torchrec/modules/object_pool_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ def forward(self, ids: torch.Tensor) -> JaggedTensor:
"""
return self.lookup(ids)

@abc.abstractmethod
def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]:
pass
yield "values", self._tbe_state
yield "key_lengths", self._key_lengths
if self._is_weighted:
yield "weights", self._tbe_weights_state


class TensorJaggedIndexSelectLookup(KeyedJaggedTensorPoolLookup):
Expand Down Expand Up @@ -249,20 +251,21 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor:
)

return JaggedTensor(
values=values, weights=weights, lengths=key_lengths_for_ids.flatten()
values=values,
weights=weights,
lengths=key_lengths_for_ids.flatten(),
)

def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:

with record_function("## TensorPool update ##"):
with record_function("## KJTPool update ##"):
key_lengths = (
# pyre-ignore
values.lengths()
.view(-1, len(self._feature_max_lengths))
.sum(axis=1)
)
key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths)

padded_values = torch.ops.fbgemm.jagged_to_padded_dense(
values.values(),
[key_offsets],
Expand Down Expand Up @@ -290,12 +293,6 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:
self._jagged_lengths = lengths
self._jagged_offsets = offsets

def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]:
yield "values", self._values
yield "key_lengths", self._key_lengths
if self._is_weighted:
yield "weights", self._weights


class UVMCachingInt64Lookup(KeyedJaggedTensorPoolLookup):
def __init__(
Expand Down Expand Up @@ -378,29 +375,18 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor:
kjt_dense_values = output_int_upper | output_int_lower

key_lengths_for_ids = self._key_lengths[ids]
lengths_sum = key_lengths_for_ids.sum(dim=1)

padded_lengths = self._bit_dims_t - lengths_sum
# TODO: pre-compute this on class init
jagged_lengths = torch.stack(
[
lengths_sum,
padded_lengths,
],
dim=1,
).flatten()

lookup_indices = torch.arange(0, ids.shape[0] * 2, 2, device=self._device)
output_lengths = jagged_lengths[lookup_indices]
lookup_indices = 2 * ids
lengths = self._jagged_lengths[lookup_indices]
offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum(lengths)
values = jagged_index_select_with_empty(
kjt_dense_values.flatten().unsqueeze(-1),
lookup_indices,
torch.ops.fbgemm.asynchronous_inclusive_cumsum(jagged_lengths),
torch.ops.fbgemm.asynchronous_inclusive_cumsum(output_lengths),
self._jagged_offsets,
offsets,
)

return JaggedTensor(
values=values.flatten(),
values=values,
lengths=key_lengths_for_ids.flatten(),
)

Expand Down Expand Up @@ -435,10 +421,9 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:
.to(self._key_lengths.dtype)
)

def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]:
yield "values_upper_and_lower_bits", self._tbe_state
if self._is_weighted:
yield "weights", self._tbe_weights_state
lengths, offsets = self._infer_jagged_lengths_inclusive_offsets()
self._jagged_lengths = lengths
self._jagged_offsets = offsets


class UVMCachingInt32Lookup(KeyedJaggedTensorPoolLookup):
Expand Down Expand Up @@ -520,24 +505,14 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor:
kjt_dense_values = output.view(torch.int32)

key_lengths_for_ids = self._key_lengths[ids]
lengths_sum = key_lengths_for_ids.sum(dim=1)

padded_lengths = self._bit_dims_t - lengths_sum
jagged_lengths = torch.stack(
[
lengths_sum,
padded_lengths,
],
dim=1,
).flatten()

lookup_ids = 2 * torch.arange(ids.shape[0], device=self._device)
output_lengths = jagged_lengths[lookup_ids]
lookup_indices = 2 * ids
lengths = self._jagged_lengths[lookup_indices]
offsets = torch.ops.fbgemm.asynchronous_inclusive_cumsum(lengths)
values = jagged_index_select_with_empty(
kjt_dense_values.flatten().unsqueeze(-1),
lookup_ids,
torch.ops.fbgemm.asynchronous_inclusive_cumsum(jagged_lengths),
torch.ops.fbgemm.asynchronous_inclusive_cumsum(output_lengths),
lookup_indices,
self._jagged_offsets,
offsets
)

return JaggedTensor(
Expand Down Expand Up @@ -569,10 +544,9 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:
.to(self._key_lengths.dtype)
)

def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]:
yield "values", self._tbe_state
if self._is_weighted:
yield "weights", self._tbe_weights_state
lengths, offsets = self._infer_jagged_lengths_inclusive_offsets()
self._jagged_lengths = lengths
self._jagged_offsets = offsets


class TensorPoolLookup(abc.ABC, torch.nn.Module):
Expand Down
58 changes: 50 additions & 8 deletions torchrec/modules/tests/test_kjt_pool_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@
import unittest

import torch
from torchrec.modules.object_pool_lookups import UVMCachingInt64Lookup
from torchrec.modules.object_pool_lookups import (
UVMCachingInt32Lookup,
UVMCachingInt64Lookup,
)
from torchrec.sparse.jagged_tensor import JaggedTensor


class KeyedJaggedTensorPoolLookupTest(unittest.TestCase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
not torch.cuda.is_available(),
"This test requires a GPU to run",
)
def test_uvm_caching_int64_lookup(
self,
) -> None:
device = torch.device("cuda:0")
device = torch.device("cuda")

pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4}
lookup = UVMCachingInt64Lookup(
Expand All @@ -43,10 +46,7 @@ def test_uvm_caching_int64_lookup(

lookup.update(
ids=ids,
values=JaggedTensor(
jt_values,
lengths=jt_lengths,
),
values=JaggedTensor(jt_values, lengths=jt_lengths),
)

torch.testing.assert_close(lookup.lookup(ids).values(), jt_values)
Expand All @@ -63,3 +63,45 @@ def test_uvm_caching_int64_lookup(
torch.testing.assert_close(
lookup.lookup(ids).values(), jt_values + INT64_VALUE_SHIFT
)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
not torch.cuda.is_available(),
"This test requires a GPU to run",
)
def test_uvm_caching_int32_lookup(
self,
) -> None:
device = torch.device("cuda")

pool_size, feature_max_lengths = 4, {"f1": 2, "f2": 4}
lookup = UVMCachingInt32Lookup(
pool_size=pool_size,
feature_max_lengths=feature_max_lengths,
is_weighted=False,
device=device,
)
ids = torch.tensor([0, 1, 2, 3], device=device)
jt_values = torch.tensor(
[1, 3, 3, 2, 2, 4, 11, 13, 13, 13, 12, 12, 14, 14, 14, 14],
dtype=torch.int32,
device=device,
)
jt_lengths = torch.tensor(
[1, 2, 2, 1, 1, 3, 2, 4], dtype=torch.int, device=device
)

lookup.update(
ids=ids,
values=JaggedTensor(jt_values, lengths=jt_lengths),
)

torch.testing.assert_close(lookup.lookup(ids).values(), jt_values)

lookup.update(
ids=ids,
values=JaggedTensor(jt_values, lengths=jt_lengths),
)

torch.testing.assert_close(lookup.lookup(ids).values(), jt_values)

Loading