Skip to content

Commit 7bb1883

Browse files
symm mem ag/rs
1 parent 5d549c1 commit 7bb1883

File tree

9 files changed

+766
-14
lines changed

9 files changed

+766
-14
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .multimem import multimem_all_gather, multimem_reduce_scatter
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import triton
2+
import triton.language as tl
3+
4+
@triton.jit
5+
def multimem_ld_reduce_128(multicast_ptrs, mask):
6+
"""
7+
Multicast load and reduce 128 bits (4 x bf16) from all peers over nvlink
8+
Outputs are returned as 4 tl.uint32 registers, each containing 2 bf16 values
9+
"""
10+
return tl.inline_asm_elementwise(
11+
"""
12+
{
13+
.reg .pred %p0;
14+
setp.eq.s32 %p0, $5, 1;
15+
@!%p0 bra end;
16+
multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {$0, $1, $2, $3}, [$4];
17+
end:
18+
}
19+
""",
20+
"=r,=r,=r,=r,l,r",
21+
args=[multicast_ptrs, mask.to(tl.int32)],
22+
dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32),
23+
is_pure=True,
24+
pack=1,
25+
)
26+
27+
28+
@triton.jit
29+
def multimem_st_128(multicast_ptrs, x, y, z, w, mask):
30+
"""
31+
Multicast store 128 bits (4 x bf16) to all peers over nvlink
32+
"""
33+
return tl.inline_asm_elementwise(
34+
"""
35+
{
36+
.reg .pred %p0;
37+
setp.eq.s32 %p0, $6, 1;
38+
@!%p0 bra end;
39+
multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5};
40+
end:
41+
}
42+
""",
43+
"=r,l,r,r,r,r,r",
44+
args=[multicast_ptrs, x, y, z, w, mask.to(tl.int32)],
45+
dtype=(tl.uint32),
46+
is_pure=False,
47+
pack=1,
48+
)
49+
50+
@triton.jit
51+
def ld_128(ptr, mask):
52+
"""
53+
Load 128 bits (4 x bf16) from ptr
54+
Outputs are returned as 4 tl.uint32 registers, each containing 2 bf16 values
55+
"""
56+
return tl.inline_asm_elementwise(
57+
"""
58+
{
59+
.reg .pred %p0;
60+
setp.eq.s32 %p0, $5, 1;
61+
@!%p0 bra end;
62+
ld.global.relaxed.sys.v4.u32 {$0, $1, $2, $3}, [$4];
63+
end:
64+
}
65+
""",
66+
"=r,=r,=r,=r,l,r",
67+
args=[ptr, mask.to(tl.int32)],
68+
dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32),
69+
is_pure=True,
70+
pack=1,
71+
)
72+
73+
74+
@triton.jit
75+
def st_128(ptr, x, y, z, w, mask):
76+
return tl.inline_asm_elementwise(
77+
"""
78+
{
79+
.reg .pred %p0;
80+
setp.eq.s32 %p0, $6, 1;
81+
@!%p0 bra end;
82+
st.global.relaxed.sys.v4.f32 [$1], {$2, $3, $4, $5};
83+
end:
84+
}
85+
""",
86+
"=r,l,r,r,r,r,r",
87+
args=[ptr, x, y, z, w, mask.to(tl.int32)],
88+
dtype=(tl.uint32),
89+
is_pure=False,
90+
pack=1,
91+
)
92+
93+
@triton.jit
94+
def add_v8_bf16_from_u32(
95+
a0, a1, a2, a3, # First vector of 8 bf16s, packed in 4 uint32s
96+
b0, b1, b2, b3, # Second vector of 8 bf16s, packed in 4 uint32s
97+
):
98+
"""
99+
Adds two vectors of 8 bfloat16 numbers.
100+
Each vector is passed as four tl.uint32 tensors.
101+
Returns the result as a tuple of four tl.uint32 tensors.
102+
"""
103+
return tl.inline_asm_elementwise(
104+
"""
105+
{
106+
add.bf16x2 $0, $4, $8;
107+
add.bf16x2 $1, $5, $9;
108+
add.bf16x2 $2, $6, $10;
109+
add.bf16x2 $3, $7, $11;
110+
}
111+
""",
112+
# 8 outputs (=r), 8 inputs (r)
113+
"=r,=r,=r,=r,r,r,r,r,r,r,r,r",
114+
args=[a0, a1, a2, a3, b0, b1, b2, b3],
115+
dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32),
116+
is_pure=True,
117+
pack=1,
118+
)
119+
120+
@triton.jit
121+
def asm_rsqrt(x, eps):
122+
"""
123+
Computes the reciprocal square root of a float32 number using inline assembly.
124+
"""
125+
return tl.inline_asm_elementwise(
126+
"""
127+
{
128+
add.f32 $1, $1, $2;
129+
rsqrt.approx.f32 $0, $1;
130+
}
131+
""",
132+
"=f, f, f",
133+
args=[x, eps],
134+
dtype=(tl.float32),
135+
is_pure=True,
136+
pack=1,
137+
)
138+

0 commit comments

Comments
 (0)