Skip to content

Commit e9da05a

Browse files
committed
Fix to pass test case
1 parent eb89ab6 commit e9da05a

File tree

1 file changed

+105
-102
lines changed

1 file changed

+105
-102
lines changed

torchscale/component/flash_attention.py

Lines changed: 105 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -5,116 +5,119 @@
55
from typing import Any, Optional
66
import torch
77

8-
try:
9-
if torch.cuda.get_device_capability()[0] > 7:
10-
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
8+
if torch.cuda.is_available():
9+
try:
10+
if torch.cuda.get_device_capability()[0] > 7:
11+
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
1112

12-
def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
13-
assert bias is None
14-
attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True)
15-
return attn, lse
13+
def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
14+
assert bias is None
15+
attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True)
16+
return attn, lse
1617

17-
else:
18-
from xformers.ops.fmha import (
19-
cutlass,
20-
Inputs,
21-
Context,
22-
_memory_efficient_attention_forward_requires_grad,
23-
_memory_efficient_attention_backward,
24-
LowerTriangularMask,
25-
)
18+
else:
19+
from xformers.ops.fmha import (
20+
cutlass,
21+
Inputs,
22+
Context,
23+
_memory_efficient_attention_forward_requires_grad,
24+
_memory_efficient_attention_backward,
25+
LowerTriangularMask,
26+
)
2627

27-
class FlashAttnFunc(torch.autograd.Function):
28-
@staticmethod
29-
# type: ignore
30-
def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
31-
if is_causal:
32-
assert bias is None
33-
attn_bias = LowerTriangularMask()
34-
else:
35-
attn_bias = bias
28+
class FlashAttnFunc(torch.autograd.Function):
29+
@staticmethod
30+
# type: ignore
31+
def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
32+
if is_causal:
33+
assert bias is None
34+
attn_bias = LowerTriangularMask()
35+
else:
36+
attn_bias = bias
3637

37-
inp = Inputs(
38-
query=q,
39-
key=k,
40-
value=v,
41-
attn_bias=attn_bias,
42-
p=dropout,
43-
scale=softmax_scale,
44-
)
45-
op_fw = cutlass.FwOp
46-
op_bw = cutlass.BwOp
38+
inp = Inputs(
39+
query=q,
40+
key=k,
41+
value=v,
42+
attn_bias=attn_bias,
43+
p=dropout,
44+
scale=softmax_scale,
45+
)
46+
op_fw = cutlass.FwOp
47+
op_bw = cutlass.BwOp
4748

48-
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
49-
inp=inp, op=op_fw
50-
)
49+
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
50+
inp=inp, op=op_fw
51+
)
5152

52-
# Saving attn_bias is a bit complicated, as the
53-
# torch part should go in `save_for_backward`
54-
if isinstance(inp.attn_bias, torch.Tensor):
55-
attn_bias_tensor = inp.attn_bias
56-
attn_bias_ctx = None
57-
else:
58-
attn_bias_tensor = None
59-
attn_bias_ctx = inp.attn_bias
53+
# Saving attn_bias is a bit complicated, as the
54+
# torch part should go in `save_for_backward`
55+
if isinstance(inp.attn_bias, torch.Tensor):
56+
attn_bias_tensor = inp.attn_bias
57+
attn_bias_ctx = None
58+
else:
59+
attn_bias_tensor = None
60+
attn_bias_ctx = inp.attn_bias
6061

61-
ctx.save_for_backward(
62-
inp.query,
63-
inp.key,
64-
inp.value,
65-
op_ctx.out,
66-
op_ctx.lse,
67-
)
68-
ctx.rng_state = op_ctx.rng_state
69-
ctx.attn_bias_tensor = attn_bias_tensor
70-
if op_ctx.op_bw is not None:
71-
if op_bw is not None and op_bw is not op_ctx.op_bw:
72-
raise ValueError(
73-
f"Specified op_bw={op_bw.NAME}, but forward op "
74-
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
75-
)
76-
op_bw = op_ctx.op_bw
77-
ctx.op_fw = op_fw
78-
ctx.op_bw = op_bw
79-
ctx.p = inp.p
62+
ctx.save_for_backward(
63+
inp.query,
64+
inp.key,
65+
inp.value,
66+
op_ctx.out,
67+
op_ctx.lse,
68+
)
69+
ctx.rng_state = op_ctx.rng_state
70+
ctx.attn_bias_tensor = attn_bias_tensor
71+
if op_ctx.op_bw is not None:
72+
if op_bw is not None and op_bw is not op_ctx.op_bw:
73+
raise ValueError(
74+
f"Specified op_bw={op_bw.NAME}, but forward op "
75+
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
76+
)
77+
op_bw = op_ctx.op_bw
78+
ctx.op_fw = op_fw
79+
ctx.op_bw = op_bw
80+
ctx.p = inp.p
8081

81-
ctx.scale = inp.scale
82-
ctx.attn_bias_ctx = attn_bias_ctx
83-
return out, op_ctx.lse
82+
ctx.scale = inp.scale
83+
ctx.attn_bias_ctx = attn_bias_ctx
84+
return out, op_ctx.lse
8485

85-
@staticmethod
86-
def deserialize_bias(
87-
attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
88-
) -> Any:
89-
if attn_bias_tensor is None:
90-
return attn_bias_ctx
91-
return attn_bias_tensor
86+
@staticmethod
87+
def deserialize_bias(
88+
attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
89+
) -> Any:
90+
if attn_bias_tensor is None:
91+
return attn_bias_ctx
92+
return attn_bias_tensor
9293

93-
@classmethod
94-
@torch.autograd.function.once_differentiable
95-
def backward(cls, ctx, grad, dlse):
96-
# Re-create context
97-
query, key, value, out, lse = ctx.saved_tensors
98-
attn_bias_tensor = ctx.attn_bias_tensor
99-
rng_state = ctx.rng_state
100-
inp = Inputs(
101-
query=query,
102-
key=key,
103-
value=value,
104-
attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
105-
p=ctx.p,
106-
scale=ctx.scale,
107-
)
108-
op_ctx = Context(
109-
lse=lse,
110-
out=out,
111-
rng_state=rng_state,
112-
)
113-
grads = _memory_efficient_attention_backward(
114-
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
115-
)
116-
return grads.dq, grads.dk, grads.dv, None, grads.db, None, None
117-
118-
flash_attn_func = FlashAttnFunc.apply
119-
except ModuleNotFoundError:
94+
@classmethod
95+
@torch.autograd.function.once_differentiable
96+
def backward(cls, ctx, grad, dlse):
97+
# Re-create context
98+
query, key, value, out, lse = ctx.saved_tensors
99+
attn_bias_tensor = ctx.attn_bias_tensor
100+
rng_state = ctx.rng_state
101+
inp = Inputs(
102+
query=query,
103+
key=key,
104+
value=value,
105+
attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
106+
p=ctx.p,
107+
scale=ctx.scale,
108+
)
109+
op_ctx = Context(
110+
lse=lse,
111+
out=out,
112+
rng_state=rng_state,
113+
)
114+
grads = _memory_efficient_attention_backward(
115+
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
116+
)
117+
return grads.dq, grads.dk, grads.dv, None, grads.db, None, None
118+
119+
flash_attn_func = FlashAttnFunc.apply
120+
except ModuleNotFoundError:
121+
flash_attn_func = None
122+
else:
120123
flash_attn_func = None

0 commit comments

Comments
 (0)