Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions mlx_lm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .generate import stream_generate
from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load
from .utils import does_model_support_prompt_cache, load

DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
Expand All @@ -16,6 +16,9 @@
DEFAULT_SEED = None
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_BLOCK_LENGTH = 32
DEFAULT_STEPS = 32
DEFAULT_THRESHOLD = 0.95


def setup_arg_parser():
Expand Down Expand Up @@ -79,6 +82,24 @@ def setup_arg_parser():
default=None,
help="System prompt to be used for the chat template",
)
parser.add_argument(
"--block-length",
type=int,
default=DEFAULT_BLOCK_LENGTH,
help="[Diffusion models only] Number of tokens per block",
)
parser.add_argument(
"--steps",
type=int,
default=DEFAULT_STEPS,
help="[Diffusion models only] Number of denoising iterations per block",
)
parser.add_argument(
"--threshold",
type=float,
default=DEFAULT_THRESHOLD,
help="[Diffusion models only] Confidence threshold for token acceptance",
)
return parser


Expand All @@ -97,36 +118,43 @@ def main():
},
)

use_cache = does_model_support_prompt_cache(model)

def print_help():
print("The command list:")
print("- 'q' to exit")
print("- 'r' to reset the chat")
print("- 'h' to display these commands")

def reset_conversation():
"""Reset conversation history and prompt cache."""
cache = make_prompt_cache(model, args.max_kv_size) if use_cache else None
msgs = []
if args.system_prompt is not None:
msgs.append({"role": "system", "content": args.system_prompt})
return cache, msgs

print(f"[INFO] Starting chat session with {args.model}.")
print_help()
prompt_cache = make_prompt_cache(model, args.max_kv_size)
prompt_cache, messages = reset_conversation()

while True:
query = input(">> ")
if query == "q":
break
if query == "r":
prompt_cache = make_prompt_cache(model, args.max_kv_size)
prompt_cache, messages = reset_conversation()
continue
if query == "h":
print_help()
continue
messages = []
if args.system_prompt is not None:
messages.append({"role": "system", "content": args.system_prompt})

messages.append({"role": "user", "content": query})
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
sampler=make_sampler(

gen_kwargs = {
"max_tokens": args.max_tokens,
"sampler": make_sampler(
args.temp,
args.top_p,
xtc_threshold=args.xtc_threshold,
Expand All @@ -135,11 +163,20 @@ def print_help():
tokenizer.encode("\n") + list(tokenizer.eos_token_ids)
),
),
prompt_cache=prompt_cache,
):
"prompt_cache": prompt_cache,
"block_length": args.block_length,
"steps": args.steps,
"threshold": args.threshold,
}

assistant_response = ""
for response in stream_generate(model, tokenizer, prompt, **gen_kwargs):
print(response.text, flush=True, end="")
assistant_response += response.text
print()

messages.append({"role": "assistant", "content": assistant_response})


if __name__ == "__main__":
print(
Expand Down
Loading