-
Notifications
You must be signed in to change notification settings - Fork 18
feat(glm): make next-message role marker trainable after assistant turns #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0dbc5ff
d4dbe0a
cb07cc4
b9fae66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1098,8 +1098,7 @@ def _patched_load(model_name_or_path: str, **kwargs): | |
| fastokens.patch_transformers() | ||
| if not _FASTOKENS_ANNOUNCED: | ||
| logger.info( | ||
| "fastokens enabled — tokenizers load through the Rust BPE " | ||
| "fast path (~10x encode speedup)." | ||
| "fastokens enabled — tokenizers load through the Rust BPE fast path (~10x encode speedup)." | ||
| ) | ||
| _FASTOKENS_ANNOUNCED = True | ||
| try: | ||
|
|
@@ -1169,8 +1168,8 @@ def load_tokenizer( | |
| def _populate_registry(): | ||
| if RENDERER_REGISTRY: | ||
| return | ||
| from renderers.default import DefaultRenderer | ||
| from renderers.deepseek_v3 import DeepSeekV3Renderer | ||
| from renderers.default import DefaultRenderer | ||
| from renderers.glm5 import GLM5Renderer, GLM51Renderer | ||
| from renderers.glm45 import GLM45Renderer | ||
| from renderers.gpt_oss import GptOssRenderer | ||
|
|
@@ -1271,8 +1270,7 @@ def create_renderer( | |
| cls = RENDERER_REGISTRY.get(config.name) | ||
| if cls is None: | ||
| raise ValueError( | ||
| f"Unknown renderer {config.name!r}. " | ||
| f"Available: {', '.join(sorted(RENDERER_REGISTRY))}" | ||
| f"Unknown renderer {config.name!r}. Available: {', '.join(sorted(RENDERER_REGISTRY))}" | ||
| ) | ||
| return cls(tokenizer, config) | ||
|
|
||
|
|
@@ -1345,7 +1343,7 @@ def build_training_sample( | |
| renderer: Renderer, | ||
| messages: list[Message], | ||
| *, | ||
| role_to_mask: Callable[[Message], bool], | ||
| role_to_mask: Callable[[Message], bool] | None = None, | ||
| tools: list[ToolSpec] | None = None, | ||
| content_sft_roles: "set[str] | frozenset[str] | None" = None, | ||
| ) -> tuple[list[int], list[bool]]: | ||
|
|
@@ -1354,15 +1352,31 @@ def build_training_sample( | |
| Single render() call + message_indices → per-token mask. | ||
| Replaces build_incremental_token_mask (O(N) renders → O(1)). | ||
|
|
||
| When the renderer populates ``rendered.sampled_mask``, the loss mask | ||
| is the AND of role-based attribution and the sampled signal: only | ||
| tokens the model would have produced at inference are trainable. | ||
| This keeps SFT byte-aligned with the RL trajectory mask (where the | ||
| prompt / completion split achieves the same effect structurally). | ||
| When ``role_to_mask`` is omitted, ``loss_mask`` is the renderer's | ||
| ``sampled_mask`` directly: every token the model would have | ||
| produced at inference is trainable, regardless of which message | ||
| it's attributed to. This is the recommended default for renderer | ||
| callers — the renderer owns the per-token "is this model output" | ||
| signal, so role-level filtering becomes a downstream constraint | ||
| rather than a precondition. (Some role markers — e.g. GLM | ||
| ``<|user|>`` / ``<|observation|>`` after a tool-calling assistant | ||
| turn — *are* sampled by the model at inference and live inside the | ||
| next message's span; ``sampled_mask`` captures that, but a | ||
| naive role filter would mask them out.) | ||
|
|
||
| When ``role_to_mask`` is provided, ``loss_mask`` is the AND of the | ||
| role-based attribution and the sampled signal: only tokens the | ||
| model would have produced at inference AND attributed to a | ||
| trainable role pass through. Useful when the caller needs to | ||
| restrict training to a specific role (e.g. assistant-only) even on | ||
| a renderer whose ``sampled_mask`` already covers other roles. | ||
|
|
||
| Renderers that don't populate ``sampled_mask`` (empty list) fall | ||
| back to attribution-only masking — every token attributed to a | ||
| trainable role is trained on, including template-injected | ||
| ``<|im_start|>role\\n`` openers. | ||
| ``<|im_start|>role\\n`` openers. In this fallback mode | ||
| ``role_to_mask`` is required; calling without it raises | ||
| ``ValueError``. | ||
|
|
||
| ``content_sft_roles`` opts in additional roles for "body-only" | ||
| supervision: for every message whose role is in this set, tokens | ||
|
|
@@ -1393,6 +1407,13 @@ def build_training_sample( | |
| else: | ||
| body_roles = frozenset() | ||
|
|
||
| if role_to_mask is None and not has_sampled_info: | ||
| raise ValueError( | ||
| "role_to_mask is required when the renderer does not populate " | ||
| "sampled_mask. Pass an explicit role filter (e.g. " | ||
| "lambda m: m['role'] == 'assistant') for this renderer." | ||
| ) | ||
|
|
||
| loss_mask: list[bool] = [] | ||
| for k, msg_idx in enumerate(rendered.message_indices): | ||
| if msg_idx < 0: | ||
|
|
@@ -1408,6 +1429,11 @@ def build_training_sample( | |
| continue | ||
| if has_sampled_info and not rendered.sampled_mask[k]: | ||
| loss_mask.append(False) | ||
| elif role_to_mask is None: | ||
| # sampled_mask alone gates the loss when no role filter is | ||
| # supplied. ``sampled_mask[k]`` is True here (handled by the | ||
| # branch above), so this token is trainable. | ||
| loss_mask.append(True) | ||
|
cursor[bot] marked this conversation as resolved.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stop tokens still excluded from lossHigh Severity The Reviewed by Cursor Bugbot for commit cb07cc4. Configure here.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the natural consequence of GLM's chat-template design where the per-turn stop signal lives inside the next message's span — there is no dedicated
The two specialized modes you flag fall out of this:
In short, the GLM chat template can't express "this token belongs to message N but counts as message N-1's output for training" cleanly, and the choice we made here (keep structural attribution, use |
||
| else: | ||
| loss_mask.append(role_to_mask(msg)) | ||
| return rendered.token_ids, loss_mask | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.