Skip to content

Conversation

@Sohailm25
Copy link

Proposed changes

  • Add minimal Metal paged attention support with a Python fallback and keep SDPA as the default. This enables downstream consumers (e.g., mlx-lm) to opt into paged attention without changing behavior for
    existing users.
  • Align SDPA bindings with current mlx main (add is_training / output_logsumexp, restore VJP / is_equivalent symbols) to fix build/link drift.
  • Tests cover paged KV allocation/copy, fallback parity, and SDPA fast-path (python/tests/test_paged_kv.py, python/tests/test_fast.py).

Checklist

  • I have read the CONTRIBUTING (https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@Sohailm25
Copy link
Author

This adds future pathway for the Continuous Batching change in mlx-lm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant