@@ -46,7 +46,7 @@ def __init__(
4646 arg_names ,
4747 return_single_tensor_as_sequence : bool ,
4848 input_infos : Dict [str , Union [InputInfo , DimensionInputInfo ]],
49- leaf_names_by_arg : Dict [str , Sequence [str ]],
49+ access_plan_by_name : Dict [str , Tuple [str , Tuple [ Union [ str , int ], ...] ]],
5050 ):
5151 self ._executable = executable
5252
@@ -79,7 +79,23 @@ def __init__(
7979 Stores metadata, like shapes and data types, for each input to the executable.
8080 """
8181
82- self ._leaf_names_by_arg = leaf_names_by_arg
82+ # Build accessor map from compile-time access plans
83+ self ._accessor_map : Dict [str , callable ] = {}
84+ name_to_index = {name : idx for idx , name in enumerate (self ._arg_names )}
85+
86+ def make_accessor (arg_index : int , steps : Tuple [Union [str , int ], ...]):
87+ def accessor (inputs , idx = arg_index , stps = steps ):
88+ v = inputs [idx ]
89+ for s in stps :
90+ v = v [s ]
91+ return v
92+
93+ return accessor
94+
95+ self ._access_plan_by_name = access_plan_by_name
96+ for leaf_name , (arg_name , steps ) in self ._access_plan_by_name .items ():
97+ idx = name_to_index [arg_name ]
98+ self ._accessor_map [leaf_name ] = make_accessor (idx , steps )
8399
84100 def __str__ (self ) -> str :
85101 params = [
@@ -198,46 +214,20 @@ def add(a, b):
198214 ],
199215 )
200216
201- # Build a name->tensor map using precomputed leaf names to avoid unnecessary recursion
217+ # Fetch flattened tensors directly via accessors
202218 input_info_names = list (self .input_infos .keys ())
203- name_to_tensor : Dict [str , Tensor ] = {}
204-
205- def extract_recursive (value , name_prefix , allowed_names ):
206- if name_prefix in allowed_names :
207- name_to_tensor [name_prefix ] = value
208- return
209- if isinstance (value , dict ):
210- for key , item in value .items ():
211- nested_name = f"{ name_prefix } .{ key } "
212- extract_recursive (item , nested_name , allowed_names )
213- elif isinstance (value , (list , tuple )):
214- for idx , item in enumerate (value ):
215- nested_name = f"{ name_prefix } [{ idx } ]"
216- extract_recursive (item , nested_name , allowed_names )
217- else :
218- return
219-
220- for name_idx , tensor in enumerate (input_tensors ):
221- arg_name = self ._arg_names [name_idx ]
222- # Fast path: direct leaf input
223- if arg_name in self .input_infos :
224- name_to_tensor [arg_name ] = tensor
225- continue
226- # If this arg has no compiled leaves beneath it, skip any recursion
227- allowed = self ._leaf_names_by_arg .get (arg_name )
228- if not allowed :
229- continue
230- extract_recursive (tensor , arg_name , set (allowed ))
231- try :
232- flattened_tensors = [name_to_tensor [name ] for name in input_info_names ]
233- except KeyError as missing :
234- raise_error (
235- f"Missing runtime tensor for input `{ missing .args [0 ]} `." ,
236- [
237- "Ensure your provided containers include tensors for all compiled inputs." ,
238- f"Expected inputs: { input_info_names } " ,
239- ],
240- )
219+ flattened_tensors = []
220+ for name in input_info_names :
221+ try :
222+ flattened_tensors .append (self ._accessor_map [name ](input_tensors ))
223+ except Exception :
224+ raise_error (
225+ f"Missing runtime tensor for input `{ name } `." ,
226+ [
227+ "Ensure your provided collections include tensors for all compiled inputs." ,
228+ f"Expected inputs: { input_info_names } " ,
229+ ],
230+ )
241231 expected_devices = ["gpu" if isinstance (info , InputInfo ) else "cpu" for info in self .input_infos .values ()]
242232
243233 # Validate flattened tensors against input_infos
@@ -398,7 +388,7 @@ def encode_executable(executable):
398388 "executable" : base64 .b64encode (executable ._executable .serialize ()).decode (),
399389 "_return_single_tensor_as_sequence" : executable ._return_single_tensor_as_sequence ,
400390 "input_infos" : executable .input_infos ,
401- "leaf_names_by_arg " : executable ._leaf_names_by_arg ,
391+ "access_plan_by_name " : executable ._access_plan_by_name ,
402392 }
403393
404394
@@ -410,5 +400,5 @@ def decode_executable(executable_dict):
410400 executable_dict ["arg_names" ],
411401 return_single_tensor_as_sequence = executable_dict ["_return_single_tensor_as_sequence" ],
412402 input_infos = executable_dict ["input_infos" ],
413- leaf_names_by_arg = executable_dict . get ( "leaf_names_by_arg" ) ,
403+ access_plan_by_name = executable_dict [ "access_plan_by_name" ] ,
414404 )
0 commit comments