Skip to content

Commit e924c8b

Browse files
committed
Cache inputinfo structure at compile time, improve test coverage
1 parent 87ecff6 commit e924c8b

File tree

4 files changed

+99
-42
lines changed

4 files changed

+99
-42
lines changed

tripy/nvtripy/backend/api/compile.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,22 @@ def process_arg(name, arg):
316316
assert isinstance(func_out, Tensor) or isinstance(
317317
func_out, Sequence
318318
), "This function is only implemented for Tensors or sequences of Tensors"
319+
320+
# Group leaf input names by top-level argument for efficient runtime extraction
321+
leaf_names_by_arg = {}
322+
leaf_names = list(input_infos.keys())
323+
for arg_name in compiled_arg_names:
324+
matching = [
325+
leaf
326+
for leaf in leaf_names
327+
if leaf == arg_name or leaf.startswith(f"{arg_name}.") or leaf.startswith(f"{arg_name}[")
328+
]
329+
leaf_names_by_arg[arg_name] = matching
330+
319331
return Executable(
320332
executable,
321333
compiled_arg_names,
322334
return_single_tensor_as_sequence=isinstance(func_out, Sequence),
323335
input_infos=input_infos,
336+
leaf_names_by_arg=leaf_names_by_arg,
324337
)

tripy/nvtripy/backend/api/executable.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
arg_names,
4747
return_single_tensor_as_sequence: bool,
4848
input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]],
49+
leaf_names_by_arg: Dict[str, Sequence[str]],
4950
):
5051
self._executable = executable
5152

@@ -78,6 +79,8 @@ def __init__(
7879
Stores metadata, like shapes and data types, for each input to the executable.
7980
"""
8081

82+
self._leaf_names_by_arg = leaf_names_by_arg
83+
8184
def __str__(self) -> str:
8285
params = [
8386
f"{name}: {str_from_type_annotation(param.annotation)}"
@@ -195,34 +198,36 @@ def add(a, b):
195198
],
196199
)
197200

198-
# Recursively build a name->tensor map
199-
def extract_inputs(tensors, input_info_names):
200-
name_to_tensor = {}
201-
202-
def extract_recursive(value, name_prefix):
203-
if name_prefix in input_info_names:
204-
name_to_tensor[name_prefix] = value
205-
return
206-
if isinstance(value, dict):
207-
for key, item in value.items():
208-
nested_name = f"{name_prefix}.{key}"
209-
extract_recursive(item, nested_name)
210-
elif isinstance(value, (list, tuple)):
211-
for idx, item in enumerate(value):
212-
nested_name = f"{name_prefix}[{idx}]"
213-
extract_recursive(item, nested_name)
214-
else:
215-
print(f"Leaf tensor: {name_prefix}: {value}")
216-
return
217-
218-
for name_idx, tensor in enumerate(tensors):
219-
arg_name = self._arg_names[name_idx]
220-
extract_recursive(tensor, arg_name)
221-
222-
return name_to_tensor
223-
201+
# Build a name->tensor map using precomputed leaf names to avoid unnecessary recursion
224202
input_info_names = list(self.input_infos.keys())
225-
name_to_tensor = extract_inputs(input_tensors, set(input_info_names))
203+
name_to_tensor: Dict[str, Tensor] = {}
204+
205+
def extract_recursive(value, name_prefix, allowed_names):
206+
if name_prefix in allowed_names:
207+
name_to_tensor[name_prefix] = value
208+
return
209+
if isinstance(value, dict):
210+
for key, item in value.items():
211+
nested_name = f"{name_prefix}.{key}"
212+
extract_recursive(item, nested_name, allowed_names)
213+
elif isinstance(value, (list, tuple)):
214+
for idx, item in enumerate(value):
215+
nested_name = f"{name_prefix}[{idx}]"
216+
extract_recursive(item, nested_name, allowed_names)
217+
else:
218+
return
219+
220+
for name_idx, tensor in enumerate(input_tensors):
221+
arg_name = self._arg_names[name_idx]
222+
# Fast path: direct leaf input
223+
if arg_name in self.input_infos:
224+
name_to_tensor[arg_name] = tensor
225+
continue
226+
# If this arg has no compiled leaves beneath it, skip any recursion
227+
allowed = self._leaf_names_by_arg.get(arg_name)
228+
if not allowed:
229+
continue
230+
extract_recursive(tensor, arg_name, set(allowed))
226231
try:
227232
flattened_tensors = [name_to_tensor[name] for name in input_info_names]
228233
except KeyError as missing:
@@ -267,7 +272,7 @@ def extract_recursive(value, name_prefix):
267272
expected_input_dtypes = [
268273
info.dtype if isinstance(info, InputInfo) else int32 for info in self.input_infos.values()
269274
]
270-
for tensor, dtype, arg_name in zip(input_tensors, expected_input_dtypes, self._arg_names):
275+
for tensor, dtype, arg_name in zip(flattened_tensors, expected_input_dtypes, self.input_infos.keys()):
271276
if tensor.dtype != dtype:
272277
raise_error(
273278
f"Unexpected tensor data type.",
@@ -282,16 +287,17 @@ def extract_recursive(value, name_prefix):
282287
expected_input_shapes = [
283288
info.shape_bounds if isinstance(info, InputInfo) else tuple() for info in self.input_infos.values()
284289
]
285-
for tensor, expected_bounds, arg_name in zip(input_tensors, expected_input_shapes, self._arg_names):
290+
for tensor, expected_bounds, arg_name in zip(
291+
flattened_tensors, expected_input_shapes, self.input_infos.keys()
292+
):
286293
shape = tensor.shape
287294

288295
if len(shape) != len(expected_bounds.min):
289296
raise_error(
290297
f"Unexpected tensor rank.",
291298
[
292299
f"For tensor: `{arg_name}`, expected a rank of: {len(expected_bounds.min)} but got: {len(shape)}.\n"
293-
f"Note: The provided argument was: ",
294-
tensor,
300+
f"Note: The provided argument was a tensor with shape: {shape}",
295301
],
296302
)
297303

@@ -302,8 +308,7 @@ def extract_recursive(value, name_prefix):
302308
[
303309
f"For tensor: `{arg_name}`, expected a shape within the bounds: min={expected_bounds.min}, max={expected_bounds.max}, but got: {shape}.\n"
304310
f"Dimension {i} has a shape of {shape[i]}, which is not within the expected bounds of [{expected_bounds.min[i]}, {expected_bounds.max[i]}].\n"
305-
f"Note: The provided argument was: ",
306-
tensor,
311+
f"Note: The provided argument was a tensor with shape: {shape}",
307312
],
308313
)
309314
raise_error(str(err))

tripy/nvtripy/frontend/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def eval(self) -> "nvtripy.Tensor":
237237
name: InputInfo(list(map(int, inp.trace_tensor.shape)), inp.dtype)
238238
for name, inp in zip(arg_names, inputs)
239239
},
240+
leaf_names_by_arg={name: [name] for name in arg_names}, # every argument is a direct input
240241
)
241242
data = executable(*inputs).trace_tensor.producer.data
242243

tripy/tests/backend/api/test_compile.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def func(data_dict):
270270
}
271271
result = compiled_func(test_dict)
272272
expected = test_dict["a"]["inner"] + test_dict["b"]["list"][0] + test_dict["b"]["list"][1]
273-
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))
273+
assert tp.equal(result, expected)
274274

275275
def test_compile_nested_sequence_input_info(self):
276276
def func(data_list):
@@ -294,11 +294,9 @@ def func(data_list):
294294
]
295295
result = compiled_func(test_list)
296296
expected = test_list[0] + test_list[1][0] + test_list[1][1]
297-
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))
297+
assert tp.equal(result, expected)
298298

299299
def test_compile_mixed_containers_and_constants(self):
300-
"""Test compilation with comprehensive mix: regular InputInfo, dict container, list container, and standalone constant."""
301-
302300
def func(regular_input, data_dict, data_list, const_in_dict, const):
303301
return (
304302
regular_input
@@ -315,19 +313,59 @@ def func(regular_input, data_dict, data_list, const_in_dict, const):
315313
"x": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
316314
"y": tp.zeros((2, 3), dtype=tp.float32),
317315
}
318-
list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.float32) * 3]
316+
list_input = [tp.ones((2, 3), dtype=tp.float32) * 3, tp.InputInfo(shape=(2, 3), dtype=tp.float32)]
319317
const_in_dict = {"z": tp.ones((2, 3), dtype=tp.float32) * 5}
320318
const = tp.ones((2, 3), dtype=tp.float32) * 6
321319

322320
compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, const_in_dict, const])
323321

324322
# Only InputInfo arguments should be in function signature
325323
test_regular = tp.ones((2, 3), dtype=tp.float32).eval()
326-
test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval(), "y": tp.zeros((2, 3), dtype=tp.float32)}
327-
test_list = [(tp.ones((2, 3), dtype=tp.float32) * 4).eval(), tp.ones((2, 3), dtype=tp.float32) * 3]
324+
test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()}
325+
test_list = [None, (tp.ones((2, 3), dtype=tp.float32) * 4).eval()]
328326

329327
result = compiled_func(test_regular, test_dict, test_list)
330328
expected = (
331-
test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + const_in_dict["z"] + const
329+
test_regular + test_dict["x"] + dict_input["y"] + test_list[1] + list_input[0] + const_in_dict["z"] + const
332330
)
333-
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))
331+
assert tp.equal(result, expected)
332+
333+
def test_compile_missing_nested_input_fails(self):
334+
def func(data_dict):
335+
return data_dict["a"]["inner"] + data_dict["b"]["list"][1]
336+
337+
dict_input = {
338+
"a": {"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32)},
339+
"b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.InputInfo(shape=(2, 3), dtype=tp.float32)]},
340+
}
341+
342+
compiled_func = tp.compile(func, args=[dict_input])
343+
344+
# Missing b.list[1]
345+
bad_dict = {
346+
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
347+
"b": {"list": [tp.ones((2, 3), dtype=tp.float32).eval()]},
348+
}
349+
with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_dict\.b\.list\[1\]`."):
350+
compiled_func(bad_dict)
351+
352+
# Wrong shape for b.list[1] should trigger a shape/device validation error
353+
wrong_shape = {
354+
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
355+
"b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32).eval()]},
356+
}
357+
with helper.raises(tp.TripyException, match="Unexpected tensor shape."):
358+
compiled_func(wrong_shape)
359+
360+
def test_compile_container_mismatch_fails(self):
361+
def func(data_list):
362+
return data_list[0] + data_list[1][0]
363+
364+
list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), [tp.InputInfo(shape=(2, 3), dtype=tp.float32)]]
365+
366+
compiled_func = tp.compile(func, args=[list_input])
367+
368+
bad_list = [tp.ones((2, 3), dtype=tp.float32).eval(), {"not": tp.ones((2, 3), dtype=tp.float32).eval()}]
369+
370+
with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_list\[1\]\[0\]`."):
371+
compiled_func(bad_list)

0 commit comments

Comments
 (0)