From 2a8fba0640d8be00071edb97c2d0f5f2138ab6c0 Mon Sep 17 00:00:00 2001 From: Christian Date: Tue, 26 May 2026 18:15:51 -0700 Subject: [PATCH] mm item copy --- renderers/qwen35.py | 6 +-- renderers/qwen3_vl.py | 6 +-- tests/test_multimodal.py | 97 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 6 deletions(-) diff --git a/renderers/qwen35.py b/renderers/qwen35.py index abcacec..32efd9f 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -814,17 +814,17 @@ def flush_buf() -> None: # Merge prev mm_data (images from earlier turns) with the new turn's. merged_hashes: dict[str, list[str]] = ( - dict(previous_multi_modal_data.mm_hashes) + {modality: list(vals) for modality, vals in previous_multi_modal_data.mm_hashes.items()} if previous_multi_modal_data else {} ) merged_placeholders: dict[str, list[PlaceholderRange]] = ( - dict(previous_multi_modal_data.mm_placeholders) + {modality: list(vals) for modality, vals in previous_multi_modal_data.mm_placeholders.items()} if previous_multi_modal_data else {} ) merged_items: dict[str, list[dict[str, Any]]] = ( - dict(previous_multi_modal_data.mm_items) + {modality: list(vals) for modality, vals in previous_multi_modal_data.mm_items.items()} if previous_multi_modal_data else {} ) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 7287159..ee483ab 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -804,17 +804,17 @@ def render_media_content(content: Any) -> None: # Merge prev mm_data with the new turn's items. merged_hashes = ( - dict(previous_multi_modal_data.mm_hashes) + {modality: list(vals) for modality, vals in previous_multi_modal_data.mm_hashes.items()} if previous_multi_modal_data else {} ) merged_placeholders = ( - dict(previous_multi_modal_data.mm_placeholders) + {modality: list(vals) for modality, vals in previous_multi_modal_data.mm_placeholders.items()} if previous_multi_modal_data else {} ) merged_items = ( - dict(previous_multi_modal_data.mm_items) + {modality: list(vals) for modality, vals in previous_multi_modal_data.mm_items.items()} if previous_multi_modal_data else {} ) diff --git a/tests/test_multimodal.py b/tests/test_multimodal.py index 28984e4..dd582f7 100644 --- a/tests/test_multimodal.py +++ b/tests/test_multimodal.py @@ -637,6 +637,103 @@ def test_multimodal_bridge_extends_and_carries_mm_data( ) +@pytest.mark.parametrize( + "mm_model_name,modality", _CASES, ids=[f"{m}|{mo}" for m, mo in _CASES] +) +def test_multimodal_bridge_does_not_mutate_previous_mm_data( + mm_model_name, modality, tiny_image +): + """``bridge_to_next_turn`` must not mutate ``previous_multi_modal_data``. + + Regression for a shallow-copy bug: ``dict(prev.mm_items)`` copies the + outer mapping but leaves each per-modality list aliased to the + original. The bridge then ``.extend(...)`` on that list, mutating + the prior turn's ``MultiModalData`` in place. Callers that retain + the prior ``RenderedTokens`` (e.g. trainers that keep per-step + snapshots for loss reconstruction) silently see their earlier + turns' image lists grow on every bridge. + """ + if not _hf_snapshot_cached(mm_model_name): + pytest.skip(f"{mm_model_name}: HF snapshot not cached locally") + + kit = _modality_kit(modality, mm_model_name) + tokenizer, _, renderer = _load_processor_and_renderer(mm_model_name) + + initial = [ + { + "role": "user", + "content": [ + kit["make_part"](tiny_image), + {"type": "text", "text": "Turn one."}, + ], + } + ] + new = [ + { + "role": "user", + "content": [ + kit["make_part"](tiny_image), + {"type": "text", "text": "Turn two."}, + ], + } + ] + + initial_rendered = renderer.render(initial, add_generation_prompt=True) + assert initial_rendered.multi_modal_data is not None + prev_mm = initial_rendered.multi_modal_data + + # Snapshot the prior lists' identities and contents BEFORE bridging. + prev_items_list = prev_mm.mm_items.get(modality, []) + prev_placeholders_list = prev_mm.mm_placeholders.get(modality, []) + prev_hashes_list = prev_mm.mm_hashes.get(modality, []) + items_id_before = id(prev_items_list) + placeholders_id_before = id(prev_placeholders_list) + hashes_id_before = id(prev_hashes_list) + items_snapshot = list(prev_items_list) + placeholders_snapshot = list(prev_placeholders_list) + hashes_snapshot = list(prev_hashes_list) + + im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + completion_ids = tokenizer.encode("Saw it.", add_special_tokens=False) + [im_end_id] + + bridged = renderer.bridge_to_next_turn( + previous_prompt_ids=initial_rendered.token_ids, + previous_completion_ids=completion_ids, + new_messages=new, + previous_multi_modal_data=prev_mm, + ) + assert bridged is not None and getattr(bridged, "multi_modal_data", None) is not None + + # The prior MultiModalData must be untouched. + assert prev_mm.mm_items.get(modality, []) == items_snapshot, ( + f"{mm_model_name} / {modality}: bridge mutated previous mm_items list " + f"(expected len {len(items_snapshot)}, got {len(prev_mm.mm_items.get(modality, []))})" + ) + assert prev_mm.mm_placeholders.get(modality, []) == placeholders_snapshot, ( + f"{mm_model_name} / {modality}: bridge mutated previous mm_placeholders list" + ) + assert prev_mm.mm_hashes.get(modality, []) == hashes_snapshot, ( + f"{mm_model_name} / {modality}: bridge mutated previous mm_hashes list" + ) + + # And the bridged data's inner lists must not be the same objects + # as the prior turn's — otherwise a later second bridge would mutate + # this turn's lists, too. + bridged_items = bridged.multi_modal_data.mm_items.get(modality, []) + bridged_placeholders = bridged.multi_modal_data.mm_placeholders.get(modality, []) + bridged_hashes = bridged.multi_modal_data.mm_hashes.get(modality, []) + assert id(bridged_items) != items_id_before, ( + f"{mm_model_name} / {modality}: bridged mm_items list is aliased to " + "previous_multi_modal_data.mm_items — outer-dict-only copy detected" + ) + assert id(bridged_placeholders) != placeholders_id_before, ( + f"{mm_model_name} / {modality}: bridged mm_placeholders list aliased to prior" + ) + assert id(bridged_hashes) != hashes_id_before, ( + f"{mm_model_name} / {modality}: bridged mm_hashes list aliased to prior" + ) + + def test_modality_registry_models_route_to_renderer(): """Every model in ``MULTIMODAL_MODELS`` resolves to a concrete renderer via ``create_renderer(renderer='auto')``. Guards against drift between