Skip to content

Commit 54efd4d

Browse files
authored
#675: Fix OOB negative slice bug (#676)
Fixes #675
1 parent b9173ce commit 54efd4d

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

tripy/nvtripy/frontend/ops/slice.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,19 @@ def get_min(a, b):
208208
else minimum(cast_to_dim_size(a), cast_to_dim_size(b))
209209
)
210210

211+
def get_max(a, b):
212+
return (
213+
max(a, b)
214+
if isinstance(a, int) and isinstance(b, int)
215+
else maximum(cast_to_dim_size(a), cast_to_dim_size(b))
216+
)
217+
211218
if slice_idx.start is not None:
212219
start = to_positive_idx(slice_idx.start)
213220
# If `start` is past the end, clamp it - if we're going backwards, we need to clamp it to a valid value;
214221
# otherwise, we can clamp it out of bounds (which will yield an empty tensor):
215222
start = get_min(start, select(step >= 0, dim_size, dim_size - 1))
223+
start = get_max(start, 0) # if the adjusted start is still negative, clamp it to 0
216224
else:
217225
start = default_start
218226

tripy/tests/integration/test_slice.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class TestSliceOp:
6969
((2, 3, 4), lambda t: t[None]),
7070
((2, 3, 4), lambda t: t[None, 1:2, :, None]),
7171
((2, 3, 4), lambda t: t[..., None, 0, None]),
72+
((1, 2), lambda t: t[:, -4:]), # negative start greater than dimension size
7273
],
7374
)
7475
def test_slice(self, use_constant, shape, slice_func, eager_or_compiled):

0 commit comments

Comments
 (0)