Skip to content

Commit b18d15c

Browse files
committed
[TensorRT] Copy tensorrt.host_tensor attribute in outline pass
WIP: Add tp.DimensionInputInfo, support shape tensor input
1 parent c1d6e9b commit b18d15c

File tree

8 files changed

+260
-45
lines changed

8 files changed

+260
-45
lines changed

mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,17 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
280280
mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName();
281281
StringRef tensorrtDimensionNamesAttrName =
282282
mlir::tensorrt::TensorRTDialect::getDimensionNamesArgAttrName();
283+
StringRef tensorrtValueBoundsAttrName =
284+
mlir::tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName();
285+
StringRef hostTensorAttrName = mlir::getHostTensorArgAttrName();
286+
StringRef memorySpaceAttrName =
287+
plan::PlanDialect::getMemorySpaceConstraintAttrName();
283288

284289
SmallVector<Attribute> profileAttrsPerInput;
285290
SmallVector<Attribute> dimensionNamesAttrsPerInput;
286291
for (Value v : inputs) {
287292
auto rtt = dyn_cast<RankedTensorType>(v.getType());
288-
if (!rtt || rtt.hasStaticShape()) {
293+
if (!rtt) {
289294
profileAttrsPerInput.push_back(Attribute{});
290295
dimensionNamesAttrsPerInput.push_back(Attribute{});
291296
continue;
@@ -299,30 +304,41 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule,
299304
}
300305

301306
int64_t argIndex = blockArg.getArgNumber();
302-
profileAttrsPerInput.push_back(
303-
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
304-
argIndex, tensorrtShapeBoundsAttrName));
305-
306-
dimensionNamesAttrsPerInput.push_back(
307-
parentFunc.getArgAttrOfType<DictionaryAttr>(
308-
argIndex, tensorrtDimensionNamesAttrName));
309-
310-
if (!profileAttrsPerInput.back()) {
311-
return emitError(blockArg.getLoc())
312-
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
313-
<< ") of argument " << argIndex << " is not set";
307+
// Get shape profile and dynamision name attributes of the input
308+
if (rtt.hasStaticShape()) {
309+
// static-shaped argument can only have value bound attr (shape input)
310+
auto valueBoundAttr =
311+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
312+
argIndex, tensorrtValueBoundsAttrName);
313+
if (valueBoundAttr) {
314+
func->setArgAttr(argIndex, tensorrtValueBoundsAttrName, valueBoundAttr);
315+
}
316+
// Get host tensor attribute of the input
317+
auto memorySpaceAttr = parentFunc.getArgAttr(argIndex, memorySpaceAttrName);
318+
if (memorySpaceAttr) {
319+
func->setArgAttr(argIndex, memorySpaceAttrName, memorySpaceAttr);
320+
// Add tensorrt.host_tensor attr, it is needed by NetworkEncoder for now
321+
func->setArgAttr(argIndex, hostTensorAttrName, rewriter.getUnitAttr());
322+
}
323+
} else {
324+
auto shapeBoundAttr =
325+
parentFunc.getArgAttrOfType<tensorrt::ShapeProfileAttr>(
326+
argIndex, tensorrtShapeBoundsAttrName);
327+
if (!shapeBoundAttr) {
328+
return emitError(blockArg.getLoc())
329+
<< "Profile attribute (" << tensorrtShapeBoundsAttrName
330+
<< ") of argument " << argIndex << " is not set";
331+
}
332+
func->setArgAttr(argIndex, tensorrtShapeBoundsAttrName, shapeBoundAttr);
333+
auto dimensionNameAttr = parentFunc.getArgAttrOfType<DictionaryAttr>(
334+
argIndex, tensorrtDimensionNamesAttrName);
335+
if (dimensionNameAttr) {
336+
func->setArgAttr(argIndex, tensorrtDimensionNamesAttrName,
337+
dimensionNameAttr);
338+
}
314339
}
315340
}
316341

317-
for (unsigned idx = 0; idx < func->getNumArguments(); idx++) {
318-
if (profileAttrsPerInput[idx])
319-
func->setArgAttr(idx, tensorrtShapeBoundsAttrName,
320-
profileAttrsPerInput[idx]);
321-
if (dimensionNamesAttrsPerInput[idx])
322-
func->setArgAttr(idx, tensorrtDimensionNamesAttrName,
323-
dimensionNamesAttrsPerInput[idx]);
324-
}
325-
326342
rewriter.setInsertionPoint(inlineGroupOp);
327343
auto callOp = rewriter.create<tensorrt::CallAllocOp>(
328344
inlineGroupOp.getLoc(), inlineGroupOp.getResultTypes(), inputs,

tripy/nvtripy/backend/api/compile.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from nvtripy import constants, export, utils
2222
from nvtripy.backend.api.executable import Executable
23-
from nvtripy.backend.api.input_info import InputInfo
23+
from nvtripy.backend.api.input_info import InputInfo, DimensionInputInfo
2424
from nvtripy.backend.mlir import Compiler
2525
from nvtripy.common.exception import raise_error
2626
from nvtripy.frontend import Tensor, Trace
@@ -162,6 +162,22 @@ def process_arg(name, arg):
162162
input_names.add(name)
163163

164164
return tensor
165+
166+
if isinstance(arg, DimensionInputInfo):
167+
from nvtripy.frontend.dimension_size import DimensionSize
168+
169+
input_infos[name] = arg
170+
171+
tensor = DimensionSize(arg.value_bounds.opt[0])
172+
tensor.name = name
173+
tensor.trace_tensor.is_compile_tracer = True
174+
assert tensor.trace_tensor.shape == ()
175+
176+
trace_input_map[name] = tensor
177+
input_names.add(name)
178+
179+
return tensor
180+
165181
return arg
166182

167183
compiled_arg_names = []

tripy/nvtripy/backend/api/executable.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,17 @@ def add(a, b):
191191
],
192192
)
193193

194-
for tensor in input_tensors:
195-
producer = tensor.trace_tensor.producer
196-
if not isinstance(producer, Constant) or tensor.device.kind != "gpu":
197-
raise_error(
198-
"Inputs to compiled executables must be evaluated tensors on the GPU.",
199-
[
200-
"Got input" + (f" on device '{tensor.device}':" if tensor.device.kind != "gpu" else ":"),
201-
tensor,
202-
"Hint: Try calling `.eval()` on the tensor to ensure it is a GPU constant.",
203-
],
204-
)
194+
# for tensor in input_tensors:
195+
# producer = tensor.trace_tensor.producer
196+
# if not isinstance(producer, Constant) or tensor.device.kind != "gpu":
197+
# raise_error(
198+
# "Inputs to compiled executables must be evaluated tensors on the GPU.",
199+
# [
200+
# "Got input" + (f" on device '{tensor.device}':" if tensor.device.kind != "gpu" else ":"),
201+
# tensor,
202+
# "Hint: Try calling `.eval()` on the tensor to ensure it is a GPU constant.",
203+
# ],
204+
# )
205205

206206
input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors]
207207
try:

tripy/nvtripy/backend/api/input_info.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from nvtripy import export
1818
from nvtripy.backend.api.named_dimension import NamedDimension
19-
from nvtripy.backend.api.shape_bounds import ShapeBounds
19+
from nvtripy.backend.api.shape_bounds import ShapeBounds, ValueBounds
2020
from nvtripy.frontend.dimension_size import DimensionSize
2121
from nvtripy.types import IntLike
2222
from nvtripy.utils import json as json_utils
@@ -74,7 +74,6 @@ def __init__(
7474
"""
7575
is_int_like = lambda arg: any(isinstance(arg, typ) for typ in {int, DimensionSize})
7676

77-
# TODO (#252): Allow `shape` to be a shape tensor
7877
min_shape = []
7978
opt_shape = []
8079
max_shape = []
@@ -129,3 +128,48 @@ def decode_input_info(input_info_dict):
129128
input_info.shape_bounds = input_info_dict["shape_bounds"]
130129
input_info.dimension_names = {int(k): v for k, v in input_info_dict.get("dimension_names", {}).items()}
131130
return input_info
131+
132+
133+
@export.public_api(document_under="compiling_code")
134+
class DimensionInputInfo:
135+
"""
136+
Captures information about a dimension size input to a compiled function.
137+
"""
138+
139+
def __init__(self, value_bounds: Tuple[IntLike, IntLike, IntLike]) -> None:
140+
"""
141+
Args:
142+
value_bounds: The value bound of the dimension size input, consisting of minimum, optimum, and maximum values.
143+
144+
.. code-block:: python
145+
:linenos:
146+
:caption: Dynamic Dimensions
147+
148+
# The dimension size will support values in the range [1, 3],
149+
# optimizing for a size of 2.
150+
dim_inp = tp.DimensionInputInfo((1, 2, 3))
151+
assert dim_inp.min == 1
152+
assert dim_inp.opt == 2
153+
assert dim_inp.max == 3
154+
"""
155+
self.value_bounds = ValueBounds(
156+
min=tuple([value_bounds[0]]), opt=tuple([value_bounds[1]]), max=tuple([value_bounds[2]])
157+
)
158+
159+
def __str__(self) -> str:
160+
return (
161+
f"DimensionInputInfo(min={self.value_bounds.min}, opt={self.value_bounds.opt}, max={self.value_bounds.max})"
162+
)
163+
164+
165+
@json_utils.Encoder.register(DimensionInputInfo)
166+
def encode_dim_input_info(dim_input_info):
167+
return {
168+
"value_bounds": dim_input_info.value_bounds,
169+
}
170+
171+
172+
@json_utils.Decoder.register(DimensionInputInfo)
173+
def decode_dim_input_info(dim_input_info_dict):
174+
dim_input_info_dict.value_bounds = dim_input_info_dict["value_bounds"]
175+
return dim_input_info_dict

tripy/nvtripy/backend/api/shape_bounds.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,28 @@ def decode_shape_bounds(shape_bounds_dict):
5959
opt=tuple(shape_bounds_dict["opt"]),
6060
max=tuple(shape_bounds_dict["max"]),
6161
)
62+
63+
64+
@dataclass
65+
class ValueBounds:
66+
min: Tuple[IntLike]
67+
opt: Tuple[IntLike]
68+
max: Tuple[IntLike]
69+
70+
71+
@json_utils.Encoder.register(ValueBounds)
72+
def encode_value_bounds(value_bounds):
73+
return {
74+
"min": tuple(value_bounds.min),
75+
"opt": tuple(value_bounds.opt),
76+
"max": tuple(value_bounds.max),
77+
}
78+
79+
80+
@json_utils.Decoder.register(ValueBounds)
81+
def decode_value_bounds(value_bounds_dict):
82+
return ValueBounds(
83+
min=tuple(value_bounds_dict["min"]),
84+
opt=tuple(value_bounds_dict["opt"]),
85+
max=tuple(value_bounds_dict["max"]),
86+
)

tripy/nvtripy/trace/trace.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
from textwrap import indent
19-
from typing import Dict, List, Optional, Sequence, Set
19+
from typing import Dict, List, Optional, Sequence, Set, Union
2020

2121
from mlir_tensorrt.compiler import ir
2222
from mlir_tensorrt.compiler.dialects import func as func_dialect
@@ -43,7 +43,7 @@ def __init__(
4343
self,
4444
outputs: Sequence[TraceTensor],
4545
inputs: Sequence[TraceTensor] = [],
46-
input_infos: Optional[Dict[str, "nvtripy.InputInfo"]] = None,
46+
input_infos: Optional[Dict[str, Union["nvtripy.InputInfo", "nvtripy.DimensionInputInfo"]]] = None,
4747
name: str = "main",
4848
) -> None:
4949
# ops/inputs/outputs are populated by `trace()`
@@ -132,6 +132,8 @@ def get_sep(lst):
132132
return "\n".join(layer_strs)
133133

134134
def to_mlir(self):
135+
from nvtripy.backend.api.input_info import InputInfo, DimensionInputInfo
136+
135137
def to_mlir_impl():
136138

137139
with make_ir_context(), ir.Location.unknown():
@@ -195,13 +197,23 @@ def num_known_dims(ranked_tensor_type):
195197
attr = {}
196198
if self.input_infos:
197199
input_info = self.input_infos[inp.name]
198-
shape_bounds = input_info.shape_bounds
199-
attr["tensorrt.shape_profile"] = ir.Attribute.parse(
200-
f"#tensorrt.shape_profile<min={list(shape_bounds.min)}, opt={list(shape_bounds.opt)}, max={list(shape_bounds.max)}>"
201-
)
202-
attr["tensorrt.dimension_names"] = ir.DictAttr.get(
203-
{str(idx): ir.StringAttr.get(name) for idx, name in input_info.dimension_names.items()}
204-
)
200+
if isinstance(input_info, InputInfo):
201+
shape_bounds = input_info.shape_bounds
202+
attr["tensorrt.shape_profile"] = ir.Attribute.parse(
203+
f"#tensorrt.shape_profile<min={list(shape_bounds.min)}, opt={list(shape_bounds.opt)}, max={list(shape_bounds.max)}>"
204+
)
205+
attr["tensorrt.dimension_names"] = ir.DictAttr.get(
206+
{
207+
str(idx): ir.StringAttr.get(name)
208+
for idx, name in input_info.dimension_names.items()
209+
}
210+
)
211+
elif isinstance(input_info, DimensionInputInfo):
212+
value_bounds = input_info.value_bounds
213+
attr["tensorrt.value_bounds"] = ir.Attribute.parse(
214+
f"#tensorrt.shape_profile<min={list(value_bounds.min)}, opt={list(value_bounds.opt)}, max={list(value_bounds.max)}>"
215+
)
216+
attr["plan.memory_space"] = ir.Attribute.parse("#plan.memory_space<host>")
205217

206218
arg_attrs.append(ir.DictAttr.get(attr))
207219

tripy/test_out_of_bound.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import nvtripy as tp
16+
import cupy as cp
17+
18+
from nvtripy.logging import logger
19+
20+
logger.verbosity = "ir"
21+
import mlir_tensorrt.runtime.api as runtime
22+
23+
24+
def func(x):
25+
x = x + x
26+
return x
27+
28+
29+
compiled_func = tp.compile(func, args=[tp.InputInfo(shape=((2, 4, 6), 4), dtype=tp.float32)])
30+
31+
sig = compiled_func._executable_signature
32+
33+
for idx in range(2):
34+
35+
arg = sig.get_arg(idx)
36+
memref = runtime.MemRefType(arg)
37+
print(f"Arg {idx}: ", memref.address_space)
38+
39+
print("Shape: ", memref.shape)
40+
bound = sig.get_arg_bound(idx)
41+
print(f"Bound: {bound.min()}, {bound.max()}")
42+
43+
# inp = cp.ones((8, 4), dtype=cp.float32)
44+
# inp = tp.Tensor(inp)
45+
# out = compiled_func(inp)

tripy/test_shape_input.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import nvtripy as tp
16+
import cupy as cp
17+
18+
from nvtripy.logging import logger
19+
20+
logger.verbosity = "ir"
21+
import mlir_tensorrt.runtime.api as runtime
22+
23+
24+
def func(x, y):
25+
x = x + x
26+
x = tp.reshape(x, (-1, y))
27+
return x
28+
29+
30+
compiled_func = tp.compile(
31+
func, args=[tp.InputInfo(shape=((2, 4, 6), 4), dtype=tp.float32), tp.DimensionInputInfo(value_bounds=(1, 2, 3))]
32+
)
33+
34+
print("compilation complete.")
35+
36+
sig = compiled_func._executable_signature
37+
38+
for idx in range(2):
39+
40+
arg = sig.get_arg(idx)
41+
memref = runtime.MemRefType(arg)
42+
print(f"Arg {idx}: ", memref.address_space)
43+
44+
print("Shape: ", memref.shape)
45+
bound = sig.get_arg_bound(idx)
46+
print(f"Bound: {bound.min()}, {bound.max()}")
47+
48+
49+
# import pdb
50+
# pdb.set_trace()
51+
52+
53+
inp = cp.ones((4, 4), dtype=cp.float32)
54+
inp = tp.Tensor(inp)
55+
dim_inp = tp.DimensionSize(2)
56+
out = compiled_func(inp, dim_inp)
57+
print(out)

0 commit comments

Comments
 (0)