Skip to content

Commit 48b5d7d

Browse files
committed
#230: Support collections of tensors in args/kwargs for compile
1 parent d1a8447 commit 48b5d7d

File tree

3 files changed

+158
-7
lines changed

3 files changed

+158
-7
lines changed

tripy/nvtripy/backend/api/compile.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from nvtripy.utils.types import obj_name_or_type_name
2929

3030

31-
# TODO (#230): Support collections of tensors in args/kwargs
3231
@export.public_api(document_under="compiling_code/compile.rst")
3332
def compile(
3433
func: Callable, optimization_level: int = 3, *, args: Sequence[Any] = [], kwargs: Dict[str, Any] = {}
@@ -163,7 +162,8 @@ def add(a, b):
163162
for name, weight in func.state_dict().items():
164163
weight.name = name
165164

166-
def process_arg(name, arg):
165+
def process_arg_input_info(name, arg):
166+
"""Process InputInfo or DimensionInputInfo objects and create corresponding tensors."""
167167
if isinstance(arg, InputInfo):
168168
# Make new tensors for tracing.
169169
from nvtripy.common.datatype import floating, integer
@@ -204,6 +204,31 @@ def process_arg(name, arg):
204204

205205
return arg
206206

207+
def process_arg(name, arg):
208+
# Handle individual InputInfo or DimensionInputInfo objects
209+
if isinstance(arg, (InputInfo, DimensionInputInfo)):
210+
return process_arg_input_info(name, arg)
211+
212+
# Handle containers of InputInfo objects
213+
if isinstance(arg, dict):
214+
if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg.values()):
215+
input_names.add(name)
216+
result = {}
217+
for key, value in arg.items():
218+
nested_name = f"{name}.{key}"
219+
result[key] = process_arg(nested_name, value)
220+
return result
221+
elif isinstance(arg, (list, tuple)):
222+
if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg):
223+
input_names.add(name)
224+
result = []
225+
for idx, value in enumerate(arg):
226+
nested_name = f"{name}[{idx}]"
227+
result.append(process_arg(nested_name, value))
228+
return type(arg)(result)
229+
230+
return arg
231+
207232
compiled_arg_names = []
208233

209234
new_args = []
@@ -259,7 +284,24 @@ def process_arg(name, arg):
259284
)
260285

261286
# Order of trace inputs also needs to match that of the compiled_arg_names
262-
trace_inputs = [trace_input_map[name].trace_tensor for name in compiled_arg_names]
287+
# For containers, we need to collect all individual trace tensors
288+
def collect_trace_tensors(name):
289+
"""Collect trace tensors for a name, flattening containers."""
290+
if name in trace_input_map:
291+
# Regular InputInfo or DimensionInputInfo
292+
return [trace_input_map[name].trace_tensor]
293+
else:
294+
# Collect all nested trace tensors inside the container
295+
nested_tensors = []
296+
for nested_name in sorted(trace_input_map.keys()):
297+
if nested_name.startswith(f"{name}.") or nested_name.startswith(f"{name}["):
298+
nested_tensors.append(trace_input_map[nested_name].trace_tensor)
299+
return nested_tensors
300+
301+
# Flatten all trace tensors from containers and individual inputs
302+
trace_inputs = []
303+
for name in compiled_arg_names:
304+
trace_inputs.extend(collect_trace_tensors(name))
263305
trace = Trace(
264306
[tensor.trace_tensor for tensor in trace_outputs],
265307
trace_inputs,

tripy/nvtripy/backend/api/executable.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,61 @@ def add(a, b):
195195
],
196196
)
197197

198+
# Recursively extract inputs from containers to get individual tensors for validation and execution
199+
def extract_inputs(tensors, input_info_names):
200+
def extract_recursive(tensor, name_prefix):
201+
if isinstance(tensor, dict):
202+
result = []
203+
for key in sorted(tensor.keys()):
204+
nested_name = f"{name_prefix}.{key}"
205+
if nested_name in input_info_names:
206+
result.append(tensor[key])
207+
else:
208+
result.extend(extract_recursive(tensor[key], nested_name))
209+
return result
210+
elif isinstance(tensor, (list, tuple)):
211+
result = []
212+
for idx, value in enumerate(tensor):
213+
nested_name = f"{name_prefix}[{idx}]"
214+
if nested_name in input_info_names:
215+
result.append(value)
216+
else:
217+
result.extend(extract_recursive(value, nested_name))
218+
return result
219+
else: # Regular tensor
220+
if name_prefix in input_info_names:
221+
return [tensor]
222+
else:
223+
return []
224+
225+
flattened = []
226+
for name_idx, tensor in enumerate(tensors):
227+
arg_name = self._arg_names[name_idx]
228+
flattened.extend(extract_recursive(tensor, arg_name))
229+
return flattened
230+
231+
flattened_tensors = extract_inputs(input_tensors, set(self.input_infos.keys()))
198232
expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]
199-
for tensor, expected_device, arg_name in zip(input_tensors, expected_devices, self._arg_names):
233+
234+
# Validate flattened tensors against input_infos
235+
if len(flattened_tensors) != len(expected_devices):
236+
raise_error(
237+
f"Mismatch between number of flattened tensors ({len(flattened_tensors)}) and expected inputs ({len(expected_devices)})."
238+
)
239+
240+
for tensor, expected_device, info_name in zip(flattened_tensors, expected_devices, self.input_infos.keys()):
200241
producer = tensor.trace_tensor.producer
201242
if not isinstance(producer, Constant):
202-
raise_error(f"Tensor `{arg_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
243+
raise_error(f"Tensor `{info_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
203244
if tensor.device.kind != expected_device:
204245
raise_error(
205246
"Unexpected tensor device.",
206247
[
207-
f"For tensor: `{arg_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
248+
f"For tensor: `{info_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
208249
],
209250
)
210251

211-
input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors]
252+
input_memrefs = [inp.trace_tensor.producer.data for inp in flattened_tensors]
212253
try:
213254
output_memrefs = self._session.execute_function(
214255
"main", in_args=input_memrefs, stream=self.stream._active_cuda_stream, client=self._runtime_client

tripy/tests/backend/api/test_compile.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,71 @@ def test_dimension_input(self):
241241
out = compiled(inp, dim_inp)
242242
expected = (inp_cp + inp_cp).reshape((-1, reshape_dim))
243243
assert cp.array_equal(cp.from_dlpack(out), expected)
244+
245+
def test_compile_dict_input_info(self):
246+
"""Test compilation with dictionary of InputInfo objects."""
247+
248+
def func(data_dict):
249+
return data_dict["a"] + data_dict["b"]
250+
251+
dict_input = {
252+
"a": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
253+
"b": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
254+
}
255+
compiled_func = tp.compile(func, args=[dict_input])
256+
257+
test_dict = {"a": tp.ones((2, 3), dtype=tp.float32).eval(), "b": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()}
258+
result = compiled_func(test_dict)
259+
expected = test_dict["a"] + test_dict["b"]
260+
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))
261+
262+
def test_compile_nested_list_input_info(self):
263+
"""Test compilation with nested list containers."""
264+
265+
def func(data_list):
266+
return data_list[0] + data_list[1][0] + data_list[1][1]
267+
268+
list_input = [
269+
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
270+
[ # Nested list
271+
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
272+
tp.ones((2, 3), dtype=tp.float32) * 2, # Constant in nested list
273+
],
274+
]
275+
compiled_func = tp.compile(func, args=[list_input])
276+
277+
test_list = [
278+
tp.ones((2, 3), dtype=tp.float32).eval(),
279+
[ # Nested list in test data
280+
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
281+
tp.ones((2, 3), dtype=tp.float32) * 2, # Should match baked constant
282+
],
283+
]
284+
result = compiled_func(test_list)
285+
expected = test_list[0] + test_list[1][0] + test_list[1][1]
286+
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))
287+
288+
def test_compile_mixed_containers_and_constants(self):
289+
"""Test compilation with comprehensive mix: regular InputInfo, dict container, list container, and standalone constant."""
290+
291+
def func(regular_input, data_dict, data_list, constant_value):
292+
return regular_input + data_dict["x"] + data_dict["y"] + data_list[0] + data_list[1] + constant_value
293+
294+
regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32)
295+
dict_input = {
296+
"x": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
297+
"y": tp.zeros((2, 3), dtype=tp.float32), # Constant in dict
298+
}
299+
list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.float32) * 3]
300+
constant_value = tp.ones((2, 3), dtype=tp.float32) * 5
301+
302+
compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, constant_value])
303+
304+
# Only InputInfo arguments should be in function signature
305+
test_regular = tp.ones((2, 3), dtype=tp.float32).eval()
306+
test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval(), "y": tp.zeros((2, 3), dtype=tp.float32)}
307+
test_list = [(tp.ones((2, 3), dtype=tp.float32) * 4).eval(), tp.ones((2, 3), dtype=tp.float32) * 3]
308+
309+
result = compiled_func(test_regular, test_dict, test_list)
310+
expected = test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + constant_value
311+
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))

0 commit comments

Comments
 (0)