Skip to content

Commit bb8174c

Browse files
committed
Address comments
1 parent 5c95436 commit bb8174c

File tree

11 files changed

+64
-93
lines changed

11 files changed

+64
-93
lines changed

tripy/nvtripy/backend/api/shape_bounds.py renamed to tripy/nvtripy/backend/api/bounds.py

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,62 +25,37 @@
2525

2626
@export.public_api(document_under="compiling_code/input_info", document_init_sig=False)
2727
@dataclass
28-
class ShapeBounds:
28+
class Bounds:
2929
min: Tuple[IntLike]
3030
"""
31-
The minimum shape.
31+
The minimum value.
3232
"""
3333
opt: Tuple[IntLike]
3434
"""
35-
The shape to optimize for.
35+
The value to optimize for.
3636
"""
3737
max: Tuple[IntLike]
3838
"""
39-
The maximum shape.
39+
The maximum value.
4040
"""
4141

4242
def is_static(self):
4343
return self.min == self.opt == self.max
4444

4545

46-
@json_utils.Encoder.register(ShapeBounds)
47-
def encode_shape_bounds(shape_bounds):
46+
@json_utils.Encoder.register(Bounds)
47+
def encode_bounds(bounds):
4848
return {
49-
"min": shape_bounds.min,
50-
"opt": shape_bounds.opt,
51-
"max": shape_bounds.max,
49+
"min": bounds.min,
50+
"opt": bounds.opt,
51+
"max": bounds.max,
5252
}
5353

5454

55-
@json_utils.Decoder.register(ShapeBounds)
56-
def decode_shape_bounds(shape_bounds_dict):
57-
return ShapeBounds(
58-
min=tuple(shape_bounds_dict["min"]),
59-
opt=tuple(shape_bounds_dict["opt"]),
60-
max=tuple(shape_bounds_dict["max"]),
61-
)
62-
63-
64-
@dataclass
65-
class ValueBounds:
66-
min: Tuple[IntLike]
67-
opt: Tuple[IntLike]
68-
max: Tuple[IntLike]
69-
70-
71-
@json_utils.Encoder.register(ValueBounds)
72-
def encode_value_bounds(value_bounds):
73-
return {
74-
"min": tuple(value_bounds.min),
75-
"opt": tuple(value_bounds.opt),
76-
"max": tuple(value_bounds.max),
77-
}
78-
79-
80-
@json_utils.Decoder.register(ValueBounds)
81-
def decode_value_bounds(value_bounds_dict):
82-
return ValueBounds(
83-
min=tuple(value_bounds_dict["min"]),
84-
opt=tuple(value_bounds_dict["opt"]),
85-
max=tuple(value_bounds_dict["max"]),
55+
@json_utils.Decoder.register(Bounds)
56+
def decode_bounds(bounds_dict):
57+
return Bounds(
58+
min=tuple(bounds_dict["min"]),
59+
opt=tuple(bounds_dict["opt"]),
60+
max=tuple(bounds_dict["max"]),
8661
)

tripy/nvtripy/backend/api/input_info.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from nvtripy import export
1818
from nvtripy.backend.api.named_dimension import NamedDimension
19-
from nvtripy.backend.api.shape_bounds import ShapeBounds, ValueBounds
19+
from nvtripy.backend.api.bounds import Bounds
2020
from nvtripy.frontend.dimension_size import DimensionSize
2121
from nvtripy.types import IntLike
2222
from nvtripy.utils import json as json_utils
@@ -97,7 +97,7 @@ def __init__(
9797
A mapping of dimension indices to their names, if set.
9898
"""
9999

100-
self.shape_bounds: ShapeBounds = ShapeBounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape))
100+
self.shape_bounds: Bounds = Bounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape))
101101
"""
102102
The shape bounds of the input.
103103
"""
@@ -152,7 +152,9 @@ def __init__(self, value_bounds: Tuple[IntLike, IntLike, IntLike]) -> None:
152152
assert dim_inp.opt == 2
153153
assert dim_inp.max == 3
154154
"""
155-
self.value_bounds = ValueBounds(
155+
# Evaluate `DimensionSize` early to avoid duplicate evaluation
156+
value_bounds = tuple(map(int, value_bounds))
157+
self.value_bounds = Bounds(
156158
min=tuple([value_bounds[0]]), opt=tuple([value_bounds[1]]), max=tuple([value_bounds[2]])
157159
)
158160

@@ -171,5 +173,6 @@ def encode_dim_input_info(dim_input_info):
171173

172174
@json_utils.Decoder.register(DimensionInputInfo)
173175
def decode_dim_input_info(dim_input_info_dict):
174-
dim_input_info_dict.value_bounds = dim_input_info_dict["value_bounds"]
175-
return dim_input_info_dict
176+
dim_input_info = DimensionInputInfo((-1, -1, -1))
177+
dim_input_info.value_bounds = dim_input_info_dict["value_bounds"]
178+
return dim_input_info

tripy/nvtripy/frontend/dimension_size.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def __str__(self) -> str:
4848
return str(val)
4949

5050
def eval(self) -> "nvtripy.Tensor":
51+
from nvtripy.common import device
5152
from nvtripy.trace.ops.shape import GetDimensionSize, Shape
53+
from nvtripy.frontend.ops.copy import copy
5254

5355
# TODO (#593): Generalize this to any branchy graph:
5456
# If we find a pattern like Shape -> GetDimensionSize, we want to eval the Shape operation
@@ -62,4 +64,4 @@ def eval(self) -> "nvtripy.Tensor":
6264
dim_size.outputs[0].is_compile_tracer = self.trace_tensor.is_compile_tracer
6365
self.trace_tensor = dim_size.outputs[0]
6466

65-
return super().eval()
67+
return copy(super().eval(), device("cpu"))

tripy/nvtripy/frontend/tensor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,6 @@ def eval(self) -> "nvtripy.Tensor":
181181
"""
182182
Immediately evaluates this tensor. By default, tensors are evaluated lazily.
183183
184-
.. note:: The evaluated tensor will always be in **GPU memory**.
185-
186184
Returns:
187185
The evaluated tensor.
188186

tripy/nvtripy/trace/trace.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def num_known_dims(ranked_tensor_type):
227227
for idx, name in input_info.dimension_names.items()
228228
}
229229
)
230-
elif isinstance(input_info, DimensionInputInfo):
230+
else:
231+
assert isinstance(input_info, DimensionInputInfo)
231232
value_bounds = input_info.value_bounds
232233
attr["tensorrt.value_bounds"] = ir.Attribute.parse(
233234
f"#tensorrt.shape_profile<min={list(value_bounds.min)}, opt={list(value_bounds.opt)}, max={list(value_bounds.max)}>"

tripy/tests/backend/api/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import nvtripy as tp
16+
17+
1518
def add(a, b):
1619
return a + b
1720

@@ -38,3 +41,8 @@ def variadic_positional(*args):
3841

3942
def variadic_keyword(**kwargs):
4043
return sum(kwargs.values())
44+
45+
46+
def dynamic_reshape(a, b):
47+
a = a + a
48+
return tp.reshape(a, (-1, b))

tripy/tests/backend/api/test_compile.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,20 @@ def func(a):
224224
tp.TripyException, match="Tensors that are not inputs to compiled functions must reside in CPU memory."
225225
):
226226
tp.compile(func, args=[tp.InputInfo((2, 3), dtype=tp.float32)])
227+
228+
def test_dimension_input(self):
229+
dummy = tp.ones((3, 4))
230+
compiled = tp.compile(
231+
dynamic_reshape,
232+
args=[
233+
tp.InputInfo(shape=((2, 4, 6), 4), dtype=tp.float32),
234+
tp.DimensionInputInfo(value_bounds=(2, dummy.shape[1], 6)),
235+
],
236+
)
237+
for reshape_dim in [2, 4, 6]:
238+
inp_cp = cp.arange(12, dtype=cp.float32).reshape((3, 4))
239+
inp = tp.Tensor(inp_cp)
240+
dim_inp = tp.DimensionSize(reshape_dim)
241+
out = compiled(inp, dim_inp)
242+
expected = (inp_cp + inp_cp).reshape((-1, reshape_dim))
243+
assert cp.array_equal(cp.from_dlpack(out), expected)

tripy/tests/backend/api/test_input_info.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,17 @@ def test_dimension_names(self):
7070
assert inp.shape_bounds.max == (3,)
7171
assert inp.dimension_names == {0: "batch"}
7272

73-
def test_serialize(self):
73+
def test_serialize_input_info(self):
7474
batch = tp.NamedDimension("batch", 1, 2, 3)
7575
inp_info = tp.InputInfo(shape=[batch, 3, 28, 28], dtype=tp.float32)
7676

7777
deserialized = json_utils.from_json(json_utils.to_json(inp_info))
7878

7979
assert inp_info.__dict__ == deserialized.__dict__
80+
81+
def test_serialize_dim_input_info(self):
82+
dim_inp_info = tp.DimensionInputInfo(value_bounds=(2, 4, 6))
83+
84+
deserialized = json_utils.from_json(json_utils.to_json(dim_inp_info))
85+
86+
assert dim_inp_info.__dict__ == deserialized.__dict__

tripy/tests/integration/test_dimension_input.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

tripy/tests/trace/test_trace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def test_str_for_dynamic_shapes(self):
229229
== dedent(
230230
r"""
231231
def main(
232-
a : tensor<?xi32:gpu:0> : InputInfo<ShapeBounds(min=(2,), opt=(3,), max=(4,)), dimension names: {0: 'dim'}, dtype: int32>,
233-
b : tensor<?xi32:gpu:0> : InputInfo<ShapeBounds(min=(2,), opt=(3,), max=(4,)), dimension names: {}, dtype: int32>
232+
a : tensor<?xi32:gpu:0> : InputInfo<Bounds(min=(2,), opt=(3,), max=(4,)), dimension names: {0: 'dim'}, dtype: int32>,
233+
b : tensor<?xi32:gpu:0> : InputInfo<Bounds(min=(2,), opt=(3,), max=(4,)), dimension names: {}, dtype: int32>
234234
) -> (
235235
c : tensor<?xi32:gpu:0>
236236
):

0 commit comments

Comments
 (0)