diff --git a/curl_cffi/_asyncio_selector.py b/curl_cffi/_asyncio_selector.py index 13b06a47..df180939 100644 --- a/curl_cffi/_asyncio_selector.py +++ b/curl_cffi/_asyncio_selector.py @@ -209,8 +209,8 @@ def _run_select(self) -> None: rs, _, _ = select.select([self._waker_r.fileno()], [], [], 0) if rs: ws = [] - else: - raise + # If we're here, the socket was probably closed + # Do not re-raise else: raise diff --git a/curl_cffi/requests/websockets.py b/curl_cffi/requests/websockets.py index 28373d24..645017a1 100644 --- a/curl_cffi/requests/websockets.py +++ b/curl_cffi/requests/websockets.py @@ -139,7 +139,7 @@ def _unpack_close_frame(frame: bytes) -> tuple[int, str]: "Invalid close frame", WsCloseCode.PROTOCOL_ERROR ) from e else: - if code < 3000 and (code not in WsCloseCode or code == 1005): + if code < 3000 and (code not in WsCloseCode._value2member_map_ or code == 1005): raise WebSocketError( "Invalid close code", WsCloseCode.PROTOCOL_ERROR ) @@ -567,12 +567,15 @@ async def recv_fragment( Args: timeout: how many seconds to wait before giving up. """ - if self.closed: - raise WebSocketClosed("WebSocket is closed") if self._recv_lock.locked(): raise TypeError("Concurrent call to recv_fragment() is not allowed") async with self._recv_lock: + # We must check the closed state after the last asyncio tick (i.e. the above async with call) + # as a race condition arises where the websocket is not yet closed until we're inside here + if self.closed: + raise WebSocketClosed("WebSocket is closed") + try: chunk, frame = await asyncio.wait_for( self.loop.run_in_executor(None, self.curl.ws_recv), timeout @@ -606,30 +609,45 @@ async def recv(self, *, timeout: Optional[float] = None) -> tuple[bytes, int]: timeout: how many seconds to wait before giving up. """ loop = self.loop - chunks = [] - flags = 0 - sock_fd = await loop.run_in_executor( - None, self.curl.getinfo, CurlInfo.ACTIVESOCKET - ) - if sock_fd == CURL_SOCKET_BAD: - raise WebSocketError( - "Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE + async def _inner_recv() -> tuple[bytes, int]: + chunks = [] + flags = 0 + + # We must check the closed state after the last asyncio tick (i.e. the above async with call) + # as a race condition arises where the websocket is not yet closed until we're inside here + if self.closed: + raise WebSocketClosed("WebSocket is closed") + + sock_fd = await loop.run_in_executor( + None, self.curl.getinfo, CurlInfo.ACTIVESOCKET ) - while True: - try: - chunk, frame = await self.recv_fragment(timeout=timeout) - flags = frame.flags - chunks.append(chunk) - if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0: - break - except CurlError as e: - if e.code == CurlECode.AGAIN: - await aselect(sock_fd, loop=loop, timeout=timeout) - else: - raise + if sock_fd == CURL_SOCKET_BAD: + raise WebSocketError( + "Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE + ) - return b"".join(chunks), flags + while True: + try: + chunk, frame = await self.recv_fragment(timeout=timeout) + flags = frame.flags + chunks.append(chunk) + if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0: + break + except CurlError as e: + if e.code == CurlECode.AGAIN: + # We don't use the timeout here because it deadlocks if + # the socket is closed while recv() is waiting + await aselect(sock_fd, loop=loop, timeout=0.5) + else: + raise + + return b"".join(chunks), flags + + if timeout: + return await asyncio.wait_for(_inner_recv(), timeout=timeout) + else: + return await _inner_recv() async def recv_str(self, *, timeout: Optional[float] = None) -> str: """Receive a text frame. @@ -663,15 +681,15 @@ async def send( payload: data to send. flags: flags for the frame. """ - if self.closed: - raise WebSocketClosed("WebSocket is closed") - # curl expects bytes if isinstance(payload, str): payload = payload.encode() - # TODO: Why does concurrently sending fail async with self._send_lock: + # We must check the closed state after the last asyncio tick (i.e. the above async with call) + # as a race condition arises where the websocket is not yet closed until we're inside here + if self.closed: + raise WebSocketClosed("WebSocket is closed") return await self.loop.run_in_executor( None, self.curl.ws_send, payload, flags )