Skip to content

_compile=True blows up GPU memory when seq length / block_mask change, but is best when both are fixed (FlexAttention) #159

@wileewang

Description

@wileewang

Hi, I’m using FlexAttention like

block_mask = create_block_mask(my_custom_mask_mod, B=B, H=None, Q_LEN=S, KV_LEN=S, _compile=False)

flex_attention_op = torch.compile(
    flex_attention,
    dynamic=False,
    mode="max-autotune",
)

out = flex_attention_op(query=q, key=k, value=v, block_mask=block_mask)

Setup:

  • q has a few fixed length buckets (Different lengths for different batch). Within one batch, sequence_length is the same and we can share one same block_mask or not.

Observations:

  1. Different lengths across steps

    • _compile=False: memory stays lower/stable, fits on my GPU.
    • _compile=True: memory jumps a lot and I OOM.
  2. Same length, but different block masks across batches

    • _compile=False still uses less memory than _compile=True.
  3. Same length and same block mask for all steps

    • _compile=True gives the smallest memory footprint here.

Currently I must use _compile=False to survive the first two cases. Any tips to reduce memory further, or to make _compile=True workable without the growth?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions