Skip to content

Commit b13211b

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs. (#3119)
Summary: For VBE KJTs, ppdate the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` when its set. Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization. Reviewed By: TroyGarden Differential Revision: D76997485
1 parent 0836fe8 commit b13211b

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,13 +1129,18 @@ def _maybe_compute_stride_kjt(
11291129
lengths: Optional[torch.Tensor],
11301130
offsets: Optional[torch.Tensor],
11311131
stride_per_key_per_rank: Optional[torch.IntTensor],
1132+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
11321133
) -> int:
11331134
if stride is None:
11341135
if len(keys) == 0:
11351136
stride = 0
11361137
elif (
11371138
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
11381139
):
1140+
# For VBE KJT, batch size should be based on inverse_indices when set.
1141+
if inverse_indices is not None:
1142+
return inverse_indices[1].shape[-1]
1143+
11391144
s = stride_per_key_per_rank.sum(dim=1).max().item()
11401145
if not torch.jit.is_scripting() and is_non_strict_exporting():
11411146
stride = torch.sym_int(s)
@@ -1882,7 +1887,9 @@ def from_lengths_sync(
18821887
lengths: torch.Tensor,
18831888
weights: Optional[torch.Tensor] = None,
18841889
stride: Optional[int] = None,
1885-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
1890+
stride_per_key_per_rank: Optional[
1891+
Union[torch.IntTensor, List[List[int]]]
1892+
] = None,
18861893
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
18871894
) -> "KeyedJaggedTensor":
18881895
"""
@@ -2180,6 +2187,7 @@ def stride(self) -> int:
21802187
self._lengths,
21812188
self._offsets,
21822189
self._stride_per_key_per_rank,
2190+
self._inverse_indices,
21832191
)
21842192
self._stride = stride
21852193
return stride

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,18 @@ def test_meta_device_compatibility(self) -> None:
10171017
lengths=torch.tensor([], device=torch.device("meta")),
10181018
)
10191019

1020+
def test_vbe_kjt_stride(self) -> None:
1021+
inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]])
1022+
kjt = KeyedJaggedTensor(
1023+
keys=["f1", "f2", "f3"],
1024+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1025+
lengths=torch.tensor([3, 3, 2]),
1026+
stride_per_key_per_rank=[[2], [1]],
1027+
inverse_indices=(["f1", "f2"], inverse_indices),
1028+
)
1029+
1030+
self.assertEqual(kjt.stride(), inverse_indices.shape[-1])
1031+
10201032

10211033
class TestKeyedJaggedTensorScripting(unittest.TestCase):
10221034
def test_scriptable_forward(self) -> None:

0 commit comments

Comments
 (0)