Skip to content

Commit 863ff90

Browse files
Support DimensionInputInfo in tp.compile (#618)
Signed-off-by: yizhuoz004 <[email protected]> Co-authored-by: pranavm-nvidia <[email protected]>
1 parent 3c4d177 commit 863ff90

File tree

20 files changed

+237
-65
lines changed

20 files changed

+237
-65
lines changed

tripy/docs/post0_developer_guides/02-debugging.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ We include some environment variables to enable extra debugging information from
1212
- `export TRIPY_MLIR_DEBUG_PATH=<mlir-debug-path>` sets the directory for IR dumps. The default path is `mlir-dumps`.
1313
- `export TRIPY_TRT_DEBUG_ENABLED=1` will dump TensorRT engines and their layer information.
1414
- `export TRIPY_TRT_DEBUG_PATH=<trt-debug-path>` sets the directory for TensorRT dumps. Default path is `tensorrt-dumps`.
15+
- `export MTRT_TENSORRT_NVTX=DETAILED` will enable detailed nvtx profiling verbosity for TRT layers.
1516

1617

1718
## Using A Debugger

tripy/examples/segment-anything-model-v2/sam2/build_sam.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def get_component_configs(model, cfg):
8181
(seq_len, mem_attention_batch, 64),
8282
getattr(tp, model_precision),
8383
),
84-
# TODO (#594): Remove this hack once we are able to pass in DimensionSizes directly:
85-
tp.InputInfo(((4, 16, 64),), tp.int32),
84+
tp.DimensionInputInfo(value_bounds=(4, 16, 64)),
8685
],
8786
"skip_dtype_convert": [],
8887
},

tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,8 @@ def forward(
186186
memory: tp.Tensor, # cross-attention inputs
187187
curr_pos: Optional[tp.Tensor] = None, # pos_enc for self-attention inputs
188188
memory_pos: Optional[tp.Tensor] = None, # pos_enc for cross-attention inputs
189-
num_obj_ptr_tokens: Optional[tp.Tensor] = None, # number of object pointer *tokens*
189+
num_obj_ptr_tokens: Optional[tp.DimensionSize] = None, # number of object pointer *tokens*
190190
):
191-
# TODO (#594): Remove this hack once we are able to pass in DimensionSizes directly:
192-
num_obj_ptr_tokens = num_obj_ptr_tokens.shape[0]
193191
output = curr
194192
if self.pos_enc_at_input and curr_pos is not None:
195193
output = output + 0.1 * curr_pos

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,6 @@ def _build_sam_heads(self):
242242
else:
243243
self.obj_ptr_tpos_proj = torch.nn.Identity()
244244

245-
self.fake_object_ptrs = torch.ones((1,), dtype=torch.int32, device="cuda")
246-
247245
def _forward_sam_heads(
248246
self,
249247
backbone_features,
@@ -667,14 +665,12 @@ def _prepare_memory_conditioned_features(
667665
memory = torch.cat(to_cat_memory, dim=0)
668666
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
669667
if isinstance(self.memory_attention, tp.Module) or isinstance(self.memory_attention, tp.Executable):
670-
if self.fake_object_ptrs.shape != (num_obj_ptr_tokens,):
671-
self.fake_object_ptrs = torch.ones((num_obj_ptr_tokens,), dtype=torch.int32, device="cuda")
672668
pix_feat_with_mem = self.memory_attention(
673669
curr=tp.Tensor(current_vision_feats[0].half().contiguous()),
674670
memory=tp.Tensor(memory.half().contiguous()),
675671
curr_pos=tp.Tensor(current_vision_pos_embeds[0].half().contiguous()),
676672
memory_pos=tp.Tensor(memory_pos_embed.half().contiguous()),
677-
num_obj_ptr_tokens=tp.Tensor(self.fake_object_ptrs),
673+
num_obj_ptr_tokens=tp.DimensionSize(num_obj_ptr_tokens),
678674
)
679675
else:
680676
pix_feat_with_mem = self.memory_attention(

tripy/nvtripy/backend/api/shape_bounds.py renamed to tripy/nvtripy/backend/api/bounds.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,37 +25,37 @@
2525

2626
@export.public_api(document_under="compiling_code/input_info", document_init_sig=False)
2727
@dataclass
28-
class ShapeBounds:
28+
class Bounds:
2929
min: Tuple[IntLike]
3030
"""
31-
The minimum shape.
31+
The minimum value.
3232
"""
3333
opt: Tuple[IntLike]
3434
"""
35-
The shape to optimize for.
35+
The value to optimize for.
3636
"""
3737
max: Tuple[IntLike]
3838
"""
39-
The maximum shape.
39+
The maximum value.
4040
"""
4141

4242
def is_static(self):
4343
return self.min == self.opt == self.max
4444

4545

46-
@json_utils.Encoder.register(ShapeBounds)
47-
def encode_shape_bounds(shape_bounds):
46+
@json_utils.Encoder.register(Bounds)
47+
def encode_bounds(bounds):
4848
return {
49-
"min": shape_bounds.min,
50-
"opt": shape_bounds.opt,
51-
"max": shape_bounds.max,
49+
"min": bounds.min,
50+
"opt": bounds.opt,
51+
"max": bounds.max,
5252
}
5353

5454

55-
@json_utils.Decoder.register(ShapeBounds)
56-
def decode_shape_bounds(shape_bounds_dict):
57-
return ShapeBounds(
58-
min=tuple(shape_bounds_dict["min"]),
59-
opt=tuple(shape_bounds_dict["opt"]),
60-
max=tuple(shape_bounds_dict["max"]),
55+
@json_utils.Decoder.register(Bounds)
56+
def decode_bounds(bounds_dict):
57+
return Bounds(
58+
min=tuple(bounds_dict["min"]),
59+
opt=tuple(bounds_dict["opt"]),
60+
max=tuple(bounds_dict["max"]),
6161
)

tripy/nvtripy/backend/api/compile.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from nvtripy import constants, export, utils
2222
from nvtripy.backend.api.executable import Executable
23-
from nvtripy.backend.api.input_info import InputInfo
23+
from nvtripy.backend.api.input_info import InputInfo, DimensionInputInfo
2424
from nvtripy.backend.mlir import Compiler
2525
from nvtripy.common.exception import raise_error
2626
from nvtripy.frontend import Tensor, Trace
@@ -106,6 +106,30 @@ def add(a, b):
106106
107107
big_out = compiled_add(big_a, big_b)
108108
109+
.. code-block:: python
110+
:linenos:
111+
:caption: Shape Input
112+
113+
def dynamic_reshape(x, s):
114+
return tp.reshape(x, (-1, s))
115+
116+
# doc: no-print-locals compiled_reshape
117+
118+
# Support dynamic dim in the range of 1 to 4, optimizing for a
119+
# dim value of 2
120+
compiled_reshape = tp.compile(
121+
dynamic_reshape,
122+
args=[
123+
tp.InputInfo(shape=(3, (2, 4, 6)), dtype=tp.float32),
124+
tp.DimensionInputInfo(value_bounds=(1, 2, 4)),
125+
],
126+
)
127+
128+
a = tp.ones((3, 4), dtype=tp.float32).eval()
129+
s = tp.DimensionSize(2)
130+
131+
out = compiled_reshape(a, s)
132+
assert out.shape == (6, 2)
109133
110134
.. code-block:: python
111135
:linenos:
@@ -162,6 +186,22 @@ def process_arg(name, arg):
162186
input_names.add(name)
163187

164188
return tensor
189+
190+
if isinstance(arg, DimensionInputInfo):
191+
from nvtripy.frontend.dimension_size import DimensionSize
192+
193+
input_infos[name] = arg
194+
195+
tensor = DimensionSize(arg.value_bounds.opt[0])
196+
tensor.name = name
197+
tensor.trace_tensor.is_compile_tracer = True
198+
assert tensor.trace_tensor.shape == ()
199+
200+
trace_input_map[name] = tensor
201+
input_names.add(name)
202+
203+
return tensor
204+
165205
return arg
166206

167207
compiled_arg_names = []

tripy/nvtripy/backend/api/executable.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import mlir_tensorrt.runtime.api as runtime
2020
from nvtripy import config, export
21-
from nvtripy.backend.api.input_info import InputInfo
21+
from nvtripy.backend.api.input_info import InputInfo, DimensionInputInfo
2222
from nvtripy.backend.api.stream import default_stream
2323
from nvtripy.backend.mlir.utils import MLIRRuntimeClient
2424
from nvtripy.common.exception import raise_error
@@ -41,7 +41,11 @@ class Executable:
4141
# `return_single_tensor_as_sequence` indicates whether the return type should be a sequence even if
4242
# there is only one output.
4343
def __init__(
44-
self, executable, arg_names, return_single_tensor_as_sequence: bool, input_infos: Dict[str, InputInfo]
44+
self,
45+
executable,
46+
arg_names,
47+
return_single_tensor_as_sequence: bool,
48+
input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]],
4549
):
4650
self._executable = executable
4751

@@ -69,7 +73,7 @@ def __init__(
6973

7074
self.__signature__ = inspect.Signature(params, return_annotation=return_annotation)
7175

72-
self.input_infos: Dict[str, InputInfo] = input_infos
76+
self.input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]] = input_infos
7377
"""
7478
Stores metadata, like shapes and data types, for each input to the executable.
7579
"""
@@ -191,15 +195,16 @@ def add(a, b):
191195
],
192196
)
193197

194-
for tensor in input_tensors:
198+
expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]
199+
for tensor, expected_device, arg_name in zip(input_tensors, expected_devices, self._arg_names):
195200
producer = tensor.trace_tensor.producer
196-
if not isinstance(producer, Constant) or tensor.device.kind != "gpu":
201+
if not isinstance(producer, Constant):
202+
raise_error(f"Tensor `{arg_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
203+
if tensor.device.kind != expected_device:
197204
raise_error(
198-
"Inputs to compiled executables must be evaluated tensors on the GPU.",
205+
"Unexpected tensor device.",
199206
[
200-
"Got input" + (f" on device '{tensor.device}':" if tensor.device.kind != "gpu" else ":"),
201-
tensor,
202-
"Hint: Try calling `.eval()` on the tensor to ensure it is a GPU constant.",
207+
f"For tensor: `{arg_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
203208
],
204209
)
205210

@@ -212,7 +217,11 @@ def add(a, b):
212217
# TODO: Evaluate whether this should be moved into the executor
213218
if "function expects a memref type with element type" in str(err):
214219
# If the problem is a mismatched data type, we can provide a better error message than the executor can.
215-
expected_input_dtypes = [info.dtype for info in self.input_infos.values()]
220+
from nvtripy.common.datatype import int32
221+
222+
expected_input_dtypes = [
223+
info.dtype if isinstance(info, InputInfo) else int32 for info in self.input_infos.values()
224+
]
216225
for tensor, dtype, arg_name in zip(input_tensors, expected_input_dtypes, self._arg_names):
217226
if tensor.dtype != dtype:
218227
raise_error(
@@ -225,7 +234,9 @@ def add(a, b):
225234
),
226235
)
227236
elif "InternalError: failed to set input shape" in str(err) or "Runtime shape mismatch" in str(err):
228-
expected_input_shapes = [info.shape_bounds for info in self.input_infos.values()]
237+
expected_input_shapes = [
238+
info.shape_bounds if isinstance(info, InputInfo) else tuple() for info in self.input_infos.values()
239+
]
229240
for tensor, expected_bounds, arg_name in zip(input_tensors, expected_input_shapes, self._arg_names):
230241
shape = tensor.shape
231242

tripy/nvtripy/backend/api/input_info.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from nvtripy import export
1818
from nvtripy.backend.api.named_dimension import NamedDimension
19-
from nvtripy.backend.api.shape_bounds import ShapeBounds
19+
from nvtripy.backend.api.bounds import Bounds
2020
from nvtripy.frontend.dimension_size import DimensionSize
2121
from nvtripy.types import IntLike
2222
from nvtripy.utils import json as json_utils
@@ -74,7 +74,6 @@ def __init__(
7474
"""
7575
is_int_like = lambda arg: any(isinstance(arg, typ) for typ in {int, DimensionSize})
7676

77-
# TODO (#252): Allow `shape` to be a shape tensor
7877
min_shape = []
7978
opt_shape = []
8079
max_shape = []
@@ -98,7 +97,7 @@ def __init__(
9897
A mapping of dimension indices to their names, if set.
9998
"""
10099

101-
self.shape_bounds: ShapeBounds = ShapeBounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape))
100+
self.shape_bounds: Bounds = Bounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape))
102101
"""
103102
The shape bounds of the input.
104103
"""
@@ -129,3 +128,51 @@ def decode_input_info(input_info_dict):
129128
input_info.shape_bounds = input_info_dict["shape_bounds"]
130129
input_info.dimension_names = {int(k): v for k, v in input_info_dict.get("dimension_names", {}).items()}
131130
return input_info
131+
132+
133+
@export.public_api(document_under="compiling_code")
134+
class DimensionInputInfo:
135+
"""
136+
Captures information about a dimension size input to a compiled function.
137+
"""
138+
139+
def __init__(self, value_bounds: Tuple[IntLike, IntLike, IntLike]) -> None:
140+
"""
141+
Args:
142+
value_bounds: The value bound of the dimension size input, consisting of minimum, optimum, and maximum values.
143+
144+
.. code-block:: python
145+
:linenos:
146+
:caption: Dimension Size Input
147+
148+
# The dimension size will support values in the range [1, 3],
149+
# optimizing for a size of 2.
150+
dim_inp = tp.DimensionInputInfo((1, 2, 3))
151+
assert dim_inp.value_bounds.min == (1,)
152+
assert dim_inp.value_bounds.opt == (2,)
153+
assert dim_inp.value_bounds.max == (3,)
154+
"""
155+
# Evaluate `DimensionSize` early to avoid duplicate evaluation
156+
value_bounds = tuple(map(int, value_bounds))
157+
self.value_bounds = Bounds(
158+
min=tuple([value_bounds[0]]), opt=tuple([value_bounds[1]]), max=tuple([value_bounds[2]])
159+
)
160+
161+
def __str__(self) -> str:
162+
return (
163+
f"DimensionInputInfo(min={self.value_bounds.min}, opt={self.value_bounds.opt}, max={self.value_bounds.max})"
164+
)
165+
166+
167+
@json_utils.Encoder.register(DimensionInputInfo)
168+
def encode_dim_input_info(dim_input_info):
169+
return {
170+
"value_bounds": dim_input_info.value_bounds,
171+
}
172+
173+
174+
@json_utils.Decoder.register(DimensionInputInfo)
175+
def decode_dim_input_info(dim_input_info_dict):
176+
dim_input_info = DimensionInputInfo((-1, -1, -1))
177+
dim_input_info.value_bounds = dim_input_info_dict["value_bounds"]
178+
return dim_input_info

tripy/nvtripy/frontend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,4 +16,5 @@
1616
#
1717

1818
from nvtripy.frontend.tensor import Tensor
19+
from nvtripy.frontend.dimension_size import DimensionSize
1920
from nvtripy.trace.trace import Trace

tripy/nvtripy/frontend/dimension_size.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Optional, Union
18+
from typing import Optional
1919

2020
from nvtripy import export
21-
from nvtripy.common.datatype import int32
2221
from nvtripy.frontend.tensor import Tensor
2322

2423

@@ -47,7 +46,27 @@ def __str__(self) -> str:
4746
assert isinstance(val, int)
4847
return str(val)
4948

50-
def eval(self) -> "nvtripy.Tensor":
49+
def eval(self) -> "nvtripy.DimensionSize":
50+
"""
51+
Immediately evaluates this ``DimensionSize`` object.
52+
53+
.. note:: ``DimensionSize`` will always reside on host even after it is evaluated.
54+
55+
Returns:
56+
The evaluated ``DimensionSize``.
57+
58+
.. code-block:: python
59+
:linenos:
60+
61+
62+
dim_size = tp.ones((2, 2)).shape[0]
63+
dim_size.eval()
64+
print(dim_size.device)
65+
assert dim_size.device.kind == "cpu"
66+
67+
"""
68+
from nvtripy.backend.mlir import memref
69+
from nvtripy.trace.ops.constant import Constant
5170
from nvtripy.trace.ops.shape import GetDimensionSize, Shape
5271

5372
# TODO (#593): Generalize this to any branchy graph:
@@ -62,4 +81,9 @@ def eval(self) -> "nvtripy.Tensor":
6281
dim_size.outputs[0].is_compile_tracer = self.trace_tensor.is_compile_tracer
6382
self.trace_tensor = dim_size.outputs[0]
6483

65-
return super().eval()
84+
if not isinstance(producer, Constant):
85+
super().eval()
86+
dim_value = memref.tolist(self.trace_tensor.producer.data)
87+
dim_size = DimensionSize(data=int(dim_value), name=self.name)
88+
self.trace_tensor = dim_size.trace_tensor
89+
return self

0 commit comments

Comments
 (0)