Skip to content

Commit 96bdf1d

Browse files
committed
Add checks for non-canonical strides
MLIR-TensorRT requires strides for function arguments and results in canonical order. NVIDIA#252 adds a check to validate memref stride against a canonical stride order. In Tripy, memref strides are derived from framework DL Pack tensors. Creating a memref with a non-canonical DL Pack tensor stride throws an exception. Add a try-catch block to catch such an exception and augment with suggestions on creating a DL Pack tensor with canonical stride for Tripy-supported frameworks. Add unit tests to create a non-canonical stride tensor to validate exceptions and suggestions.
1 parent 342dec5 commit 96bdf1d

File tree

6 files changed

+127
-6
lines changed

6 files changed

+127
-6
lines changed

tripy/tests/backend/test_compiler_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def test_function(self):
185185
inp = tp.ones((2, 2), dtype=tp.float32)
186186
out = compiled_gelu(inp)
187187

188-
# TODO (#225): Replace with tp.all
189-
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(tp.relu(inp)))
188+
assert tp.allclose(out, tp.relu(inp), rtol=0.0, atol=0.0)
190189

191190
def test_module(self):
192191
layernorm = tp.LayerNorm(2)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import re
20+
21+
import cupy as cp
22+
import numpy as np
23+
import torch
24+
25+
import tripy as tp
26+
from tests.helper import raises
27+
28+
29+
class TestStride:
30+
def assert_error_message(self, excinfo, tensor_type, expected_suggestion):
31+
error_message = str(excinfo.value)
32+
assert "Non-canonical strides are not supported for Tripy tensors." in error_message
33+
assert f"For {tensor_type}, use {expected_suggestion}" in error_message
34+
35+
def tripy_byte_order_strides(self, data):
36+
return tuple(s * data.dtype.itemsize for s in tp.Tensor(data).stride())
37+
38+
def test_non_canonical_stride(self):
39+
# PyTorch test
40+
t_torch = torch.arange(12, dtype=torch.float32).reshape(3, 4)
41+
a_torch = t_torch.transpose(0, 1)
42+
with pytest.raises(tp.TripyException) as excinfo:
43+
tp.Tensor(a_torch)
44+
self.assert_error_message(excinfo, "PyTorch Tensor", "tensor.contiguous() or tensor.clone()")
45+
46+
# No exception is thrown.
47+
print(tp.Tensor(a_torch.contiguous()))
48+
print(tp.Tensor(a_torch.clone(memory_format=torch.contiguous_format)))
49+
50+
# CuPy test
51+
t_cupy = cp.arange(12, dtype=cp.float32).reshape(3, 4)
52+
a_cupy = t_cupy.transpose(1, 0)
53+
with pytest.raises(tp.TripyException) as excinfo:
54+
tp.Tensor(a_cupy)
55+
self.assert_error_message(excinfo, "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')")
56+
57+
print(tp.Tensor(cp.ascontiguousarray(a_cupy)))
58+
print(tp.Tensor(a_cupy.copy(order="C")))
59+
60+
# NumPy test
61+
t_numpy = np.arange(12, dtype=np.float32).reshape(3, 4)
62+
a_numpy = t_numpy.transpose(1, 0)
63+
with pytest.raises(tp.TripyException) as excinfo:
64+
tp.Tensor(a_numpy)
65+
self.assert_error_message(excinfo, "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')")
66+
67+
print(tp.Tensor(np.ascontiguousarray(a_numpy)))
68+
print(tp.Tensor(a_numpy.copy(order="C")))

tripy/tests/integration/test_allclose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class TestAllClose:
3535
],
3636
)
3737
def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol):
38-
np_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
38+
torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
3939
tp_result = tp.allclose(
4040
tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol
4141
)
42-
assert np_result == tp_result
42+
assert torch_result == tp_result

tripy/tests/integration/test_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def test_non_constant_scale(self):
118118
scale = tp.ones((4,))
119119
quantized = tp.quantize(input, scale, tp.int8, dim=0)
120120

121-
assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
121+
assert tp.allclose(quantized, tp.ones((4, 4), dtype=tp.int8), rtol=0.0, atol=0.0)

tripy/tripy/backend/api/executable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def add(a, b):
161161
tensor,
162162
],
163163
)
164+
elif "Runtime stride mismatch" in str(err):
165+
# Just raise the error for now.
166+
raise raise_error(str(err))
167+
164168
raise
165169

166170
from tripy.utils.stack_info import StackInfo

tripy/tripy/backend/mlir/memref.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# limitations under the License.
1616
#
1717

18+
import re
19+
1820
from functools import lru_cache
1921
from typing import Sequence
2022

23+
from tripy.utils import raise_error
2124
from tripy.backend.mlir import utils as mlir_utils
2225
from tripy.common import device as tp_device
2326
from tripy.common import utils as common_utils
@@ -66,7 +69,54 @@ def create_memref_view(data):
6669
"""
6770
Creates a memref view of an array object that implements the dlpack interface.
6871
"""
69-
return mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
72+
try:
73+
memref = mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
74+
except runtime.MTRTException as e:
75+
error_msg = str(e)
76+
match = re.search(
77+
r"Given strides \[([\d, ]+)\] do not match canonical strides \[([\d, ]+)\] for shape \[([\d, ]+)\]",
78+
error_msg,
79+
)
80+
81+
if match:
82+
given_strides = [int(s) for s in match.group(1).split(",")]
83+
canonical_strides = [int(s) for s in match.group(2).split(",")]
84+
shape = [int(s) for s in match.group(3).split(",")]
85+
86+
def check_tensor_type_and_suggest_contiguous(obj):
87+
obj_type = str(type(obj))
88+
if "torch.Tensor" in obj_type:
89+
return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"
90+
elif "jaxlib" in obj_type or "jax.numpy" in obj_type:
91+
return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)"
92+
elif "numpy.ndarray" in obj_type:
93+
return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"
94+
elif "cupy.ndarray" in obj_type:
95+
return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"
96+
else:
97+
return None, None
98+
99+
tensor_type, contiguous_suggestion = check_tensor_type_and_suggest_contiguous(data)
100+
101+
error_message = (
102+
f"Non-canonical strides detected:\n"
103+
f" Shape: {shape}\n"
104+
f" Current stride: {given_strides}\n"
105+
f" Expected canonical stride: {canonical_strides}\n"
106+
f"Non-canonical strides are not supported for Tripy tensors. "
107+
f"This usually occurs when the tensor is not contiguous in memory. "
108+
+ (
109+
f"To resolve this issue:\n"
110+
f"For {tensor_type}, use {contiguous_suggestion} to ensure contiguity before converting to a Tripy tensor."
111+
if tensor_type is not None
112+
else ""
113+
)
114+
)
115+
raise_error(error_message)
116+
else:
117+
# If the error message doesn't match the expected format, re-raise the original exception
118+
raise
119+
return memref
70120

71121

72122
# TODO(#134): Consider move below functions to MLIR py bindings

0 commit comments

Comments
 (0)