Skip to content

Commit a283b13

Browse files
authored
[Tripy] Handle variadic arguments in preprocess_args for the convert_to_tensors decorator (#329)
Addresses issue #311.
1 parent 3b7a2af commit a283b13

File tree

5 files changed

+40
-17
lines changed

5 files changed

+40
-17
lines changed

tripy/tests/frontend/test_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,17 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike):
178178
a, b = func(1, 2)
179179

180180
assert b.tolist() == 3
181+
182+
def test_variadic_args(self):
183+
184+
def increment(a, *args):
185+
return {"a": a + 1, "args": list(map(lambda arg: arg + 1, args))}
186+
187+
@convert_to_tensors(preprocess_args=increment)
188+
def func(a: tp.Tensor, *args):
189+
return [a] + list(args)
190+
191+
a, b, c = func(tp.Tensor(1), tp.Tensor(2), tp.Tensor(3))
192+
assert a.tolist() == 2
193+
assert b.tolist() == 3
194+
assert c.tolist() == 4

tripy/tripy/backend/api/compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def process_arg(name, arg):
151151
return arg
152152

153153
new_args = []
154-
for name, arg in utils.get_positional_arg_names(func, *args):
154+
positional_arg_info, _ = utils.get_positional_arg_names(func, *args)
155+
for name, arg in positional_arg_info:
155156
new_args.append(process_arg(name, arg))
156157

157158
new_kwargs = {}

tripy/tripy/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def wrapper(*args, **kwargs):
105105
from tripy.common.datatype import dtype
106106
from tripy.frontend.tensor import Tensor
107107

108-
merged_args = utils.merge_function_arguments(func, *args, **kwargs)
108+
merged_args, _ = utils.merge_function_arguments(func, *args, **kwargs)
109109

110110
# The first arguments seen for each type variable. Other arguments with the same variable
111111
# must use the same data types.

tripy/tripy/frontend/utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import functools
1919
import inspect
2020
from collections import deque
21-
from typing import Callable, List, Optional, Set, Union
21+
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
2222

2323
from tripy import utils
2424
from tripy.common.exception import raise_error
@@ -166,7 +166,8 @@ def convert_to_tensors(
166166
preprocess_args: A callback used to preprocess arguments before potential conversion. If provided,
167167
this is always called, regardless of whether the decorator actually needed to perform conversion.
168168
This will be called with all arguments that were passed to the decorated function and should
169-
return a dictionary of all updated arguments.
169+
return a dictionary of all updated arguments. For a variadic arg, the dictionary entry for the name
170+
should have a list of all the updated values.
170171
"""
171172

172173
def impl(func):
@@ -193,17 +194,20 @@ def wrapper(*args, **kwargs):
193194
from tripy.frontend.tensor import Tensor
194195
from tripy.frontend.trace.ops.cast import cast
195196

196-
all_args = utils.merge_function_arguments(func, *args, **kwargs)
197+
all_args, var_arg_info = utils.merge_function_arguments(func, *args, **kwargs)
197198

198199
if preprocess_args is not None:
200+
201+
var_arg_name, var_arg_start_idx = utils.default(var_arg_info, (None, None))
199202
new_args = preprocess_args(*args, **kwargs)
200-
# TODO (#311): Make this work for variadic arguments. If `name` appears multiple times in `all_args`, then
201-
# we know we're dealing with a variadic argument. In that case, we could expect a list in `new_args` and
202-
# then unpack it over the corresponding arguments in `all_args`.
203203
for index in range(len(all_args)):
204204
name, _ = all_args[index]
205205
if name in new_args:
206-
all_args[index] = (name, new_args[name])
206+
if name == var_arg_name:
207+
assert var_arg_start_idx is not None
208+
all_args[index] = (name, new_args[name][index - var_arg_start_idx])
209+
else:
210+
all_args[index] = (name, new_args[name])
207211

208212
# Materialize type variables from tensors.
209213
type_vars = {}

tripy/tripy/utils/utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import math
2525
import time
2626
import typing
27-
from typing import Any, List, Sequence, Union
27+
from typing import Any, List, Optional, Sequence, Tuple, Union
2828

2929
from colored import Fore, Style
3030

@@ -415,10 +415,11 @@ def gen_uid(inputs=None, outputs=None):
415415
##
416416
## Functions
417417
##
418-
def get_positional_arg_names(func, *args):
418+
def get_positional_arg_names(func, *args) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]:
419419
# Returns the names of positional arguments by inspecting the function signature.
420420
# In the case of variadic positional arguments, we cannot determine names, so we use
421-
# None instead.
421+
# None instead. To assist in further processing, this function also returns the name
422+
# and start index of the variadic args in a pair if present (None if not).
422423
signature = inspect.signature(func)
423424
arg_names = []
424425
varargs_name = None
@@ -432,12 +433,15 @@ def get_positional_arg_names(func, *args):
432433
arg_names.append(name)
433434

434435
# For all variadic positional arguments, assign the name of the variadic group.
435-
arg_names.extend([varargs_name] * (len(args) - len(arg_names)))
436-
return list(zip(arg_names, args))
436+
num_variadic_args = len(args) - len(arg_names)
437+
variadic_start_idx = len(arg_names)
438+
arg_names.extend([varargs_name] * num_variadic_args)
439+
return list(zip(arg_names, args)), (varargs_name, variadic_start_idx) if num_variadic_args > 0 else None
437440

438441

439-
def merge_function_arguments(func, *args, **kwargs):
442+
def merge_function_arguments(func, *args, **kwargs) -> Tuple[List[Tuple[str, Any]], Optional[Tuple[str, int]]]:
440443
# Merge positional and keyword arguments, trying to determine names where possible.
441-
all_args = get_positional_arg_names(func, *args)
444+
# Also returns a pair containing the variadic arg name and start index if present (None otherwise).
445+
all_args, var_arg_info = get_positional_arg_names(func, *args)
442446
all_args.extend(kwargs.items())
443-
return all_args
447+
return all_args, var_arg_info

0 commit comments

Comments
 (0)