Skip to content

Commit 899052b

Browse files
Updates plugin frontend function to convert data types between Tripy/TensorRT
1 parent effd960 commit 899052b

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

tripy/nvtripy/frontend/ops/plugin_qdp.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,28 @@
1414
# limitations under the License.
1515
from typing import Any, Dict, List, Sequence, Union
1616

17+
import tensorrt as trt
1718
from nvtripy import export, utils
19+
from nvtripy.common import datatype
1820
from nvtripy.common.exception import raise_error
1921
from nvtripy.frontend.ops import utils as op_utils
2022
from nvtripy.trace.ops.plugin import Plugin
2123
from nvtripy.utils.types import str_from_type_annotation, type_str_from_arg
2224

25+
TRT_FROM_TRIPY_DTYPE = {
26+
datatype.float32: trt.float32,
27+
datatype.float16: trt.float16,
28+
datatype.float8: trt.fp8,
29+
datatype.bfloat16: trt.bfloat16,
30+
datatype.int4: trt.int4,
31+
datatype.int8: trt.int8,
32+
datatype.int32: trt.int32,
33+
datatype.int64: trt.int64,
34+
datatype.bool: trt.bool,
35+
}
36+
37+
TRIPY_FROM_TRT_DTYPE = {val: key for key, val in TRT_FROM_TRIPY_DTYPE.items()}
38+
2339

2440
# TODO (pranavm): Add link to custom layers guide once published
2541
@export.public_api(document_under="operations/functions")
@@ -93,11 +109,20 @@ def get_plugins_in_namespace(ns):
93109
input_descs = [None] * len(inputs)
94110
for i in range(len(inputs)):
95111
input_descs[i] = trtp._tensor.TensorDesc()
96-
input_descs[i].dtype = inputs[i].dtype
112+
input_descs[i].dtype = TRT_FROM_TRIPY_DTYPE[inputs[i].dtype]
97113
input_descs[i].shape_expr = trtp._tensor.ShapeExprs(inputs[i].rank, _is_dummy=True)
98114
input_descs[i]._immutable = True
115+
99116
output_descs = utils.utils.make_tuple(trtp_op.register_func(*input_descs, attrs))
100117

101-
output_info = [(len(desc.shape_expr), desc.dtype) for desc in output_descs]
118+
def tripy_from_trt_dtype(dtype):
119+
if dtype not in TRIPY_FROM_TRT_DTYPE:
120+
raise_error(
121+
f"Unsupported TensorRT data type: '{dtype}'.",
122+
details=[f"Supported types are: {list(TRIPY_FROM_TRT_DTYPE.keys())}."],
123+
)
124+
return TRIPY_FROM_TRT_DTYPE[dtype]
125+
126+
output_info = [(len(desc.shape_expr), tripy_from_trt_dtype(desc.dtype)) for desc in output_descs]
102127

103128
return op_utils.create_op(Plugin, inputs, name, "1", namespace, output_info, kwargs)

tripy/tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import nvtripy as tp
2222
import pytest
23+
import tensorrt as trt
2324
import tensorrt.plugin as trtp
2425
import torch
2526
import triton
@@ -64,6 +65,8 @@ def add_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
6465

6566
@trtp.register("example::elemwise_add_plugin")
6667
def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc:
68+
# QDPs should receive and return TRT data types
69+
assert inp0.dtype == trt.float32
6770
return inp0.like()
6871

6972

tripy/tests/frontend/ops/test_plugin_qdp.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,19 @@
1414
# limitations under the License.
1515
import nvtripy as tp
1616
import pytest
17+
import tensorrt as trt
18+
from nvtripy.common.datatype import DATA_TYPES
19+
from nvtripy.frontend.ops.plugin_qdp import TRIPY_FROM_TRT_DTYPE, TRT_FROM_TRIPY_DTYPE
1720
from tests import helper
1821

1922

23+
@pytest.mark.parametrize("dtype", DATA_TYPES.values())
24+
def test_dtype_trt_conversion(dtype):
25+
assert dtype in TRT_FROM_TRIPY_DTYPE
26+
assert isinstance(TRT_FROM_TRIPY_DTYPE[dtype], trt.DataType)
27+
assert TRIPY_FROM_TRT_DTYPE[TRT_FROM_TRIPY_DTYPE[dtype]] == dtype
28+
29+
2030
class TestQuicklyDeployablePlugin:
2131
@pytest.mark.parametrize(
2232
"plugin_id, err",

0 commit comments

Comments
 (0)