Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
16d870b
feat(crypto): introduce backward compatible keyset v3 (prefix 02) wit…
a1denvalu3 May 5, 2026
68efbf2
fix(crypto): replace keyset v3 prefix checks with general >= 02 logic
a1denvalu3 May 5, 2026
0c11684
perf(crypto): optimize BLS pairing verification
a1denvalu3 May 20, 2026
767b4c5
fix(crypto): include 02 prefix in is_base64_keyset_id check
a1denvalu3 May 20, 2026
e9f548f
fixes
a1denvalu3 May 20, 2026
8dba59c
test vectors from NUT
a1denvalu3 May 20, 2026
95ac573
feat(crypto): add BLS12-381 (v3) test vectors and debug tracing
a1denvalu3 May 20, 2026
7142c90
chore: remove accidental junk files
a1denvalu3 May 20, 2026
7a000c5
feat(bls): Add subgroup checks for public points and deterministic ra…
a1denvalu3 May 21, 2026
542d09c
test: add BLS12-381 (v3) test vectors
a1denvalu3 May 21, 2026
0dc6dd2
fix: update tests for BLS12-381 test vectors
a1denvalu3 May 21, 2026
2fd4809
refactor(crypto): global G2 generator caching for BLS
a1denvalu3 May 22, 2026
d997da1
refactor(crypto): formally verify BLS point at infinity using pyblst
a1denvalu3 May 22, 2026
69c370a
fix: resolve mypy errors for BLS12-381 keysets
a1denvalu3 May 25, 2026
99dff86
fix: bls12-381-v3-keyset implementation
a1denvalu3 May 26, 2026
4f02151
fix: refactor duck typing to explicit isinstance checks
a1denvalu3 May 26, 2026
09045b4
fix: secure BLS signature verification and prevent Mint server DoS
a1denvalu3 Jun 5, 2026
c61f7d0
refactor(crypto): improve BLS derivation and error handling
a1denvalu3 Jun 8, 2026
4da3e2a
refactor(crypto): replace custom mod_inverse with built-in pow()
a1denvalu3 Jun 8, 2026
4ca3be7
refactor(crypto): simplify return statement in keyed_verification
a1denvalu3 Jun 8, 2026
6a6a964
refactor(crypto): hoist nested imports to module level and update tes…
a1denvalu3 Jun 8, 2026
c31c68b
fix(mint): resolve db operations and init failures with BLS keysets
a1denvalu3 Jun 8, 2026
ce56376
refactor(tests): hoist dynamic imports to top level in test_mint_db_o…
a1denvalu3 Jun 12, 2026
4230932
fix(crypto): ensure unit string is explicitly lowercased in v2 and v3…
a1denvalu3 Jun 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@
from ..mint.events.event_model import LedgerEvent
from .crypto.aes import AESCipher
from .crypto.b_dhke import hash_to_curve
from .crypto.bls import PrivateKey as BlsPrivateKey
from .crypto.bls import PublicKey as BlsPublicKey
from .crypto.keys import (
derive_keys,
derive_keys_deprecated_pre_0_15,
derive_keys_v3,
derive_keyset_id,
derive_keyset_id_deprecated,
derive_keyset_id_v2,
derive_keyset_id_v3,
derive_pubkeys,
is_bls_keyset,
)
from .crypto.secp import PrivateKey, PublicKey
from .crypto.secp import PrivateKey as SecpPrivateKey
from .crypto.secp import PublicKey as SecpPublicKey
from .legacy import derive_keys_backwards_compatible_insecure_pre_0_12
from .settings import settings

PrivateKey = Union[SecpPrivateKey, BlsPrivateKey]
PublicKey = Union[SecpPublicKey, BlsPublicKey]


class DLEQ(BaseModel):
"""
Expand Down Expand Up @@ -769,12 +778,16 @@ def serialize(self):
)

@classmethod
def from_row(cls, row: Row):
def from_row(cls, row: RowMapping):
def deserialize(serialized: str) -> Dict[int, PublicKey]:
return {
int(amount): PublicKey(bytes.fromhex(hex_key))
for amount, hex_key in dict(json.loads(serialized)).items()
}
is_v3 = is_bls_keyset(row["id"])
pub_keys: Dict[int, PublicKey] = {}
for amount, hex_key in dict(json.loads(serialized)).items():
if is_v3:
pub_keys[int(amount)] = BlsPublicKey(bytes.fromhex(hex_key), group="G2")
else:
pub_keys[int(amount)] = SecpPublicKey(bytes.fromhex(hex_key))
return pub_keys

return cls(
id=row["id"],
Expand Down Expand Up @@ -984,7 +997,7 @@ def generate_keys(self):
assert self.public_keys is not None
self.id = derive_keyset_id(self.public_keys)
logger.info(f"Generated keyset v1 ID: {self.id}")
else:
elif self.version_tuple < (0, 21):
self.private_keys = derive_keys(
self.seed, self.derivation_path, self.amounts
)
Expand All @@ -998,6 +1011,19 @@ def generate_keys(self):
assert self.public_keys is not None
self.id = derive_keyset_id_v2(self.public_keys, self.unit.name, self.final_expiry, self.input_fee_ppk)
logger.info(f"Generated keyset v2 ID: {self.id}")
else:
self.private_keys = derive_keys_v3(
self.seed, self.derivation_path, self.amounts
) # type: ignore[assignment]
self.public_keys = derive_pubkeys(self.private_keys, self.amounts) # type: ignore

# KEYSETS V3: BLS12-381 cryptography
if id_in_db:
self.id = id_in_db
else:
assert self.public_keys is not None
self.id = derive_keyset_id_v3(self.public_keys, self.unit.name, self.final_expiry, self.input_fee_ppk) # type: ignore[arg-type]
logger.info(f"Generated keyset v3 (BLS) ID: {self.id}")


# ------- TOKEN -------
Expand Down
2 changes: 1 addition & 1 deletion cashu/core/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Maximum lengths for Pydantic string fields
MAX_UNIT_LEN = 64
MAX_PUBKEY_LEN = 66
MAX_PUBKEY_LEN = 96
MAX_SIG_LEN = 130
MAX_QUOTE_ID_LEN = 256
MAX_INVOICE_DESC_LEN = 1024
Expand Down
5 changes: 4 additions & 1 deletion cashu/core/crypto/b_dhke.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def step2_bob_dleq(
A = a.public_key
assert A
e = hash_e(R1, R2, A, C_) # e = hash(R1, R2, A, C_)
s = p.add(bytes.fromhex(a.multiply(e).to_hex())) # s = p + ek
if isinstance(a, PrivateKey):
s = p.add(bytes.fromhex(a.multiply(e).to_hex())) # s = p + ek
else:
raise TypeError(f"Expected SecpPrivateKey, got {type(a)}")
spk = PrivateKey(bytes.fromhex(s.to_hex()))
epk = PrivateKey(e)
return epk, spk
Expand Down
75 changes: 75 additions & 0 deletions cashu/core/crypto/bls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
from typing import Optional

import pyblst

curve_order = 52435875175126190479447740508185965837690552500527637822603658699938581184513
_G2_HEX = '93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8'
G2 = pyblst.BlstP2Element().uncompress(bytes.fromhex(_G2_HEX))


class PrivateKey:
def __init__(self, privkey: bytes = b"", scalar: Optional[int] = None):
if scalar is not None:
self.scalar = scalar % curve_order
elif privkey:
self.scalar = int.from_bytes(privkey, "big") % curve_order
else:
self.scalar = int.from_bytes(os.urandom(32), "big") % curve_order

@property
def private_key(self) -> bytes:
return self.scalar.to_bytes(32, "big")

def to_hex(self) -> str:
return self.private_key.hex()

def get_g2_public_key(self) -> "PublicKey":
pt = G2.scalar_mul(self.scalar)
return PublicKey(point=pt, group="G2")

@property
def public_key(self) -> "PublicKey":
return self.get_g2_public_key()


class PublicKey:
def __init__(self, compressed: bytes = b"", point=None, group="G1"):
self.group = group
try:
if point is not None:
self.point = point
elif compressed:
if self.group == "G1":
self.point = pyblst.BlstP1Element().uncompress(compressed)
else:
self.point = pyblst.BlstP2Element().uncompress(compressed)
else:
raise ValueError("Must provide point or compressed bytes")
except Exception:
raise ValueError("The public key could not be parsed or is invalid.")

def format(self, compressed: bool = True) -> bytes:
return self.point.compress()

def serialize(self) -> bytes:
return self.format()

def is_infinity(self) -> bool:
"""Check if the point is the point at infinity (additive identity)."""
if self.group == "G1":
return self.point == pyblst.BlstP1Element()
else:
return self.point == pyblst.BlstP2Element()

def __eq__(self, other):
if isinstance(other, PublicKey):
return self.point == other.point
return False

def __mul__(self, scalar):
if isinstance(scalar, PrivateKey):
return PublicKey(point=self.point.scalar_mul(scalar.scalar), group=self.group)
elif isinstance(scalar, int):
return PublicKey(point=self.point.scalar_mul(scalar), group=self.group)
raise TypeError("Can't multiply with non-scalar")
149 changes: 149 additions & 0 deletions cashu/core/crypto/bls_dhke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import hashlib
from typing import Optional, Tuple

import pyblst
from loguru import logger

from .bls import G2, PrivateKey, PublicKey, curve_order

# Cashu specific domain separation tag for BLS12-381 G1
DST = b"CASHU_BLS12_381_G1_XMD:SHA-256_SSWU_RO_"
BLS_BATCH_DST = b"Cashu_BLS_Batch_v1"

def hash_to_curve(message: bytes) -> PublicKey:
"""
Hash a message to a point on G1 using SSWU.
"""
pt = pyblst.BlstP1Element().hash_to_group(message, DST)
return PublicKey(point=pt, group="G1")

def step1_alice(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
"""
Alice blinds the message: B' = Y * r
where Y = hash_to_curve(secret_msg)
"""
Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y * r
logger.trace(f"BLS step1: secret='{secret_msg}' -> Y={Y.format().hex()} B_={B_.format().hex()} r={r.to_hex()}")
return B_, r

def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]:
"""
Bob signs the blinded message: C' = B' * a
Returns C' and dummy DLEQ values since BLS12-381 pairings make DLEQ proofs redundant.
"""
if B_.is_infinity():
raise ValueError("Invalid blinded message: point at infinity")

# The point was already checked to be in G1 during uncompression
# pyblst.BlstP1Element().uncompress() performs the subgroup check
# and throws BLST_POINT_NOT_IN_GROUP if the point is not in G1

C_: PublicKey = B_ * a
logger.trace(f"BLS step2: B_={B_.format().hex()} a={a.to_hex()} C_={C_.format().hex()}")
# Return dummy private keys for backwards compatibility with DLEQ logic elsewhere
return C_, PrivateKey(scalar=1), PrivateKey(scalar=1)

def step3_alice(C_: PublicKey, r: PrivateKey, A: PublicKey) -> PublicKey:
"""
Alice unblinds the signature: C = C' * (1/r)
"""
r_inv = pow(r.scalar, -1, curve_order)
C: PublicKey = C_ * r_inv
logger.trace(f"BLS step3: C_={C_.format().hex()} C={C.format().hex()} r={r.to_hex()}")
return C

def keyed_verification(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
"""
Mint verification: checks C == Y * a
"""
Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
return C == Y * a

def pairing_verification(K2: PublicKey, C: PublicKey, secret_msg: str) -> bool:
"""
Verify the BLS signature using pairings.
e(C, G2) == e(Y, K2)
"""
Y = hash_to_curve(secret_msg.encode("utf-8"))

p1 = pyblst.miller_loop(-C.point, G2)
p2 = pyblst.miller_loop(Y.point, K2.point)
return pyblst.final_verify(p1 * p2, pyblst.BlstFP12Element())

def derive_batch_random_scalars(K2s: list[PublicKey], Cs: list[PublicKey], secret_msgs: list[str]) -> list[int]:
"""
Derives deterministic random scalars for batch verification using the Fiat-Shamir heuristic
and rejection sampling to ensure scalars are uniformly distributed over Fr*.
"""
n = len(Cs)
transcript = BLS_BATCH_DST
for i in range(n):
secret_bytes = secret_msgs[i].encode("utf-8")
transcript += Cs[i].format()
transcript += K2s[i].format()
transcript += len(secret_bytes).to_bytes(4, "big")
transcript += secret_bytes

challenge = hashlib.sha256(transcript).digest()

rs = []
for i in range(n):
ctr = 0
while True:
h = hashlib.sha256(challenge + i.to_bytes(4, "big") + ctr.to_bytes(4, "big")).digest()
x = int.from_bytes(h, "big")
if x != 0 and x < curve_order:
rs.append(x)
break
ctr += 1

return rs

def batch_pairing_verification(K2s: list[PublicKey], Cs: list[PublicKey], secret_msgs: list[str]) -> bool:
"""
Batch verifies BLS12-381 signatures using random linear combinations.
This significantly improves performance over checking each signature individually.
"""
n = len(Cs)
if n == 0:
return True

rs = derive_batch_random_scalars(K2s, Cs, secret_msgs)

Ys = [hash_to_curve(msg.encode("utf-8")) for msg in secret_msgs]

# Left side: sum(r_i * C_i)
sum_C = Cs[0].point.scalar_mul(rs[0])
for i in range(1, n):
sum_C = sum_C + Cs[i].point.scalar_mul(rs[i])

# Right side: prod(e(sum(r_i * Y_i), K2_j)) grouped by unique K2
# Group the Y points by their corresponding K2 point
grouped_Ys = {}
for i in range(n):
k2_hex = K2s[i].format().hex()
y_r = Ys[i].point.scalar_mul(rs[i])

if k2_hex not in grouped_Ys:
grouped_Ys[k2_hex] = {"k2": K2s[i].point, "sum_y": y_r}
else:
grouped_Ys[k2_hex]["sum_y"] = grouped_Ys[k2_hex]["sum_y"] + y_r

# Now compute the pairings for each unique K2
miller = pyblst.miller_loop(-sum_C, G2)
for group in grouped_Ys.values():
miller = miller * pyblst.miller_loop(group["sum_y"], group["k2"])

return pyblst.final_verify(miller, pyblst.BlstFP12Element())

def hash_e(*publickeys: PublicKey) -> bytes:
"""Dummy for backwards compatibility"""
e_ = ""
for p in publickeys:
_p = p.format(compressed=True).hex()
e_ += str(_p)
return hashlib.sha256(e_.encode("utf-8")).digest()
Loading
Loading