Skip to content

Commit 932d3ab

Browse files
format barrier.py with ruff
1 parent f3561a6 commit 932d3ab

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

flash_attn/cute/barrier.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from cutlass.cutlass_dsl import T, dsl_user_op
55
from cutlass._mlir.dialects import llvm
66

7+
78
@dsl_user_op
8-
def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
9+
def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
910
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
1011
state = llvm.inline_asm(
1112
T.i32(),
@@ -18,8 +19,11 @@ def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
1819
)
1920
return cutlass.Int32(state)
2021

22+
2123
@dsl_user_op
22-
def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None:
24+
def red_relaxed(
25+
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
26+
) -> None:
2327
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
2428
llvm.inline_asm(
2529
None,
@@ -31,8 +35,11 @@ def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N
3135
asm_dialect=llvm.AsmDialect.AD_ATT,
3236
)
3337

38+
3439
@dsl_user_op
35-
def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None:
40+
def red_release(
41+
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
42+
) -> None:
3643
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
3744
llvm.inline_asm(
3845
None,
@@ -43,28 +50,22 @@ def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N
4350
is_align_stack=False,
4451
asm_dialect=llvm.AsmDialect.AD_ATT,
4552
)
46-
53+
54+
4755
@cute.jit
48-
def wait_eq(
49-
lock_ptr : cute.Pointer,
50-
thread_idx : int | Int32,
51-
flag_offset : int,
52-
val : Int32
53-
) -> None:
56+
def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
5457
flag_ptr = lock_ptr + flag_offset
5558
if thread_idx == 0:
5659
read_val = Int32(0)
5760
while read_val != val:
5861
read_val = ld_acquire(flag_ptr)
5962

63+
6064
@cute.jit
6165
def arrive_inc(
62-
lock_ptr : cute.Pointer,
63-
thread_idx : int | Int32,
64-
flag_offset : int,
65-
val : cutlass.Constexpr[Int32]
66+
lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
6667
) -> None:
6768
flag_ptr = lock_ptr + flag_offset
6869
if thread_idx == 0:
6970
red_release(flag_ptr, val)
70-
# red_relaxed(flag_ptr, val)
71+
# red_relaxed(flag_ptr, val)

0 commit comments

Comments
 (0)