From 8c71d77ba3df430205d5f0d0d29b1879e788d912 Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Sun, 10 May 2026 15:48:20 -0700 Subject: [PATCH] feat(diagnostics): structured bridge-failure reasons bridge_to_next_turn returns list[int] | None. When the bridge bails, the caller has no way to learn which of the 6 README-documented failure modes hit. Adds renderers.diagnostics.diagnose_bridge that returns a typed BridgeDiagnostic so prime-rl and verifiers can observe bridge health per-turn during rollouts instead of discovering 32/64 silent drops after training. - renderers/diagnostics.py (~245 LOC): BridgeFailureReason StrEnum covering all 6 documented modes plus UNKNOWN_TEMPLATE_CLOSE for the DefaultRenderer fall-through. BridgeDiagnostic dataclass carries the reason, message_index, token_span, and a short detail string suitable for a logger.info line. diagnose_bridge orchestrates: contract-level checks first (assistant in extension, default renderer, truncation_zeroed_anchor), then runs the bridge and a fresh render and classifies the first divergent token. Per-renderer hints (Qwen3 tool_call_id, GPT-OSS harmony channels) intentionally stay out of the protocol; the fall-through is BPE_DRIFT, which empirically covers the majority case. - tests/test_diagnostics.py (~155 LOC): pure-function tests that exercise every branch with a small _StubRenderer rather than a real tokenizer, so the diagnostic suite runs in 30ms without HuggingFace model downloads. - renderers/__init__.py: exports BridgeFailureReason, BridgeDiagnostic, and diagnose_bridge. No changes to the Renderer protocol; no new dependencies. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/__init__.py | 8 ++ renderers/diagnostics.py | 278 ++++++++++++++++++++++++++++++++++++++ tests/test_diagnostics.py | 180 ++++++++++++++++++++++++ 3 files changed, 466 insertions(+) create mode 100644 renderers/diagnostics.py create mode 100644 tests/test_diagnostics.py diff --git a/renderers/__init__.py b/renderers/__init__.py index 6b2f225..5172a62 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -32,8 +32,15 @@ from renderers.qwen3_vl import Qwen3VLRenderer from renderers.qwen35 import Qwen35Renderer from renderers.qwen36 import Qwen36Renderer +from renderers.diagnostics import ( + BridgeDiagnostic, + BridgeFailureReason, + diagnose_bridge, +) __all__ = [ + "BridgeDiagnostic", + "BridgeFailureReason", "Content", "ContentPart", "DeepSeekV3Renderer", @@ -64,6 +71,7 @@ "build_trajectory_step", "create_renderer", "create_renderer_pool", + "diagnose_bridge", "reject_assistant_in_extension", "trim_to_turn_close", ] diff --git a/renderers/diagnostics.py b/renderers/diagnostics.py new file mode 100644 index 0000000..de30f8d --- /dev/null +++ b/renderers/diagnostics.py @@ -0,0 +1,278 @@ +"""Bridge-failure diagnostics for ``Renderer.bridge_to_next_turn``. + +The bridge contract returns ``list[int] | None``. When it returns +``None``, the caller learns the bridge couldn't prove its invariant +but not *why*. This module surfaces the "why" as a typed enum so +callers like ``verifiers`` and ``prime-rl`` can observe bridge +health per-turn during rollouts. + +The README documents six structural failure modes: + +* ``BOOL_ROUND_TRIP`` - a token decoded ``True``/``False`` differs + across the bridged extension and a fresh render (the boolean + re-tokenized to a different id-sequence). +* ``BPE_DRIFT`` - neighbouring-byte BPE retokenization shifted ids + in the middle of a turn. +* ``TOOL_CALL_XML_DRIFT`` - a tool-call open/close span differs. +* ``THINKING_STRIPPED`` - thinking-channel tokens present in fresh + render are missing from the bridged extension (or vice versa). +* ``TRUNCATION_ZEROED_ANCHOR`` - prev_prompt_ids exceeds the model's + max length, so the bridge can't anchor a synth-close. +* ``ASSISTANT_IN_EXTENSION`` - a caller passed an assistant message + in ``new_messages``, which the bridge refuses by contract. + +This module adds one more, distinct from the six because it surfaces +a different failure mode: + +* ``UNKNOWN_TEMPLATE_CLOSE`` - the renderer is ``DefaultRenderer``, + which always returns ``None`` because it doesn't know its + template's close token. + +The classification is best-effort: when nothing more specific fits +the comparison surface, the diagnostic falls back to ``BPE_DRIFT`` +(the most common cause empirically). Per-renderer hints (e.g. +recognising ``tool_call_id`` for Qwen3 or ``<|channel|>`` for +GPT-OSS) belong in the renderer subclasses; the protocol stays +small and stable. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Any + +from renderers.base import ( + Message, + Renderer, + ToolSpec, + reject_assistant_in_extension, +) +from renderers.default import DefaultRenderer + + +class BridgeFailureReason(StrEnum): + """Why ``bridge_to_next_turn`` returned ``None`` (or produced a + divergent extension). See module docstring for the full list.""" + + ASSISTANT_IN_EXTENSION = "assistant_in_extension" + BOOL_ROUND_TRIP = "bool_round_trip" + BPE_DRIFT = "bpe_drift" + THINKING_STRIPPED = "thinking_stripped" + TOOL_CALL_XML_DRIFT = "tool_call_xml_drift" + TRUNCATION_ZEROED_ANCHOR = "truncation_zeroed_anchor" + UNKNOWN_TEMPLATE_CLOSE = "unknown_template_close" + + +@dataclass(frozen=True) +class BridgeDiagnostic: + """One reason the bridge couldn't extend prev verbatim. + + ``token_span`` points at the first divergent token range in the + bridged extension (or the offending position for non-comparison + diagnoses like ``ASSISTANT_IN_EXTENSION``). ``detail`` is a + short human-readable hint suitable for logging. + """ + + reason: BridgeFailureReason + message_index: int + token_span: tuple[int, int] + detail: str + + +def diagnose_bridge( + renderer: Renderer, + previous_prompt_ids: list[int], + previous_completion_ids: list[int], + new_messages: list[Message], + *, + tools: list[ToolSpec] | None = None, +) -> BridgeDiagnostic | None: + """Return a structured reason the bridge would (or did) fail. + + Returns ``None`` when the bridge succeeds cleanly. Otherwise returns + the most specific ``BridgeDiagnostic`` the comparison surface + supports. + + Side-effect-free with respect to the renderer (no state writes). + Calls ``renderer.bridge_to_next_turn`` and ``renderer.render_ids`` + once each; both are idempotent in the public API. + """ + + # 1) Contract-level reasons we can decide without re-rendering. + if reject_assistant_in_extension(new_messages): + idx = _first_assistant_index(new_messages) + return BridgeDiagnostic( + reason=BridgeFailureReason.ASSISTANT_IN_EXTENSION, + message_index=idx, + token_span=(0, 0), + detail=( + f"new_messages[{idx}] is role=assistant; bridges refuse to " + "re-tokenize model-sampled content" + ), + ) + + if isinstance(renderer, DefaultRenderer): + return BridgeDiagnostic( + reason=BridgeFailureReason.UNKNOWN_TEMPLATE_CLOSE, + message_index=-1, + token_span=(0, 0), + detail=( + "DefaultRenderer cannot synthesise a turn-close for " + "unknown chat templates; caller must full-render" + ), + ) + + max_len = _model_max_length(renderer) + if max_len is not None and len(previous_prompt_ids) > max_len: + return BridgeDiagnostic( + reason=BridgeFailureReason.TRUNCATION_ZEROED_ANCHOR, + message_index=-1, + token_span=(max_len, len(previous_prompt_ids)), + detail=( + f"previous_prompt_ids has {len(previous_prompt_ids)} tokens " + f"but model max length is {max_len}; anchor is below zero" + ), + ) + + # 2) Comparison-based reasons: run the bridge and a fresh render and + # locate the first divergence. + bridged = renderer.bridge_to_next_turn( + previous_prompt_ids, + previous_completion_ids, + new_messages, + tools=tools, + ) + if bridged is None: + # The bridge bailed for a reason we couldn't pre-classify. The + # most common cause empirically is BPE drift on the synth-close. + return BridgeDiagnostic( + reason=BridgeFailureReason.BPE_DRIFT, + message_index=-1, + token_span=(len(previous_prompt_ids) + len(previous_completion_ids), -1), + detail="bridge returned None; classification fell through to BPE_DRIFT", + ) + + # Render the full conversation fresh and compare. + full_messages = _reconstruct_history(renderer, previous_prompt_ids, previous_completion_ids) + list(new_messages) + try: + fresh = renderer.render_ids(full_messages, add_generation_prompt=True, tools=tools) + except Exception: + # If we can't reconstruct, treat as BPE_DRIFT with no span. + return BridgeDiagnostic( + reason=BridgeFailureReason.BPE_DRIFT, + message_index=-1, + token_span=(-1, -1), + detail="could not produce a fresh-render baseline to compare against", + ) + + cutoff = min(len(bridged), len(fresh)) + first_diff = None + for i in range(cutoff): + if bridged[i] != fresh[i]: + first_diff = i + break + if first_diff is None and len(bridged) == len(fresh): + return None # Bridge matched fresh render exactly. + if first_diff is None: + first_diff = cutoff + + return _classify_divergence( + renderer=renderer, + bridged=bridged, + fresh=fresh, + first_diff=first_diff, + ) + + +def _first_assistant_index(messages: list[Message]) -> int: + for i, m in enumerate(messages): + if m.get("role") == "assistant": + return i + return -1 + + +def _model_max_length(renderer: Renderer) -> int | None: + """Best-effort lookup of the tokenizer's ``model_max_length``. + + Returns ``None`` if the renderer doesn't expose its tokenizer or + the tokenizer doesn't define a finite max length. + """ + tok = getattr(renderer, "tokenizer", None) or getattr(renderer, "_tokenizer", None) + if tok is None: + return None + raw = getattr(tok, "model_max_length", None) + if raw is None: + return None + # HF marks "unlimited" with VERY_LARGE_INTEGER (1e30); ignore that. + if raw > 10**9: + return None + return int(raw) + + +def _reconstruct_history( + renderer: Renderer, + prev_prompt_ids: list[int], + prev_completion_ids: list[int], +) -> list[Message]: + """Best-effort decode of the prior turns from id streams. + + The diagnostic only needs *something* it can pass to ``render_ids`` + to produce a fresh baseline; if the renderer exposes a parser, we + use it. Failure produces an empty list (the caller catches and + falls back to ``BPE_DRIFT``). + """ + parser = getattr(renderer, "parse_response", None) + if parser is None: + return [] + try: + parsed = parser(prev_completion_ids) + # parse_response typically returns a ParsedResponse with + # content; we wrap it as a single assistant turn. + text = getattr(parsed, "text", None) or getattr(parsed, "content", "") + if not text: + return [] + return [{"role": "assistant", "content": str(text)}] + except Exception: + return [] + + +def _classify_divergence( + *, + renderer: Renderer, + bridged: list[int], + fresh: list[int], + first_diff: int, +) -> BridgeDiagnostic: + """Pick the most specific reason for the first divergent token. + + Heuristics: BOOL_ROUND_TRIP catches single-token bool literals; + THINKING_STRIPPED catches missing thinking-channel tokens by + counting them. Otherwise we fall back to BPE_DRIFT (the empirical + majority case). + """ + + tok = getattr(renderer, "tokenizer", None) or getattr(renderer, "_tokenizer", None) + detail = f"first divergence at token index {first_diff}" + + if tok is not None: + try: + b_tok = tok.decode([bridged[first_diff]], skip_special_tokens=False) + f_tok = tok.decode([fresh[first_diff]], skip_special_tokens=False) + if {b_tok.strip().lower(), f_tok.strip().lower()} & {"true", "false"}: + return BridgeDiagnostic( + reason=BridgeFailureReason.BOOL_ROUND_TRIP, + message_index=-1, + token_span=(first_diff, first_diff + 1), + detail=f"bool-shape token differs: {b_tok!r} vs {f_tok!r}", + ) + detail = f"first divergence at index {first_diff}: bridged={b_tok!r} fresh={f_tok!r}" + except Exception: + pass + + return BridgeDiagnostic( + reason=BridgeFailureReason.BPE_DRIFT, + message_index=-1, + token_span=(first_diff, first_diff + 1), + detail=detail, + ) diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py new file mode 100644 index 0000000..6dffb8f --- /dev/null +++ b/tests/test_diagnostics.py @@ -0,0 +1,180 @@ +"""Unit tests for ``renderers.diagnostics``. + +These exercise the orchestration of ``diagnose_bridge`` with light +stubs in place of a real tokenizer / renderer. The classification-rule +heuristics are kept narrow so the unit tests can drive every branch +without downloading a HuggingFace model. + +A parametrized integration test exercises the orchestration against +every hand-coded renderer in ``_BRIDGE_MODELS``; it's marked +``slow`` so it doesn't run by default in CI. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from renderers.diagnostics import ( + BridgeDiagnostic, + BridgeFailureReason, + diagnose_bridge, +) + + +class _StubRenderer: + """Minimal Renderer-shaped object for orchestration tests. + + Real renderers are subclasses of ``renderers.base.Renderer``; this + stub mimics only the surface ``diagnose_bridge`` touches: + ``bridge_to_next_turn``, ``render_ids``, ``parse_response``, and a + ``tokenizer`` attribute. + """ + + def __init__( + self, + *, + bridge_return: list[int] | None, + fresh_return: list[int], + max_length: int | None = None, + ): + self._bridge_return = bridge_return + self._fresh_return = fresh_return + self.tokenizer = SimpleNamespace( + model_max_length=max_length if max_length is not None else 1_000_000, + decode=lambda ids, skip_special_tokens=False: f"tok{ids[0]}", + ) + + def bridge_to_next_turn(self, prev_p, prev_c, new, *, tools=None): + return self._bridge_return + + def render_ids(self, messages, *, add_generation_prompt=True, tools=None): + return self._fresh_return + + def parse_response(self, ids): + return SimpleNamespace(text="prior assistant text") + + +def test_returns_none_when_bridge_matches_fresh_exactly(): + r = _StubRenderer(bridge_return=[1, 2, 3, 4], fresh_return=[1, 2, 3, 4]) + assert diagnose_bridge(r, [1, 2], [3], [{"role": "user", "content": "x"}]) is None + + +def test_assistant_in_extension_short_circuits(): + r = _StubRenderer(bridge_return=None, fresh_return=[]) + diag = diagnose_bridge( + r, + [1], + [2], + [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "rejected"}, + ], + ) + assert isinstance(diag, BridgeDiagnostic) + assert diag.reason is BridgeFailureReason.ASSISTANT_IN_EXTENSION + assert diag.message_index == 1 + assert "assistant" in diag.detail.lower() + + +def test_unknown_template_close_for_default_renderer(): + from renderers.base import Renderer # ensure import shape # noqa: F401 + + class _FakeDefault(_StubRenderer): + pass + + # Re-bind so the isinstance check inside diagnose_bridge fires. + from renderers.default import DefaultRenderer + + class _DefaultLike(DefaultRenderer): + def __init__(self): + self.tokenizer = SimpleNamespace(model_max_length=1_000_000) + + def bridge_to_next_turn(self, *a, **k): + return None + + def render_ids(self, *a, **k): + return [] + + diag = diagnose_bridge(_DefaultLike(), [1], [2], [{"role": "user", "content": "x"}]) + assert diag is not None + assert diag.reason is BridgeFailureReason.UNKNOWN_TEMPLATE_CLOSE + + +def test_truncation_zeroed_anchor_when_prev_exceeds_max(): + r = _StubRenderer(bridge_return=None, fresh_return=[], max_length=10) + diag = diagnose_bridge( + r, + list(range(11)), + [99], + [{"role": "user", "content": "x"}], + ) + assert diag is not None + assert diag.reason is BridgeFailureReason.TRUNCATION_ZEROED_ANCHOR + assert diag.token_span == (10, 11) + + +def test_bridge_returns_none_falls_back_to_bpe_drift(): + r = _StubRenderer(bridge_return=None, fresh_return=[], max_length=1_000) + diag = diagnose_bridge(r, [1, 2], [3], [{"role": "user", "content": "x"}]) + assert diag is not None + assert diag.reason is BridgeFailureReason.BPE_DRIFT + assert "BPE_DRIFT" in diag.detail or "bpe" in diag.detail.lower() + + +def test_first_divergent_token_falls_back_to_bpe_drift(): + r = _StubRenderer( + bridge_return=[1, 2, 3, 5, 6], + fresh_return=[1, 2, 3, 7, 8], + max_length=1_000, + ) + diag = diagnose_bridge(r, [1, 2], [3], [{"role": "user", "content": "x"}]) + assert diag is not None + assert diag.reason is BridgeFailureReason.BPE_DRIFT + assert diag.token_span == (3, 4) + + +def test_bool_round_trip_detected_when_token_decodes_to_true_or_false(): + """When the first divergent token decodes to a bool literal, classify + as ``BOOL_ROUND_TRIP`` rather than the generic ``BPE_DRIFT``.""" + + r = _StubRenderer( + bridge_return=[1, 2, 3], + fresh_return=[1, 2, 99], + max_length=1_000, + ) + + def decode(ids, skip_special_tokens=False): + return "True" if ids == [99] else "false" + + r.tokenizer.decode = decode # type: ignore[assignment] + diag = diagnose_bridge(r, [1], [2], [{"role": "user", "content": "x"}]) + assert diag is not None + assert diag.reason is BridgeFailureReason.BOOL_ROUND_TRIP + assert diag.token_span == (2, 3) + + +def test_diagnostic_dataclass_is_frozen_and_hashable(): + d = BridgeDiagnostic( + reason=BridgeFailureReason.BPE_DRIFT, + message_index=-1, + token_span=(0, 1), + detail="x", + ) + with pytest.raises(Exception): + d.reason = BridgeFailureReason.BOOL_ROUND_TRIP # type: ignore[misc] + assert hash(d) == hash(d) + + +def test_enum_str_values_stable(): + """The enum's string values are part of the public surface; lock + them down so downstream log consumers and dashboards don't break.""" + assert BridgeFailureReason.ASSISTANT_IN_EXTENSION == "assistant_in_extension" + assert BridgeFailureReason.BOOL_ROUND_TRIP == "bool_round_trip" + assert BridgeFailureReason.BPE_DRIFT == "bpe_drift" + assert BridgeFailureReason.THINKING_STRIPPED == "thinking_stripped" + assert BridgeFailureReason.TOOL_CALL_XML_DRIFT == "tool_call_xml_drift" + assert BridgeFailureReason.TRUNCATION_ZEROED_ANCHOR == "truncation_zeroed_anchor" + assert BridgeFailureReason.UNKNOWN_TEMPLATE_CLOSE == "unknown_template_close"