Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def identifier(self) -> str:

@property
def kind(self) -> JSONRPCSubscriptionKinds:
return JSONRPCSubscriptionKinds.BOLT11_MELT_QUOTE
return JSONRPCSubscriptionKinds.MELT_QUOTE

@property
def unpaid(self) -> bool:
Expand Down Expand Up @@ -437,15 +437,21 @@ def from_row(cls, row: Row):
# SQLITE: row is timestamp (string)
created_time = int(row["created_time"]) if row["created_time"] else None
paid_time = int(row["paid_time"]) if row["paid_time"] else None
issued_time = int(row["issued_time"]) if "issued_time" in row.keys() and row["issued_time"] else None
issued_time = (
int(row["issued_time"])
if "issued_time" in row.keys() and row["issued_time"]
else None
)
except Exception:
# POSTGRES: row is datetime.datetime
created_time = (
int(row["created_time"].timestamp()) if row["created_time"] else None
)
paid_time = int(row["paid_time"].timestamp()) if row["paid_time"] else None
issued_time = (
int(row["issued_time"].timestamp()) if "issued_time" in row.keys() and row["issued_time"] else None
int(row["issued_time"].timestamp())
if "issued_time" in row.keys() and row["issued_time"]
else None
)
return cls(
quote=row["quote"],
Expand Down Expand Up @@ -487,7 +493,7 @@ def identifier(self) -> str:

@property
def kind(self) -> JSONRPCSubscriptionKinds:
return JSONRPCSubscriptionKinds.BOLT11_MINT_QUOTE
return JSONRPCSubscriptionKinds.MINT_QUOTE

@property
def unpaid(self) -> bool:
Expand Down Expand Up @@ -558,11 +564,11 @@ def str(self, amount: int | float) -> str:
elif self == Unit.msat:
return f"{amount} msat"
elif self == Unit.usd:
return f"${amount/100:.2f} USD"
return f"${amount / 100:.2f} USD"
elif self == Unit.eur:
return f"{amount/100:.2f} EUR"
return f"{amount / 100:.2f} EUR"
elif self == Unit.btc:
return f"{amount/1e8:.8f} BTC"
return f"{amount / 1e8:.8f} BTC"
elif self == Unit.auth:
return f"{amount} AUTH"
else:
Expand Down Expand Up @@ -623,18 +629,18 @@ def from_float(cls, amount: float, unit: Unit) -> "Amount":
def sat_to_btc(self) -> str:
if self.unit != Unit.sat:
raise Exception("Amount must be in satoshis")
return f"{self.amount/1e8:.8f}"
return f"{self.amount / 1e8:.8f}"

def msat_to_btc(self) -> str:
if self.unit != Unit.msat:
raise Exception("Amount must be in msat")
sat_amount = Amount(Unit.msat, self.amount).to(Unit.sat, round="up")
return f"{sat_amount.amount/1e8:.8f}"
return f"{sat_amount.amount / 1e8:.8f}"

def cents_to_usd(self) -> str:
if self.unit != Unit.usd and self.unit != Unit.eur:
raise Exception("Amount must be in cents")
return f"{self.amount/100:.2f}"
return f"{self.amount / 100:.2f}"

def str(self) -> str:
return self.unit.str(self.amount)
Expand Down Expand Up @@ -924,8 +930,7 @@ def from_row(cls, row: Row):
def public_keys_hex(self) -> Dict[int, str]:
assert self.public_keys, "public keys not set"
return {
int(amount): key.format().hex()
for amount, key in self.public_keys.items()
int(amount): key.format().hex() for amount, key in self.public_keys.items()
}

def generate_keys(self):
Expand Down Expand Up @@ -966,7 +971,7 @@ def generate_keys(self):
self.seed, self.derivation_path, self.amounts
)
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore

if id_in_db:
# If loading from DB, preserve existing ID
self.id = id_in_db
Expand All @@ -979,14 +984,19 @@ def generate_keys(self):
self.seed, self.derivation_path, self.amounts
)
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore

# KEYSETS V2: Use new keyset ID derivation
if id_in_db:
# If loading from DB, preserve existing ID
self.id = id_in_db
else:
assert self.public_keys is not None
self.id = derive_keyset_id_v2(self.public_keys, self.unit.name, self.final_expiry, self.input_fee_ppk)
self.id = derive_keyset_id_v2(
self.public_keys,
self.unit.name,
self.final_expiry,
self.input_fee_ppk,
)
logger.info(f"Generated keyset v2 ID: {self.id}")


Expand Down Expand Up @@ -1417,9 +1427,9 @@ def from_proof(cls, proof: Proof):
def to_base64(self):
serialize_dict = self.model_dump()
serialize_dict.pop("amount", None)
return (
self.prefix + base64.urlsafe_b64encode(json.dumps(serialize_dict).encode()).decode().rstrip("=")
)
return self.prefix + base64.urlsafe_b64encode(
json.dumps(serialize_dict).encode()
).decode().rstrip("=")

@classmethod
def from_base64(cls, base64_str: str):
Expand Down
29 changes: 26 additions & 3 deletions cashu/core/json_rpc/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Annotated, List
from typing import Annotated, List, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator

from ..constants import MAX_QUOTE_ID_LEN
from ..settings import settings
Expand Down Expand Up @@ -58,9 +58,27 @@ class JSONRPCMethods(Enum):


class JSONRPCSubscriptionKinds(Enum):
MINT_QUOTE = "mint_quote"
MELT_QUOTE = "melt_quote"
PROOF_STATE = "proof_state"

# TODO: Remove these deprecated bolt11-specific aliases once old websocket
# clients have been migrated to the method-independent NUT-17 kinds.
BOLT11_MINT_QUOTE = "bolt11_mint_quote"
BOLT11_MELT_QUOTE = "bolt11_melt_quote"
PROOF_STATE = "proof_state"

@classmethod
def normalize(
cls, kind: Union["JSONRPCSubscriptionKinds", str]
) -> "JSONRPCSubscriptionKinds":
parsed_kind = kind if isinstance(kind, cls) else cls(kind)

if parsed_kind == cls.BOLT11_MINT_QUOTE:
return cls.MINT_QUOTE
if parsed_kind == cls.BOLT11_MELT_QUOTE:
return cls.MELT_QUOTE

return parsed_kind


class JSONRPCStatus(Enum):
Expand All @@ -74,6 +92,11 @@ class JSONRPCSubscribeParams(BaseModel):
)
subId: str

@field_validator("kind", mode="before")
@classmethod
def normalize_kind(cls, kind):
return JSONRPCSubscriptionKinds.normalize(kind)


class JSONRPCUnsubscribeParams(BaseModel):
subId: str
Expand Down
2 changes: 1 addition & 1 deletion cashu/core/mint_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def supports_websocket_mint_quote(self, method: Method, unit: Unit) -> bool:
websocket_supported = websocket_settings["supported"]
for entry in websocket_supported:
if entry["method"] == method.name and entry["unit"] == unit.name:
if "bolt11_mint_quote" in entry["commands"]:
if "mint_quote" in entry["commands"]:
return True
return False

Expand Down
23 changes: 17 additions & 6 deletions cashu/mint/events/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def add_subscription(
filters: List[str],
subId: str,
) -> None:
kind = JSONRPCSubscriptionKinds.normalize(kind)

if kind not in self.subscriptions:
self.subscriptions[kind] = {}

Expand All @@ -187,7 +189,7 @@ def add_subscription(
for f in filters:
logger.debug(f"Adding subscription {subId} for filter {f}")
self.subscriptions[kind].setdefault(f, []).append(subId)

# Initialize the subscriptions in batch
asyncio.create_task(self._init_subscriptions(subId, filters, kind))

Expand Down Expand Up @@ -215,26 +217,35 @@ def serialize_event(self, event: LedgerEvent) -> dict:
async def _init_subscriptions(
self, subId: str, filters: List[str], kind: JSONRPCSubscriptionKinds
):
kind = JSONRPCSubscriptionKinds.normalize(kind)
results = []
async with self.db_read.db.connect() as conn:
if kind == JSONRPCSubscriptionKinds.BOLT11_MINT_QUOTE:
if kind == JSONRPCSubscriptionKinds.MINT_QUOTE:
for filter in filters:
mint_quote = await self.db_read.crud.get_mint_quote(
quote_id=filter, db=self.db_read.db, conn=conn
)
if mint_quote:
results.append(PostMintQuoteResponse.from_mint_quote(mint_quote).model_dump())
elif kind == JSONRPCSubscriptionKinds.BOLT11_MELT_QUOTE:
results.append(
PostMintQuoteResponse.from_mint_quote(
mint_quote
).model_dump()
)
elif kind == JSONRPCSubscriptionKinds.MELT_QUOTE:
for filter in filters:
melt_quote = await self.db_read.crud.get_melt_quote(
quote_id=filter, db=self.db_read.db, conn=conn
)
if melt_quote:
results.append(PostMeltQuoteResponse.from_melt_quote(melt_quote).model_dump())
results.append(
PostMeltQuoteResponse.from_melt_quote(
melt_quote
).model_dump()
)
elif kind == JSONRPCSubscriptionKinds.PROOF_STATE:
proofs = await self.db_read.get_proofs_states(Ys=filters, conn=conn)
for proof in proofs:
results.append(proof.model_dump())

for result in results:
await self._send_obj(result, subId)
10 changes: 5 additions & 5 deletions cashu/mint/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
_BOLT11 = "bolt11"
_MPP = "mpp"
_COMMANDS = "commands"
_BOLT11_MINT_QUOTE = "bolt11_mint_quote"
_BOLT11_MELT_QUOTE = "bolt11_melt_quote"
_MINT_QUOTE = "mint_quote"
_MELT_QUOTE = "melt_quote"
_PROOF_STATE = "proof_state"
_PROTECTED_ENDPOINTS = "protected_endpoints"
_BAT_MAX_MINT = "bat_max_mint"
Expand Down Expand Up @@ -148,23 +148,23 @@ def add_websocket_features(
websocket_features: Dict[str, List[Dict[str, Union[str, List[str]]]]] = {
_SUPPORTED: []
}
# we check the backend to see if "bolt11_mint_quote" is supported as well
# we check the backend to see if "mint_quote" is supported as well
for method, unit_dict in self.backends.items():
if method == Method[_BOLT11]:
for unit in unit_dict.keys():
websocket_features[_SUPPORTED].append(
{
_METHOD: method.name,
_UNIT: unit.name,
_COMMANDS: [_BOLT11_MELT_QUOTE, _PROOF_STATE],
_COMMANDS: [_MELT_QUOTE, _PROOF_STATE],
}
)
if unit_dict[unit].supports_incoming_payment_stream:
supported_features: List[str] = list(
websocket_features[_SUPPORTED][-1][_COMMANDS]
)
websocket_features[_SUPPORTED][-1][_COMMANDS] = (
supported_features + [_BOLT11_MINT_QUOTE]
supported_features + [_MINT_QUOTE]
)

if websocket_features:
Expand Down
2 changes: 1 addition & 1 deletion cashu/wallet/mint_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def supports_websocket_mint_quote(self, method: Method, unit: Unit) -> bool:
websocket_supported = websocket_settings["supported"]
for entry in websocket_supported:
if entry["method"] == method.name and entry["unit"] == unit.name:
if "bolt11_mint_quote" in entry["commands"]:
if "mint_quote" in entry["commands"]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a new wallet talks to an old mint that still advertises bolt11_mint_quote, would we skip websockets and fall back to polling?
may be worth accepting both during migration
"mint_quote" in commands or "bolt11_mint_quote" in commands

return True
return False

Expand Down
28 changes: 14 additions & 14 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ async def load_mint_info(self, reload=False, offline=False) -> MintInfo | None:
logger.debug("Updating mint info in db.")
await update_mint(
db=self.db,
mint=WalletMint(url=self.url, info=json.dumps(self.mint_info.model_dump())),
mint=WalletMint(
url=self.url, info=json.dumps(self.mint_info.model_dump())
),
)
return self.mint_info
else:
Expand Down Expand Up @@ -482,7 +484,7 @@ async def request_mint_with_callback(
target=subscriptions.connect, name="SubscriptionManager", daemon=True
).start()
subscriptions.subscribe(
kind=JSONRPCSubscriptionKinds.BOLT11_MINT_QUOTE,
kind=JSONRPCSubscriptionKinds.MINT_QUOTE,
filters=[mint_quote.quote],
callback=callback,
)
Expand Down Expand Up @@ -700,9 +702,9 @@ async def split(
send_outputs, keep_outputs, secret_lock
)

assert len(secrets) == len(
amounts
), "number of secrets does not match number of outputs"
assert len(secrets) == len(amounts), (
"number of secrets does not match number of outputs"
)
# verify that we didn't accidentally reuse a secret
await self._check_used_secrets(secrets)

Expand Down Expand Up @@ -947,9 +949,9 @@ def verify_proofs_dleq(self, proofs: List[Proof]):
return
logger.trace("Verifying DLEQ proof.")
assert proof.id
assert (
proof.id in self.keysets
), f"Keyset {proof.id} not known, can not verify DLEQ."
assert proof.id in self.keysets, (
f"Keyset {proof.id} not known, can not verify DLEQ."
)
if not b_dhke.carol_verify_dleq(
secret_msg=proof.secret,
C=PublicKey(bytes.fromhex(proof.C)),
Expand Down Expand Up @@ -1060,9 +1062,9 @@ def _construct_outputs(
Raises:
AssertionError: if len(amounts) != len(secrets)
"""
assert len(amounts) == len(
secrets
), f"len(amounts)={len(amounts)} not equal to len(secrets)={len(secrets)}"
assert len(amounts) == len(secrets), (
f"len(amounts)={len(amounts)} not equal to len(secrets)={len(secrets)}"
)
keyset_id = keyset_id or self.keyset_id
outputs: List[BlindedMessage] = []
rs_ = [None] * len(amounts) if not rs else rs
Expand All @@ -1077,9 +1079,7 @@ def _construct_outputs(

assert r
rs_return.append(r)
output = BlindedMessage(
amount=amount, B_=B_.format().hex(), id=keyset_id
)
output = BlindedMessage(amount=amount, B_=B_.format().hex(), id=keyset_id)
outputs.append(output)
logger.trace(f"Constructing output: {output}, r: {r.to_hex()}")

Expand Down
Loading
Loading