diff --git a/cashu/lightning/base.py b/cashu/lightning/base.py index f8bd925ee..45946810f 100644 --- a/cashu/lightning/base.py +++ b/cashu/lightning/base.py @@ -119,6 +119,10 @@ def assert_unit_supported(self, unit: Unit): if unit not in self.supported_units: raise Unsupported(f"Unit {unit} is not supported") + async def cleanup(self) -> None: + """Cleanup method for backends to release resources like connections""" + pass + @abstractmethod def __init__(self, unit: Unit, **kwargs): pass diff --git a/cashu/lightning/blink.py b/cashu/lightning/blink.py index 94a1326ab..e1abfac7d 100644 --- a/cashu/lightning/blink.py +++ b/cashu/lightning/blink.py @@ -85,6 +85,12 @@ def __init__(self, unit: Unit = Unit.sat, **kwargs): timeout=None, ) + async def cleanup(self): + try: + await self.client.aclose() + except RuntimeError as e: + logger.warning(f"Error closing wallet connection: {e}") + async def status(self) -> StatusResponse: try: data = { diff --git a/cashu/lightning/fake.py b/cashu/lightning/fake.py index 54344a813..d2bf0b50d 100644 --- a/cashu/lightning/fake.py +++ b/cashu/lightning/fake.py @@ -63,6 +63,16 @@ class FakeWallet(LightningBackend): def __init__(self, unit: Unit = Unit.sat, **kwargs): self.assert_unit_supported(unit) self.unit = unit + self.tasks: set[asyncio.Task] = set() + + async def cleanup(self) -> None: + """Cancel any running background tasks when the backend is cleaned up.""" + for task in self.tasks: + if not task.done(): + task.cancel() + if self.tasks: + await asyncio.gather(*self.tasks, return_exceptions=True) + self.tasks.clear() async def status(self) -> StatusResponse: return StatusResponse( @@ -179,7 +189,9 @@ async def create_invoice( payment_request = encode(bolt11, self.privkey) if settings.fakewallet_brr: - asyncio.create_task(self.mark_invoice_paid(bolt11)) + task = asyncio.create_task(self.mark_invoice_paid(bolt11)) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) return InvoiceResponse( ok=True, checking_id=payment_hash, payment_request=payment_request diff --git a/cashu/lightning/lnbits.py b/cashu/lightning/lnbits.py index e4392aeb0..f161405d3 100644 --- a/cashu/lightning/lnbits.py +++ b/cashu/lightning/lnbits.py @@ -45,6 +45,12 @@ def __init__(self, unit: Unit = Unit.sat, **kwargs): self.ws_url = f"{self.endpoint.replace('http', 'ws', 1)}/api/v1/ws/{settings.mint_lnbits_key}" self.old_api = True + async def cleanup(self): + try: + await self.client.aclose() + except RuntimeError as e: + logger.warning(f"Error closing wallet connection: {e}") + async def status(self) -> StatusResponse: try: r = await self.client.get(url=f"{self.endpoint}/api/v1/wallet", timeout=15) diff --git a/cashu/lightning/lndrest.py b/cashu/lightning/lndrest.py index f8057fecc..8276eee23 100644 --- a/cashu/lightning/lndrest.py +++ b/cashu/lightning/lndrest.py @@ -105,6 +105,12 @@ def __init__(self, unit: Unit = Unit.sat, **kwargs): if self.supports_mpp: logger.info("LNDRestWallet enabling MPP feature") + async def cleanup(self): + try: + await self.client.aclose() + except RuntimeError as e: + logger.warning(f"Error closing wallet connection: {e}") + async def status(self) -> StatusResponse: try: r = await self.client.get("/v1/balance/channels") diff --git a/cashu/lightning/strike.py b/cashu/lightning/strike.py index b069f9adb..71febc013 100644 --- a/cashu/lightning/strike.py +++ b/cashu/lightning/strike.py @@ -2,6 +2,7 @@ from typing import AsyncGenerator, Dict, Optional, Union import httpx +from loguru import logger from pydantic import BaseModel from ..core.base import Amount, MeltQuote, Unit @@ -133,6 +134,12 @@ def __init__(self, unit: Unit, **kwargs): timeout=None, ) + async def cleanup(self): + try: + await self.client.aclose() + except RuntimeError as e: + logger.warning(f"Error closing wallet connection: {e}") + async def status(self) -> StatusResponse: try: r = await self.client.get(url=f"{self.endpoint}/v1/balances", timeout=15) diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index e292e1671..c69035bf6 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -190,8 +190,6 @@ async def _check_backends(self) -> None: logger.info(f"Data dir: {settings.cashu_dir}") async def shutdown_ledger(self) -> None: - logger.debug("Disconnecting from database") - await self.db.engine.dispose() logger.debug("Shutting down invoice listeners") for task in self.invoice_listener_tasks: task.cancel() @@ -201,6 +199,20 @@ async def shutdown_ledger(self) -> None: for task in self.regular_tasks: task.cancel() + # Wait for all background tasks to finish cancellation + tasks_to_wait = self.invoice_listener_tasks + self.watchdog_tasks + self.regular_tasks + if tasks_to_wait: + await asyncio.gather(*tasks_to_wait, return_exceptions=True) + + logger.debug("Shutting down backends") + for method, unitbackends in self.backends.items(): + for unit, backend in unitbackends.items(): + if hasattr(backend, "cleanup"): + await backend.cleanup() + + logger.debug("Disconnecting from database") + await self.db.engine.dispose() + async def _check_pending_proofs_and_melt_quotes(self): """Startup routine that checks all pending melt quotes and either invalidates their pending proofs for a successful melt or deletes them if the melt failed.