feat(sft): default loss_mask to renderer's sampled_mask#2644
Conversation
The existing LossMaskConfig was per-role booleans, AND'd against the renderer's sampled_mask in build_training_sample. With renderer fixes landing in PrimeIntellect-ai/renderers#66, the per-token sampled_mask is now the authoritative "what the model would produce at inference" signal — including chat-template stop tokens (e.g. GLM's <|user|> / <|observation|> after a tool-calling assistant turn) whose structural attribution lives in the *next* message's span. AND-ing those with a role filter that says "assistant only" silently masks them out of the loss, which is exactly the bug that caused GLM-4.5-Air SFT to dump 50-283 parallel tool_calls per turn on SWE-bench Verified. Refactor LossMaskConfig to a discriminated union mirroring prime-rl's existing ``Literal | dict-class`` pattern (cf. ``fused_lm_head_token_chunk_size: int | Literal["auto", "disabled"]``): - ``"sampled"`` (new default, recommended with renderers): every token the renderer marks ``is_sampled=True`` is trainable, regardless of role. Stop tokens whose attribution lives in the next message's span are correctly trained. - ``"all"``: every renderer token contributes to the loss. Debugging. - ``LossMaskRolesConfig`` (renamed from the old per-role LossMaskConfig): AND the renderer's ``is_sampled`` with per-role booleans. Strict opt-in for callers who want to restrict supervision to a subset of roles even when other roles' tokens are model-sampled. Existing TOML configs with ``[data.loss_mask] assistant = true`` parse as ``LossMaskRolesConfig`` via the union — old behaviour preserved. Configs that don't override ``loss_mask`` pick up the new ``"sampled"`` default, which is the fix. The chat-template fallback path (no renderer) has no ``sampled_mask`` signal so it always uses ``LossMaskRolesConfig()`` defaults regardless of the configured mode. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When a renderer is configured its token stream is authoritative — it mirrors the model's chat template (which carries no EOS by design) and trains the real stop signals via sampled_mask (e.g. GLM's turn-closing <|observation|> / <|user|>). Stop injecting an EOS the renderer didn't emit, and don't assert EOS-in-targets on the renderer path. The chat-template fallback (no sampled_mask) is unchanged: it still appends EOS so the model learns to stop.
| STACKING_DATASET_BUCKET_TIMEOUT = 10 | ||
|
|
||
|
|
||
| def _always_true_role_filter(message: dict) -> bool: |
| return True | ||
|
|
||
|
|
||
| def _role_filter_for(cfg: LossMaskRolesConfig): |
| # boolean lookup. The chat-template fallback below has no | ||
| # ``is_sampled`` signal so we treat sentinel modes as the | ||
| # ``LossMaskRolesConfig()`` defaults (assistant-only). | ||
| loss_mask_cfg = self.loss_mask_config |
There was a problem hiding this comment.
I feel like all this logic, plus the always_true_role_filter, should be in role_filter_for which just returns a tuple of the renderer_role_filter and fallback_role_filter. I found this split in logic between the role filter helper and this dataset quite confusing.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit de439e3. Configure here.
| fallback_role_filter = _role_filter_for(LossMaskRolesConfig()) | ||
| elif loss_mask_cfg == "all": | ||
| renderer_role_filter = _always_true_role_filter | ||
| fallback_role_filter = _role_filter_for(LossMaskRolesConfig()) |
There was a problem hiding this comment.
"all" fallback silently restricts loss to assistant-only tokens
Medium Severity
When loss_mask = "all" and no renderer is configured, fallback_role_filter is set to _role_filter_for(LossMaskRolesConfig()), which defaults to assistant-only masking. The "all" mode is documented as training on "every…token…regardless of role," but this fallback silently contradicts that. The chat-template build_incremental_token_mask uses role_to_mask directly (no is_sampled dependency), so _always_true_role_filter would correctly honor the "all" semantic in the fallback path.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit de439e3. Configure here.
| @@ -246,15 +290,15 @@ def should_mask(message: dict) -> bool: | |||
| input_ids, loss_mask = build_training_sample( | |||
There was a problem hiding this comment.
i dislike the extremely different names between build_training_sample and build_incremental_token_mask even though they do the same thing, just with a different backend. Not necessarily important for this PR, but I didn't want to forget it


Motivation
Some chat templates put the per-turn stop signal in the next message's span rather than inside the assistant turn itself. GLM-4.5 and GLM-5 are the immediate cases: after a tool-calling assistant message, the byte that closes the assistant's turn is the
<|observation|>(or<|user|>) at the start of the next message — there is no<|im_end|>-style terminator inside the assistant span. vLLM uses that token as a stop id at inference (get_stop_token_ids), but the renderer attributes it to the next (tool / user) message structurally.The previous SFT loss-mask pipeline AND'd the renderer's
sampled_maskwith a per-role filter (assistant=True, tool=False, user=False, system=Falseby default). For the GLM family this meant: the model would-sample<|observation|>at inference (sampled_mask=True), but it's attributed to the tool message (role_to_mask(tool)=False), so it was masked out of the loss and the model never learned to emit it. At inference the natural continuation of</tool_call>is another<tool_call>, so the model dumps parallel tool_calls instead of closing its turn. In a recent GLM-4.5-Air-Base SFT run, 26% of assistant turns had >1 tool call (max observed: 283 parallel calls in a single turn); on SWE-bench Verified the model'spass@1measurably suffered — rollouts with extreme multi-tool dumps solved at roughly half the rate of the bounded ones.The companion renderer fix (PrimeIntellect-ai/renderers#66) updates
glm45.py/glm5.pyto mark that stop opener asis_sampled=True. This PR teachesbuild_training_sample's caller (SFTDataset) to actually use that signal — by making thesampled_mask-only path the default rather than always AND-ing with a role filter.Changes
Refactor
LossMaskConfigfrom "a class of per-role booleans" to a discriminated union, mirroring prime-rl's existingLiteral | dict-classpattern (cf.fused_lm_head_token_chunk_size: int | Literal[\"auto\", \"disabled\"]inconfigs/trainer.py):\"sampled\"(new default, recommended with renderers): every token the renderer marksis_sampled=Trueis trainable, regardless of role. Stop tokens that the chat template attributes to the next message's span are correctly trained.\"all\": every renderer token contributes to the loss. Mostly useful for debugging.LossMaskRolesConfig(renamed from the old per-roleLossMaskConfigbody): AND the renderer'sis_sampledwith per-role booleans. Use this when you specifically want to restrict supervision to a subset of roles even when other roles' tokens are model-sampled.SFTDatasetdispatches:\"sampled\"→role_to_mask=None(relies onbuild_training_sample's newrole_to_maskdefault added in renderers#66, which falls back tosampled_mask).\"all\"→role_to_mask=lambda m: True.LossMaskRolesConfig→ callable that looks up the per-role bool, same as before.The chat-template fallback path (no renderer registered) has no
sampled_masksignal so it falls back toLossMaskRolesConfig()defaults regardless of the configured mode.Backward compatibility
Existing TOML configs of the form:
still parse: pydantic's union dispatch tries
Literal[\"sampled\", \"all\"]first (fails for a dict), thenLossMaskRolesConfig(succeeds). Per-role behavior is preserved verbatim.Configs that don't set
loss_maskat all pick up the new\"sampled\"default. This is a deliberate behavior change for SFT runs that relied on the implicit assistant-only filter — but the new default is what fixes the bug for renderers whose stop signal lives outside the assistant span. For configs that want to keep the old behavior, set[data.loss_mask] assistant = trueexplicitly.Test plan
tests/unit/train/sft) and orchestrator SFT-trajectory tests pass under both old (per-role) and new (\"sampled\") configs.\"sampled\",\"all\", and{\"assistant\": true, \"tool\": true}all parse to the expected variant.SFTDataConfig().loss_mask == \"sampled\".Companion PRs
<|user|>/<|observation|>openers asis_sampled=Truewhen they close an assistant turn, and makesrole_to_maskoptional inbuild_training_sample(falls back tosampled_mask).🤖 Generated with Claude Code
Note
Medium Risk
Default
loss_maskbehavior changes for configs that omit it (supervision can include non-assistant sampled tokens), which directly affects what the model learns during SFT; explicit per-role TOML remains backward compatible.Overview
SFT loss masking is refactored so renderer-based training can supervise every token the renderer marks as model-sampled (
is_sampled), not only tokens inside assistant spans. That fixes templates (e.g. GLM) where turn-ending stop tokens live on the next message’s role.Config:
loss_maskbecomesLiteral["sampled", "all"] | LossMaskRolesConfig(per-role settings moved toLossMaskRolesConfig). The default changes from implicit assistant-only role masking to"sampled". Existing TOML[data.loss_mask]tables still parse asLossMaskRolesConfig.Data pipeline:
SFTDatasetmaps"sampled"→role_to_mask=Noneforbuild_training_sample,"all"→ train all roles, and role config → the previous per-role AND. The tokenizer fallback (no renderer) still uses assistant-only role defaults for"sampled"/"all"because there is nois_sampledsignal. EOS is only auto-appended and required intarget_idson the chat-template path; renderer streams are left as emitted.Reviewed by Cursor Bugbot for commit de439e3. Bugbot is set up for automated code reviews on this repo. Configure here.