|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 |
|
| 18 | +import re |
| 19 | + |
18 | 20 | from functools import lru_cache |
19 | 21 | from typing import Sequence |
20 | 22 |
|
| 23 | +from tripy.utils import raise_error |
21 | 24 | from tripy.backend.mlir import utils as mlir_utils |
22 | 25 | from tripy.common import device as tp_device |
23 | 26 | from tripy.common import utils as common_utils |
@@ -66,7 +69,56 @@ def create_memref_view(data): |
66 | 69 | """ |
67 | 70 | Creates a memref view of an array object that implements the dlpack interface. |
68 | 71 | """ |
69 | | - return mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack(data.__dlpack__()) |
| 72 | + try: |
| 73 | + memref = mlir_utils.MLIRRuntimeClient().create_memref_view_from_dlpack( |
| 74 | + data.__dlpack__(), assert_canonical_strides=True |
| 75 | + ) |
| 76 | + except runtime.MTRTException as e: |
| 77 | + error_msg = str(e) |
| 78 | + match = re.search( |
| 79 | + r"Given strides \[([\d, ]+)\] do not match canonical strides \[([\d, ]+)\] for shape \[([\d, ]+)\]", |
| 80 | + error_msg, |
| 81 | + ) |
| 82 | + |
| 83 | + if match: |
| 84 | + given_strides = [int(s) for s in match.group(1).split(",")] |
| 85 | + canonical_strides = [int(s) for s in match.group(2).split(",")] |
| 86 | + shape = [int(s) for s in match.group(3).split(",")] |
| 87 | + |
| 88 | + def check_tensor_type_and_suggest_contiguous(obj): |
| 89 | + obj_type = str(type(obj)) |
| 90 | + if "torch.Tensor" in obj_type: |
| 91 | + return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()" |
| 92 | + elif "jaxlib" in obj_type or "jax.numpy" in obj_type: |
| 93 | + return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)" |
| 94 | + elif "numpy.ndarray" in obj_type: |
| 95 | + return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')" |
| 96 | + elif "cupy.ndarray" in obj_type: |
| 97 | + return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')" |
| 98 | + else: |
| 99 | + return None, None |
| 100 | + |
| 101 | + tensor_type, contiguous_suggestion = check_tensor_type_and_suggest_contiguous(data) |
| 102 | + |
| 103 | + error_message = ( |
| 104 | + f"Non-canonical strides detected:\n" |
| 105 | + f" Shape: {shape}\n" |
| 106 | + f" Current stride: {given_strides}\n" |
| 107 | + f" Expected canonical stride: {canonical_strides}\n" |
| 108 | + f"Non-canonical strides are not supported for Tripy tensors. " |
| 109 | + f"This usually occurs when the tensor is not contiguous in memory. " |
| 110 | + + ( |
| 111 | + f"To resolve this issue:\n" |
| 112 | + f"For {tensor_type}, use {contiguous_suggestion} to ensure contiguity before converting to a Tripy tensor." |
| 113 | + if tensor_type is not None |
| 114 | + else "" |
| 115 | + ) |
| 116 | + ) |
| 117 | + raise_error(error_message) |
| 118 | + else: |
| 119 | + # If the error message doesn't match the expected format, re-raise the original exception |
| 120 | + raise |
| 121 | + return memref |
70 | 122 |
|
71 | 123 |
|
72 | 124 | # TODO(#134): Consider move below functions to MLIR py bindings |
|
0 commit comments