Skip to content

Commit c6c6136

Browse files
authored
#343: Support collections of tensors in args/kwargs for compile (#701)
1 parent 34e1e17 commit c6c6136

File tree

4 files changed

+226
-10
lines changed

4 files changed

+226
-10
lines changed

tripy/nvtripy/backend/api/compile.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from nvtripy.utils.types import obj_name_or_type_name
2929

3030

31-
# TODO (#230): Support collections of tensors in args/kwargs
3231
@export.public_api(document_under="compiling_code/compile.rst")
3332
def compile(
3433
func: Callable, optimization_level: int = 3, *, args: Sequence[Any] = [], kwargs: Dict[str, Any] = {}
@@ -157,13 +156,16 @@ def add(a, b):
157156
trace_input_map = {}
158157
input_names = set()
159158
input_infos = {}
159+
trace_inputs = [] # flattened list of trace input tensors in argument order
160+
access_plan_by_name: Dict[str, tuple] = {}
160161

161162
# Set up names for the weights in the module to make the trace easier to read.
162163
if isinstance(func, Module):
163164
for name, weight in func.state_dict().items():
164165
weight.name = name
165166

166-
def process_arg(name, arg):
167+
def process_arg_input_info(name, arg):
168+
"""Process InputInfo or DimensionInputInfo objects and create corresponding tensors."""
167169
if isinstance(arg, InputInfo):
168170
# Make new tensors for tracing.
169171
from nvtripy.common.datatype import floating, integer
@@ -184,6 +186,7 @@ def process_arg(name, arg):
184186

185187
trace_input_map[name] = tensor
186188
input_names.add(name)
189+
trace_inputs.append(tensor.trace_tensor)
187190

188191
return tensor
189192

@@ -199,11 +202,49 @@ def process_arg(name, arg):
199202

200203
trace_input_map[name] = tensor
201204
input_names.add(name)
205+
trace_inputs.append(tensor.trace_tensor)
202206

203207
return tensor
204208

205209
return arg
206210

211+
def process_arg_and_flag(top_arg_name, name, arg, steps):
212+
# Handle individual InputInfo or DimensionInputInfo objects
213+
if isinstance(arg, (InputInfo, DimensionInputInfo)):
214+
tensor_or_dim = process_arg_input_info(name, arg)
215+
access_plan_by_name[name] = (top_arg_name, tuple(steps))
216+
return tensor_or_dim, True
217+
218+
# Handle containers of InputInfo objects
219+
if isinstance(arg, dict):
220+
result = {}
221+
has_input = False
222+
for key, value in arg.items():
223+
nested_name = f"{name}.{key}"
224+
processed_child, child_has_input = process_arg_and_flag(
225+
top_arg_name, nested_name, value, (*steps, str(key))
226+
)
227+
result[key] = processed_child
228+
has_input = has_input or child_has_input
229+
return result, has_input
230+
elif isinstance(arg, (list, tuple)):
231+
result_list = []
232+
has_input = False
233+
for idx, value in enumerate(arg):
234+
nested_name = f"{name}[{idx}]"
235+
processed_child, child_has_input = process_arg_and_flag(top_arg_name, nested_name, value, (*steps, idx))
236+
result_list.append(processed_child)
237+
has_input = has_input or child_has_input
238+
return type(arg)(result_list), has_input # preserve sequence type
239+
240+
return arg, False
241+
242+
def process_arg(name, arg):
243+
processed, has_input = process_arg_and_flag(name, name, arg, tuple())
244+
if has_input:
245+
input_names.add(name)
246+
return processed
247+
207248
compiled_arg_names = []
208249

209250
new_args = []
@@ -258,8 +299,7 @@ def process_arg(name, arg):
258299
[f"Return value {index} was not a tensor: {repr(trace_out)}"],
259300
)
260301

261-
# Order of trace inputs also needs to match that of the compiled_arg_names
262-
trace_inputs = [trace_input_map[name].trace_tensor for name in compiled_arg_names]
302+
# We collected flattened trace inputs during traversal
263303
trace = Trace(
264304
[tensor.trace_tensor for tensor in trace_outputs],
265305
trace_inputs,
@@ -281,9 +321,11 @@ def process_arg(name, arg):
281321
assert isinstance(func_out, Tensor) or isinstance(
282322
func_out, Sequence
283323
), "This function is only implemented for Tensors or sequences of Tensors"
324+
284325
return Executable(
285326
executable,
286327
compiled_arg_names,
287328
return_single_tensor_as_sequence=isinstance(func_out, Sequence),
288329
input_infos=input_infos,
330+
access_plan_by_name=access_plan_by_name,
289331
)

tripy/nvtripy/backend/api/executable.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
arg_names,
4747
return_single_tensor_as_sequence: bool,
4848
input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]],
49+
access_plan_by_name: Dict[str, Tuple[str, Tuple[Union[str, int], ...]]],
4950
):
5051
self._executable = executable
5152

@@ -78,6 +79,24 @@ def __init__(
7879
Stores metadata, like shapes and data types, for each input to the executable.
7980
"""
8081

82+
# Build accessor map from compile-time access plans
83+
self._accessor_map: Dict[str, callable] = {}
84+
name_to_index = {name: idx for idx, name in enumerate(self._arg_names)}
85+
86+
def make_accessor(arg_index: int, steps: Tuple[Union[str, int], ...]):
87+
def accessor(inputs):
88+
v = inputs[arg_index]
89+
for s in steps:
90+
v = v[s]
91+
return v
92+
93+
return accessor
94+
95+
self._access_plan_by_name = access_plan_by_name
96+
for leaf_name, (arg_name, steps) in self._access_plan_by_name.items():
97+
idx = name_to_index[arg_name]
98+
self._accessor_map[leaf_name] = make_accessor(idx, steps)
99+
81100
def __str__(self) -> str:
82101
params = [
83102
f"{name}: {str_from_type_annotation(param.annotation)}"
@@ -195,20 +214,42 @@ def add(a, b):
195214
],
196215
)
197216

217+
# Fetch flattened tensors directly via accessors
218+
input_info_names = list(self.input_infos.keys())
219+
flattened_tensors = []
220+
for name in input_info_names:
221+
try:
222+
flattened_tensors.append(self._accessor_map[name](input_tensors))
223+
except Exception as exc:
224+
raise_error(
225+
f"Missing runtime tensor for input `{name}`.",
226+
[
227+
"Ensure your provided collections include tensors for all compiled inputs.",
228+
f"Expected inputs: {input_info_names}",
229+
f"Note: Error was:\n{exc}",
230+
],
231+
)
198232
expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]
199-
for tensor, expected_device, arg_name in zip(input_tensors, expected_devices, self._arg_names):
233+
234+
# Validate flattened tensors against input_infos
235+
if len(flattened_tensors) != len(expected_devices):
236+
raise_error(
237+
f"Mismatch between number of flattened tensors ({len(flattened_tensors)}) and expected inputs ({len(expected_devices)})."
238+
)
239+
240+
for tensor, expected_device, info_name in zip(flattened_tensors, expected_devices, self.input_infos.keys()):
200241
producer = tensor.trace_tensor.producer
201242
if not isinstance(producer, Constant):
202-
raise_error(f"Tensor `{arg_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
243+
raise_error(f"Tensor `{info_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
203244
if tensor.device.kind != expected_device:
204245
raise_error(
205246
"Unexpected tensor device.",
206247
[
207-
f"For tensor: `{arg_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
248+
f"For tensor: `{info_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
208249
],
209250
)
210251

211-
input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors]
252+
input_memrefs = [inp.trace_tensor.producer.data for inp in flattened_tensors]
212253
try:
213254
output_memrefs = self._session.execute_function(
214255
"main", in_args=input_memrefs, stream=self.stream._active_cuda_stream, client=self._runtime_client
@@ -222,7 +263,7 @@ def add(a, b):
222263
expected_input_dtypes = [
223264
info.dtype if isinstance(info, InputInfo) else int32 for info in self.input_infos.values()
224265
]
225-
for tensor, dtype, arg_name in zip(input_tensors, expected_input_dtypes, self._arg_names):
266+
for tensor, dtype, arg_name in zip(flattened_tensors, expected_input_dtypes, self.input_infos.keys()):
226267
if tensor.dtype != dtype:
227268
raise_error(
228269
f"Unexpected tensor data type.",
@@ -237,7 +278,9 @@ def add(a, b):
237278
expected_input_shapes = [
238279
info.shape_bounds if isinstance(info, InputInfo) else tuple() for info in self.input_infos.values()
239280
]
240-
for tensor, expected_bounds, arg_name in zip(input_tensors, expected_input_shapes, self._arg_names):
281+
for tensor, expected_bounds, arg_name in zip(
282+
flattened_tensors, expected_input_shapes, self.input_infos.keys()
283+
):
241284
shape = tensor.shape
242285

243286
if len(shape) != len(expected_bounds.min):
@@ -346,6 +389,7 @@ def encode_executable(executable):
346389
"executable": base64.b64encode(executable._executable.serialize()).decode(),
347390
"_return_single_tensor_as_sequence": executable._return_single_tensor_as_sequence,
348391
"input_infos": executable.input_infos,
392+
"access_plan_by_name": executable._access_plan_by_name,
349393
}
350394

351395

@@ -357,4 +401,5 @@ def decode_executable(executable_dict):
357401
executable_dict["arg_names"],
358402
return_single_tensor_as_sequence=executable_dict["_return_single_tensor_as_sequence"],
359403
input_infos=executable_dict["input_infos"],
404+
access_plan_by_name=executable_dict["access_plan_by_name"],
360405
)

tripy/nvtripy/frontend/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def eval(self) -> "nvtripy.Tensor":
237237
name: InputInfo(list(map(int, inp.trace_tensor.shape)), inp.dtype)
238238
for name, inp in zip(arg_names, inputs)
239239
},
240+
access_plan_by_name={name: (name, tuple()) for name in arg_names},
240241
)
241242
data = executable(*inputs).trace_tensor.producer.data
242243

tripy/tests/backend/api/test_compile.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,131 @@ def test_dimension_input(self):
241241
out = compiled(inp, dim_inp)
242242
expected = (inp_cp + inp_cp).reshape((-1, reshape_dim))
243243
assert cp.array_equal(cp.from_dlpack(out), expected)
244+
245+
def test_compile_nested_dict_input_info(self):
246+
def func(data_dict):
247+
return data_dict["a"]["inner"] + data_dict["b"]["list"][0] + data_dict["b"]["list"][1]
248+
249+
dict_input = {
250+
"a": {
251+
"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
252+
},
253+
"b": {
254+
"list": [
255+
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
256+
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
257+
],
258+
},
259+
}
260+
compiled_func = tp.compile(func, args=[dict_input])
261+
262+
test_dict = {
263+
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
264+
"b": {
265+
"list": [
266+
(tp.ones((2, 3), dtype=tp.float32) * 2).eval(),
267+
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
268+
]
269+
},
270+
}
271+
result = compiled_func(test_dict)
272+
expected = test_dict["a"]["inner"] + test_dict["b"]["list"][0] + test_dict["b"]["list"][1]
273+
assert tp.equal(result, expected)
274+
275+
def test_compile_nested_sequence_input_info(self):
276+
def func(data_list):
277+
return data_list[0] + data_list[1][0] + data_list[1][1]
278+
279+
list_input = [
280+
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
281+
[
282+
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
283+
tp.ones((2, 3), dtype=tp.float32) * 2,
284+
],
285+
]
286+
compiled_func = tp.compile(func, args=[list_input])
287+
288+
test_list = [
289+
tp.ones((2, 3), dtype=tp.float32).eval(),
290+
(
291+
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
292+
tp.ones((2, 3), dtype=tp.float32) * 2,
293+
),
294+
]
295+
result = compiled_func(test_list)
296+
expected = test_list[0] + test_list[1][0] + test_list[1][1]
297+
assert tp.equal(result, expected)
298+
299+
def test_compile_mixed_containers_and_constants(self):
300+
def func(regular_input, data_dict, data_list, const_in_dict, const):
301+
return (
302+
regular_input
303+
+ data_dict["x"]
304+
+ data_dict["y"]
305+
+ data_list[0]
306+
+ data_list[1]
307+
+ const_in_dict["z"]
308+
+ const
309+
)
310+
311+
regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32)
312+
dict_input = {
313+
"x": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
314+
"y": tp.zeros((2, 3), dtype=tp.float32),
315+
}
316+
list_input = [tp.ones((2, 3), dtype=tp.float32) * 3, tp.InputInfo(shape=(2, 3), dtype=tp.float32)]
317+
const_in_dict = {"z": tp.ones((2, 3), dtype=tp.float32) * 5}
318+
const = tp.ones((2, 3), dtype=tp.float32) * 6
319+
320+
compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, const_in_dict, const])
321+
322+
# Only InputInfo arguments should be in function signature
323+
test_regular = tp.ones((2, 3), dtype=tp.float32).eval()
324+
test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()}
325+
test_list = [None, (tp.ones((2, 3), dtype=tp.float32) * 4).eval()]
326+
327+
result = compiled_func(test_regular, test_dict, test_list)
328+
expected = (
329+
test_regular + test_dict["x"] + dict_input["y"] + test_list[1] + list_input[0] + const_in_dict["z"] + const
330+
)
331+
assert tp.equal(result, expected)
332+
333+
def test_compile_missing_nested_input_fails(self):
334+
def func(data_dict):
335+
return data_dict["a"]["inner"] + data_dict["b"]["list"][1]
336+
337+
dict_input = {
338+
"a": {"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32)},
339+
"b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.InputInfo(shape=(2, 3), dtype=tp.float32)]},
340+
}
341+
342+
compiled_func = tp.compile(func, args=[dict_input])
343+
344+
# Missing b.list[1]
345+
bad_dict = {
346+
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
347+
"b": {"list": [tp.ones((2, 3), dtype=tp.float32).eval()]},
348+
}
349+
with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_dict\.b\.list\[1\]`."):
350+
compiled_func(bad_dict)
351+
352+
# Wrong shape for b.list[1] should trigger a shape/device validation error
353+
wrong_shape = {
354+
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
355+
"b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32).eval()]},
356+
}
357+
with helper.raises(tp.TripyException, match="Unexpected tensor shape."):
358+
compiled_func(wrong_shape)
359+
360+
def test_compile_container_mismatch_fails(self):
361+
def func(data_list):
362+
return data_list[0] + data_list[1][0]
363+
364+
list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), [tp.InputInfo(shape=(2, 3), dtype=tp.float32)]]
365+
366+
compiled_func = tp.compile(func, args=[list_input])
367+
368+
bad_list = [tp.ones((2, 3), dtype=tp.float32).eval(), {"not": tp.ones((2, 3), dtype=tp.float32).eval()}]
369+
370+
with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_list\[1\]\[0\]`."):
371+
compiled_func(bad_list)

0 commit comments

Comments
 (0)