Skip to content

Commit b60cec8

Browse files
authored
Add system prompt to chat script (#334)
* Added system prompt option to chat script While testing locally fine-tuned models, being able to add a system prompt makes the evaluation much easier. The generate script already has the same feature. * keep linter gods happy
1 parent b26c608 commit b60cec8

File tree

2 files changed

+176
-1
lines changed

2 files changed

+176
-1
lines changed

mlx_lm/chat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def setup_arg_parser():
6969
default=DEFAULT_MAX_TOKENS,
7070
help="Maximum number of tokens to generate",
7171
)
72+
parser.add_argument(
73+
"--system-prompt",
74+
default=None,
75+
help="System prompt to be used for the chat template",
76+
)
7277
return parser
7378

7479

@@ -104,7 +109,10 @@ def print_help():
104109
if query == "h":
105110
print_help()
106111
continue
107-
messages = [{"role": "user", "content": query}]
112+
messages = []
113+
if args.system_prompt is not None:
114+
messages.append({"role": "system", "content": args.system_prompt})
115+
messages.append({"role": "user", "content": query})
108116
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
109117
for response in stream_generate(
110118
model,

tests/test_chat.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import argparse
2+
import unittest
3+
from unittest.mock import MagicMock, patch
4+
5+
from mlx_lm.chat import setup_arg_parser
6+
7+
8+
class TestChat(unittest.TestCase):
9+
10+
def test_setup_arg_parser_system_prompt(self):
11+
parser = setup_arg_parser()
12+
13+
# Test default (no system prompt)
14+
args = parser.parse_args([])
15+
self.assertIsNone(args.system_prompt)
16+
17+
# Test with system prompt
18+
args = parser.parse_args(["--system-prompt", "You are a helpful assistant."])
19+
self.assertEqual(args.system_prompt, "You are a helpful assistant.")
20+
21+
def test_setup_arg_parser_all_args(self):
22+
parser = setup_arg_parser()
23+
args = parser.parse_args(
24+
[
25+
"--model",
26+
"test-model",
27+
"--adapter-path",
28+
"/path/to/adapter",
29+
"--temp",
30+
"0.7",
31+
"--top-p",
32+
"0.9",
33+
"--xtc-probability",
34+
"0.1",
35+
"--xtc-threshold",
36+
"0.1",
37+
"--seed",
38+
"42",
39+
"--max-kv-size",
40+
"1024",
41+
"--max-tokens",
42+
"512",
43+
"--system-prompt",
44+
"You are a helpful assistant.",
45+
]
46+
)
47+
48+
self.assertEqual(args.model, "test-model")
49+
self.assertEqual(args.adapter_path, "/path/to/adapter")
50+
self.assertEqual(args.temp, 0.7)
51+
self.assertEqual(args.top_p, 0.9)
52+
self.assertEqual(args.xtc_probability, 0.1)
53+
self.assertEqual(args.xtc_threshold, 0.1)
54+
self.assertEqual(args.seed, 42)
55+
self.assertEqual(args.max_kv_size, 1024)
56+
self.assertEqual(args.max_tokens, 512)
57+
self.assertEqual(args.system_prompt, "You are a helpful assistant.")
58+
59+
@patch("mlx_lm.chat.load")
60+
@patch("mlx_lm.chat.make_prompt_cache")
61+
@patch("mlx_lm.chat.stream_generate")
62+
@patch("builtins.input")
63+
@patch("builtins.print")
64+
def test_system_prompt_in_messages(
65+
self,
66+
mock_print,
67+
mock_input,
68+
mock_stream_generate,
69+
mock_make_prompt_cache,
70+
mock_load,
71+
):
72+
from mlx_lm.chat import main
73+
74+
# Mock the model and tokenizer
75+
mock_model = MagicMock()
76+
mock_tokenizer = MagicMock()
77+
mock_tokenizer.apply_chat_template.return_value = "processed_prompt"
78+
mock_load.return_value = (mock_model, mock_tokenizer)
79+
80+
# Mock prompt cache
81+
mock_prompt_cache = MagicMock()
82+
mock_make_prompt_cache.return_value = mock_prompt_cache
83+
84+
# Mock stream_generate to return some responses
85+
mock_response = MagicMock()
86+
mock_response.text = "Hello there!"
87+
mock_stream_generate.return_value = [mock_response]
88+
89+
# Mock user input: first a question, then 'q' to quit
90+
mock_input.side_effect = ["What is the weather?", "q"]
91+
92+
# Test with system prompt
93+
with patch(
94+
"sys.argv", ["chat.py", "--system-prompt", "You are a weather assistant."]
95+
):
96+
try:
97+
main()
98+
except SystemExit:
99+
pass
100+
101+
# Verify that apply_chat_template was called with system prompt
102+
mock_tokenizer.apply_chat_template.assert_called()
103+
call_args = mock_tokenizer.apply_chat_template.call_args[0][
104+
0
105+
] # First positional arg (messages)
106+
107+
# Check that the messages contain both system and user messages
108+
self.assertEqual(len(call_args), 2)
109+
self.assertEqual(call_args[0]["role"], "system")
110+
self.assertEqual(call_args[0]["content"], "You are a weather assistant.")
111+
self.assertEqual(call_args[1]["role"], "user")
112+
self.assertEqual(call_args[1]["content"], "What is the weather?")
113+
114+
@patch("mlx_lm.chat.load")
115+
@patch("mlx_lm.chat.make_prompt_cache")
116+
@patch("mlx_lm.chat.stream_generate")
117+
@patch("builtins.input")
118+
@patch("builtins.print")
119+
def test_no_system_prompt_in_messages(
120+
self,
121+
mock_print,
122+
mock_input,
123+
mock_stream_generate,
124+
mock_make_prompt_cache,
125+
mock_load,
126+
):
127+
from mlx_lm.chat import main
128+
129+
# Mock the model and tokenizer
130+
mock_model = MagicMock()
131+
mock_tokenizer = MagicMock()
132+
mock_tokenizer.apply_chat_template.return_value = "processed_prompt"
133+
mock_load.return_value = (mock_model, mock_tokenizer)
134+
135+
# Mock prompt cache
136+
mock_prompt_cache = MagicMock()
137+
mock_make_prompt_cache.return_value = mock_prompt_cache
138+
139+
# Mock stream_generate to return some responses
140+
mock_response = MagicMock()
141+
mock_response.text = "Hello there!"
142+
mock_stream_generate.return_value = [mock_response]
143+
144+
# Mock user input: first a question, then 'q' to quit
145+
mock_input.side_effect = ["What is the weather?", "q"]
146+
147+
# Test without system prompt
148+
with patch("sys.argv", ["chat.py"]):
149+
try:
150+
main()
151+
except SystemExit:
152+
pass
153+
154+
# Verify that apply_chat_template was called without system prompt
155+
mock_tokenizer.apply_chat_template.assert_called()
156+
call_args = mock_tokenizer.apply_chat_template.call_args[0][
157+
0
158+
] # First positional arg (messages)
159+
160+
# Check that the messages contain only user message
161+
self.assertEqual(len(call_args), 1)
162+
self.assertEqual(call_args[0]["role"], "user")
163+
self.assertEqual(call_args[0]["content"], "What is the weather?")
164+
165+
166+
if __name__ == "__main__":
167+
unittest.main()

0 commit comments

Comments
 (0)