Skip to content

Commit 8354497

Browse files
bottlerxFormers Bot
authored andcommitted
remove merge_attentions backward (fairinternal/xformers#1402)
__original_commit__ = fairinternal/xformers@601197a
1 parent 9eb546b commit 8354497

File tree

4 files changed

+87
-179
lines changed

4 files changed

+87
-179
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## [0.0.32] - 2025-??-??
8+
### Removed
9+
- Removed autograd backward pass for merge_attentions as it is easy to use incorrectly.
810

911
## [0.0.31] - 2025-06-25
1012
Pre-built binary wheels are available for PyTorch 2.7.1.

tests/test_mem_eff_attention.py

Lines changed: 2 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import math
99
import random
1010
from contextlib import nullcontext
11-
from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar, Union
11+
from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar
1212

1313
import pytest
1414
import torch
@@ -2814,68 +2814,6 @@ def test_merge_attentions_nobias(
28142814
assert lse is None
28152815

28162816

2817-
@disable_on_rocm
2818-
@sm80_or_better_only
2819-
@pytest.mark.parametrize(
2820-
"op",
2821-
[
2822-
pytest.param(fmha.flash.FwOp, id="flashfwd"),
2823-
pytest.param((fmha.flash.FwOp, fmha.cutlass.BwOp), id="flashcutlass"),
2824-
# pytest.param((fmha.triton_splitk.FwOp, fmha.cutlass.BwOp), id="splitk"), # XXX
2825-
pytest.param(fmha.MemoryEfficientAttentionFlashAttentionOp, id="flash"),
2826-
None,
2827-
],
2828-
)
2829-
def test_merge_attentions_nobias_bwd(
2830-
op: Union[Type[AttentionFwOpBase], fmha.AttentionOp],
2831-
):
2832-
B, M, Mq, H, K = 13, 5, 5, 4, 128
2833-
dtype = torch.bfloat16
2834-
nparts = 3
2835-
torch.manual_seed(1)
2836-
q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda")
2837-
kv = [
2838-
[3 * (torch.rand(B, M, H, K, dtype=dtype, device="cuda")) for _ in range(2)]
2839-
for _ in range(nparts)
2840-
]
2841-
q = q.requires_grad_(True)
2842-
kv = [[j.requires_grad_(True) for j in i] for i in kv]
2843-
out_parts = [fmha.memory_efficient_attention_partial(q, k, v, op=op) for k, v in kv]
2844-
attn_split, lse_split = [list(x) for x in zip(*out_parts)]
2845-
out_merged = fmha.merge_attentions(attn_split, lse_split, write_lse=True)[0]
2846-
grad_out = torch.rand_like(q)
2847-
out_merged.backward(grad_out)
2848-
grad_q_out = q.grad
2849-
assert q.grad is not None
2850-
grad_kv_out = [[j.grad for j in i] for i in kv]
2851-
q = q.detach().requires_grad_(True)
2852-
kv = [[j.detach().requires_grad_(True) for j in i] for i in kv]
2853-
2854-
k2, v2 = [torch.cat([i[j] for i in kv], dim=1) for j in range(2)]
2855-
2856-
if op is None or isinstance(op, tuple):
2857-
full_op = op
2858-
else:
2859-
full_op = (op, None)
2860-
out_full = fmha.memory_efficient_attention(q, k2, v2, op=full_op) # type: ignore
2861-
out_full.backward(grad_out)
2862-
assert_allclose(
2863-
out_merged, out_full.to(out_merged.dtype), rtol=1e-2, atol=2e-2, msg="out"
2864-
)
2865-
atol = fmha.AttentionBwOpBase.ERROR_ATOL[dtype] * 1.5
2866-
rtol = fmha.AttentionBwOpBase.ERROR_RTOL[dtype]
2867-
assert_allclose(grad_q_out, q.grad, rtol=rtol, atol=atol, msg="qgrad")
2868-
for i in range(nparts):
2869-
for j in range(2):
2870-
assert_allclose(
2871-
grad_kv_out[i][j],
2872-
kv[i][j].grad,
2873-
rtol=rtol,
2874-
atol=atol,
2875-
msg=f"kvgrad {i} {j}",
2876-
)
2877-
2878-
28792817
@disable_on_rocm
28802818
@sm80_or_better_only
28812819
@pytest.mark.parametrize(
@@ -3221,15 +3159,7 @@ def test_merge_attentions_sharedinput(
32213159

32223160
@sm80_or_better_only
32233161
@pytest.mark.parametrize("bmghk", (False, True))
3224-
@pytest.mark.parametrize(
3225-
"stack_inputs", (False, True), ids=lambda x: "stack_inputs" if x else ""
3226-
)
3227-
@pytest.mark.parametrize(
3228-
"grad_var", ("lse", "attn", None)
3229-
) # Gradient with respect to attention, LSE, or neither
3230-
def test_merge_attentions_against_ref(
3231-
bmghk: bool, stack_inputs: bool, grad_var: Optional[str]
3232-
):
3162+
def test_merge_attentions_against_ref(bmghk: bool):
32333163
split_k = 16
32343164
B = 12
32353165
M = 137
@@ -3245,55 +3175,12 @@ def test_merge_attentions_against_ref(
32453175
attn_split = attn_split[:, :, :, 0]
32463176
lse_split = lse_split[:, :, 0]
32473177

3248-
if grad_var is not None:
3249-
attn_split.requires_grad_(True)
3250-
lse_split.requires_grad_(True)
3251-
32523178
attn_out_ref, lse_out_ref = _merge_attentions_ref(attn_split, lse_split)
3253-
if grad_var is not None:
3254-
if grad_var == "attn":
3255-
out_grad = torch.randn_like(attn_out_ref)
3256-
attn_out_ref.backward(out_grad)
3257-
else:
3258-
out_grad = torch.randn_like(lse_out_ref)
3259-
lse_out_ref.backward(out_grad)
3260-
3261-
attn_grad_ref, lse_grad_ref = attn_split.grad, lse_split.grad
3262-
3263-
attn_split = attn_split.detach().unbind(0) # type: ignore
3264-
lse_split = lse_split.detach().unbind(0) # type: ignore
3265-
3266-
for x in attn_split + lse_split:
3267-
x.requires_grad_(True)
3268-
x.retain_grad()
3269-
32703179
attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split)
32713180

32723181
torch.testing.assert_close(lse_out, lse_out_ref, rtol=1e-4, atol=1e-4)
32733182
torch.testing.assert_close(attn_out, attn_out_ref, rtol=1e-4, atol=1e-4)
32743183

3275-
if grad_var is not None:
3276-
if grad_var == "attn":
3277-
attn_out.backward(out_grad)
3278-
else:
3279-
assert lse_out is not None
3280-
lse_out.backward(out_grad)
3281-
3282-
attn_grads = [x.grad for x in attn_split]
3283-
lse_grads = [x.grad for x in lse_split]
3284-
attn_grad_concat = torch.stack(attn_grads, dim=0)
3285-
lse_grad_concat = torch.stack(lse_grads, dim=0)
3286-
3287-
if grad_var == "lse":
3288-
# LSE doesn't depend on attn_split, so when only gradient with respect to LSE is provided as input,
3289-
# the output gradient with respect to attn_split is zero.
3290-
# The reference implementation produced None instead of zero in this case
3291-
attn_grad_ref = torch.zeros_like(attn_grad_concat)
3292-
torch.testing.assert_close(lse_grad_concat, lse_grad_ref, rtol=1e-4, atol=1e-4)
3293-
torch.testing.assert_close(
3294-
attn_grad_concat, attn_grad_ref, rtol=1e-4, atol=1e-4
3295-
)
3296-
32973184

32983185
def _merge_attentions_ref(attn_split, lse_split):
32993186
"""

xformers/ops/fmha/__init__.py

Lines changed: 21 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -794,31 +794,34 @@ def merge_attentions(
794794
attn_dtype = attn_split[0].dtype
795795
lse_dtype = lse_split[0].dtype
796796

797-
attn_out = torch.empty(
798-
B,
799-
M,
800-
G,
801-
H,
802-
Kq,
803-
device=device,
804-
dtype=output_dtype or attn_dtype,
805-
)
806-
if write_lse:
807-
lse_out = torch.empty(
797+
if concat_path:
798+
attn_out = torch.empty(
808799
B,
800+
M,
809801
G,
810802
H,
811-
M,
803+
Kq,
812804
device=device,
813-
dtype=lse_dtype,
805+
dtype=output_dtype or attn_dtype,
814806
)
815-
else:
816-
lse_out = None
817-
818-
if concat_path:
807+
if write_lse:
808+
lse_out = torch.empty(
809+
B,
810+
G,
811+
H,
812+
M,
813+
device=device,
814+
dtype=lse_dtype,
815+
)
816+
else:
817+
lse_out = None
819818
triton_splitk.merge_attentions(attn_out, lse_out, attn_split, lse_split) # type: ignore
820819
else:
821-
attn_out, lse_out = _MergeAttentions.apply(attn_out, lse_out, *attn_split, *lse_split) # type: ignore
820+
outs = triton_splitk.merge_attentions_varargs(
821+
attn_split, lse_split, write_lse, output_dtype, B, M, G, H, Kq
822+
) # type: ignore
823+
attn_out = outs[0]
824+
lse_out = outs[1] if write_lse else None
822825

823826
if is_bmhk:
824827
attn_out = attn_out[:, :, 0]
@@ -828,44 +831,6 @@ def merge_attentions(
828831
return attn_out, lse_out
829832

830833

831-
class _MergeAttentions(torch.autograd.Function):
832-
@staticmethod
833-
# type: ignore
834-
def forward(
835-
ctx, attn_out: torch.Tensor, lse_out: torch.Tensor, *inputs: torch.Tensor
836-
) -> Tuple[torch.Tensor, torch.Tensor]:
837-
num_chunks = len(inputs) // 2
838-
attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:]
839-
840-
triton_splitk.merge_attentions_varargs(attn_out, lse_out, attn_split, lse_split)
841-
842-
ctx.save_for_backward(
843-
attn_out,
844-
lse_out,
845-
*inputs,
846-
)
847-
return attn_out, lse_out
848-
849-
@staticmethod
850-
# type: ignore
851-
def backward(
852-
ctx, grad_attn: torch.Tensor, grad_lse: torch.Tensor
853-
) -> Tuple[Optional[torch.Tensor], ...]:
854-
out, lse, *inputs = ctx.saved_tensors
855-
num_chunks = len(inputs) // 2
856-
attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:]
857-
dattn, dlse = triton_splitk.merge_attentions_varargs_backward(
858-
attn_split,
859-
lse_split,
860-
out,
861-
lse,
862-
grad_attn,
863-
grad_lse,
864-
)
865-
ret = [None, None] + dattn + dlse
866-
return tuple(ret)
867-
868-
869834
ALL_FW_OPS: List[Type[AttentionFwOpBase]] = [
870835
cutlass.FwOp if torch.version.cuda else ck.FwOp,
871836
flash.FwOp,

xformers/ops/fmha/triton_splitk.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,19 +1047,37 @@ def merge_attentions(
10471047

10481048
@torch.library.custom_op(
10491049
"xformers::fmha_merge_attentions_varargs",
1050-
mutates_args=("attn_out", "lse_out"),
1050+
mutates_args=(),
10511051
device_types=["cuda"],
10521052
)
10531053
def merge_attentions_varargs(
1054-
attn_out: torch.Tensor,
1055-
lse_out: Optional[torch.Tensor],
10561054
attn_split: Sequence[torch.Tensor],
10571055
lse_split: Sequence[torch.Tensor],
1058-
) -> None:
1056+
write_lse: bool,
1057+
output_dtype: Optional[torch.dtype],
1058+
B: int,
1059+
M: int,
1060+
G: int,
1061+
H: int,
1062+
Kq: int,
1063+
) -> List[torch.Tensor]:
10591064
from xformers.triton.vararg_kernel import unroll_varargs
10601065

10611066
from ._triton.splitk_kernels import _splitK_reduce_varargs
10621067

1068+
attn_out = torch.empty(
1069+
(B, M, G, H, Kq),
1070+
device=attn_split[0].device,
1071+
dtype=output_dtype or attn_split[0].dtype,
1072+
)
1073+
if write_lse:
1074+
lse_out = torch.empty(
1075+
(B, G, H, M),
1076+
device=attn_split[0].device,
1077+
dtype=lse_split[0].dtype,
1078+
)
1079+
else:
1080+
lse_out = None
10631081
kernel_args, grid = _prepare_reduce_kernel_params(
10641082
attn_out, lse_out, attn_split, lse_split
10651083
)
@@ -1073,16 +1091,52 @@ def merge_attentions_varargs(
10731091
BLOCK_SIZE=attn_out.shape[-1],
10741092
WRITE_LSE=lse_out is not None,
10751093
)
1094+
if write_lse:
1095+
assert lse_out is not None
1096+
return [attn_out, lse_out]
1097+
return [attn_out]
10761098

10771099

10781100
@torch.library.register_fake("xformers::fmha_merge_attentions_varargs")
10791101
def merge_attentions_varargs_fake(
1080-
attn_out: torch.Tensor,
1081-
lse_out: Optional[torch.Tensor],
10821102
attn_split: Sequence[torch.Tensor],
10831103
lse_split: Sequence[torch.Tensor],
1084-
) -> None:
1085-
return
1104+
write_lse: bool,
1105+
output_dtype: Optional[torch.dtype],
1106+
B: int,
1107+
M: int,
1108+
G: int,
1109+
H: int,
1110+
Kq: int,
1111+
) -> List[torch.Tensor]:
1112+
attn_out = torch.empty(
1113+
(B, M, G, H, Kq),
1114+
device=attn_split[0].device,
1115+
dtype=output_dtype or attn_split[0].dtype,
1116+
)
1117+
if write_lse:
1118+
lse_out = torch.empty(
1119+
(B, G, H, M),
1120+
device=attn_split[0].device,
1121+
dtype=lse_split[0].dtype,
1122+
)
1123+
return [attn_out, lse_out]
1124+
return [attn_out]
1125+
1126+
1127+
def _merge_attentions_backward(
1128+
ctx: torch.autograd.function.FunctionCtx,
1129+
grad: List[torch.Tensor],
1130+
) -> Tuple[None, ...]:
1131+
raise NotImplementedError(
1132+
"Backward pass is not implemented for merge_attentions. "
1133+
"If it was, it would be easy to get wrong attention gradients, "
1134+
"because the gradients of the LSEs "
1135+
"don't get propagated by attention backward."
1136+
)
1137+
1138+
1139+
merge_attentions_varargs.register_autograd(_merge_attentions_backward)
10861140

10871141

10881142
@torch.library.custom_op(

0 commit comments

Comments
 (0)