Skip to content

Conversation

@angeloskath
Copy link
Member

It simply generalizes the logic in the get_prompt_cache() to many caches simultaneously.

There is a decision to be made on whether we want to always grow the cache or perhaps we can do something smarter like leave it in the cache if it is a commonly used series of tokens (that would take care of system prompts for instance). For now this implements the exact same thing as we had before. Always grow the cache if possible and try trimming it if not.

@angeloskath angeloskath requested a review from awni November 20, 2025 19:48
@kernelpool
Copy link
Contributor

Nice work! I just tested this and ran into a few issues in server.py for streaming/non-streaming:

864                self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
865                self.wfile.flush()
866                if self.stream_options is not None and self.stream_options["include_usage"]:
867 -                  original_prompt_length = (
868 -                      len(self.prompt_cache.tokens) - len(tokens) + len(prompt)
869 -                  )
867 +                  original_prompt_length = len(cache_key) - len(tokens)
868                    response = self.completion_usage_response(
869                        original_prompt_length, len(tokens)
870                    )

and

876                response = self.generate_response(
877                    text,
878                    finish_reason,
879 -                  len(prompt),
879 +                  len(cache_key) - len(tokens),
880                    len(tokens),
881                    token_logprobs=token_logprobs,
882                    top_tokens=top_tokens,

cache, rest = self.prompt_cache.fetch_nearest_cache(
self.model_provider.model_key, prompt
)
cache_key = prompt[: len(prompt) - len(rest)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cache_key = prompt[: len(prompt) - len(rest)]
cache_key = prompt[: len(prompt) - len(rest)]
if len(rest) == 0 and len(cache_key) > 0:
rest = [cache_key[-1]]
cache_key = cache_key[:-1]

We probably need something like this like the old implementation had since stream_generate in generate.py has a len(prompt) == 0 check

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it is the exact same behavior as before though. If we do indeed want exact cache matches to be used then we need to try trimming the cache by 1 if possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I probably should have been more clear. I ran into a crash and so this was one way to fix it. The original implementation had:

# Leave at least one token in the prompt
com_prefix_len = min(com_prefix_len, len(prompt) - 1)

You can reproduce with the following:

mlx_lm.server --model mlx-community/Qwen3-4B-Instruct-2507-4bit --host 0.0.0.0 --port 8080

Then do the following 2 times:

curl -X POST http://localhost:8080/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "default_model",
    "messages": [{"role": "user", "content": "Hello, how are you?"}],
    "max_tokens": 10
  }'

Result:

mlx_lm.server --model /Volumes/WD_EXTRA/models/catalyst/Qwen3-4B-Instruct-2507-8bit --host 0.0.0.0 --port 8080 --max-tokens 16384
/Users/optimus/repo/mlx-lm/mlx_lm/server.py:1052: UserWarning: mlx_lm.server is not recommended for production as it only implements basic security checks.
  warnings.warn(
2025-11-21 16:23:12,284 - INFO - Starting httpd at 0.0.0.0 on port 8080...
127.0.0.1 - - [21/Nov/2025 16:23:15] "POST /v1/chat/completions HTTP/1.1" 200 -
2025-11-21 16:23:15,421 - INFO - Prompt processing progress: 0/14
2025-11-21 16:23:15,507 - INFO - Prompt processing progress: 13/14
2025-11-21 16:23:15,540 - INFO - Prompt processing progress: 14/14
127.0.0.1 - - [21/Nov/2025 16:23:17] "POST /v1/chat/completions HTTP/1.1" 200 -
----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 64738)
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 318, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 349, in process_request
    self.finish_request(request, client_address)
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 362, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/Users/optimus/repo/mlx-lm/mlx_lm/server.py", line 1044, in <lambda>
    lambda *args, **kwargs: handler_class(
                            ^^^^^^^^^^^^^^
  File "/Users/optimus/repo/mlx-lm/mlx_lm/server.py", line 393, in __init__
    super().__init__(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 761, in __init__
    self.handle()
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/http/server.py", line 436, in handle
    self.handle_one_request()
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/http/server.py", line 424, in handle_one_request
    method()
  File "/Users/optimus/repo/mlx-lm/mlx_lm/server.py", line 511, in do_POST
    self.handle_completion(prompt, stop_id_sequences)
  File "/Users/optimus/repo/mlx-lm/mlx_lm/server.py", line 779, in handle_completion
    for gen_response in stream_generate(
                        ^^^^^^^^^^^^^^^^
  File "/Users/optimus/repo/mlx-lm/mlx_lm/generate.py", line 698, in stream_generate
    for n, (token, logprobs, from_draft) in enumerate(token_generator):
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/optimus/repo/mlx-lm/mlx_lm/generate.py", line 688, in <genexpr>
    (token, logprobs, False) for token, logprobs in token_generator
                                                    ^^^^^^^^^^^^^^^
  File "/Users/optimus/repo/mlx-lm/mlx_lm/generate.py", line 355, in generate_step
    raise ValueError(
ValueError: Either input_embeddings or prompt (or both) must be provided.
----------------------------------------

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I ran into that later and fixed it in a063c39 let me know if you run into more issues. I should probably add a test for that as well...

@angeloskath
Copy link
Member Author

Damn sorry I missed those. I only tested non streaming admittedly (stupid of me). Will fix these by tomorrow.

@kernelpool
Copy link
Contributor

No worries, I was curious because I have my own implementation doing this, but it is much simpler and uses a session id based approach instead (since I'm using my own client). This approach is much better for general use. Perhaps there should be a way to also configure the maximum number of caches via a mlx_lm.server argument (for more resource constrained systems)? I see there's already a max_size set to 10.

model, tokens = self._lru.popleft()
self._delete(model, tokens)


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use radix tree techniques, which is featured in SGLang.

It accelerates in multi-round LLM, i.e. agentic LLM evrionment more efficiently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants