-
Notifications
You must be signed in to change notification settings - Fork 64
Open
Labels
bugSomething isn't workingSomething isn't working
Description
with >=0.1.8, the original issue apache/tvm#18798 no longer raise readable error message. It crash with following error message:
python(69796,0x1f328ec40) malloc: *** error for object 0x8000000000000070: pointer being freed was not allocated
python(69796,0x1f328ec40) malloc: *** set a breakpoint in malloc_error_break to debug
could be reproduced with following tilelang script:
import tilelang
print("Imported tilelang")
from tilelang import tvm as tvm
from time import sleep
# import tilelang.testing
import tilelang.language as T
import json
import torch
import os
print("Imports done", flush=True)
from tilelang.engine.callback import register_metal_postproc_callback
@register_metal_postproc_callback
def _p(code, target):
print(code)
return code
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
for i, j in T.Parallel(block_M, block_N):
for k in T.Serial(block_K):
C_local[i, j] += A_shared[i, k] * B_shared[k, j]
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm
def benchmark(f, n, *args, **kwargs):
# trigger jit
f(*args, **kwargs)
torch.mps.synchronize()
with torch.mps.profiler.profile(mode="interval,event", wait_until_completed=True):
start = torch.mps.Event(enable_timing=True)
end = torch.mps.Event(enable_timing=True)
start.record()
for _ in range(n):
f(*args, **kwargs)
end.record()
start.synchronize()
end.synchronize()
return start.elapsed_time(end) / 1000
if __name__ == "__main__":
m = n = k = 128
torch_dtype = torch.float16
dtype = 'float16'
a = torch.randn(m, k, device="mps", dtype=torch_dtype)
b = torch.randn(k, n, device="mps", dtype=torch_dtype)
c = torch.zeros(m, n, device="mps", dtype=torch_dtype)
# torch_add = lambda: torch.matmul(a, b, out=c)
# torch_add()
# print(benchmark(torch_add, n=100))
print("Starting compilation...", flush=True)
jit_kernel = matmul(m, n, k, 16, 16, 16, dtype=dtype, accum_dtype="float")
print("Compilation finished.", flush=True)
# print(jit_kernel.get_kernel_source())
jit_kernel(a, b, c)
print(c)
print(a @ b)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working