Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions tensorrt_llm/serve/openai_disagg_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,18 +251,35 @@ async def merge_streaming_responses(self, ctx_response,
if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length", "not_finished"]:
yield "data: [DONE]\n\n".encode('utf-8')
else:
# Then yield the generation responses
await self._increment_metric("gen_total_requests")
if isinstance(gen_req, CompletionRequest):
gen_response = await self.send_completion_request(gen_server, gen_req)
elif isinstance(gen_req, ChatCompletionRequest):
gen_response = await self.send_chat_request(gen_server, gen_req)
else:
raise TypeError("Invalid request type: {type(gen_req).__name__}")

async for chunk in gen_response.body_iterator:
yield chunk
await self._increment_metric("gen_completed_requests")
token_count = 0
for attempt in range(self.max_retries + 1):
# Note that the retry here is needed as well as the retry in send_request function.
# This is because we are using the `body_iterator` to stream the response, and it may raise an exception if the connection is lost.
# The retry is needed so we send the request again to the same server.
try:
# Then yield the generation responses
await self._increment_metric("gen_total_requests")
if isinstance(gen_req, CompletionRequest):
gen_response = await self.send_completion_request(gen_server, gen_req)
elif isinstance(gen_req, ChatCompletionRequest):
gen_response = await self.send_chat_request(gen_server, gen_req)
else:
raise TypeError(f"Invalid request type: {type(gen_req).__name__}")
async for chunk in gen_response.body_iterator:
token_count += len(chunk)
yield chunk
await self._increment_metric("gen_completed_requests")
except (aiohttp.ClientError, OSError) as e:
# We will retry if no tokens have been yielded as otherwise we will need to discard the tokens
# that have been yielded.
if attempt == self.max_retries or token_count > 0:
raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error after {token_count} tokens") from e
logger.warning(f"Client error: {e} - retry {attempt} of {self.max_retries}")
# TODO : add a configurable retry interval
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error encountered while streaming generation response: {e}")
raise

finally:
await self.gen_router.finish_request(gen_req)
Expand Down