Skip to content

Commit c21f8aa

Browse files
committed
Add stride field to Tripy tensor representations
MLIR-TensorRT requires strides for function arguments and results in canonical order. For Tripy compile mode, ensure we provide correct stride information to the compiled function input. There already exists a runtime check in MLIR-TensorRT to constrain strides to canonical order. For Tripy lazy mode, since function inputs are materialized to data memrefs (or stablehlo.constant), add a compile-time check to ensure strides are in canonical order. This PR adds the following changes: - Add a stride field to Tripy tensor representations. - Update storage op to store stride field from framework tensors and add a compile-time check. - Add get_stride and is_canonical_stride helper methods. - Add a compile-time check while lowering to stablehlo.constant. - Add test_stride.py validating both compile-time and runtime stride constraints.
1 parent 59b9536 commit c21f8aa

File tree

20 files changed

+285
-19
lines changed

20 files changed

+285
-19
lines changed

tripy/tests/backend/test_compiler_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def test_function(self):
185185
inp = tp.ones((2, 2), dtype=tp.float32)
186186
out = compiled_gelu(inp)
187187

188-
# TODO (#225): Replace with tp.all
189-
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(tp.relu(inp)))
188+
assert tp.allclose(out, tp.relu(inp), rtol=0.0, atol=0.0)
190189

191190
def test_module(self):
192191
layernorm = tp.LayerNorm(2)

tripy/tests/flat_ir/ops/test_constant.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def test_str(self):
3030

3131
const = flat_ir.ops[-1]
3232
assert isinstance(const, ConstantOp)
33-
assert str(const) == "out: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=[2.0, 3.0])"
33+
assert (
34+
str(const)
35+
== "out: [rank=(1), shape=((2,)), stride=((1,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=[2.0, 3.0])"
36+
)
3437

3538
def test_mlir(self):
3639
out = tp.Tensor([2, 3], dtype=tp.int32, name="out")
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import re
20+
import torch
21+
22+
import tripy as tp
23+
from tripy.utils import are_strides_equivalent
24+
from tests.helper import raises
25+
26+
27+
class TestStride:
28+
29+
def test_non_canonical_stride(self):
30+
t = torch.arange(12, dtype=torch.float32).reshape(3, 4)
31+
a = tp.Tensor(t)
32+
assert a.stride == t.stride()
33+
34+
a = t.transpose(0, 1)
35+
with raises(
36+
tp.TripyException,
37+
match=re.escape("Non-canonical strides are not supported for Tripy tensors."),
38+
):
39+
print(tp.Tensor(a))
40+
41+
@pytest.mark.parametrize(
42+
"shape",
43+
[
44+
(0,),
45+
(0, 3),
46+
(2, 0, 4),
47+
(3, 0, 0, 5),
48+
(1,),
49+
(1, 3),
50+
(2, 1, 4),
51+
(1, 3, 1, 5),
52+
(3, 1, 1, 1),
53+
(0, 1, 3),
54+
(2, 0, 1, 4),
55+
(1, 0, 3, 1, 5),
56+
(3, 1, 0, 2, 1),
57+
],
58+
)
59+
def test_tensor_stride(self, shape):
60+
torch_tensor = torch.empty(shape)
61+
torch_stride = torch_tensor.stride()
62+
tripy_stride = tp.Tensor(torch_tensor).stride
63+
64+
assert are_strides_equivalent(
65+
shape, tripy_stride, torch_stride
66+
), f"Mismatch for shape {shape}. Calculated: {tripy_stride}, Torch: {torch_stride}"

tripy/tests/frontend/trace/test_trace.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_str(self):
9595
str(trace)
9696
== dedent(
9797
"""
98-
a = storage(data=[0], shape=(1,), dtype=int32, device=gpu:0)
99-
b = storage(data=[1], shape=(1,), dtype=int32, device=gpu:0)
98+
a = storage(data=[0], shape=(1,), stride=(1,), dtype=int32, device=gpu:0)
99+
b = storage(data=[1], shape=(1,), stride=(1,), dtype=int32, device=gpu:0)
100100
c = a + b
101101
outputs:
102102
c: [rank=(1), dtype=(int32), loc=(gpu:0)]
@@ -133,8 +133,8 @@ def test_multiple_outputs(self):
133133
str(trace)
134134
== dedent(
135135
"""
136-
a = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
137-
b = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
136+
a = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
137+
b = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
138138
c = a + b
139139
d = c + c
140140
outputs:
@@ -168,8 +168,8 @@ def test_all_inputs(self):
168168
== dedent(
169169
"""
170170
inputs:
171-
a: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
172-
b: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
171+
a: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
172+
b: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
173173
c = a + b
174174
outputs:
175175
c: [rank=(1), dtype=(float32), loc=(gpu:0)]
@@ -191,8 +191,8 @@ def test_const_and_input(self):
191191
== dedent(
192192
"""
193193
inputs:
194-
a: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
195-
b = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
194+
a: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
195+
b = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
196196
c = a + b
197197
outputs:
198198
c: [rank=(1), dtype=(float32), loc=(gpu:0)]

tripy/tests/integration/test_allclose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class TestAllClose:
3535
],
3636
)
3737
def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol):
38-
np_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
38+
torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
3939
tp_result = tp.allclose(
4040
tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol
4141
)
42-
assert np_result == tp_result
42+
assert torch_result == tp_result

tripy/tests/integration/test_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def test_non_constant_scale(self):
118118
scale = tp.ones((4,))
119119
quantized = tp.quantize(input, scale, tp.int8, dim=0)
120120

121-
assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
121+
assert tp.allclose(quantized, tp.ones((4, 4), dtype=tp.int8), rtol=0.0, atol=0.0)

tripy/tests/utils/test_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tripy import utils
2222
from tests import helper
2323
from collections import defaultdict
24+
from typing import Optional, Sequence
2425

2526

2627
class TestMd5:
@@ -109,3 +110,49 @@ def test_gen_uid(self, inputs, outputs, expected_prefix):
109110
def test_uniqueness(self):
110111
uids = [utils.UniqueNameGen.gen_uid() for _ in range(100)]
111112
assert len(set(uids)) == 100
113+
114+
115+
class TestStride:
116+
@pytest.mark.parametrize(
117+
"shape, provided_stride, expected_stride",
118+
[
119+
((), None, ()),
120+
((1,), None, (1,)),
121+
((2, 3), None, (3, 1)),
122+
((0, 5), None, (1, 1)),
123+
((1, 0, 3), None, (1, 1, 1)),
124+
((2, 1, 4), None, (4, 1, 1)),
125+
((3, 0, 0, 5), None, (5, 1, 1, 1)),
126+
((2, 3), (3, 1), (3, 1)),
127+
((0, 5), (5, 1), (5, 1)),
128+
((1, 0, 3), (3, 1, 1), (3, 1, 1)),
129+
(None, None, None),
130+
],
131+
)
132+
def test_get_stride(
133+
self,
134+
shape: Optional[Sequence[int]],
135+
provided_stride: Optional[Sequence[int]],
136+
expected_stride: Optional[Sequence[int]],
137+
):
138+
"""Test both get_stride and get_canonical_stride functions."""
139+
assert utils.get_stride(shape, provided_stride) == expected_stride
140+
if provided_stride is None and shape is not None:
141+
assert utils.get_canonical_stride(shape) == expected_stride
142+
143+
@pytest.mark.parametrize(
144+
"shape, stride, expected_stride, expected_result",
145+
[
146+
((2, 3), (3, 1), (3, 1), True),
147+
((2, 3), (6, 2), (3, 1), False),
148+
((0, 5), (1, 1), (5, 1), True),
149+
((1, 0, 3), (0, 3, 1), (3, 3, 1), True),
150+
((2, 1, 4), (8, 4, 1), (4, 1, 1), False),
151+
((2, 3), (3, 1, 2), (3, 1), False), # Mismatched lengths
152+
],
153+
)
154+
def test_are_strides_equivalent(
155+
self, shape: Sequence[int], stride: Sequence[int], expected_stride: Sequence[int], expected_result: bool
156+
):
157+
"""Test are_strides_equivalent function."""
158+
assert utils.are_strides_equivalent(shape, stride, expected_stride) == expected_result

tripy/tripy/backend/api/executable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def add(a, b):
161161
tensor,
162162
],
163163
)
164+
elif "Runtime stride mismatch" in str(err):
165+
# Just raise the error for now.
166+
raise raise_error(str(err))
167+
164168
raise
165169

166170
from tripy.utils.stack_info import StackInfo

tripy/tripy/backend/mlir/executor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
114114
is_static_shape = all(dim >= 0 for dim in memref.shape)
115115
if is_static_shape:
116116
outputs_tensor_info.append(
117-
TensorInfo(len(memref.shape), tuple(memref.shape), dtype, device(device_type))
117+
TensorInfo(
118+
len(memref.shape), tuple(memref.shape), tuple(memref.strides), dtype, device(device_type)
119+
)
118120
)
119121
else:
120122
runtime_shape = [
@@ -124,6 +126,7 @@ def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
124126
TensorInfo(
125127
len(runtime_shape),
126128
tuple(runtime_shape),
129+
tuple(memref.strides),
127130
dtype,
128131
device(device_type),
129132
)

tripy/tripy/backend/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class TensorInfo:
2828
rank: int
2929
shape: Sequence[int]
30+
stride: Sequence[int]
3031
dtype: "tripy.dtype"
3132
device: "tripy.device"
3233

@@ -36,11 +37,12 @@ def encode(tensor_info: TensorInfo) -> Dict[str, Any]:
3637
return {
3738
"rank": tensor_info.rank,
3839
"shape": tensor_info.shape,
40+
"stride": tensor_info.stride,
3941
"dtype": tensor_info.dtype,
4042
"device": tensor_info.device,
4143
}
4244

4345

4446
@Decoder.register(TensorInfo)
4547
def decode(dct: Dict[str, Any]) -> TensorInfo:
46-
return TensorInfo(dct["rank"], dct["shape"], dct["dtype"], dct["device"])
48+
return TensorInfo(dct["rank"], dct["shape"], dct["stride"], dct["dtype"], dct["device"])

0 commit comments

Comments
 (0)