Skip to content

Commit cc2b071

Browse files
committed
Add checks for non-canonical strides
MLIR-TensorRT requires strides for function arguments and results in canonical order. Add a check in storage op to ensure strides are in canonical order. This will throw an exception for both compile mode (at execution time) and lazy mode when inputs are materialized to a storage op with a non-canonical stride order. This PR adds the following changes: - Add a stride field to Trace tensor. This materializes when populating a storage op with framework tensor. - Add an assert check while lowering data memref to a `stablehlo.constant` op. - Add helper methods for stride check: `get_canonical_stride`, `are_strides_equivalent`, and is_canonical_stride. Add corresponding tests. - Add test_stride.py validating stride for framework tensors.
1 parent 59b9536 commit cc2b071

File tree

14 files changed

+282
-7
lines changed

14 files changed

+282
-7
lines changed

tripy/tests/backend/test_compiler_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def test_function(self):
185185
inp = tp.ones((2, 2), dtype=tp.float32)
186186
out = compiled_gelu(inp)
187187

188-
# TODO (#225): Replace with tp.all
189-
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(tp.relu(inp)))
188+
assert tp.allclose(out, tp.relu(inp), rtol=0.0, atol=0.0)
190189

191190
def test_module(self):
192191
layernorm = tp.LayerNorm(2)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import re
20+
21+
import cupy as cp
22+
import numpy as np
23+
import torch
24+
25+
import tripy as tp
26+
from tripy.utils import are_strides_equivalent
27+
from tests.helper import raises
28+
29+
30+
class TestStride:
31+
def assert_error_message(self, excinfo, tensor_type, expected_suggestion):
32+
error_message = str(excinfo.value)
33+
assert "Non-canonical strides are not supported for Tripy tensors." in error_message
34+
assert f"For {tensor_type}, use {expected_suggestion}" in error_message
35+
36+
def tripy_byte_order_strides(self, data):
37+
return tuple(s * data.dtype.itemsize for s in tp.Tensor(data).stride())
38+
39+
def test_non_canonical_stride(self):
40+
# PyTorch test
41+
t_torch = torch.arange(12, dtype=torch.float32).reshape(3, 4)
42+
a_torch = t_torch.transpose(0, 1)
43+
with pytest.raises(tp.TripyException) as excinfo:
44+
tp.Tensor(a_torch)
45+
self.assert_error_message(excinfo, "PyTorch Tensor", "tensor.contiguous() or tensor.clone()")
46+
47+
assert tp.Tensor(a_torch.contiguous()).stride() == a_torch.contiguous().stride()
48+
assert (
49+
tp.Tensor(a_torch.clone(memory_format=torch.contiguous_format)).stride()
50+
== a_torch.clone(memory_format=torch.contiguous_format).stride()
51+
)
52+
53+
# CuPy test
54+
t_cupy = cp.arange(12, dtype=cp.float32).reshape(3, 4)
55+
a_cupy = t_cupy.transpose(1, 0)
56+
with pytest.raises(tp.TripyException) as excinfo:
57+
tp.Tensor(a_cupy)
58+
self.assert_error_message(excinfo, "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')")
59+
60+
# CuPy and NumPy's strides attribute returns byte-based strides in their frontend APIs.
61+
# However, when these arrays are converted to DLPack tensors (which Tripy uses internally),
62+
# the strides are represented as element-wise strides.
63+
# As a result, Tripy's `get_canonical_stride` method produces element-wise strides
64+
# that match the expected strides of a memref.
65+
# Multiply Tripy's strides by the item size when comparing to the original CuPy or NumPy strides here.
66+
assert self.tripy_byte_order_strides(cp.ascontiguousarray(a_cupy)) == cp.ascontiguousarray(a_cupy).strides
67+
assert self.tripy_byte_order_strides(a_cupy.copy(order="C")) == a_cupy.copy(order="C").strides
68+
69+
# NumPy test
70+
t_numpy = np.arange(12, dtype=np.float32).reshape(3, 4)
71+
a_numpy = t_numpy.transpose(1, 0)
72+
with pytest.raises(tp.TripyException) as excinfo:
73+
tp.Tensor(a_numpy)
74+
self.assert_error_message(excinfo, "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')")
75+
76+
assert self.tripy_byte_order_strides(np.ascontiguousarray(a_numpy)) == np.ascontiguousarray(a_numpy).strides
77+
assert self.tripy_byte_order_strides(a_numpy.copy(order="C")) == a_numpy.copy(order="C").strides
78+
79+
# Test for canonical strides (should not raise an exception)
80+
assert tp.Tensor(t_torch).stride() == t_torch.stride()
81+
assert self.tripy_byte_order_strides(t_cupy) == t_cupy.strides
82+
assert self.tripy_byte_order_strides(t_numpy) == t_numpy.strides
83+
84+
@pytest.mark.parametrize(
85+
"shape",
86+
[
87+
(0,),
88+
(0, 3),
89+
(2, 0, 4),
90+
(3, 0, 0, 5),
91+
(1,),
92+
(1, 3),
93+
(2, 1, 4),
94+
(1, 3, 1, 5),
95+
(3, 1, 1, 1),
96+
(0, 1, 3),
97+
(2, 0, 1, 4),
98+
(1, 0, 3, 1, 5),
99+
(3, 1, 0, 2, 1),
100+
],
101+
)
102+
def test_tensor_stride(self, shape):
103+
torch_tensor = torch.empty(shape)
104+
torch_stride = torch_tensor.stride()
105+
tripy_stride = tp.Tensor(torch_tensor).stride()
106+
107+
assert are_strides_equivalent(
108+
shape, tripy_stride, torch_stride
109+
), f"Mismatch for shape {shape}. Calculated: {tripy_stride}, Torch: {torch_stride}"

tripy/tests/integration/test_allclose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class TestAllClose:
3535
],
3636
)
3737
def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol):
38-
np_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
38+
torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
3939
tp_result = tp.allclose(
4040
tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol
4141
)
42-
assert np_result == tp_result
42+
assert torch_result == tp_result

tripy/tests/integration/test_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def test_non_constant_scale(self):
118118
scale = tp.ones((4,))
119119
quantized = tp.quantize(input, scale, tp.int8, dim=0)
120120

121-
assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
121+
assert tp.allclose(quantized, tp.ones((4, 4), dtype=tp.int8), rtol=0.0, atol=0.0)

tripy/tests/utils/test_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tripy import utils
2222
from tests import helper
2323
from collections import defaultdict
24+
from typing import Optional, Sequence
2425

2526

2627
class TestMd5:
@@ -109,3 +110,45 @@ def test_gen_uid(self, inputs, outputs, expected_prefix):
109110
def test_uniqueness(self):
110111
uids = [utils.UniqueNameGen.gen_uid() for _ in range(100)]
111112
assert len(set(uids)) == 100
113+
114+
115+
class TestStride:
116+
@pytest.mark.parametrize(
117+
"shape, expected_stride",
118+
[
119+
((), ()),
120+
((1,), (1,)),
121+
((2, 3), (3, 1)),
122+
((0, 5), (1, 1)),
123+
((1, 0, 3), (1, 1, 1)),
124+
((2, 1, 4), (4, 1, 1)),
125+
((3, 0, 0, 5), (5, 1, 1, 1)),
126+
((2, 3), (3, 1)),
127+
((0, 5), (1, 1)),
128+
((2, 0, 3), (3, 1, 1)),
129+
(None, None),
130+
],
131+
)
132+
def test_get_stride(
133+
self,
134+
shape: Optional[Sequence[int]],
135+
expected_stride: Optional[Sequence[int]],
136+
):
137+
if shape is not None:
138+
assert utils.get_canonical_stride(shape) == expected_stride
139+
140+
@pytest.mark.parametrize(
141+
"shape, stride, expected_stride, expected_result",
142+
[
143+
((2, 3), (3, 1), (3, 1), True),
144+
((2, 3), (6, 2), (3, 1), False),
145+
((0, 5), (1, 1), (5, 1), True),
146+
((1, 0, 3), (0, 3, 1), (3, 3, 1), True),
147+
((2, 1, 4), (8, 4, 1), (4, 1, 1), False),
148+
((2, 3), (3, 1, 2), (3, 1), False), # Mismatched lengths
149+
],
150+
)
151+
def test_are_strides_equivalent(
152+
self, shape: Sequence[int], stride: Sequence[int], expected_stride: Sequence[int], expected_result: bool
153+
):
154+
assert utils.are_strides_equivalent(shape, stride, expected_stride) == expected_result

tripy/tripy/backend/api/executable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def add(a, b):
161161
tensor,
162162
],
163163
)
164+
elif "Runtime stride mismatch" in str(err):
165+
# Just raise the error for now.
166+
raise raise_error(str(err))
167+
164168
raise
165169

166170
from tripy.utils.stack_info import StackInfo

tripy/tripy/backend/mlir/memref.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from functools import lru_cache
1919
from typing import Sequence
2020

21+
from tripy import utils as tripy_utils
2122
from tripy.backend.mlir import utils as mlir_utils
2223
from tripy.common import device as tp_device
2324
from tripy.common import utils as common_utils
@@ -65,7 +66,37 @@ def create_memref_view(data):
6566
"""
6667
Creates a memref view of an array object that implements the dlpack interface.
6768
"""
68-
return mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
69+
memref = mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
70+
71+
def check_tensor_type_and_suggest_contiguous(obj):
72+
obj_type = str(type(obj))
73+
if "torch.Tensor" in obj_type:
74+
return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"
75+
elif "jaxlib" in obj_type or "jax.numpy" in obj_type:
76+
return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)"
77+
elif "numpy.ndarray" in obj_type:
78+
return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"
79+
elif "cupy.ndarray" in obj_type:
80+
return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"
81+
else:
82+
return "Unknown Type", "Cannot suggest method for unknown type"
83+
84+
tensor_type, contiguous_suggestion = check_tensor_type_and_suggest_contiguous(data)
85+
86+
if not tripy_utils.is_canonical_stride(memref.shape, memref.strides):
87+
canonical_stride = tripy_utils.get_canonical_stride(memref.shape)
88+
error_message = (
89+
f"Non-canonical strides detected:\n"
90+
f" Shape: {memref.shape}\n"
91+
f" Current stride: {tripy_utils.make_tuple(memref.strides)}\n"
92+
f" Expected canonical stride: {canonical_stride}\n"
93+
f"Non-canonical strides are not supported for Tripy tensors. "
94+
f"This usually occurs when the tensor is not contiguous in memory. "
95+
f"To resolve this issue:\n"
96+
f"For {tensor_type}, use {contiguous_suggestion} to ensure contiguity before converting to a Tripy tensor."
97+
)
98+
tripy_utils.raise_error(error_message)
99+
return memref
69100

70101

71102
# TODO(#134): Consider move below functions to MLIR py bindings

tripy/tripy/backend/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def encode(tensor_info: TensorInfo) -> Dict[str, Any]:
3636
return {
3737
"rank": tensor_info.rank,
3838
"shape": tensor_info.shape,
39+
"stride": tensor_info.stride,
3940
"dtype": tensor_info.dtype,
4041
"device": tensor_info.device,
4142
}

tripy/tripy/flat_ir/ops/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def to_mlir(self, operands):
7272
constant_op = stablehlo.ConstantOp(attr)
7373
return [stablehlo.ConvertOp(result=cast_output, operand=constant_op)]
7474

75+
assert utils.is_canonical_stride(data_memref.shape, utils.make_tuple(data_memref.strides))
7576
attr = ir.DenseElementsAttr.get(
7677
array=data_memref, type=mlir_utils.get_mlir_dtype(self.outputs[0].dtype), shape=data_memref.shape
7778
)

tripy/tripy/flat_ir/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class FlatIRTensor:
3838
rank: int
3939
producer: "BaseFlatIROp" = None
4040
shape: Optional[List[int]] = None
41+
stride: Optional[List[int]] = None
4142
reason_details: Optional[List[Any]] = None
4243
"""
4344
Describes why this tensor was created.

0 commit comments

Comments
 (0)