diff --git a/workers/comfyui-json/server.py b/workers/comfyui-json/server.py index cd970a3..e51a425 100644 --- a/workers/comfyui-json/server.py +++ b/workers/comfyui-json/server.py @@ -33,33 +33,39 @@ async def generate_client_response( client_request: web.Request, model_response: ClientResponse ) -> Union[web.Response, web.StreamResponse]: - # Check if the response is actually streaming based on response headers/content-type - is_streaming_response = ( - model_response.content_type == "text/event-stream" - or model_response.content_type == "application/x-ndjson" - or model_response.headers.get("Transfer-Encoding") == "chunked" - or "stream" in model_response.content_type.lower() - ) - - if is_streaming_response: - log.debug("Detected streaming response...") - res = web.StreamResponse() - res.content_type = model_response.content_type - await res.prepare(client_request) - async for chunk in model_response.content: - await res.write(chunk) - await res.write_eof() - log.debug("Done streaming response") - return res - else: - log.debug("Detected non-streaming response...") - content = await model_response.read() - return web.Response( - body=content, - status=model_response.status, - content_type=model_response.content_type - ) - + match model_response.status: + case 200: + log.debug("SUCCESS") + # Check if the response is actually streaming based on response headers/content-type + is_streaming_response = ( + model_response.content_type == "text/event-stream" + or model_response.content_type == "application/x-ndjson" + or model_response.headers.get("Transfer-Encoding") == "chunked" + or "stream" in model_response.content_type.lower() + ) + + if is_streaming_response: + log.debug("Detected streaming response...") + res = web.StreamResponse() + res.content_type = model_response.content_type + await res.prepare(client_request) + async for chunk in model_response.content: + await res.write(chunk) + await res.write_eof() + log.debug("Done streaming response") + return res + else: + log.debug("Detected non-streaming response...") + content = await model_response.read() + return web.Response( + body=content, + status=model_response.status, + content_type=model_response.content_type + ) + case code: + log.debug(f"Model responded with error {code}") + return web.Response(status=code) + @dataclasses.dataclass class ComfyWorkflowHandler(EndpointHandler[ComfyWorkflowData]):