From 45909c6b3ec2e31cbe90d8196d874839e5b7b826 Mon Sep 17 00:00:00 2001 From: jlebensold Date: Mon, 25 Aug 2025 12:50:00 -0400 Subject: [PATCH] timeout fix --- src/swerex/runtime/remote.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/src/swerex/runtime/remote.py b/src/swerex/runtime/remote.py index 3177423..77bf412 100644 --- a/src/swerex/runtime/remote.py +++ b/src/swerex/runtime/remote.py @@ -114,6 +114,27 @@ def _handle_transfer_exception(self, exc_transfer: _ExceptionTransfer) -> None: exception.extra_info = exc_transfer.extra_info raise exception from None + def _compute_http_timeout(self, payload: BaseModel | None) -> aiohttp.ClientTimeout: + """Compute the HTTP client timeout based on the caller-provided timeout. + + Rules: + - If payload has a numeric `timeout`, use 1.5x of that. + - Else if payload has `startup_timeout`, use 1.5x of that. + - Else fall back to 1.5x of the runtime config timeout. + - If the selected base timeout is None, disable the total HTTP timeout. + """ + base_timeout: float | None + if payload is not None: + base_timeout = getattr(payload, "timeout", None) + if base_timeout is None: + base_timeout = getattr(payload, "startup_timeout", None) + else: + base_timeout = self._config.timeout + + if base_timeout is None: + return aiohttp.ClientTimeout(total=None) + return aiohttp.ClientTimeout(total=base_timeout * 1.5) + async def _handle_response_errors(self, response: aiohttp.ClientResponse) -> None: """Raise exceptions found in the request response.""" if response.status == 511: @@ -137,7 +158,7 @@ async def is_alive(self, *, timeout: float | None = None) -> IsAliveResponse: async with session.get( f"{self._api_url}/is_alive", headers=self._headers, - timeout=aiohttp.ClientTimeout(total=timeout_value), + timeout=aiohttp.ClientTimeout(total=(timeout_value * 1.5) if timeout_value else None), ) as response: if response.status == 200: data = await response.json() @@ -181,6 +202,7 @@ async def _request(self, endpoint: str, payload: BaseModel | None, output_class: request_url, json=payload.model_dump() if payload else None, headers=headers, + timeout=self._compute_http_timeout(payload), ) as resp: await self._handle_response_errors(resp) return output_class(**await resp.json()) @@ -240,7 +262,10 @@ async def upload(self, request: UploadRequest) -> UploadResponse: data.add_field("unzip", "true") async with session.post( - f"{self._api_url}/upload", data=data, headers=self._headers + f"{self._api_url}/upload", + data=data, + headers=self._headers, + timeout=self._compute_http_timeout(request), ) as response: await self._handle_response_errors(response) return UploadResponse(**(await response.json())) @@ -253,7 +278,12 @@ async def upload(self, request: UploadRequest) -> UploadResponse: data.add_field("target_path", request.target_path) data.add_field("unzip", "false") - async with session.post(f"{self._api_url}/upload", data=data, headers=self._headers) as response: + async with session.post( + f"{self._api_url}/upload", + data=data, + headers=self._headers, + timeout=self._compute_http_timeout(request), + ) as response: await self._handle_response_errors(response) return UploadResponse(**(await response.json())) else: