Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tripy/tests/backend/mlir/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import pytest
from mlir_tensorrt.compiler import ir

import cupy as cp
import numpy as np
import torch

import tripy
from tripy.backend.mlir import utils as mlir_utils
from tripy.common.datatype import DATA_TYPES
Expand Down Expand Up @@ -47,3 +51,17 @@ def test_convert_dtype(self, dtype):
"bool": ir.IntegerType.get_signless(1),
}[dtype.name]
)

@pytest.mark.parametrize(
"tensor, expected_type, expected_suggestion",
[
(torch.tensor([1, 2, 3]), "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"),
(np.array([1, 2, 3]), "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"),
(cp.array([1, 2, 3]), "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"),
([1, 2, 3], None, None),
],
)
def test_check_tensor_type_and_suggest_contiguous(self, tensor, expected_type, expected_suggestion):
result_type, result_suggestion = mlir_utils.check_tensor_type_and_suggest_contiguous(tensor)
assert result_type == expected_type
assert result_suggestion == expected_suggestion
56 changes: 56 additions & 0 deletions tripy/tests/frontend/test_stride.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest
import re

import cupy as cp
import numpy as np
import torch

import tripy as tp
from tests.helper import raises


class TestStride:
def test_non_canonical_stride(self):
test_cases = [
(
torch.arange(12, dtype=torch.float32).reshape(3, 4).transpose(0, 1),
lambda x: x.contiguous(),
lambda x: x.clone(memory_format=torch.contiguous_format),
),
(
cp.arange(12, dtype=cp.float32).reshape(3, 4).transpose(1, 0),
cp.ascontiguousarray,
lambda x: x.copy(order="C"),
),
(
np.arange(12, dtype=np.float32).reshape(3, 4).transpose(1, 0),
np.ascontiguousarray,
lambda x: x.copy(order="C"),
),
]

for array, contiguous_func, copy_func in test_cases:
# Test for exception with non-canonical strides
with pytest.raises(tp.TripyException, match="Non-canonical strides are not supported for Tripy tensors"):
tp.Tensor(array)

# Test successful creation with contiguous array
assert tp.Tensor(contiguous_func(array)) is not None
assert tp.Tensor(copy_func(array)) is not None
4 changes: 2 additions & 2 deletions tripy/tests/integration/test_allclose.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class TestAllClose:
],
)
def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol):
np_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol)
tp_result = tp.allclose(
tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol
)
assert np_result == tp_result
assert torch_result == tp_result
2 changes: 1 addition & 1 deletion tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ def test_non_constant_scale(self):
scale = tp.ones((4,))
quantized = tp.quantize(input, scale, tp.int8, dim=0)

assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
assert tp.allclose(quantized, tp.ones((4, 4), dtype=tp.int8), rtol=0.0, atol=0.0)
4 changes: 4 additions & 0 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def add(a, b):
tensor,
],
)
elif "Runtime stride mismatch" in str(err):
# Just raise the error for now.
raise raise_error(str(err))

raise

output_tensors = [Tensor(output, fetch_stack_info=False) for output in executor_outputs]
Expand Down
54 changes: 53 additions & 1 deletion tripy/tripy/backend/mlir/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# limitations under the License.
#

import re

from functools import lru_cache
from typing import Sequence

from tripy.utils import raise_error
from tripy.backend.mlir import utils as mlir_utils
from tripy.common import device as tp_device
from tripy.common import utils as common_utils
Expand Down Expand Up @@ -66,7 +69,56 @@ def create_memref_view(data):
"""
Creates a memref view of an array object that implements the dlpack interface.
"""
return mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__())
try:
memref = mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(
data.__dlpack__(), assert_canonical_strides=True
)
except runtime.MTRTException as e:
error_msg = str(e)
match = re.search(
r"Given strides \[([\d, ]+)\] do not match canonical strides \[([\d, ]+)\] for shape \[([\d, ]+)\]",
error_msg,
)

if match:
given_strides = [int(s) for s in match.group(1).split(",")]
canonical_strides = [int(s) for s in match.group(2).split(",")]
shape = [int(s) for s in match.group(3).split(",")]

def check_tensor_type_and_suggest_contiguous(obj):
obj_type = str(type(obj))
if "torch.Tensor" in obj_type:
return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"
elif "jaxlib" in obj_type or "jax.numpy" in obj_type:
return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)"
elif "numpy.ndarray" in obj_type:
return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"
elif "cupy.ndarray" in obj_type:
return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"
else:
return None, None

tensor_type, contiguous_suggestion = check_tensor_type_and_suggest_contiguous(data)

error_message = (
f"Non-canonical strides detected:\n"
f" Shape: {shape}\n"
f" Current stride: {given_strides}\n"
f" Expected canonical stride: {canonical_strides}\n"
f"Non-canonical strides are not supported for Tripy tensors. "
f"This usually occurs when the tensor is not contiguous in memory. "
+ (
f"To resolve this issue:\n"
f"For {tensor_type}, use {contiguous_suggestion} to ensure contiguity before converting to a Tripy tensor."
if tensor_type is not None
else ""
)
)
raise_error(error_message)
else:
# If the error message doesn't match the expected format, re-raise the original exception
raise
return memref


# TODO(#134): Consider move below functions to MLIR py bindings
Expand Down
14 changes: 14 additions & 0 deletions tripy/tripy/backend/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ def get_constant_value(arg) -> Optional[ir.DenseElementsAttr]:
return None


def check_tensor_type_and_suggest_contiguous(obj):
obj_type = str(type(obj))
if "torch.Tensor" in obj_type:
return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"
elif "jaxlib" in obj_type or "jax.numpy" in obj_type:
return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)"
elif "numpy.ndarray" in obj_type:
return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"
elif "cupy.ndarray" in obj_type:
return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"
else:
return None, None


def remove_sym_attr(mlir_text: str) -> str:
return re.sub(r"module @\S+ {", "module {", mlir_text)

Expand Down
Loading