Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tripy/nvtripy/frontend/ops/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,19 @@ def get_min(a, b):
else minimum(cast_to_dim_size(a), cast_to_dim_size(b))
)

def get_max(a, b):
return (
max(a, b)
if isinstance(a, int) and isinstance(b, int)
else maximum(cast_to_dim_size(a), cast_to_dim_size(b))
)

if slice_idx.start is not None:
start = to_positive_idx(slice_idx.start)
# If `start` is past the end, clamp it - if we're going backwards, we need to clamp it to a valid value;
# otherwise, we can clamp it out of bounds (which will yield an empty tensor):
start = get_min(start, select(step >= 0, dim_size, dim_size - 1))
start = get_max(start, 0) # if the start is negative and greater than the dimension size, clamp it to 0
else:
start = default_start

Expand Down