44from cutlass .cutlass_dsl import T , dsl_user_op
55from cutlass ._mlir .dialects import llvm
66
7+
78@dsl_user_op
8- def ld_acquire (lock_ptr : cute .Pointer , * , loc = None , ip = None ) -> cutlass .Int32 :
9+ def ld_acquire (lock_ptr : cute .Pointer , * , loc = None , ip = None ) -> cutlass .Int32 :
910 lock_ptr_i64 = lock_ptr .toint (loc = loc , ip = ip ).ir_value ()
1011 state = llvm .inline_asm (
1112 T .i32 (),
@@ -18,8 +19,11 @@ def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
1819 )
1920 return cutlass .Int32 (state )
2021
22+
2123@dsl_user_op
22- def red_relaxed (lock_ptr : cute .Pointer , val : cutlass .Constexpr [Int32 ], * , loc = None , ip = None ) -> None :
24+ def red_relaxed (
25+ lock_ptr : cute .Pointer , val : cutlass .Constexpr [Int32 ], * , loc = None , ip = None
26+ ) -> None :
2327 lock_ptr_i64 = lock_ptr .toint (loc = loc , ip = ip ).ir_value ()
2428 llvm .inline_asm (
2529 None ,
@@ -31,8 +35,11 @@ def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N
3135 asm_dialect = llvm .AsmDialect .AD_ATT ,
3236 )
3337
38+
3439@dsl_user_op
35- def red_release (lock_ptr : cute .Pointer , val : cutlass .Constexpr [Int32 ], * , loc = None , ip = None ) -> None :
40+ def red_release (
41+ lock_ptr : cute .Pointer , val : cutlass .Constexpr [Int32 ], * , loc = None , ip = None
42+ ) -> None :
3643 lock_ptr_i64 = lock_ptr .toint (loc = loc , ip = ip ).ir_value ()
3744 llvm .inline_asm (
3845 None ,
@@ -43,28 +50,22 @@ def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N
4350 is_align_stack = False ,
4451 asm_dialect = llvm .AsmDialect .AD_ATT ,
4552 )
46-
53+
54+
4755@cute .jit
48- def wait_eq (
49- lock_ptr : cute .Pointer ,
50- thread_idx : int | Int32 ,
51- flag_offset : int ,
52- val : Int32
53- ) -> None :
56+ def wait_eq (lock_ptr : cute .Pointer , thread_idx : int | Int32 , flag_offset : int , val : Int32 ) -> None :
5457 flag_ptr = lock_ptr + flag_offset
5558 if thread_idx == 0 :
5659 read_val = Int32 (0 )
5760 while read_val != val :
5861 read_val = ld_acquire (flag_ptr )
5962
63+
6064@cute .jit
6165def arrive_inc (
62- lock_ptr : cute .Pointer ,
63- thread_idx : int | Int32 ,
64- flag_offset : int ,
65- val : cutlass .Constexpr [Int32 ]
66+ lock_ptr : cute .Pointer , thread_idx : int | Int32 , flag_offset : int , val : cutlass .Constexpr [Int32 ]
6667) -> None :
6768 flag_ptr = lock_ptr + flag_offset
6869 if thread_idx == 0 :
6970 red_release (flag_ptr , val )
70- # red_relaxed(flag_ptr, val)
71+ # red_relaxed(flag_ptr, val)
0 commit comments