Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions renderers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import asyncio
import base64
import logging
from dataclasses import dataclass
from functools import reduce
from operator import mul
from typing import Any, cast

import numpy as np
Expand All @@ -24,6 +27,28 @@
_request_logger = logging.getLogger("renderers.client")


@dataclass(frozen=True)
class _FallbackPlaceholderRange:
offset: int
length: int
is_embed: Any = None


@dataclass
class _FallbackMultiModalFieldElem:
data: Any
field: Any = None


@dataclass
class _FallbackMultiModalFeatureSpec:
data: dict[str, _FallbackMultiModalFieldElem] | None
modality: str
identifier: str
mm_position: _FallbackPlaceholderRange
mm_hash: str | None = None


async def _run_pooled(pool: RendererPool, fn):
def _work():
with pool.checkout() as r:
Expand All @@ -32,6 +57,261 @@ def _work():
return await asyncio.to_thread(_work)


def _build_mm_features(renderer_cls: type, mm_data: Any) -> list[Any] | None:
"""Build vLLM multimodal feature specs for renderer-native payloads."""
from renderers.qwen3_vl import Qwen3VLRenderer
from renderers.qwen35 import Qwen35Renderer

if issubclass(renderer_cls, (Qwen3VLRenderer, Qwen35Renderer)):
# Qwen3-VL and Qwen3.5 both emit Qwen2-VL-family image payloads:
# pixel_values plus image_grid_thw. All seven current Qwen3.5 sizes
# use merge_size=2; move this to renderer metadata when that API lands.
return _build_qwen_vl_features(mm_data, spatial_merge_size=2)

raise NotImplementedError(f"No multimodal feature builder for {renderer_cls!r}")


def _build_qwen_vl_features(
mm_data: Any, *, spatial_merge_size: int
) -> list[Any] | None:
image_payloads = _image_payloads(mm_data)
if not image_payloads:
return None

try:
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldConfig,
PlaceholderRange,
)
except Exception:
return _build_fallback_qwen_vl_features(
image_payloads, spatial_merge_size=spatial_merge_size
)

features: list[Any] = []
next_offset = 0
for payload_idx, payload in enumerate(image_payloads):
pixel_values = _tensor_data(payload["pixel_values"])
image_grid_thw = _image_grid_tensor(payload["image_grid_thw"])
grid_rows = _grid_rows(image_grid_thw)
sizes = [_grid_prod(row) for row in grid_rows]

field_elems = MultiModalFieldConfig.flat_from_sizes(
"image", _tensor(sizes, like=image_grid_thw)
).field.build_elems("image", "pixel_values", pixel_values)
grid_elems = MultiModalFieldConfig.batched("image").field.build_elems(
"image", "image_grid_thw", image_grid_thw
)

for image_idx, (pixel_elem, grid_elem, grid_row) in enumerate(
zip(field_elems, grid_elems, grid_rows, strict=True)
):
length = _grid_prod(grid_row) // (spatial_merge_size**2)
mm_position = _placeholder_range(
payload,
image_idx,
default_offset=next_offset,
default_length=length,
placeholder_cls=PlaceholderRange,
)
next_offset = mm_position.offset + mm_position.length
feature_kwargs = {
"data": {
"pixel_values": pixel_elem,
"image_grid_thw": grid_elem,
},
"modality": "image",
"identifier": _identifier(payload, payload_idx, image_idx),
"mm_position": mm_position,
"mm_hash": _mm_hash(payload),
}
try:
features.append(MultiModalFeatureSpec(**feature_kwargs))
except TypeError:
feature_kwargs.pop("mm_hash")
features.append(MultiModalFeatureSpec(**feature_kwargs))

return features


def _build_fallback_qwen_vl_features(
image_payloads: list[dict[str, Any]], *, spatial_merge_size: int
) -> list[_FallbackMultiModalFeatureSpec]:
features: list[_FallbackMultiModalFeatureSpec] = []
next_offset = 0
for payload_idx, payload in enumerate(image_payloads):
for image_idx, grid_row in enumerate(_grid_rows(payload["image_grid_thw"])):
length = _grid_prod(grid_row) // (spatial_merge_size**2)
mm_position = _placeholder_range(
payload,
image_idx,
default_offset=next_offset,
default_length=length,
placeholder_cls=_FallbackPlaceholderRange,
)
next_offset = mm_position.offset + mm_position.length
features.append(
_FallbackMultiModalFeatureSpec(
data={
"pixel_values": _FallbackMultiModalFieldElem(
payload["pixel_values"]
),
"image_grid_thw": _FallbackMultiModalFieldElem(grid_row),
},
modality="image",
identifier=_identifier(payload, payload_idx, image_idx),
mm_position=mm_position,
mm_hash=_mm_hash(payload),
)
)
return features


def _image_payloads(mm_data: Any) -> list[dict[str, Any]]:
if mm_data is None:
return []

image_data = _get(mm_data, "image")
if image_data is None and _get(mm_data, "pixel_values") is not None:
image_data = mm_data
if image_data is None:
return []

if _is_pixel_grid_pair(image_data):
image_data = [image_data]
elif isinstance(image_data, dict) and "pixel_values" in image_data:
image_data = [image_data]

payloads: list[dict[str, Any]] = []
for item in image_data:
if _is_pixel_grid_pair(item):
pixel_values, image_grid_thw = item
payloads.append(
{"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}
)
else:
payloads.append(
{
"pixel_values": _get(item, "pixel_values"),
"image_grid_thw": _get(item, "image_grid_thw"),
"mm_position": _get(item, "mm_position"),
"mm_positions": _get(item, "mm_positions"),
"offset": _get(item, "offset"),
"identifier": _get(item, "identifier"),
"mm_hash": _get(item, "mm_hash"),
}
)

return [
payload
for payload in payloads
if payload["pixel_values"] is not None and payload["image_grid_thw"] is not None
]


def _is_pixel_grid_pair(value: Any) -> bool:
return isinstance(value, tuple) and len(value) == 2


def _get(value: Any, key: str) -> Any:
if isinstance(value, dict):
return value.get(key)
return getattr(value, key, None)


def _grid_rows(image_grid_thw: Any) -> list[Any]:
rows = _to_list(image_grid_thw)
if not rows:
return []
if all(isinstance(x, int | float) for x in rows):
return [rows]
return rows


def _grid_prod(grid_row: Any) -> int:
return int(reduce(mul, (int(x) for x in _to_list(grid_row)), 1))


def _to_list(value: Any) -> list[Any]:
if hasattr(value, "tolist"):
value = value.tolist()
if isinstance(value, tuple):
return list(value)
if isinstance(value, list):
return value
return [value]


def _tensor(value: list[int], *, like: Any) -> Any:
try:
import torch

device = getattr(like, "device", None)
return torch.as_tensor(value, device=device)
except Exception:
return value


def _tensor_data(value: Any) -> Any:
try:
import torch

return torch.as_tensor(value)
except Exception:
return value


def _image_grid_tensor(value: Any) -> Any:
tensor = _tensor_data(value)
if hasattr(tensor, "ndim") and tensor.ndim == 1:
return tensor.unsqueeze(0)
return tensor


def _placeholder_range(
payload: dict[str, Any],
image_idx: int,
*,
default_offset: int,
default_length: int,
placeholder_cls: Any,
) -> Any:
mm_position = _indexed(_get(payload, "mm_positions"), image_idx) or _get(
payload, "mm_position"
)
if mm_position is not None:
offset = _get(mm_position, "offset")
length = _get(mm_position, "length")
if offset is not None and length is not None:
return placeholder_cls(offset=int(offset), length=int(length))

offset = _indexed(_get(payload, "offset"), image_idx)
if offset is None:
offset = default_offset
return placeholder_cls(offset=int(offset), length=default_length)


def _indexed(value: Any, idx: int) -> Any:
if value is None or isinstance(value, str):
return value
if isinstance(value, list | tuple):
return value[idx] if idx < len(value) else None
return value


def _identifier(payload: dict[str, Any], payload_idx: int, image_idx: int) -> str:
identifier = _indexed(_get(payload, "identifier"), image_idx)
if identifier is not None:
return str(identifier)
return f"image-{payload_idx}-{image_idx}"


def _mm_hash(payload: dict[str, Any]) -> str | None:
value = _get(payload, "mm_hash")
return str(value) if value is not None else None


async def generate(
*,
client: AsyncOpenAI,
Expand Down
25 changes: 25 additions & 0 deletions tests/test_client_mm_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from renderers.client import _build_mm_features
from renderers.qwen35 import Qwen35Renderer


def test_build_mm_features_dispatches_qwen35_renderer():
features = _build_mm_features(
Qwen35Renderer,
{
"image": {
"pixel_values": [[1.0], [2.0], [3.0], [4.0]],
"image_grid_thw": [1, 4, 4],
"offset": 7,
"identifier": "image-0",
}
},
)

assert features is not None
assert len(features) == 1
feature = features[0]
assert feature.modality == "image"
assert feature.identifier == "image-0"
assert feature.mm_position.offset == 7
assert feature.mm_position.length == 4
assert set(feature.data) == {"pixel_values", "image_grid_thw"}