@@ -70,7 +70,8 @@ def empty_buffer():
7070
7171
7272# Try to include correct column offsets for non-tensor arguments.
73- def _add_column_info (arg , arg_index , is_kwarg , num_positional , func_name , skip_num_stack_entries , arg_names ):
73+ def _add_column_info (arg , arg_index , is_kwarg , num_positional , func_name , arg_names ):
74+ from tripy import function_registry
7475 from tripy .frontend .tensor import Tensor
7576
7677 assert isinstance (arg , Tensor ), f"This function should only be called for objects that are already Tensor instances"
@@ -90,10 +91,10 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n
9091 # Find the first caller of this function that is NOT the function registry.
9192 # Also save the last dispatch target we see.
9293 dispatch_target = None
93- for idx , source_info in enumerate (arg .stack_info [WRAPPER_STACK_DEPTH + skip_num_stack_entries :]):
94+ for idx , source_info in enumerate (arg .stack_info [WRAPPER_STACK_DEPTH :]):
9495 dispatch_target = source_info ._dispatch_target or dispatch_target
9596 if source_info .module not in utils .get_module_names_to_exclude_from_stack_info ():
96- frame_index = idx + WRAPPER_STACK_DEPTH + skip_num_stack_entries
97+ frame_index = idx + WRAPPER_STACK_DEPTH
9798 break
9899 else :
99100 # Fallback path is just to look at the user code
@@ -118,12 +119,6 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n
118119 arg_index = 0 if arg_index == 1 else 1
119120 dispatch_target = dispatch_target .replace ("__r" , "__" )
120121
121- # Special case for __getitem__: It is variadic. Argument 0 is the tensor,
122- # and all subsequent arguments are slice parameters (in start, stop, step order).
123- # Hence, we subtract one to get the index of the slice parameters
124- if dispatch_target == "__getitem__" :
125- arg_index -= 1
126-
127122 candidates = utils .get_arg_candidate_column_offsets (
128123 source_info .code , arg_index , num_positional , dispatch_target or func_name , is_kwarg , arg_names
129124 )
@@ -136,9 +131,7 @@ def _add_column_info(arg, arg_index, is_kwarg, num_positional, func_name, skip_n
136131
137132# NOTE: Conversion to tensors needs to be done via a decorator so that we can add stack information
138133# for non-tensors. Without having full context of the function signature, it is otherwise difficult to do so.
139- def convert_to_tensors (
140- targets : Set [str ] = None , skip_num_stack_entries : int = 0 , preprocess_args : Optional [Callable ] = None
141- ):
134+ def convert_to_tensors (targets : Set [str ] = None , preprocess_args : Optional [Callable ] = None ):
142135 """
143136 Decorator that converts specified arguments to Tensors or DimensionSizes.
144137 If the argument can be converted to a DimensionSize, it is. Otherwise, it is
@@ -152,17 +145,6 @@ def convert_to_tensors(
152145 targets: Names of arguments to convert to tensors. If not supplied, any arguments annotated
153146 with `TensorLike` or `ShapeLike` are converted.
154147
155- skip_num_stack_entries: If the decorator is used on a function that is *called by*
156- a function that the user invokes, it will be necessary to skip stack entries
157- in order to get the column info from the user code. The number of entries skipped
158- should be equal to the nesting depth from a function called by user code
159- (if the decorated function is called by the user the depth is 0;
160- if the decorated function is called from a user function, the depth is 1; etc.).
161-
162- NOTE: When using this, make sure any extra arguments to the decorated function are
163- passed as keyword arguments. Otherwise, the logic for determining column information
164- will break.
165-
166148 preprocess_args: A callback used to preprocess arguments before potential conversion. If provided,
167149 this is always called, regardless of whether the decorator actually needed to perform conversion.
168150 This will be called with all arguments that were passed to the decorated function and should
@@ -242,7 +224,6 @@ def add_arg(arg):
242224 name in kwargs ,
243225 len (args ),
244226 func .__name__ ,
245- skip_num_stack_entries ,
246227 [name for name , _ in all_args ],
247228 )
248229
0 commit comments