Skip to content

feat(sft): default loss_mask to renderer's sampled_mask#2644

Open
hallerite wants to merge 3 commits into
mainfrom
feat/sft-loss-mask-sampled-default
Open

feat(sft): default loss_mask to renderer's sampled_mask#2644
hallerite wants to merge 3 commits into
mainfrom
feat/sft-loss-mask-sampled-default

Conversation

@hallerite
Copy link
Copy Markdown
Member

@hallerite hallerite commented May 26, 2026

Motivation

Some chat templates put the per-turn stop signal in the next message's span rather than inside the assistant turn itself. GLM-4.5 and GLM-5 are the immediate cases: after a tool-calling assistant message, the byte that closes the assistant's turn is the <|observation|> (or <|user|>) at the start of the next message — there is no <|im_end|>-style terminator inside the assistant span. vLLM uses that token as a stop id at inference (get_stop_token_ids), but the renderer attributes it to the next (tool / user) message structurally.

The previous SFT loss-mask pipeline AND'd the renderer's sampled_mask with a per-role filter (assistant=True, tool=False, user=False, system=False by default). For the GLM family this meant: the model would-sample <|observation|> at inference (sampled_mask=True), but it's attributed to the tool message (role_to_mask(tool)=False), so it was masked out of the loss and the model never learned to emit it. At inference the natural continuation of </tool_call> is another <tool_call>, so the model dumps parallel tool_calls instead of closing its turn. In a recent GLM-4.5-Air-Base SFT run, 26% of assistant turns had >1 tool call (max observed: 283 parallel calls in a single turn); on SWE-bench Verified the model's pass@1 measurably suffered — rollouts with extreme multi-tool dumps solved at roughly half the rate of the bounded ones.

The companion renderer fix (PrimeIntellect-ai/renderers#66) updates glm45.py / glm5.py to mark that stop opener as is_sampled=True. This PR teaches build_training_sample's caller (SFTDataset) to actually use that signal — by making the sampled_mask-only path the default rather than always AND-ing with a role filter.

Changes

Refactor LossMaskConfig from "a class of per-role booleans" to a discriminated union, mirroring prime-rl's existing Literal | dict-class pattern (cf. fused_lm_head_token_chunk_size: int | Literal[\"auto\", \"disabled\"] in configs/trainer.py):

LossMaskConfig: TypeAlias = Literal["sampled", "all"] | LossMaskRolesConfig
  • \"sampled\" (new default, recommended with renderers): every token the renderer marks is_sampled=True is trainable, regardless of role. Stop tokens that the chat template attributes to the next message's span are correctly trained.
  • \"all\": every renderer token contributes to the loss. Mostly useful for debugging.
  • LossMaskRolesConfig (renamed from the old per-role LossMaskConfig body): AND the renderer's is_sampled with per-role booleans. Use this when you specifically want to restrict supervision to a subset of roles even when other roles' tokens are model-sampled.

SFTDataset dispatches:

  • \"sampled\"role_to_mask=None (relies on build_training_sample's new role_to_mask default added in renderers#66, which falls back to sampled_mask).
  • \"all\"role_to_mask=lambda m: True.
  • LossMaskRolesConfig → callable that looks up the per-role bool, same as before.

The chat-template fallback path (no renderer registered) has no sampled_mask signal so it falls back to LossMaskRolesConfig() defaults regardless of the configured mode.

Backward compatibility

Existing TOML configs of the form:

[data.loss_mask]
assistant = true
tool = false

still parse: pydantic's union dispatch tries Literal[\"sampled\", \"all\"] first (fails for a dict), then LossMaskRolesConfig (succeeds). Per-role behavior is preserved verbatim.

Configs that don't set loss_mask at all pick up the new \"sampled\" default. This is a deliberate behavior change for SFT runs that relied on the implicit assistant-only filter — but the new default is what fixes the bug for renderers whose stop signal lives outside the assistant span. For configs that want to keep the old behavior, set [data.loss_mask] assistant = true explicitly.

Test plan

  • Existing SFT unit tests (tests/unit/train/sft) and orchestrator SFT-trajectory tests pass under both old (per-role) and new (\"sampled\") configs.
  • Type-adapter round-trip: \"sampled\", \"all\", and {\"assistant\": true, \"tool\": true} all parse to the expected variant.
  • Default SFTDataConfig().loss_mask == \"sampled\".

Companion PRs

  • PrimeIntellect-ai/renderers#66 — marks GLM <|user|> / <|observation|> openers as is_sampled=True when they close an assistant turn, and makes role_to_mask optional in build_training_sample (falls back to sampled_mask).

🤖 Generated with Claude Code


Note

Medium Risk
Default loss_mask behavior changes for configs that omit it (supervision can include non-assistant sampled tokens), which directly affects what the model learns during SFT; explicit per-role TOML remains backward compatible.

Overview
SFT loss masking is refactored so renderer-based training can supervise every token the renderer marks as model-sampled (is_sampled), not only tokens inside assistant spans. That fixes templates (e.g. GLM) where turn-ending stop tokens live on the next message’s role.

Config: loss_mask becomes Literal["sampled", "all"] | LossMaskRolesConfig (per-role settings moved to LossMaskRolesConfig). The default changes from implicit assistant-only role masking to "sampled". Existing TOML [data.loss_mask] tables still parse as LossMaskRolesConfig.

Data pipeline: SFTDataset maps "sampled"role_to_mask=None for build_training_sample, "all" → train all roles, and role config → the previous per-role AND. The tokenizer fallback (no renderer) still uses assistant-only role defaults for "sampled" / "all" because there is no is_sampled signal. EOS is only auto-appended and required in target_ids on the chat-template path; renderer streams are left as emitted.

Reviewed by Cursor Bugbot for commit de439e3. Bugbot is set up for automated code reviews on this repo. Configure here.

hallerite and others added 2 commits May 27, 2026 02:29
The existing LossMaskConfig was per-role booleans, AND'd against the
renderer's sampled_mask in build_training_sample. With renderer fixes
landing in PrimeIntellect-ai/renderers#66, the per-token sampled_mask
is now the authoritative "what the model would produce at inference"
signal — including chat-template stop tokens (e.g. GLM's <|user|> /
<|observation|> after a tool-calling assistant turn) whose structural
attribution lives in the *next* message's span. AND-ing those with a
role filter that says "assistant only" silently masks them out of the
loss, which is exactly the bug that caused GLM-4.5-Air SFT to dump
50-283 parallel tool_calls per turn on SWE-bench Verified.

Refactor LossMaskConfig to a discriminated union mirroring prime-rl's
existing ``Literal | dict-class`` pattern (cf.
``fused_lm_head_token_chunk_size: int | Literal["auto", "disabled"]``):

- ``"sampled"`` (new default, recommended with renderers): every token
  the renderer marks ``is_sampled=True`` is trainable, regardless of
  role. Stop tokens whose attribution lives in the next message's span
  are correctly trained.
- ``"all"``: every renderer token contributes to the loss. Debugging.
- ``LossMaskRolesConfig`` (renamed from the old per-role LossMaskConfig):
  AND the renderer's ``is_sampled`` with per-role booleans. Strict
  opt-in for callers who want to restrict supervision to a subset of
  roles even when other roles' tokens are model-sampled.

Existing TOML configs with ``[data.loss_mask] assistant = true`` parse
as ``LossMaskRolesConfig`` via the union — old behaviour preserved.
Configs that don't override ``loss_mask`` pick up the new ``"sampled"``
default, which is the fix.

The chat-template fallback path (no renderer) has no ``sampled_mask``
signal so it always uses ``LossMaskRolesConfig()`` defaults regardless
of the configured mode.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When a renderer is configured its token stream is authoritative — it
mirrors the model's chat template (which carries no EOS by design) and
trains the real stop signals via sampled_mask (e.g. GLM's turn-closing
<|observation|> / <|user|>). Stop injecting an EOS the renderer didn't
emit, and don't assert EOS-in-targets on the renderer path. The
chat-template fallback (no sampled_mask) is unchanged: it still appends
EOS so the model learns to stop.
STACKING_DATASET_BUCKET_TIMEOUT = 10


def _always_true_role_filter(message: dict) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why private

return True


def _role_filter_for(cfg: LossMaskRolesConfig):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why private

# boolean lookup. The chat-template fallback below has no
# ``is_sampled`` signal so we treat sentinel modes as the
# ``LossMaskRolesConfig()`` defaults (assistant-only).
loss_mask_cfg = self.loss_mask_config
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like all this logic, plus the always_true_role_filter, should be in role_filter_for which just returns a tuple of the renderer_role_filter and fallback_role_filter. I found this split in logic between the role filter helper and this dataset quite confusing.

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit de439e3. Configure here.

fallback_role_filter = _role_filter_for(LossMaskRolesConfig())
elif loss_mask_cfg == "all":
renderer_role_filter = _always_true_role_filter
fallback_role_filter = _role_filter_for(LossMaskRolesConfig())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"all" fallback silently restricts loss to assistant-only tokens

Medium Severity

When loss_mask = "all" and no renderer is configured, fallback_role_filter is set to _role_filter_for(LossMaskRolesConfig()), which defaults to assistant-only masking. The "all" mode is documented as training on "every…token…regardless of role," but this fallback silently contradicts that. The chat-template build_incremental_token_mask uses role_to_mask directly (no is_sampled dependency), so _always_true_role_filter would correctly honor the "all" semantic in the fallback path.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit de439e3. Configure here.

@@ -246,15 +290,15 @@ def should_mask(message: dict) -> bool:
input_ids, loss_mask = build_training_sample(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dislike the extremely different names between build_training_sample and build_incremental_token_mask even though they do the same thing, just with a different backend. Not necessarily important for this PR, but I didn't want to forget it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants