-
Notifications
You must be signed in to change notification settings - Fork 60
Description
Hello there,
I'm comparing SDPA with the efficient backend vs flex attention on a 300M model. Sequences are typically from 70 to 1400 frames long, and they vary in length (and padding due to batching). According to my measurements, flex attention is roughly 7% faster than SDPA, which is a bit ... disappointing :'(
Env info:
4090
Torch 2.7
cuda 12.6
At a high level this is what happens:
a Transformer.py file where:
- Create a padding block mask function:
def create_padding_mask(pads):
def padding(b, h, q_idx, kv_idx):
return ~pads[b, kv_idx]
return padding
And follow this approach in the forward:
Ba, Ti, Fe = x.shape
masks_fn = create_padding_mask(a_boolean_tensor)
padding_mask_fct = create_block_mask(masks_fn, B=Ba, H=None, Q_LEN=Ti, KV_LEN=Ti, _compile=True)
for mha in all_layers:
mha(x, padding_mask_fct)
The mha come from another python file where, at the top there is flex_attention = torch.compile(flex_attention, dynamic=True)
and then within the forward of the MHA class:
x = flex_attention(
query=q.permute(0, 2, 1, 3),
key=k.permute(0, 2, 1, 3),
value=v.permute(0, 2, 1, 3),
block_mask=padding_mask_fct
)
The permutations are here due to prior transformations of the q,k,v tensors.
Does my implementation seem to utilise flex attention properly?