|
5 | 5 | from typing import Any, Optional |
6 | 6 | import torch |
7 | 7 |
|
8 | | -try: |
9 | | - if torch.cuda.get_device_capability()[0] > 7: |
10 | | - from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func |
| 8 | +if torch.cuda.is_available(): |
| 9 | + try: |
| 10 | + if torch.cuda.get_device_capability()[0] > 7: |
| 11 | + from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func |
11 | 12 |
|
12 | | - def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False): |
13 | | - assert bias is None |
14 | | - attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True) |
15 | | - return attn, lse |
| 13 | + def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False): |
| 14 | + assert bias is None |
| 15 | + attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True) |
| 16 | + return attn, lse |
16 | 17 |
|
17 | | - else: |
18 | | - from xformers.ops.fmha import ( |
19 | | - cutlass, |
20 | | - Inputs, |
21 | | - Context, |
22 | | - _memory_efficient_attention_forward_requires_grad, |
23 | | - _memory_efficient_attention_backward, |
24 | | - LowerTriangularMask, |
25 | | - ) |
| 18 | + else: |
| 19 | + from xformers.ops.fmha import ( |
| 20 | + cutlass, |
| 21 | + Inputs, |
| 22 | + Context, |
| 23 | + _memory_efficient_attention_forward_requires_grad, |
| 24 | + _memory_efficient_attention_backward, |
| 25 | + LowerTriangularMask, |
| 26 | + ) |
26 | 27 |
|
27 | | - class FlashAttnFunc(torch.autograd.Function): |
28 | | - @staticmethod |
29 | | - # type: ignore |
30 | | - def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False): |
31 | | - if is_causal: |
32 | | - assert bias is None |
33 | | - attn_bias = LowerTriangularMask() |
34 | | - else: |
35 | | - attn_bias = bias |
| 28 | + class FlashAttnFunc(torch.autograd.Function): |
| 29 | + @staticmethod |
| 30 | + # type: ignore |
| 31 | + def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False): |
| 32 | + if is_causal: |
| 33 | + assert bias is None |
| 34 | + attn_bias = LowerTriangularMask() |
| 35 | + else: |
| 36 | + attn_bias = bias |
36 | 37 |
|
37 | | - inp = Inputs( |
38 | | - query=q, |
39 | | - key=k, |
40 | | - value=v, |
41 | | - attn_bias=attn_bias, |
42 | | - p=dropout, |
43 | | - scale=softmax_scale, |
44 | | - ) |
45 | | - op_fw = cutlass.FwOp |
46 | | - op_bw = cutlass.BwOp |
| 38 | + inp = Inputs( |
| 39 | + query=q, |
| 40 | + key=k, |
| 41 | + value=v, |
| 42 | + attn_bias=attn_bias, |
| 43 | + p=dropout, |
| 44 | + scale=softmax_scale, |
| 45 | + ) |
| 46 | + op_fw = cutlass.FwOp |
| 47 | + op_bw = cutlass.BwOp |
47 | 48 |
|
48 | | - out, op_ctx = _memory_efficient_attention_forward_requires_grad( |
49 | | - inp=inp, op=op_fw |
50 | | - ) |
| 49 | + out, op_ctx = _memory_efficient_attention_forward_requires_grad( |
| 50 | + inp=inp, op=op_fw |
| 51 | + ) |
51 | 52 |
|
52 | | - # Saving attn_bias is a bit complicated, as the |
53 | | - # torch part should go in `save_for_backward` |
54 | | - if isinstance(inp.attn_bias, torch.Tensor): |
55 | | - attn_bias_tensor = inp.attn_bias |
56 | | - attn_bias_ctx = None |
57 | | - else: |
58 | | - attn_bias_tensor = None |
59 | | - attn_bias_ctx = inp.attn_bias |
| 53 | + # Saving attn_bias is a bit complicated, as the |
| 54 | + # torch part should go in `save_for_backward` |
| 55 | + if isinstance(inp.attn_bias, torch.Tensor): |
| 56 | + attn_bias_tensor = inp.attn_bias |
| 57 | + attn_bias_ctx = None |
| 58 | + else: |
| 59 | + attn_bias_tensor = None |
| 60 | + attn_bias_ctx = inp.attn_bias |
60 | 61 |
|
61 | | - ctx.save_for_backward( |
62 | | - inp.query, |
63 | | - inp.key, |
64 | | - inp.value, |
65 | | - op_ctx.out, |
66 | | - op_ctx.lse, |
67 | | - ) |
68 | | - ctx.rng_state = op_ctx.rng_state |
69 | | - ctx.attn_bias_tensor = attn_bias_tensor |
70 | | - if op_ctx.op_bw is not None: |
71 | | - if op_bw is not None and op_bw is not op_ctx.op_bw: |
72 | | - raise ValueError( |
73 | | - f"Specified op_bw={op_bw.NAME}, but forward op " |
74 | | - f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None." |
75 | | - ) |
76 | | - op_bw = op_ctx.op_bw |
77 | | - ctx.op_fw = op_fw |
78 | | - ctx.op_bw = op_bw |
79 | | - ctx.p = inp.p |
| 62 | + ctx.save_for_backward( |
| 63 | + inp.query, |
| 64 | + inp.key, |
| 65 | + inp.value, |
| 66 | + op_ctx.out, |
| 67 | + op_ctx.lse, |
| 68 | + ) |
| 69 | + ctx.rng_state = op_ctx.rng_state |
| 70 | + ctx.attn_bias_tensor = attn_bias_tensor |
| 71 | + if op_ctx.op_bw is not None: |
| 72 | + if op_bw is not None and op_bw is not op_ctx.op_bw: |
| 73 | + raise ValueError( |
| 74 | + f"Specified op_bw={op_bw.NAME}, but forward op " |
| 75 | + f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None." |
| 76 | + ) |
| 77 | + op_bw = op_ctx.op_bw |
| 78 | + ctx.op_fw = op_fw |
| 79 | + ctx.op_bw = op_bw |
| 80 | + ctx.p = inp.p |
80 | 81 |
|
81 | | - ctx.scale = inp.scale |
82 | | - ctx.attn_bias_ctx = attn_bias_ctx |
83 | | - return out, op_ctx.lse |
| 82 | + ctx.scale = inp.scale |
| 83 | + ctx.attn_bias_ctx = attn_bias_ctx |
| 84 | + return out, op_ctx.lse |
84 | 85 |
|
85 | | - @staticmethod |
86 | | - def deserialize_bias( |
87 | | - attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor] |
88 | | - ) -> Any: |
89 | | - if attn_bias_tensor is None: |
90 | | - return attn_bias_ctx |
91 | | - return attn_bias_tensor |
| 86 | + @staticmethod |
| 87 | + def deserialize_bias( |
| 88 | + attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor] |
| 89 | + ) -> Any: |
| 90 | + if attn_bias_tensor is None: |
| 91 | + return attn_bias_ctx |
| 92 | + return attn_bias_tensor |
92 | 93 |
|
93 | | - @classmethod |
94 | | - @torch.autograd.function.once_differentiable |
95 | | - def backward(cls, ctx, grad, dlse): |
96 | | - # Re-create context |
97 | | - query, key, value, out, lse = ctx.saved_tensors |
98 | | - attn_bias_tensor = ctx.attn_bias_tensor |
99 | | - rng_state = ctx.rng_state |
100 | | - inp = Inputs( |
101 | | - query=query, |
102 | | - key=key, |
103 | | - value=value, |
104 | | - attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor), |
105 | | - p=ctx.p, |
106 | | - scale=ctx.scale, |
107 | | - ) |
108 | | - op_ctx = Context( |
109 | | - lse=lse, |
110 | | - out=out, |
111 | | - rng_state=rng_state, |
112 | | - ) |
113 | | - grads = _memory_efficient_attention_backward( |
114 | | - ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw |
115 | | - ) |
116 | | - return grads.dq, grads.dk, grads.dv, None, grads.db, None, None |
117 | | - |
118 | | - flash_attn_func = FlashAttnFunc.apply |
119 | | -except ModuleNotFoundError: |
| 94 | + @classmethod |
| 95 | + @torch.autograd.function.once_differentiable |
| 96 | + def backward(cls, ctx, grad, dlse): |
| 97 | + # Re-create context |
| 98 | + query, key, value, out, lse = ctx.saved_tensors |
| 99 | + attn_bias_tensor = ctx.attn_bias_tensor |
| 100 | + rng_state = ctx.rng_state |
| 101 | + inp = Inputs( |
| 102 | + query=query, |
| 103 | + key=key, |
| 104 | + value=value, |
| 105 | + attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor), |
| 106 | + p=ctx.p, |
| 107 | + scale=ctx.scale, |
| 108 | + ) |
| 109 | + op_ctx = Context( |
| 110 | + lse=lse, |
| 111 | + out=out, |
| 112 | + rng_state=rng_state, |
| 113 | + ) |
| 114 | + grads = _memory_efficient_attention_backward( |
| 115 | + ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw |
| 116 | + ) |
| 117 | + return grads.dq, grads.dk, grads.dv, None, grads.db, None, None |
| 118 | + |
| 119 | + flash_attn_func = FlashAttnFunc.apply |
| 120 | + except ModuleNotFoundError: |
| 121 | + flash_attn_func = None |
| 122 | +else: |
120 | 123 | flash_attn_func = None |
0 commit comments