diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 2b5f7dc59c0..d59e2c92e22 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -251,18 +251,35 @@ async def merge_streaming_responses(self, ctx_response, if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length", "not_finished"]: yield "data: [DONE]\n\n".encode('utf-8') else: - # Then yield the generation responses - await self._increment_metric("gen_total_requests") - if isinstance(gen_req, CompletionRequest): - gen_response = await self.send_completion_request(gen_server, gen_req) - elif isinstance(gen_req, ChatCompletionRequest): - gen_response = await self.send_chat_request(gen_server, gen_req) - else: - raise TypeError("Invalid request type: {type(gen_req).__name__}") - - async for chunk in gen_response.body_iterator: - yield chunk - await self._increment_metric("gen_completed_requests") + token_count = 0 + for attempt in range(self.max_retries + 1): + # Note that the retry here is needed as well as the retry in send_request function. + # This is because we are using the `body_iterator` to stream the response, and it may raise an exception if the connection is lost. + # The retry is needed so we send the request again to the same server. + try: + # Then yield the generation responses + await self._increment_metric("gen_total_requests") + if isinstance(gen_req, CompletionRequest): + gen_response = await self.send_completion_request(gen_server, gen_req) + elif isinstance(gen_req, ChatCompletionRequest): + gen_response = await self.send_chat_request(gen_server, gen_req) + else: + raise TypeError(f"Invalid request type: {type(gen_req).__name__}") + async for chunk in gen_response.body_iterator: + token_count += len(chunk) + yield chunk + await self._increment_metric("gen_completed_requests") + except (aiohttp.ClientError, OSError) as e: + # We will retry if no tokens have been yielded as otherwise we will need to discard the tokens + # that have been yielded. + if attempt == self.max_retries or token_count > 0: + raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error after {token_count} tokens") from e + logger.warning(f"Client error: {e} - retry {attempt} of {self.max_retries}") + # TODO : add a configurable retry interval + await asyncio.sleep(1) + except Exception as e: + logger.error(f"Error encountered while streaming generation response: {e}") + raise finally: await self.gen_router.finish_request(gen_req)