1818
1919import mlir_tensorrt .runtime .api as runtime
2020from nvtripy import config , export
21- from nvtripy .backend .api .input_info import InputInfo
21+ from nvtripy .backend .api .input_info import InputInfo , DimensionInputInfo
2222from nvtripy .backend .api .stream import default_stream
2323from nvtripy .backend .mlir .utils import MLIRRuntimeClient
2424from nvtripy .common .exception import raise_error
@@ -41,7 +41,11 @@ class Executable:
4141 # `return_single_tensor_as_sequence` indicates whether the return type should be a sequence even if
4242 # there is only one output.
4343 def __init__ (
44- self , executable , arg_names , return_single_tensor_as_sequence : bool , input_infos : Dict [str , InputInfo ]
44+ self ,
45+ executable ,
46+ arg_names ,
47+ return_single_tensor_as_sequence : bool ,
48+ input_infos : Dict [str , Union [InputInfo , DimensionInputInfo ]],
4549 ):
4650 self ._executable = executable
4751
@@ -69,7 +73,7 @@ def __init__(
6973
7074 self .__signature__ = inspect .Signature (params , return_annotation = return_annotation )
7175
72- self .input_infos : Dict [str , InputInfo ] = input_infos
76+ self .input_infos : Dict [str , Union [ InputInfo , DimensionInputInfo ] ] = input_infos
7377 """
7478 Stores metadata, like shapes and data types, for each input to the executable.
7579 """
@@ -191,15 +195,16 @@ def add(a, b):
191195 ],
192196 )
193197
194- for tensor in input_tensors :
198+ expected_devices = ["gpu" if isinstance (info , InputInfo ) else "cpu" for info in self .input_infos .values ()]
199+ for tensor , expected_device , arg_name in zip (input_tensors , expected_devices , self ._arg_names ):
195200 producer = tensor .trace_tensor .producer
196- if not isinstance (producer , Constant ) or tensor .device .kind != "gpu" :
201+ if not isinstance (producer , Constant ):
202+ raise_error (f"Tensor `{ arg_name } ` is not evaluated." , ["Hint: Try calling `.eval()` on the tensor." ])
203+ if tensor .device .kind != expected_device :
197204 raise_error (
198- "Inputs to compiled executables must be evaluated tensors on the GPU ." ,
205+ "Unexpected tensor device ." ,
199206 [
200- "Got input" + (f" on device '{ tensor .device } ':" if tensor .device .kind != "gpu" else ":" ),
201- tensor ,
202- "Hint: Try calling `.eval()` on the tensor to ensure it is a GPU constant." ,
207+ f"For tensor: `{ arg_name } `, expected to be on device: { expected_device } but got: { tensor .device .kind } .\n " ,
203208 ],
204209 )
205210
@@ -212,7 +217,11 @@ def add(a, b):
212217 # TODO: Evaluate whether this should be moved into the executor
213218 if "function expects a memref type with element type" in str (err ):
214219 # If the problem is a mismatched data type, we can provide a better error message than the executor can.
215- expected_input_dtypes = [info .dtype for info in self .input_infos .values ()]
220+ from nvtripy .common .datatype import int32
221+
222+ expected_input_dtypes = [
223+ info .dtype if isinstance (info , InputInfo ) else int32 for info in self .input_infos .values ()
224+ ]
216225 for tensor , dtype , arg_name in zip (input_tensors , expected_input_dtypes , self ._arg_names ):
217226 if tensor .dtype != dtype :
218227 raise_error (
@@ -225,7 +234,9 @@ def add(a, b):
225234 ),
226235 )
227236 elif "InternalError: failed to set input shape" in str (err ) or "Runtime shape mismatch" in str (err ):
228- expected_input_shapes = [info .shape_bounds for info in self .input_infos .values ()]
237+ expected_input_shapes = [
238+ info .shape_bounds if isinstance (info , InputInfo ) else tuple () for info in self .input_infos .values ()
239+ ]
229240 for tensor , expected_bounds , arg_name in zip (input_tensors , expected_input_shapes , self ._arg_names ):
230241 shape = tensor .shape
231242
0 commit comments