diff --git a/tripy/nvtripy/frontend/ops/slice.py b/tripy/nvtripy/frontend/ops/slice.py index df8353f3e..010e04dee 100644 --- a/tripy/nvtripy/frontend/ops/slice.py +++ b/tripy/nvtripy/frontend/ops/slice.py @@ -208,11 +208,19 @@ def get_min(a, b): else minimum(cast_to_dim_size(a), cast_to_dim_size(b)) ) + def get_max(a, b): + return ( + max(a, b) + if isinstance(a, int) and isinstance(b, int) + else maximum(cast_to_dim_size(a), cast_to_dim_size(b)) + ) + if slice_idx.start is not None: start = to_positive_idx(slice_idx.start) # If `start` is past the end, clamp it - if we're going backwards, we need to clamp it to a valid value; # otherwise, we can clamp it out of bounds (which will yield an empty tensor): start = get_min(start, select(step >= 0, dim_size, dim_size - 1)) + start = get_max(start, 0) # if the adjusted start is still negative, clamp it to 0 else: start = default_start diff --git a/tripy/tests/integration/test_slice.py b/tripy/tests/integration/test_slice.py index f603fb3c4..ab875b8e6 100644 --- a/tripy/tests/integration/test_slice.py +++ b/tripy/tests/integration/test_slice.py @@ -69,6 +69,7 @@ class TestSliceOp: ((2, 3, 4), lambda t: t[None]), ((2, 3, 4), lambda t: t[None, 1:2, :, None]), ((2, 3, 4), lambda t: t[..., None, 0, None]), + ((1, 2), lambda t: t[:, -4:]), # negative start greater than dimension size ], ) def test_slice(self, use_constant, shape, slice_func, eager_or_compiled):