Skip to content

Conversation

@yiakwy-xpu-ml-framework-team
Copy link

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team commented Nov 13, 2025

Description

This is from the latest paper where pervious 1-bit is packed with 2 bits, we prove sub-1 bit (0.7 ~ 0.8 bit) can also generate very good results and only 1 bit needed to pack weights.

Config:

Mlx : 0.29.3
Pytorch : 2.9
System :

  • M4 Pro (20 metal 3 cores + 16 ANE cores),
  • M3 Ultra (80 metal 3 cores + 32 ANE cores),

Test

Simple Case
截屏2025-11-14 下午8 42 56

We found apple's fp16 AccT has some precision problems hence fallback to fp32 datatype for accFrag.

Complex Case
  • Llama3 classic layout MxKxN = 4096 x 16384 x 4096
    • baseline :
      • M4 Pro :
        截屏2025-11-15 下午6 47 11
    • metal kernel vectorized loads
      • M4 Pro :
        截屏2025-11-16 下午2 20 24
      • M3 Ultra (thanks the support from Apple team) :
        mlx_test_img_2
TO DO LIST

[] Double Buffer
[] WASP
[] Prefetch
[] Add a simple tunner
[x] streamk :
[x] splitk with atomic_fetch_add
[] SIMD_SUM_ACC, SIMD_ADD_MULTIPLY_ACC (to improve flops/bytes ratio)
[x] Faster unpack
[] Faster SIMD unpack
[] Memory interleave preprocessing
[x] vector load (64 bit)
[WIP] generalized N banks M (guess, 32 banks, 4 byte per bank) bit (64 bits vectorization size for Apple platform) swizzle strategy, e.g. BLOCK_SIZE_K=128*16 bits data, for 4 x 8 threads group configure 32x128 / (8x64) = 8 rows, at least BLOCK_SIZE_M=16 exists will gives 2-way conficts.

Apple Metal Performance Ablation Study

Pending

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team marked this pull request as draft November 13, 2025 11:17
@yiakwy-xpu-ml-framework-team
Copy link
Author

@awni Could you have a look at it ? Thanks !

@awni
Copy link
Member

awni commented Nov 13, 2025

I'm a bit confused about this. Is there a specific model it should be used with? Or what is the intended usage?

@yiakwy-xpu-ml-framework-team
Copy link
Author

yiakwy-xpu-ml-framework-team commented Nov 14, 2025

I'm a bit confused about this. Is there a specific model it should be used with? Or what is the intended usage?

Yes there is a paper working on this topic and later codebook lut kernel will be added in metal platform.

https://openreview.net/pdf?id=yBDBCpEzsO


Just realized that apple's terminology of grid is different from CUDA grid.

Ref :

[1] https://www.shashankshekhar.com/blog/apple-metal-vs-nvidia-cuda
[2] https://ml-explore.github.io/mlx/build/html/dev/custom_metal_kernels.html#custom-metal-kernels

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team marked this pull request as ready for review November 15, 2025 09:13
@yiakwy-xpu-ml-framework-team
Copy link
Author

Hi @awni it is ready to be merged.

Optimization will be done in parallel and I am very happy if there are any inputs from you. Thank you!

@awni
Copy link
Member

awni commented Nov 15, 2025

Hi @yiakwy-xpu-ml-framework-team appreciate the contribution, but I don't think it makes sense to merge this. We'd need some evidence that this is useful: is there a corresponding model? Is it fast? Does it work? etc.

@yiakwy-xpu-ml-framework-team
Copy link
Author

yiakwy-xpu-ml-framework-team commented Nov 16, 2025

@awni

is there a corresponding model? Does it work?

I guess the model should be provided, let me check.

Is it fast?

This is the follow up of 1.58 bit model in #219, where end2end performance is evaluted.

As for kernel part, I believe the the file self-explained, as an extension to https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/bitlinear_layers.py by Blaizzy @Blaizzy

Benefits form sub 1 bits:

  1. Bitnet (1.58 bit) actually uses 2 bits for each weights, while sub-1-bit models uses average less 1 bit per weight , hence less of memory IO from global to sram on chip

  2. We have extensively benchmarked in various platform, how 1 bit should be accelerated since including 4 bits (marlin, bitnet) they are very hard to beat SOTA fp16 matmul due to extensive shift operations.

Here we make it better, by switching m loop into inner loop , values obtained from shifting is cached and reused.

More techniques will be added soon

  1. Streamk added : various old implementation are built on split-k variants, especially in metal platform. Perhaps this is the first stream k low bit gemm (or even the first stream k gemm) in metal platform ? With streamk, workloads are more evenly distributed among 80 Gpu cores in M3 Ultra platfrom. This is a practical contribution.

  2. Reducing model size can be orthogonal to accelerating inference. For example MoE has large model size, but only few of the parameters are activated; with model size reduction, gpt-oss-120b-mxfp4 alike models are affordable running in a single H100 GPU (H100x8 80 GB DGX).

One of drawback of bitnet is that it is based on dense model, pervious work (gpt-oss mxfp4) has proven that large MoE models with MoE FFN layers are resistant to quantization errors.

When adapting this change to low-bit MoE , it will be very useful in model size reduction (for large MoE).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants