@@ -46,6 +46,7 @@ def __init__(
4646 arg_names ,
4747 return_single_tensor_as_sequence : bool ,
4848 input_infos : Dict [str , Union [InputInfo , DimensionInputInfo ]],
49+ leaf_names_by_arg : Dict [str , Sequence [str ]],
4950 ):
5051 self ._executable = executable
5152
@@ -78,6 +79,8 @@ def __init__(
7879 Stores metadata, like shapes and data types, for each input to the executable.
7980 """
8081
82+ self ._leaf_names_by_arg = leaf_names_by_arg
83+
8184 def __str__ (self ) -> str :
8285 params = [
8386 f"{ name } : { str_from_type_annotation (param .annotation )} "
@@ -195,34 +198,36 @@ def add(a, b):
195198 ],
196199 )
197200
198- # Recursively build a name->tensor map
199- def extract_inputs (tensors , input_info_names ):
200- name_to_tensor = {}
201-
202- def extract_recursive (value , name_prefix ):
203- if name_prefix in input_info_names :
204- name_to_tensor [name_prefix ] = value
205- return
206- if isinstance (value , dict ):
207- for key , item in value .items ():
208- nested_name = f"{ name_prefix } .{ key } "
209- extract_recursive (item , nested_name )
210- elif isinstance (value , (list , tuple )):
211- for idx , item in enumerate (value ):
212- nested_name = f"{ name_prefix } [{ idx } ]"
213- extract_recursive (item , nested_name )
214- else :
215- print (f"Leaf tensor: { name_prefix } : { value } " )
216- return
217-
218- for name_idx , tensor in enumerate (tensors ):
219- arg_name = self ._arg_names [name_idx ]
220- extract_recursive (tensor , arg_name )
221-
222- return name_to_tensor
223-
201+ # Build a name->tensor map using precomputed leaf names to avoid unnecessary recursion
224202 input_info_names = list (self .input_infos .keys ())
225- name_to_tensor = extract_inputs (input_tensors , set (input_info_names ))
203+ name_to_tensor : Dict [str , Tensor ] = {}
204+
205+ def extract_recursive (value , name_prefix , allowed_names ):
206+ if name_prefix in allowed_names :
207+ name_to_tensor [name_prefix ] = value
208+ return
209+ if isinstance (value , dict ):
210+ for key , item in value .items ():
211+ nested_name = f"{ name_prefix } .{ key } "
212+ extract_recursive (item , nested_name , allowed_names )
213+ elif isinstance (value , (list , tuple )):
214+ for idx , item in enumerate (value ):
215+ nested_name = f"{ name_prefix } [{ idx } ]"
216+ extract_recursive (item , nested_name , allowed_names )
217+ else :
218+ return
219+
220+ for name_idx , tensor in enumerate (input_tensors ):
221+ arg_name = self ._arg_names [name_idx ]
222+ # Fast path: direct leaf input
223+ if arg_name in self .input_infos :
224+ name_to_tensor [arg_name ] = tensor
225+ continue
226+ # If this arg has no compiled leaves beneath it, skip any recursion
227+ allowed = self ._leaf_names_by_arg .get (arg_name )
228+ if not allowed :
229+ continue
230+ extract_recursive (tensor , arg_name , set (allowed ))
226231 try :
227232 flattened_tensors = [name_to_tensor [name ] for name in input_info_names ]
228233 except KeyError as missing :
@@ -267,7 +272,7 @@ def extract_recursive(value, name_prefix):
267272 expected_input_dtypes = [
268273 info .dtype if isinstance (info , InputInfo ) else int32 for info in self .input_infos .values ()
269274 ]
270- for tensor , dtype , arg_name in zip (input_tensors , expected_input_dtypes , self ._arg_names ):
275+ for tensor , dtype , arg_name in zip (flattened_tensors , expected_input_dtypes , self .input_infos . keys () ):
271276 if tensor .dtype != dtype :
272277 raise_error (
273278 f"Unexpected tensor data type." ,
@@ -282,16 +287,17 @@ def extract_recursive(value, name_prefix):
282287 expected_input_shapes = [
283288 info .shape_bounds if isinstance (info , InputInfo ) else tuple () for info in self .input_infos .values ()
284289 ]
285- for tensor , expected_bounds , arg_name in zip (input_tensors , expected_input_shapes , self ._arg_names ):
290+ for tensor , expected_bounds , arg_name in zip (
291+ flattened_tensors , expected_input_shapes , self .input_infos .keys ()
292+ ):
286293 shape = tensor .shape
287294
288295 if len (shape ) != len (expected_bounds .min ):
289296 raise_error (
290297 f"Unexpected tensor rank." ,
291298 [
292299 f"For tensor: `{ arg_name } `, expected a rank of: { len (expected_bounds .min )} but got: { len (shape )} .\n "
293- f"Note: The provided argument was: " ,
294- tensor ,
300+ f"Note: The provided argument was a tensor with shape: { shape } " ,
295301 ],
296302 )
297303
@@ -302,8 +308,7 @@ def extract_recursive(value, name_prefix):
302308 [
303309 f"For tensor: `{ arg_name } `, expected a shape within the bounds: min={ expected_bounds .min } , max={ expected_bounds .max } , but got: { shape } .\n "
304310 f"Dimension { i } has a shape of { shape [i ]} , which is not within the expected bounds of [{ expected_bounds .min [i ]} , { expected_bounds .max [i ]} ].\n "
305- f"Note: The provided argument was: " ,
306- tensor ,
311+ f"Note: The provided argument was a tensor with shape: { shape } " ,
307312 ],
308313 )
309314 raise_error (str (err ))
0 commit comments