Skip to content

Commit db60c25

Browse files
committed
Enable direct fast access of inputs inside containers
1 parent de3ebe0 commit db60c25

File tree

3 files changed

+45
-61
lines changed

3 files changed

+45
-61
lines changed

tripy/nvtripy/backend/api/compile.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def add(a, b):
157157
input_names = set()
158158
input_infos = {}
159159
trace_inputs = [] # flattened list of trace input tensors in argument order
160+
access_plan_by_name: Dict[str, tuple] = {}
160161

161162
# Set up names for the weights in the module to make the trace easier to read.
162163
if isinstance(func, Module):
@@ -207,18 +208,22 @@ def process_arg_input_info(name, arg):
207208

208209
return arg
209210

210-
def process_arg_and_flag(name, arg):
211+
def process_arg_and_flag(top_arg_name, name, arg, steps):
211212
# Handle individual InputInfo or DimensionInputInfo objects
212213
if isinstance(arg, (InputInfo, DimensionInputInfo)):
213-
return process_arg_input_info(name, arg), True
214+
tensor_or_dim = process_arg_input_info(name, arg)
215+
access_plan_by_name[name] = (top_arg_name, tuple(steps))
216+
return tensor_or_dim, True
214217

215218
# Handle containers of InputInfo objects
216219
if isinstance(arg, dict):
217220
result = {}
218221
has_input = False
219222
for key, value in arg.items():
220223
nested_name = f"{name}.{key}"
221-
processed_child, child_has_input = process_arg_and_flag(nested_name, value)
224+
processed_child, child_has_input = process_arg_and_flag(
225+
top_arg_name, nested_name, value, (*steps, str(key))
226+
)
222227
result[key] = processed_child
223228
has_input = has_input or child_has_input
224229
return result, has_input
@@ -227,15 +232,15 @@ def process_arg_and_flag(name, arg):
227232
has_input = False
228233
for idx, value in enumerate(arg):
229234
nested_name = f"{name}[{idx}]"
230-
processed_child, child_has_input = process_arg_and_flag(nested_name, value)
235+
processed_child, child_has_input = process_arg_and_flag(top_arg_name, nested_name, value, (*steps, idx))
231236
result_list.append(processed_child)
232237
has_input = has_input or child_has_input
233238
return type(arg)(result_list), has_input # preserve sequence type
234239

235240
return arg, False
236241

237242
def process_arg(name, arg):
238-
processed, has_input = process_arg_and_flag(name, arg)
243+
processed, has_input = process_arg_and_flag(name, name, arg, ())
239244
if has_input:
240245
input_names.add(name)
241246
return processed
@@ -317,21 +322,10 @@ def process_arg(name, arg):
317322
func_out, Sequence
318323
), "This function is only implemented for Tensors or sequences of Tensors"
319324

320-
# Group leaf input names by top-level argument for efficient runtime extraction
321-
leaf_names_by_arg = {}
322-
leaf_names = list(input_infos.keys())
323-
for arg_name in compiled_arg_names:
324-
matching = [
325-
leaf
326-
for leaf in leaf_names
327-
if leaf == arg_name or leaf.startswith(f"{arg_name}.") or leaf.startswith(f"{arg_name}[")
328-
]
329-
leaf_names_by_arg[arg_name] = matching
330-
331325
return Executable(
332326
executable,
333327
compiled_arg_names,
334328
return_single_tensor_as_sequence=isinstance(func_out, Sequence),
335329
input_infos=input_infos,
336-
leaf_names_by_arg=leaf_names_by_arg,
330+
access_plan_by_name=access_plan_by_name,
337331
)

tripy/nvtripy/backend/api/executable.py

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
arg_names,
4747
return_single_tensor_as_sequence: bool,
4848
input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]],
49-
leaf_names_by_arg: Dict[str, Sequence[str]],
49+
access_plan_by_name: Dict[str, Tuple[str, Tuple[Union[str, int], ...]]],
5050
):
5151
self._executable = executable
5252

@@ -79,7 +79,23 @@ def __init__(
7979
Stores metadata, like shapes and data types, for each input to the executable.
8080
"""
8181

82-
self._leaf_names_by_arg = leaf_names_by_arg
82+
# Build accessor map from compile-time access plans
83+
self._accessor_map: Dict[str, callable] = {}
84+
name_to_index = {name: idx for idx, name in enumerate(self._arg_names)}
85+
86+
def make_accessor(arg_index: int, steps: Tuple[Union[str, int], ...]):
87+
def accessor(inputs, idx=arg_index, stps=steps):
88+
v = inputs[idx]
89+
for s in stps:
90+
v = v[s]
91+
return v
92+
93+
return accessor
94+
95+
self._access_plan_by_name = access_plan_by_name
96+
for leaf_name, (arg_name, steps) in self._access_plan_by_name.items():
97+
idx = name_to_index[arg_name]
98+
self._accessor_map[leaf_name] = make_accessor(idx, steps)
8399

84100
def __str__(self) -> str:
85101
params = [
@@ -198,46 +214,20 @@ def add(a, b):
198214
],
199215
)
200216

201-
# Build a name->tensor map using precomputed leaf names to avoid unnecessary recursion
217+
# Fetch flattened tensors directly via accessors
202218
input_info_names = list(self.input_infos.keys())
203-
name_to_tensor: Dict[str, Tensor] = {}
204-
205-
def extract_recursive(value, name_prefix, allowed_names):
206-
if name_prefix in allowed_names:
207-
name_to_tensor[name_prefix] = value
208-
return
209-
if isinstance(value, dict):
210-
for key, item in value.items():
211-
nested_name = f"{name_prefix}.{key}"
212-
extract_recursive(item, nested_name, allowed_names)
213-
elif isinstance(value, (list, tuple)):
214-
for idx, item in enumerate(value):
215-
nested_name = f"{name_prefix}[{idx}]"
216-
extract_recursive(item, nested_name, allowed_names)
217-
else:
218-
return
219-
220-
for name_idx, tensor in enumerate(input_tensors):
221-
arg_name = self._arg_names[name_idx]
222-
# Fast path: direct leaf input
223-
if arg_name in self.input_infos:
224-
name_to_tensor[arg_name] = tensor
225-
continue
226-
# If this arg has no compiled leaves beneath it, skip any recursion
227-
allowed = self._leaf_names_by_arg.get(arg_name)
228-
if not allowed:
229-
continue
230-
extract_recursive(tensor, arg_name, set(allowed))
231-
try:
232-
flattened_tensors = [name_to_tensor[name] for name in input_info_names]
233-
except KeyError as missing:
234-
raise_error(
235-
f"Missing runtime tensor for input `{missing.args[0]}`.",
236-
[
237-
"Ensure your provided containers include tensors for all compiled inputs.",
238-
f"Expected inputs: {input_info_names}",
239-
],
240-
)
219+
flattened_tensors = []
220+
for name in input_info_names:
221+
try:
222+
flattened_tensors.append(self._accessor_map[name](input_tensors))
223+
except Exception:
224+
raise_error(
225+
f"Missing runtime tensor for input `{name}`.",
226+
[
227+
"Ensure your provided collections include tensors for all compiled inputs.",
228+
f"Expected inputs: {input_info_names}",
229+
],
230+
)
241231
expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]
242232

243233
# Validate flattened tensors against input_infos
@@ -398,7 +388,7 @@ def encode_executable(executable):
398388
"executable": base64.b64encode(executable._executable.serialize()).decode(),
399389
"_return_single_tensor_as_sequence": executable._return_single_tensor_as_sequence,
400390
"input_infos": executable.input_infos,
401-
"leaf_names_by_arg": executable._leaf_names_by_arg,
391+
"access_plan_by_name": executable._access_plan_by_name,
402392
}
403393

404394

@@ -410,5 +400,5 @@ def decode_executable(executable_dict):
410400
executable_dict["arg_names"],
411401
return_single_tensor_as_sequence=executable_dict["_return_single_tensor_as_sequence"],
412402
input_infos=executable_dict["input_infos"],
413-
leaf_names_by_arg=executable_dict.get("leaf_names_by_arg"),
403+
access_plan_by_name=executable_dict["access_plan_by_name"],
414404
)

tripy/nvtripy/frontend/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def eval(self) -> "nvtripy.Tensor":
237237
name: InputInfo(list(map(int, inp.trace_tensor.shape)), inp.dtype)
238238
for name, inp in zip(arg_names, inputs)
239239
},
240-
leaf_names_by_arg={name: [name] for name in arg_names}, # every argument is a direct input
240+
access_plan_by_name={name: (name, tuple()) for name in arg_names},
241241
)
242242
data = executable(*inputs).trace_tensor.producer.data
243243

0 commit comments

Comments
 (0)