Skip to content

Commit c8ee99f

Browse files
authored
[Tripy] Eliminate need for skip_num_stack_entries argument in convert_to_tensors (#333)
Addresses issue #310. The only use of `skip_num_stack_entries` was for `slice_helper` and addressing this issue in a systematic manner would likely require building in many hacks and assumptions, so the approach here is just to manually override the stack information in that one function.
1 parent 7fc38c8 commit c8ee99f

File tree

3 files changed

+42
-31
lines changed

3 files changed

+42
-31
lines changed

tripy/tripy/frontend/trace/ops/slice.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,37 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]:
250250
return out
251251

252252

253-
# Because the helper is called inside another function, we need to skip one entry in the call stack to find
254-
# the original call to user code.
255-
@frontend_utils.convert_to_tensors(skip_num_stack_entries=1)
253+
@frontend_utils.convert_to_tensors()
256254
def slice_helper(tensor, *slice_params: TensorLike):
255+
from tripy.utils import get_arg_candidate_column_offsets
256+
257+
# The default behavior of convert_to_tensors will not add the correct column info to the slice params
258+
# because this call occurs *inside* the overridden call to __getitem__, so we adjust the column info manually.
259+
260+
# Look for the stack frame index to __getitem__. We need to go one stack frame beyond to get to the *user* call of __getitem__.
261+
# It will be the same for all the slice params
262+
frame_index = -1
263+
assert slice_params
264+
265+
for idx, source_info in enumerate(slice_params[0].stack_info):
266+
if source_info._dispatch_target == "__getitem__":
267+
frame_index = idx + 1
268+
break
269+
270+
# convert_to_tensors should have taken care of this for us
271+
assert frame_index >= 0, "No call to the __getitem__ dispatch found"
272+
273+
arg_names = ["tensor"] + ["slice_params"] * len(slice_params)
274+
for arg_index, arg in enumerate(slice_params):
275+
source_info = arg.stack_info[frame_index]
276+
277+
# Note: arg_index does not account for the positional arg, hence we add 1 for the index argument
278+
candidates = get_arg_candidate_column_offsets(
279+
source_info.code, 1 + arg_index, 1, "__getitem__", False, arg_names
280+
)
281+
282+
# Now we can set the column range correctly
283+
if len(candidates) == 1:
284+
source_info.column_range = candidates[0]
285+
257286
return Slice.build(inputs=[tensor, *slice_params])

tripy/tripy/frontend/utils.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def empty_buffer():
7070

7171

7272
# Try to include correct column offsets for non-tensor arguments.
73-
def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_num_stack_entries, arg_names):
73+
def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, arg_names):
74+
from tripy import function_registry
7475
from tripy.frontend.tensor import Tensor
7576

7677
assert isinstance(arg, Tensor), f"This function should only be called for objects that are already Tensor instances"
@@ -90,10 +91,10 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n
9091
# Find the first caller of this function that is NOT the function registry.
9192
# Also save the last dispatch target we see.
9293
dispatch_target = None
93-
for idx, source_info in enumerate(arg.stack_info[WRAPPER_STACK_DEPTH + skip_num_stack_entries :]):
94+
for idx, source_info in enumerate(arg.stack_info[WRAPPER_STACK_DEPTH:]):
9495
dispatch_target = source_info._dispatch_target or dispatch_target
9596
if source_info.module not in utils.get_module_names_to_exclude_from_stack_info():
96-
frame_index = idx + WRAPPER_STACK_DEPTH + skip_num_stack_entries
97+
frame_index = idx + WRAPPER_STACK_DEPTH
9798
break
9899
else:
99100
# Fallback path is just to look at the user code
@@ -118,12 +119,6 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n
118119
arg_index = 0 if arg_index == 1 else 1
119120
dispatch_target = dispatch_target.replace("__r", "__")
120121

121-
# Special case for __getitem__: It is variadic. Argument 0 is the tensor,
122-
# and all subsequent arguments are slice parameters (in start, stop, step order).
123-
# Hence, we subtract one to get the index of the slice parameters
124-
if dispatch_target == "__getitem__":
125-
arg_index -= 1
126-
127122
candidates = utils.get_arg_candidate_column_offsets(
128123
source_info.code, arg_index, num_positional, dispatch_target or func_name, is_kwarg, arg_names
129124
)
@@ -136,9 +131,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n
136131

137132
# NOTE: Conversion to tensors needs to be done via a decorator so that we can add stack information
138133
# for non-tensors. Without having full context of the function signature, it is otherwise difficult to do so.
139-
def convert_to_tensors(
140-
targets: Set[str] = None, skip_num_stack_entries: int = 0, preprocess_args: Optional[Callable] = None
141-
):
134+
def convert_to_tensors(targets: Set[str] = None, preprocess_args: Optional[Callable] = None):
142135
"""
143136
Decorator that converts specified arguments to Tensors or DimensionSizes.
144137
If the argument can be converted to a DimensionSize, it is. Otherwise, it is
@@ -152,17 +145,6 @@ def convert_to_tensors(
152145
targets: Names of arguments to convert to tensors. If not supplied, any arguments annotated
153146
with `TensorLike` or `ShapeLike` are converted.
154147
155-
skip_num_stack_entries: If the decorator is used on a function that is *called by*
156-
a function that the user invokes, it will be necessary to skip stack entries
157-
in order to get the column info from the user code. The number of entries skipped
158-
should be equal to the nesting depth from a function called by user code
159-
(if the decorated function is called by the user the depth is 0;
160-
if the decorated function is called from a user function, the depth is 1; etc.).
161-
162-
NOTE: When using this, make sure any extra arguments to the decorated function are
163-
passed as keyword arguments. Otherwise, the logic for determining column information
164-
will break.
165-
166148
preprocess_args: A callback used to preprocess arguments before potential conversion. If provided,
167149
this is always called, regardless of whether the decorator actually needed to perform conversion.
168150
This will be called with all arguments that were passed to the decorated function and should
@@ -242,7 +224,6 @@ def add_arg(arg):
242224
name in kwargs,
243225
len(args),
244226
func.__name__,
245-
skip_num_stack_entries,
246227
[name for name, _ in all_args],
247228
)
248229

tripy/tripy/utils/ast.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,13 @@ def index_into_expr(node: ast.expr, index: int) -> ast.expr:
139139
return node
140140

141141
# If we have multiple dimensions specified, then we have a tuple of slices.
142-
# Indices are given in as a list of start, stop, step
142+
# NOTE: We subtract num_positional from the index because the slice arguments would
143+
# be passed as *variadic arguments* to slice_helper and so would come after the positional argument
143144
if isinstance(node.slice, ast.Tuple):
144-
element = node.slice.elts[index // 3]
145-
arg_node = index_into_expr(element, index % 3)
145+
element = node.slice.elts[(index - num_positional) // 3]
146+
arg_node = index_into_expr(element, (index - num_positional) % 3)
146147
else:
147-
arg_node = index_into_expr(node.slice, index)
148+
arg_node = index_into_expr(node.slice, (index - num_positional))
148149

149150
if arg_node is not None:
150151
candidates.append((indentation + arg_node.col_offset, indentation + arg_node.end_col_offset))

0 commit comments

Comments
 (0)