Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions renderers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from renderers.qwen3_vl import Qwen3VLRenderer
from renderers.qwen35 import Qwen35Renderer
from renderers.qwen36 import Qwen36Renderer
from renderers.zaya1 import Zaya1Renderer

__all__ = [
"Content",
Expand Down Expand Up @@ -87,6 +88,7 @@
"ToolCallParseStatus",
"ToolSpec",
"VideoPart",
"Zaya1Renderer",
"__version__",
"build_training_sample",
"build_trajectory_step",
Expand Down
6 changes: 5 additions & 1 deletion renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,8 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No
# GPT-OSS.
"openai/gpt-oss-20b": "gpt-oss",
"openai/gpt-oss-120b": "gpt-oss",
# ZAYA1.
"Zyphra/ZAYA1-8B": "zaya1",
}


Expand Down Expand Up @@ -1026,6 +1028,7 @@ def _populate_registry():
from renderers.qwen3_vl import Qwen3VLRenderer
from renderers.qwen35 import Qwen35Renderer
from renderers.qwen36 import Qwen36Renderer
from renderers.zaya1 import Zaya1Renderer

RENDERER_REGISTRY.update(
{
Expand All @@ -1044,6 +1047,7 @@ def _populate_registry():
"laguna-xs.2": LagunaXS2Renderer,
"nemotron-3": Nemotron3Renderer,
"gpt-oss": GptOssRenderer,
"zaya1": Zaya1Renderer,
}
)

Expand Down Expand Up @@ -1107,7 +1111,7 @@ def create_renderer(
renderer: Renderer name ('qwen3', 'qwen3-vl', 'qwen3.5', 'qwen3.6',
'glm-5', 'glm-5.1', 'glm-4.5', 'minimax-m2', 'deepseek-v3',
'kimi-k2', 'kimi-k2.5', 'laguna-xs.2', 'nemotron-3',
'gpt-oss', 'default') or 'auto' to detect from model name.
'gpt-oss', 'zaya1', 'default') or 'auto' to detect from model name.
tool_parser: Name of a tool parser registered in ``renderers.parsers``.
Only consumed by DefaultRenderer. Model-specific renderers
have their own parsing wired in.
Expand Down
75 changes: 75 additions & 0 deletions renderers/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,81 @@ def _decode(tokenizer, ids: list[int]) -> str:
return tokenizer.decode(ids, skip_special_tokens=False)


# ── ZAYA1: <zyphra_tool_call> <function=name> ... ────────────────


def parse_zaya1(
tokenizer,
token_ids: list[int],
*,
stop_ids: set[int],
tools: list[ToolSpec] | None = None,
) -> ParsedResponse:
"""Parse ZAYA1 completion tokens.

ZAYA1 uses plain-text XML-ish tool-call tags rather than dedicated
special-token boundaries, so this parser decodes once and then parses the
template's native tags. Argument coercion still uses the shared schema
helpers used by the token-level XML parsers.
"""
import re

ids = _strip_stop_tokens(token_ids, stop_ids)
text = _decode(tokenizer, ids)

reasoning = None
if "</think>" in text:
before, after = text.split("</think>", 1)
if "<think>" in before:
reasoning = before.split("<think>", 1)[-1].strip("\n")
else:
reasoning = before.strip("\n")
text = after.lstrip("\n")

param_index = _build_param_type_index(tools)
tool_calls: list[ParsedToolCall] = []

pattern = re.compile(
r"<zyphra_tool_call>\s*<function=([^>\n]+)>\s*(.*?)\s*</function>\s*</zyphra_tool_call>",
re.DOTALL,
)

def remove_call(match: re.Match[str]) -> str:
name = match.group(1).strip()
block = match.group(2)
params = param_index.get(name, {})
arguments: dict[str, Any] = {}
any_json_fallback = False
for pm in re.finditer(r"<parameter=([^>\n]+)>\n?(.*?)\n?</parameter>", block, re.DOTALL):
arg_name = pm.group(1).strip()
raw_value = pm.group(2).strip()
value, used_fallback = _coerce_arg_value(raw_value, params.get(arg_name))
arguments[arg_name] = value
any_json_fallback = any_json_fallback or used_fallback
tool_calls.append(
ParsedToolCall(
raw=match.group(0),
name=name or None,
arguments=arguments,
status=(
ToolCallParseStatus.INVALID_JSON
if any_json_fallback
else ToolCallParseStatus.OK
if name
else ToolCallParseStatus.MISSING_NAME
),
)
)
return ""

content = pattern.sub(remove_call, text).strip()
return ParsedResponse(
content=content,
reasoning_content=reasoning or None,
tool_calls=tool_calls,
)


# ── Qwen3: <tool_call> JSON </tool_call> ────────────────────────────


Expand Down
Loading