diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py index 177c7622..9e827a8e 100644 --- a/dflash/scripts/server.py +++ b/dflash/scripts/server.py @@ -16,6 +16,7 @@ import argparse import json import os +import re import struct import subprocess import sys @@ -62,6 +63,8 @@ def resolve_draft(root: Path) -> Path: "Qwen3.5-27B": "Qwen/Qwen3.5-27B", "Qwen3.6-27B": "Qwen/Qwen3.6-27B", } +THINK_OPEN_TAG = "" +THINK_CLOSE_TAG = "" def _tokenizer_id_from_gguf(gguf_path: Path) -> str: @@ -89,6 +92,74 @@ def _tokenizer_id_from_gguf(gguf_path: Path) -> str: return default +def parse_reasoning( + text: str, + thinking_enabled: bool = True, + started_in_thinking: bool = False, +) -> tuple[str, str | None]: + # Qwen chat templates can prefill `\n` into the prompt, so the + # generated output contains only the reasoning body plus ``. + parts = text.partition(THINK_OPEN_TAG) + saw_open_tag = bool(parts[1]) + rest = parts[2] if saw_open_tag else parts[0] + if THINK_CLOSE_TAG not in rest: + if thinking_enabled and (started_in_thinking or saw_open_tag): + return "", (rest.strip() or None) + return rest.strip(), None + reasoning, _, content = rest.partition(THINK_CLOSE_TAG) + return content.strip(), (reasoning.strip() or None) + + +def prompt_starts_in_thinking(prompt: str) -> bool: + return bool(re.search(r"\s*$", prompt)) + + +def consume_stream_piece(window: str, mode: str, piece: str): + outputs = [] + holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG)) + window += piece + while True: + if mode == "reasoning": + idx = window.find(THINK_CLOSE_TAG) + if idx != -1: + pre = window[:idx] + if pre: + outputs.append(("reasoning_content", pre)) + window = window[idx + len(THINK_CLOSE_TAG):] + mode = "content" + continue + if len(window) > holdback: + safe = window[:-holdback] + if safe: + outputs.append(("reasoning_content", safe)) + window = window[-holdback:] + break + + idx = window.find(THINK_OPEN_TAG) + if idx != -1: + pre = window[:idx] + if pre: + outputs.append(("content", pre)) + window = window[idx + len(THINK_OPEN_TAG):] + mode = "reasoning" + continue + if len(window) > holdback: + safe = window[:-holdback] + if safe: + outputs.append(("content", safe)) + window = window[-holdback:] + break + + return outputs, window, mode + + +def flush_stream_deltas(window: str, mode: str): + if not window: + return [] + kind = "reasoning_content" if mode == "reasoning" else "content" + return [(kind, window)] + + # FIX 2: _content_to_str helper used for BOTH OpenAI and Anthropic message # content fields (str | list[dict]). Previously OpenAI list[dict] content # was passed raw to the tokenizer and caused a crash. @@ -286,12 +357,18 @@ def _render_messages(msgs_list: list[dict], ids = tokenizer.encode(prompt, add_special_tokens=False) return _ids_to_bin(ids), ids, prompt + def _thinking_enabled(kwargs: dict | None) -> bool: + if kwargs: + return kwargs.get("enable_thinking", True) + return True + # FIX 2 applied: always call _content_to_str on message content - def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict]]: + def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], bool]: msgs = [{"role": m.role, "content": _content_to_str(m.content)} for m in req.messages] - path, ids, _prompt = _render_messages(msgs, req.chat_template_kwargs) - return path, ids, msgs + path, ids, prompt = _render_messages(msgs, req.chat_template_kwargs) + think = _thinking_enabled(req.chat_template_kwargs) and prompt_starts_in_thinking(prompt) + return path, ids, msgs, think def _maybe_compress(msgs: list[dict], prompt_bin: Path, prompt_ids: list[int], template_kwargs: dict | None = None @@ -423,7 +500,7 @@ def _gen_len_for(prompt_len: int, max_tokens: int) -> int: @app.post("/v1/chat/completions") async def chat_completions(req: ChatRequest): - prompt_bin, prompt_ids, raw_msgs = _tokenize_prompt(req) + prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(req) completion_id = "chatcmpl-" + uuid.uuid4().hex[:24] created = int(time.time()) @@ -487,15 +564,29 @@ async def sse() -> AsyncIterator[str]: "finish_reason": None}], } yield f"data: {json.dumps(head)}\n\n" + window, mode = "", ("reasoning" if started_in_thinking else "content") try: async for tok_id in _astream_tokens(r_pipe, gen_len): + outputs, window, mode = consume_stream_piece( + window, mode, tokenizer.decode([tok_id])) + for kind, text in outputs: + chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, "model": MODEL_NAME, + "choices": [{"index": 0, + "delta": {kind: text}, + "finish_reason": None}], + } + yield f"data: {json.dumps(chunk)}\n\n" + for kind, text in flush_stream_deltas(window, mode): chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_NAME, "choices": [{"index": 0, - "delta": {"content": tokenizer.decode([tok_id])}, + "delta": {kind: text}, "finish_reason": None}], } yield f"data: {json.dumps(chunk)}\n\n" @@ -584,6 +675,14 @@ async def sse() -> AsyncIterator[str]: except Exception: pass text = tokenizer.decode(tokens, skip_special_tokens=True) + cleaned, reasoning = parse_reasoning( + text, + thinking_enabled=_thinking_enabled(req.chat_template_kwargs), + started_in_thinking=started_in_thinking, + ) + msg = {"role": "assistant", "content": cleaned} + if reasoning: + msg["reasoning_content"] = reasoning return JSONResponse({ "id": completion_id, "object": "chat.completion", @@ -591,7 +690,7 @@ async def sse() -> AsyncIterator[str]: "model": MODEL_NAME, "choices": [{ "index": 0, - "message": {"role": "assistant", "content": text}, + "message": msg, "finish_reason": "stop", }], "usage": {"prompt_tokens": prompt_len, @@ -602,19 +701,20 @@ async def sse() -> AsyncIterator[str]: # ── Anthropic Messages API ────────────────────────────────────────────── def _tokenize_anthropic(req: AnthropicMessagesRequest - ) -> tuple[Path, list[int], list[dict]]: + ) -> tuple[Path, list[int], list[dict], bool]: msgs = [] system_text = _content_to_str(req.system) if req.system else None if system_text: msgs.append({"role": "system", "content": system_text}) for m in req.messages: msgs.append({"role": m.role, "content": _content_to_str(m.content)}) - path, ids, _prompt = _render_messages(msgs, req.chat_template_kwargs) - return path, ids, msgs + path, ids, prompt = _render_messages(msgs, req.chat_template_kwargs) + think = _thinking_enabled(req.chat_template_kwargs) and prompt_starts_in_thinking(prompt) + return path, ids, msgs, think @app.post("/v1/messages") async def anthropic_messages(req: AnthropicMessagesRequest): - prompt_bin, prompt_ids, raw_msgs = _tokenize_anthropic(req) + prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_anthropic(req) msg_id = "msg_" + uuid.uuid4().hex[:24] if req.stream: @@ -668,7 +768,6 @@ async def sse() -> AsyncIterator[str]: }, } yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" - yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" try: _write_cmd(cmd_line) @@ -677,15 +776,42 @@ async def sse() -> AsyncIterator[str]: return out_tokens = 0 + window, mode = "", ("reasoning" if started_in_thinking else "content") + block_index = 0 + active_kind = "thinking" if mode == "reasoning" else "text" + block = {"type": active_kind} + if active_kind == "thinking": + block["thinking"] = "" + else: + block["text"] = "" + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': block})}\n\n" try: async for tok_id in _astream_tokens(r_pipe, gen_len): out_tokens += 1 - delta = { - "type": "content_block_delta", "index": 0, - "delta": {"type": "text_delta", - "text": tokenizer.decode([tok_id])}, - } - yield f"event: content_block_delta\ndata: {json.dumps(delta)}\n\n" + outputs, window, mode = consume_stream_piece( + window, mode, tokenizer.decode([tok_id])) + for kind, text in outputs: + target_kind = "thinking" if kind == "reasoning_content" else "text" + if target_kind != active_kind: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" + block_index += 1 + active_kind = target_kind + new_block = {"type": active_kind, active_kind: ""} + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': new_block})}\n\n" + delta_type = "thinking_delta" if target_kind == "thinking" else "text_delta" + delta_key = "thinking" if target_kind == "thinking" else "text" + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': block_index, 'delta': {'type': delta_type, delta_key: text}})}\n\n" + for kind, text in flush_stream_deltas(window, mode): + target_kind = "thinking" if kind == "reasoning_content" else "text" + if target_kind != active_kind: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" + block_index += 1 + active_kind = target_kind + new_block = {"type": active_kind, active_kind: ""} + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': new_block})}\n\n" + delta_type = "thinking_delta" if target_kind == "thinking" else "text_delta" + delta_key = "thinking" if target_kind == "thinking" else "text" + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': block_index, 'delta': {'type': delta_type, delta_key: text}})}\n\n" finally: if full_hit is None: try: cur_bin.unlink() @@ -701,7 +827,7 @@ async def sse() -> AsyncIterator[str]: elif snap_prep: prefix_cache.confirm_inline_snap(*snap_prep, cur_ids) - yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" msg_delta = { "type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, @@ -774,12 +900,20 @@ async def sse() -> AsyncIterator[str]: except Exception: pass text = tokenizer.decode(tokens, skip_special_tokens=True) + cleaned, reasoning = parse_reasoning( + text, + thinking_enabled=_thinking_enabled(req.chat_template_kwargs), + started_in_thinking=started_in_thinking, + ) + content = [{"type": "text", "text": cleaned}] + if reasoning: + content.insert(0, {"type": "thinking", "thinking": reasoning}) return JSONResponse({ "id": msg_id, "type": "message", "role": "assistant", "model": req.model or MODEL_NAME, - "content": [{"type": "text", "text": text}], + "content": content, "stop_reason": "end_turn", "stop_sequence": None, "usage": {"input_tokens": prompt_len, diff --git a/dflash/scripts/test_server.py b/dflash/scripts/test_server.py index a9926806..5592cf96 100644 --- a/dflash/scripts/test_server.py +++ b/dflash/scripts/test_server.py @@ -8,7 +8,10 @@ import pytest from fastapi.testclient import TestClient -from server import build_app, MODEL_NAME +from server import ( + build_app, MODEL_NAME, parse_reasoning, + consume_stream_piece, flush_stream_deltas, +) @pytest.fixture @@ -16,8 +19,108 @@ def mock_tokenizer(): tokenizer = MagicMock() tokenizer.encode.return_value = [1] tokenizer.decode.return_value = "hello" + tokenizer.apply_chat_template.return_value = "prompt" return tokenizer + +def test_parse_reasoning_headless_think(): + cleaned, reasoning = parse_reasoning("private chain of thought\n\nvisible answer") + assert cleaned == "visible answer" + assert reasoning == "private chain of thought" + + +def test_parse_reasoning_full_think_tags(): + cleaned, reasoning = parse_reasoning("my reasoning\n\nthe answer") + assert cleaned == "the answer" + assert reasoning == "my reasoning" + + +def test_parse_reasoning_plain_content_when_no_think_segment_present(): + cleaned, reasoning = parse_reasoning("visible answer only") + assert cleaned == "visible answer only" + assert reasoning is None + + +def test_parse_reasoning_truncated_when_prompt_started_in_thinking(): + cleaned, reasoning = parse_reasoning( + "unfinished private chain of thought", + started_in_thinking=True, + ) + assert cleaned == "" + assert reasoning == "unfinished private chain of thought" + + +# -- consume_stream_piece / flush_stream_deltas ------------------------- + +def test_consume_stream_piece_reasoning_to_content(): + """Full transition: reasoning tokens, close tag, content tokens.""" + window, mode = "", "reasoning" + assert mode == "reasoning" + + all_outputs = [] + + # Feed reasoning text + outputs, window, mode = consume_stream_piece(window, mode, "deep thought") + all_outputs.extend(outputs) + assert mode == "reasoning" + + # Feed close tag + outputs, window, mode = consume_stream_piece(window, mode, "") + all_outputs.extend(outputs) + reasoning_parts = [t for k, t in all_outputs if k == "reasoning_content"] + assert "deep thought" in "".join(reasoning_parts) + assert mode == "content" + + # Feed content + outputs, window, mode = consume_stream_piece(window, mode, "visible answer") + all_outputs.extend(outputs) + assert mode == "content" + + # Flush remaining + flushed = flush_stream_deltas(window, mode) + all_content = [t for k, t in all_outputs if k == "content"] + [t for k, t in flushed if k == "content"] + assert "visible answer" in "".join(all_content) + + +def test_consume_stream_piece_tag_split_across_pieces(): + """The tag arrives split across two pieces.""" + window, mode = "", "reasoning" + all_outputs = [] + + outputs, window, mode = consume_stream_piece(window, mode, "thoughtanswer") + all_outputs.extend(outputs) + # Now the tag is complete, should have transitioned + assert mode == "content" + # Collect everything emitted across both calls + all_reasoning = [t for k, t in all_outputs if k == "reasoning_content"] + all_content = [t for k, t in all_outputs if k == "content"] + flushed = flush_stream_deltas(window, mode) + all_content += [t for k, t in flushed if k == "content"] + assert "thought" in "".join(all_reasoning) + assert "answer" in "".join(all_content) + + +def test_consume_stream_piece_content_mode_no_tags(): + """Plain content with no think tags passes through.""" + window, mode = "", "content" + assert mode == "content" + + outputs, window, mode = consume_stream_piece(window, mode, "hello world") + assert mode == "content" + flushed = flush_stream_deltas(window, mode) + all_text = [t for k, t in outputs if k == "content"] + [t for k, t in flushed if k == "content"] + assert "hello world" in "".join(all_text) + + +def test_flush_empty_window(): + assert flush_stream_deltas("", "content") == [] + assert flush_stream_deltas("", "reasoning") == [] + @patch("server.subprocess.Popen") def test_models_endpoint(mock_popen, mock_tokenizer): app = build_app( @@ -42,6 +145,8 @@ def test_models_endpoint(mock_popen, mock_tokenizer): @patch("server.os.read") def test_chat_completions_non_streaming(mock_os_read, mock_popen, mock_pipe, mock_tokenizer): mock_pipe.return_value = (1, 2) + mock_popen.return_value.poll.return_value = None # daemon alive + mock_tokenizer.decode.return_value = "private chain of thought\n\nvisible answer" app = build_app( target=Path("target.gguf"), @@ -69,13 +174,21 @@ def test_chat_completions_non_streaming(mock_os_read, mock_popen, mock_pipe, moc assert response.status_code == 200 data = response.json() assert data["object"] == "chat.completion" - assert data["choices"][0]["message"]["content"] == "hello" + assert data["choices"][0]["message"]["content"] == "visible answer" + assert data["choices"][0]["message"]["reasoning_content"] == "private chain of thought" @patch("server.os.pipe") @patch("server.subprocess.Popen") @patch("server.os.read") def test_chat_completions_streaming(mock_os_read, mock_popen, mock_pipe, mock_tokenizer): mock_pipe.return_value = (1, 2) + mock_popen.return_value.poll.return_value = None # daemon alive + mock_tokenizer.apply_chat_template.return_value = "\n" + mock_tokenizer.decode.side_effect = [ + "private thought", + "", + "visible answer", + ] app = build_app( target=Path("target.gguf"), @@ -89,6 +202,8 @@ def test_chat_completions_streaming(mock_os_read, mock_popen, mock_pipe, mock_to mock_os_read.side_effect = [ struct.pack("= 3 - assert lines[-1] == "data: [DONE]" \ No newline at end of file + assert '"reasoning_content"' in response.text + assert "" not in response.text + assert "data: [DONE]" in response.text