Skip to content

Commit 3379879

Browse files
committed
Add checks for non-canonical strides
MLIR-TensorRT requires strides for function arguments and results in canonical order. Add a check in storage op to ensure strides are in canonical order. This will throw an exception for both compile mode (at execution time) and lazy mode when inputs are materialized to a storage op with a non-canonical stride order. This PR adds the following changes: - Add a stride field to Trace tensor. This materializes when populating a storage op with framework tensor. - Add an assert check while lowering data memref to a `stablehlo.constant` op. - Add helper methods for stride check: `get_canonical_stride`, `are_strides_equivalent`, and is_canonical_stride. Add corresponding tests. - Add test_stride.py validating stride for framework tensors.
1 parent 342dec5 commit 3379879

File tree

8 files changed

+125
-7
lines changed

8 files changed

+125
-7
lines changed

tripy/tests/backend/test_compiler_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def test_function(self):
185185
inp = tp.ones((2, 2), dtype=tp.float32)
186186
out = compiled_gelu(inp)
187187

188-
# TODO (#225): Replace with tp.all
189-
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(tp.relu(inp)))
188+
assert tp.allclose(out, tp.relu(inp), rtol=0.0, atol=0.0)
190189

191190
def test_module(self):
192191
layernorm = tp.LayerNorm(2)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import re
20+
21+
import cupy as cp
22+
import numpy as np
23+
import torch
24+
25+
import tripy as tp
26+
from tests.helper import raises
27+
28+
29+
class TestStride:
30+
def assert_error_message(self, excinfo, tensor_type, expected_suggestion):
31+
error_message = str(excinfo.value)
32+
assert "Non-canonical strides are not supported for Tripy tensors." in error_message
33+
assert f"For {tensor_type}, use {expected_suggestion}" in error_message
34+
35+
def tripy_byte_order_strides(self, data):
36+
return tuple(s * data.dtype.itemsize for s in tp.Tensor(data).stride())
37+
38+
def test_non_canonical_stride(self):
39+
# PyTorch test
40+
t_torch = torch.arange(12, dtype=torch.float32).reshape(3, 4)
41+
a_torch = t_torch.transpose(0, 1)
42+
with pytest.raises(tp.TripyException) as excinfo:
43+
tp.Tensor(a_torch)
44+
self.assert_error_message(excinfo, "PyTorch Tensor", "tensor.contiguous() or tensor.clone()")
45+
46+
# No exception is thrown.
47+
print(tp.Tensor(a_torch.contiguous()))
48+
print(tp.Tensor(a_torch.clone(memory_format=torch.contiguous_format)))
49+
50+
# CuPy test
51+
t_cupy = cp.arange(12, dtype=cp.float32).reshape(3, 4)
52+
a_cupy = t_cupy.transpose(1, 0)
53+
with pytest.raises(tp.TripyException) as excinfo:
54+
tp.Tensor(a_cupy)
55+
self.assert_error_message(excinfo, "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')")
56+
57+
print(tp.Tensor(cp.ascontiguousarray(a_cupy)))
58+
print(tp.Tensor(a_cupy.copy(order="C")))
59+
60+
# NumPy test
61+
t_numpy = np.arange(12, dtype=np.float32).reshape(3, 4)
62+
a_numpy = t_numpy.transpose(1, 0)
63+
with pytest.raises(tp.TripyException) as excinfo:
64+
tp.Tensor(a_numpy)
65+
self.assert_error_message(excinfo, "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')")
66+
67+
print(tp.Tensor(np.ascontiguousarray(a_numpy)))
68+
print(tp.Tensor(a_numpy.copy(order="C")))

tripy/tests/integration/test_allclose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class TestAllClose:
3535
],
3636
)
3737
def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol):
38-
np_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
38+
torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
3939
tp_result = tp.allclose(
4040
tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol
4141
)
42-
assert np_result == tp_result
42+
assert torch_result == tp_result

tripy/tests/integration/test_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def test_non_constant_scale(self):
118118
scale = tp.ones((4,))
119119
quantized = tp.quantize(input, scale, tp.int8, dim=0)
120120

121-
assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
121+
assert tp.allclose(quantized, tp.ones((4, 4), dtype=tp.int8), rtol=0.0, atol=0.0)

tripy/tripy/backend/api/executable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def add(a, b):
161161
tensor,
162162
],
163163
)
164+
elif "Runtime stride mismatch" in str(err):
165+
# Just raise the error for now.
166+
raise raise_error(str(err))
167+
164168
raise
165169

166170
from tripy.utils.stack_info import StackInfo

tripy/tripy/backend/mlir/memref.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# limitations under the License.
1616
#
1717

18+
import re
19+
1820
from functools import lru_cache
1921
from typing import Sequence
2022

23+
from tripy import utils as tripy_utils
2124
from tripy.backend.mlir import utils as mlir_utils
2225
from tripy.common import device as tp_device
2326
from tripy.common import utils as common_utils
@@ -66,7 +69,50 @@ def create_memref_view(data):
6669
"""
6770
Creates a memref view of an array object that implements the dlpack interface.
6871
"""
69-
return mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
72+
try:
73+
memref = mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
74+
except runtime.MTRTException as e:
75+
error_msg = str(e)
76+
match = re.search(
77+
r"Given strides \[([\d, ]+)\] do not match canonical strides \[([\d, ]+)\] for shape \[([\d, ]+)\]",
78+
error_msg,
79+
)
80+
81+
if match:
82+
given_strides = [int(s) for s in match.group(1).split(",")]
83+
canonical_strides = [int(s) for s in match.group(2).split(",")]
84+
shape = [int(s) for s in match.group(3).split(",")]
85+
86+
def check_tensor_type_and_suggest_contiguous(obj):
87+
obj_type = str(type(obj))
88+
if "torch.Tensor" in obj_type:
89+
return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"
90+
elif "jaxlib" in obj_type or "jax.numpy" in obj_type:
91+
return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)"
92+
elif "numpy.ndarray" in obj_type:
93+
return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"
94+
elif "cupy.ndarray" in obj_type:
95+
return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"
96+
else:
97+
return "Unknown Type", "Cannot suggest method for unknown type"
98+
99+
tensor_type, contiguous_suggestion = check_tensor_type_and_suggest_contiguous(data)
100+
101+
error_message = (
102+
f"Non-canonical strides detected:\n"
103+
f" Shape: {shape}\n"
104+
f" Current stride: {given_strides}\n"
105+
f" Expected canonical stride: {canonical_strides}\n"
106+
f"Non-canonical strides are not supported for Tripy tensors. "
107+
f"This usually occurs when the tensor is not contiguous in memory. "
108+
f"To resolve this issue:\n"
109+
f"For {tensor_type}, use {contiguous_suggestion} to ensure contiguity before converting to a Tripy tensor."
110+
)
111+
tripy_utils.raise_error(error_message)
112+
else:
113+
# If the error message doesn't match the expected format, re-raise the original exception
114+
raise
115+
return memref
70116

71117

72118
# TODO(#134): Consider move below functions to MLIR py bindings

tripy/tripy/backend/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def encode(tensor_info: TensorInfo) -> Dict[str, Any]:
3636
return {
3737
"rank": tensor_info.rank,
3838
"shape": tensor_info.shape,
39+
"stride": tensor_info.stride,
3940
"dtype": tensor_info.dtype,
4041
"device": tensor_info.device,
4142
}

tripy/tripy/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import os
2424
import time
2525
import typing
26-
from typing import Any, List, Sequence, Union
26+
from typing import Any, List, Optional, Sequence, Union
2727

2828
from colored import Fore, Style
2929

0 commit comments

Comments
 (0)