Skip to content

Commit 342dec5

Browse files
authored
Create memerf with None stream for host data (NVIDIA#251)
Several tests creating memref for host data with stream are failing NVIDIA@54f9819#diff-d10f6cc35783e94f23c4ff7a348efac1ede25554509e0631d82ba6701f838455R71. It looks like MLIR-TRT is now propagating the stream correctly. Failure is observed in following check: ``` if (type == PointerType::host) { assert(alignment && !stream && "expected alignment, no stream for host allocation"); ... } ``` Tripy side fix is to pass None stream for host memref data.
1 parent 59b9536 commit 342dec5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tripy/tripy/backend/mlir/memref.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@
2828
@lru_cache(maxsize=None)
2929
def _cached_create_empty_memref(shape: Sequence[int], dtype: str, device_kind: str, stream):
3030
mlirtrt_device = mlir_utils.MLIRRuntimeClient().get_devices()[0] if device_kind == "gpu" else None
31+
mlirtrt_stream = stream if device_kind == "gpu" else None
3132
mlir_dtype = mlir_utils.convert_tripy_dtype_to_runtime_dtype(dtype)
3233
return mlir_utils.MLIRRuntimeClient().create_memref(
3334
shape=list(shape),
3435
dtype=mlir_dtype,
3536
device=mlirtrt_device,
36-
stream=stream,
37+
stream=mlirtrt_stream,
3738
)
3839

3940

0 commit comments

Comments
 (0)