From 55137030539a3d4c9c944eb8fc9755f117b14876 Mon Sep 17 00:00:00 2001 From: b-l-u-e <8102260+blue@users.noreply.github.com> Date: Fri, 22 May 2026 01:11:36 +0300 Subject: [PATCH] fix(mint): handle websocket disconnect frames cleanly Signed-off-by: b-l-u-e <8102260+blue@users.noreply.github.com> --- cashu/mint/events/client.py | 2 ++ cashu/mint/router.py | 9 ++++----- tests/mint/test_mint_websocket_protocol.py | 12 ++++++++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/cashu/mint/events/client.py b/cashu/mint/events/client.py index 6c78e3385..3f1c2d1bd 100644 --- a/cashu/mint/events/client.py +++ b/cashu/mint/events/client.py @@ -51,6 +51,8 @@ async def start(self): self.websocket.receive(), timeout=settings.mint_websocket_read_timeout, ) + if message.get("type") == "websocket.disconnect": + raise WebSocketDisconnect(code=message.get("code", 1000)) message_text = message.get("text") # Check the rate limit diff --git a/cashu/mint/router.py b/cashu/mint/router.py index 55f0439eb..84867200d 100644 --- a/cashu/mint/router.py +++ b/cashu/mint/router.py @@ -212,7 +212,7 @@ async def get_mint_quote(request: Request, quote: str) -> PostMintQuoteResponse: @router.websocket("/v1/ws", name="Websocket endpoint for subscriptions") async def websocket_endpoint(websocket: WebSocket): limit_websocket(websocket) - disconnected = False + client = None try: client = ledger.events.add_client(websocket, ledger.db, ledger.crud) except Exception as e: @@ -225,13 +225,12 @@ async def websocket_endpoint(websocket: WebSocket): await client.start() except WebSocketDisconnect as e: logger.debug(f"Websocket disconnected: {e}") - disconnected = True - return except Exception as e: logger.debug(f"Exception: {e}") - ledger.events.remove_client(client) finally: - if not disconnected: + if client and client in ledger.events.clients: + ledger.events.remove_client(client) + if websocket.client_state.name != "DISCONNECTED": await asyncio.wait_for(websocket.close(), timeout=1) diff --git a/tests/mint/test_mint_websocket_protocol.py b/tests/mint/test_mint_websocket_protocol.py index e7b4ba073..e6c83ea50 100644 --- a/tests/mint/test_mint_websocket_protocol.py +++ b/tests/mint/test_mint_websocket_protocol.py @@ -92,6 +92,18 @@ async def test_websocket_start_returns_jsonrpc_errors(monkeypatch): assert parsed[3]["error"]["code"] == JSONRPCErrorCode.INTERNAL_ERROR.value +@pytest.mark.asyncio +async def test_websocket_start_exits_on_disconnect_message(monkeypatch): + websocket = FakeWebSocket([{"type": "websocket.disconnect", "code": 1012}]) + manager = _client_manager(websocket) + monkeypatch.setattr("cashu.mint.events.client.limit_websocket", lambda ws: None) + + with pytest.raises(WebSocketDisconnect) as exc_info: + await manager.start() + + assert exc_info.value.code == 1012 + + @pytest.mark.asyncio async def test_handle_request_subscribe_and_unsubscribe_roundtrip(monkeypatch): manager = _client_manager(FakeWebSocket())