Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 229 additions & 45 deletions renderers/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
"""Renderer-based generate client for vLLM 0.20's /inference/v1/generate.
"""Renderer-based generate client for vLLM 0.20 + Dynamo.

messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate
→ completion tokens → Renderer.parse_response() → structured message
Two transports, selected per-call via ``transport=`` parameter:

"prime_vllm_generate" (default)
messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate
→ completion tokens → Renderer.parse_response() → structured message
vLLM's TITO surface (server.py mounts the route in prime-rl).

"dynamo_chat_nvext"
messages → Renderer.render_ids() → token IDs → POST /v1/chat/completions
with ``nvext.token_data`` + ``nvext.extra_fields=["engine_data"]``
→ completion tokens via ``nvext.engine_data.completion_token_ids``
→ Renderer.parse_response() → structured message
Dynamo has no ``/inference/v1/generate`` route; this branch posts to
the standard OpenAI chat-completions surface and reads the engine
token IDs back via the PR #8119 ``nvext.engine_data`` channel.

When a RendererPool is passed instead of a single Renderer, the sync tokenization
and parsing work is offloaded to threads for parallel execution across rollouts.
Expand All @@ -14,7 +27,7 @@
import asyncio
import base64
import logging
from typing import Any, cast
from typing import Any, Literal, cast

import numpy as np
from openai import AsyncOpenAI, BadRequestError
Expand All @@ -23,6 +36,9 @@

_request_logger = logging.getLogger("renderers.client")

# Public type alias; matches verifiers.types.RendererTransport string set.
RendererTransport = Literal["prime_vllm_generate", "dynamo_chat_nvext"]


async def _maybe_offload(renderer: Renderer | RendererPool, fn):
"""Run sync renderer work on a thread iff ``renderer`` is a pool.
Expand All @@ -48,6 +64,7 @@ async def generate(
multi_modal_data: MultiModalData | None = None,
tools: list[ToolSpec] | None = None,
sampling_params: dict[str, Any] | None = None,
transport: RendererTransport = "prime_vllm_generate",
cache_salt: str | None = None,
priority: int | None = None,
extra_headers: dict[str, str] | None = None,
Expand Down Expand Up @@ -94,65 +111,120 @@ def _prepare():
sp["logprobs"] = 1
sp.setdefault("skip_special_tokens", False)

body: dict[str, Any] = {
"model": model,
"token_ids": prompt_ids,
"sampling_params": sp,
}
features = (
_build_mm_features(renderer, mm_data)
if mm_data and not mm_data.is_empty()
else None
)
if features is not None:
body["features"] = features
if cache_salt is not None:
body["cache_salt"] = cache_salt
if priority is not None:
body["priority"] = priority

# /inference/v1/generate is mounted at the server root, not under /v1
# like the OpenAI-compatible endpoints. Build an absolute URL so the
# AsyncOpenAI client doesn't prepend its automatic /v1.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
endpoint = f"{base}/inference/v1/generate"
_request_logger.debug(
"POST %s prompt_len=%d max_tokens=%s",
endpoint,
len(prompt_ids),
sp.get("max_tokens"),
)
post_kwargs: dict[str, Any] = {
"cast_to": cast(Any, dict[str, Any]),
"body": body,
}
if extra_headers:
post_kwargs["options"] = cast(Any, {"headers": extra_headers})
try:
data = await client.post(endpoint, **post_kwargs)
except BadRequestError as exc:
_log_overlong_prompt_diagnostic(

if transport == "dynamo_chat_nvext":
# Dynamo branch: POST /v1/chat/completions with nvext.token_data.
# Dynamo has no /inference/v1/generate route; the equivalent TITO
# surface lives on chat-completions via the ``nvext`` envelope
# (PR #8119: response token IDs come back under
# ``nvext.engine_data.completion_token_ids``).
if features is not None:
# Multimodal renderers ship a vLLM-shaped ``features`` payload
# to /inference/v1/generate. Dynamo's chat-completions surface
# doesn't accept that shape today; the renderer needs a
# different per-image transfer path for Dynamo. Until that
# ships, refuse rather than silently drop the image data.
raise NotImplementedError(
"Multimodal renderers are not yet supported on the "
"dynamo_chat_nvext transport. Use prime_vllm_generate or "
"stay on the token-client TITO path for VLMs."
)
data = await _post_dynamo_chat_nvext(
client=client,
model=model,
prompt_ids=prompt_ids,
sp=sp,
tools=tools,
cache_salt=cache_salt,
priority=priority,
extra_headers=extra_headers,
messages=messages,
max_tokens=sp.get("max_tokens"),
exc=exc,
)
raise
else:
# vLLM-native branch: POST /inference/v1/generate (vLLM 0.20 TITO).
body: dict[str, Any] = {
"model": model,
"token_ids": prompt_ids,
"sampling_params": sp,
}
if features is not None:
body["features"] = features
if cache_salt is not None:
body["cache_salt"] = cache_salt
if priority is not None:
body["priority"] = priority

# /inference/v1/generate is mounted at the server root, not under /v1
# like the OpenAI-compatible endpoints. Build an absolute URL so the
# AsyncOpenAI client doesn't prepend its automatic /v1.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
endpoint = f"{base}/inference/v1/generate"
_request_logger.debug(
"POST %s prompt_len=%d max_tokens=%s",
endpoint,
len(prompt_ids),
sp.get("max_tokens"),
)
post_kwargs: dict[str, Any] = {
"cast_to": cast(Any, dict[str, Any]),
"body": body,
}
if extra_headers:
post_kwargs["options"] = cast(Any, {"headers": extra_headers})
try:
data = await client.post(endpoint, **post_kwargs)
except BadRequestError as exc:
_log_overlong_prompt_diagnostic(
prompt_ids=prompt_ids,
messages=messages,
max_tokens=sp.get("max_tokens"),
exc=exc,
)
raise

choice = (data.get("choices") or [{}])[0]
completion_ids = choice.get("token_ids") or []
# Dynamo emits engine token IDs under ``nvext.engine_data.completion_token_ids``
# (PR #8119 channel) rather than ``choice.token_ids``. Try both — vLLM's
# /inference/v1/generate writes the top-level shape; Dynamo's
# /v1/chat/completions writes the nested one. The first present wins.
completion_ids = choice.get("token_ids")
if not completion_ids:
nvext_resp = data.get("nvext") or {}
engine_data = nvext_resp.get("engine_data") or {}
completion_ids = (
engine_data.get("completion_token_ids")
or nvext_resp.get("completion_token_ids")
or []
)
completion_ids = list(completion_ids or [])

parsed = await _maybe_offload(
renderer, lambda: renderer.parse_response(completion_ids)
)

# ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]}
# ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]}.
# Same shape on both transports (Dynamo aliases the standard OpenAI
# logprobs field). engine_data.completion_logprobs is a fallback when
# the OpenAI-style logprobs array is absent.
raw_logprobs = choice.get("logprobs") or {}
content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None
completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []]
if not completion_logprobs:
nvext_resp = data.get("nvext") or {}
engine_data = nvext_resp.get("engine_data") or {}
engine_lp = engine_data.get("completion_logprobs") or []
if engine_lp:
completion_logprobs = [float(x) for x in engine_lp]

routed_experts = None
raw_re = choice.get("routed_experts")
raw_re = choice.get("routed_experts") or (data.get("nvext") or {}).get(
"routed_experts"
)
if isinstance(raw_re, dict) and "data" in raw_re and "shape" in raw_re:
routed_experts = (
np.frombuffer(base64.b85decode(raw_re["data"]), dtype=np.int32)
Expand All @@ -164,13 +236,14 @@ def _prepare():
# never "tool_calls" (a chat-completions concept). Promote stop→tool_calls
# when we extracted tool calls client-side, so OpenAI-compatible agent
# loops continue past the tool turn instead of treating the response as
# final.
# final. Dynamo's chat-completions surface CAN return "tool_calls"
# directly, so this promotion is a no-op there.
finish_reason = choice.get("finish_reason")
if parsed.tool_calls and finish_reason == "stop":
finish_reason = "tool_calls"

return {
"request_id": data.get("request_id") or "",
"request_id": data.get("request_id") or data.get("id") or "",
"prompt_ids": list(prompt_ids),
"completion_ids": list(completion_ids),
"completion_logprobs": completion_logprobs,
Expand All @@ -186,6 +259,117 @@ def _prepare():
}


async def _post_dynamo_chat_nvext(
*,
client: AsyncOpenAI,
model: str,
prompt_ids: list[int],
sp: dict[str, Any],
tools: list[ToolSpec] | None,
cache_salt: str | None,
priority: int | None,
extra_headers: dict[str, str] | None,
messages: list[Message],
) -> dict[str, Any]:
"""POST ``prompt_ids`` to Dynamo's ``/v1/chat/completions`` route.

Mirrors ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat_nvext``
in shape, so the wire payload is identical whether the rollout goes
through the token client or the renderer client. Anything that lands
on Dynamo's chat-completions surface, lands here.

Wire shape:

- ``nvext.token_data``: pre-tokenized prompt; Dynamo's preprocessor
skips tokenization when present.
- ``nvext.extra_fields = ["engine_data", "routed_experts"]``: opt-in
to Dynamo's engine metadata and router replay channels.
- ``messages``: placeholder (single user message). Dynamo ignores
when ``token_data`` is present, but the OpenAI schema requires
a non-empty messages array, so we send a 1-token stub.
- ``stop_token_ids`` / ``cache_salt`` / ``logprobs`` / backend sampling
hints ride as passthrough fields accepted by Dynamo's
``PASSTHROUGH_EXTRA_FIELDS`` allowlist.
"""
# Standard OpenAI fields that map 1:1 onto Dynamo's chat-completions
# request schema (validate.rs accepts them natively).
body: dict[str, Any] = {
"model": model,
# Single placeholder user message; ignored when token_data is set.
"messages": [{"role": "user", "content": ""}],
"stream": False,
"nvext": {
"token_data": list(prompt_ids),
"extra_fields": ["engine_data", "routed_experts"],
},
}
if tools:
body["tools"] = tools
if cache_salt is not None:
body["nvext"]["cache_salt"] = cache_salt
if priority is not None:
body["nvext"]["agent_hints"] = {"priority": priority}

# Surface standard sampling params at top level (Dynamo's schema
# recognizes them natively, so they flow into SamplingOptions cleanly).
promotable = (
"max_tokens",
"temperature",
"top_p",
"top_k",
"min_p",
"seed",
"n",
"repetition_penalty",
"min_tokens",
"logprobs",
"skip_special_tokens",
)
for key in promotable:
value = sp.get(key)
if value is None:
continue
if key == "max_tokens":
body["max_completion_tokens"] = value
elif key == "logprobs":
# Standard OpenAI shape: logprobs=true + top_logprobs=N. The
# vLLM TITO surface accepts ``logprobs=N`` (int); Dynamo's
# chat-completions schema requires the bool+top_logprobs split.
body["logprobs"] = True
if isinstance(value, int) and value > 1:
body["top_logprobs"] = value
else:
body[key] = value

# Pass-through hints that Dynamo's PASSTHROUGH_EXTRA_FIELDS allowlist
# accepts (stop_token_ids, token constraints, backend sampling toggles).
for key in (
"stop_token_ids",
"bad_words_token_ids",
"allowed_token_ids",
"detokenize",
):
if sp.get(key) is not None:
body[key] = sp[key]

post_kwargs: dict[str, Any] = {
"cast_to": cast(Any, dict[str, Any]),
"body": body,
}
if extra_headers:
post_kwargs["options"] = cast(Any, {"headers": extra_headers})
try:
return await client.post("/chat/completions", **post_kwargs)
except BadRequestError as exc:
_log_overlong_prompt_diagnostic(
prompt_ids=prompt_ids,
messages=messages,
max_tokens=sp.get("max_tokens"),
exc=exc,
)
raise


def _build_mm_features(
renderer: Renderer | RendererPool,
mm_data: MultiModalData,
Expand Down
Loading