Skip to content

Commit befa4cd

Browse files
Applejyork03
authored andcommitted
fix mlx-server for chunked request (to support one-api, curl)
update codes - add body bytes limit to prevent DOS attacks - clean codes add unit test Co-authored-by: Josh York <[email protected]>
1 parent 6c1a459 commit befa4cd

File tree

2 files changed

+178
-3
lines changed

2 files changed

+178
-3
lines changed

mlx_lm/server.py

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .sample_utils import make_logits_processors, make_sampler
3333
from .utils import common_prefix_len, load
3434

35+
import os
3536

3637
def get_system_fingerprint():
3738
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
@@ -267,6 +268,78 @@ def _set_stream_headers(self, status_code: int = 200):
267268
self.send_header("Cache-Control", "no-cache")
268269
self._set_cors_headers()
269270

271+
def _read_chunked_body(self, max_bytes: int | None = None) -> bytes:
272+
"""
273+
Read an HTTP/1.1 chunked transfer-encoded body from self.rfile and return
274+
the concatenated payload bytes.
275+
276+
Raises ValueError on decoding errors.
277+
"""
278+
out = bytearray()
279+
rfile = self.rfile
280+
281+
def _readline_strict():
282+
# read up to CRLF (returns bytes without CRLF)
283+
line = rfile.readline()
284+
if not line:
285+
raise ValueError("unexpected EOF while reading chunk size")
286+
if not line.endswith(b"\r\n"):
287+
# allow robustness: if line ends with \n only, normalize
288+
if line.endswith(b"\n"):
289+
line = line.rstrip(b"\n")
290+
else:
291+
raise ValueError("malformed chunk size line (missing CRLF)")
292+
return line.rstrip(b"\r\n")
293+
294+
while True:
295+
# Read chunk-size line
296+
size_line = _readline_strict()
297+
# strip optional chunk extensions
298+
if b";" in size_line:
299+
size_str = size_line.split(b";", 1)[0].strip()
300+
else:
301+
size_str = size_line.strip()
302+
try:
303+
size = int(size_str.decode("ascii"), 16)
304+
except Exception:
305+
raise ValueError(f"invalid chunk size: {size_str!r}")
306+
if size == 0:
307+
# Consume the trailing CRLF after last-chunk (it may already be consumed)
308+
# Then read optional trailers until blank line (not handled here in detail)
309+
# Read the CRLF after the 0 chunk if present
310+
trailer_line = rfile.readline()
311+
# If there are trailers, consume until blank line
312+
if trailer_line and trailer_line.strip() != b"":
313+
# there are trailers; read until blank line
314+
while True:
315+
line = rfile.readline()
316+
if not line:
317+
break
318+
if line in (b"\r\n", b"\n", b""):
319+
break
320+
return bytes(out)
321+
# Read exactly `size` bytes of data
322+
remaining = size
323+
while remaining > 0:
324+
chunk = rfile.read(remaining)
325+
if not chunk:
326+
raise ValueError("incomplete chunk data!")
327+
out.extend(chunk)
328+
if max_bytes is not None and len(out) > max_bytes:
329+
raise ValueError(f"payload too large: {len(out)} bytes > {max_bytes} bytes")
330+
remaining -= len(chunk)
331+
# after chunk-data there must be CRLF
332+
crlf = rfile.read(2)
333+
if crlf != b"\r\n":
334+
# some clients may send only '\n' or split; be permissive but strict enough
335+
if crlf == b"\n":
336+
# accept single LF
337+
pass
338+
else:
339+
raise ValueError("missing CRLF after chunk data")
340+
# unreachable
341+
return bytes(out)
342+
270343
def do_OPTIONS(self):
271344
self._set_completion_headers(204)
272345
self.end_headers()
@@ -287,9 +360,48 @@ def do_POST(self):
287360
self.wfile.write(b"Not Found")
288361
return
289362

290-
# Fetch and parse request body
291-
content_length = int(self.headers["Content-Length"])
292-
raw_body = self.rfile.read(content_length)
363+
# Maximum body size in bytes. Prevent DOS attacks. Set sane limits.
364+
MAX_BODY_BYTES = 1024 * 1024 * 10 # 10MB
365+
366+
transfer_encoding = self.headers.get('Transfer-Encoding', "") # transfer encoding
367+
368+
if "chunked" in transfer_encoding.lower():
369+
try:
370+
raw_body = self._read_chunked_body(max_bytes=MAX_BODY_BYTES)
371+
except ValueError as e:
372+
self._set_completion_headers(413 if "payload too large" in str(e).lower() else 400)
373+
self.end_headers()
374+
self.wfile.write(json.dumps({"error": f"Invalid chunked body: {e}"}).encode())
375+
return
376+
except Exception:
377+
self._set_completion_headers(500)
378+
self.end_headers()
379+
self.wfile.write(json.dumps({"error": "internal server error"}).encode())
380+
return
381+
else:
382+
cl_header = self.headers.get("Content-Length")
383+
if cl_header is None:
384+
self._set_completion_headers(411) # Length Required
385+
self.end_headers()
386+
self.wfile.write(json.dumps({"error": "Content-Length required"}).encode())
387+
return
388+
389+
try:
390+
content_length = int(cl_header)
391+
except ValueError:
392+
self._set_completion_headers(400) # Bad Request
393+
self.end_headers()
394+
self.wfile.write(json.dumps({"error": "Invalid Content-Length"}).encode())
395+
return
396+
397+
if content_length < 0 or content_length > MAX_BODY_BYTES:
398+
self._set_completion_headers(413) # Payload Too Large
399+
self.end_headers()
400+
self.wfile.write(json.dumps({"error": "Payload too large"}).encode())
401+
return
402+
403+
raw_body = self.rfile.read(content_length)
404+
293405
try:
294406
self.body = json.loads(raw_body.decode())
295407
except json.JSONDecodeError as e:
@@ -309,8 +421,21 @@ def do_POST(self):
309421
# Extract request parameters from the body
310422
self.stream = self.body.get("stream", False)
311423
self.stream_options = self.body.get("stream_options", None)
424+
312425
self.requested_model = self.body.get("model", "default_model")
313426
self.requested_draft_model = self.body.get("draft_model", "default_model")
427+
428+
if os.environ.get("MLX_MODEL_PATH", None) is not None:
429+
model_path = os.environ['MLX_MODEL_PATH']
430+
if not os.path.exists(model_path):
431+
raise Exception(f"MLX_MODEL_PATH={model_path} is not a path")
432+
433+
if self.requested_model != "default_model":
434+
self.requested_model = os.path.join(model_path, self.requested_model)
435+
436+
if self.requested_draft_model != "default_model":
437+
self.requested_draft_model = os.path.join(model_path, self.requested_draft_model)
438+
314439
self.num_draft_tokens = self.body.get(
315440
"num_draft_tokens", self.model_provider.cli_args.num_draft_tokens
316441
)

tests/test_server.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88

99
import requests
10+
import time
1011

1112
from mlx_lm.server import APIHandler
1213
from mlx_lm.utils import load
@@ -69,6 +70,55 @@ def tearDownClass(cls):
6970
cls.httpd.server_close()
7071
cls.server_thread.join()
7172

73+
def test_handle_chunked_request(self):
74+
url = f"http://localhost:{self.port}/v1/chat/completions"
75+
76+
post_data = {
77+
"model": "default_model",
78+
"prompt": "Once upon a time",
79+
"max_tokens": 10,
80+
"temperature": 0.0,
81+
"stream": False,
82+
"top_p": 1.0,
83+
}
84+
85+
# chunked request
86+
data_parts = [
87+
b'{"model": "default_model", "messages": [{"role": "user", "content": "Once',
88+
b' upon a times, Once upon ',
89+
b'a time"}], "temperature": 0.8, "max_tokens": 1024, "stream": false}',
90+
]
91+
92+
max_length = 0
93+
for part in data_parts:
94+
max_length += len(part)
95+
96+
def data_generator():
97+
for part in data_parts:
98+
yield part
99+
time.sleep(0.1)
100+
101+
try:
102+
response = requests.post(
103+
url,
104+
data=data_generator(),
105+
headers={
106+
"Transfer-Encoding": "chunked",
107+
"Content-Type": "application/json",
108+
},
109+
)
110+
self.assertEqual(response.status_code, 200)
111+
except requests.exceptions.RequestException:
112+
self.assertTrue(False, "Chunked request failed")
113+
114+
response_body = json.loads(response.text)
115+
self.assertIn("id", response_body)
116+
self.assertIn("choices", response_body)
117+
self.assertIn("usage", response_body)
118+
119+
# Check that tokens were generated
120+
self.assertTrue(response_body["usage"]["completion_tokens"] > 0)
121+
72122
def test_handle_completions(self):
73123
url = f"http://localhost:{self.port}/v1/completions"
74124

0 commit comments

Comments
 (0)