Skip to content

Commit 40d6c05

Browse files
committed
Add stride suppport
1 parent 8d67c62 commit 40d6c05

File tree

22 files changed

+214
-24
lines changed

22 files changed

+214
-24
lines changed

tripy/tests/backend/test_compiler_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def test_function(self):
185185
inp = tp.ones((2, 2), dtype=tp.float32)
186186
out = compiled_gelu(inp)
187187

188-
# TODO (#225): Replace with tp.all
189-
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(tp.relu(inp)))
188+
assert tp.allclose(out, tp.relu(inp), rtol=0.0, atol=0.0)
190189

191190
def test_module(self):
192191
layernorm = tp.LayerNorm(2)

tripy/tests/flat_ir/ops/test_constant.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def test_str(self):
3030

3131
const = flat_ir.ops[-1]
3232
assert isinstance(const, ConstantOp)
33-
assert str(const) == "out: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=[2.0, 3.0])"
33+
assert (
34+
str(const)
35+
== "out: [rank=(1), shape=((2,)), stride=((1,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=[2.0, 3.0])"
36+
)
3437

3538
def test_mlir(self):
3639
out = tp.Tensor([2, 3], dtype=tp.int32, name="out")
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import re
20+
import torch
21+
22+
import tripy as tp
23+
from tests.helper import raises
24+
25+
26+
class TestStride:
27+
28+
def test_non_canonical_stride(self):
29+
t = torch.arange(12, dtype=torch.float32).reshape(3, 4)
30+
a = tp.Tensor(t)
31+
assert a.stride == t.stride()
32+
33+
t = t.transpose(0, 1)
34+
a = tp.Tensor(t)
35+
36+
assert a.stride == t.stride()
37+
38+
def test_lazy_stride(self):
39+
a = torch.arange(12, dtype=torch.float32).reshape(4, 3).transpose(0, 1)
40+
with raises(
41+
tp.TripyException,
42+
match=re.escape("Non-canonical strides are not supported for Tripy tensors."),
43+
):
44+
print(tp.Tensor(a))
45+
46+
def test_compile_stride(self):
47+
def twice(t):
48+
return 2 * t
49+
50+
compiler = tp.Compiler(twice)
51+
52+
t = tp.Tensor(torch.arange(12, dtype=torch.float32).reshape(4, 3).transpose(0, 1))
53+
54+
# Create a tensor info with non-canonical stride.
55+
t_info = tp.InputInfo(shape=t.shape.tolist(), dtype=t.dtype, stride=t.stride)
56+
compiled_add = compiler.compile(t_info)
57+
58+
with raises(
59+
tp.TripyException,
60+
match=re.escape("Reason: InvalidArgument: Runtime stride mismatch. Expected [4, 1] but received [1, 3]"),
61+
):
62+
print(compiled_add(t))

tripy/tests/frontend/trace/test_trace.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_str(self):
9595
str(trace)
9696
== dedent(
9797
"""
98-
a = storage(data=[0], shape=(1,), dtype=int32, device=gpu:0)
99-
b = storage(data=[1], shape=(1,), dtype=int32, device=gpu:0)
98+
a = storage(data=[0], shape=(1,), stride=(1,), dtype=int32, device=gpu:0)
99+
b = storage(data=[1], shape=(1,), stride=(1,), dtype=int32, device=gpu:0)
100100
c = a + b
101101
outputs:
102102
c: [rank=(1), dtype=(int32), loc=(gpu:0)]
@@ -133,8 +133,8 @@ def test_multiple_outputs(self):
133133
str(trace)
134134
== dedent(
135135
"""
136-
a = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
137-
b = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
136+
a = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
137+
b = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
138138
c = a + b
139139
d = c + c
140140
outputs:
@@ -168,8 +168,8 @@ def test_all_inputs(self):
168168
== dedent(
169169
"""
170170
inputs:
171-
a: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
172-
b: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
171+
a: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
172+
b: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
173173
c = a + b
174174
outputs:
175175
c: [rank=(1), dtype=(float32), loc=(gpu:0)]
@@ -191,8 +191,8 @@ def test_const_and_input(self):
191191
== dedent(
192192
"""
193193
inputs:
194-
a: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
195-
b = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
194+
a: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
195+
b = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
196196
c = a + b
197197
outputs:
198198
c: [rank=(1), dtype=(float32), loc=(gpu:0)]

tripy/tests/integration/test_allclose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class TestAllClose:
3535
],
3636
)
3737
def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol):
38-
np_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
38+
torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
3939
tp_result = tp.allclose(
4040
tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol
4141
)
42-
assert np_result == tp_result
42+
assert torch_result == tp_result

tripy/tests/integration/test_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def test_non_constant_scale(self):
118118
scale = tp.ones((4,))
119119
quantized = tp.quantize(input, scale, tp.int8, dim=0)
120120

121-
assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
121+
assert tp.allclose(quantized, tp.ones((4, 4), dtype=tp.int8), rtol=0.0, atol=0.0)

tripy/tripy/backend/api/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def add(a, b):
147147
"""
148148

149149
shapes = []
150+
strides = []
150151
trace_input_map = {}
151152
input_names = set()
152153

@@ -162,6 +163,7 @@ def process_arg(name, arg):
162163

163164
trace_input_map[name] = tensor
164165
shapes.append(arg.shape_bounds)
166+
strides.append(arg.stride)
165167
input_names.add(name)
166168

167169
return tensor
@@ -196,7 +198,7 @@ def process_arg(name, arg):
196198

197199
# Order of trace inputs also needs to match that of the compiled_arg_names
198200
trace_inputs = [trace_input_map[name] for name in compiled_arg_names]
199-
trace = Trace(trace_outputs, trace_inputs, shapes=shapes)
201+
trace = Trace(trace_outputs, trace_inputs, shapes=shapes, strides=strides)
200202

201203
flat_ir = trace.to_flat_ir()
202204
mlir = flat_ir.to_mlir()

tripy/tripy/backend/api/executable.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def add(a, b):
161161
tensor,
162162
],
163163
)
164+
elif "Runtime stride mismatch" in str(err):
165+
# Just raise the error for now.
166+
raise raise_error(str(err))
167+
164168
raise
165169

166170
from tripy.utils.stack_info import StackInfo
@@ -175,10 +179,11 @@ def _get_arg_info(self, idx):
175179
arg = runtime.MemRefType(arg)
176180
arg_bound = self._executable_signature.get_arg_bound(idx)
177181
shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max()))
182+
stride = arg.strides
178183
if len(shape_bounds) == 0:
179184
# For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape
180185
shape_bounds = tuple((x, x) for x in arg.shape)
181-
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
186+
return ArgInfo(shape_bounds, stride, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
182187

183188
def get_input_info(self) -> Sequence[ArgInfo]:
184189
"""

tripy/tripy/backend/api/input_info.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ class InputInfo:
2929
"""
3030

3131
def __init__(
32-
self, shape: Sequence[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]], dtype: "tripy.dtype"
32+
self,
33+
shape: Sequence[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]],
34+
dtype: "tripy.dtype",
35+
stride: Sequence[int] = None,
3336
) -> None:
3437
"""
3538
Args:
@@ -88,10 +91,13 @@ def __init__(
8891
max_shape.append(elem[2])
8992

9093
self.shape_bounds = ShapeBounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape))
94+
self.stride = stride
9195
self.dtype = dtype
9296

9397
def __str__(self) -> str:
94-
return f"InputInfo(min={self.shape_bounds.min}, opt={self.shape_bounds.opt}, max={self.shape_bounds.max}, dtype={self.dtype})"
98+
base_info = f"InputInfo(min={self.shape_bounds.min}, opt={self.shape_bounds.opt}, max={self.shape_bounds.max}"
99+
stride_info = f", stride={self.stride}" if self.stride is not None else ""
100+
return f"{base_info}{stride_info}, dtype={self.dtype})"
95101

96102

97103
# TODO(MLIR-TRT #923): Can generalize `InputInfo` and drop this class.
@@ -100,5 +106,7 @@ def __str__(self) -> str:
100106
class ArgInfo:
101107
shape_bounds: Sequence[Tuple[int, int]]
102108
"""A sequence of tuple(min, max) indicating the bounds of each dimension"""
109+
stride: Sequence[int]
110+
"""A sequence of integers indicating stride"""
103111
dtype: "tripy.dtype"
104112
"""The datatype of the argument"""

tripy/tripy/backend/mlir/executor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
114114
is_static_shape = all(dim >= 0 for dim in memref.shape)
115115
if is_static_shape:
116116
outputs_tensor_info.append(
117-
TensorInfo(len(memref.shape), tuple(memref.shape), dtype, device(device_type))
117+
TensorInfo(
118+
len(memref.shape), tuple(memref.shape), tuple(memref.strides), dtype, device(device_type)
119+
)
118120
)
119121
else:
120122
runtime_shape = [
@@ -124,6 +126,7 @@ def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
124126
TensorInfo(
125127
len(runtime_shape),
126128
tuple(runtime_shape),
129+
tuple(memref.strides),
127130
dtype,
128131
device(device_type),
129132
)
@@ -174,7 +177,10 @@ def execute(self, output_devices=List[device], inputs: List["Tensor"] = []) -> L
174177
# Allocate output memory and store buffer pointers.
175178
outputs = [
176179
create_empty_memref(
177-
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
180+
shape=info.shape,
181+
dtype=info.dtype,
182+
device=info.device,
183+
stream=self.stream._active_cuda_stream,
178184
)
179185
for info in out_tensor_info
180186
]

0 commit comments

Comments
 (0)