Skip to content

togethercomputer/saw-int4

Repository files navigation

saw-int4

saw-int4 is the official implementation of
<<SAW-INT4: System-Aware 4-Bit KV-Cache Quantization for Real-World LLM Serving>>

This repository implements Block Diagonal Rotation (BDR) for KV-cache quantization, along with system-level optimizations that seamlessly integrate into SGLang. The resulting system achieves near-BF16 accuracy while preserving the end-to-end performance benefits of INT4.

Contents

Introduction

This work studies 4-bit KV-cache quantization under real serving constraints such as paged memory layouts, regular memory access, and fused attention execution. Our primary method, BDR (block-diagonal rotation), applies a block-diagonal Hadamard rotation to the KV cache before token-wise INT4 KV-cache quantization, implemented directly inside a fork of SGLang.

We ship two submodule branches on the same fork remote:

  • third_party/sglang-fast-rotationOur proposed BDR implementation: fused block-diagonal rotation + INT4 KV-cache write. Use this fork for both accuracy and throughput on BF16, INT4, and BDR (the main paper numbers).
  • third_party/sglang-kmeansAblation study for kmeans, kmeans+rotation: KV dump, k-means centroids, and k-means + rotation variants. Not required to reproduce the core BDR vs BF16 vs INT4 story.

Pinned commits: SUBMODULE_VERSIONS.md.

How to run BDR

This section covers everything needed to run BDR on third_party/sglang-fast-rotation: get the code, install, and launch a server.

Get the code

git clone --recurse-submodules https://github.com/togethercomputer/saw-int4.git
cd saw-int4

If you cloned without submodules: git submodule update --init third_party/sglang-fast-rotation.

Server requirements

The BDR implementation is built on top of the SGLang codebase and currently assumes the following setup:

  • MHA models onlyMLA and other non-MHA layouts are not supported for these KV / BDR settings.
  • Prefill backend: fa3.
  • Decode backend: triton.

Install BDR

cd third_party/sglang-fast-rotation/python
pip install -e ".[all]"
pip install --no-build-isolation "git+https://github.com/Dao-AILab/fast-hadamard-transform.git"

Run BDR

BF16 KV (baseline)

python -m sglang.launch_server \
  --prefill-attention-backend fa3 \
  --decode-attention-backend triton \
  --model-path "Qwen/Qwen3-4B-Thinking-2507" \
  --port 30000 \
  --kv-cache-dtype auto

Original INT4 KV

python -m sglang.launch_server \
  --prefill-attention-backend fa3 \
  --decode-attention-backend triton \
  --model-path "Qwen/Qwen3-4B-Thinking-2507" \
  --port 30000 \
  --kv-cache-dtype int4

BDR (block diagnoal rotation on K)

HADAMARD=1 HADAMARD_ORDER=128 python -m sglang.launch_server \
  --prefill-attention-backend fa3 \
  --decode-attention-backend triton \
  --model-path "Qwen/Qwen3-4B-Thinking-2507" \
  --port 30000 \
  --kv-cache-dtype int4

For the full env variable reference, and the complete mode matrix, see docs/bdr_env_vars.md.

Quick demo (verify your install)

With the server running in any of the three modes above, run the smoke-test script from the repository root:

pip install openai   # if not already installed
python scripts/bdr_smoke_test.py --port 30001 --model Qwen/Qwen3-4B-Thinking-2507

The script sends a GPQA sample question to the server and streams the response.

Server : http://0.0.0.0:30000/v1
Model  : Qwen/Qwen3-4B-Thinking-2507

--- Prompt (GPQA sample) ---
Answer the following multiple choice question.....
...

--- Response ---
<model reasoning and answer streamed here>

Primary accuracy and throughput

Accuracy (simple-evals / GPQA) and throughput (genai-bench) both use third_party/sglang-fast-rotation; server setup is in How to run BDR. Accuracy model: Qwen/Qwen3-4B-Thinking-2507. Throughput model: Qwen/Qwen3-8B (override MODEL_PATH in scripts if you align checkpoints).

Accuracy (primary)

Prepare

Prerequisite (GPQA client): openai/simple-evals is included as a submodule at third_party/simple-evals.

git submodule update --init --checkout third_party/simple-evals
cd third_party/simple-evals
mkdir -p simple_evals
touch simple_evals/__init__.py
pip install openai pandas requests jinja2 tqdm numpy

Add a local model alias once in third_party/simple-evals/simple_evals.py inside the models = { ... } dictionary so simple-evals and set max_tokens=32768:

"qwen3_4b": ChatCompletionSampler(
    model="Qwen/Qwen3-4B-Thinking-2507",
    system_message=OPENAI_SYSTEM_MESSAGE_API,
    max_tokens=32768,
),

RUN-GPQA

With simple-evals installed and the SGLang server already up (start it in the desired mode from Run BDR, using Qwen/Qwen3-4B-Thinking-2507 as the model), point the client at http://127.0.0.1:<port>/v1 and run GPQA:

cd third_party/simple-evals
export OPENAI_BASE_URL="http://127.0.0.1:30000/v1" 
export OPENAI_API_KEY="dummy"
python -m simple-evals.simple_evals --model qwen3_4b --eval gpqa --n-repeats 3

Accuracy results (primary, temp=0.6, seq=32k and top=0.95)

Model Method Benchmark Score
Qwen/Qwen3-4B-Thinking-2507 BF16 GPQA 66.6667
Qwen/Qwen3-4B-Thinking-2507 INT4 GPQA 0
Qwen/Qwen3-4B-Thinking-2507 BDR (K-only) GPQA 65.8249

Throughput and latency (primary)

Speed results use sglang-fast-rotation (fused INT4 + BDR kernels) with Qwen/Qwen3-8B, driven by genai-bench against the server’s OpenAI-compatible HTTP API. Helper: scripts/run_genai_bench_example.sh (default MODEL_PATH). Full CLI, traffic scenarios, Excel/plots: GenAI Bench docs and Run benchmark.

Prepare (genai-bench)

Prerequisite (throughput client): install genai-bench (separate from the SGLang venv if you prefer):

pip install genai-bench

Optional (quieter HF logs during tokenizer load): export TRANSFORMERS_VERBOSITY=error. For Docker / dev installs, see the upstream installation guide.

Terminal 1 — server (example BF16 KV):

cd third_party/sglang-fast-rotation/python
python -m sglang.launch_server \
  --prefill-attention-backend fa3 \
  --decode-attention-backend triton \
  --model-path "Qwen/Qwen3-8B" \
  --port 30000 \
  --kv-cache-dtype int4

Terminal 2 — client (after pip install genai-bench; matches ~256 input / 32 output tokens and concurrency 16 — see traffic scenarios):

genai-bench benchmark --api-backend sglang \
  --api-base "http://127.0.0.1:30000" \
  --api-key "dummy" \
  --api-model-name "Qwen/Qwen3-8B" \
  --model-tokenizer "Qwen/Qwen3-8B" \
  --task text-to-text \
  --traffic-scenario "D(256,32)" \
  --num-concurrency 16 \
  --max-time-per-run 5 \
  --max-requests-per-run 200 \
  --server-engine "SGLang" \
  --server-gpu-type "local" \
  --server-version "custom" \
  --server-gpu-count 1

Tune --max-time-per-run, --max-requests-per-run, --num-concurrency, and --traffic-scenario using genai-bench benchmark --help and the docs above. Label runs with accurate --server-gpu-type / --server-version when publishing numbers.

Sweep BF16 vs INT4 vs BDR: restart the server with the right env and --kv-cache-dtype, then rerun genai-bench with identical client flags.

Config Env --kv-cache-dtype
BF16 KV HADAMARD=0 auto
INT4 KV HADAMARD=0 int4
BDR + INT4 HADAMARD=1 ROTATE_V=0 HADAMARD_ORDER=128 int4

SGLang’s built-in bench_serving (bench_serving) is optional; this repo standardizes on genai-bench for comparable sweeps and reporting.

Hub: eval_speed/
Helper: scripts/run_genai_bench_example.sh

Speed results (primary)

Hardware: 1× H100 80 GB, TP=1. Model: Qwen/Qwen3-8B.
Client: genai-bench. Metric definitions: eval_speed/metrics.md.

Short context — D(256, 1024) (256 input / 1024 output tokens)
Cap: 5 min or 256 requests. Results: eval_speed/results/20260416_203040/

KV config Conc output_tps(job) mean_input_tps(req) mean_output_tps(req) mean_ttft(req) (ms) E2E mean(req) (s) E2E p75(req) (s) E2E p90(req) (s) total requests Wall (s)
BF16 32 3,795 1,573 122.1 196 8.57 8.60 8.62 256 69
INT4 32 3,687 1,380 120.9 225 8.69 8.71 8.75 256 71
INT4 + BDR (K-only, ord=128) 32 3,689 1,379 120.2 226 8.74 8.74 8.76 256 71
BF16 64 5,950 796 98.7 369 10.74 10.78 10.82 256 44
INT4 64 6,371 774 105.0 370 10.11 10.16 10.20 256 41
INT4 + BDR (K-only, ord=128) 64 6,235 755 104.3 377 10.19 10.24 10.26 256 42
BF16 128 8,410 455 71.8 657 14.92 15.00 15.11 256 31
INT4 128 9,544 437 81.0 665 13.30 13.38 13.45 256 28
INT4 + BDR (K-only, ord=128) 128 9,350 458 80.1 655 13.43 13.51 13.60 256 28
BF16 256 11,195 242 49.3 1,224 22.00 22.15 22.24 256 23
INT4 256 11,624 225 51.1 1,237 21.25 21.50 21.57 256 23
INT4 + BDR (K-only, ord=128) 256 11,732 266 51.6 1,148 20.99 21.12 21.19 256 22

Long context — D(16384, 1024) (16 384 input / 1024 output tokens)
Cap: 20 min or 64–256 requests (varies by concurrency). Results: eval_speed/results/20260416_214449/ (conc 8–64), eval_speed/results/20260416_233035/ (conc 128)

KV config Conc output_tps(job) mean_input_tps(req) mean_output_tps(req) mean_ttft(req) (ms) E2E mean(req) (s) E2E p75(req) (s) E2E p90(req) (s) total requests Wall (s)
BF16 8 414 8,311 61.4 2,636 19.37 19.53 19.65 64 158
INT4 8 458 8,391 69.2 2,631 17.50 17.67 17.77 64 143
INT4 + BDR (K-only, ord=128) 8 457 8,784 68.7 2,523 17.50 17.69 17.78 64 143
BF16 16 481 4,413 36.7 5,104 33.14 33.48 33.65 64 136
INT4 16 571 4,672 45.4 4,956 27.74 28.04 28.28 64 115
INT4 + BDR (K-only, ord=128) 16 568 4,083 44.8 4,875 27.94 28.30 28.54 64 116
BF16 32 570 1,741 32.9 18,047 49.58 73.20 73.64 64 115
INT4 32 618 2,147 25.4 9,568 50.45 51.11 51.49 64 106
INT4 + BDR (K-only, ord=128) 32 616 2,215 25.1 9,350 50.57 51.23 51.62 64 107
BF16 64 471 806 32.7 44,798 76.91 112.33 113.22 64 139
INT4 64 666 1,114 14.7 19,398 90.46 91.70 92.51 64 98
INT4 + BDR (K-only, ord=128) 64 663 1,150 14.4 18,371 90.78 92.06 92.83 64 99
BF16 128 559 310 32.9 113,583 145.96 220.85 221.91 148 271
INT4 128 701 527 12.3 57,654 142.19 208.11 210.82 153 224
INT4 + BDR (K-only, ord=128) 128 701 535 12.3 57,054 142.09 208.05 210.73 153 224

Ablation study (k-means, k-means + rotation)

Use third_party/sglang-kmeans: KV dump for calibration, tools/fit_kv_centroids.py, then SGLANG_KV_CENTROIDS_PATH for k-means + INT4 and k-means + BDR (optional HADAMARD / ROTATE_V). Accuracy still uses simple-evals from third_party/simple-evals (Prepare; run GPQA per upstream docs).

Install sglang-kmeans

Not needed for primary BF16 / INT4 / BDR (How to run BDR). Initialize the submodule (skipped by default), then install:

git submodule update --init third_party/sglang-kmeans
cd third_party/sglang-kmeans/python
pip install -e ".[all]"
pip install "flash-kmeans @ git+https://github.com/jindajia/flash-kmeans.git"

KV calibration (ablation only)

Primary BF16 / INT4 / BDR does not need this step.

1. Dump KV activations — run from sglang-kmeans with a BF16 KV cache (auto) so dumps are in calibration space:

cd third_party/sglang-kmeans/python

export DUMP_KVCACHE=true
export DUMP_KVCACHE_TOKENS=512
export DUMP_KVCACHE_DIR=/path/to/kv_dumps

python -m sglang.launch_server \
  --prefill-attention-backend fa3 \
  --decode-attention-backend triton \
  --model-path "Qwen/Qwen3-8B" \
  --port 30000 \
  --kv-cache-dtype auto

Drive enough traffic so each layer hits the threshold at least once. Files appear as kv_calibration_layer_<layer_id>.pt (dict with k, v, indices on CPU; see triton_backend.py in the submodule for selection logic).

2. Fit centroids offline — from the repository root:

python tools/fit_kv_centroids.py \
  --dump-dir /path/to/kv_dumps \
  --out-dir /path/to/centroids_out \
  --n-clusters 16 \
  --seed 0

This writes k_layer_L_clusters_<N>_centers.pt and v_layer_L_clusters_<N>_centers.pt per global layer L, shaped (N, num_kv_heads_global * head_dim), for loading in the submodule.

3. Run INT4 + k-means inference

export N_CLUSTERS=16
export SGLANG_KV_CENTROIDS_PATH=/path/to/centroids_out

python -m sglang.launch_server \
  --prefill-attention-backend fa3 \
  --decode-attention-backend triton \
  --model-path "Qwen/Qwen3-8B" \
  --port 30000 \
  --kv-cache-dtype int4

K-means + BDR: keep SGLANG_KV_CENTROIDS_PATH, set HADAMARD=1, optional ROTATE_V, and HADAMARD_ORDER consistent with head dimension (same as primary BDR).

Ablation method matrix

Method HADAMARD ROTATE_V HADAMARD_ORDER --kv-cache-dtype SGLANG_KV_CENTROIDS_PATH N_CLUSTERS
K-means + INT4 0 0 n/a int4 required match files
K-means + BDR 1 0 or 1 set int4 required match files

K-means + INT4 example:

cd third_party/sglang-kmeans/python
export OPENAI_API_KEY=dummy
export N_CLUSTERS=16
export SGLANG_KV_CENTROIDS_PATH=/path/to/centroids_out
export HADAMARD=0
export ROTATE_V=0
python -m sglang.launch_server \
  --prefill-attention-backend fa3 \
  --decode-attention-backend triton \
  --model-path "Qwen/Qwen3-8B" --port 30000 --kv-cache-dtype int4

K-means + BDR example:

export HADAMARD=1
export ROTATE_V=0
export HADAMARD_ORDER=16
export N_CLUSTERS=16
export SGLANG_KV_CENTROIDS_PATH=/path/to/centroids_out
python -m sglang.launch_server \
  --prefill-attention-backend fa3 \
  --decode-attention-backend triton \
  --model-path "Qwen/Qwen3-8B" --port 30000 --kv-cache-dtype int4

Hub: eval_accuracy/
Helper: CENTROIDS=/path/to/centroids_out ./scripts/run_eval_matrix.sh kmeans or kmeans_bdr.

Accuracy results (ablation)

Model Method Benchmark Score
K-means + INT4
K-means + BDR

Fill from eval_accuracy/results/.

Repository layout

Path Role
third_party/sglang-fast-rotation/ Primary BF16 / INT4 / BDR — accuracy + speed
third_party/sglang-kmeans/ Ablation k-means KV + dump / centroids
third_party/simple-evals/ GPQA accuracy client (openai/simple-evals submodule; no separate clone needed)
docs/bdr_env_vars.md Full BDR env variable reference and mode matrix
scripts/ bdr_smoke_test.py (install smoke test), run_primary_eval_matrix.sh, run_eval_matrix.sh, run_genai_bench_example.sh, clone_submodules.sh
tools/ fit_kv_centroids.py (ablation calibration)
eval_primary/ Primary accuracy logs / tables
eval_speed/ Primary throughput logs / tables
eval_accuracy/ Ablation accuracy logs / tables

Full reproduction

Large raw bundles may live outside this repo.

  • Full reproduction bundle: TBD — add URL

Submodule SHAs: SUBMODULE_VERSIONS.md.

License

See LICENSE.

About

Official implementation of Paper "System-Aware 4-Bit KV-Cache Quantization for Real-World LLM Serving"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors