From 0dbc5ff27f2d7b1d70882c8a94f384f014db43ee Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 27 May 2026 00:30:14 +0530 Subject: [PATCH 1/4] feat(glm): mark next-message role marker as trainable after assistant turns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit GLM-4.5 / GLM-5 templates have no per-turn close token inside the assistant span — the next message's role marker (<|user|> / <|observation|>) doubles as the inference stop signal (see get_stop_token_ids). Before this change, that role marker was emitted with is_sampled=False, so build_training_sample never trained the model to predict it. At inference the natural continuation of \n is another , which is what we observed: GLM-Air SFT step_250 dumped 50-283 parallel tool_calls per turn on SWE-bench Verified (median max-tool-calls 35 for solved rollouts vs 55 for failed; pass@1 25.0%). Fix: - glm45.py / glm5.py: when the previous message is an assistant, emit the next message's role-opening token with is_sampled=True. The token stays attributed to that next message (msg_idx unchanged); only sampled_mask flips. Byte stream identical to apply_chat_template. - base.py: build_training_sample now accepts role_to_mask=None. With a renderer that populates sampled_mask, the default behaviour becomes loss_mask = sampled_mask — the renderer's per-token "model would have sampled this" signal is the truth, and role-level filtering is opt-in rather than required. Callers passing role_to_mask keep the old AND-behaviour; callers without it now train on every sampled token (including the assistant's closing role marker). system role only appears at the start of a GLM conversation, so its opener is never the closer of an assistant turn — no logic needed for it. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/base.py | 71 ++++++++++++++----------- renderers/glm45.py | 119 ++++++++++++++++++++---------------------- renderers/glm5.py | 126 +++++++++++++++++++++------------------------ 3 files changed, 155 insertions(+), 161 deletions(-) diff --git a/renderers/base.py b/renderers/base.py index 65edf68..82d136e 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -220,9 +220,7 @@ class RenderedTokens: message_roles: list[str] = field(default_factory=list) multi_modal_data: "MultiModalData | None" = None - def tokens_per_message( - self, n_messages: int | None = None, *, sampled_only: bool = False - ) -> list[int]: + def tokens_per_message(self, n_messages: int | None = None, *, sampled_only: bool = False) -> list[int]: """Count rendered tokens attributed to each caller-relative message. ``out[i]`` is the number of tokens with ``message_indices[k] == i``, @@ -1097,10 +1095,7 @@ def _patched_load(model_name_or_path: str, **kwargs): with contextlib.redirect_stdout(io.StringIO()): fastokens.patch_transformers() if not _FASTOKENS_ANNOUNCED: - logger.info( - "fastokens enabled — tokenizers load through the Rust BPE " - "fast path (~10x encode speedup)." - ) + logger.info("fastokens enabled — tokenizers load through the Rust BPE fast path (~10x encode speedup).") _FASTOKENS_ANNOUNCED = True try: return AutoTokenizer.from_pretrained(model_name_or_path, **kwargs) @@ -1169,8 +1164,8 @@ def load_tokenizer( def _populate_registry(): if RENDERER_REGISTRY: return - from renderers.default import DefaultRenderer from renderers.deepseek_v3 import DeepSeekV3Renderer + from renderers.default import DefaultRenderer from renderers.glm5 import GLM5Renderer, GLM51Renderer from renderers.glm45 import GLM45Renderer from renderers.gpt_oss import GptOssRenderer @@ -1270,10 +1265,7 @@ def create_renderer( if not isinstance(config, AutoRendererConfig): cls = RENDERER_REGISTRY.get(config.name) if cls is None: - raise ValueError( - f"Unknown renderer {config.name!r}. " - f"Available: {', '.join(sorted(RENDERER_REGISTRY))}" - ) + raise ValueError(f"Unknown renderer {config.name!r}. Available: {', '.join(sorted(RENDERER_REGISTRY))}") return cls(tokenizer, config) return _resolve_auto(tokenizer, config) @@ -1345,7 +1337,7 @@ def build_training_sample( renderer: Renderer, messages: list[Message], *, - role_to_mask: Callable[[Message], bool], + role_to_mask: Callable[[Message], bool] | None = None, tools: list[ToolSpec] | None = None, content_sft_roles: "set[str] | frozenset[str] | None" = None, ) -> tuple[list[int], list[bool]]: @@ -1354,15 +1346,31 @@ def build_training_sample( Single render() call + message_indices → per-token mask. Replaces build_incremental_token_mask (O(N) renders → O(1)). - When the renderer populates ``rendered.sampled_mask``, the loss mask - is the AND of role-based attribution and the sampled signal: only - tokens the model would have produced at inference are trainable. - This keeps SFT byte-aligned with the RL trajectory mask (where the - prompt / completion split achieves the same effect structurally). + When ``role_to_mask`` is omitted, ``loss_mask`` is the renderer's + ``sampled_mask`` directly: every token the model would have + produced at inference is trainable, regardless of which message + it's attributed to. This is the recommended default for renderer + callers — the renderer owns the per-token "is this model output" + signal, so role-level filtering becomes a downstream constraint + rather than a precondition. (Some role markers — e.g. GLM + ``<|user|>`` / ``<|observation|>`` after a tool-calling assistant + turn — *are* sampled by the model at inference and live inside the + next message's span; ``sampled_mask`` captures that, but a + naive role filter would mask them out.) + + When ``role_to_mask`` is provided, ``loss_mask`` is the AND of the + role-based attribution and the sampled signal: only tokens the + model would have produced at inference AND attributed to a + trainable role pass through. Useful when the caller needs to + restrict training to a specific role (e.g. assistant-only) even on + a renderer whose ``sampled_mask`` already covers other roles. + Renderers that don't populate ``sampled_mask`` (empty list) fall back to attribution-only masking — every token attributed to a trainable role is trained on, including template-injected - ``<|im_start|>role\\n`` openers. + ``<|im_start|>role\\n`` openers. In this fallback mode + ``role_to_mask`` is required; calling without it raises + ``ValueError``. ``content_sft_roles`` opts in additional roles for "body-only" supervision: for every message whose role is in this set, tokens @@ -1393,6 +1401,13 @@ def build_training_sample( else: body_roles = frozenset() + if role_to_mask is None and not has_sampled_info: + raise ValueError( + "role_to_mask is required when the renderer does not populate " + "sampled_mask. Pass an explicit role filter (e.g. " + "lambda m: m['role'] == 'assistant') for this renderer." + ) + loss_mask: list[bool] = [] for k, msg_idx in enumerate(rendered.message_indices): if msg_idx < 0: @@ -1408,6 +1423,11 @@ def build_training_sample( continue if has_sampled_info and not rendered.sampled_mask[k]: loss_mask.append(False) + elif role_to_mask is None: + # sampled_mask alone gates the loss when no role filter is + # supplied. ``sampled_mask[k]`` is True here (handled by the + # branch above), so this token is trainable. + loss_mask.append(True) else: loss_mask.append(role_to_mask(msg)) return rendered.token_ids, loss_mask @@ -1659,9 +1679,7 @@ def should_preserve_past_thinking( return False # The current segment must contain a tool response for it to count # as an in-flight tool cycle. - return any( - messages[j].get("role") == "tool" for j in range(last_user + 1, len(messages)) - ) + return any(messages[j].get("role") == "tool" for j in range(last_user + 1, len(messages))) def build_trajectory_step( @@ -1683,9 +1701,7 @@ def build_trajectory_step( the completion). """ has_completion = len(completion_messages) > 0 - prompt_ids = renderer.render_ids( - prompt_messages, tools=tools, add_generation_prompt=has_completion - ) + prompt_ids = renderer.render_ids(prompt_messages, tools=tools, add_generation_prompt=has_completion) full_rendered = renderer.render(prompt_messages + completion_messages, tools=tools) full_ids = full_rendered.token_ids @@ -1700,9 +1716,6 @@ def build_trajectory_step( "completion_logprobs": [0.0] * len(completion_ids), "routed_experts": None, } - if ( - full_rendered.multi_modal_data is not None - and not full_rendered.multi_modal_data.is_empty() - ): + if full_rendered.multi_modal_data is not None and not full_rendered.multi_modal_data.is_empty(): out["multi_modal_data"] = full_rendered.multi_modal_data return out diff --git a/renderers/glm45.py b/renderers/glm45.py index ed0e0b7..acb0944 100644 --- a/renderers/glm45.py +++ b/renderers/glm45.py @@ -125,26 +125,20 @@ def render( sampled: list[bool] = [] content_mask: list[bool] = [] - def emit_special( - token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool - ) -> None: + def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) content_mask.append(is_content) - def emit_text( - text: str, msg_idx: int, *, is_sampled: bool, is_content: bool - ) -> None: + def emit_text(text: str, msg_idx: int, *, is_sampled: bool, is_content: bool) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) content_mask.extend([is_content] * len(ids)) - def emit_text_segments( - segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool - ) -> None: + def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool) -> None: """Tokenize concatenated segments as one BPE pass; per-token ``is_content`` follows each token's source segment. @@ -152,9 +146,7 @@ def emit_text_segments( same way as the chat template, but attributed separately" without splitting the encode call (which could shift BPE merges at the boundary).""" - for tok_id, is_content in attribute_text_segments( - self._tokenizer, segments - ): + for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): tokens.append(tok_id) indices.append(msg_idx) sampled.append(is_sampled) @@ -184,16 +176,34 @@ def emit_text_segments( role = msg["role"] content = self._visible_text(msg.get("content")) + # When the previous message is an assistant, this message's + # role-opening token (``<|user|>`` / ``<|observation|>``) is + # the inference-time stop signal that closes the assistant's + # turn (see ``get_stop_token_ids``). Mark it + # ``is_sampled=True`` and ``is_content=True`` so the + # loss-mask pipeline trains the model to emit it after + # ```` (instead of continuing with another + # ```` block) in both the default ``sampled`` + # path and the body-only ``content_sft_roles`` path. The + # token stays attributed to this message (msg_idx=i); the + # byte stream is unchanged. ``system`` only appears at the + # start of a GLM conversation, so its opener is never the + # closer of an assistant turn. + closes_assistant_turn = i > 0 and messages[i - 1]["role"] == "assistant" + if role == "system": emit_special(self._system, i, is_sampled=False, is_content=False) # ``\n`` is the scaffold separator after the role tag; # the body proper is the caller-provided content. - emit_text_segments( - [("\n", False), (content, True)], i, is_sampled=False - ) + emit_text_segments([("\n", False), (content, True)], i, is_sampled=False) elif role == "user": - emit_special(self._user, i, is_sampled=False, is_content=False) + emit_special( + self._user, + i, + is_sampled=closes_assistant_turn, + is_content=closes_assistant_turn, + ) # ``\n`` is scaffold; ``content`` is body; the optional # ``/nothink`` suffix is scaffold the renderer injects # when ``enable_thinking=False``. @@ -291,21 +301,14 @@ def bridge_to_next_turn( *, tools: list[ToolSpec] | None = None, ) -> RenderedTokens | None: - if ( - not previous_prompt_ids - or not new_messages - or reject_assistant_in_extension(new_messages) - ): + if not previous_prompt_ids or not new_messages or reject_assistant_in_extension(new_messages): return None # Same next-turn-marker scheme as GLM-5, but role markers are # followed by a literal ``\n`` in the prompt text. previous_ids = list(previous_prompt_ids) + list(previous_completion_ids) stop_ids = {self._endoftext, self._user, self._observation} - if ( - not previous_ids[len(previous_prompt_ids) :] - or previous_ids[-1] not in stop_ids - ): + if not previous_ids[len(previous_prompt_ids) :] or previous_ids[-1] not in stop_ids: previous_ids.append(self._endoftext) last_prev = previous_ids[-1] @@ -354,9 +357,7 @@ def emit_text_segments( *, is_sampled: bool = False, ) -> None: - for tok_id, is_content in attribute_text_segments( - self._tokenizer, segments - ): + for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): ext.append(tok_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) @@ -458,9 +459,7 @@ def _render_assistant( if (msg_idx > last_user_index or preserve_thinking) and reasoning_content: emit_special(self._think, msg_idx, is_sampled=True, is_content=True) - emit_text( - reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True - ) + emit_text(reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True) emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) else: emit_special(self._think, msg_idx, is_sampled=True, is_content=True) @@ -469,9 +468,7 @@ def _render_assistant( # Tool calls — keep content + \n contiguous to preserve BPE merges tool_calls = msg.get("tool_calls") or [] if content.strip() and tool_calls: - emit_text( - "\n" + content.strip() + "\n", msg_idx, is_sampled=True, is_content=True - ) + emit_text("\n" + content.strip() + "\n", msg_idx, is_sampled=True, is_content=True) elif content.strip(): emit_text("\n" + content.strip(), msg_idx, is_sampled=True, is_content=True) @@ -493,17 +490,11 @@ def _render_assistant( arguments = {} if isinstance(arguments, dict): for arg_name, arg_value in arguments.items(): - emit_special( - self._arg_key, msg_idx, is_sampled=True, is_content=True - ) + emit_special(self._arg_key, msg_idx, is_sampled=True, is_content=True) emit_text(arg_name, msg_idx, is_sampled=True, is_content=True) - emit_special( - self._arg_key_end, msg_idx, is_sampled=True, is_content=True - ) + emit_special(self._arg_key_end, msg_idx, is_sampled=True, is_content=True) emit_text("\n", msg_idx, is_sampled=True, is_content=True) - emit_special( - self._arg_value, msg_idx, is_sampled=True, is_content=True - ) + emit_special(self._arg_value, msg_idx, is_sampled=True, is_content=True) if isinstance(arg_value, str): emit_text(arg_value, msg_idx, is_sampled=True, is_content=True) else: @@ -513,13 +504,9 @@ def _render_assistant( is_sampled=True, is_content=True, ) - emit_special( - self._arg_value_end, msg_idx, is_sampled=True, is_content=True - ) + emit_special(self._arg_value_end, msg_idx, is_sampled=True, is_content=True) emit_text("\n", msg_idx, is_sampled=True, is_content=True) - emit_special( - self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True - ) + emit_special(self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True) def _render_tool( self, @@ -531,21 +518,25 @@ def _render_tool( emit_text, emit_text_segments, ) -> None: - # Tool messages are conversation history injected by the runtime - # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. The body bytes - # get ``is_content=True``; the ``\n\n`` / - # ``\n`` wraps and the ``<|observation|>`` role - # tag are scaffold so the SFT mask for tool body never trains - # the model to emit them. Single BPE pass over the joined text - # preserves boundary merges (the tool body's leading/trailing - # chars can merge with the wrap's ``\n``s if the tokenizer would - # do so; we route through ``emit_text_segments`` so the - # attribution is offset-driven and tokenizer-agnostic). - prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" - - if not prev_is_tool: - emit_special(self._observation, msg_idx, is_sampled=False, is_content=False) + # Tool body bytes get ``is_content=True``; the wraps are + # scaffold. The ``<|observation|>`` role tag is normally + # scaffold too, but when the previous message is an assistant + # it doubles as the inference stop signal for that assistant's + # turn — mark it ``is_sampled=True`` and ``is_content=True`` so + # SFT trains the model to emit it after ```` in + # both the default sampled path and the body-only + # ``content_sft_roles`` path. The token stays attributed to + # this tool message; byte stream is unchanged. + prev_role = messages[msg_idx - 1]["role"] if msg_idx > 0 else None + closes_assistant_turn = prev_role == "assistant" + + if prev_role != "tool": + emit_special( + self._observation, + msg_idx, + is_sampled=closes_assistant_turn, + is_content=closes_assistant_turn, + ) emit_text_segments( [ diff --git a/renderers/glm5.py b/renderers/glm5.py index f3e28e3..470b894 100644 --- a/renderers/glm5.py +++ b/renderers/glm5.py @@ -145,26 +145,20 @@ def render( sampled: list[bool] = [] content_mask: list[bool] = [] - def emit_special( - token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool - ) -> None: + def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) content_mask.append(is_content) - def emit_text( - text: str, msg_idx: int, *, is_sampled: bool, is_content: bool - ) -> None: + def emit_text(text: str, msg_idx: int, *, is_sampled: bool, is_content: bool) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) content_mask.extend([is_content] * len(ids)) - def emit_text_segments( - segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool - ) -> None: + def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool) -> None: """Tokenize concatenated segments as one BPE pass; per-token ``is_content`` follows each token's source segment. @@ -172,9 +166,7 @@ def emit_text_segments( same way as the chat template, but attributed separately" without splitting the encode call (which could shift BPE merges at the boundary).""" - for tok_id, is_content in attribute_text_segments( - self._tokenizer, segments - ): + for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): tokens.append(tok_id) indices.append(msg_idx) sampled.append(is_sampled) @@ -207,12 +199,32 @@ def emit_text_segments( role = msg["role"] content = self._visible_text(msg.get("content")) + # When the previous message is an assistant, this message's + # role-opening token (``<|user|>`` / ``<|observation|>``) is + # the inference-time stop signal that closes the assistant's + # turn (see ``get_stop_token_ids``). Mark it + # ``is_sampled=True`` and ``is_content=True`` so the + # loss-mask pipeline trains the model to emit it after + # ```` (instead of continuing with another + # ```` block) in both the default ``sampled`` + # path and the body-only ``content_sft_roles`` path. The + # token stays attributed to this message (msg_idx=i); the + # byte stream is unchanged. ``system`` only appears at the + # start of a GLM conversation, so its opener is never the + # closer of an assistant turn. + closes_assistant_turn = i > 0 and messages[i - 1]["role"] == "assistant" + if role == "system": emit_special(self._system, i, is_sampled=False, is_content=False) emit_text(content, i, is_sampled=False, is_content=True) elif role == "user": - emit_special(self._user, i, is_sampled=False, is_content=False) + emit_special( + self._user, + i, + is_sampled=closes_assistant_turn, + is_content=closes_assistant_turn, + ) emit_text(content, i, is_sampled=False, is_content=True) elif role == "assistant": @@ -307,11 +319,7 @@ def bridge_to_next_turn( *, tools: list[ToolSpec] | None = None, ) -> RenderedTokens | None: - if ( - not previous_prompt_ids - or not new_messages - or reject_assistant_in_extension(new_messages) - ): + if not previous_prompt_ids or not new_messages or reject_assistant_in_extension(new_messages): return None # GLM has no per-turn close token. An assistant turn ends when the @@ -321,10 +329,7 @@ def bridge_to_next_turn( # previous_completion_ids. Truncation means none is there yet. previous_ids = list(previous_prompt_ids) + list(previous_completion_ids) stop_ids = {self._endoftext, self._user, self._observation} - if ( - not previous_ids[len(previous_prompt_ids) :] - or previous_ids[-1] not in stop_ids - ): + if not previous_ids[len(previous_prompt_ids) :] or previous_ids[-1] not in stop_ids: # Truncation: synthesise <|endoftext|> as the canonical turn end. previous_ids.append(self._endoftext) @@ -374,9 +379,7 @@ def emit_text_segments( *, is_sampled: bool = False, ) -> None: - for tok_id, is_content in attribute_text_segments( - self._tokenizer, segments - ): + for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): ext.append(tok_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) @@ -467,9 +470,7 @@ def _render_assistant( # clear_thinking`` gate: a chat_template_kwarg surface for the # same behaviour, gated explicitly by the caller per render. include_thinking = ( - msg_idx > last_user_index - or preserve_thinking - or not self.config.clear_thinking + msg_idx > last_user_index or preserve_thinking or not self.config.clear_thinking ) and reasoning_content if include_thinking: @@ -478,15 +479,9 @@ def _render_assistant( # template-injected scaffolding. The reasoning text and the # closing ```` are what the model actually samples. emit_special(self._think, msg_idx, is_sampled=False, is_content=False) - emit_text( - reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True - ) + emit_text(reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True) emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) - elif ( - self.empty_think_on_last_assistant - and msg_idx > last_user_index - and self.config.enable_thinking - ): + elif self.empty_think_on_last_assistant and msg_idx > last_user_index and self.config.enable_thinking: # GLM-5.1: wrap the last assistant with an empty # even without reasoning, matching the Jinja template. With # ``enable_thinking=True`` the gen prompt already includes @@ -530,16 +525,10 @@ def _render_assistant( arguments = {} if isinstance(arguments, dict): for arg_name, arg_value in arguments.items(): - emit_special( - self._arg_key, msg_idx, is_sampled=True, is_content=True - ) + emit_special(self._arg_key, msg_idx, is_sampled=True, is_content=True) emit_text(arg_name, msg_idx, is_sampled=True, is_content=True) - emit_special( - self._arg_key_end, msg_idx, is_sampled=True, is_content=True - ) - emit_special( - self._arg_value, msg_idx, is_sampled=True, is_content=True - ) + emit_special(self._arg_key_end, msg_idx, is_sampled=True, is_content=True) + emit_special(self._arg_value, msg_idx, is_sampled=True, is_content=True) if isinstance(arg_value, str): emit_text(arg_value, msg_idx, is_sampled=True, is_content=True) else: @@ -549,12 +538,8 @@ def _render_assistant( is_sampled=True, is_content=True, ) - emit_special( - self._arg_value_end, msg_idx, is_sampled=True, is_content=True - ) - emit_special( - self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True - ) + emit_special(self._arg_value_end, msg_idx, is_sampled=True, is_content=True) + emit_special(self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True) def _render_tool( self, @@ -566,24 +551,29 @@ def _render_tool( emit_text, emit_text_segments, ) -> None: - # Tool messages are conversation history injected by the runtime - # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. The tool body - # bytes get ``is_content=True``; the ``<|observation|>`` / - # ```` wraps are scaffold so the SFT mask for - # tool body never trains the model to emit them. - prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" - - if not prev_is_tool: - emit_special(self._observation, msg_idx, is_sampled=False, is_content=False) - - emit_special( - self._tool_response_tok, msg_idx, is_sampled=False, is_content=False - ) + # Tool body bytes get ``is_content=True``; the wraps are + # scaffold. The ``<|observation|>`` role tag is normally + # scaffold too, but when the previous message is an assistant + # it doubles as the inference stop signal for that assistant's + # turn — mark it ``is_sampled=True`` and ``is_content=True`` so + # SFT trains the model to emit it after ```` in + # both the default sampled path and the body-only + # ``content_sft_roles`` path. The token stays attributed to + # this tool message; byte stream is unchanged. + prev_role = messages[msg_idx - 1]["role"] if msg_idx > 0 else None + closes_assistant_turn = prev_role == "assistant" + + if prev_role != "tool": + emit_special( + self._observation, + msg_idx, + is_sampled=closes_assistant_turn, + is_content=closes_assistant_turn, + ) + + emit_special(self._tool_response_tok, msg_idx, is_sampled=False, is_content=False) emit_text(content, msg_idx, is_sampled=False, is_content=True) - emit_special( - self._tool_response_end_tok, msg_idx, is_sampled=False, is_content=False - ) + emit_special(self._tool_response_end_tok, msg_idx, is_sampled=False, is_content=False) class GLM51Renderer(GLM5Renderer): From d4dbe0a3fe7bc96b968541e0e8a4900856adb983 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 27 May 2026 01:53:41 +0530 Subject: [PATCH 2/4] docs(glm): explain why bridge intentionally diverges from render MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bridge synthesizes the next-step prompt by appending new messages (tool responses, follow-up user turns) on top of what the model actually sampled this round (``previous_completion_ids``). Tokens it emits are template scaffolding for the next generation, not model output from this step. When the model fails to sample its stop token and the bridge has to synthesize the boundary via a ``<|user|>`` / ``<|observation|>`` role-opener for the first new message, that opener stays ``is_sampled=False, is_content=False`` — even though :meth:`render` would mark the same token ``True, True`` on the SFT path. The discrepancy is intentional: render's masks describe what the model *should* produce given a complete conversation (SFT target); bridge's masks describe what it *actually* produced this step (RL signal). The RL loss operates on ``previous_completion_ids`` exclusively; bridge tokens belong to the subsequent prompt and must not be counted as sampled output by downstream mask consumers (e.g. per-step credit assignment in trajectory metrics). Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/glm45.py | 15 +++++++++++++++ renderers/glm5.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/renderers/glm45.py b/renderers/glm45.py index acb0944..ac48d56 100644 --- a/renderers/glm45.py +++ b/renderers/glm45.py @@ -363,6 +363,21 @@ def emit_text_segments( ext_sampled.append(is_sampled) ext_content.append(is_content) + # The opener-token of the first new_message may also serve as + # the close of the previous assistant turn (when the model + # failed to sample the stop token itself and the bridge has to + # synthesize the boundary above). Unlike :meth:`render`, the + # bridge emits these with ``is_sampled=False, is_content=False`` + # — they are template scaffolding for the *next* step's prompt, + # not tokens the model produced *in this* step. The RL loss + # operates on ``previous_completion_ids`` (what the model + # actually sampled this round); bridge tokens belong to the + # subsequent prompt and must not be counted as "model output" + # by downstream mask consumers. This deliberate disagreement + # with ``render()`` reflects the SFT vs RL semantics: render's + # masks describe what the model *should* produce given a + # complete conversation; bridge's masks describe what it + # *actually* produced this step. for i, msg in enumerate(new_messages): role = msg.get("role") content = self._visible_text(msg.get("content")) diff --git a/renderers/glm5.py b/renderers/glm5.py index 470b894..1185965 100644 --- a/renderers/glm5.py +++ b/renderers/glm5.py @@ -385,6 +385,21 @@ def emit_text_segments( ext_sampled.append(is_sampled) ext_content.append(is_content) + # The opener-token of the first new_message may also serve as + # the close of the previous assistant turn (when the model + # failed to sample the stop token itself and the bridge has to + # synthesize the boundary above). Unlike :meth:`render`, the + # bridge emits these with ``is_sampled=False, is_content=False`` + # — they are template scaffolding for the *next* step's prompt, + # not tokens the model produced *in this* step. The RL loss + # operates on ``previous_completion_ids`` (what the model + # actually sampled this round); bridge tokens belong to the + # subsequent prompt and must not be counted as "model output" + # by downstream mask consumers. This deliberate disagreement + # with ``render()`` reflects the SFT vs RL semantics: render's + # masks describe what the model *should* produce given a + # complete conversation; bridge's masks describe what it + # *actually* produced this step. for i, msg in enumerate(new_messages): role = msg.get("role") content = self._visible_text(msg.get("content")) From cb07cc492d05db16cdb0a65eee4e700942a512b6 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 27 May 2026 02:11:44 +0530 Subject: [PATCH 3/4] fix(glm): keep closing role-opener is_content=False to preserve tool-body semantics The previous commit attributed ``<|user|>`` / ``<|observation|>`` tokens that close an assistant turn with ``is_content=True`` so the body-only SFT path (``content_sft_roles``) would also train on them. That bled the role-marker into ``content_mask_for_roles({"tool"})`` and ``content_token_spans_by_role()``, which downstream consumers expect to return the actual tool response body only. Keep ``is_sampled=True`` (the actual training signal that makes the model learn to emit the stop marker after ````) but revert ``is_content`` to ``False``: the role-opener is scaffolding, not body. In the body-only SFT mode the assistant stop-signal training belongs to the RL path anyway, not to SFT on tool-response bodies. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/glm45.py | 36 ++++++++++++++++++------------------ renderers/glm5.py | 36 ++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/renderers/glm45.py b/renderers/glm45.py index ac48d56..ee3d5d8 100644 --- a/renderers/glm45.py +++ b/renderers/glm45.py @@ -180,15 +180,16 @@ def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sam # role-opening token (``<|user|>`` / ``<|observation|>``) is # the inference-time stop signal that closes the assistant's # turn (see ``get_stop_token_ids``). Mark it - # ``is_sampled=True`` and ``is_content=True`` so the - # loss-mask pipeline trains the model to emit it after - # ```` (instead of continuing with another - # ```` block) in both the default ``sampled`` - # path and the body-only ``content_sft_roles`` path. The - # token stays attributed to this message (msg_idx=i); the - # byte stream is unchanged. ``system`` only appears at the - # start of a GLM conversation, so its opener is never the - # closer of an assistant turn. + # ``is_sampled=True`` so the loss-mask pipeline trains the + # model to emit it after ```` (instead of + # continuing with another ```` block). The token + # stays attributed to this message (msg_idx=i) and remains + # ``is_content=False`` — it's a role-marker / scaffold, not + # body bytes, so ``content_mask_for_roles({"tool"})`` and + # ``content_token_spans_by_role()`` correctly exclude it + # from "tool body" views. Byte stream is unchanged. + # ``system`` only appears at the start of a GLM conversation, + # so its opener is never the closer of an assistant turn. closes_assistant_turn = i > 0 and messages[i - 1]["role"] == "assistant" if role == "system": @@ -202,7 +203,7 @@ def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sam self._user, i, is_sampled=closes_assistant_turn, - is_content=closes_assistant_turn, + is_content=False, ) # ``\n`` is scaffold; ``content`` is body; the optional # ``/nothink`` suffix is scaffold the renderer injects @@ -534,13 +535,12 @@ def _render_tool( emit_text_segments, ) -> None: # Tool body bytes get ``is_content=True``; the wraps are - # scaffold. The ``<|observation|>`` role tag is normally - # scaffold too, but when the previous message is an assistant - # it doubles as the inference stop signal for that assistant's - # turn — mark it ``is_sampled=True`` and ``is_content=True`` so - # SFT trains the model to emit it after ```` in - # both the default sampled path and the body-only - # ``content_sft_roles`` path. The token stays attributed to + # scaffold. The ``<|observation|>`` role tag is scaffold too + # (``is_content=False`` so ``content_mask_for_roles({"tool"})`` + # excludes it). When the previous message is an assistant it + # doubles as the inference stop signal for that assistant's + # turn — mark it ``is_sampled=True`` so SFT trains the model to + # emit it after ````. The token stays attributed to # this tool message; byte stream is unchanged. prev_role = messages[msg_idx - 1]["role"] if msg_idx > 0 else None closes_assistant_turn = prev_role == "assistant" @@ -550,7 +550,7 @@ def _render_tool( self._observation, msg_idx, is_sampled=closes_assistant_turn, - is_content=closes_assistant_turn, + is_content=False, ) emit_text_segments( diff --git a/renderers/glm5.py b/renderers/glm5.py index 1185965..4b7e32e 100644 --- a/renderers/glm5.py +++ b/renderers/glm5.py @@ -203,15 +203,16 @@ def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sam # role-opening token (``<|user|>`` / ``<|observation|>``) is # the inference-time stop signal that closes the assistant's # turn (see ``get_stop_token_ids``). Mark it - # ``is_sampled=True`` and ``is_content=True`` so the - # loss-mask pipeline trains the model to emit it after - # ```` (instead of continuing with another - # ```` block) in both the default ``sampled`` - # path and the body-only ``content_sft_roles`` path. The - # token stays attributed to this message (msg_idx=i); the - # byte stream is unchanged. ``system`` only appears at the - # start of a GLM conversation, so its opener is never the - # closer of an assistant turn. + # ``is_sampled=True`` so the loss-mask pipeline trains the + # model to emit it after ```` (instead of + # continuing with another ```` block). The token + # stays attributed to this message (msg_idx=i) and remains + # ``is_content=False`` — it's a role-marker / scaffold, not + # body bytes, so ``content_mask_for_roles({"tool"})`` and + # ``content_token_spans_by_role()`` correctly exclude it + # from "tool body" views. Byte stream is unchanged. + # ``system`` only appears at the start of a GLM conversation, + # so its opener is never the closer of an assistant turn. closes_assistant_turn = i > 0 and messages[i - 1]["role"] == "assistant" if role == "system": @@ -223,7 +224,7 @@ def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sam self._user, i, is_sampled=closes_assistant_turn, - is_content=closes_assistant_turn, + is_content=False, ) emit_text(content, i, is_sampled=False, is_content=True) @@ -567,13 +568,12 @@ def _render_tool( emit_text_segments, ) -> None: # Tool body bytes get ``is_content=True``; the wraps are - # scaffold. The ``<|observation|>`` role tag is normally - # scaffold too, but when the previous message is an assistant - # it doubles as the inference stop signal for that assistant's - # turn — mark it ``is_sampled=True`` and ``is_content=True`` so - # SFT trains the model to emit it after ```` in - # both the default sampled path and the body-only - # ``content_sft_roles`` path. The token stays attributed to + # scaffold. The ``<|observation|>`` role tag is scaffold too + # (``is_content=False`` so ``content_mask_for_roles({"tool"})`` + # excludes it). When the previous message is an assistant it + # doubles as the inference stop signal for that assistant's + # turn — mark it ``is_sampled=True`` so SFT trains the model to + # emit it after ````. The token stays attributed to # this tool message; byte stream is unchanged. prev_role = messages[msg_idx - 1]["role"] if msg_idx > 0 else None closes_assistant_turn = prev_role == "assistant" @@ -583,7 +583,7 @@ def _render_tool( self._observation, msg_idx, is_sampled=closes_assistant_turn, - is_content=closes_assistant_turn, + is_content=False, ) emit_special(self._tool_response_tok, msg_idx, is_sampled=False, is_content=False) From b9fae665bf0b93f994c591193902d36f0d8b3b11 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 27 May 2026 02:15:31 +0530 Subject: [PATCH 4/4] style: reformat to ruff default line-length (88) The renderers repo has no ``[tool.ruff]`` override for ``line-length``, so CI uses ruff's default of 88. Locally, ruff was picking up the parent prime-rl workspace's ``line-length = 120`` (via parent-dir config discovery), which is why ``ruff format --check`` passed locally but failed CI. Reformatting the three files this PR touches at the repo's actual line length. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/base.py | 25 ++++++++++++---- renderers/glm45.py | 63 +++++++++++++++++++++++++++++---------- renderers/glm5.py | 73 +++++++++++++++++++++++++++++++++++----------- 3 files changed, 123 insertions(+), 38 deletions(-) diff --git a/renderers/base.py b/renderers/base.py index 82d136e..e9805c4 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -220,7 +220,9 @@ class RenderedTokens: message_roles: list[str] = field(default_factory=list) multi_modal_data: "MultiModalData | None" = None - def tokens_per_message(self, n_messages: int | None = None, *, sampled_only: bool = False) -> list[int]: + def tokens_per_message( + self, n_messages: int | None = None, *, sampled_only: bool = False + ) -> list[int]: """Count rendered tokens attributed to each caller-relative message. ``out[i]`` is the number of tokens with ``message_indices[k] == i``, @@ -1095,7 +1097,9 @@ def _patched_load(model_name_or_path: str, **kwargs): with contextlib.redirect_stdout(io.StringIO()): fastokens.patch_transformers() if not _FASTOKENS_ANNOUNCED: - logger.info("fastokens enabled — tokenizers load through the Rust BPE fast path (~10x encode speedup).") + logger.info( + "fastokens enabled — tokenizers load through the Rust BPE fast path (~10x encode speedup)." + ) _FASTOKENS_ANNOUNCED = True try: return AutoTokenizer.from_pretrained(model_name_or_path, **kwargs) @@ -1265,7 +1269,9 @@ def create_renderer( if not isinstance(config, AutoRendererConfig): cls = RENDERER_REGISTRY.get(config.name) if cls is None: - raise ValueError(f"Unknown renderer {config.name!r}. Available: {', '.join(sorted(RENDERER_REGISTRY))}") + raise ValueError( + f"Unknown renderer {config.name!r}. Available: {', '.join(sorted(RENDERER_REGISTRY))}" + ) return cls(tokenizer, config) return _resolve_auto(tokenizer, config) @@ -1679,7 +1685,9 @@ def should_preserve_past_thinking( return False # The current segment must contain a tool response for it to count # as an in-flight tool cycle. - return any(messages[j].get("role") == "tool" for j in range(last_user + 1, len(messages))) + return any( + messages[j].get("role") == "tool" for j in range(last_user + 1, len(messages)) + ) def build_trajectory_step( @@ -1701,7 +1709,9 @@ def build_trajectory_step( the completion). """ has_completion = len(completion_messages) > 0 - prompt_ids = renderer.render_ids(prompt_messages, tools=tools, add_generation_prompt=has_completion) + prompt_ids = renderer.render_ids( + prompt_messages, tools=tools, add_generation_prompt=has_completion + ) full_rendered = renderer.render(prompt_messages + completion_messages, tools=tools) full_ids = full_rendered.token_ids @@ -1716,6 +1726,9 @@ def build_trajectory_step( "completion_logprobs": [0.0] * len(completion_ids), "routed_experts": None, } - if full_rendered.multi_modal_data is not None and not full_rendered.multi_modal_data.is_empty(): + if ( + full_rendered.multi_modal_data is not None + and not full_rendered.multi_modal_data.is_empty() + ): out["multi_modal_data"] = full_rendered.multi_modal_data return out diff --git a/renderers/glm45.py b/renderers/glm45.py index ee3d5d8..efea47b 100644 --- a/renderers/glm45.py +++ b/renderers/glm45.py @@ -125,20 +125,26 @@ def render( sampled: list[bool] = [] content_mask: list[bool] = [] - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool, is_content: bool) -> None: + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) content_mask.extend([is_content] * len(ids)) - def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool) -> None: + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: """Tokenize concatenated segments as one BPE pass; per-token ``is_content`` follows each token's source segment. @@ -146,7 +152,9 @@ def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sam same way as the chat template, but attributed separately" without splitting the encode call (which could shift BPE merges at the boundary).""" - for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): tokens.append(tok_id) indices.append(msg_idx) sampled.append(is_sampled) @@ -196,7 +204,9 @@ def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sam emit_special(self._system, i, is_sampled=False, is_content=False) # ``\n`` is the scaffold separator after the role tag; # the body proper is the caller-provided content. - emit_text_segments([("\n", False), (content, True)], i, is_sampled=False) + emit_text_segments( + [("\n", False), (content, True)], i, is_sampled=False + ) elif role == "user": emit_special( @@ -302,14 +312,21 @@ def bridge_to_next_turn( *, tools: list[ToolSpec] | None = None, ) -> RenderedTokens | None: - if not previous_prompt_ids or not new_messages or reject_assistant_in_extension(new_messages): + if ( + not previous_prompt_ids + or not new_messages + or reject_assistant_in_extension(new_messages) + ): return None # Same next-turn-marker scheme as GLM-5, but role markers are # followed by a literal ``\n`` in the prompt text. previous_ids = list(previous_prompt_ids) + list(previous_completion_ids) stop_ids = {self._endoftext, self._user, self._observation} - if not previous_ids[len(previous_prompt_ids) :] or previous_ids[-1] not in stop_ids: + if ( + not previous_ids[len(previous_prompt_ids) :] + or previous_ids[-1] not in stop_ids + ): previous_ids.append(self._endoftext) last_prev = previous_ids[-1] @@ -358,7 +375,9 @@ def emit_text_segments( *, is_sampled: bool = False, ) -> None: - for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): ext.append(tok_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) @@ -475,7 +494,9 @@ def _render_assistant( if (msg_idx > last_user_index or preserve_thinking) and reasoning_content: emit_special(self._think, msg_idx, is_sampled=True, is_content=True) - emit_text(reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True) + emit_text( + reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True + ) emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) else: emit_special(self._think, msg_idx, is_sampled=True, is_content=True) @@ -484,7 +505,9 @@ def _render_assistant( # Tool calls — keep content + \n contiguous to preserve BPE merges tool_calls = msg.get("tool_calls") or [] if content.strip() and tool_calls: - emit_text("\n" + content.strip() + "\n", msg_idx, is_sampled=True, is_content=True) + emit_text( + "\n" + content.strip() + "\n", msg_idx, is_sampled=True, is_content=True + ) elif content.strip(): emit_text("\n" + content.strip(), msg_idx, is_sampled=True, is_content=True) @@ -506,11 +529,17 @@ def _render_assistant( arguments = {} if isinstance(arguments, dict): for arg_name, arg_value in arguments.items(): - emit_special(self._arg_key, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_key, msg_idx, is_sampled=True, is_content=True + ) emit_text(arg_name, msg_idx, is_sampled=True, is_content=True) - emit_special(self._arg_key_end, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_key_end, msg_idx, is_sampled=True, is_content=True + ) emit_text("\n", msg_idx, is_sampled=True, is_content=True) - emit_special(self._arg_value, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_value, msg_idx, is_sampled=True, is_content=True + ) if isinstance(arg_value, str): emit_text(arg_value, msg_idx, is_sampled=True, is_content=True) else: @@ -520,9 +549,13 @@ def _render_assistant( is_sampled=True, is_content=True, ) - emit_special(self._arg_value_end, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_value_end, msg_idx, is_sampled=True, is_content=True + ) emit_text("\n", msg_idx, is_sampled=True, is_content=True) - emit_special(self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True + ) def _render_tool( self, diff --git a/renderers/glm5.py b/renderers/glm5.py index 4b7e32e..a42a0af 100644 --- a/renderers/glm5.py +++ b/renderers/glm5.py @@ -145,20 +145,26 @@ def render( sampled: list[bool] = [] content_mask: list[bool] = [] - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool, is_content: bool) -> None: + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) content_mask.extend([is_content] * len(ids)) - def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool) -> None: + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: """Tokenize concatenated segments as one BPE pass; per-token ``is_content`` follows each token's source segment. @@ -166,7 +172,9 @@ def emit_text_segments(segments: list[tuple[str, bool]], msg_idx: int, *, is_sam same way as the chat template, but attributed separately" without splitting the encode call (which could shift BPE merges at the boundary).""" - for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): tokens.append(tok_id) indices.append(msg_idx) sampled.append(is_sampled) @@ -320,7 +328,11 @@ def bridge_to_next_turn( *, tools: list[ToolSpec] | None = None, ) -> RenderedTokens | None: - if not previous_prompt_ids or not new_messages or reject_assistant_in_extension(new_messages): + if ( + not previous_prompt_ids + or not new_messages + or reject_assistant_in_extension(new_messages) + ): return None # GLM has no per-turn close token. An assistant turn ends when the @@ -330,7 +342,10 @@ def bridge_to_next_turn( # previous_completion_ids. Truncation means none is there yet. previous_ids = list(previous_prompt_ids) + list(previous_completion_ids) stop_ids = {self._endoftext, self._user, self._observation} - if not previous_ids[len(previous_prompt_ids) :] or previous_ids[-1] not in stop_ids: + if ( + not previous_ids[len(previous_prompt_ids) :] + or previous_ids[-1] not in stop_ids + ): # Truncation: synthesise <|endoftext|> as the canonical turn end. previous_ids.append(self._endoftext) @@ -380,7 +395,9 @@ def emit_text_segments( *, is_sampled: bool = False, ) -> None: - for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): ext.append(tok_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) @@ -486,7 +503,9 @@ def _render_assistant( # clear_thinking`` gate: a chat_template_kwarg surface for the # same behaviour, gated explicitly by the caller per render. include_thinking = ( - msg_idx > last_user_index or preserve_thinking or not self.config.clear_thinking + msg_idx > last_user_index + or preserve_thinking + or not self.config.clear_thinking ) and reasoning_content if include_thinking: @@ -495,9 +514,15 @@ def _render_assistant( # template-injected scaffolding. The reasoning text and the # closing ```` are what the model actually samples. emit_special(self._think, msg_idx, is_sampled=False, is_content=False) - emit_text(reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True) + emit_text( + reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True + ) emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) - elif self.empty_think_on_last_assistant and msg_idx > last_user_index and self.config.enable_thinking: + elif ( + self.empty_think_on_last_assistant + and msg_idx > last_user_index + and self.config.enable_thinking + ): # GLM-5.1: wrap the last assistant with an empty # even without reasoning, matching the Jinja template. With # ``enable_thinking=True`` the gen prompt already includes @@ -541,10 +566,16 @@ def _render_assistant( arguments = {} if isinstance(arguments, dict): for arg_name, arg_value in arguments.items(): - emit_special(self._arg_key, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_key, msg_idx, is_sampled=True, is_content=True + ) emit_text(arg_name, msg_idx, is_sampled=True, is_content=True) - emit_special(self._arg_key_end, msg_idx, is_sampled=True, is_content=True) - emit_special(self._arg_value, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_key_end, msg_idx, is_sampled=True, is_content=True + ) + emit_special( + self._arg_value, msg_idx, is_sampled=True, is_content=True + ) if isinstance(arg_value, str): emit_text(arg_value, msg_idx, is_sampled=True, is_content=True) else: @@ -554,8 +585,12 @@ def _render_assistant( is_sampled=True, is_content=True, ) - emit_special(self._arg_value_end, msg_idx, is_sampled=True, is_content=True) - emit_special(self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_value_end, msg_idx, is_sampled=True, is_content=True + ) + emit_special( + self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True + ) def _render_tool( self, @@ -586,9 +621,13 @@ def _render_tool( is_content=False, ) - emit_special(self._tool_response_tok, msg_idx, is_sampled=False, is_content=False) + emit_special( + self._tool_response_tok, msg_idx, is_sampled=False, is_content=False + ) emit_text(content, msg_idx, is_sampled=False, is_content=True) - emit_special(self._tool_response_end_tok, msg_idx, is_sampled=False, is_content=False) + emit_special( + self._tool_response_end_tok, msg_idx, is_sampled=False, is_content=False + ) class GLM51Renderer(GLM5Renderer):