|
14 | 14 | # limitations under the License. |
15 | 15 | from typing import Any, Dict, List, Sequence, Union |
16 | 16 |
|
| 17 | +import tensorrt as trt |
17 | 18 | from nvtripy import export, utils |
| 19 | +from nvtripy.common import datatype |
18 | 20 | from nvtripy.common.exception import raise_error |
19 | 21 | from nvtripy.frontend.ops import utils as op_utils |
20 | 22 | from nvtripy.trace.ops.plugin import Plugin |
21 | 23 | from nvtripy.utils.types import str_from_type_annotation, type_str_from_arg |
22 | 24 |
|
| 25 | +TRT_FROM_TRIPY_DTYPE = { |
| 26 | + datatype.float32: trt.float32, |
| 27 | + datatype.float16: trt.float16, |
| 28 | + datatype.float8: trt.fp8, |
| 29 | + datatype.bfloat16: trt.bfloat16, |
| 30 | + datatype.int4: trt.int4, |
| 31 | + datatype.int8: trt.int8, |
| 32 | + datatype.int32: trt.int32, |
| 33 | + datatype.int64: trt.int64, |
| 34 | + datatype.bool: trt.bool, |
| 35 | +} |
| 36 | + |
| 37 | +TRIPY_FROM_TRT_DTYPE = {val: key for key, val in TRT_FROM_TRIPY_DTYPE.items()} |
| 38 | + |
23 | 39 |
|
24 | 40 | # TODO (pranavm): Add link to custom layers guide once published |
25 | 41 | @export.public_api(document_under="operations/functions") |
@@ -93,11 +109,20 @@ def get_plugins_in_namespace(ns): |
93 | 109 | input_descs = [None] * len(inputs) |
94 | 110 | for i in range(len(inputs)): |
95 | 111 | input_descs[i] = trtp._tensor.TensorDesc() |
96 | | - input_descs[i].dtype = inputs[i].dtype |
| 112 | + input_descs[i].dtype = TRT_FROM_TRIPY_DTYPE[inputs[i].dtype] |
97 | 113 | input_descs[i].shape_expr = trtp._tensor.ShapeExprs(inputs[i].rank, _is_dummy=True) |
98 | 114 | input_descs[i]._immutable = True |
| 115 | + |
99 | 116 | output_descs = utils.utils.make_tuple(trtp_op.register_func(*input_descs, attrs)) |
100 | 117 |
|
101 | | - output_info = [(len(desc.shape_expr), desc.dtype) for desc in output_descs] |
| 118 | + def tripy_from_trt_dtype(dtype): |
| 119 | + if dtype not in TRIPY_FROM_TRT_DTYPE: |
| 120 | + raise_error( |
| 121 | + f"Unsupported TensorRT data type: '{dtype}'.", |
| 122 | + details=[f"Supported types are: {list(TRIPY_FROM_TRT_DTYPE.keys())}."], |
| 123 | + ) |
| 124 | + return TRIPY_FROM_TRT_DTYPE[dtype] |
| 125 | + |
| 126 | + output_info = [(len(desc.shape_expr), tripy_from_trt_dtype(desc.dtype)) for desc in output_descs] |
102 | 127 |
|
103 | 128 | return op_utils.create_op(Plugin, inputs, name, "1", namespace, output_info, kwargs) |
0 commit comments