Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 153 additions & 19 deletions dflash/scripts/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import argparse
import json
import os
import re
import struct
import subprocess
import sys
Expand Down Expand Up @@ -62,6 +63,8 @@ def resolve_draft(root: Path) -> Path:
"Qwen3.5-27B": "Qwen/Qwen3.5-27B",
"Qwen3.6-27B": "Qwen/Qwen3.6-27B",
}
THINK_OPEN_TAG = "<think>"
THINK_CLOSE_TAG = "</think>"


def _tokenizer_id_from_gguf(gguf_path: Path) -> str:
Expand Down Expand Up @@ -89,6 +92,74 @@ def _tokenizer_id_from_gguf(gguf_path: Path) -> str:
return default


def parse_reasoning(
text: str,
thinking_enabled: bool = True,
started_in_thinking: bool = False,
) -> tuple[str, str | None]:
# Qwen chat templates can prefill `<think>\n` into the prompt, so the
# generated output contains only the reasoning body plus `</think>`.
parts = text.partition(THINK_OPEN_TAG)
saw_open_tag = bool(parts[1])
rest = parts[2] if saw_open_tag else parts[0]
if THINK_CLOSE_TAG not in rest:
if thinking_enabled and (started_in_thinking or saw_open_tag):
return "", (rest.strip() or None)
return rest.strip(), None
reasoning, _, content = rest.partition(THINK_CLOSE_TAG)
return content.strip(), (reasoning.strip() or None)


def prompt_starts_in_thinking(prompt: str) -> bool:
return bool(re.search(r"<think>\s*$", prompt))


def consume_stream_piece(window: str, mode: str, piece: str):
outputs = []
holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG))
window += piece
while True:
if mode == "reasoning":
idx = window.find(THINK_CLOSE_TAG)
if idx != -1:
pre = window[:idx]
if pre:
outputs.append(("reasoning_content", pre))
window = window[idx + len(THINK_CLOSE_TAG):]
mode = "content"
continue
if len(window) > holdback:
safe = window[:-holdback]
if safe:
outputs.append(("reasoning_content", safe))
window = window[-holdback:]
break

idx = window.find(THINK_OPEN_TAG)
if idx != -1:
pre = window[:idx]
if pre:
outputs.append(("content", pre))
window = window[idx + len(THINK_OPEN_TAG):]
mode = "reasoning"
continue
if len(window) > holdback:
safe = window[:-holdback]
if safe:
outputs.append(("content", safe))
window = window[-holdback:]
break

return outputs, window, mode


def flush_stream_deltas(window: str, mode: str):
if not window:
return []
kind = "reasoning_content" if mode == "reasoning" else "content"
return [(kind, window)]


# FIX 2: _content_to_str helper used for BOTH OpenAI and Anthropic message
# content fields (str | list[dict]). Previously OpenAI list[dict] content
# was passed raw to the tokenizer and caused a crash.
Expand Down Expand Up @@ -286,12 +357,18 @@ def _render_messages(msgs_list: list[dict],
ids = tokenizer.encode(prompt, add_special_tokens=False)
return _ids_to_bin(ids), ids, prompt

def _thinking_enabled(kwargs: dict | None) -> bool:
if kwargs:
return kwargs.get("enable_thinking", True)
return True

# FIX 2 applied: always call _content_to_str on message content
def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict]]:
def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], bool]:
msgs = [{"role": m.role, "content": _content_to_str(m.content)}
for m in req.messages]
path, ids, _prompt = _render_messages(msgs, req.chat_template_kwargs)
return path, ids, msgs
path, ids, prompt = _render_messages(msgs, req.chat_template_kwargs)
think = _thinking_enabled(req.chat_template_kwargs) and prompt_starts_in_thinking(prompt)
return path, ids, msgs, think

def _maybe_compress(msgs: list[dict], prompt_bin: Path, prompt_ids: list[int],
template_kwargs: dict | None = None
Expand Down Expand Up @@ -423,7 +500,7 @@ def _gen_len_for(prompt_len: int, max_tokens: int) -> int:

@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
prompt_bin, prompt_ids, raw_msgs = _tokenize_prompt(req)
prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(req)
completion_id = "chatcmpl-" + uuid.uuid4().hex[:24]
created = int(time.time())

Expand Down Expand Up @@ -487,15 +564,29 @@ async def sse() -> AsyncIterator[str]:
"finish_reason": None}],
}
yield f"data: {json.dumps(head)}\n\n"
window, mode = "", ("reasoning" if started_in_thinking else "content")

try:
async for tok_id in _astream_tokens(r_pipe, gen_len):
outputs, window, mode = consume_stream_piece(
window, mode, tokenizer.decode([tok_id]))
for kind, text in outputs:
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created, "model": MODEL_NAME,
"choices": [{"index": 0,
"delta": {kind: text},
"finish_reason": None}],
}
yield f"data: {json.dumps(chunk)}\n\n"
for kind, text in flush_stream_deltas(window, mode):
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created, "model": MODEL_NAME,
"choices": [{"index": 0,
"delta": {"content": tokenizer.decode([tok_id])},
"delta": {kind: text},
"finish_reason": None}],
}
yield f"data: {json.dumps(chunk)}\n\n"
Expand Down Expand Up @@ -584,14 +675,22 @@ async def sse() -> AsyncIterator[str]:
except Exception: pass

text = tokenizer.decode(tokens, skip_special_tokens=True)
cleaned, reasoning = parse_reasoning(
text,
thinking_enabled=_thinking_enabled(req.chat_template_kwargs),
started_in_thinking=started_in_thinking,
)
msg = {"role": "assistant", "content": cleaned}
if reasoning:
msg["reasoning_content"] = reasoning
return JSONResponse({
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": MODEL_NAME,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"message": msg,
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": prompt_len,
Expand All @@ -602,19 +701,20 @@ async def sse() -> AsyncIterator[str]:
# ── Anthropic Messages API ──────────────────────────────────────────────

def _tokenize_anthropic(req: AnthropicMessagesRequest
) -> tuple[Path, list[int], list[dict]]:
) -> tuple[Path, list[int], list[dict], bool]:
msgs = []
system_text = _content_to_str(req.system) if req.system else None
if system_text:
msgs.append({"role": "system", "content": system_text})
for m in req.messages:
msgs.append({"role": m.role, "content": _content_to_str(m.content)})
path, ids, _prompt = _render_messages(msgs, req.chat_template_kwargs)
return path, ids, msgs
path, ids, prompt = _render_messages(msgs, req.chat_template_kwargs)
think = _thinking_enabled(req.chat_template_kwargs) and prompt_starts_in_thinking(prompt)
return path, ids, msgs, think

@app.post("/v1/messages")
async def anthropic_messages(req: AnthropicMessagesRequest):
prompt_bin, prompt_ids, raw_msgs = _tokenize_anthropic(req)
prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_anthropic(req)
msg_id = "msg_" + uuid.uuid4().hex[:24]

if req.stream:
Expand Down Expand Up @@ -668,7 +768,6 @@ async def sse() -> AsyncIterator[str]:
},
}
yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n"
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n"

try:
_write_cmd(cmd_line)
Expand All @@ -677,15 +776,42 @@ async def sse() -> AsyncIterator[str]:
return

out_tokens = 0
window, mode = "", ("reasoning" if started_in_thinking else "content")
block_index = 0
active_kind = "thinking" if mode == "reasoning" else "text"
block = {"type": active_kind}
if active_kind == "thinking":
block["thinking"] = ""
else:
block["text"] = ""
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': block})}\n\n"
try:
async for tok_id in _astream_tokens(r_pipe, gen_len):
out_tokens += 1
delta = {
"type": "content_block_delta", "index": 0,
"delta": {"type": "text_delta",
"text": tokenizer.decode([tok_id])},
}
yield f"event: content_block_delta\ndata: {json.dumps(delta)}\n\n"
outputs, window, mode = consume_stream_piece(
window, mode, tokenizer.decode([tok_id]))
for kind, text in outputs:
target_kind = "thinking" if kind == "reasoning_content" else "text"
if target_kind != active_kind:
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n"
block_index += 1
active_kind = target_kind
new_block = {"type": active_kind, active_kind: ""}
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': new_block})}\n\n"
delta_type = "thinking_delta" if target_kind == "thinking" else "text_delta"
delta_key = "thinking" if target_kind == "thinking" else "text"
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': block_index, 'delta': {'type': delta_type, delta_key: text}})}\n\n"
for kind, text in flush_stream_deltas(window, mode):
target_kind = "thinking" if kind == "reasoning_content" else "text"
if target_kind != active_kind:
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n"
block_index += 1
active_kind = target_kind
new_block = {"type": active_kind, active_kind: ""}
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': new_block})}\n\n"
delta_type = "thinking_delta" if target_kind == "thinking" else "text_delta"
delta_key = "thinking" if target_kind == "thinking" else "text"
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': block_index, 'delta': {'type': delta_type, delta_key: text}})}\n\n"
finally:
if full_hit is None:
try: cur_bin.unlink()
Expand All @@ -701,7 +827,7 @@ async def sse() -> AsyncIterator[str]:
elif snap_prep:
prefix_cache.confirm_inline_snap(*snap_prep, cur_ids)

yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n"
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n"
msg_delta = {
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
Expand Down Expand Up @@ -774,12 +900,20 @@ async def sse() -> AsyncIterator[str]:
except Exception: pass

text = tokenizer.decode(tokens, skip_special_tokens=True)
cleaned, reasoning = parse_reasoning(
text,
thinking_enabled=_thinking_enabled(req.chat_template_kwargs),
started_in_thinking=started_in_thinking,
)
content = [{"type": "text", "text": cleaned}]
if reasoning:
content.insert(0, {"type": "thinking", "thinking": reasoning})
return JSONResponse({
"id": msg_id,
"type": "message",
"role": "assistant",
"model": req.model or MODEL_NAME,
"content": [{"type": "text", "text": text}],
"content": content,
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": prompt_len,
Expand Down
Loading