Skip to content

Commit 65607a5

Browse files
committed
WIP: Add tp.DimensionInputInfo, support shape tensor input
1 parent 05ce756 commit 65607a5

File tree

7 files changed

+208
-18
lines changed

7 files changed

+208
-18
lines changed

tripy/nvtripy/backend/api/compile.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from nvtripy import constants, export, utils
2222
from nvtripy.backend.api.executable import Executable
23-
from nvtripy.backend.api.input_info import InputInfo
23+
from nvtripy.backend.api.input_info import InputInfo, DimensionInputInfo
2424
from nvtripy.backend.mlir import Compiler
2525
from nvtripy.common.exception import raise_error
2626
from nvtripy.frontend import Tensor, Trace
@@ -164,6 +164,21 @@ def process_arg(name, arg):
164164
input_names.add(name)
165165

166166
return tensor
167+
168+
if isinstance(arg, DimensionInputInfo):
169+
from nvtripy.frontend.dimension_size import DimensionSize
170+
171+
tensor = DimensionSize(arg.value_bounds.opt[0])
172+
tensor.name = name
173+
tensor.trace_tensor.is_compile_tracer = True
174+
assert tensor.trace_tensor.shape == ()
175+
176+
trace_input_map[name] = tensor
177+
shapes.append(arg.value_bounds)
178+
input_names.add(name)
179+
180+
return tensor
181+
167182
return arg
168183

169184
compiled_arg_names = []

tripy/nvtripy/backend/api/executable.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,17 @@ def add(a, b):
190190
],
191191
)
192192

193-
for tensor in input_tensors:
194-
producer = tensor.trace_tensor.producer
195-
if not isinstance(producer, Constant) or tensor.device.kind != "gpu":
196-
raise_error(
197-
"Inputs to compiled executables must be evaluated tensors on the GPU.",
198-
[
199-
"Got input" + (f" on device '{tensor.device}':" if tensor.device.kind != "gpu" else ":"),
200-
tensor,
201-
"Hint: Try calling `.eval()` on the tensor to ensure it is a GPU constant.",
202-
],
203-
)
193+
# for tensor in input_tensors:
194+
# producer = tensor.trace_tensor.producer
195+
# if not isinstance(producer, Constant) or tensor.device.kind != "gpu":
196+
# raise_error(
197+
# "Inputs to compiled executables must be evaluated tensors on the GPU.",
198+
# [
199+
# "Got input" + (f" on device '{tensor.device}':" if tensor.device.kind != "gpu" else ":"),
200+
# tensor,
201+
# "Hint: Try calling `.eval()` on the tensor to ensure it is a GPU constant.",
202+
# ],
203+
# )
204204

205205
input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors]
206206
try:

tripy/nvtripy/backend/api/input_info.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Sequence, Tuple, Union
1616

1717
from nvtripy import export
18-
from nvtripy.common.shape_bounds import ShapeBounds
18+
from nvtripy.common.shape_bounds import ShapeBounds, ValueBounds
1919
from nvtripy.frontend.dimension_size import DimensionSize
2020
from nvtripy.types import IntLike
2121
from nvtripy.utils import json as json_utils
@@ -57,7 +57,6 @@ def __init__(
5757
"""
5858
is_int_like = lambda arg: any(isinstance(arg, typ) for typ in {int, DimensionSize})
5959

60-
# TODO (#252): Allow `shape` to be a shape tensor
6160
min_shape = []
6261
opt_shape = []
6362
max_shape = []
@@ -100,3 +99,44 @@ def decode_input_info(input_info_dict):
10099
input_info = InputInfo(shape=[], dtype=input_info_dict["dtype"])
101100
input_info.shape_bounds = input_info_dict["shape_bounds"]
102101
return input_info
102+
103+
@export.public_api(document_under="compiling_code")
104+
class DimensionInputInfo:
105+
"""
106+
Captures information about a dimension size input to a compiled function.
107+
"""
108+
109+
def __init__(self, value_bounds: Tuple[IntLike, IntLike, IntLike]) -> None:
110+
"""
111+
Args:
112+
value_bounds: The value bound of the dimension size input, consisting of minimum, optimum, and maximum values.
113+
114+
.. code-block:: python
115+
:linenos:
116+
:caption: Dynamic Dimensions
117+
118+
# The dimension size will support values in the range [1, 3],
119+
# optimizing for a size of 2.
120+
dim_inp = tp.DimensionInputInfo((1, 2, 3))
121+
assert dim_inp.min == 1
122+
assert dim_inp.opt == 2
123+
assert dim_inp.max == 3
124+
"""
125+
self.value_bounds = ValueBounds(min=tuple(value_bounds[0]), opt=tuple(value_bounds[1]), max=tuple(value_bounds[2]))
126+
127+
def __str__(self) -> str:
128+
return (
129+
f"DimensionInputInfo(min={self.value_bounds.min}, opt={self.value_bounds.opt}, max={self.value_bounds.max})"
130+
)
131+
132+
@json_utils.Encoder.register(DimensionInputInfo)
133+
def encode_input_info(dim_input_info):
134+
return {
135+
"value_bounds": dim_input_info.value_bounds,
136+
}
137+
138+
139+
@json_utils.Decoder.register(DimensionInputInfo)
140+
def decode_input_info(dim_input_info_dict):
141+
dim_input_info_dict.value_bounds = dim_input_info_dict["value_bounds"]
142+
return dim_input_info_dict

tripy/nvtripy/common/shape_bounds.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,26 @@ def decode_shape_bounds(shape_bounds_dict):
5959
opt=tuple(shape_bounds_dict["opt"]),
6060
max=tuple(shape_bounds_dict["max"]),
6161
)
62+
63+
@dataclass
64+
class ValueBounds:
65+
min: Tuple[IntLike]
66+
opt: Tuple[IntLike]
67+
max: Tuple[IntLike]
68+
69+
@json_utils.Encoder.register(ValueBounds)
70+
def encode_shape_bounds(value_bounds):
71+
return {
72+
"min": tuple(value_bounds.min),
73+
"opt": tuple(value_bounds.opt),
74+
"max": tuple(value_bounds.max),
75+
}
76+
77+
78+
@json_utils.Decoder.register(ValueBounds)
79+
def decode_shape_bounds(shape_bounds_dict):
80+
return ValueBounds(
81+
min=tuple(shape_bounds_dict["min"]),
82+
opt=tuple(shape_bounds_dict["opt"]),
83+
max=tuple(shape_bounds_dict["max"]),
84+
)

tripy/nvtripy/trace/trace.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
from textwrap import indent
19-
from typing import Dict, List, Optional, Sequence, Set
19+
from typing import Dict, List, Optional, Sequence, Set, Union
2020

2121
from mlir_tensorrt.compiler import ir
2222
from mlir_tensorrt.compiler.dialects import func as func_dialect
@@ -29,7 +29,7 @@
2929
redirect_stderr,
3030
)
3131
from nvtripy.common.exception import raise_error
32-
from nvtripy.common.shape_bounds import ShapeBounds
32+
from nvtripy.common.shape_bounds import ShapeBounds, ValueBounds
3333
from nvtripy.logging import logger
3434
from nvtripy.trace.tensor import TraceTensor
3535
from nvtripy.trace.utils import topological_sort
@@ -44,7 +44,7 @@ def __init__(
4444
self,
4545
outputs: Sequence[TraceTensor],
4646
inputs: Sequence[TraceTensor] = [],
47-
shapes: Optional[Sequence[ShapeBounds]] = None,
47+
shapes: Optional[Sequence[Union[ShapeBounds, ValueBounds]]] = None,
4848
name: str = "main",
4949
) -> None:
5050
"""
@@ -199,9 +199,16 @@ def num_known_dims(ranked_tensor_type):
199199
for idx in range(len(self.inputs)):
200200
attr = {}
201201
if self.shapes:
202-
attr["tensorrt.shape_profile"] = ir.Attribute.parse(
202+
attr_name = (
203+
"tensorrt.value_bounds"
204+
if isinstance(self.shapes[idx], ValueBounds)
205+
else "tensorrt.shape_profile"
206+
)
207+
attr[attr_name] = ir.Attribute.parse(
203208
f"#tensorrt.shape_profile<min={list(self.shapes[idx].min)}, opt={list(self.shapes[idx].opt)}, max={list(self.shapes[idx].max)}>"
204209
)
210+
if idx == 1:
211+
attr["tensorrt.host_tensor"] = ir.UnitAttr.get()
205212

206213
arg_attrs.append(ir.DictAttr.get(attr))
207214

tripy/test_out_of_bound.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import nvtripy as tp
16+
import cupy as cp
17+
18+
from nvtripy.logging import logger
19+
20+
logger.verbosity = "ir"
21+
import mlir_tensorrt.runtime.api as runtime
22+
23+
24+
def func(x):
25+
x = x + x
26+
return x
27+
28+
29+
compiled_func = tp.compile(func, args=[tp.InputInfo(shape=((2, 4, 6), 4), dtype=tp.float32)])
30+
31+
sig = compiled_func._executable_signature
32+
33+
for idx in range(2):
34+
35+
arg = sig.get_arg(idx)
36+
memref = runtime.MemRefType(arg)
37+
print(f"Arg {idx}: ", memref.address_space)
38+
39+
print("Shape: ", memref.shape)
40+
bound = sig.get_arg_bound(idx)
41+
print(f"Bound: {bound.min()}, {bound.max()}")
42+
43+
# inp = cp.ones((8, 4), dtype=cp.float32)
44+
# inp = tp.Tensor(inp)
45+
# out = compiled_func(inp)

tripy/test_shape_input.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import nvtripy as tp
16+
import cupy as cp
17+
18+
from nvtripy.logging import logger
19+
20+
logger.verbosity = "ir"
21+
import mlir_tensorrt.runtime.api as runtime
22+
23+
24+
def func(x, y):
25+
x = x + x
26+
x = tp.reshape(x, (-1, y))
27+
return x
28+
29+
30+
compiled_func = tp.compile(
31+
func, args=[tp.InputInfo(shape=((2, 4, 6), 4), dtype=tp.float32), tp.DimensionInputInfo(value_bounds=(1, 2, 3))]
32+
)
33+
34+
print("compilation complete.")
35+
36+
sig = compiled_func._executable_signature
37+
38+
input_info = compiled_func._get_input_info()
39+
print(input_info[0])
40+
print(input_info[1])
41+
for idx in range(2):
42+
43+
arg = sig.get_arg(idx)
44+
memref = runtime.MemRefType(arg)
45+
print(f"Arg {idx}: ", memref.address_space)
46+
47+
print("Shape: ", memref.shape)
48+
bound = sig.get_arg_bound(idx)
49+
print(f"Bound: {bound.min()}, {bound.max()}")
50+
51+
52+
# import pdb
53+
# pdb.set_trace()
54+
55+
56+
inp = cp.ones((4, 4), dtype=cp.float32)
57+
inp = tp.Tensor(inp)
58+
dim_inp = tp.DimensionSize(2)
59+
out = compiled_func(inp, dim_inp)
60+
print(out)

0 commit comments

Comments
 (0)