From 1c1a13f14a1353e3cfcb2cb3b0ab98f6ff2cdb6f Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Fri, 27 Sep 2024 23:39:31 -0700 Subject: [PATCH] Add checks for non-canonical strides MLIR-TensorRT requires strides for function arguments and results in canonical order. https://github.com/NVIDIA/TensorRT-Incubator/pull/252 adds a check to validate memref stride against a canonical stride order. In Tripy, memref strides are derived from framework DL Pack tensors. Creating a memref with a non-canonical DL Pack tensor stride throws an exception. Add a try-catch block to catch such an exception and augment with suggestions on creating a DL Pack tensor with canonical stride for Tripy-supported frameworks. Add unit tests to create a non-canonical stride tensor to validate exceptions and suggestions. --- tripy/tests/backend/mlir/test_utils.py | 18 ++++++++ tripy/tests/frontend/test_stride.py | 56 ++++++++++++++++++++++++ tripy/tests/integration/test_allclose.py | 4 +- tripy/tests/integration/test_quantize.py | 2 +- tripy/tripy/backend/api/executable.py | 4 ++ tripy/tripy/backend/mlir/memref.py | 54 ++++++++++++++++++++++- tripy/tripy/backend/mlir/utils.py | 14 ++++++ 7 files changed, 148 insertions(+), 4 deletions(-) create mode 100644 tripy/tests/frontend/test_stride.py diff --git a/tripy/tests/backend/mlir/test_utils.py b/tripy/tests/backend/mlir/test_utils.py index e28304c6d..832e514e7 100644 --- a/tripy/tests/backend/mlir/test_utils.py +++ b/tripy/tests/backend/mlir/test_utils.py @@ -18,6 +18,10 @@ import pytest from mlir_tensorrt.compiler import ir +import cupy as cp +import numpy as np +import torch + import tripy from tripy.backend.mlir import utils as mlir_utils from tripy.common.datatype import DATA_TYPES @@ -47,3 +51,17 @@ def test_convert_dtype(self, dtype): "bool": ir.IntegerType.get_signless(1), }[dtype.name] ) + + @pytest.mark.parametrize( + "tensor, expected_type, expected_suggestion", + [ + (torch.tensor([1, 2, 3]), "PyTorch Tensor", "tensor.contiguous() or tensor.clone()"), + (np.array([1, 2, 3]), "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')"), + (cp.array([1, 2, 3]), "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')"), + ([1, 2, 3], None, None), + ], + ) + def test_check_tensor_type_and_suggest_contiguous(self, tensor, expected_type, expected_suggestion): + result_type, result_suggestion = mlir_utils.check_tensor_type_and_suggest_contiguous(tensor) + assert result_type == expected_type + assert result_suggestion == expected_suggestion diff --git a/tripy/tests/frontend/test_stride.py b/tripy/tests/frontend/test_stride.py new file mode 100644 index 000000000..61940067a --- /dev/null +++ b/tripy/tests/frontend/test_stride.py @@ -0,0 +1,56 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import re + +import cupy as cp +import numpy as np +import torch + +import tripy as tp +from tests.helper import raises + + +class TestStride: + def test_non_canonical_stride(self): + test_cases = [ + ( + torch.arange(12, dtype=torch.float32).reshape(3, 4).transpose(0, 1), + lambda x: x.contiguous(), + lambda x: x.clone(memory_format=torch.contiguous_format), + ), + ( + cp.arange(12, dtype=cp.float32).reshape(3, 4).transpose(1, 0), + cp.ascontiguousarray, + lambda x: x.copy(order="C"), + ), + ( + np.arange(12, dtype=np.float32).reshape(3, 4).transpose(1, 0), + np.ascontiguousarray, + lambda x: x.copy(order="C"), + ), + ] + + for array, contiguous_func, copy_func in test_cases: + # Test for exception with non-canonical strides + with pytest.raises(tp.TripyException, match="Non-canonical strides are not supported for Tripy tensors"): + tp.Tensor(array) + + # Test successful creation with contiguous array + assert tp.Tensor(contiguous_func(array)) is not None + assert tp.Tensor(copy_func(array)) is not None diff --git a/tripy/tests/integration/test_allclose.py b/tripy/tests/integration/test_allclose.py index 0c2d9499a..cf83c4770 100644 --- a/tripy/tests/integration/test_allclose.py +++ b/tripy/tests/integration/test_allclose.py @@ -35,8 +35,8 @@ class TestAllClose: ], ) def test_all_close_float32(self, tensor_a, tensor_b, rtol, atol): - np_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol) + torch_result = torch.allclose(torch.FloatTensor(tensor_a), torch.FloatTensor(tensor_b), rtol=rtol, atol=atol) tp_result = tp.allclose( tp.Tensor(tensor_a, dtype=tp.float32), tp.Tensor(tensor_b, dtype=tp.float32), rtol=rtol, atol=atol ) - assert np_result == tp_result + assert torch_result == tp_result diff --git a/tripy/tests/integration/test_quantize.py b/tripy/tests/integration/test_quantize.py index b50293869..4ad6b2b96 100644 --- a/tripy/tests/integration/test_quantize.py +++ b/tripy/tests/integration/test_quantize.py @@ -118,4 +118,4 @@ def test_non_constant_scale(self): scale = tp.ones((4,)) quantized = tp.quantize(input, scale, tp.int8, dim=0) - assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8))) + assert tp.allclose(quantized, tp.ones((4, 4), dtype=tp.int8), rtol=0.0, atol=0.0) diff --git a/tripy/tripy/backend/api/executable.py b/tripy/tripy/backend/api/executable.py index 16de171f7..4bfe4b84e 100644 --- a/tripy/tripy/backend/api/executable.py +++ b/tripy/tripy/backend/api/executable.py @@ -159,6 +159,10 @@ def add(a, b): tensor, ], ) + elif "Runtime stride mismatch" in str(err): + # Just raise the error for now. + raise raise_error(str(err)) + raise output_tensors = [Tensor(output, fetch_stack_info=False) for output in executor_outputs] diff --git a/tripy/tripy/backend/mlir/memref.py b/tripy/tripy/backend/mlir/memref.py index 73f5f9f55..5c5c670c3 100644 --- a/tripy/tripy/backend/mlir/memref.py +++ b/tripy/tripy/backend/mlir/memref.py @@ -15,9 +15,12 @@ # limitations under the License. # +import re + from functools import lru_cache from typing import Sequence +from tripy.utils import raise_error from tripy.backend.mlir import utils as mlir_utils from tripy.common import device as tp_device from tripy.common import utils as common_utils @@ -66,7 +69,56 @@ def create_memref_view(data): """ Creates a memref view of an array object that implements the dlpack interface. """ - return mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__()) + try: + memref = mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack( + data.__dlpack__(), assert_canonical_strides=True + ) + except runtime.MTRTException as e: + error_msg = str(e) + match = re.search( + r"Given strides \[([\d, ]+)\] do not match canonical strides \[([\d, ]+)\] for shape \[([\d, ]+)\]", + error_msg, + ) + + if match: + given_strides = [int(s) for s in match.group(1).split(",")] + canonical_strides = [int(s) for s in match.group(2).split(",")] + shape = [int(s) for s in match.group(3).split(",")] + + def check_tensor_type_and_suggest_contiguous(obj): + obj_type = str(type(obj)) + if "torch.Tensor" in obj_type: + return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()" + elif "jaxlib" in obj_type or "jax.numpy" in obj_type: + return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)" + elif "numpy.ndarray" in obj_type: + return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')" + elif "cupy.ndarray" in obj_type: + return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')" + else: + return None, None + + tensor_type, contiguous_suggestion = check_tensor_type_and_suggest_contiguous(data) + + error_message = ( + f"Non-canonical strides detected:\n" + f" Shape: {shape}\n" + f" Current stride: {given_strides}\n" + f" Expected canonical stride: {canonical_strides}\n" + f"Non-canonical strides are not supported for Tripy tensors. " + f"This usually occurs when the tensor is not contiguous in memory. " + + ( + f"To resolve this issue:\n" + f"For {tensor_type}, use {contiguous_suggestion} to ensure contiguity before converting to a Tripy tensor." + if tensor_type is not None + else "" + ) + ) + raise_error(error_message) + else: + # If the error message doesn't match the expected format, re-raise the original exception + raise + return memref # TODO(#134): Consider move below functions to MLIR py bindings diff --git a/tripy/tripy/backend/mlir/utils.py b/tripy/tripy/backend/mlir/utils.py index 552987169..9d52cffbf 100644 --- a/tripy/tripy/backend/mlir/utils.py +++ b/tripy/tripy/backend/mlir/utils.py @@ -172,6 +172,20 @@ def get_constant_value(arg) -> Optional[ir.DenseElementsAttr]: return None +def check_tensor_type_and_suggest_contiguous(obj): + obj_type = str(type(obj)) + if "torch.Tensor" in obj_type: + return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()" + elif "jaxlib" in obj_type or "jax.numpy" in obj_type: + return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)" + elif "numpy.ndarray" in obj_type: + return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')" + elif "cupy.ndarray" in obj_type: + return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')" + else: + return None, None + + def remove_sym_attr(mlir_text: str) -> str: return re.sub(r"module @\S+ {", "module {", mlir_text)