Skip to content

Commit 64f2acd

Browse files
add unit test
1 parent 7596cc2 commit 64f2acd

File tree

2 files changed

+60
-13
lines changed

2 files changed

+60
-13
lines changed

mlx_lm/server.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -406,17 +406,10 @@ def do_POST(self):
406406
self.body = json.loads(raw_body.decode())
407407
except json.JSONDecodeError as e:
408408
logging.error(f"JSONDecodeError: {e} - Raw body: {raw_body.decode()}")
409-
# Set appropriate headers based on streaming requirement
410-
if self.stream:
411-
self._set_stream_headers(400)
412-
self.wfile.write(
413-
f"data: {json.dumps({'error': f'Invalid JSON in request body: {e}'})}\n\n".encode()
414-
)
415-
else:
416-
self._set_completion_headers(400)
417-
self.wfile.write(
418-
json.dumps({"error": f"Invalid JSON in request body: {e}"}).encode()
419-
)
409+
self._set_completion_headers(400)
410+
self.wfile.write(
411+
json.dumps({"error": f"Invalid JSON in request body: {e}"}).encode()
412+
)
420413
return
421414

422415
indent = "\t" # Backslashes can't be inside of f-strings
@@ -436,8 +429,12 @@ def do_POST(self):
436429
model_path = os.environ['MLX_MODEL_PATH']
437430
if not os.path.exists(model_path):
438431
raise Exception(f"MLX_MODEL_PATH={model_path} is not a path")
439-
self.requested_model = os.path.join(model_path, self.requested_model)
440-
self.requested_draft_model = os.path.join(model_path, self.requested_draft_model)
432+
433+
if self.requested_model != "default_model":
434+
self.requested_model = os.path.join(model_path, self.requested_model)
435+
436+
if self.requested_draft_model != "default_model":
437+
self.requested_draft_model = os.path.join(model_path, self.requested_draft_model)
441438

442439
self.num_draft_tokens = self.body.get(
443440
"num_draft_tokens", self.model_provider.cli_args.num_draft_tokens

tests/test_server.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88

99
import requests
10+
import time
1011

1112
from mlx_lm.server import APIHandler
1213
from mlx_lm.utils import load
@@ -69,6 +70,55 @@ def tearDownClass(cls):
6970
cls.httpd.server_close()
7071
cls.server_thread.join()
7172

73+
def test_handle_chunked_request(self):
74+
url = f"http://localhost:{self.port}/v1/chat/completions"
75+
76+
post_data = {
77+
"model": "default_model",
78+
"prompt": "Once upon a time",
79+
"max_tokens": 10,
80+
"temperature": 0.0,
81+
"stream": False,
82+
"top_p": 1.0,
83+
}
84+
85+
# chunked request
86+
data_parts = [
87+
b'{"model": "default_model", "messages": [{"role": "user", "content": "Once',
88+
b' upon a times, Once upon ',
89+
b'a time"}], "temperature": 0.8, "max_tokens": 1024, "stream": false}',
90+
]
91+
92+
max_length = 0
93+
for part in data_parts:
94+
max_length += len(part)
95+
96+
def data_generator():
97+
for part in data_parts:
98+
yield part
99+
time.sleep(0.1)
100+
101+
try:
102+
response = requests.post(
103+
url,
104+
data=data_generator(),
105+
headers={
106+
"Transfer-Encoding": "chunked",
107+
"Content-Type": "application/json",
108+
},
109+
)
110+
self.assertEqual(response.status_code, 200)
111+
except requests.exceptions.RequestException:
112+
self.assertTrue(False, "Chunked request failed")
113+
114+
response_body = json.loads(response.text)
115+
self.assertIn("id", response_body)
116+
self.assertIn("choices", response_body)
117+
self.assertIn("usage", response_body)
118+
119+
# Check that tokens were generated
120+
self.assertTrue(response_body["usage"]["completion_tokens"] > 0)
121+
72122
def test_handle_completions(self):
73123
url = f"http://localhost:{self.port}/v1/completions"
74124

0 commit comments

Comments
 (0)