1717
1818from typing import List
1919
20- import mlir_tensorrt .compiler .api as compiler
2120import mlir_tensorrt .runtime .api as runtime
2221
22+ from tripy .backend .api .stream import default_stream
2323from tripy .backend .mlir .memref import create_empty_memref
24+ from tripy .backend .mlir .utils import MLIRRuntimeClient , convert_runtime_dtype_to_tripy_dtype
2425from tripy .backend .utils import TensorInfo
2526from tripy .common import datatype , device
2627from tripy .common .exception import raise_error
28+ from tripy .common .utils import convert_list_to_array
2729from tripy .utils import make_tuple
2830
2931
3032class Executor :
3133 def __init__ (self , executable : runtime .Executable ) -> None :
32- from tripy .backend .api .stream import default_stream
33- from tripy .backend .mlir .utils import MLIRRuntimeClient
34-
3534 self .runtime_client = MLIRRuntimeClient ()
3635 session_options = runtime .RuntimeSessionOptions (num_devices = 1 , device_id = 0 )
3736 self .session = runtime .RuntimeSession (session_options , executable )
3837 self .device = self .runtime_client .get_devices ()[0 ] # Assume a single device is available.
3938 self .signature = executable .get_signature ("main" )
4039 self .stream = default_stream ()
40+ self .num_input_args = self .signature .get_num_input_args ()
41+ self .num_output_args = self .signature .get_num_output_args ()
42+ self .output_args = [
43+ self .signature .get_arg (index + self .num_input_args ) for index in range (self .num_output_args )
44+ ]
45+ self .output_memrefs = [runtime .MemRefType (out ) for out in self .output_args ]
4146
4247 def _create_shape_memref (self , shape ):
43- from tripy .common .utils import convert_list_to_array
44-
4548 shape = make_tuple (shape )
4649 if len (shape ) == 0 :
4750 # create an empty memref
@@ -55,34 +58,21 @@ def _create_shape_memref(self, shape):
5558 stream = self .stream ._active_cuda_stream ,
5659 )
5760
58- def _get_inputs_runtime_shape (self , inputs ):
59- inputs_shape = []
60- for input in inputs :
61- inputs_shape .append (input .trace_tensor .producer .data .shape )
62- return inputs_shape
63-
6461 def _get_outputs_shape (self ):
65- offset = self .signature .get_num_input_args ()
6662 outputs_shape = []
6763 all_outputs_known = True
68- for output_index in range (self .signature .get_num_output_args ()):
69- arg_index = output_index + offset
70- arg = self .signature .get_arg (arg_index )
71- assert compiler .MemRefType .isinstance (arg )
72- memref = runtime .MemRefType (arg )
73- rank = len (memref .shape )
74-
64+ for memref in self .output_memrefs :
7565 outputs_shape .append (memref .shape )
76- if rank > 0 :
77- all_outputs_known &= all (dim >= 0 for dim in memref .shape )
66+ all_outputs_known &= all (dim >= 0 for dim in memref .shape )
7867 return outputs_shape , all_outputs_known
7968
80- def _execute_shape_inference (self , inputs_shape , outputs_shape ):
81- # Only execute shape inference if shape function name is valid.
82- assert (
83- self . signature . get_shape_func_name ( )
84- ), f"Shape inference function is missing while output shapes are not known."
69+ def _get_inputs_runtime_shape (self , inputs ):
70+ inputs_shape = []
71+ for input in inputs :
72+ inputs_shape . append ( input . trace_tensor . producer . data . shape )
73+ return inputs_shape
8574
75+ def _execute_shape_inference (self , inputs_shape , outputs_shape ):
8676 inputs_shape_memref = [self ._create_shape_memref (inp_shape ) for inp_shape in inputs_shape ]
8777 outputs_shape_memref = [self ._create_shape_memref (out_shape ) for out_shape in outputs_shape ]
8878 self .session .execute_function (
@@ -93,41 +83,24 @@ def _execute_shape_inference(self, inputs_shape, outputs_shape):
9383 return outputs_runtime_shape
9484
9585 def _get_output_tensor_info (self , outputs_runtime_shape , output_devices ):
96- from tripy .backend .mlir .utils import convert_runtime_dtype_to_tripy_dtype
97-
98- offset = self .signature .get_num_input_args ()
9986 outputs_tensor_info = []
100- for output_index in range (self .signature .get_num_output_args ()):
101- arg_index = output_index + offset
102- arg = self .signature .get_arg (arg_index )
103- assert compiler .MemRefType .isinstance (arg ) or compiler .ScalarType .isinstance (
104- arg
105- ), "Argument must be either MemRefType or ScalarType"
106- assert compiler .MemRefType .isinstance (
107- arg
108- ), "ScalarType argument are not yet supported" # 158: Add scalar type output argument support.
109- memref = compiler .MemRefType (arg )
87+ for index in range (self .num_output_args ):
88+ memref = self .output_memrefs [index ]
11089 dtype = convert_runtime_dtype_to_tripy_dtype (memref .dtype )
111- device_type = "gpu" if memref .address_space == runtime .PointerType .device else "cpu"
112- if output_devices [output_index ]:
113- device_type = output_devices [output_index ].kind
114- is_static_shape = all (dim >= 0 for dim in memref .shape )
115- if is_static_shape :
116- outputs_tensor_info .append (
117- TensorInfo (len (memref .shape ), tuple (memref .shape ), dtype , device (device_type ))
118- )
119- else :
120- runtime_shape = [
121- rs if dim < 0 else dim for dim , rs in zip (memref .shape , outputs_runtime_shape [output_index ])
122- ]
123- outputs_tensor_info .append (
124- TensorInfo (
125- len (runtime_shape ),
126- tuple (runtime_shape ),
127- dtype ,
128- device (device_type ),
129- )
90+
91+ output_device = output_devices [index ]
92+ if not output_device :
93+ output_device = device (("gpu" if memref .address_space == runtime .PointerType .device else "cpu" , 0 ))
94+
95+ runtime_shape = [rs if dim < 0 else dim for dim , rs in zip (memref .shape , outputs_runtime_shape [index ])]
96+ outputs_tensor_info .append (
97+ TensorInfo (
98+ len (runtime_shape ),
99+ tuple (runtime_shape ),
100+ dtype ,
101+ output_device ,
130102 )
103+ )
131104 return outputs_tensor_info
132105
133106 def get_output_tensor_runtime_info (self , inputs , output_devices = List [device ]):
0 commit comments