Skip to content

Commit 13dde1b

Browse files
authored
Add checks for non-canonical strides (#273)
MLIR-TensorRT requires strides for function arguments and results in canonical order. #252 adds a check to validate memref stride against a canonical stride order. In Tripy, memref strides are derived from framework DL Pack tensors. Creating a memref with a non-canonical DL Pack tensor stride throws an exception. Add a try-catch block to catch such an exception and augment with suggestions on creating a DL Pack tensor with canonical stride for Tripy-supported frameworks. Add unit tests to create a non-canonical stride tensor to validate exceptions and suggestions.
1 parent 60ccc43 commit 13dde1b

File tree

7 files changed

+148
-4
lines changed

7 files changed

+148
-4
lines changed

tripy/tests/backend/mlir/test_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import pytest
1919
from mlir_tensorrt.compiler import ir
2020

21+
import cupy as cp
22+
import numpy as np
23+
import torch
24+
2125
import tripy
2226
from tripy.backend.mlir import utils as mlir_utils
2327
from tripy.common.datatype import DATA_TYPES
@@ -47,3 +51,17 @@ def test_convert_dtype(self, dtype):
4751
"bool": ir.IntegerType.get_signless(1),
4852
}[dtype.name]
4953
)
54+
55+
@pytest.mark.parametrize(
56+
"tensor, expected_type, expected_suggestion",
57+
[
58+
(torch.tensor([1, 2, 3]), "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"),
59+
(np.array([1, 2, 3]), "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"),
60+
(cp.array([1, 2, 3]), "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"),
61+
([1, 2, 3], None, None),
62+
],
63+
)
64+
def test_check_tensor_type_and_suggest_contiguous(self, tensor, expected_type, expected_suggestion):
65+
result_type, result_suggestion = mlir_utils.check_tensor_type_and_suggest_contiguous(tensor)
66+
assert result_type == expected_type
67+
assert result_suggestion == expected_suggestion
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import re
20+
21+
import cupy as cp
22+
import numpy as np
23+
import torch
24+
25+
import tripy as tp
26+
from tests.helper import raises
27+
28+
29+
class TestStride:
30+
def test_non_canonical_stride(self):
31+
test_cases = [
32+
(
33+
torch.arange(12, dtype=torch.float32).reshape(3, 4).transpose(0, 1),
34+
lambda x: x.contiguous(),
35+
lambda x: x.clone(memory_format=torch.contiguous_format),
36+
),
37+
(
38+
cp.arange(12, dtype=cp.float32).reshape(3, 4).transpose(1, 0),
39+
cp.ascontiguousarray,
40+
lambda x: x.copy(order="C"),
41+
),
42+
(
43+
np.arange(12, dtype=np.float32).reshape(3, 4).transpose(1, 0),
44+
np.ascontiguousarray,
45+
lambda x: x.copy(order="C"),
46+
),
47+
]
48+
49+
for array, contiguous_func, copy_func in test_cases:
50+
# Test for exception with non-canonical strides
51+
with pytest.raises(tp.TripyException, match="Non-canonical strides are not supported for Tripy tensors"):
52+
tp.Tensor(array)
53+
54+
# Test successful creation with contiguous array
55+
assert tp.Tensor(contiguous_func(array)) is not None
56+
assert tp.Tensor(copy_func(array)) is not None

tripy/tests/integration/test_allclose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class TestAllClose:
3535
],
3636
)
3737
def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol):
38-
np_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
38+
torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
3939
tp_result = tp.allclose(
4040
tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol
4141
)
42-
assert np_result == tp_result
42+
assert torch_result == tp_result

tripy/tests/integration/test_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def test_non_constant_scale(self):
118118
scale = tp.ones((4,))
119119
quantized = tp.quantize(input, scale, tp.int8, dim=0)
120120

121-
assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
121+
assert tp.allclose(quantized, tp.ones((4, 4), dtype=tp.int8), rtol=0.0, atol=0.0)

tripy/tripy/backend/api/executable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def add(a, b):
159159
tensor,
160160
],
161161
)
162+
elif "Runtime stride mismatch" in str(err):
163+
# Just raise the error for now.
164+
raise raise_error(str(err))
165+
162166
raise
163167

164168
output_tensors = [Tensor(output, fetch_stack_info=False) for output in executor_outputs]

tripy/tripy/backend/mlir/memref.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# limitations under the License.
1616
#
1717

18+
import re
19+
1820
from functools import lru_cache
1921
from typing import Sequence
2022

23+
from tripy.utils import raise_error
2124
from tripy.backend.mlir import utils as mlir_utils
2225
from tripy.common import device as tp_device
2326
from tripy.common import utils as common_utils
@@ -66,7 +69,56 @@ def create_memref_view(data):
6669
"""
6770
Creates a memref view of an array object that implements the dlpack interface.
6871
"""
69-
return mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
72+
try:
73+
memref = mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(
74+
data.__dlpack__(), assert_canonical_strides=True
75+
)
76+
except runtime.MTRTException as e:
77+
error_msg = str(e)
78+
match = re.search(
79+
r"Given strides \[([\d, ]+)\] do not match canonical strides \[([\d, ]+)\] for shape \[([\d, ]+)\]",
80+
error_msg,
81+
)
82+
83+
if match:
84+
given_strides = [int(s) for s in match.group(1).split(",")]
85+
canonical_strides = [int(s) for s in match.group(2).split(",")]
86+
shape = [int(s) for s in match.group(3).split(",")]
87+
88+
def check_tensor_type_and_suggest_contiguous(obj):
89+
obj_type = str(type(obj))
90+
if "torch.Tensor" in obj_type:
91+
return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"
92+
elif "jaxlib" in obj_type or "jax.numpy" in obj_type:
93+
return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)"
94+
elif "numpy.ndarray" in obj_type:
95+
return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"
96+
elif "cupy.ndarray" in obj_type:
97+
return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"
98+
else:
99+
return None, None
100+
101+
tensor_type, contiguous_suggestion = check_tensor_type_and_suggest_contiguous(data)
102+
103+
error_message = (
104+
f"Non-canonical strides detected:\n"
105+
f" Shape: {shape}\n"
106+
f" Current stride: {given_strides}\n"
107+
f" Expected canonical stride: {canonical_strides}\n"
108+
f"Non-canonical strides are not supported for Tripy tensors. "
109+
f"This usually occurs when the tensor is not contiguous in memory. "
110+
+ (
111+
f"To resolve this issue:\n"
112+
f"For {tensor_type}, use {contiguous_suggestion} to ensure contiguity before converting to a Tripy tensor."
113+
if tensor_type is not None
114+
else ""
115+
)
116+
)
117+
raise_error(error_message)
118+
else:
119+
# If the error message doesn't match the expected format, re-raise the original exception
120+
raise
121+
return memref
70122

71123

72124
# TODO(#134): Consider move below functions to MLIR py bindings

tripy/tripy/backend/mlir/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ def get_constant_value(arg) -> Optional[ir.DenseElementsAttr]:
172172
return None
173173

174174

175+
def check_tensor_type_and_suggest_contiguous(obj):
176+
obj_type = str(type(obj))
177+
if "torch.Tensor" in obj_type:
178+
return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"
179+
elif "jaxlib" in obj_type or "jax.numpy" in obj_type:
180+
return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)"
181+
elif "numpy.ndarray" in obj_type:
182+
return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"
183+
elif "cupy.ndarray" in obj_type:
184+
return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"
185+
else:
186+
return None, None
187+
188+
175189
def remove_sym_attr(mlir_text: str) -> str:
176190
return re.sub(r"module @\S+ {", "module {", mlir_text)
177191

0 commit comments

Comments
 (0)