Skip to content
Open
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ To see a description of all the arguments you can do:
>>> help(convert)
```

### Continuous batching (experimental)

`mlx_lm.server` supports iteration-level continuous batching when launched with `--enable-continuous-batching`. The scheduler admits new requests between decode steps so concurrent arrivals see lower TTFT and higher throughput.

Flags:
- `--enable-continuous-batching`
- `--max-num-seqs` (default `16`)
- `--max-tokens-per-step` (default `4096`)
- `--prefill-chunk` (default `1024`)

Requests that require tool-calling or per-token logprobs currently fall back to the legacy streaming path. Persistent KV cache, prefix caching, and Metal PagedAttention arrive in future PRs.

A Poisson-arrival benchmark comparing the new runtime against `batch_generate` is available in `bench/bench_continuous_vs_static.py`.

#### Streaming

For streaming generation, use the `stream_generate` function. This yields
Expand Down
244 changes: 244 additions & 0 deletions bench/bench_continuous_vs_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# ABOUTME: Benchmarks static batch_generate against continuous batching runtime.
# ABOUTME: Uses wall-clock throughput under Poisson arrivals to track gains.

"""
Benchmark static batch_generate vs continuous batching runtime under Poisson arrivals.
Run: python bench/bench_continuous_vs_static.py --repo mlx-community/Llama-3.2-3B-Instruct-4bit
"""

import argparse
import json
import logging
import math
import random
import threading
import time
from dataclasses import dataclass
from statistics import mean, median
from typing import List, Optional

import mlx.core as mx

from mlx_lm import batch_generate, load
from mlx_lm.server_batched.engine import ModelRunner
from mlx_lm.server_batched.runtime import ContinuousBatchingRuntime

random.seed(0)

logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")


@dataclass
class Result:
submit_ns: int
first_token_ns: Optional[int]
finish_ns: int
gen_tokens: int


def poisson_arrivals(lmbda: float, n: int):
t = 0.0
out = []
for _ in range(n):
u = random.random()
gap = -math.log1p(-u) / lmbda
t += gap
out.append(t)
return out


def summarize(name: str, results: List[Result]):
if not results:
print(
f"[{name}] n=0 tokens=0 wall=0.00s tokens/s=0.00 ttft mean=0.0ms median=0.0ms p95=0.0ms"
)
return 0.0, 0.0, 0

start_ns = min(r.submit_ns for r in results)
end_ns = max(r.finish_ns for r in results)
wall_seconds = max((end_ns - start_ns) / 1e9, 1e-9)
total_tokens = sum(r.gen_tokens for r in results)
tokens_per_sec = total_tokens / wall_seconds

ttfts = []
for r in results:
first = r.first_token_ns if r.first_token_ns is not None else r.finish_ns
ttft_ms = max((first - r.submit_ns) / 1e6, 0.0)
ttfts.append(ttft_ms)

ttfts.sort()
ttft_mean = mean(ttfts) if ttfts else 0.0
ttft_median = median(ttfts) if ttfts else 0.0
idx = max(int(math.ceil(0.95 * len(ttfts))) - 1, 0) if ttfts else 0
ttft_p95 = ttfts[idx] if ttfts else 0.0

print(
f"[{name}] n={len(results)} tokens={total_tokens} wall={wall_seconds:.2f}s "
f"tokens/s={tokens_per_sec:.2f} "
f"ttft mean={ttft_mean:.1f}ms median={ttft_median:.1f}ms p95={ttft_p95:.1f}ms"
)
return tokens_per_sec, wall_seconds, total_tokens


def run_static(model, tokenizer, prompts, max_tokens):
prompt_token_batches = [tokenizer.encode(p) for p in prompts]
start_ns = time.perf_counter_ns()
response = batch_generate(
model,
tokenizer,
prompt_token_batches,
max_tokens=max_tokens,
verbose=False,
)
end_ns = time.perf_counter_ns()
results = []
for text in response.texts:
tokens = len(tokenizer.encode(text, add_special_tokens=False))
results.append(
Result(
submit_ns=start_ns,
first_token_ns=end_ns,
finish_ns=end_ns,
gen_tokens=tokens,
)
)
return results


def run_continuous(model, tokenizer, prompts, max_tokens, lmbda):
runner = ModelRunner(model, tokenizer, draft_model=None)
runtime = ContinuousBatchingRuntime(
runner,
max_num_seqs=16,
max_tokens_per_step=4096,
prefill_chunk=1024,
debug_metrics=True,
)
results: List[Result] = []
lock = threading.Lock()
arrivals = poisson_arrivals(lmbda, len(prompts))
start = time.perf_counter()

def submit(idx, arrival_s):
# Wait until scheduled arrival time
while True:
elapsed = time.perf_counter() - start
if elapsed >= arrival_s:
break
time.sleep(min(0.001, arrival_s - elapsed))

prompt = prompts[idx]
prompt_tokens = tokenizer.encode(prompt)
sampler_settings = {
"temp": 0.0,
"top_p": 1.0,
"min_p": 0.0,
"top_k": 0,
"xtc_probability": 0.0,
"xtc_threshold": 0.0,
"xtc_special_tokens": [tokenizer.eos_token_id, tokenizer.encode("\n")],
}
stopping_settings = {"eos_token_id": tokenizer.eos_token_id}

submit_ns = time.perf_counter_ns()
_, generator = runtime.submit_request(
prompt_tokens=prompt_tokens,
max_new_tokens=max_tokens,
sampler_settings=sampler_settings,
stopping_settings=stopping_settings,
logit_bias=None,
repetition_penalty=None,
repetition_context_size=None,
)

first_ns = None
finish_ns = submit_ns
final_tokens = 0
deadline = time.perf_counter() + 60.0 # guard against hung generators
try:
for response in generator:
if time.perf_counter() > deadline:
break
if first_ns is None and response.generation_tokens > 0:
first_ns = time.perf_counter_ns()
final_tokens = response.generation_tokens
if response.finish_reason:
finish_ns = time.perf_counter_ns()
break
except Exception:
finish_ns = time.perf_counter_ns()
finally:
if finish_ns is None:
finish_ns = time.perf_counter_ns()
with lock:
results.append(
Result(
submit_ns=submit_ns,
first_token_ns=first_ns,
finish_ns=finish_ns,
gen_tokens=max(final_tokens, 0),
)
)

threads = []
for idx, arrival in enumerate(arrivals):
th = threading.Thread(target=submit, args=(idx, arrival), daemon=True)
th.start()
threads.append(th)

for th in threads:
th.join()

runtime.shutdown()
return results


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--repo", default="mlx-community/Llama-3.2-3B-Instruct-4bit")
parser.add_argument("--n", type=int, default=32)
parser.add_argument("--concurrency", type=int, default=8)
parser.add_argument("--max_tokens", type=int, default=64)
parser.add_argument("--prompt_len", type=int, default=64)
args = parser.parse_args()

model, tokenizer = load(path_or_hf_repo=args.repo)
base_prompt = "Tell me a haiku about mac GPUs." * 4
prompts = [base_prompt[: args.prompt_len] for _ in range(args.n)]

static_results = run_static(model, tokenizer, prompts, args.max_tokens)
static_tps, static_wall, static_tokens = summarize(
"static_batch_generate", static_results
)

est_service = max(args.max_tokens / 200.0, 0.01)
lmbda = max(0.1, args.concurrency / est_service)
continuous_results = run_continuous(
model, tokenizer, prompts, args.max_tokens, lmbda
)
cont_tps, cont_wall, cont_tokens = summarize(
"continuous_runtime", continuous_results
)

print(
json.dumps(
{
"static_tokens_per_sec": static_tps,
"static_total_tokens": static_tokens,
"static_wall_seconds": static_wall,
"continuous_tokens_per_sec": cont_tps,
"continuous_total_tokens": cont_tokens,
"continuous_wall_seconds": cont_wall,
},
indent=2,
)
)

if args.concurrency >= 4 and cont_tps < 1.5 * static_tps:
raise SystemExit(
f"FAIL: continuous tokens/s {cont_tps:.2f} < 1.5x static {static_tps:.2f}"
)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions mlx_lm/SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ curl localhost:8080/v1/chat/completions \
- `num_draft_tokens`: (Optional) The number of draft tokens the draft model
should predict at once. Defaults to `3`.

### Continuous batching flags (experimental)

The HTTP server now supports iteration-level continuous batching when you start it with `--enable-continuous-batching`. Additional flags:

- `--max-num-seqs`: maximum active sequences decoded per scheduler iteration (default `16`).
- `--max-tokens-per-step`: token budget per scheduler tick for prefill work (default `4096`).
- `--prefill-chunk`: maximum prompt tokens ingested per prefill step (default `1024`).

When the flag is enabled, requests requiring tool-calling or per-token logprobs currently fall back to the legacy streaming path. Persistent KV cache and prefix caching arrive in subsequent releases.

### Response Fields

- `id`: A unique identifier for the chat.
Expand Down
38 changes: 24 additions & 14 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,16 +1026,27 @@ def _process_prompts(self, prompts):
)

def _step(self, input_tokens: mx.array, prompt_cache: List[Any]):
# Align caches to the current batch size before attention masks are built.
batch_size = input_tokens.shape[0]
for c in prompt_cache or []:
rebatch = getattr(c, "rebatch", None)
if rebatch is not None:
rebatch(batch_size)

logits = self.model(input_tokens, cache=prompt_cache)
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
sampled = self.sampler(logprobs)
return sampled, logprobs

def stats(self):
self._stats.prompt_tps = self._stats.prompt_tokens / self._stats.prompt_time
prompt_time = self._stats.prompt_time or 0.0
gen_time = self._stats.generation_time or 0.0
self._stats.prompt_tps = (
self._stats.prompt_tokens / prompt_time if prompt_time > 0 else 0.0
)
self._stats.generation_tps = (
self._stats.generation_tokens / self._stats.generation_time
self._stats.generation_tokens / gen_time if gen_time > 0 else 0.0
)
self._stats.peak_memory = mx.get_peak_memory() / 1e9
return self._stats
Expand All @@ -1047,15 +1058,13 @@ def _next(self):
batch = self.active_batch
num_active = len(batch) if batch else 0
num_to_add = self.completion_batch_size - num_active
while num_to_add >= self.prefill_batch_size:
prompts = self.unprocessed_prompts[: self.prefill_batch_size]
# Finish processing the last examples of the last batch
if len(prompts) == 0 and num_active > 0:
while num_to_add > 0 and self.unprocessed_prompts:
take = min(
self.prefill_batch_size, num_to_add, len(self.unprocessed_prompts)
)
prompts = self.unprocessed_prompts[:take]
if not prompts:
break
# No more prompts and no more completions, all done
elif len(prompts) == 0:
self.active_batch = None
return []
# Process prompts
if batch is not None and not prompt_processing:
# Finish any active completion tokens
Expand All @@ -1064,9 +1073,7 @@ def _next(self):
tic = time.perf_counter()

batch = self._process_prompts(prompts)
self.unprocessed_prompts = self.unprocessed_prompts[
self.prefill_batch_size :
]
self.unprocessed_prompts = self.unprocessed_prompts[take:]
prompt_processing = True
# If there was no active batch, set it
if self.active_batch is None:
Expand All @@ -1075,7 +1082,10 @@ def _next(self):
self.active_batch.extend(batch)

num_active = len(self.active_batch)
num_to_add -= len(batch)
num_to_add = self.completion_batch_size - num_active

if self.active_batch is None and not self.unprocessed_prompts:
return []

batch = self.active_batch
y, logprobs = batch.y, batch.logprobs
Expand Down
Loading