-
Notifications
You must be signed in to change notification settings - Fork 312
Add a prompt cache that can hold multiple prompts #625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Nice work! I just tested this and ran into a few issues in server.py for streaming/non-streaming: and |
| cache, rest = self.prompt_cache.fetch_nearest_cache( | ||
| self.model_provider.model_key, prompt | ||
| ) | ||
| cache_key = prompt[: len(prompt) - len(rest)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| cache_key = prompt[: len(prompt) - len(rest)] | |
| cache_key = prompt[: len(prompt) - len(rest)] | |
| if len(rest) == 0 and len(cache_key) > 0: | |
| rest = [cache_key[-1]] | |
| cache_key = cache_key[:-1] |
We probably need something like this like the old implementation had since stream_generate in generate.py has a len(prompt) == 0 check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it is the exact same behavior as before though. If we do indeed want exact cache matches to be used then we need to try trimming the cache by 1 if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I probably should have been more clear. I ran into a crash and so this was one way to fix it. The original implementation had:
# Leave at least one token in the prompt
com_prefix_len = min(com_prefix_len, len(prompt) - 1)
You can reproduce with the following:
mlx_lm.server --model mlx-community/Qwen3-4B-Instruct-2507-4bit --host 0.0.0.0 --port 8080
Then do the following 2 times:
curl -X POST http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "default_model",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"max_tokens": 10
}'
Result:
mlx_lm.server --model /Volumes/WD_EXTRA/models/catalyst/Qwen3-4B-Instruct-2507-8bit --host 0.0.0.0 --port 8080 --max-tokens 16384
/Users/optimus/repo/mlx-lm/mlx_lm/server.py:1052: UserWarning: mlx_lm.server is not recommended for production as it only implements basic security checks.
warnings.warn(
2025-11-21 16:23:12,284 - INFO - Starting httpd at 0.0.0.0 on port 8080...
127.0.0.1 - - [21/Nov/2025 16:23:15] "POST /v1/chat/completions HTTP/1.1" 200 -
2025-11-21 16:23:15,421 - INFO - Prompt processing progress: 0/14
2025-11-21 16:23:15,507 - INFO - Prompt processing progress: 13/14
2025-11-21 16:23:15,540 - INFO - Prompt processing progress: 14/14
127.0.0.1 - - [21/Nov/2025 16:23:17] "POST /v1/chat/completions HTTP/1.1" 200 -
----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 64738)
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 318, in _handle_request_noblock
self.process_request(request, client_address)
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 349, in process_request
self.finish_request(request, client_address)
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 362, in finish_request
self.RequestHandlerClass(request, client_address, self)
File "/Users/optimus/repo/mlx-lm/mlx_lm/server.py", line 1044, in <lambda>
lambda *args, **kwargs: handler_class(
^^^^^^^^^^^^^^
File "/Users/optimus/repo/mlx-lm/mlx_lm/server.py", line 393, in __init__
super().__init__(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 761, in __init__
self.handle()
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/http/server.py", line 436, in handle
self.handle_one_request()
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/http/server.py", line 424, in handle_one_request
method()
File "/Users/optimus/repo/mlx-lm/mlx_lm/server.py", line 511, in do_POST
self.handle_completion(prompt, stop_id_sequences)
File "/Users/optimus/repo/mlx-lm/mlx_lm/server.py", line 779, in handle_completion
for gen_response in stream_generate(
^^^^^^^^^^^^^^^^
File "/Users/optimus/repo/mlx-lm/mlx_lm/generate.py", line 698, in stream_generate
for n, (token, logprobs, from_draft) in enumerate(token_generator):
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/optimus/repo/mlx-lm/mlx_lm/generate.py", line 688, in <genexpr>
(token, logprobs, False) for token, logprobs in token_generator
^^^^^^^^^^^^^^^
File "/Users/optimus/repo/mlx-lm/mlx_lm/generate.py", line 355, in generate_step
raise ValueError(
ValueError: Either input_embeddings or prompt (or both) must be provided.
----------------------------------------
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep I ran into that later and fixed it in a063c39 let me know if you run into more issues. I should probably add a test for that as well...
|
Damn sorry I missed those. I only tested non streaming admittedly (stupid of me). Will fix these by tomorrow. |
|
No worries, I was curious because I have my own implementation doing this, but it is much simpler and uses a session id based approach instead (since I'm using my own client). This approach is much better for general use. Perhaps there should be a way to also configure the maximum number of caches via a |
| model, tokens = self._lru.popleft() | ||
| self._delete(model, tokens) | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use radix tree techniques, which is featured in SGLang.
It accelerates in multi-round LLM, i.e. agentic LLM evrionment more efficiently.
It simply generalizes the logic in the
get_prompt_cache()to many caches simultaneously.There is a decision to be made on whether we want to always grow the cache or perhaps we can do something smarter like leave it in the cache if it is a commonly used series of tokens (that would take care of system prompts for instance). For now this implements the exact same thing as we had before. Always grow the cache if possible and try trimming it if not.