From 67f95bffa5de3e84da7444cb31fd740aeafe43c0 Mon Sep 17 00:00:00 2001 From: nreHieW Date: Sun, 17 May 2026 22:37:58 +0800 Subject: [PATCH 1/2] zaya renderer --- renderers/__init__.py | 2 + renderers/base.py | 6 +- renderers/parsing.py | 75 ++++++ renderers/zaya1.py | 315 +++++++++++++++++++++++ tests/conftest.py | 1 + tests/test_preserve_thinking.py | 1 + tests/test_roundtrip.py | 1 + tests/test_tool_arg_type_preservation.py | 1 + 8 files changed, 401 insertions(+), 1 deletion(-) create mode 100644 renderers/zaya1.py diff --git a/renderers/__init__.py b/renderers/__init__.py index 62bc666..e8e6b24 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -50,6 +50,7 @@ from renderers.qwen3_vl import Qwen3VLRenderer from renderers.qwen35 import Qwen35Renderer from renderers.qwen36 import Qwen36Renderer +from renderers.zaya1 import Zaya1Renderer __all__ = [ "Content", @@ -87,6 +88,7 @@ "ToolCallParseStatus", "ToolSpec", "VideoPart", + "Zaya1Renderer", "__version__", "build_training_sample", "build_trajectory_step", diff --git a/renderers/base.py b/renderers/base.py index de0f2b9..a7c23a5 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -815,6 +815,8 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No # GPT-OSS. "openai/gpt-oss-20b": "gpt-oss", "openai/gpt-oss-120b": "gpt-oss", + # ZAYA1. + "Zyphra/ZAYA1-8B": "zaya1", } @@ -1026,6 +1028,7 @@ def _populate_registry(): from renderers.qwen3_vl import Qwen3VLRenderer from renderers.qwen35 import Qwen35Renderer from renderers.qwen36 import Qwen36Renderer + from renderers.zaya1 import Zaya1Renderer RENDERER_REGISTRY.update( { @@ -1044,6 +1047,7 @@ def _populate_registry(): "laguna-xs.2": LagunaXS2Renderer, "nemotron-3": Nemotron3Renderer, "gpt-oss": GptOssRenderer, + "zaya1": Zaya1Renderer, } ) @@ -1107,7 +1111,7 @@ def create_renderer( renderer: Renderer name ('qwen3', 'qwen3-vl', 'qwen3.5', 'qwen3.6', 'glm-5', 'glm-5.1', 'glm-4.5', 'minimax-m2', 'deepseek-v3', 'kimi-k2', 'kimi-k2.5', 'laguna-xs.2', 'nemotron-3', - 'gpt-oss', 'default') or 'auto' to detect from model name. + 'gpt-oss', 'zaya1', 'default') or 'auto' to detect from model name. tool_parser: Name of a tool parser registered in ``renderers.parsers``. Only consumed by DefaultRenderer. Model-specific renderers have their own parsing wired in. diff --git a/renderers/parsing.py b/renderers/parsing.py index 528f122..47ac657 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -127,6 +127,81 @@ def _decode(tokenizer, ids: list[int]) -> str: return tokenizer.decode(ids, skip_special_tokens=False) +# ── ZAYA1: ... ──────────────── + + +def parse_zaya1( + tokenizer, + token_ids: list[int], + *, + stop_ids: set[int], + tools: list[ToolSpec] | None = None, +) -> ParsedResponse: + """Parse ZAYA1 completion tokens. + + ZAYA1 uses plain-text XML-ish tool-call tags rather than dedicated + special-token boundaries, so this parser decodes once and then parses the + template's native tags. Argument coercion still uses the shared schema + helpers used by the token-level XML parsers. + """ + import re + + ids = _strip_stop_tokens(token_ids, stop_ids) + text = _decode(tokenizer, ids) + + reasoning = None + if "" in text: + before, after = text.split("", 1) + if "" in before: + reasoning = before.split("", 1)[-1].strip("\n") + else: + reasoning = before.strip("\n") + text = after.lstrip("\n") + + param_index = _build_param_type_index(tools) + tool_calls: list[ParsedToolCall] = [] + + pattern = re.compile( + r"\s*\n]+)>\s*(.*?)\s*\s*", + re.DOTALL, + ) + + def remove_call(match: re.Match[str]) -> str: + name = match.group(1).strip() + block = match.group(2) + params = param_index.get(name, {}) + arguments: dict[str, Any] = {} + any_json_fallback = False + for pm in re.finditer(r"\n]+)>\n?(.*?)\n?", block, re.DOTALL): + arg_name = pm.group(1).strip() + raw_value = pm.group(2).strip() + value, used_fallback = _coerce_arg_value(raw_value, params.get(arg_name)) + arguments[arg_name] = value + any_json_fallback = any_json_fallback or used_fallback + tool_calls.append( + ParsedToolCall( + raw=match.group(0), + name=name or None, + arguments=arguments, + status=( + ToolCallParseStatus.INVALID_JSON + if any_json_fallback + else ToolCallParseStatus.OK + if name + else ToolCallParseStatus.MISSING_NAME + ), + ) + ) + return "" + + content = pattern.sub(remove_call, text).strip() + return ParsedResponse( + content=content, + reasoning_content=reasoning or None, + tool_calls=tool_calls, + ) + + # ── Qwen3: JSON ──────────────────────────── diff --git a/renderers/zaya1.py b/renderers/zaya1.py new file mode 100644 index 0000000..8901827 --- /dev/null +++ b/renderers/zaya1.py @@ -0,0 +1,315 @@ +"""ZAYA1 renderer — hard-coded Python mirroring Zyphra/ZAYA1-8B's Jinja template.""" + +from __future__ import annotations + +import json + +from transformers.tokenization_utils import PreTrainedTokenizer + +from renderers.base import ( + Message, + ParsedResponse, + RenderedTokens, + ToolSpec, + reject_assistant_in_extension, + should_preserve_past_thinking, + trim_to_turn_close, +) +from renderers.parsing import parse_zaya1 + + +_TOOL_INSTRUCTIONS = """ + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +""" + + +class Zaya1Renderer: + """Deterministic message → token renderer for ``Zyphra/ZAYA1-8B``.""" + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + *, + enable_thinking: bool = True, + truncate_history_thinking: bool = False, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, + ): + self._tokenizer = tokenizer + self._enable_thinking = enable_thinking + self._truncate_history_thinking = truncate_history_thinking + self._preserve_all_thinking = preserve_all_thinking + self._preserve_thinking_between_tool_calls = ( + preserve_thinking_between_tool_calls + ) + self._bos = tokenizer.bos_token_id + self._eos = tokenizer.eos_token_id + self._im_start = self._token_id("<|im_start|>") + self._im_end = self._token_id("<|im_end|>") + + @property + def supports_tools(self) -> bool: + return True + + def _token_id(self, token: str) -> int: + tid = self._tokenizer.convert_tokens_to_ids(token) + assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, ( + f"Special token {token!r} not found in tokenizer vocabulary" + ) + return tid + + def _encode(self, text: str) -> list[int]: + if not text: + return [] + return self._tokenizer.encode(text, add_special_tokens=False) + + def render( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> RenderedTokens: + if not messages: + raise ValueError("No messages provided.") + + tokens: list[int] = [] + indices: list[int] = [] + sampled: list[bool] = [] + + def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None: + tokens.extend(ids) + indices.extend([msg_idx] * len(ids)) + sampled.extend([is_sampled] * len(ids)) + + def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + tokens.append(token_id) + indices.append(msg_idx) + sampled.append(is_sampled) + + def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: + emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled) + + if self._bos is not None: + emit_special(self._bos, -1, is_sampled=False) + + first_is_system = messages[0].get("role") == "system" + system_message = str(messages[0].get("content") or "") if first_is_system else "" + loop_messages = messages[1:] if first_is_system else messages + loop_offset = 1 if first_is_system else 0 + last_user_idx = max( + (j for j, m in enumerate(loop_messages) if m.get("role") == "user"), + default=-1, + ) + + # The upstream template always defines system_message, so it always + # emits a system block, even when the caller did not supply one. + sys_idx = 0 if first_is_system else -1 + emit_special(self._im_start, sys_idx, is_sampled=False) + emit_text("system\n" + system_message, sys_idx, is_sampled=False) + if tools: + if system_message: + emit_text("\n\n", sys_idx, is_sampled=False) + emit_text(self._render_tools(tools), sys_idx, is_sampled=False) + emit_special(self._im_end, sys_idx, is_sampled=False) + emit_text("\n", sys_idx, is_sampled=False) + + for rel_i, msg in enumerate(loop_messages): + i = rel_i + loop_offset + role = msg.get("role") + content = self._string_content(msg.get("content") or "") + if role == "assistant": + preserve_thinking = should_preserve_past_thinking( + messages, + i, + preserve_all_thinking=self._preserve_all_thinking, + preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + ) + include_content = not ( + self._truncate_history_thinking and rel_i < last_user_idx + ) or preserve_thinking + self._render_assistant( + msg, + i, + content, + include_content=include_content, + emit_special=emit_special, + emit_text=emit_text, + ) + elif role in {"user", "system"}: + emit_special(self._im_start, i, is_sampled=False) + emit_text(f"{role}\n{content}", i, is_sampled=False) + emit_special(self._im_end, i, is_sampled=False) + emit_text("\n", i, is_sampled=False) + elif role == "tool": + self._render_tool(loop_messages, rel_i, i, content, emit_special, emit_text) + else: + emit_special(self._im_start, i, is_sampled=False) + emit_text(f"{role}\n{content}", i, is_sampled=False) + emit_special(self._im_end, i, is_sampled=False) + emit_text("\n", i, is_sampled=False) + + if add_generation_prompt: + emit_special(self._im_start, -1, is_sampled=False) + if self._enable_thinking: + emit_text("assistant\n\n", -1, is_sampled=False) + else: + emit_text("assistant\n\n\n\n", -1, is_sampled=False) + + return RenderedTokens( + token_ids=tokens, + message_indices=indices, + sampled_mask=sampled, + message_roles=[m.get("role") or "" for m in messages], + ) + + def render_ids(self, messages, *, tools=None, add_generation_prompt=False) -> list[int]: + return self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt).token_ids + + def parse_response(self, token_ids: list[int], *, tools: list[ToolSpec] | None = None) -> ParsedResponse: + return parse_zaya1(self._tokenizer, token_ids, stop_ids={self._im_end, self._eos} if self._eos is not None else {self._im_end}, tools=tools) + + def get_stop_token_ids(self) -> list[int]: + return [self._im_end] + ([] if self._eos is None else [self._eos]) + + def bridge_to_next_turn(self, previous_prompt_ids, previous_completion_ids, new_messages, *, tools=None): + if not previous_prompt_ids or not new_messages or reject_assistant_in_extension(new_messages): + return None + previous_ids = trim_to_turn_close( + previous_prompt_ids, + previous_completion_ids, + set(self.get_stop_token_ids()), + synthesize_close=self._im_end, + ) + if previous_ids is None: + return None + rendered = self.render(new_messages, tools=None, add_generation_prompt=True) + # Drop BOS + the template's synthetic empty system block for extensions. + prefix = [] if self._bos is None else [self._bos] + empty_system = prefix + self._encode("<|im_start|>system\n<|im_end|>\n") + ext = rendered.token_ids[len(empty_system) :] if rendered.token_ids[: len(empty_system)] == empty_system else rendered.token_ids + total_len = len(previous_ids) + len(ext) + return RenderedTokens( + token_ids=previous_ids + ext, + message_indices=[-1] * len(previous_ids) + rendered.message_indices[-len(ext) :], + sampled_mask=[False] * total_len, + message_roles=[m.get("role") or "" for m in new_messages], + ) + + def _render_assistant(self, msg, msg_idx, content, *, include_content, emit_special, emit_text) -> None: + reasoning = msg.get("reasoning_content") + if isinstance(reasoning, str) and reasoning.strip(): + rendered_content = f"\n{reasoning}\n\n\n{content}" + elif isinstance(content, str) and "" not in content and "" not in content: + rendered_content = f"\n\n\n{content}" + else: + rendered_content = content + + if not include_content: + rendered_content = self._truncate_thinking(rendered_content) + + tool_calls = msg.get("tool_calls") or [] + emit_special(self._im_start, msg_idx, is_sampled=False) + emit_text("assistant\n", msg_idx, is_sampled=False) + if tool_calls: + body = rendered_content.strip() + if body: + emit_text(body + "\n\n", msg_idx, is_sampled=True) + else: + emit_text("\n\n\n", msg_idx, is_sampled=True) + for tc in tool_calls: + emit_text(self._render_tool_call(tc), msg_idx, is_sampled=True) + emit_special(self._im_end, msg_idx, is_sampled=True) + emit_text("\n", msg_idx, is_sampled=False) + else: + emit_text(rendered_content.strip(), msg_idx, is_sampled=True) + emit_special(self._im_end, msg_idx, is_sampled=True) + emit_text("\n", msg_idx, is_sampled=False) + + def _render_tool(self, loop_messages, rel_i, msg_idx, content, emit_special, emit_text) -> None: + prev_is_tool = rel_i > 0 and loop_messages[rel_i - 1].get("role") == "tool" + next_is_tool = rel_i + 1 < len(loop_messages) and loop_messages[rel_i + 1].get("role") == "tool" + if not prev_is_tool: + emit_special(self._im_start, msg_idx, is_sampled=False) + emit_text("user\n", msg_idx, is_sampled=False) + emit_text(f"\n{content}\n\n", msg_idx, is_sampled=False) + if not next_is_tool: + emit_special(self._im_end, msg_idx, is_sampled=False) + emit_text("\n", msg_idx, is_sampled=False) + + def _render_tools(self, tools: list[ToolSpec]) -> str: + text = "# Tools\n\nYou have access to the following functions:\n\n" + for raw_tool in tools: + tool = raw_tool.get("function", raw_tool) if isinstance(raw_tool, dict) else raw_tool + text += f"\n\n{tool.get('name', '')}" + if tool.get("description") is not None: + text += f"\n{str(tool['description']).strip()}" + params = tool.get("parameters") or {} + text += "\n" + props = params.get("properties") if isinstance(params, dict) else None + if isinstance(props, dict): + for name, fields in props.items(): + text += f"\n\n{name}" + if fields.get("type") is not None: + text += f"\n{fields['type']}" + if fields.get("description") is not None: + text += f"\n{str(fields['description']).strip()}" + if fields.get("enum") is not None: + text += "\n" + json.dumps(fields["enum"], ensure_ascii=False) + "" + text += "\n" + if isinstance(params, dict) and params.get("required") is not None: + text += "\n" + json.dumps(params["required"], ensure_ascii=False) + "" + text += "\n\n" + return text + "\n" + _TOOL_INSTRUCTIONS + + def _render_tool_call(self, tc) -> str: + func = tc.get("function") or tc + text = f"\n\n" + arguments = func.get("arguments") or {} + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + for name, value in arguments.items(): + if isinstance(value, (dict, list)): + value_text = json.dumps(value, ensure_ascii=False) + else: + value_text = str(value) + text += f"\n{value_text}\n\n" + return text + "\n\n" + + @staticmethod + def _string_content(content) -> str: + if isinstance(content, list): + return "".join(str(p.get("text", "")) for p in content if isinstance(p, dict)) + return str(content) + + @staticmethod + def _truncate_thinking(content: str) -> str: + if "" in content: + content = content.split("")[-1] + elif "" in content: + content = content.split("")[0] + return ("\n\n\n" + content.strip()).strip() diff --git a/tests/conftest.py b/tests/conftest.py index 8eea97b..0a8d879 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,7 @@ ("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"), ("poolside/Laguna-XS.2", "auto"), ("openai/gpt-oss-20b", "gpt-oss"), + ("Zyphra/ZAYA1-8B", "auto"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] diff --git a/tests/test_preserve_thinking.py b/tests/test_preserve_thinking.py index 661d577..8e9ee2c 100644 --- a/tests/test_preserve_thinking.py +++ b/tests/test_preserve_thinking.py @@ -43,6 +43,7 @@ def _make(tokenizer, renderer_name, **flags): "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct", "poolside/Laguna-XS.2", + "Zyphra/ZAYA1-8B", } diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index a4577fd..2934f51 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -45,6 +45,7 @@ ("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"), ("poolside/Laguna-XS.2", "auto"), ("openai/gpt-oss-20b", "gpt-oss"), + ("Zyphra/ZAYA1-8B", "auto"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] diff --git a/tests/test_tool_arg_type_preservation.py b/tests/test_tool_arg_type_preservation.py index 61dbb9b..75bce67 100644 --- a/tests/test_tool_arg_type_preservation.py +++ b/tests/test_tool_arg_type_preservation.py @@ -35,6 +35,7 @@ ("zai-org/GLM-5", "auto"), # XML ("MiniMaxAI/MiniMax-M2.5", "auto"), # XML ("poolside/Laguna-XS.2", "auto"), # XML + ("Zyphra/ZAYA1-8B", "auto"), # XML ] From 6f9b76cb3c3876b8c26243e1938bd05a761bd349 Mon Sep 17 00:00:00 2001 From: nreHieW Date: Tue, 19 May 2026 00:00:16 +0800 Subject: [PATCH 2/2] add notes --- renderers/zaya1.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/renderers/zaya1.py b/renderers/zaya1.py index 8901827..44ccc60 100644 --- a/renderers/zaya1.py +++ b/renderers/zaya1.py @@ -1,4 +1,11 @@ -"""ZAYA1 renderer — hard-coded Python mirroring Zyphra/ZAYA1-8B's Jinja template.""" +"""ZAYA1 renderer — hard-coded Python mirroring Zyphra/ZAYA1-8B's Jinja template. + +Notes: +- The template always emits an empty system prelude, even when the caller did not pass one. +- Multi-turn bridging must strip that synthetic BOS + empty system prelude from subsequent turns. +- Tool calls use Zyphra's XML-ish ```` format. +- Thinking can be optionally truncated from history. +""" from __future__ import annotations