From 1c1dbe0c359d5ddcc808c531aa64595ede9b726b Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Fri, 10 Oct 2025 16:13:19 -0700 Subject: [PATCH 1/7] #230: Support collections of tensors in args/kwargs for compile --- tripy/nvtripy/backend/api/compile.py | 48 +++++++++++++++-- tripy/nvtripy/backend/api/executable.py | 49 ++++++++++++++++-- tripy/tests/backend/api/test_compile.py | 68 +++++++++++++++++++++++++ 3 files changed, 158 insertions(+), 7 deletions(-) diff --git a/tripy/nvtripy/backend/api/compile.py b/tripy/nvtripy/backend/api/compile.py index 9ba38ecd4..a92a5f2ef 100644 --- a/tripy/nvtripy/backend/api/compile.py +++ b/tripy/nvtripy/backend/api/compile.py @@ -28,7 +28,6 @@ from nvtripy.utils.types import obj_name_or_type_name -# TODO (#230): Support collections of tensors in args/kwargs @export.public_api(document_under="compiling_code/compile.rst") def compile( func: Callable, optimization_level: int = 3, *, args: Sequence[Any] = [], kwargs: Dict[str, Any] = {} @@ -163,7 +162,8 @@ def add(a, b): for name, weight in func.state_dict().items(): weight.name = name - def process_arg(name, arg): + def process_arg_input_info(name, arg): + """Process InputInfo or DimensionInputInfo objects and create corresponding tensors.""" if isinstance(arg, InputInfo): # Make new tensors for tracing. from nvtripy.common.datatype import floating, integer @@ -204,6 +204,31 @@ def process_arg(name, arg): return arg + def process_arg(name, arg): + # Handle individual InputInfo or DimensionInputInfo objects + if isinstance(arg, (InputInfo, DimensionInputInfo)): + return process_arg_input_info(name, arg) + + # Handle containers of InputInfo objects + if isinstance(arg, dict): + if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg.values()): + input_names.add(name) + result = {} + for key, value in arg.items(): + nested_name = f"{name}.{key}" + result[key] = process_arg(nested_name, value) + return result + elif isinstance(arg, (list, tuple)): + if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg): + input_names.add(name) + result = [] + for idx, value in enumerate(arg): + nested_name = f"{name}[{idx}]" + result.append(process_arg(nested_name, value)) + return type(arg)(result) + + return arg + compiled_arg_names = [] new_args = [] @@ -259,7 +284,24 @@ def process_arg(name, arg): ) # Order of trace inputs also needs to match that of the compiled_arg_names - trace_inputs = [trace_input_map[name].trace_tensor for name in compiled_arg_names] + # For containers, we need to collect all individual trace tensors + def collect_trace_tensors(name): + """Collect trace tensors for a name, flattening containers.""" + if name in trace_input_map: + # Regular InputInfo or DimensionInputInfo + return [trace_input_map[name].trace_tensor] + else: + # Collect all nested trace tensors inside the container + nested_tensors = [] + for nested_name in sorted(trace_input_map.keys()): + if nested_name.startswith(f"{name}.") or nested_name.startswith(f"{name}["): + nested_tensors.append(trace_input_map[nested_name].trace_tensor) + return nested_tensors + + # Flatten all trace tensors from containers and individual inputs + trace_inputs = [] + for name in compiled_arg_names: + trace_inputs.extend(collect_trace_tensors(name)) trace = Trace( [tensor.trace_tensor for tensor in trace_outputs], trace_inputs, diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index 57b3a78c7..e3539d64a 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -195,20 +195,61 @@ def add(a, b): ], ) + # Recursively extract inputs from containers to get individual tensors for validation and execution + def extract_inputs(tensors, input_info_names): + def extract_recursive(tensor, name_prefix): + if isinstance(tensor, dict): + result = [] + for key in sorted(tensor.keys()): + nested_name = f"{name_prefix}.{key}" + if nested_name in input_info_names: + result.append(tensor[key]) + else: + result.extend(extract_recursive(tensor[key], nested_name)) + return result + elif isinstance(tensor, (list, tuple)): + result = [] + for idx, value in enumerate(tensor): + nested_name = f"{name_prefix}[{idx}]" + if nested_name in input_info_names: + result.append(value) + else: + result.extend(extract_recursive(value, nested_name)) + return result + else: # Regular tensor + if name_prefix in input_info_names: + return [tensor] + else: + return [] + + flattened = [] + for name_idx, tensor in enumerate(tensors): + arg_name = self._arg_names[name_idx] + flattened.extend(extract_recursive(tensor, arg_name)) + return flattened + + flattened_tensors = extract_inputs(input_tensors, set(self.input_infos.keys())) expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()] - for tensor, expected_device, arg_name in zip(input_tensors, expected_devices, self._arg_names): + + # Validate flattened tensors against input_infos + if len(flattened_tensors) != len(expected_devices): + raise_error( + f"Mismatch between number of flattened tensors ({len(flattened_tensors)}) and expected inputs ({len(expected_devices)})." + ) + + for tensor, expected_device, info_name in zip(flattened_tensors, expected_devices, self.input_infos.keys()): producer = tensor.trace_tensor.producer if not isinstance(producer, Constant): - raise_error(f"Tensor `{arg_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."]) + raise_error(f"Tensor `{info_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."]) if tensor.device.kind != expected_device: raise_error( "Unexpected tensor device.", [ - f"For tensor: `{arg_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n", + f"For tensor: `{info_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n", ], ) - input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors] + input_memrefs = [inp.trace_tensor.producer.data for inp in flattened_tensors] try: output_memrefs = self._session.execute_function( "main", in_args=input_memrefs, stream=self.stream._active_cuda_stream, client=self._runtime_client diff --git a/tripy/tests/backend/api/test_compile.py b/tripy/tests/backend/api/test_compile.py index b25ba4b92..0c3edef63 100644 --- a/tripy/tests/backend/api/test_compile.py +++ b/tripy/tests/backend/api/test_compile.py @@ -241,3 +241,71 @@ def test_dimension_input(self): out = compiled(inp, dim_inp) expected = (inp_cp + inp_cp).reshape((-1, reshape_dim)) assert cp.array_equal(cp.from_dlpack(out), expected) + + def test_compile_dict_input_info(self): + """Test compilation with dictionary of InputInfo objects.""" + + def func(data_dict): + return data_dict["a"] + data_dict["b"] + + dict_input = { + "a": tp.InputInfo(shape=(2, 3), dtype=tp.float32), + "b": tp.InputInfo(shape=(2, 3), dtype=tp.float32), + } + compiled_func = tp.compile(func, args=[dict_input]) + + test_dict = {"a": tp.ones((2, 3), dtype=tp.float32).eval(), "b": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()} + result = compiled_func(test_dict) + expected = test_dict["a"] + test_dict["b"] + assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected)) + + def test_compile_nested_list_input_info(self): + """Test compilation with nested list containers.""" + + def func(data_list): + return data_list[0] + data_list[1][0] + data_list[1][1] + + list_input = [ + tp.InputInfo(shape=(2, 3), dtype=tp.float32), + [ # Nested list + tp.InputInfo(shape=(2, 3), dtype=tp.float32), + tp.ones((2, 3), dtype=tp.float32) * 2, # Constant in nested list + ], + ] + compiled_func = tp.compile(func, args=[list_input]) + + test_list = [ + tp.ones((2, 3), dtype=tp.float32).eval(), + [ # Nested list in test data + (tp.ones((2, 3), dtype=tp.float32) * 3).eval(), + tp.ones((2, 3), dtype=tp.float32) * 2, # Should match baked constant + ], + ] + result = compiled_func(test_list) + expected = test_list[0] + test_list[1][0] + test_list[1][1] + assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected)) + + def test_compile_mixed_containers_and_constants(self): + """Test compilation with comprehensive mix: regular InputInfo, dict container, list container, and standalone constant.""" + + def func(regular_input, data_dict, data_list, constant_value): + return regular_input + data_dict["x"] + data_dict["y"] + data_list[0] + data_list[1] + constant_value + + regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32) + dict_input = { + "x": tp.InputInfo(shape=(2, 3), dtype=tp.float32), + "y": tp.zeros((2, 3), dtype=tp.float32), # Constant in dict + } + list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.float32) * 3] + constant_value = tp.ones((2, 3), dtype=tp.float32) * 5 + + compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, constant_value]) + + # Only InputInfo arguments should be in function signature + test_regular = tp.ones((2, 3), dtype=tp.float32).eval() + test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval(), "y": tp.zeros((2, 3), dtype=tp.float32)} + test_list = [(tp.ones((2, 3), dtype=tp.float32) * 4).eval(), tp.ones((2, 3), dtype=tp.float32) * 3] + + result = compiled_func(test_regular, test_dict, test_list) + expected = test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + constant_value + assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected)) From 85774cb518d16d218569ffd53ce18cc21f683509 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Fri, 24 Oct 2025 15:22:21 -0700 Subject: [PATCH 2/7] improve tests, review fixes --- tripy/nvtripy/backend/api/compile.py | 65 +++++++++++------------- tripy/nvtripy/backend/api/executable.py | 60 +++++++++++----------- tripy/tests/backend/api/test_compile.py | 66 ++++++++++++++++--------- 3 files changed, 105 insertions(+), 86 deletions(-) diff --git a/tripy/nvtripy/backend/api/compile.py b/tripy/nvtripy/backend/api/compile.py index a92a5f2ef..d8d7db5bb 100644 --- a/tripy/nvtripy/backend/api/compile.py +++ b/tripy/nvtripy/backend/api/compile.py @@ -156,6 +156,7 @@ def add(a, b): trace_input_map = {} input_names = set() input_infos = {} + trace_inputs = [] # flattened list of trace input tensors in argument order # Set up names for the weights in the module to make the trace easier to read. if isinstance(func, Module): @@ -184,6 +185,7 @@ def process_arg_input_info(name, arg): trace_input_map[name] = tensor input_names.add(name) + trace_inputs.append(tensor.trace_tensor) return tensor @@ -199,35 +201,44 @@ def process_arg_input_info(name, arg): trace_input_map[name] = tensor input_names.add(name) + trace_inputs.append(tensor.trace_tensor) return tensor return arg - def process_arg(name, arg): + def process_arg_and_flag(name, arg): # Handle individual InputInfo or DimensionInputInfo objects if isinstance(arg, (InputInfo, DimensionInputInfo)): - return process_arg_input_info(name, arg) + return process_arg_input_info(name, arg), True # Handle containers of InputInfo objects if isinstance(arg, dict): - if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg.values()): - input_names.add(name) - result = {} - for key, value in arg.items(): - nested_name = f"{name}.{key}" - result[key] = process_arg(nested_name, value) - return result + result = {} + has_input = False + for key, value in arg.items(): + nested_name = f"{name}.{key}" + processed_child, child_has_input = process_arg_and_flag(nested_name, value) + result[key] = processed_child + has_input = has_input or child_has_input + return result, has_input elif isinstance(arg, (list, tuple)): - if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg): - input_names.add(name) - result = [] - for idx, value in enumerate(arg): - nested_name = f"{name}[{idx}]" - result.append(process_arg(nested_name, value)) - return type(arg)(result) + result_list = [] + has_input = False + for idx, value in enumerate(arg): + nested_name = f"{name}[{idx}]" + processed_child, child_has_input = process_arg_and_flag(nested_name, value) + result_list.append(processed_child) + has_input = has_input or child_has_input + return type(arg)(result_list), has_input # preserve sequence type - return arg + return arg, False + + def process_arg(name, arg): + processed, has_input = process_arg_and_flag(name, arg) + if has_input: + input_names.add(name) + return processed compiled_arg_names = [] @@ -283,25 +294,7 @@ def process_arg(name, arg): [f"Return value {index} was not a tensor: {repr(trace_out)}"], ) - # Order of trace inputs also needs to match that of the compiled_arg_names - # For containers, we need to collect all individual trace tensors - def collect_trace_tensors(name): - """Collect trace tensors for a name, flattening containers.""" - if name in trace_input_map: - # Regular InputInfo or DimensionInputInfo - return [trace_input_map[name].trace_tensor] - else: - # Collect all nested trace tensors inside the container - nested_tensors = [] - for nested_name in sorted(trace_input_map.keys()): - if nested_name.startswith(f"{name}.") or nested_name.startswith(f"{name}["): - nested_tensors.append(trace_input_map[nested_name].trace_tensor) - return nested_tensors - - # Flatten all trace tensors from containers and individual inputs - trace_inputs = [] - for name in compiled_arg_names: - trace_inputs.extend(collect_trace_tensors(name)) + # We collected flattened trace inputs during traversal trace = Trace( [tensor.trace_tensor for tensor in trace_outputs], trace_inputs, diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index e3539d64a..331de3619 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -195,40 +195,44 @@ def add(a, b): ], ) - # Recursively extract inputs from containers to get individual tensors for validation and execution + # Recursively build a name->tensor map def extract_inputs(tensors, input_info_names): - def extract_recursive(tensor, name_prefix): - if isinstance(tensor, dict): - result = [] - for key in sorted(tensor.keys()): + name_to_tensor = {} + + def extract_recursive(value, name_prefix): + if name_prefix in input_info_names: + name_to_tensor[name_prefix] = value + return + if isinstance(value, dict): + for key, item in value.items(): nested_name = f"{name_prefix}.{key}" - if nested_name in input_info_names: - result.append(tensor[key]) - else: - result.extend(extract_recursive(tensor[key], nested_name)) - return result - elif isinstance(tensor, (list, tuple)): - result = [] - for idx, value in enumerate(tensor): + extract_recursive(item, nested_name) + elif isinstance(value, (list, tuple)): + for idx, item in enumerate(value): nested_name = f"{name_prefix}[{idx}]" - if nested_name in input_info_names: - result.append(value) - else: - result.extend(extract_recursive(value, nested_name)) - return result - else: # Regular tensor - if name_prefix in input_info_names: - return [tensor] - else: - return [] - - flattened = [] + extract_recursive(item, nested_name) + else: + print(f"Leaf tensor: {name_prefix}: {value}") + return + for name_idx, tensor in enumerate(tensors): arg_name = self._arg_names[name_idx] - flattened.extend(extract_recursive(tensor, arg_name)) - return flattened + extract_recursive(tensor, arg_name) + + return name_to_tensor - flattened_tensors = extract_inputs(input_tensors, set(self.input_infos.keys())) + input_info_names = list(self.input_infos.keys()) + name_to_tensor = extract_inputs(input_tensors, set(input_info_names)) + try: + flattened_tensors = [name_to_tensor[name] for name in input_info_names] + except KeyError as missing: + raise_error( + f"Missing runtime tensor for input `{missing.args[0]}`.", + [ + "Ensure your provided containers include tensors for all compiled inputs.", + f"Expected inputs: {input_info_names}", + ], + ) expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()] # Validate flattened tensors against input_infos diff --git a/tripy/tests/backend/api/test_compile.py b/tripy/tests/backend/api/test_compile.py index 0c3edef63..7028defbe 100644 --- a/tripy/tests/backend/api/test_compile.py +++ b/tripy/tests/backend/api/test_compile.py @@ -242,44 +242,55 @@ def test_dimension_input(self): expected = (inp_cp + inp_cp).reshape((-1, reshape_dim)) assert cp.array_equal(cp.from_dlpack(out), expected) - def test_compile_dict_input_info(self): - """Test compilation with dictionary of InputInfo objects.""" - + def test_compile_nested_dict_input_info(self): def func(data_dict): - return data_dict["a"] + data_dict["b"] + return data_dict["a"]["inner"] + data_dict["b"]["list"][0] + data_dict["b"]["list"][1] dict_input = { - "a": tp.InputInfo(shape=(2, 3), dtype=tp.float32), - "b": tp.InputInfo(shape=(2, 3), dtype=tp.float32), + "a": { + "inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32), + }, + "b": { + "list": [ + tp.InputInfo(shape=(2, 3), dtype=tp.float32), + tp.InputInfo(shape=(2, 3), dtype=tp.float32), + ], + }, } compiled_func = tp.compile(func, args=[dict_input]) - test_dict = {"a": tp.ones((2, 3), dtype=tp.float32).eval(), "b": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()} + test_dict = { + "a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()}, + "b": { + "list": [ + (tp.ones((2, 3), dtype=tp.float32) * 2).eval(), + (tp.ones((2, 3), dtype=tp.float32) * 3).eval(), + ] + }, + } result = compiled_func(test_dict) - expected = test_dict["a"] + test_dict["b"] + expected = test_dict["a"]["inner"] + test_dict["b"]["list"][0] + test_dict["b"]["list"][1] assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected)) - def test_compile_nested_list_input_info(self): - """Test compilation with nested list containers.""" - + def test_compile_nested_sequence_input_info(self): def func(data_list): return data_list[0] + data_list[1][0] + data_list[1][1] list_input = [ tp.InputInfo(shape=(2, 3), dtype=tp.float32), - [ # Nested list + [ tp.InputInfo(shape=(2, 3), dtype=tp.float32), - tp.ones((2, 3), dtype=tp.float32) * 2, # Constant in nested list + tp.ones((2, 3), dtype=tp.float32) * 2, ], ] compiled_func = tp.compile(func, args=[list_input]) test_list = [ tp.ones((2, 3), dtype=tp.float32).eval(), - [ # Nested list in test data + ( (tp.ones((2, 3), dtype=tp.float32) * 3).eval(), - tp.ones((2, 3), dtype=tp.float32) * 2, # Should match baked constant - ], + tp.ones((2, 3), dtype=tp.float32) * 2, + ), ] result = compiled_func(test_list) expected = test_list[0] + test_list[1][0] + test_list[1][1] @@ -288,18 +299,27 @@ def func(data_list): def test_compile_mixed_containers_and_constants(self): """Test compilation with comprehensive mix: regular InputInfo, dict container, list container, and standalone constant.""" - def func(regular_input, data_dict, data_list, constant_value): - return regular_input + data_dict["x"] + data_dict["y"] + data_list[0] + data_list[1] + constant_value + def func(regular_input, data_dict, data_list, const_in_dict, const): + return ( + regular_input + + data_dict["x"] + + data_dict["y"] + + data_list[0] + + data_list[1] + + const_in_dict["z"] + + const + ) regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32) dict_input = { "x": tp.InputInfo(shape=(2, 3), dtype=tp.float32), - "y": tp.zeros((2, 3), dtype=tp.float32), # Constant in dict + "y": tp.zeros((2, 3), dtype=tp.float32), } list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.float32) * 3] - constant_value = tp.ones((2, 3), dtype=tp.float32) * 5 + const_in_dict = {"z": tp.ones((2, 3), dtype=tp.float32) * 5} + const = tp.ones((2, 3), dtype=tp.float32) * 6 - compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, constant_value]) + compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, const_in_dict, const]) # Only InputInfo arguments should be in function signature test_regular = tp.ones((2, 3), dtype=tp.float32).eval() @@ -307,5 +327,7 @@ def func(regular_input, data_dict, data_list, constant_value): test_list = [(tp.ones((2, 3), dtype=tp.float32) * 4).eval(), tp.ones((2, 3), dtype=tp.float32) * 3] result = compiled_func(test_regular, test_dict, test_list) - expected = test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + constant_value + expected = ( + test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + const_in_dict["z"] + const + ) assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected)) From fbf1afb2fbe09346f24e71b2d0bd280143e67e1f Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Fri, 24 Oct 2025 18:10:53 -0700 Subject: [PATCH 3/7] Cache inputinfo structure at compile time, improve test coverage --- tripy/nvtripy/backend/api/compile.py | 13 +++++ tripy/nvtripy/backend/api/executable.py | 71 +++++++++++++------------ tripy/nvtripy/frontend/tensor.py | 1 + tripy/tests/backend/api/test_compile.py | 56 +++++++++++++++---- 4 files changed, 99 insertions(+), 42 deletions(-) diff --git a/tripy/nvtripy/backend/api/compile.py b/tripy/nvtripy/backend/api/compile.py index d8d7db5bb..558e279d8 100644 --- a/tripy/nvtripy/backend/api/compile.py +++ b/tripy/nvtripy/backend/api/compile.py @@ -316,9 +316,22 @@ def process_arg(name, arg): assert isinstance(func_out, Tensor) or isinstance( func_out, Sequence ), "This function is only implemented for Tensors or sequences of Tensors" + + # Group leaf input names by top-level argument for efficient runtime extraction + leaf_names_by_arg = {} + leaf_names = list(input_infos.keys()) + for arg_name in compiled_arg_names: + matching = [ + leaf + for leaf in leaf_names + if leaf == arg_name or leaf.startswith(f"{arg_name}.") or leaf.startswith(f"{arg_name}[") + ] + leaf_names_by_arg[arg_name] = matching + return Executable( executable, compiled_arg_names, return_single_tensor_as_sequence=isinstance(func_out, Sequence), input_infos=input_infos, + leaf_names_by_arg=leaf_names_by_arg, ) diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index 331de3619..3048f55f4 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -46,6 +46,7 @@ def __init__( arg_names, return_single_tensor_as_sequence: bool, input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]], + leaf_names_by_arg: Dict[str, Sequence[str]], ): self._executable = executable @@ -78,6 +79,8 @@ def __init__( Stores metadata, like shapes and data types, for each input to the executable. """ + self._leaf_names_by_arg = leaf_names_by_arg + def __str__(self) -> str: params = [ f"{name}: {str_from_type_annotation(param.annotation)}" @@ -195,34 +198,36 @@ def add(a, b): ], ) - # Recursively build a name->tensor map - def extract_inputs(tensors, input_info_names): - name_to_tensor = {} - - def extract_recursive(value, name_prefix): - if name_prefix in input_info_names: - name_to_tensor[name_prefix] = value - return - if isinstance(value, dict): - for key, item in value.items(): - nested_name = f"{name_prefix}.{key}" - extract_recursive(item, nested_name) - elif isinstance(value, (list, tuple)): - for idx, item in enumerate(value): - nested_name = f"{name_prefix}[{idx}]" - extract_recursive(item, nested_name) - else: - print(f"Leaf tensor: {name_prefix}: {value}") - return - - for name_idx, tensor in enumerate(tensors): - arg_name = self._arg_names[name_idx] - extract_recursive(tensor, arg_name) - - return name_to_tensor - + # Build a name->tensor map using precomputed leaf names to avoid unnecessary recursion input_info_names = list(self.input_infos.keys()) - name_to_tensor = extract_inputs(input_tensors, set(input_info_names)) + name_to_tensor: Dict[str, Tensor] = {} + + def extract_recursive(value, name_prefix, allowed_names): + if name_prefix in allowed_names: + name_to_tensor[name_prefix] = value + return + if isinstance(value, dict): + for key, item in value.items(): + nested_name = f"{name_prefix}.{key}" + extract_recursive(item, nested_name, allowed_names) + elif isinstance(value, (list, tuple)): + for idx, item in enumerate(value): + nested_name = f"{name_prefix}[{idx}]" + extract_recursive(item, nested_name, allowed_names) + else: + return + + for name_idx, tensor in enumerate(input_tensors): + arg_name = self._arg_names[name_idx] + # Fast path: direct leaf input + if arg_name in self.input_infos: + name_to_tensor[arg_name] = tensor + continue + # If this arg has no compiled leaves beneath it, skip any recursion + allowed = self._leaf_names_by_arg.get(arg_name) + if not allowed: + continue + extract_recursive(tensor, arg_name, set(allowed)) try: flattened_tensors = [name_to_tensor[name] for name in input_info_names] except KeyError as missing: @@ -267,7 +272,7 @@ def extract_recursive(value, name_prefix): expected_input_dtypes = [ info.dtype if isinstance(info, InputInfo) else int32 for info in self.input_infos.values() ] - for tensor, dtype, arg_name in zip(input_tensors, expected_input_dtypes, self._arg_names): + for tensor, dtype, arg_name in zip(flattened_tensors, expected_input_dtypes, self.input_infos.keys()): if tensor.dtype != dtype: raise_error( f"Unexpected tensor data type.", @@ -282,7 +287,9 @@ def extract_recursive(value, name_prefix): expected_input_shapes = [ info.shape_bounds if isinstance(info, InputInfo) else tuple() for info in self.input_infos.values() ] - for tensor, expected_bounds, arg_name in zip(input_tensors, expected_input_shapes, self._arg_names): + for tensor, expected_bounds, arg_name in zip( + flattened_tensors, expected_input_shapes, self.input_infos.keys() + ): shape = tensor.shape if len(shape) != len(expected_bounds.min): @@ -290,8 +297,7 @@ def extract_recursive(value, name_prefix): f"Unexpected tensor rank.", [ f"For tensor: `{arg_name}`, expected a rank of: {len(expected_bounds.min)} but got: {len(shape)}.\n" - f"Note: The provided argument was: ", - tensor, + f"Note: The provided argument was a tensor with shape: {shape}", ], ) @@ -302,8 +308,7 @@ def extract_recursive(value, name_prefix): [ f"For tensor: `{arg_name}`, expected a shape within the bounds: min={expected_bounds.min}, max={expected_bounds.max}, but got: {shape}.\n" f"Dimension {i} has a shape of {shape[i]}, which is not within the expected bounds of [{expected_bounds.min[i]}, {expected_bounds.max[i]}].\n" - f"Note: The provided argument was: ", - tensor, + f"Note: The provided argument was a tensor with shape: {shape}", ], ) raise_error(str(err)) diff --git a/tripy/nvtripy/frontend/tensor.py b/tripy/nvtripy/frontend/tensor.py index 91146a3d6..0268acf55 100644 --- a/tripy/nvtripy/frontend/tensor.py +++ b/tripy/nvtripy/frontend/tensor.py @@ -237,6 +237,7 @@ def eval(self) -> "nvtripy.Tensor": name: InputInfo(list(map(int, inp.trace_tensor.shape)), inp.dtype) for name, inp in zip(arg_names, inputs) }, + leaf_names_by_arg={name: [name] for name in arg_names}, # every argument is a direct input ) data = executable(*inputs).trace_tensor.producer.data diff --git a/tripy/tests/backend/api/test_compile.py b/tripy/tests/backend/api/test_compile.py index 7028defbe..6cc935628 100644 --- a/tripy/tests/backend/api/test_compile.py +++ b/tripy/tests/backend/api/test_compile.py @@ -270,7 +270,7 @@ def func(data_dict): } result = compiled_func(test_dict) expected = test_dict["a"]["inner"] + test_dict["b"]["list"][0] + test_dict["b"]["list"][1] - assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected)) + assert tp.equal(result, expected) def test_compile_nested_sequence_input_info(self): def func(data_list): @@ -294,11 +294,9 @@ def func(data_list): ] result = compiled_func(test_list) expected = test_list[0] + test_list[1][0] + test_list[1][1] - assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected)) + assert tp.equal(result, expected) def test_compile_mixed_containers_and_constants(self): - """Test compilation with comprehensive mix: regular InputInfo, dict container, list container, and standalone constant.""" - def func(regular_input, data_dict, data_list, const_in_dict, const): return ( regular_input @@ -315,7 +313,7 @@ def func(regular_input, data_dict, data_list, const_in_dict, const): "x": tp.InputInfo(shape=(2, 3), dtype=tp.float32), "y": tp.zeros((2, 3), dtype=tp.float32), } - list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.float32) * 3] + list_input = [tp.ones((2, 3), dtype=tp.float32) * 3, tp.InputInfo(shape=(2, 3), dtype=tp.float32)] const_in_dict = {"z": tp.ones((2, 3), dtype=tp.float32) * 5} const = tp.ones((2, 3), dtype=tp.float32) * 6 @@ -323,11 +321,51 @@ def func(regular_input, data_dict, data_list, const_in_dict, const): # Only InputInfo arguments should be in function signature test_regular = tp.ones((2, 3), dtype=tp.float32).eval() - test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval(), "y": tp.zeros((2, 3), dtype=tp.float32)} - test_list = [(tp.ones((2, 3), dtype=tp.float32) * 4).eval(), tp.ones((2, 3), dtype=tp.float32) * 3] + test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()} + test_list = [None, (tp.ones((2, 3), dtype=tp.float32) * 4).eval()] result = compiled_func(test_regular, test_dict, test_list) expected = ( - test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + const_in_dict["z"] + const + test_regular + test_dict["x"] + dict_input["y"] + test_list[1] + list_input[0] + const_in_dict["z"] + const ) - assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected)) + assert tp.equal(result, expected) + + def test_compile_missing_nested_input_fails(self): + def func(data_dict): + return data_dict["a"]["inner"] + data_dict["b"]["list"][1] + + dict_input = { + "a": {"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32)}, + "b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.InputInfo(shape=(2, 3), dtype=tp.float32)]}, + } + + compiled_func = tp.compile(func, args=[dict_input]) + + # Missing b.list[1] + bad_dict = { + "a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()}, + "b": {"list": [tp.ones((2, 3), dtype=tp.float32).eval()]}, + } + with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_dict\.b\.list\[1\]`."): + compiled_func(bad_dict) + + # Wrong shape for b.list[1] should trigger a shape/device validation error + wrong_shape = { + "a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()}, + "b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32).eval()]}, + } + with helper.raises(tp.TripyException, match="Unexpected tensor shape."): + compiled_func(wrong_shape) + + def test_compile_container_mismatch_fails(self): + def func(data_list): + return data_list[0] + data_list[1][0] + + list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), [tp.InputInfo(shape=(2, 3), dtype=tp.float32)]] + + compiled_func = tp.compile(func, args=[list_input]) + + bad_list = [tp.ones((2, 3), dtype=tp.float32).eval(), {"not": tp.ones((2, 3), dtype=tp.float32).eval()}] + + with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_list\[1\]\[0\]`."): + compiled_func(bad_list) From 68407090f709698fc52a689a498c21b97bfa4e7c Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Thu, 30 Oct 2025 17:22:39 -0700 Subject: [PATCH 4/7] Update serialization and deserialization --- tripy/nvtripy/backend/api/executable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index 3048f55f4..5db4c6548 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -396,6 +396,7 @@ def encode_executable(executable): "executable": base64.b64encode(executable._executable.serialize()).decode(), "_return_single_tensor_as_sequence": executable._return_single_tensor_as_sequence, "input_infos": executable.input_infos, + "leaf_names_by_arg": executable._leaf_names_by_arg, } @@ -407,4 +408,5 @@ def decode_executable(executable_dict): executable_dict["arg_names"], return_single_tensor_as_sequence=executable_dict["_return_single_tensor_as_sequence"], input_infos=executable_dict["input_infos"], + leaf_names_by_arg=executable_dict.get("leaf_names_by_arg"), ) From de3ebe08570ae926811dac509460c7bc27a3f6b3 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Thu, 30 Oct 2025 17:53:09 -0700 Subject: [PATCH 5/7] fix missing tensor in stack info for shape mismatch --- tripy/nvtripy/backend/api/executable.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index 5db4c6548..1ca480f74 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -297,7 +297,8 @@ def extract_recursive(value, name_prefix, allowed_names): f"Unexpected tensor rank.", [ f"For tensor: `{arg_name}`, expected a rank of: {len(expected_bounds.min)} but got: {len(shape)}.\n" - f"Note: The provided argument was a tensor with shape: {shape}", + f"Note: The provided argument was: ", + tensor, ], ) @@ -308,7 +309,8 @@ def extract_recursive(value, name_prefix, allowed_names): [ f"For tensor: `{arg_name}`, expected a shape within the bounds: min={expected_bounds.min}, max={expected_bounds.max}, but got: {shape}.\n" f"Dimension {i} has a shape of {shape[i]}, which is not within the expected bounds of [{expected_bounds.min[i]}, {expected_bounds.max[i]}].\n" - f"Note: The provided argument was a tensor with shape: {shape}", + f"Note: The provided argument was: ", + tensor, ], ) raise_error(str(err)) From db60c258b7610f09a08c6e5650e0e710cf6f437c Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Wed, 12 Nov 2025 21:11:52 -0800 Subject: [PATCH 6/7] Enable direct fast access of inputs inside containers --- tripy/nvtripy/backend/api/compile.py | 28 ++++----- tripy/nvtripy/backend/api/executable.py | 76 +++++++++++-------------- tripy/nvtripy/frontend/tensor.py | 2 +- 3 files changed, 45 insertions(+), 61 deletions(-) diff --git a/tripy/nvtripy/backend/api/compile.py b/tripy/nvtripy/backend/api/compile.py index 558e279d8..624ace597 100644 --- a/tripy/nvtripy/backend/api/compile.py +++ b/tripy/nvtripy/backend/api/compile.py @@ -157,6 +157,7 @@ def add(a, b): input_names = set() input_infos = {} trace_inputs = [] # flattened list of trace input tensors in argument order + access_plan_by_name: Dict[str, tuple] = {} # Set up names for the weights in the module to make the trace easier to read. if isinstance(func, Module): @@ -207,10 +208,12 @@ def process_arg_input_info(name, arg): return arg - def process_arg_and_flag(name, arg): + def process_arg_and_flag(top_arg_name, name, arg, steps): # Handle individual InputInfo or DimensionInputInfo objects if isinstance(arg, (InputInfo, DimensionInputInfo)): - return process_arg_input_info(name, arg), True + tensor_or_dim = process_arg_input_info(name, arg) + access_plan_by_name[name] = (top_arg_name, tuple(steps)) + return tensor_or_dim, True # Handle containers of InputInfo objects if isinstance(arg, dict): @@ -218,7 +221,9 @@ def process_arg_and_flag(name, arg): has_input = False for key, value in arg.items(): nested_name = f"{name}.{key}" - processed_child, child_has_input = process_arg_and_flag(nested_name, value) + processed_child, child_has_input = process_arg_and_flag( + top_arg_name, nested_name, value, (*steps, str(key)) + ) result[key] = processed_child has_input = has_input or child_has_input return result, has_input @@ -227,7 +232,7 @@ def process_arg_and_flag(name, arg): has_input = False for idx, value in enumerate(arg): nested_name = f"{name}[{idx}]" - processed_child, child_has_input = process_arg_and_flag(nested_name, value) + processed_child, child_has_input = process_arg_and_flag(top_arg_name, nested_name, value, (*steps, idx)) result_list.append(processed_child) has_input = has_input or child_has_input return type(arg)(result_list), has_input # preserve sequence type @@ -235,7 +240,7 @@ def process_arg_and_flag(name, arg): return arg, False def process_arg(name, arg): - processed, has_input = process_arg_and_flag(name, arg) + processed, has_input = process_arg_and_flag(name, name, arg, ()) if has_input: input_names.add(name) return processed @@ -317,21 +322,10 @@ def process_arg(name, arg): func_out, Sequence ), "This function is only implemented for Tensors or sequences of Tensors" - # Group leaf input names by top-level argument for efficient runtime extraction - leaf_names_by_arg = {} - leaf_names = list(input_infos.keys()) - for arg_name in compiled_arg_names: - matching = [ - leaf - for leaf in leaf_names - if leaf == arg_name or leaf.startswith(f"{arg_name}.") or leaf.startswith(f"{arg_name}[") - ] - leaf_names_by_arg[arg_name] = matching - return Executable( executable, compiled_arg_names, return_single_tensor_as_sequence=isinstance(func_out, Sequence), input_infos=input_infos, - leaf_names_by_arg=leaf_names_by_arg, + access_plan_by_name=access_plan_by_name, ) diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index 1ca480f74..599983b69 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -46,7 +46,7 @@ def __init__( arg_names, return_single_tensor_as_sequence: bool, input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]], - leaf_names_by_arg: Dict[str, Sequence[str]], + access_plan_by_name: Dict[str, Tuple[str, Tuple[Union[str, int], ...]]], ): self._executable = executable @@ -79,7 +79,23 @@ def __init__( Stores metadata, like shapes and data types, for each input to the executable. """ - self._leaf_names_by_arg = leaf_names_by_arg + # Build accessor map from compile-time access plans + self._accessor_map: Dict[str, callable] = {} + name_to_index = {name: idx for idx, name in enumerate(self._arg_names)} + + def make_accessor(arg_index: int, steps: Tuple[Union[str, int], ...]): + def accessor(inputs, idx=arg_index, stps=steps): + v = inputs[idx] + for s in stps: + v = v[s] + return v + + return accessor + + self._access_plan_by_name = access_plan_by_name + for leaf_name, (arg_name, steps) in self._access_plan_by_name.items(): + idx = name_to_index[arg_name] + self._accessor_map[leaf_name] = make_accessor(idx, steps) def __str__(self) -> str: params = [ @@ -198,46 +214,20 @@ def add(a, b): ], ) - # Build a name->tensor map using precomputed leaf names to avoid unnecessary recursion + # Fetch flattened tensors directly via accessors input_info_names = list(self.input_infos.keys()) - name_to_tensor: Dict[str, Tensor] = {} - - def extract_recursive(value, name_prefix, allowed_names): - if name_prefix in allowed_names: - name_to_tensor[name_prefix] = value - return - if isinstance(value, dict): - for key, item in value.items(): - nested_name = f"{name_prefix}.{key}" - extract_recursive(item, nested_name, allowed_names) - elif isinstance(value, (list, tuple)): - for idx, item in enumerate(value): - nested_name = f"{name_prefix}[{idx}]" - extract_recursive(item, nested_name, allowed_names) - else: - return - - for name_idx, tensor in enumerate(input_tensors): - arg_name = self._arg_names[name_idx] - # Fast path: direct leaf input - if arg_name in self.input_infos: - name_to_tensor[arg_name] = tensor - continue - # If this arg has no compiled leaves beneath it, skip any recursion - allowed = self._leaf_names_by_arg.get(arg_name) - if not allowed: - continue - extract_recursive(tensor, arg_name, set(allowed)) - try: - flattened_tensors = [name_to_tensor[name] for name in input_info_names] - except KeyError as missing: - raise_error( - f"Missing runtime tensor for input `{missing.args[0]}`.", - [ - "Ensure your provided containers include tensors for all compiled inputs.", - f"Expected inputs: {input_info_names}", - ], - ) + flattened_tensors = [] + for name in input_info_names: + try: + flattened_tensors.append(self._accessor_map[name](input_tensors)) + except Exception: + raise_error( + f"Missing runtime tensor for input `{name}`.", + [ + "Ensure your provided collections include tensors for all compiled inputs.", + f"Expected inputs: {input_info_names}", + ], + ) expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()] # Validate flattened tensors against input_infos @@ -398,7 +388,7 @@ def encode_executable(executable): "executable": base64.b64encode(executable._executable.serialize()).decode(), "_return_single_tensor_as_sequence": executable._return_single_tensor_as_sequence, "input_infos": executable.input_infos, - "leaf_names_by_arg": executable._leaf_names_by_arg, + "access_plan_by_name": executable._access_plan_by_name, } @@ -410,5 +400,5 @@ def decode_executable(executable_dict): executable_dict["arg_names"], return_single_tensor_as_sequence=executable_dict["_return_single_tensor_as_sequence"], input_infos=executable_dict["input_infos"], - leaf_names_by_arg=executable_dict.get("leaf_names_by_arg"), + access_plan_by_name=executable_dict["access_plan_by_name"], ) diff --git a/tripy/nvtripy/frontend/tensor.py b/tripy/nvtripy/frontend/tensor.py index 0268acf55..50a899f2f 100644 --- a/tripy/nvtripy/frontend/tensor.py +++ b/tripy/nvtripy/frontend/tensor.py @@ -237,7 +237,7 @@ def eval(self) -> "nvtripy.Tensor": name: InputInfo(list(map(int, inp.trace_tensor.shape)), inp.dtype) for name, inp in zip(arg_names, inputs) }, - leaf_names_by_arg={name: [name] for name in arg_names}, # every argument is a direct input + access_plan_by_name={name: (name, tuple()) for name in arg_names}, ) data = executable(*inputs).trace_tensor.producer.data From 1cbd89ba00021687219b8ef75eab46658f9530e9 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Thu, 13 Nov 2025 17:15:03 -0800 Subject: [PATCH 7/7] review fixes --- tripy/nvtripy/backend/api/compile.py | 2 +- tripy/nvtripy/backend/api/executable.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tripy/nvtripy/backend/api/compile.py b/tripy/nvtripy/backend/api/compile.py index 624ace597..d9296f53a 100644 --- a/tripy/nvtripy/backend/api/compile.py +++ b/tripy/nvtripy/backend/api/compile.py @@ -240,7 +240,7 @@ def process_arg_and_flag(top_arg_name, name, arg, steps): return arg, False def process_arg(name, arg): - processed, has_input = process_arg_and_flag(name, name, arg, ()) + processed, has_input = process_arg_and_flag(name, name, arg, tuple()) if has_input: input_names.add(name) return processed diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index 599983b69..e28fd3301 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -84,9 +84,9 @@ def __init__( name_to_index = {name: idx for idx, name in enumerate(self._arg_names)} def make_accessor(arg_index: int, steps: Tuple[Union[str, int], ...]): - def accessor(inputs, idx=arg_index, stps=steps): - v = inputs[idx] - for s in stps: + def accessor(inputs): + v = inputs[arg_index] + for s in steps: v = v[s] return v @@ -220,12 +220,13 @@ def add(a, b): for name in input_info_names: try: flattened_tensors.append(self._accessor_map[name](input_tensors)) - except Exception: + except Exception as exc: raise_error( f"Missing runtime tensor for input `{name}`.", [ "Ensure your provided collections include tensors for all compiled inputs.", f"Expected inputs: {input_info_names}", + f"Note: Error was:\n{exc}", ], ) expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]