From 2f2ec1231be861d753ca8c9bc79936586f4f2115 Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 2 Oct 2025 17:21:55 +0700 Subject: [PATCH 01/21] wip: implement taurus lib --- e2e/go.mod | 14 +- e2e/go.sum | 34 +- examples/generate/main.go | 12 + go.mod | 8 +- go.sum | 29 +- peers.json | 5 + pkg/encoding/json.go | 13 + pkg/event/keygen.go | 21 +- pkg/eventconsumer/event_consumer.go | 133 ++++--- pkg/mpc/node.go | 58 +++ pkg/mpc/taurus/ecdsa_resharing_session.go | 351 +++++++++++++++++ pkg/mpc/taurus/eddsa_resharing_session.go | 345 +++++++++++++++++ pkg/mpc/taurus/keygen_session.go | 328 ++++++++++++++++ pkg/mpc/taurus/node.go | 242 ++++++++++++ pkg/mpc/taurus/node_test.go | 119 ++++++ pkg/mpc/taurus/registry.go | 213 ++++++++++ pkg/mpc/taurus/reshare_session.go | 32 ++ pkg/mpc/taurus/session.go | 184 +++++++++ pkg/mpc/taurus/signing_session.go | 332 ++++++++++++++++ pkg/protocol/cggmp21/adapter.go | 452 ++++++++++++++++++++++ pkg/protocol/frost/adapter.go | 445 +++++++++++++++++++++ pkg/protocol/interfaces.go | 91 +++++ pkg/types/taurus.go | 10 + pkg/utils/utils.go | 20 + setup_identities.sh | 3 + wallets.json | 3 + 26 files changed, 3424 insertions(+), 73 deletions(-) create mode 100644 peers.json create mode 100644 pkg/encoding/json.go create mode 100644 pkg/mpc/taurus/ecdsa_resharing_session.go create mode 100644 pkg/mpc/taurus/eddsa_resharing_session.go create mode 100644 pkg/mpc/taurus/keygen_session.go create mode 100644 pkg/mpc/taurus/node.go create mode 100644 pkg/mpc/taurus/node_test.go create mode 100644 pkg/mpc/taurus/registry.go create mode 100644 pkg/mpc/taurus/reshare_session.go create mode 100644 pkg/mpc/taurus/session.go create mode 100644 pkg/mpc/taurus/signing_session.go create mode 100644 pkg/protocol/cggmp21/adapter.go create mode 100644 pkg/protocol/frost/adapter.go create mode 100644 pkg/protocol/interfaces.go create mode 100644 pkg/types/taurus.go create mode 100644 pkg/utils/utils.go create mode 100644 wallets.json diff --git a/e2e/go.mod b/e2e/go.mod index 4112bb3..82c3de4 100644 --- a/e2e/go.mod +++ b/e2e/go.mod @@ -1,12 +1,12 @@ module github.com/fystack/mpcium/e2e -go 1.23.0 +go 1.23.8 require ( github.com/dgraph-io/badger/v4 v4.7.0 github.com/fystack/mpcium v0.0.0-00010101000000-000000000000 github.com/google/uuid v1.6.0 - github.com/hashicorp/consul/api v1.26.1 + github.com/hashicorp/consul/api v1.32.1 github.com/nats-io/nats.go v1.31.0 github.com/stretchr/testify v1.10.0 gopkg.in/yaml.v2 v2.4.0 @@ -37,13 +37,15 @@ require ( github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect github.com/btcsuite/btcutil v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cronokirby/saferith v0.33.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 // indirect - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fxamacker/cbor/v2 v2.4.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -61,6 +63,7 @@ require ( github.com/ipfs/go-log v1.0.5 // indirect github.com/ipfs/go-log/v2 v2.1.3 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.5 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -83,6 +86,9 @@ require ( github.com/spf13/pflag v1.0.6 // indirect github.com/spf13/viper v1.18.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/taurusgroup/multi-party-sig v0.7.0-alpha-2025-01-28 // indirect + github.com/x448/float16 v0.8.4 // indirect + github.com/zeebo/blake3 v0.2.3 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel/metric v1.35.0 // indirect @@ -91,7 +97,7 @@ require ( go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.21.0 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 // indirect + golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/net v0.39.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/term v0.31.0 // indirect diff --git a/e2e/go.sum b/e2e/go.sum index 804a2c4..cac7878 100644 --- a/e2e/go.sum +++ b/e2e/go.sum @@ -88,6 +88,8 @@ github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6D github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/cronokirby/saferith v0.33.0 h1:TgoQlfsD4LIwx71+ChfRcIpjkw+RPOapDEVxa+LhwLo= +github.com/cronokirby/saferith v0.33.0/go.mod h1:QKJhjoqUtBsXCAVEjw38mFqoi7DebT7kthcD7UzbnoA= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -96,8 +98,9 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8Yc github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 h1:l/lhv2aJCUignzls81+wvga0TFlyoZx8QxRMQgXpZik= github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3/go.mod h1:AKpV6+wZ2MfPRJnTbQ6NPgWrKzbe9RCIlCF/FKzMtM8= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/decred/dcrd/lru v1.0.0/go.mod h1:mxKOwFd7lFjN2GZYsiz/ecgqR6kkYAl+0pz0tEMk218= github.com/dgraph-io/badger/v4 v4.7.0 h1:Q+J8HApYAY7UMpL8d9owqiB+odzEc0zn/aqOD9jhc6Y= github.com/dgraph-io/badger/v4 v4.7.0/go.mod h1:He7TzG3YBy3j4f5baj5B7Zl2XyfNe5bl4Udl0aPemVA= @@ -118,6 +121,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= +github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= @@ -158,10 +163,10 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hashicorp/consul/api v1.26.1 h1:5oSXOO5fboPZeW5SN+TdGFP/BILDgBm19OrPZ/pICIM= -github.com/hashicorp/consul/api v1.26.1/go.mod h1:B4sQTeaSO16NtynqrAdwOlahJ7IUDZM9cj2420xYL8A= -github.com/hashicorp/consul/sdk v0.15.0 h1:2qK9nDrr4tiJKRoxPGhm6B7xJjLVIQqkjiab2M4aKjU= -github.com/hashicorp/consul/sdk v0.15.0/go.mod h1:r/OmRRPbHOe0yxNahLw7G9x5WG17E1BIECMtCjcPSNo= +github.com/hashicorp/consul/api v1.32.1 h1:0+osr/3t/aZNAdJX558crU3PEjVrG4x6715aZHRgceE= +github.com/hashicorp/consul/api v1.32.1/go.mod h1:mXUWLnxftwTmDv4W3lzxYCPD199iNLLUyLfLGFJbtl4= +github.com/hashicorp/consul/sdk v0.16.1 h1:V8TxTnImoPD5cj0U9Spl0TUxcytjcbbJeADFF07KdHg= +github.com/hashicorp/consul/sdk v0.16.1/go.mod h1:fSXvwxB2hmh1FMZCNl6PwX0Q/1wdWtHJcZ7Ea5tns0s= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -220,6 +225,9 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= +github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= +github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -361,12 +369,22 @@ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= +github.com/taurusgroup/multi-party-sig v0.7.0-alpha-2025-01-28 h1:rbyJpV3kH/aMxG7gUQ5ynveAEXuPiIG136Ld3HGNV7I= +github.com/taurusgroup/multi-party-sig v0.7.0-alpha-2025-01-28/go.mod h1:roZI3gaKCo15PUSB4LdJpTLTjq8TFsJiOH5kpcN1HpQ= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= +github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= +github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ= +github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= +github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= @@ -402,8 +420,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= -golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 h1:hNQpMuAJe5CtcUqCXaWga3FHu+kQvCqcsoVaQgSV60o= -golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw= +golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -440,6 +458,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/examples/generate/main.go b/examples/generate/main.go index 3f0135f..a2cd085 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -92,9 +92,19 @@ func main() { walletIDsMu.Unlock() } + // Track processed results to prevent duplicate processing + processedResults := sync.Map{} + // STEP 2: Register the result handler AFTER all walletIDs are stored err = mpcClient.OnWalletCreationResult(func(event event.KeygenResultEvent) { logger.Info("Received wallet creation result", "event", event) + + // Check if we've already processed this result + if _, alreadyProcessed := processedResults.LoadOrStore(event.WalletID, true); alreadyProcessed { + logger.Warn("Duplicate wallet result received, ignoring", "walletID", event.WalletID) + return + } + now := time.Now() startTimeAny, ok := walletStartTimes.Load(event.WalletID) if ok { @@ -127,6 +137,8 @@ func main() { if err := mpcClient.CreateWallet(walletID); err != nil { logger.Error("CreateWallet failed", err) walletStartTimes.Delete(walletID) + // Mark this wallet as processed to prevent callback from processing it + processedResults.Store(walletID, true) wg.Done() // Now this is safe since we added 1 above continue } diff --git a/go.mod b/go.mod index d9cc99b..21eee92 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/kms v1.45.0 github.com/bnb-chain/tss-lib/v2 v2.0.2 github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 github.com/dgraph-io/badger/v4 v4.7.0 github.com/google/uuid v1.6.0 github.com/hashicorp/consul/api v1.32.1 @@ -21,6 +22,7 @@ require ( github.com/samber/lo v1.39.0 github.com/spf13/viper v1.18.0 github.com/stretchr/testify v1.10.0 + github.com/taurusgroup/multi-party-sig v0.7.0-alpha-2025-01-28 github.com/urfave/cli/v3 v3.3.2 golang.org/x/crypto v0.37.0 golang.org/x/term v0.31.0 @@ -45,12 +47,13 @@ require ( github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect github.com/btcsuite/btcutil v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cronokirby/saferith v0.33.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fxamacker/cbor/v2 v2.4.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -68,6 +71,7 @@ require ( github.com/ipfs/go-log v1.0.5 // indirect github.com/ipfs/go-log/v2 v2.1.3 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.5 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -88,6 +92,8 @@ require ( github.com/spf13/pflag v1.0.6 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/x448/float16 v0.8.4 // indirect + github.com/zeebo/blake3 v0.2.3 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel/metric v1.35.0 // indirect diff --git a/go.sum b/go.sum index da2768e..df0a58f 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6D github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/cronokirby/saferith v0.33.0 h1:TgoQlfsD4LIwx71+ChfRcIpjkw+RPOapDEVxa+LhwLo= +github.com/cronokirby/saferith v0.33.0/go.mod h1:QKJhjoqUtBsXCAVEjw38mFqoi7DebT7kthcD7UzbnoA= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -96,8 +98,9 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8Yc github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 h1:l/lhv2aJCUignzls81+wvga0TFlyoZx8QxRMQgXpZik= github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3/go.mod h1:AKpV6+wZ2MfPRJnTbQ6NPgWrKzbe9RCIlCF/FKzMtM8= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/decred/dcrd/lru v1.0.0/go.mod h1:mxKOwFd7lFjN2GZYsiz/ecgqR6kkYAl+0pz0tEMk218= github.com/dgraph-io/badger/v4 v4.7.0 h1:Q+J8HApYAY7UMpL8d9owqiB+odzEc0zn/aqOD9jhc6Y= github.com/dgraph-io/badger/v4 v4.7.0/go.mod h1:He7TzG3YBy3j4f5baj5B7Zl2XyfNe5bl4Udl0aPemVA= @@ -118,6 +121,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= +github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= @@ -158,13 +163,10 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hashicorp/consul/api v1.26.1 h1:5oSXOO5fboPZeW5SN+TdGFP/BILDgBm19OrPZ/pICIM= -github.com/hashicorp/consul/api v1.26.1/go.mod h1:B4sQTeaSO16NtynqrAdwOlahJ7IUDZM9cj2420xYL8A= github.com/hashicorp/consul/api v1.32.1 h1:0+osr/3t/aZNAdJX558crU3PEjVrG4x6715aZHRgceE= github.com/hashicorp/consul/api v1.32.1/go.mod h1:mXUWLnxftwTmDv4W3lzxYCPD199iNLLUyLfLGFJbtl4= -github.com/hashicorp/consul/sdk v0.15.0 h1:2qK9nDrr4tiJKRoxPGhm6B7xJjLVIQqkjiab2M4aKjU= -github.com/hashicorp/consul/sdk v0.15.0/go.mod h1:r/OmRRPbHOe0yxNahLw7G9x5WG17E1BIECMtCjcPSNo= github.com/hashicorp/consul/sdk v0.16.1 h1:V8TxTnImoPD5cj0U9Spl0TUxcytjcbbJeADFF07KdHg= +github.com/hashicorp/consul/sdk v0.16.1/go.mod h1:fSXvwxB2hmh1FMZCNl6PwX0Q/1wdWtHJcZ7Ea5tns0s= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -223,6 +225,9 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= +github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= +github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -364,14 +369,24 @@ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= +github.com/taurusgroup/multi-party-sig v0.7.0-alpha-2025-01-28 h1:rbyJpV3kH/aMxG7gUQ5ynveAEXuPiIG136Ld3HGNV7I= +github.com/taurusgroup/multi-party-sig v0.7.0-alpha-2025-01-28/go.mod h1:roZI3gaKCo15PUSB4LdJpTLTjq8TFsJiOH5kpcN1HpQ= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli/v3 v3.3.2 h1:BYFVnhhZ8RqT38DxEYVFPPmGFTEf7tJwySTXsVRrS/o= github.com/urfave/cli/v3 v3.3.2/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= +github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= +github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ= +github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= +github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= @@ -407,8 +422,6 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw= golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -447,6 +460,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/peers.json b/peers.json new file mode 100644 index 0000000..0670265 --- /dev/null +++ b/peers.json @@ -0,0 +1,5 @@ +{ + "node0": "aa4adaea-257d-4337-842a-1d3f966d85c2", + "node1": "21ac5259-ac9e-4b81-bd42-d05f584879e4", + "node2": "2fff5119-a1f1-4763-8f4c-d7d88c212608" +} \ No newline at end of file diff --git a/pkg/encoding/json.go b/pkg/encoding/json.go new file mode 100644 index 0000000..87ad4a3 --- /dev/null +++ b/pkg/encoding/json.go @@ -0,0 +1,13 @@ +package encoding + +import "encoding/json" + +// StructToJsonBytes converts a struct to JSON bytes +func StructToJsonBytes(v any) ([]byte, error) { + return json.Marshal(v) +} + +// JsonBytesToStruct converts JSON bytes to a struct +func JsonBytesToStruct(data []byte, v any) error { + return json.Unmarshal(data, v) +} diff --git a/pkg/event/keygen.go b/pkg/event/keygen.go index 78ab631..f72b1ae 100644 --- a/pkg/event/keygen.go +++ b/pkg/event/keygen.go @@ -7,11 +7,26 @@ const ( ) type KeygenResultEvent struct { - WalletID string `json:"wallet_id"` - ECDSAPubKey []byte `json:"ecdsa_pub_key"` - EDDSAPubKey []byte `json:"eddsa_pub_key"` + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` + TaurusCMPPubKey []byte `json:"taurus_cmp_pub_key"` ResultType ResultType `json:"result_type"` ErrorReason string `json:"error_reason"` ErrorCode string `json:"error_code"` } + +// CreateKeygenFailureEvent creates a failed keygen event +func CreateKeygenFailureEvent(walletID string, metadata map[string]any) *KeygenResultEvent { + errorMsg := "" + if err, ok := metadata["error"].(string); ok { + errorMsg = err + } + return &KeygenResultEvent{ + WalletID: walletID, + ResultType: ResultTypeError, + ErrorReason: errorMsg, + ErrorCode: string(ErrorCodeKeygenFailure), + } +} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 691c712..7e08100 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -2,6 +2,7 @@ package eventconsumer import ( "context" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -167,81 +168,111 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { } walletID := msg.WalletID - ecdsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeECDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) + // ecdsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeECDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) + // if err != nil { + // ec.handleKeygenSessionError(walletID, err, "Failed to create ECDSA key generation session", natMsg) + // return + // } + // eddsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeEDDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) + // if err != nil { + // ec.handleKeygenSessionError(walletID, err, "Failed to create EdDSA key generation session", natMsg) + // return + // } + taurusSession, err := ec.node.CreateCMPKeyGenSession(walletID, ec.mpcThreshold, ec.genKeyResultQueue) if err != nil { - ec.handleKeygenSessionError(walletID, err, "Failed to create ECDSA key generation session", natMsg) + logger.Error("Failed to create Taurus CMP session", err, "walletID", walletID) + ec.handleKeygenSessionError(walletID, err, "Failed to create Taurus CMP key generation session", natMsg) return } - eddsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeEDDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) - if err != nil { - ec.handleKeygenSessionError(walletID, err, "Failed to create EdDSA key generation session", natMsg) - return - } - ecdsaSession.Init() - eddsaSession.Init() - ctxEcdsa, doneEcdsa := context.WithCancel(baseCtx) - ctxEddsa, doneEddsa := context.WithCancel(baseCtx) + // ecdsaSession.Init() + // eddsaSession.Init() + taurusSession.Init() + + // ctxEcdsa, doneEcdsa := context.WithCancel(baseCtx) + // ctxEddsa, doneEddsa := context.WithCancel(baseCtx) + ctxTaurus, doneTaurus := context.WithCancel(baseCtx) successEvent := &event.KeygenResultEvent{WalletID: walletID, ResultType: event.ResultTypeSuccess} var wg sync.WaitGroup - wg.Add(2) + wg.Add(1) // Channel to communicate errors from goroutines to main function - errorChan := make(chan error, 2) + errorChan := make(chan error, 1) + + // go func() { + // defer wg.Done() + // select { + // case <-ctxEcdsa.Done(): + // successEvent.ECDSAPubKey = ecdsaSession.GetPubKeyResult() + // case err := <-ecdsaSession.ErrChan(): + // logger.Error("ECDSA keygen session error", err) + // ec.handleKeygenSessionError(walletID, err, "ECDSA keygen session error", natMsg) + // errorChan <- err + // doneEcdsa() + // } + // }() + // go func() { + // defer wg.Done() + // select { + // case <-ctxEddsa.Done(): + // successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() + // case err := <-eddsaSession.ErrChan(): + // logger.Error("EdDSA keygen session error", err) + // ec.handleKeygenSessionError(walletID, err, "EdDSA keygen session error", natMsg) + // errorChan <- err + // doneEddsa() + // } + // }() go func() { - defer wg.Done() - select { - case <-ctxEcdsa.Done(): - successEvent.ECDSAPubKey = ecdsaSession.GetPubKeyResult() - case err := <-ecdsaSession.ErrChan(): - logger.Error("ECDSA keygen session error", err) - ec.handleKeygenSessionError(walletID, err, "ECDSA keygen session error", natMsg) - errorChan <- err - doneEcdsa() - } - }() - go func() { - defer wg.Done() select { - case <-ctxEddsa.Done(): - successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() - case err := <-eddsaSession.ErrChan(): - logger.Error("EdDSA keygen session error", err) - ec.handleKeygenSessionError(walletID, err, "EdDSA keygen session error", natMsg) - errorChan <- err - doneEddsa() + case <-ctxTaurus.Done(): + return + case err := <-taurusSession.ErrChan(): + if err != nil { + logger.Error("CGGMP21 keygen session error", err) + errorChan <- err + doneTaurus() + } } }() - ecdsaSession.ListenToIncomingMessageAsync() - eddsaSession.ListenToIncomingMessageAsync() - + // ecdsaSession.ListenToIncomingMessageAsync() + // eddsaSession.ListenToIncomingMessageAsync() + taurusSession.ListenToIncomingMessageAsync(taurusSession.ProcessInboundMessage) // Temporary delay for peer setup ec.warmUpSession() - go ecdsaSession.GenerateKey(doneEcdsa) - go eddsaSession.GenerateKey(doneEddsa) + // go ecdsaSession.GenerateKey(doneEcdsa) + // go eddsaSession.GenerateKey(doneEddsa) + go taurusSession.ProcessOutboundMessage() - // Wait for completion or timeout - doneAll := make(chan struct{}) + // Wait for the keygen to complete + completionChan := make(chan string, 1) go func() { - wg.Wait() - close(doneAll) + result := taurusSession.WaitForFinish() + completionChan <- result }() + // Wait for completion, error, or timeout select { - case <-doneAll: - // Check if any errors occurred during execution - select { - case <-errorChan: - // Error already handled by the goroutine, just return early - return - default: - // No errors, continue with success + case pubKeyHex := <-completionChan: + // Success - set the public key + if pubKeyHex != "" { + pubKeyBytes, err := hex.DecodeString(pubKeyHex) + if err == nil { + successEvent.TaurusCMPPubKey = pubKeyBytes + } } + doneTaurus() // Signal completion + + case err := <-errorChan: + // Error occurred + ec.handleKeygenSessionError(walletID, err, "CGGMP21 keygen error", natMsg) + return + case <-baseCtx.Done(): - // timeout occurred + // Timeout occurred logger.Warn("Key generation timed out", "walletID", walletID, "timeout", KeyGenTimeOut) ec.handleKeygenSessionError(walletID, fmt.Errorf("keygen session timed out after %v", KeyGenTimeOut), "Key generation timed out", natMsg) return diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index d615444..fea226d 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -14,6 +14,8 @@ import ( "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc/taurus" + "github.com/taurusgroup/multi-party-sig/pkg/party" ) const ( @@ -140,6 +142,38 @@ func (p *Node) createEDDSAKeyGenSession(walletID string, threshold int, version return session, nil } +func (p *Node) CreateCMPKeyGenSession( + walletID string, + threshold int, + resultQueue messaging.MessageQueue, +) (taurus.KeyGenSession, error) { + if !p.peerRegistry.ArePeersReady() { + return nil, fmt.Errorf( + "peers are not ready yet. ready: %d, expected: %d", + p.peerRegistry.GetReadyPeersCount(), + len(p.peerIDs)+1, + ) + } + + readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() + selfPartyID, allPartyIDs := p.generateTaurusPartyIDs(PurposeKeygen, readyPeerIDs, DefaultVersion) + + session := taurus.NewCGGMP21KeygenSession( + walletID, + p.pubSub, + selfPartyID, + allPartyIDs, + threshold, + p.kvstore, + p.keyinfoStore, + resultQueue, + p.identityStore, + ) + + session.Init() + return session, nil +} + func (p *Node) CreateSigningSession( sessionType SessionType, walletID string, @@ -471,3 +505,27 @@ func sessionKeyPrefix(sessionType SessionType) (string, error) { return "", fmt.Errorf("unsupported session type: %v", sessionType) } } + +func (p *Node) generateTaurusPartyIDs(purpose string, peerIDs []string, version int) (party.ID, []party.ID) { + partyIDs := make([]party.ID, len(peerIDs)) + var selfPartyID party.ID + + for i, peerID := range peerIDs { + partyID := createTaurusPartyID(peerID, purpose, version) + partyIDs[i] = partyID + if peerID == p.nodeID { + selfPartyID = partyID + } + } + + return selfPartyID, partyIDs +} + +func createTaurusPartyID(sessionID string, keyType string, version int) party.ID { + if version == 0 { + // Backward compatible version - just use sessionID + return party.ID(sessionID) + } + // Include version in party ID + return party.ID(fmt.Sprintf("%s:%s:%d", sessionID, keyType, version)) +} diff --git a/pkg/mpc/taurus/ecdsa_resharing_session.go b/pkg/mpc/taurus/ecdsa_resharing_session.go new file mode 100644 index 0000000..0aab2d1 --- /dev/null +++ b/pkg/mpc/taurus/ecdsa_resharing_session.go @@ -0,0 +1,351 @@ +package taurus + +import ( + "crypto/ecdsa" + "encoding/json" + "fmt" + "math/big" + "sync" + + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/protocol" + "github.com/fystack/mpcium/pkg/protocol/cggmp21" + "github.com/fystack/mpcium/pkg/utils" + "github.com/rs/zerolog" + "github.com/taurusgroup/multi-party-sig/pkg/party" +) + +// cggmp21ReshareSession implements ReshareSession for ECDSA using CGGMP21 +type cggmp21ReshareSession struct { + session + isNewPeer bool + pubKeyResult []byte + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + resultQueue messaging.MessageQueue + protocol protocol.Protocol + party protocol.Party + config protocol.KeyGenConfig + newThreshold int + newNodeIDs []string +} + +// newCGGMP21ReshareSession creates a new CGGMP21 reshare session +func newCGGMP21ReshareSession( + walletID string, + threshold int, + newThreshold int, + newNodeIDs []string, + isNewPeer bool, + pubSub messaging.PubSub, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + resultQueue messaging.MessageQueue, + selfNodeID string, +) (*cggmp21ReshareSession, error) { + // Generate session ID for resharing + sessionID := fmt.Sprintf("reshare-%s", walletID) + + // For resharing, we need to determine the party IDs + var partyIDs []party.ID + + if !isNewPeer { + // For old peers, get the existing key info to find current parties + keyInfo, err := keyinfoStore.Get(walletID) + if err != nil { + return nil, fmt.Errorf("failed to get key info for resharing: %w", err) + } + + // Old peers use their existing party IDs + for _, id := range keyInfo.ParticipantPeerIDs { + partyIDs = append(partyIDs, party.ID(id)) + } + } else { + // New peers use the new node IDs + for _, id := range newNodeIDs { + partyIDs = append(partyIDs, party.ID(id)) + } + } + + // Create CGGMP21 protocol + protocol := cggmp21.NewCGGMP21Protocol() + + s := &cggmp21ReshareSession{ + session: session{ + walletID: walletID, + sessionID: sessionID, + pubSub: pubSub, + selfPartyID: party.ID(selfNodeID), + partyIDs: partyIDs, + subscriberList: []messaging.Subscription{}, + rounds: 5, // CGGMP21 has 5 rounds + outCh: make(chan msg, 100), + errCh: make(chan error, 10), + finishCh: make(chan bool, 1), + externalFinishChan: make(chan string, 1), + threshold: threshold, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + resultQueue: resultQueue, + logger: zerolog.New(utils.ZerologConsoleWriter()).With().Timestamp().Logger(), + processing: make(map[string]bool), + processingLock: sync.Mutex{}, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("reshare:broadcast:cggmp21:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("reshare:direct:cggmp21:%s:%s", nodeID, walletID) + }, + }, + identityStore: nil, // Not needed for resharing + }, + isNewPeer: isNewPeer, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + resultQueue: resultQueue, + protocol: protocol, + newThreshold: newThreshold, + newNodeIDs: newNodeIDs, + } + + // Load existing config for old peers + if !isNewPeer { + config, err := s.loadConfig(walletID) + if err != nil { + return nil, fmt.Errorf("failed to load existing config: %w", err) + } + s.config = config + } + + return s, nil +} + +// Init initializes the reshare session +func (s *cggmp21ReshareSession) Init() { + s.logger.Info(). + Str("sessionID", s.sessionID). + Bool("isNewPeer", s.isNewPeer). + Int("threshold", s.threshold). + Int("newThreshold", s.newThreshold). + Msg("Initializing CGGMP21 reshare session") +} + +// Reshare starts the resharing protocol +func (s *cggmp21ReshareSession) Reshare(done func()) { + defer done() + + s.logger.Info(). + Str("sessionID", s.sessionID). + Bool("isNewPeer", s.isNewPeer). + Int("threshold", s.threshold). + Msg("Starting CGGMP21 reshare session") + + // Create the protocol party + var err error + if s.isNewPeer { + // New peers participate in key generation with the new committee + // For new peers, this is essentially a new key generation + // but coordinated with the refresh protocol of old peers + s.party, err = s.protocol.KeyGen( + string(s.selfPartyID), + convertFromPartyIDs(s.partyIDs), + s.newThreshold, + ) + } else { + // Old peers run the refresh protocol + s.party, err = s.protocol.Refresh(s.config) + } + + if err != nil { + s.errCh <- fmt.Errorf("failed to create reshare party: %w", err) + return + } + + // Start listening for messages + s.ListenToIncomingMessageAsync(s.ProcessInboundMessage) + go s.ProcessOutboundMessage() + + // Wait for protocol to complete + <-s.finishCh + + // Process the result + if s.party.Done() { + result, err := s.party.Result() + if err != nil { + s.errCh <- fmt.Errorf("reshare protocol failed: %w", err) + return + } + + // Handle the result based on peer type + if newConfig, ok := result.(protocol.KeyGenConfig); ok { + // Save the new configuration + if err := s.saveConfig(newConfig); err != nil { + s.errCh <- fmt.Errorf("failed to save reshare result: %w", err) + return + } + + // Extract public key for result + pubKey := newConfig.GetPublicKey() + if pubKey != nil { + pubKeyBytes := append(pubKey.X.Bytes(), pubKey.Y.Bytes()...) + s.pubKeyResult = pubKeyBytes + } + + s.logger.Info(). + Str("sessionID", s.sessionID). + Bool("isNewPeer", s.isNewPeer). + Msg("CGGMP21 reshare completed successfully") + } else { + s.errCh <- fmt.Errorf("unexpected result type from reshare: %T", result) + } + } +} + +// ProcessInboundMessage handles incoming protocol messages +func (s *cggmp21ReshareSession) ProcessInboundMessage(msgBytes []byte) { + // Implementation similar to keygen session + // Convert message and send to protocol party +} + +// ProcessOutboundMessage handles outgoing protocol messages +func (s *cggmp21ReshareSession) ProcessOutboundMessage() { + // Implementation similar to keygen session +} + +// GetPubKeyResult returns the public key after successful resharing +func (s *cggmp21ReshareSession) GetPubKeyResult() []byte { + return s.pubKeyResult +} + +// IsNewPeer returns true if this node is joining as a new peer +func (s *cggmp21ReshareSession) IsNewPeer() bool { + return s.isNewPeer +} + +// ErrChan returns the error channel +func (s *cggmp21ReshareSession) ErrChan() <-chan error { + return s.errCh +} + +// Stop stops the session +func (s *cggmp21ReshareSession) Stop() { + // Protocol doesn't have Close method + close(s.outCh) + close(s.errCh) +} + +// WaitForFinish waits for the session to complete +func (s *cggmp21ReshareSession) WaitForFinish() string { + return <-s.externalFinishChan +} + +// loadConfig loads the existing key configuration +func (s *cggmp21ReshareSession) loadConfig(walletID string) (protocol.KeyGenConfig, error) { + // Get key info + keyInfo, err := s.keyinfoStore.Get(walletID) + if err != nil { + return nil, err + } + + // Load the key share data + keyShareData, err := s.kvstore.Get(walletID) + if err != nil { + return nil, err + } + + // Create a config adapter that implements protocol.KeyGenConfig + return &keyGenConfigAdapter{ + keyInfo: keyInfo, + keyShareData: keyShareData, + walletID: walletID, + }, nil +} + +// saveConfig saves the new key configuration after resharing +func (s *cggmp21ReshareSession) saveConfig(config protocol.KeyGenConfig) error { + // Serialize the config + configData, err := config.Serialize() + if err != nil { + return fmt.Errorf("failed to serialize config: %w", err) + } + + // Save to kvstore + if err := s.kvstore.Put(s.walletID, configData); err != nil { + return fmt.Errorf("failed to save share data: %w", err) + } + + // Update key info + keyInfo := &keyinfo.KeyInfo{ + ParticipantPeerIDs: s.newNodeIDs, + Threshold: s.newThreshold, + Version: 1, + } + + if err := s.keyinfoStore.Save(s.walletID, keyInfo); err != nil { + return fmt.Errorf("failed to save key info: %w", err) + } + + return nil +} + +// keyGenConfigAdapter adapts stored key data to protocol.KeyGenConfig interface +type keyGenConfigAdapter struct { + keyInfo *keyinfo.KeyInfo + keyShareData []byte + walletID string +} + +func (a *keyGenConfigAdapter) GetPartyID() string { + // Extract from the stored data - this is implementation specific + var data map[string]interface{} + if err := json.Unmarshal(a.keyShareData, &data); err != nil { + return "" + } + if id, ok := data["ID"].(string); ok { + return id + } + return "" +} + +func (a *keyGenConfigAdapter) GetThreshold() int { + return a.keyInfo.Threshold +} + +func (a *keyGenConfigAdapter) GetPublicKey() *ecdsa.PublicKey { + // Extract from stored data + var data map[string]interface{} + if err := json.Unmarshal(a.keyShareData, &data); err != nil { + return nil + } + + // This is a simplified version - actual implementation would need proper parsing + return nil +} + +func (a *keyGenConfigAdapter) GetShare() *big.Int { + // Extract from the stored data + var data map[string]interface{} + if err := json.Unmarshal(a.keyShareData, &data); err != nil { + return nil + } + + // This is a simplified version - actual implementation would need proper parsing + return nil +} + +func (a *keyGenConfigAdapter) GetSharePublicKey() *ecdsa.PublicKey { + // This would need to be extracted from the stored data + // For now, return nil as it's not critical for refresh + return nil +} + +func (a *keyGenConfigAdapter) GetPartyIDs() []string { + return a.keyInfo.ParticipantPeerIDs +} + +func (a *keyGenConfigAdapter) Serialize() ([]byte, error) { + return a.keyShareData, nil +} diff --git a/pkg/mpc/taurus/eddsa_resharing_session.go b/pkg/mpc/taurus/eddsa_resharing_session.go new file mode 100644 index 0000000..837d8fe --- /dev/null +++ b/pkg/mpc/taurus/eddsa_resharing_session.go @@ -0,0 +1,345 @@ +package taurus + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "encoding/json" + "fmt" + "math/big" + "sync" + + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/protocol" + "github.com/fystack/mpcium/pkg/protocol/frost" + "github.com/fystack/mpcium/pkg/utils" + "github.com/rs/zerolog" + "github.com/taurusgroup/multi-party-sig/pkg/party" +) + +// eddsaReshareSession implements ReshareSession for EdDSA using FROST +type eddsaReshareSession struct { + session + isNewPeer bool + pubKeyResult []byte + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + resultQueue messaging.MessageQueue + protocol protocol.Protocol + party protocol.Party + config protocol.KeyGenConfig + newThreshold int + newNodeIDs []string +} + +// newEdDSAReshareSession creates a new EdDSA reshare session +func newEdDSAReshareSession( + walletID string, + threshold int, + newThreshold int, + newNodeIDs []string, + isNewPeer bool, + pubSub messaging.PubSub, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + resultQueue messaging.MessageQueue, + selfNodeID string, +) (*eddsaReshareSession, error) { + // Generate session ID for resharing + sessionID := fmt.Sprintf("reshare-%s", walletID) + + // For resharing, we need to determine the party IDs + var partyIDs []party.ID + + if !isNewPeer { + // For old peers, get the existing key info to find current parties + keyInfo, err := keyinfoStore.Get(walletID) + if err != nil { + return nil, fmt.Errorf("failed to get key info for resharing: %w", err) + } + + // Old peers use their existing party IDs + for _, id := range keyInfo.ParticipantPeerIDs { + partyIDs = append(partyIDs, party.ID(id)) + } + } else { + // New peers use the new node IDs + for _, id := range newNodeIDs { + partyIDs = append(partyIDs, party.ID(id)) + } + } + + // Create FROST protocol + protocol := frost.NewFROSTProtocol() + + s := &eddsaReshareSession{ + session: session{ + walletID: walletID, + sessionID: sessionID, + pubSub: pubSub, + selfPartyID: party.ID(selfNodeID), + partyIDs: partyIDs, + subscriberList: []messaging.Subscription{}, + rounds: 3, // FROST has fewer rounds + outCh: make(chan msg, 100), + errCh: make(chan error, 10), + finishCh: make(chan bool, 1), + externalFinishChan: make(chan string, 1), + threshold: threshold, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + resultQueue: resultQueue, + logger: zerolog.New(utils.ZerologConsoleWriter()).With().Timestamp().Logger(), + processing: make(map[string]bool), + processingLock: sync.Mutex{}, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("reshare:broadcast:frost:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("reshare:direct:frost:%s:%s", nodeID, walletID) + }, + }, + identityStore: nil, // Not needed for resharing + }, + isNewPeer: isNewPeer, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + resultQueue: resultQueue, + protocol: protocol, + newThreshold: newThreshold, + newNodeIDs: newNodeIDs, + } + + // Load existing config for old peers + if !isNewPeer { + config, err := s.loadConfig(walletID) + if err != nil { + return nil, fmt.Errorf("failed to load existing config: %w", err) + } + s.config = config + } + + return s, nil +} + +// Init initializes the reshare session +func (s *eddsaReshareSession) Init() { + s.logger.Info(). + Str("sessionID", s.sessionID). + Bool("isNewPeer", s.isNewPeer). + Int("threshold", s.threshold). + Int("newThreshold", s.newThreshold). + Msg("Initializing EdDSA/FROST reshare session") +} + +// Reshare starts the resharing protocol +func (s *eddsaReshareSession) Reshare(done func()) { + defer done() + + s.logger.Info(). + Str("sessionID", s.sessionID). + Bool("isNewPeer", s.isNewPeer). + Int("threshold", s.threshold). + Msg("Starting EdDSA/FROST reshare session") + + // Create the protocol party + var err error + if s.isNewPeer { + // New peers participate in key generation with the new committee + s.party, err = s.protocol.KeyGen( + string(s.selfPartyID), + convertFromPartyIDs(s.partyIDs), + s.newThreshold, + ) + } else { + // Old peers run the refresh protocol + s.party, err = s.protocol.Refresh(s.config) + } + + if err != nil { + s.errCh <- fmt.Errorf("failed to create reshare party: %w", err) + return + } + + // Start listening for messages + s.ListenToIncomingMessageAsync(s.ProcessInboundMessage) + go s.ProcessOutboundMessage() + + // Wait for protocol to complete + <-s.finishCh + + // Process the result + if s.party.Done() { + result, err := s.party.Result() + if err != nil { + s.errCh <- fmt.Errorf("reshare protocol failed: %w", err) + return + } + + // Handle the result + if newConfig, ok := result.(protocol.KeyGenConfig); ok { + // Save the new configuration + if err := s.saveConfig(newConfig); err != nil { + s.errCh <- fmt.Errorf("failed to save reshare result: %w", err) + return + } + + // For EdDSA, we would extract the Ed25519 public key + // This is a placeholder - actual implementation would depend on the protocol + s.pubKeyResult = []byte{} // Placeholder + + s.logger.Info(). + Str("sessionID", s.sessionID). + Bool("isNewPeer", s.isNewPeer). + Msg("EdDSA/FROST reshare completed successfully") + } else { + s.errCh <- fmt.Errorf("unexpected result type from reshare: %T", result) + } + } +} + +// ProcessInboundMessage handles incoming protocol messages +func (s *eddsaReshareSession) ProcessInboundMessage(msgBytes []byte) { + // Implementation similar to keygen session +} + +// ProcessOutboundMessage handles outgoing protocol messages +func (s *eddsaReshareSession) ProcessOutboundMessage() { + // Implementation similar to keygen session +} + +// GetPubKeyResult returns the public key after successful resharing +func (s *eddsaReshareSession) GetPubKeyResult() []byte { + return s.pubKeyResult +} + +// IsNewPeer returns true if this node is joining as a new peer +func (s *eddsaReshareSession) IsNewPeer() bool { + return s.isNewPeer +} + +// ErrChan returns the error channel +func (s *eddsaReshareSession) ErrChan() <-chan error { + return s.errCh +} + +// Stop stops the session +func (s *eddsaReshareSession) Stop() { + // Protocol doesn't have Close method + close(s.outCh) + close(s.errCh) +} + +// WaitForFinish waits for the session to complete +func (s *eddsaReshareSession) WaitForFinish() string { + return <-s.externalFinishChan +} + +// loadConfig loads the existing key configuration +func (s *eddsaReshareSession) loadConfig(walletID string) (protocol.KeyGenConfig, error) { + // Get key info + keyInfo, err := s.keyinfoStore.Get(walletID) + if err != nil { + return nil, err + } + + // Load the key share data + keyShareData, err := s.kvstore.Get(walletID) + if err != nil { + return nil, err + } + + // Create a config adapter for EdDSA + return &eddsaKeyGenConfigAdapter{ + keyInfo: keyInfo, + keyShareData: keyShareData, + walletID: walletID, + }, nil +} + +// saveConfig saves the new key configuration after resharing +func (s *eddsaReshareSession) saveConfig(config protocol.KeyGenConfig) error { + // Serialize the config + configData, err := config.Serialize() + if err != nil { + return fmt.Errorf("failed to serialize config: %w", err) + } + + // Save to kvstore + if err := s.kvstore.Put(s.walletID, configData); err != nil { + return fmt.Errorf("failed to save share data: %w", err) + } + + // Update key info + keyInfo := &keyinfo.KeyInfo{ + ParticipantPeerIDs: s.newNodeIDs, + Threshold: s.newThreshold, + Version: 1, + } + + if err := s.keyinfoStore.Save(s.walletID, keyInfo); err != nil { + return fmt.Errorf("failed to save key info: %w", err) + } + + return nil +} + +// eddsaKeyGenConfigAdapter adapts stored key data to protocol.KeyGenConfig interface for EdDSA +type eddsaKeyGenConfigAdapter struct { + keyInfo *keyinfo.KeyInfo + keyShareData []byte + walletID string +} + +func (a *eddsaKeyGenConfigAdapter) GetPartyID() string { + // Extract from the stored data + var data map[string]interface{} + if err := json.Unmarshal(a.keyShareData, &data); err != nil { + return "" + } + if id, ok := data["ID"].(string); ok { + return id + } + return "" +} + +func (a *eddsaKeyGenConfigAdapter) GetThreshold() int { + return a.keyInfo.Threshold +} + +func (a *eddsaKeyGenConfigAdapter) GetPublicKey() *ecdsa.PublicKey { + // EdDSA doesn't use ECDSA public keys + return nil +} + +// GetPublicKeyEd25519 returns the Ed25519 public key +func (a *eddsaKeyGenConfigAdapter) GetPublicKeyEd25519() ed25519.PublicKey { + // Extract from stored data + var data map[string]interface{} + if err := json.Unmarshal(a.keyShareData, &data); err != nil { + return nil + } + + // This is a simplified version - actual implementation would need proper parsing + return nil +} + +func (a *eddsaKeyGenConfigAdapter) GetShare() *big.Int { + // EdDSA shares are handled differently + return nil +} + +func (a *eddsaKeyGenConfigAdapter) GetSharePublicKey() *ecdsa.PublicKey { + // EdDSA doesn't use ECDSA public keys + return nil +} + +func (a *eddsaKeyGenConfigAdapter) GetPartyIDs() []string { + return a.keyInfo.ParticipantPeerIDs +} + +func (a *eddsaKeyGenConfigAdapter) Serialize() ([]byte, error) { + return a.keyShareData, nil +} diff --git a/pkg/mpc/taurus/keygen_session.go b/pkg/mpc/taurus/keygen_session.go new file mode 100644 index 0000000..c87e1be --- /dev/null +++ b/pkg/mpc/taurus/keygen_session.go @@ -0,0 +1,328 @@ +package taurus + +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/types" + "github.com/fystack/mpcium/pkg/utils" + "github.com/rs/zerolog" + "github.com/taurusgroup/multi-party-sig/pkg/math/curve" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/pool" + "github.com/taurusgroup/multi-party-sig/pkg/protocol" + "github.com/taurusgroup/multi-party-sig/protocols/cmp" + "github.com/taurusgroup/multi-party-sig/protocols/cmp/config" +) + +type KeyGenSession interface { + Session +} + +type cggmp21KeygenSession struct { + session + handler *protocol.MultiHandler + pool *pool.Pool + config *config.Config + messagesCh chan *protocol.Message + resultMutex sync.Mutex + done bool + resultErr error +} + +func NewCGGMP21KeygenSession( + walletID string, + pubSub messaging.PubSub, + selfPartyID party.ID, + partyIDs []party.ID, + threshold int, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + resultQueue messaging.MessageQueue, + identityStore identity.Store, +) *cggmp21KeygenSession { + // Create thread pool + threadPool := pool.NewPool(0) // Use max threads + + return &cggmp21KeygenSession{ + session: session{ + walletID: walletID, + pubSub: pubSub, + selfPartyID: selfPartyID, + partyIDs: partyIDs, + subscriberList: []messaging.Subscription{}, + rounds: 5, // CGGMP21 keygen has 5 rounds + outCh: make(chan msg, 100), + errCh: make(chan error, 10), + finishCh: make(chan bool, 1), + externalFinishChan: make(chan string, 1), + threshold: threshold, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + resultQueue: resultQueue, + logger: zerolog.New(utils.ZerologConsoleWriter()).With().Timestamp().Logger(), + processing: make(map[string]bool), + processingLock: sync.Mutex{}, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("keygen:broadcast:cggmp21:%s", walletID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("keygen:direct:cggmp21:%s:%s", nodeID, walletID) + }, + }, + identityStore: identityStore, + }, + pool: threadPool, + messagesCh: make(chan *protocol.Message, 100), + done: false, + } +} + +func (s *cggmp21KeygenSession) Init() { + s.logger.Info(). + Int("threshold", s.threshold). + Interface("partyIDs", s.partyIDs). + Msg("Initializing CGGMP21 keygen session") + + // Create CGGMP21 keygen protocol + startFunc := cmp.Keygen(curve.Secp256k1{}, s.selfPartyID, s.partyIDs, s.threshold, s.pool) + + // Create handler + handler, err := protocol.NewMultiHandler(startFunc, nil) + if err != nil { + s.logger.Fatal().Err(err).Msg("Failed to create keygen handler") + return + } + + s.handler = handler + + // Start message handling goroutine + go s.handleProtocolMessages() + + s.logger.Info(). + Str("partyID", string(s.selfPartyID)). + Interface("peerIDs", s.partyIDs). + Str("walletID", s.walletID). + Msg("[INITIALIZED] CGGMP21 keygen session initialized successfully") +} + +func (s *cggmp21KeygenSession) handleProtocolMessages() { + for { + select { + case protoMsg, ok := <-s.handler.Listen(): + if !ok { + // Protocol finished + s.resultMutex.Lock() + s.done = true + result, err := s.handler.Result() + if err != nil { + s.resultErr = err + s.errCh <- err + } else { + s.config = result.(*config.Config) + } + s.resultMutex.Unlock() + s.finishCh <- true + return + } + + // Convert protocol message to our message format + var toPartyIDs []party.ID + if !protoMsg.Broadcast && protoMsg.To != "" { + toPartyIDs = []party.ID{protoMsg.To} + } + outMsg := msg{ + FromPartyID: protoMsg.From, + ToPartyIDs: toPartyIDs, + IsBroadcast: protoMsg.Broadcast, + Data: protoMsg.Data, + } + + s.outCh <- outMsg + + case protoMsg := <-s.messagesCh: + // Handle incoming message + if !s.handler.CanAccept(protoMsg) { + s.logger.Warn().Msgf("Handler cannot accept message from %s", protoMsg.From) + continue + } + + s.handler.Accept(protoMsg) + } + } +} + +func (s *cggmp21KeygenSession) ProcessInboundMessage(msgBytes []byte) { + s.processingLock.Lock() + defer s.processingLock.Unlock() + + inboundMessage := &types.TaurusMessage{} + if err := json.Unmarshal(msgBytes, inboundMessage); err != nil { + s.logger.Error().Err(err).Msg("ProcessInboundMessage unmarshal error") + return + } + + msgHashStr := fmt.Sprintf("%x", utils.GetMessageHash(msgBytes)) + if s.processing[msgHashStr] { + return + } + s.processing[msgHashStr] = true + + // Convert to protocol message + protoMsg := &protocol.Message{ + From: party.ID(inboundMessage.SenderID), + To: party.ID(""), // Single recipient for protocol messages + Data: inboundMessage.Body, + Broadcast: inboundMessage.IsBroadcast, + } + + // Send to handler + s.messagesCh <- protoMsg +} + +func (s *cggmp21KeygenSession) ProcessOutboundMessage() { + s.logger.Info().Msgf("ProcessOutboundMessage started: %s", s.walletID) + for { + select { + case m := <-s.outCh: + // Convert party IDs back to strings + recipientIDs := make([]string, len(m.ToPartyIDs)) + for i, pid := range m.ToPartyIDs { + recipientIDs[i] = string(pid) + } + + msgWireBytes := &types.TaurusMessage{ + SessionID: s.walletID, + SenderID: string(m.FromPartyID), + RecipientIDs: recipientIDs, + Body: m.Data, + IsBroadcast: m.IsBroadcast, + } + + s.sendMsg(msgWireBytes) + + case err := <-s.errCh: + s.logger.Error().Err(err).Msg("Received error during ProcessOutboundMessage") + + case <-s.finishCh: + s.logger.Info().Msg("Received finish message during ProcessOutboundMessage") + s.publishResult() + return + } + } +} + +func (s *cggmp21KeygenSession) publishResult() { + s.resultMutex.Lock() + defer s.resultMutex.Unlock() + + if s.resultErr != nil { + // failureEvent := event.CreateKeygenFailure( + // s.walletID, + // map[string]any{ + // "error": s.resultErr.Error(), + // }, + // ) + // evtData, _ := json.Marshal(failureEvent) + // if err := s.resultQueue.Enqueue(fmt.Sprintf("mpc.keygen_result.%s", s.walletID), evtData, nil); err != nil { + // s.logger.Error().Err(err).Msg("failed to publish keygen failure event") + // } + return + } + + if s.config == nil { + s.logger.Error().Msg("No config available after keygen completion") + return + } + + // Save key share + shareBytes, err := json.Marshal(s.config) + if err != nil { + s.logger.Error().Err(err).Msg("Failed to marshal key share") + return + } + + if err := s.kvstore.Put(s.walletID, shareBytes); err != nil { + s.logger.Error().Err(err).Msgf("Failed to save key share for wallet %s", s.walletID) + return + } + + // Convert public key to hex + // Use the X coordinate as a simple representation + var pubKeyHex string + if s.config != nil && s.config.PublicPoint() != nil { + if xScalar := s.config.PublicPoint().XScalar(); xScalar != nil { + xBytes, _ := xScalar.MarshalBinary() + pubKeyHex = fmt.Sprintf("%x", xBytes) + } + } + + // Save key info + keyInfo := &keyinfo.KeyInfo{ + ParticipantPeerIDs: convertFromPartyIDs(s.partyIDs), + Threshold: s.threshold, + Version: 1, + } + + if err := s.keyinfoStore.Save(s.walletID, keyInfo); err != nil { + s.logger.Error().Err(err).Msgf("Failed to save key info for wallet %s", s.walletID) + return + } + + // Publish success event + // successEvent := event.CreateKeygenSuccess( + // s.walletID, + // pubKeyHex, + // map[string]any{ + // "threshold": s.threshold, + // "parties": len(s.partyIDs), + // "protocol": "CGGMP21", + // }, + // ) + + // evtData, _ := json.Marshal(successEvent) + // if err := s.resultQueue.Enqueue(fmt.Sprintf("mpc.keygen_result.%s", s.walletID), evtData, nil); err != nil { + // s.logger.Error().Err(err).Msg("failed to publish keygen success event") + // } + + s.logger.Info(). + Str("walletID", s.walletID). + Str("publicKey", pubKeyHex). + Msg("CGGMP21 keygen completed successfully") +} + +func (s *cggmp21KeygenSession) Stop() { + if s.pool != nil { + s.pool.TearDown() + } + close(s.outCh) + close(s.errCh) + close(s.messagesCh) +} + +func (s *cggmp21KeygenSession) WaitForFinish() string { + return <-s.externalFinishChan +} + +// Helper functions +func convertToPartyIDs(ids []string) []party.ID { + result := make([]party.ID, len(ids)) + for i, id := range ids { + result[i] = party.ID(id) + } + return result +} + +func convertFromPartyIDs(ids []party.ID) []string { + result := make([]string, len(ids)) + for i, id := range ids { + result[i] = string(id) + } + return result +} diff --git a/pkg/mpc/taurus/node.go b/pkg/mpc/taurus/node.go new file mode 100644 index 0000000..c64fe15 --- /dev/null +++ b/pkg/mpc/taurus/node.go @@ -0,0 +1,242 @@ +package taurus + +import ( + "fmt" + "time" + + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/taurusgroup/multi-party-sig/pkg/party" +) + +const ( + PurposeKeygen string = "keygen" + PurposeSign string = "sign" + PurposeReshare string = "reshare" + + DefaultVersion int = 1 +) + +type ID string + +type Node struct { + nodeID string + peerIDs []string + pubSub messaging.PubSub + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + identityStore identity.Store + peerRegistry PeerRegistry +} + +func ComposeReadyKey(nodeID string) string { + return fmt.Sprintf("ready/%s", nodeID) +} + +func NewNode( + nodeID string, + peerIDs []string, + pubSub messaging.PubSub, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + peerRegistry PeerRegistry, + identityStore identity.Store, +) *Node { + start := time.Now() + elapsed := time.Since(start) + logger.Info("Starting new CGGMP21 node", "nodeID", nodeID, "elapsed", elapsed.Milliseconds()) + + node := &Node{ + nodeID: nodeID, + peerIDs: peerIDs, + pubSub: pubSub, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + peerRegistry: peerRegistry, + identityStore: identityStore, + } + + go peerRegistry.WatchPeersReady() + return node +} + +func (p *Node) ID() string { + return p.nodeID +} + +func (p *Node) KeyInfoStore() keyinfo.Store { + return p.keyinfoStore +} + +// func (p *Node) CreateKeyGenSession( +// walletID string, +// threshold int, +// resultQueue messaging.MessageQueue, +// ) (KeyGenSession, error) { +// if !p.peerRegistry.ArePeersReady() { +// return nil, fmt.Errorf( +// "peers are not ready yet. ready: %d, expected: %d", +// p.peerRegistry.GetReadyPeersCount(), +// len(p.peerIDs)+1, +// ) +// } + +// readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() +// selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs, DefaultVersion) + +// session := newCGGMP21KeygenSession( +// walletID, +// p.pubSub, +// selfPartyID, +// allPartyIDs, +// threshold, +// p.kvstore, +// p.keyinfoStore, +// resultQueue, +// p.identityStore, +// ) + +// session.Init() +// return session, nil +// } + +func (p *Node) CreateSignSession( + sessionID string, + walletID string, + messageHash []byte, + signerPeerIDs []string, + resultQueue messaging.MessageQueue, + useBroadcast bool, +) (SignSession, error) { + // Check if we have enough signers + keyInfo, err := p.keyinfoStore.Get(walletID) + if err != nil { + return nil, fmt.Errorf("failed to get key info: %w", err) + } + + if len(signerPeerIDs) < keyInfo.Threshold+1 { + return nil, ErrNotEnoughParticipants + } + + // Check if this node is in the signer list + if !contains(signerPeerIDs, p.nodeID) { + return nil, ErrNotInParticipantList + } + + // Generate party IDs for signers + version := p.getVersion(SessionTypeCGGMP21, walletID) + selfPartyID, signerPartyIDs := p.generatePartyIDs(PurposeSign, signerPeerIDs, version) + + session, err := newCGGMP21SigningSession( + sessionID, + walletID, + messageHash, + p.pubSub, + selfPartyID, + signerPartyIDs, + p.kvstore, + p.keyinfoStore, + resultQueue, + p.identityStore, + useBroadcast, + ) + if err != nil { + return nil, err + } + + session.Init() + return session, nil +} + +func (p *Node) generatePartyIDs(purpose string, peerIDs []string, version int) (party.ID, []party.ID) { + partyIDs := make([]party.ID, len(peerIDs)) + var selfPartyID party.ID + + for i, peerID := range peerIDs { + partyID := createPartyID(peerID, purpose, version) + partyIDs[i] = partyID + if peerID == p.nodeID { + selfPartyID = partyID + } + } + + return selfPartyID, partyIDs +} + +func createPartyID(sessionID string, keyType string, version int) party.ID { + if version == 0 { + // Backward compatible version - just use sessionID + return party.ID(sessionID) + } + // Include version in party ID + return party.ID(fmt.Sprintf("%s:%s:%d", sessionID, keyType, version)) +} + +func (p *Node) getVersion(sessionType SessionType, walletID string) int { + // In production, you might want to store and retrieve version info + // For now, always use the default version + return DefaultVersion +} + +func (p *Node) CreateReshareSession( + sessionType SessionType, + walletID string, + threshold int, + newThreshold int, + newNodeIDs []string, + isNewPeer bool, + resultQueue messaging.MessageQueue, +) (ReshareSession, error) { + logger.Info("Creating reshare session", + "sessionType", sessionType, + "walletID", walletID, + "threshold", threshold, + "newThreshold", newThreshold, + "newNodeIDs", newNodeIDs, + "isNewPeer", isNewPeer, + "nodeID", p.nodeID, + ) + + switch sessionType { + case SessionTypeECDSA: + return newCGGMP21ReshareSession( + walletID, + threshold, + newThreshold, + newNodeIDs, + isNewPeer, + p.pubSub, + p.kvstore, + p.keyinfoStore, + resultQueue, + p.nodeID, + ) + case SessionTypeEDDSA: + return newEdDSAReshareSession( + walletID, + threshold, + newThreshold, + newNodeIDs, + isNewPeer, + p.pubSub, + p.kvstore, + p.keyinfoStore, + resultQueue, + p.nodeID, + ) + default: + return nil, fmt.Errorf("unsupported session type for reshare: %v", sessionType) + } +} + +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} diff --git a/pkg/mpc/taurus/node_test.go b/pkg/mpc/taurus/node_test.go new file mode 100644 index 0000000..805b3b2 --- /dev/null +++ b/pkg/mpc/taurus/node_test.go @@ -0,0 +1,119 @@ +package taurus + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPartyIDToNodeID(t *testing.T) { + partyID := createPartyID("4d8cb873-dc86-4776-b6f6-cf5c668f6468", "keygen", 1) + nodeID := PartyIDToRoutingDest(partyID) + assert.Equal(t, "4d8cb873-dc86-4776-b6f6-cf5c668f6468:keygen:1", nodeID, "NodeID should be equal") +} + +func TestCreatePartyID_Structure(t *testing.T) { + sessionID := "test-session-123" + keyType := "keygen" + version := 5 + + partyID := createPartyID(sessionID, keyType, version) + + assert.NotNil(t, partyID) + // The party ID should be in the format sessionID:keyType:version + expectedID := "test-session-123:keygen:5" + assert.Equal(t, expectedID, string(partyID)) +} + +func TestCreatePartyID_DifferentVersions(t *testing.T) { + sessionID := "test-session-456" + keyType := "keygen" + + // Test version 0 (backward compatible) + partyID0 := createPartyID(sessionID, keyType, 0) + assert.NotNil(t, partyID0) + // Version 0 should just be the sessionID + assert.Equal(t, sessionID, string(partyID0)) + + // Test version 1 (default) + partyID1 := createPartyID(sessionID, keyType, DefaultVersion) + assert.NotNil(t, partyID1) + // Version 1 should include version info + expectedID1 := "test-session-456:keygen:1" + assert.Equal(t, expectedID1, string(partyID1)) + + // Different versions should produce different party IDs + assert.NotEqual(t, partyID0, partyID1) +} + +func TestPartyIDToRoutingDest_BackwardCompatible(t *testing.T) { + sessionID := "test-session-789" + keyType := "signing" + + partyID := createPartyID(sessionID, keyType, 0) + nodeID := PartyIDToRoutingDest(partyID) + + // For backward compatible version, should just be the sessionID + assert.Equal(t, sessionID, nodeID) +} + +func TestPartyIDToRoutingDest_DefaultVersion(t *testing.T) { + sessionID := "test-session-999" + keyType := "signing" + + partyID := createPartyID(sessionID, keyType, DefaultVersion) + nodeID := PartyIDToRoutingDest(partyID) + + // For default version, should be the full party ID string + expected := "test-session-999:signing:1" + assert.Equal(t, expected, nodeID) +} + +func TestCreatePartyID_EmptyValues(t *testing.T) { + // Test with empty session ID + partyID := createPartyID("", "keygen", 0) + assert.NotNil(t, partyID) + // Version 0 should just return empty string + assert.Equal(t, "", string(partyID)) + + // Test with empty key type + partyID = createPartyID("session", "", 1) + assert.NotNil(t, partyID) + // Should still create the party ID with format + expectedID := "session::1" + assert.Equal(t, expectedID, string(partyID)) +} + +func TestPartyIDToRoutingDest_Consistency(t *testing.T) { + sessionID := "consistent-session" + keyType := "keygen" + version := 3 + + // Create the same party ID multiple times + partyID1 := createPartyID(sessionID, keyType, version) + partyID2 := createPartyID(sessionID, keyType, version) + + nodeID1 := PartyIDToRoutingDest(partyID1) + nodeID2 := PartyIDToRoutingDest(partyID2) + + // Should produce consistent results based on sessionID and version + assert.Equal(t, nodeID1, nodeID2, "Same parameters should produce same routing destinations") +} + +func TestCreatePartyID_SameParameters(t *testing.T) { + sessionID := "test-session" + keyType := "keygen" + version := 1 + + // Create multiple party IDs with same parameters + partyID1 := createPartyID(sessionID, keyType, version) + partyID2 := createPartyID(sessionID, keyType, version) + + // Party IDs with same parameters should be identical in the new implementation + assert.Equal(t, partyID1, partyID2, "Party IDs with same parameters should be equal") + + // Both should have the same format + expectedID := "test-session:keygen:1" + assert.Equal(t, expectedID, string(partyID1)) + assert.Equal(t, expectedID, string(partyID2)) +} diff --git a/pkg/mpc/taurus/registry.go b/pkg/mpc/taurus/registry.go new file mode 100644 index 0000000..cfe334d --- /dev/null +++ b/pkg/mpc/taurus/registry.go @@ -0,0 +1,213 @@ +package taurus + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/fystack/mpcium/pkg/infra" + "github.com/fystack/mpcium/pkg/logger" + "github.com/hashicorp/consul/api" + "github.com/samber/lo" +) + +const ( + ReadinessCheckPeriod = 1 * time.Second +) + +type PeerRegistry interface { + Ready() error + ArePeersReady() bool + WatchPeersReady() + // Resign is called by the node when it is going to shutdown + Resign() error + GetReadyPeersCount() int64 + GetReadyPeersIncludeSelf() []string // get ready peers include self + GetTotalPeersCount() int64 +} + +type registry struct { + nodeID string + peerNodeIDs []string + readyMap map[string]bool + readyCount int64 + mu sync.RWMutex + ready bool // ready is true when all peers are ready + + consulKV infra.ConsulKV +} + +func NewRegistry( + nodeID string, + peerNodeIDs []string, + consulKV infra.ConsulKV, +) *registry { + return ®istry{ + consulKV: consulKV, + nodeID: nodeID, + peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), + readyMap: make(map[string]bool), + readyCount: 1, // self + } +} + +func getPeerIDsExceptSelf(nodeID string, peerNodeIDs []string) []string { + peerIDs := make([]string, 0, len(peerNodeIDs)) + for _, peerID := range peerNodeIDs { + if peerID != nodeID { + peerIDs = append(peerIDs, peerID) + } + } + return peerIDs +} + +func (r *registry) readyKey(nodeID string) string { + return fmt.Sprintf("ready/%s", nodeID) +} + +func (r *registry) registerReadyPairs(peerIDs []string) { + for _, peerID := range peerIDs { + ready, exist := r.readyMap[peerID] + if !exist { + atomic.AddInt64(&r.readyCount, 1) + logger.Info("Register", "peerID", peerID) + } else if !ready { + atomic.AddInt64(&r.readyCount, 1) + logger.Info("Reconnecting...", "peerID", peerID) + } + + r.readyMap[peerID] = true + } + + if len(peerIDs) == len(r.peerNodeIDs) && !r.ready { + r.mu.Lock() + r.ready = true + r.mu.Unlock() + logger.Info("ALL PEERS ARE READY! Starting to accept MPC requests") + } + +} + +// Ready is called by the node when it complete generate preparams and starting to accept +// incoming requests +func (r *registry) Ready() error { + k := r.readyKey(r.nodeID) + + kv := &api.KVPair{ + Key: k, + Value: []byte("true"), + } + + _, err := r.consulKV.Put(kv, nil) + if err != nil { + return fmt.Errorf("Put ready key failed: %w", err) + } + + return nil +} + +func (r *registry) WatchPeersReady() { + ticker := time.NewTicker(ReadinessCheckPeriod) + go r.logReadyStatus() + // first tick is executed immediately + for ; true; <-ticker.C { + pairs, _, err := r.consulKV.List("ready/", nil) + if err != nil { + logger.Error("List ready keys failed", err) + } + + newReadyPeerIDs := r.getReadyPeersFromKVStore(pairs) + if len(newReadyPeerIDs) != len(r.peerNodeIDs) { + r.mu.Lock() + r.ready = false + r.mu.Unlock() + + var readyPeerIDs []string + for peerID, isReady := range r.readyMap { + if isReady { + readyPeerIDs = append(readyPeerIDs, peerID) + } + } + + disconnecteds, _ := lo.Difference(readyPeerIDs, newReadyPeerIDs) + if len(disconnecteds) > 0 { + for _, peerID := range disconnecteds { + logger.Warn("Peer disconnected!", "peerID", peerID) + r.readyMap[peerID] = false + atomic.AddInt64(&r.readyCount, -1) + } + + } + + } + r.registerReadyPairs(newReadyPeerIDs) + } + +} + +func (r *registry) logReadyStatus() { + for { + time.Sleep(5 * time.Second) + if !r.ArePeersReady() { + logger.Info("Peers are not ready yet", "ready", r.GetReadyPeersCount(), "expected", len(r.peerNodeIDs)+1) + } + } +} + +func (r *registry) GetReadyPeersCount() int64 { + return atomic.LoadInt64(&r.readyCount) +} + +func (r *registry) GetReadyPeersIncludeSelf() []string { + var peerIDs []string + for peerID, isReady := range r.readyMap { + if isReady { + peerIDs = append(peerIDs, peerID) + } + } + + peerIDs = append(peerIDs, r.nodeID) // append self + return peerIDs +} + +func (r *registry) getReadyPeersFromKVStore(kvPairs api.KVPairs) []string { + var peers []string + for _, k := range kvPairs { + var peerNodeID string + _, err := fmt.Sscanf(k.Key, "ready/%s", &peerNodeID) + if err != nil { + logger.Error("Parse ready key failed", err) + } + if peerNodeID == r.nodeID { + continue + } + + peers = append(peers, peerNodeID) + } + + return peers +} + +func (r *registry) ArePeersReady() bool { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.ready +} + +func (r *registry) GetTotalPeersCount() int64 { + var self int64 = 1 + return int64(len(r.peerNodeIDs)) + self +} + +func (r *registry) Resign() error { + k := r.readyKey(r.nodeID) + + _, err := r.consulKV.Delete(k, nil) + if err != nil { + return fmt.Errorf("Delete ready key failed: %w", err) + } + + return nil +} diff --git a/pkg/mpc/taurus/reshare_session.go b/pkg/mpc/taurus/reshare_session.go new file mode 100644 index 0000000..2d3d9cf --- /dev/null +++ b/pkg/mpc/taurus/reshare_session.go @@ -0,0 +1,32 @@ +package taurus + +// ReshareSession represents a threshold signature resharing session +type ReshareSession interface { + Session + + // Reshare starts the resharing protocol + Reshare(done func()) + + // GetPubKeyResult returns the public key after successful resharing + GetPubKeyResult() []byte + + // IsNewPeer returns true if this node is joining as a new peer + IsNewPeer() bool +} + +// BaseReshareSession provides common functionality for reshare sessions +type BaseReshareSession struct { + session + isNewPeer bool + pubKeyResult []byte +} + +// IsNewPeer returns true if this node is joining as a new peer +func (s *BaseReshareSession) IsNewPeer() bool { + return s.isNewPeer +} + +// GetPubKeyResult returns the public key after successful resharing +func (s *BaseReshareSession) GetPubKeyResult() []byte { + return s.pubKeyResult +} diff --git a/pkg/mpc/taurus/session.go b/pkg/mpc/taurus/session.go new file mode 100644 index 0000000..bfa5921 --- /dev/null +++ b/pkg/mpc/taurus/session.go @@ -0,0 +1,184 @@ +package taurus + +import ( + "sync" + + "github.com/nats-io/nats.go" + "github.com/rs/zerolog" + "github.com/taurusgroup/multi-party-sig/pkg/party" + + "github.com/fystack/mpcium/pkg/common/errors" + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/types" +) + +type SessionType string + +const ( + TypeGenerateWalletResultFmt = "mpc.mpc_keygen_result.%s" + TypeReshareWalletResultFmt = "mpc.mpc_reshare_result.%s" + + SessionTypeCGGMP21 SessionType = "session_cggmp21" + SessionTypeECDSA SessionType = "ecdsa" + SessionTypeEDDSA SessionType = "eddsa" +) + +var ( + ErrNotEnoughParticipants = errors.New("Not enough participants to sign") + ErrNotInParticipantList = errors.New("Node is not in the participant list") +) + +type TopicComposer struct { + ComposeBroadcastTopic func() string + ComposeDirectTopic func(nodeID string) string +} + +type KeyComposerFn func(id string) string + +type Session interface { + ListenToIncomingMessageAsync(f func(msgBytes []byte)) + ErrChan() <-chan error + Init() + ProcessInboundMessage(msgBytes []byte) + ProcessOutboundMessage() + WaitForFinish() string +} + +type session struct { + walletID string + sessionID string + pubSub messaging.PubSub + selfPartyID party.ID + partyIDs []party.ID + subscriberList []messaging.Subscription + rounds int + outCh chan msg + errCh chan error + finishCh chan bool + externalFinishChan chan string + threshold int + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + resultQueue messaging.MessageQueue + logger zerolog.Logger + processing map[string]bool + processingLock sync.Mutex + topicComposer *TopicComposer + identityStore identity.Store +} + +type msg struct { + FromPartyID party.ID + ToPartyIDs []party.ID + IsBroadcast bool + Data []byte +} + +func (s *session) ProcessInboundMessage(msgBytes []byte) { + // This should be implemented by specific session types + // If this method is called directly on the base session type, it means + // the concrete type doesn't properly implement ProcessInboundMessage + panic("ProcessInboundMessage must be implemented by session type") +} + +func (s *session) ListenToIncomingMessageAsync(f func(msgBytes []byte)) { + // Subscribe to broadcast messages + broadcastTopic := s.topicComposer.ComposeBroadcastTopic() + broadcastSub, err := s.pubSub.Subscribe(broadcastTopic, func(m *nats.Msg) { + s.logger.Debug(). + Str("topic", broadcastTopic). + Int("size", len(m.Data)). + Msg("Received broadcast message") + f(m.Data) + }) + + if err != nil { + s.logger.Error().Err(err).Msgf("Failed to subscribe to broadcast topic %s", broadcastTopic) + s.errCh <- err + return + } + + s.subscriberList = append(s.subscriberList, broadcastSub) + + // Subscribe to direct messages + directTopic := s.topicComposer.ComposeDirectTopic(string(s.selfPartyID)) + directSub, err := s.pubSub.Subscribe(directTopic, func(m *nats.Msg) { + s.logger.Debug(). + Str("topic", directTopic). + Int("size", len(m.Data)). + Msg("Received direct message") + f(m.Data) + }) + + if err != nil { + s.logger.Error().Err(err).Msgf("Failed to subscribe to direct topic %s", directTopic) + s.errCh <- err + return + } + + s.subscriberList = append(s.subscriberList, directSub) + + s.logger.Info(). + Str("broadcast", broadcastTopic). + Str("direct", directTopic). + Msg("Listening to incoming messages") +} + +func (s *session) sendMsg(message *types.TaurusMessage) { + data, err := encoding.StructToJsonBytes(message) + if err != nil { + s.logger.Error().Err(err).Msg("Failed to marshal message") + return + } + + if message.IsBroadcast { + topic := s.topicComposer.ComposeBroadcastTopic() + if err := s.pubSub.Publish(topic, data); err != nil { + s.logger.Error().Err(err).Msgf("Failed to publish broadcast message to %s", topic) + } else { + s.logger.Debug().Str("topic", topic).Msg("Published broadcast message") + } + } else { + // Send to specific recipients + for _, recipient := range message.RecipientIDs { + topic := s.topicComposer.ComposeDirectTopic(recipient) + if err := s.pubSub.Publish(topic, data); err != nil { + s.logger.Error().Err(err).Msgf("Failed to publish direct message to %s", topic) + } else { + s.logger.Debug(). + Str("topic", topic). + Str("recipient", recipient). + Msg("Published direct message") + } + } + } +} + +func (s *session) ErrChan() <-chan error { + return s.errCh +} + +func (s *session) unsubscribe() { + for _, sub := range s.subscriberList { + if err := sub.Unsubscribe(); err != nil { + s.logger.Error().Err(err).Msg("Failed to unsubscribe") + } + } + s.subscriberList = nil +} + +func (s *session) Stop() { + s.unsubscribe() +} + +// Helper function to get party routing destination +func PartyIDToRoutingDest(partyID party.ID) string { + // Extract node ID from party ID if it contains version info + nodeID := string(partyID) + // Simple extraction - in production you'd have more robust parsing + return nodeID +} diff --git a/pkg/mpc/taurus/signing_session.go b/pkg/mpc/taurus/signing_session.go new file mode 100644 index 0000000..8b12ec3 --- /dev/null +++ b/pkg/mpc/taurus/signing_session.go @@ -0,0 +1,332 @@ +package taurus + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "sync" + + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/identity" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/types" + "github.com/fystack/mpcium/pkg/utils" + "github.com/rs/zerolog" + "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/pool" + "github.com/taurusgroup/multi-party-sig/pkg/protocol" + "github.com/taurusgroup/multi-party-sig/protocols/cmp" + "github.com/taurusgroup/multi-party-sig/protocols/cmp/config" +) + +type SignSession interface { + Session +} + +type cggmp21SigningSession struct { + session + handler *protocol.MultiHandler + pool *pool.Pool + config *config.Config + signature *ecdsa.Signature + messagesCh chan *protocol.Message + resultMutex sync.Mutex + done bool + resultErr error + messageHash []byte + signerIDs []party.ID + useBroadcast bool +} + +func newCGGMP21SigningSession( + sessionID string, + walletID string, + messageHash []byte, + pubSub messaging.PubSub, + selfPartyID party.ID, + signerIDs []party.ID, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, + resultQueue messaging.MessageQueue, + identityStore identity.Store, + useBroadcast bool, +) (*cggmp21SigningSession, error) { + // Load config from kvstore + shareBytes, err := kvstore.Get(walletID) + if err != nil { + return nil, fmt.Errorf("failed to get key share: %w", err) + } + + config := &config.Config{} + if err := json.Unmarshal(shareBytes, config); err != nil { + return nil, fmt.Errorf("failed to unmarshal key share: %w", err) + } + + // Create thread pool + threadPool := pool.NewPool(0) // Use max threads + + return &cggmp21SigningSession{ + session: session{ + walletID: walletID, + sessionID: sessionID, + pubSub: pubSub, + selfPartyID: selfPartyID, + partyIDs: signerIDs, + subscriberList: []messaging.Subscription{}, + rounds: 5, // CGGMP21 signing has 5 rounds + outCh: make(chan msg, 100), + errCh: make(chan error, 10), + finishCh: make(chan bool, 1), + externalFinishChan: make(chan string, 1), + threshold: config.Threshold, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + resultQueue: resultQueue, + logger: zerolog.New(utils.ZerologConsoleWriter()).With().Timestamp().Logger(), + processing: make(map[string]bool), + processingLock: sync.Mutex{}, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("sign:broadcast:cggmp21:%s", sessionID) + }, + ComposeDirectTopic: func(nodeID string) string { + return fmt.Sprintf("sign:direct:cggmp21:%s:%s", nodeID, sessionID) + }, + }, + identityStore: identityStore, + }, + pool: threadPool, + config: config, + messagesCh: make(chan *protocol.Message, 100), + messageHash: messageHash, + signerIDs: signerIDs, + useBroadcast: useBroadcast, + done: false, + }, nil +} + +func (s *cggmp21SigningSession) Init() { + s.logger.Info(). + Str("sessionID", s.sessionID). + Str("walletID", s.walletID). + Hex("messageHash", s.messageHash). + Interface("signerIDs", s.signerIDs). + Bool("useBroadcast", s.useBroadcast). + Msg("Initializing CGGMP21 signing session") + + // Create CGGMP21 signing protocol + startFunc := cmp.Sign(s.config, s.signerIDs, s.messageHash, s.pool) + + // Create handler + handler, err := protocol.NewMultiHandler(startFunc, nil) + if err != nil { + s.logger.Fatal().Err(err).Msg("Failed to create signing handler") + return + } + + s.handler = handler + + // Start message handling goroutine + go s.handleProtocolMessages() + + s.logger.Info(). + Str("sessionID", s.sessionID). + Str("partyID", string(s.selfPartyID)). + Interface("signerIDs", s.signerIDs). + Msg("[INITIALIZED] CGGMP21 signing session initialized successfully") +} + +func (s *cggmp21SigningSession) handleProtocolMessages() { + for { + select { + case protoMsg, ok := <-s.handler.Listen(): + if !ok { + // Protocol finished + s.resultMutex.Lock() + s.done = true + result, err := s.handler.Result() + if err != nil { + s.resultErr = err + s.errCh <- err + } else { + s.signature = result.(*ecdsa.Signature) + } + s.resultMutex.Unlock() + s.finishCh <- true + return + } + + // Convert protocol message to our message format + var toPartyIDs []party.ID + if !protoMsg.Broadcast && protoMsg.To != "" { + toPartyIDs = []party.ID{protoMsg.To} + } + outMsg := msg{ + FromPartyID: protoMsg.From, + ToPartyIDs: toPartyIDs, + IsBroadcast: protoMsg.Broadcast, + Data: protoMsg.Data, + } + + s.outCh <- outMsg + + case protoMsg := <-s.messagesCh: + // Handle incoming message + if !s.handler.CanAccept(protoMsg) { + s.logger.Warn().Msgf("Handler cannot accept message from %s", protoMsg.From) + continue + } + + s.handler.Accept(protoMsg) + } + } +} + +func (s *cggmp21SigningSession) ProcessInboundMessage(msgBytes []byte) { + s.processingLock.Lock() + defer s.processingLock.Unlock() + + inboundMessage := &types.TaurusMessage{} + if err := encoding.JsonBytesToStruct(msgBytes, inboundMessage); err != nil { + s.logger.Error().Err(err).Msg("ProcessInboundMessage unmarshal error") + return + } + + msgHashStr := fmt.Sprintf("%x", utils.GetMessageHash(msgBytes)) + if s.processing[msgHashStr] { + return + } + s.processing[msgHashStr] = true + + // Convert to protocol message + protoMsg := &protocol.Message{ + From: party.ID(inboundMessage.SenderID), + To: party.ID(""), // Single recipient for protocol messages + Data: inboundMessage.Body, + Broadcast: inboundMessage.IsBroadcast, + } + + // Send to handler + s.messagesCh <- protoMsg +} + +func (s *cggmp21SigningSession) ProcessOutboundMessage() { + s.logger.Info().Msgf("ProcessOutboundMessage started: %s", s.sessionID) + for { + select { + case m := <-s.outCh: + // Convert party IDs back to strings + recipientIDs := make([]string, len(m.ToPartyIDs)) + for i, pid := range m.ToPartyIDs { + recipientIDs[i] = string(pid) + } + + msgWireBytes := &types.TaurusMessage{ + SessionID: s.sessionID, + SenderID: string(m.FromPartyID), + RecipientIDs: recipientIDs, + Body: m.Data, + IsBroadcast: m.IsBroadcast, + } + + s.sendMsg(msgWireBytes) + + case err := <-s.errCh: + s.logger.Error().Err(err).Msg("Received error during ProcessOutboundMessage") + + case <-s.finishCh: + s.logger.Info().Msg("Received finish message during ProcessOutboundMessage") + s.publishResult() + return + } + } +} + +func (s *cggmp21SigningSession) publishResult() { + s.resultMutex.Lock() + defer s.resultMutex.Unlock() + + if s.resultErr != nil { + // failureEvent := event.CreateSignFailure( + // s.sessionID, + // s.walletID, + // map[string]any{ + // "error": s.resultErr.Error(), + // }, + // ) + // evtData, _ := encoding.StructToJsonBytes(failureEvent) + // if err := s.resultQueue.Enqueue(fmt.Sprintf("%s.%s", event.SigningResultTopic, s.walletID), evtData, nil); err != nil { + // s.logger.Error().Err(err).Msg("failed to publish sign failure event") + // } + return + } + + if s.signature == nil { + s.logger.Error().Msg("No signature available after signing completion") + return + } + + // Verify signature + if !s.signature.Verify(s.config.PublicPoint(), s.messageHash) { + s.logger.Error().Msg("Failed to verify signature") + // failureEvent := event.CreateSignFailure( + // s.sessionID, + // s.walletID, + // map[string]any{ + // "error": "signature verification failed", + // }, + // ) + // evtData, _ := encoding.StructToJsonBytes(failureEvent) + // if err := s.resultQueue.Enqueue(fmt.Sprintf("%s.%s", event.SigningResultTopic, s.walletID), evtData, nil); err != nil { + // s.logger.Error().Err(err).Msg("failed to publish sign failure event") + // } + return + } + + // Convert signature to hex + sigRBytes, _ := s.signature.R.MarshalBinary() + sigSBytes, _ := s.signature.S.MarshalBinary() + sigR := hex.EncodeToString(sigRBytes) + sigS := hex.EncodeToString(sigSBytes) + + // Publish success event + // successEvent := event.CreateSignSuccess( + // s.sessionID, + // s.walletID, + // sigR, + // sigS, + // map[string]any{ + // "messageHash": hex.EncodeToString(s.messageHash), + // "signers": len(s.signerIDs), + // "protocol": "CGGMP21", + // }, + // ) + + // evtData, _ := encoding.StructToJsonBytes(successEvent) + // if err := s.resultQueue.Enqueue(fmt.Sprintf("%s.%s", event.SigningResultTopic, s.walletID), evtData, nil); err != nil { + // s.logger.Error().Err(err).Msg("failed to publish sign success event") + // } + + s.logger.Info(). + Str("sessionID", s.sessionID). + Str("walletID", s.walletID). + Str("sigR", sigR). + Str("sigS", sigS). + Msg("CGGMP21 signing completed successfully") +} + +func (s *cggmp21SigningSession) Stop() { + if s.pool != nil { + s.pool.TearDown() + } + close(s.outCh) + close(s.errCh) + close(s.messagesCh) +} + +func (s *cggmp21SigningSession) WaitForFinish() string { + return <-s.externalFinishChan +} diff --git a/pkg/protocol/cggmp21/adapter.go b/pkg/protocol/cggmp21/adapter.go new file mode 100644 index 0000000..7f4f710 --- /dev/null +++ b/pkg/protocol/cggmp21/adapter.go @@ -0,0 +1,452 @@ +package cggmp21 + +import ( + "crypto/ecdsa" + "encoding/json" + "errors" + "fmt" + "math/big" + "sync" + + "github.com/fystack/mpcium/pkg/protocol" + mpsEcdsa "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" + "github.com/taurusgroup/multi-party-sig/pkg/math/curve" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/pool" + mpsProtocol "github.com/taurusgroup/multi-party-sig/pkg/protocol" + "github.com/taurusgroup/multi-party-sig/protocols/cmp" + "github.com/taurusgroup/multi-party-sig/protocols/cmp/config" +) + +// CGGMP21Protocol implements the Protocol interface using CGGMP21 +type CGGMP21Protocol struct { + pool *pool.Pool +} + +// NewCGGMP21Protocol creates a new CGGMP21 protocol adapter +func NewCGGMP21Protocol() *CGGMP21Protocol { + return &CGGMP21Protocol{ + pool: pool.NewPool(0), // Use max threads + } +} + +// Close cleans up resources +func (p *CGGMP21Protocol) Close() { + if p.pool != nil { + p.pool.TearDown() + } +} + +// Name returns the protocol name +func (p *CGGMP21Protocol) Name() string { + return "CGGMP21" +} + +// KeyGen starts a distributed key generation +func (p *CGGMP21Protocol) KeyGen(selfID string, partyIDs []string, threshold int) (protocol.Party, error) { + // Convert string IDs to party.ID + ids := make([]party.ID, len(partyIDs)) + for i, id := range partyIDs { + ids[i] = party.ID(id) + } + + // Create the keygen protocol + startFunc := cmp.Keygen(curve.Secp256k1{}, party.ID(selfID), ids, threshold, p.pool) + + // Create handler + handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) + if err != nil { + return nil, fmt.Errorf("failed to create keygen handler: %w", err) + } + + return &partyAdapter{ + handler: handler, + selfID: selfID, + }, nil +} + +// Refresh refreshes shares from an existing config +func (p *CGGMP21Protocol) Refresh(cfg protocol.KeyGenConfig) (protocol.Party, error) { + // Convert to CGGMP21 config + cmpConfig, err := toCMPConfig(cfg) + if err != nil { + return nil, err + } + + // Create refresh protocol + startFunc := cmp.Refresh(cmpConfig, p.pool) + + // Create handler + handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) + if err != nil { + return nil, fmt.Errorf("failed to create refresh handler: %w", err) + } + + return &partyAdapter{ + handler: handler, + selfID: cfg.GetPartyID(), + }, nil +} + +// Sign starts a signing protocol +func (p *CGGMP21Protocol) Sign(cfg protocol.KeyGenConfig, signers []string, messageHash []byte) (protocol.Party, error) { + // Convert to CGGMP21 config + cmpConfig, err := toCMPConfig(cfg) + if err != nil { + return nil, err + } + + // Convert signer IDs + signerIDs := make([]party.ID, len(signers)) + for i, id := range signers { + signerIDs[i] = party.ID(id) + } + + // Create sign protocol + startFunc := cmp.Sign(cmpConfig, signerIDs, messageHash, p.pool) + + // Create handler + handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) + if err != nil { + return nil, fmt.Errorf("failed to create sign handler: %w", err) + } + + return &partyAdapter{ + handler: handler, + selfID: cfg.GetPartyID(), + }, nil +} + +// PreSign starts a presigning protocol +func (p *CGGMP21Protocol) PreSign(cfg protocol.KeyGenConfig, signers []string) (protocol.Party, error) { + // Convert to CGGMP21 config + cmpConfig, err := toCMPConfig(cfg) + if err != nil { + return nil, err + } + + // Convert signer IDs + signerIDs := make([]party.ID, len(signers)) + for i, id := range signers { + signerIDs[i] = party.ID(id) + } + + // Create presign protocol + startFunc := cmp.Presign(cmpConfig, signerIDs, p.pool) + + // Create handler + handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) + if err != nil { + return nil, fmt.Errorf("failed to create presign handler: %w", err) + } + + return &partyAdapter{ + handler: handler, + selfID: cfg.GetPartyID(), + }, nil +} + +// PreSignOnline completes a signature with a presignature +func (p *CGGMP21Protocol) PreSignOnline(cfg protocol.KeyGenConfig, preSig protocol.PreSignature, messageHash []byte) (protocol.Party, error) { + // Convert to CGGMP21 types + cmpConfig, err := toCMPConfig(cfg) + if err != nil { + return nil, err + } + + cmpPreSig, ok := preSig.(*preSignatureAdapter) + if !ok { + return nil, errors.New("invalid presignature type") + } + + // Create presign online protocol + startFunc := cmp.PresignOnline(cmpConfig, cmpPreSig.preSig, messageHash, p.pool) + + // Create handler + handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) + if err != nil { + return nil, fmt.Errorf("failed to create presign online handler: %w", err) + } + + return &partyAdapter{ + handler: handler, + selfID: cfg.GetPartyID(), + }, nil +} + +// partyAdapter adapts mpsProtocol.Handler to protocol.Party +type partyAdapter struct { + handler *mpsProtocol.MultiHandler + selfID string + mu sync.Mutex + done bool + result interface{} + err error +} + +func (p *partyAdapter) Update(msg protocol.Message) error { + // Convert to MPS message format + // If broadcast, To is nil. Otherwise, it's the first recipient + var to party.ID + if !msg.IsBroadcast() && len(msg.GetTo()) > 0 { + to = party.ID(msg.GetTo()[0]) + } + + mpsMsg := &mpsProtocol.Message{ + From: party.ID(msg.GetFrom()), + To: to, + Broadcast: msg.IsBroadcast(), + Data: msg.GetData(), + } + + // Check if handler can accept the message + if !p.handler.CanAccept(mpsMsg) { + return errors.New("message rejected by handler") + } + + // Update handler with message + // Note: MultiHandler doesn't have Update method, we need to send via Accept + p.handler.Accept(mpsMsg) + return nil +} + +func (p *partyAdapter) Messages() <-chan protocol.Message { + ch := make(chan protocol.Message) + + go func() { + defer close(ch) + + for { + select { + case msg, ok := <-p.handler.Listen(): + if !ok { + // Protocol finished + p.mu.Lock() + p.done = true + p.result, p.err = p.handler.Result() + p.mu.Unlock() + return + } + + // Convert and send message + var toList []string + if !msg.Broadcast && msg.To != "" { + toList = []string{string(msg.To)} + } + ch <- &messageAdapter{ + from: string(msg.From), + to: toList, + data: msg.Data, + broadcast: msg.Broadcast, + } + } + } + }() + + return ch +} + +func (p *partyAdapter) Errors() <-chan error { + // CGGMP21 doesn't have a separate error channel + // Errors are returned in Result() + ch := make(chan error) + close(ch) + return ch +} + +func (p *partyAdapter) Done() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.done +} + +func (p *partyAdapter) Result() (interface{}, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.done { + return nil, errors.New("protocol not finished") + } + + if p.err != nil { + return nil, p.err + } + + // Convert result to appropriate type + switch r := p.result.(type) { + case *config.Config: + return &configAdapter{config: r}, nil + case *mpsEcdsa.Signature: + return &signatureAdapter{sig: r}, nil + case *mpsEcdsa.PreSignature: + return &preSignatureAdapter{preSig: r}, nil + default: + return p.result, nil + } +} + +// messageAdapter implements protocol.Message +type messageAdapter struct { + from string + to []string + data []byte + broadcast bool +} + +func (m *messageAdapter) GetFrom() string { return m.from } +func (m *messageAdapter) GetTo() []string { return m.to } +func (m *messageAdapter) GetData() []byte { return m.data } +func (m *messageAdapter) IsBroadcast() bool { return m.broadcast } + +// configAdapter implements protocol.KeyGenConfig +type configAdapter struct { + config *config.Config +} + +func (c *configAdapter) GetPartyID() string { + return string(c.config.ID) +} + +func (c *configAdapter) GetThreshold() int { + return c.config.Threshold +} + +func (c *configAdapter) GetPublicKey() *ecdsa.PublicKey { + point := c.config.PublicPoint() + // Convert curve.Point to ecdsa.PublicKey + // Using XScalar to get X coordinate as big.Int + if point.XScalar() != nil { + xBytes, _ := point.XScalar().MarshalBinary() + x := new(big.Int).SetBytes(xBytes) + // For Y, we need to derive it from the point + // This is a limitation - we can't get Y directly + return &ecdsa.PublicKey{ + Curve: nil, // We can't convert curve.Curve to elliptic.Curve + X: x, + Y: new(big.Int), // Placeholder + } + } + return nil +} + +func (c *configAdapter) GetShare() *big.Int { + // Get ECDSA scalar share and convert to big.Int + if c.config.ECDSA != nil { + bytes, _ := c.config.ECDSA.MarshalBinary() + return new(big.Int).SetBytes(bytes) + } + return nil +} + +func (c *configAdapter) GetSharePublicKey() *ecdsa.PublicKey { + // Get this party's public share + if public, ok := c.config.Public[c.config.ID]; ok && public.ECDSA != nil { + // Convert curve.Point to ecdsa.PublicKey + if public.ECDSA.XScalar() != nil { + xBytes, _ := public.ECDSA.XScalar().MarshalBinary() + x := new(big.Int).SetBytes(xBytes) + return &ecdsa.PublicKey{ + Curve: nil, // We can't convert curve.Curve to elliptic.Curve + X: x, + Y: new(big.Int), // Placeholder + } + } + } + return nil +} + +func (c *configAdapter) GetPartyIDs() []string { + ids := c.config.PartyIDs() + result := make([]string, len(ids)) + for i, id := range ids { + result[i] = string(id) + } + return result +} + +func (c *configAdapter) Serialize() ([]byte, error) { + return json.Marshal(c.config) +} + +// signatureAdapter implements protocol.Signature +type signatureAdapter struct { + sig *mpsEcdsa.Signature +} + +func (s *signatureAdapter) GetR() *big.Int { + // Convert curve.Point R to big.Int + if s.sig.R != nil && s.sig.R.XScalar() != nil { + bytes, _ := s.sig.R.XScalar().MarshalBinary() + return new(big.Int).SetBytes(bytes) + } + return nil +} + +func (s *signatureAdapter) GetS() *big.Int { + // Convert curve.Scalar S to big.Int + if s.sig.S != nil { + bytes, _ := s.sig.S.MarshalBinary() + return new(big.Int).SetBytes(bytes) + } + return nil +} + +func (s *signatureAdapter) Verify(pubKey *ecdsa.PublicKey, message []byte) bool { + // Verification would require converting ecdsa.PublicKey to curve.Point + // This is complex without the proper curve conversion + // For now, return false + return false +} + +func (s *signatureAdapter) Serialize() ([]byte, error) { + return json.Marshal(s.sig) +} + +// preSignatureAdapter implements protocol.PreSignature +type preSignatureAdapter struct { + preSig *mpsEcdsa.PreSignature +} + +func (p *preSignatureAdapter) GetID() string { + // Convert RID (byte slice) to hex string + return fmt.Sprintf("%x", p.preSig.ID) +} + +func (p *preSignatureAdapter) Validate() error { + return p.preSig.Validate() +} + +// Helper functions + +func convertToPartyIDs(ids []string) []party.ID { + if ids == nil { + return nil + } + result := make([]party.ID, len(ids)) + for i, id := range ids { + result[i] = party.ID(id) + } + return result +} + +func convertFromPartyIDs(ids []party.ID) []string { + if ids == nil { + return nil + } + result := make([]string, len(ids)) + for i, id := range ids { + result[i] = string(id) + } + return result +} + +func toCMPConfig(cfg protocol.KeyGenConfig) (*config.Config, error) { + // Try to cast directly first + if adapter, ok := cfg.(*configAdapter); ok { + return adapter.config, nil + } + + // Otherwise, we need to reconstruct + // This is a simplified version - in production you'd need proper serialization + return nil, errors.New("config conversion not implemented for non-CGGMP21 configs") +} diff --git a/pkg/protocol/frost/adapter.go b/pkg/protocol/frost/adapter.go new file mode 100644 index 0000000..dc875ea --- /dev/null +++ b/pkg/protocol/frost/adapter.go @@ -0,0 +1,445 @@ +package frost + +import ( + "crypto/ecdsa" + "encoding/json" + "errors" + "fmt" + "math/big" + "sync" + + "github.com/fystack/mpcium/pkg/protocol" + "github.com/taurusgroup/multi-party-sig/pkg/math/curve" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/pool" + mpsProtocol "github.com/taurusgroup/multi-party-sig/pkg/protocol" + "github.com/taurusgroup/multi-party-sig/protocols/frost" +) + +// FROSTProtocol implements the Protocol interface using FROST for EdDSA +type FROSTProtocol struct { + pool *pool.Pool +} + +// NewFROSTProtocol creates a new FROST protocol adapter +func NewFROSTProtocol() *FROSTProtocol { + return &FROSTProtocol{ + pool: pool.NewPool(0), // Use max threads + } +} + +// Close cleans up resources +func (p *FROSTProtocol) Close() { + if p.pool != nil { + p.pool.TearDown() + } +} + +// Name returns the protocol name +func (p *FROSTProtocol) Name() string { + return "FROST" +} + +// KeyGen starts a distributed key generation for EdDSA +func (p *FROSTProtocol) KeyGen(selfID string, partyIDs []string, threshold int) (protocol.Party, error) { + // Convert string IDs to party.ID + ids := make([]party.ID, len(partyIDs)) + for i, id := range partyIDs { + ids[i] = party.ID(id) + } + + // Create the FROST keygen protocol for Ed25519/Taproot + startFunc := frost.KeygenTaproot(party.ID(selfID), ids, threshold) + + // Create handler + handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) + if err != nil { + return nil, fmt.Errorf("failed to create FROST keygen handler: %w", err) + } + + return &frostPartyAdapter{ + handler: handler, + selfID: selfID, + isTaproot: true, + }, nil +} + +// Refresh refreshes shares from an existing config +func (p *FROSTProtocol) Refresh(cfg protocol.KeyGenConfig) (protocol.Party, error) { + // Convert to FROST config + frostConfig, err := toFROSTConfig(cfg) + if err != nil { + return nil, err + } + + // Get party IDs from config + partyIDs := cfg.GetPartyIDs() + ids := make([]party.ID, len(partyIDs)) + for i, id := range partyIDs { + ids[i] = party.ID(id) + } + + // Create refresh protocol + startFunc := frost.Refresh(frostConfig, ids) + + // Create handler + handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) + if err != nil { + return nil, fmt.Errorf("failed to create FROST refresh handler: %w", err) + } + + return &frostPartyAdapter{ + handler: handler, + selfID: cfg.GetPartyID(), + isTaproot: false, + }, nil +} + +// Sign starts a signing protocol +func (p *FROSTProtocol) Sign(cfg protocol.KeyGenConfig, signers []string, messageHash []byte) (protocol.Party, error) { + // Convert to FROST config + frostConfig, err := toFROSTConfig(cfg) + if err != nil { + return nil, err + } + + // Convert signer IDs + signerIDs := make([]party.ID, len(signers)) + for i, id := range signers { + signerIDs[i] = party.ID(id) + } + + // Create sign protocol + startFunc := frost.Sign(frostConfig, signerIDs, messageHash) + + // Create handler + handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) + if err != nil { + return nil, fmt.Errorf("failed to create FROST sign handler: %w", err) + } + + return &frostPartyAdapter{ + handler: handler, + selfID: cfg.GetPartyID(), + isTaproot: false, + }, nil +} + +// PreSign starts a presigning protocol +func (p *FROSTProtocol) PreSign(cfg protocol.KeyGenConfig, signers []string) (protocol.Party, error) { + // FROST doesn't support presigning in the same way as ECDSA protocols + return nil, errors.New("FROST protocol does not support presigning") +} + +// PreSignOnline completes a signature with a presignature +func (p *FROSTProtocol) PreSignOnline(cfg protocol.KeyGenConfig, preSig protocol.PreSignature, messageHash []byte) (protocol.Party, error) { + // FROST doesn't support presigning in the same way as ECDSA protocols + return nil, errors.New("FROST protocol does not support presigning") +} + +// frostPartyAdapter adapts mpsProtocol.Handler to protocol.Party +type frostPartyAdapter struct { + handler *mpsProtocol.MultiHandler + selfID string + isTaproot bool + mu sync.Mutex + done bool + result interface{} + err error +} + +func (p *frostPartyAdapter) Update(msg protocol.Message) error { + // Convert to MPS message format + // If broadcast, To is nil. Otherwise, it's the first recipient + var to party.ID + if !msg.IsBroadcast() && len(msg.GetTo()) > 0 { + to = party.ID(msg.GetTo()[0]) + } + + mpsMsg := &mpsProtocol.Message{ + From: party.ID(msg.GetFrom()), + To: to, + Broadcast: msg.IsBroadcast(), + Data: msg.GetData(), + } + + // Check if handler can accept the message + if !p.handler.CanAccept(mpsMsg) { + return errors.New("message rejected by handler") + } + + // Update handler with message + // Note: MultiHandler doesn't have Update method, we need to send via Accept + p.handler.Accept(mpsMsg) + return nil +} + +func (p *frostPartyAdapter) Messages() <-chan protocol.Message { + ch := make(chan protocol.Message) + + go func() { + defer close(ch) + + for { + select { + case msg, ok := <-p.handler.Listen(): + if !ok { + // Protocol finished + p.mu.Lock() + p.done = true + p.result, p.err = p.handler.Result() + p.mu.Unlock() + return + } + + // Convert and send message + var toList []string + if !msg.Broadcast && msg.To != "" { + toList = []string{string(msg.To)} + } + + ch <- &messageAdapter{ + from: string(msg.From), + to: toList, + data: msg.Data, + broadcast: msg.Broadcast, + } + } + } + }() + + return ch +} + +func (p *frostPartyAdapter) Errors() <-chan error { + // FROST doesn't have a separate error channel + // Errors are returned in Result() + ch := make(chan error) + close(ch) + return ch +} + +func (p *frostPartyAdapter) Done() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.done +} + +func (p *frostPartyAdapter) Result() (interface{}, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.done { + return nil, errors.New("protocol not finished") + } + + if p.err != nil { + return nil, p.err + } + + // Convert result to appropriate type + switch r := p.result.(type) { + case *frost.Signature: + return &frostSignatureAdapter{sig: r}, nil + case *frost.Config: + return &frostConfigAdapter{ + config: r, + isTaproot: false, + }, nil + case *frost.TaprootConfig: + return &frostConfigAdapter{ + taprootConfig: r, + isTaproot: true, + }, nil + default: + return nil, fmt.Errorf("unexpected result type: %T", r) + } +} + +// messageAdapter implements protocol.Message +type messageAdapter struct { + from string + to []string + data []byte + broadcast bool +} + +func (m *messageAdapter) GetFrom() string { return m.from } +func (m *messageAdapter) GetTo() []string { return m.to } +func (m *messageAdapter) GetData() []byte { return m.data } +func (m *messageAdapter) IsBroadcast() bool { return m.broadcast } + +// frostConfigAdapter implements protocol.KeyGenConfig for FROST +type frostConfigAdapter struct { + config *frost.Config + taprootConfig *frost.TaprootConfig + isTaproot bool +} + +func (c *frostConfigAdapter) GetPartyID() string { + if c.isTaproot && c.taprootConfig != nil { + return string(c.taprootConfig.ID) + } + if c.config != nil { + return string(c.config.ID) + } + return "" +} + +func (c *frostConfigAdapter) GetThreshold() int { + if c.isTaproot && c.taprootConfig != nil { + return c.taprootConfig.Threshold + } + if c.config != nil { + return c.config.Threshold + } + return 0 +} + +// GetPublicKey returns nil for EdDSA as it uses different key type +func (c *frostConfigAdapter) GetPublicKey() *ecdsa.PublicKey { + // FROST uses Ed25519/Schnorr, not ECDSA + // This is a limitation of the current interface design + // For Taproot, we could potentially convert but it's not standard ECDSA + return nil +} + +// GetPublicKeyBytes returns the public key as bytes +func (c *frostConfigAdapter) GetPublicKeyBytes() []byte { + if c.isTaproot && c.taprootConfig != nil { + return c.taprootConfig.PublicKey + } + if c.config != nil && c.config.PublicKey != nil { + bytes, _ := c.config.PublicKey.MarshalBinary() + return bytes + } + return nil +} + +func (c *frostConfigAdapter) GetShare() *big.Int { + if c.isTaproot && c.taprootConfig != nil { + bytes, _ := c.taprootConfig.PrivateShare.MarshalBinary() + return new(big.Int).SetBytes(bytes) + } + if c.config != nil && c.config.PrivateShare != nil { + bytes, _ := c.config.PrivateShare.MarshalBinary() + return new(big.Int).SetBytes(bytes) + } + return nil +} + +func (c *frostConfigAdapter) GetSharePublicKey() *ecdsa.PublicKey { + // FROST uses Ed25519/Schnorr, not ECDSA + return nil +} + +func (c *frostConfigAdapter) GetPartyIDs() []string { + if c.isTaproot && c.taprootConfig != nil { + ids := make([]string, 0, len(c.taprootConfig.VerificationShares)) + for id := range c.taprootConfig.VerificationShares { + ids = append(ids, string(id)) + } + return ids + } + if c.config != nil && c.config.VerificationShares != nil { + ids := make([]string, 0, len(c.config.VerificationShares.Points)) + for id := range c.config.VerificationShares.Points { + ids = append(ids, string(id)) + } + return ids + } + return nil +} + +func (c *frostConfigAdapter) Serialize() ([]byte, error) { + if c.isTaproot && c.taprootConfig != nil { + return json.Marshal(c.taprootConfig) + } + if c.config != nil { + return json.Marshal(c.config) + } + return nil, errors.New("no config to serialize") +} + +// frostSignatureAdapter implements protocol.Signature for FROST +type frostSignatureAdapter struct { + sig *frost.Signature +} + +func (s *frostSignatureAdapter) GetR() *big.Int { + // FROST signatures have an R point, convert X coordinate to big.Int + if s.sig.R != nil && s.sig.R.XScalar() != nil { + bytes, _ := s.sig.R.XScalar().MarshalBinary() + return new(big.Int).SetBytes(bytes) + } + return new(big.Int) +} + +func (s *frostSignatureAdapter) GetS() *big.Int { + // FROST signatures don't have a direct S component + // This is a limitation of the current interface + return new(big.Int) +} + +func (s *frostSignatureAdapter) Verify(pubKey *ecdsa.PublicKey, message []byte) bool { + // This adapter doesn't support ECDSA verification + // FROST uses Schnorr signatures, not ECDSA + return false +} + +func (s *frostSignatureAdapter) Serialize() ([]byte, error) { + // Marshal the signature using JSON for now + return json.Marshal(s.sig) +} + +// Helper functions + +func convertToPartyIDs(ids []string) []party.ID { + if ids == nil { + return nil + } + result := make([]party.ID, len(ids)) + for i, id := range ids { + result[i] = party.ID(id) + } + return result +} + +func convertFromPartyIDs(ids []party.ID) []string { + if ids == nil { + return nil + } + result := make([]string, len(ids)) + for i, id := range ids { + result[i] = string(id) + } + return result +} + +func toFROSTConfig(cfg protocol.KeyGenConfig) (*frost.Config, error) { + // Try to cast directly first + if adapter, ok := cfg.(*frostConfigAdapter); ok { + if adapter.config != nil { + return adapter.config, nil + } + // If it's a Taproot config, we need to convert it + if adapter.taprootConfig != nil { + // This would need proper conversion logic + return nil, errors.New("cannot convert Taproot config to regular FROST config") + } + } + + // Otherwise, deserialize if possible + data, err := cfg.Serialize() + if err != nil { + return nil, fmt.Errorf("failed to serialize config: %w", err) + } + + // Try to unmarshal as FROST config + config := frost.EmptyConfig(curve.Secp256k1{}) + if err := json.Unmarshal(data, config); err != nil { + return nil, fmt.Errorf("failed to unmarshal as FROST config: %w", err) + } + + return config, nil +} diff --git a/pkg/protocol/interfaces.go b/pkg/protocol/interfaces.go new file mode 100644 index 0000000..916dd11 --- /dev/null +++ b/pkg/protocol/interfaces.go @@ -0,0 +1,91 @@ +package protocol + +import ( + "crypto/ecdsa" + "math/big" +) + +// Message represents a protocol message +type Message interface { + // GetFrom returns the sender ID + GetFrom() string + // GetTo returns the recipient IDs (nil for broadcast) + GetTo() []string + // GetData returns the message data + GetData() []byte + // IsBroadcast returns true if this is a broadcast message + IsBroadcast() bool +} + +// Party represents a participant in the protocol +type Party interface { + // Update processes an incoming message + Update(msg Message) error + // Messages returns a channel of outgoing messages + Messages() <-chan Message + // Errors returns a channel of errors + Errors() <-chan error + // Done returns true when the protocol is complete + Done() bool + // Result returns the protocol result + Result() (interface{}, error) +} + +// KeyGenConfig represents the result of key generation +type KeyGenConfig interface { + // GetPartyID returns this party's ID + GetPartyID() string + // GetThreshold returns the threshold value + GetThreshold() int + // GetPublicKey returns the group's public key + GetPublicKey() *ecdsa.PublicKey + // GetShare returns this party's secret share + GetShare() *big.Int + // GetSharePublicKey returns this party's public share + GetSharePublicKey() *ecdsa.PublicKey + // GetPartyIDs returns all party IDs + GetPartyIDs() []string + // Serialize returns the config as bytes + Serialize() ([]byte, error) +} + +// Signature represents a signature +type Signature interface { + // GetR returns the R component + GetR() *big.Int + // GetS returns the S component + GetS() *big.Int + // Verify verifies the signature + Verify(pubKey *ecdsa.PublicKey, message []byte) bool + // Serialize returns the signature as bytes + Serialize() ([]byte, error) +} + +// PreSignature represents a preprocessed signature +type PreSignature interface { + // GetID returns the presignature ID + GetID() string + // Validate validates the presignature + Validate() error +} + +// Protocol represents a threshold signature protocol implementation +type Protocol interface { + // KeyGen starts a distributed key generation + KeyGen(selfID string, partyIDs []string, threshold int) (Party, error) + + // Refresh refreshes shares from an existing config + Refresh(config KeyGenConfig) (Party, error) + + // Sign starts a signing protocol + Sign(config KeyGenConfig, signers []string, messageHash []byte) (Party, error) + + // PreSign starts a presigning protocol + PreSign(config KeyGenConfig, signers []string) (Party, error) + + // PreSignOnline completes a signature with a presignature + PreSignOnline(config KeyGenConfig, preSignature PreSignature, messageHash []byte) (Party, error) + + // Name returns the protocol name (e.g., "GG20", "CGGMP21") + Name() string +} diff --git a/pkg/types/taurus.go b/pkg/types/taurus.go new file mode 100644 index 0000000..e224c93 --- /dev/null +++ b/pkg/types/taurus.go @@ -0,0 +1,10 @@ +package types + +// Message represents a protocol message +type TaurusMessage struct { + SessionID string `json:"session_id"` + SenderID string `json:"sender_id"` + RecipientIDs []string `json:"recipient_ids"` + Body []byte `json:"body"` + IsBroadcast bool `json:"is_broadcast"` +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 0000000..bd62682 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,20 @@ +package utils + +import ( + "crypto/sha256" + "io" + "os" + + "github.com/rs/zerolog" +) + +// GetMessageHash returns the SHA256 hash of the message +func GetMessageHash(msgBytes []byte) []byte { + hash := sha256.Sum256(msgBytes) + return hash[:] +} + +// ZerologConsoleWriter returns a console writer for zerolog +func ZerologConsoleWriter() io.Writer { + return zerolog.ConsoleWriter{Out: os.Stdout} +} diff --git a/setup_identities.sh b/setup_identities.sh index 708c961..e2f5979 100755 --- a/setup_identities.sh +++ b/setup_identities.sh @@ -1,5 +1,8 @@ #!/bin/bash +# Add Go bin directory to PATH to ensure mpcium-cli is available +export PATH="$HOME/go/bin:$PATH" + # Number of nodes to create (default is 3) NUM_NODES=3 diff --git a/wallets.json b/wallets.json new file mode 100644 index 0000000..0e05a2d --- /dev/null +++ b/wallets.json @@ -0,0 +1,3 @@ +[ + "fb89e64c-e2ee-4e1c-a04e-2fa728dae170" +] \ No newline at end of file From b94e76359947896b69737b0913eac3a1b940add9 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 3 Oct 2025 17:24:55 +0700 Subject: [PATCH 02/21] wip: done taurus in mem and pubsub --- pkg/eventconsumer/event_consumer.go | 56 +--- pkg/mpc/node.go | 35 +-- pkg/mpc/taurus/adapter.go | 87 ++++++ pkg/mpc/taurus/cmp.go | 152 ++++++++++ pkg/mpc/taurus/cmp_test.go | 93 ++++++ pkg/mpc/taurus/ecdsa_resharing_session.go | 351 ---------------------- pkg/mpc/taurus/eddsa_resharing_session.go | 345 --------------------- pkg/mpc/taurus/keygen_session.go | 328 -------------------- pkg/mpc/taurus/nats_transport.go | 146 +++++++++ pkg/mpc/taurus/node.go | 242 --------------- pkg/mpc/taurus/node_test.go | 119 -------- pkg/mpc/taurus/registry.go | 213 ------------- pkg/mpc/taurus/reshare_session.go | 32 -- pkg/mpc/taurus/session.go | 184 ------------ pkg/mpc/taurus/signing_session.go | 332 -------------------- pkg/mpc/taurus/transport.go | 83 +++++ pkg/types/taurus.go | 7 + 17 files changed, 587 insertions(+), 2218 deletions(-) create mode 100644 pkg/mpc/taurus/adapter.go create mode 100644 pkg/mpc/taurus/cmp.go create mode 100644 pkg/mpc/taurus/cmp_test.go delete mode 100644 pkg/mpc/taurus/ecdsa_resharing_session.go delete mode 100644 pkg/mpc/taurus/eddsa_resharing_session.go delete mode 100644 pkg/mpc/taurus/keygen_session.go create mode 100644 pkg/mpc/taurus/nats_transport.go delete mode 100644 pkg/mpc/taurus/node.go delete mode 100644 pkg/mpc/taurus/node_test.go delete mode 100644 pkg/mpc/taurus/registry.go delete mode 100644 pkg/mpc/taurus/reshare_session.go delete mode 100644 pkg/mpc/taurus/session.go delete mode 100644 pkg/mpc/taurus/signing_session.go create mode 100644 pkg/mpc/taurus/transport.go diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 7e08100..9c325e5 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -2,7 +2,6 @@ package eventconsumer import ( "context" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -187,7 +186,6 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { // ecdsaSession.Init() // eddsaSession.Init() - taurusSession.Init() // ctxEcdsa, doneEcdsa := context.WithCancel(baseCtx) // ctxEddsa, doneEddsa := context.WithCancel(baseCtx) @@ -225,58 +223,24 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { // } // }() - go func() { - select { - case <-ctxTaurus.Done(): - return - case err := <-taurusSession.ErrChan(): - if err != nil { - logger.Error("CGGMP21 keygen session error", err) - errorChan <- err - doneTaurus() - } - } - }() - // ecdsaSession.ListenToIncomingMessageAsync() // eddsaSession.ListenToIncomingMessageAsync() - taurusSession.ListenToIncomingMessageAsync(taurusSession.ProcessInboundMessage) // Temporary delay for peer setup ec.warmUpSession() // go ecdsaSession.GenerateKey(doneEcdsa) // go eddsaSession.GenerateKey(doneEddsa) - go taurusSession.ProcessOutboundMessage() - // Wait for the keygen to complete - completionChan := make(chan string, 1) go func() { - result := taurusSession.WaitForFinish() - completionChan <- result - }() - - // Wait for completion, error, or timeout - select { - case pubKeyHex := <-completionChan: - // Success - set the public key - if pubKeyHex != "" { - pubKeyBytes, err := hex.DecodeString(pubKeyHex) - if err == nil { - successEvent.TaurusCMPPubKey = pubKeyBytes - } + data, err := taurusSession.Keygen(ctxTaurus) + if err != nil { + logger.Error("Failed to generate key", err) + ec.handleKeygenSessionError(walletID, err, "Failed to generate key", natMsg) + errorChan <- err + doneTaurus() } - doneTaurus() // Signal completion - - case err := <-errorChan: - // Error occurred - ec.handleKeygenSessionError(walletID, err, "CGGMP21 keygen error", natMsg) - return - - case <-baseCtx.Done(): - // Timeout occurred - logger.Warn("Key generation timed out", "walletID", walletID, "timeout", KeyGenTimeOut) - ec.handleKeygenSessionError(walletID, fmt.Errorf("keygen session timed out after %v", KeyGenTimeOut), "Key generation timed out", natMsg) - return - } + successEvent.TaurusCMPPubKey = data.Payload + doneTaurus() + }() payload, err := json.Marshal(successEvent) if err != nil { @@ -284,7 +248,7 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { ec.handleKeygenSessionError(walletID, err, "Failed to marshal keygen success event", natMsg) return } - + fmt.Println("payload", string(payload)) key := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, walletID) if err := ec.genKeyResultQueue.Enqueue( key, diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index fea226d..e2b4fa3 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -16,6 +16,7 @@ import ( "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/taurus" "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/pool" ) const ( @@ -146,32 +147,14 @@ func (p *Node) CreateCMPKeyGenSession( walletID string, threshold int, resultQueue messaging.MessageQueue, -) (taurus.KeyGenSession, error) { - if !p.peerRegistry.ArePeersReady() { - return nil, fmt.Errorf( - "peers are not ready yet. ready: %d, expected: %d", - p.peerRegistry.GetReadyPeersCount(), - len(p.peerIDs)+1, - ) - } - +) (*taurus.CmpParty, error) { readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() selfPartyID, allPartyIDs := p.generateTaurusPartyIDs(PurposeKeygen, readyPeerIDs, DefaultVersion) - - session := taurus.NewCGGMP21KeygenSession( - walletID, - p.pubSub, - selfPartyID, - allPartyIDs, - threshold, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.identityStore, - ) - - session.Init() - return session, nil + tr := taurus.NewNATSTransport(walletID, selfPartyID, p.pubSub) + adapter := taurus.NewTaurusNetworkAdapter(walletID, selfPartyID, tr, allPartyIDs) + pl := pool.NewPool(0) + party := taurus.NewCmpParty(walletID, selfPartyID, allPartyIDs, threshold, pl, adapter) + return party, nil } func (p *Node) CreateSigningSession( @@ -506,8 +489,8 @@ func sessionKeyPrefix(sessionType SessionType) (string, error) { } } -func (p *Node) generateTaurusPartyIDs(purpose string, peerIDs []string, version int) (party.ID, []party.ID) { - partyIDs := make([]party.ID, len(peerIDs)) +func (p *Node) generateTaurusPartyIDs(purpose string, peerIDs []string, version int) (party.ID, party.IDSlice) { + partyIDs := make(party.IDSlice, len(peerIDs)) var selfPartyID party.ID for i, peerID := range peerIDs { diff --git a/pkg/mpc/taurus/adapter.go b/pkg/mpc/taurus/adapter.go new file mode 100644 index 0000000..c02c732 --- /dev/null +++ b/pkg/mpc/taurus/adapter.go @@ -0,0 +1,87 @@ +package taurus + +import ( + "encoding/json" + "log/slog" + + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/protocol" +) + +type NetworkInterface interface { + Next() <-chan *protocol.Message + Send(msg *protocol.Message) + Done() <-chan struct{} +} + +type TaurusNetworkAdapter struct { + sid string + selfID party.ID + transport Transport + inbox chan *protocol.Message + done chan struct{} + peers party.IDSlice +} + +func NewTaurusNetworkAdapter( + sid string, + selfID party.ID, + t Transport, + peers party.IDSlice, +) *TaurusNetworkAdapter { + a := &TaurusNetworkAdapter{ + sid: sid, + selfID: selfID, + transport: t, + inbox: make(chan *protocol.Message, 100), + done: make(chan struct{}), + peers: peers, + } + go a.route() + return a +} + +func (a *TaurusNetworkAdapter) Next() <-chan *protocol.Message { return a.inbox } +func (a *TaurusNetworkAdapter) Done() <-chan struct{} { return a.done } + +func (a *TaurusNetworkAdapter) Send(msg *protocol.Message) { + wire, err := json.Marshal(msg) + if err != nil { + slog.Error("❌ marshal protocol msg", "err", err) + return + } + m := Msg{SID: a.sid, From: string(msg.From), IsBroadcast: msg.Broadcast, Data: wire} + for _, pid := range a.peers { + if pid == a.selfID { + continue + } + if msg.Broadcast || msg.IsFor(pid) { + _ = a.transport.Send(string(pid), m) + } + } +} + +func (a *TaurusNetworkAdapter) route() { + for { + select { + case tm, ok := <-a.transport.Inbox(): + if !ok { + close(a.done) + return + } + var pm protocol.Message + if err := json.Unmarshal(tm.Data, &pm); err != nil { + slog.Error("❌ unmarshal protocol msg", "err", err) + continue + } + select { + case a.inbox <- &pm: + default: + slog.Warn("⚠️ inbox full, drop msg", "self", a.selfID) + } + case <-a.transport.Done(): + close(a.done) + return + } + } +} diff --git a/pkg/mpc/taurus/cmp.go b/pkg/mpc/taurus/cmp.go new file mode 100644 index 0000000..b94440e --- /dev/null +++ b/pkg/mpc/taurus/cmp.go @@ -0,0 +1,152 @@ +package taurus + +import ( + "context" + "errors" + "fmt" + "math/big" + + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" + "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" + "github.com/taurusgroup/multi-party-sig/pkg/math/curve" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/pool" + "github.com/taurusgroup/multi-party-sig/pkg/protocol" + "github.com/taurusgroup/multi-party-sig/protocols/cmp" +) + +type CmpParty struct { + sid string + id party.ID + ids party.IDSlice + threshold int + pl *pool.Pool + savedData *cmp.Config + network NetworkInterface +} + +func NewCmpParty( + sid string, + id party.ID, + ids party.IDSlice, + threshold int, + pl *pool.Pool, + network NetworkInterface, +) *CmpParty { + return &CmpParty{ + sid: sid, + id: id, + ids: ids, + threshold: threshold, + pl: pl, + network: network, + } +} + +func (p *CmpParty) LoadKey(data *types.KeyData) error { + cfg := cmp.EmptyConfig(curve.Secp256k1{}) + if err := cfg.UnmarshalBinary(data.Payload); err != nil { + return fmt.Errorf("decode key data: %w", err) + } + p.savedData = cfg + return nil +} + +func (p *CmpParty) Keygen(ctx context.Context) (types.KeyData, error) { + h, err := protocol.NewMultiHandler( + cmp.Keygen(curve.Secp256k1{}, p.id, p.ids, p.threshold, p.pl), + []byte(p.sid), + ) + if err != nil { + return types.KeyData{}, err + } + if err := p.executeProtocol(ctx, h); err != nil { + return types.KeyData{}, err + } + res, err := h.Result() + if err != nil { + return types.KeyData{}, err + } + cfg, ok := res.(*cmp.Config) + if !ok { + return types.KeyData{}, errors.New("unexpected result type") + } + p.savedData = cfg + packed, _ := cfg.MarshalBinary() + return types.KeyData{SID: p.sid, Type: "taurus_cmp", Payload: packed}, nil +} + +func (p *CmpParty) Sign(ctx context.Context, msg *big.Int) ([]byte, error) { + if p.savedData == nil { + return nil, errors.New("no key loaded") + } + h, err := protocol.NewMultiHandler( + cmp.Sign(p.savedData, p.ids, msg.Bytes(), p.pl), + []byte(p.sid), + ) + if err != nil { + return nil, err + } + if err := p.executeProtocol(ctx, h); err != nil { + return nil, err + } + res, err := h.Result() + if err != nil { + return nil, err + } + sig, ok := res.(*ecdsa.Signature) + if !ok { + return nil, errors.New("unexpected signature result") + } + if !sig.Verify(p.savedData.PublicPoint(), msg.Bytes()) { + return nil, errors.New("failed to verify cmp signature") + } + return sig.SigEthereum() +} + +func (p *CmpParty) Reshare(ctx context.Context) (types.KeyData, error) { + if p.savedData == nil { + return types.KeyData{}, errors.New("no key loaded") + } + h, err := protocol.NewMultiHandler(cmp.Refresh(p.savedData, p.pl), []byte(p.sid)) + if err != nil { + return types.KeyData{}, err + } + if err := p.executeProtocol(ctx, h); err != nil { + return types.KeyData{}, err + } + res, err := h.Result() + if err != nil { + return types.KeyData{}, err + } + cfg, ok := res.(*cmp.Config) + if !ok { + return types.KeyData{}, errors.New("unexpected result type") + } + p.savedData = cfg + packed, _ := cfg.MarshalBinary() + return types.KeyData{SID: p.sid, Type: "taurus_cmp", Payload: packed}, nil +} + +func (p *CmpParty) executeProtocol(ctx context.Context, h protocol.Handler) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case msg, ok := <-h.Listen(): + if !ok { + return nil + } + p.network.Send(msg) + case msg := <-p.network.Next(): + if h.CanAccept(msg) { + h.Accept(msg) + } else { + logger.Warn("⚠️ Ignored invalid msg", "self", p.id) + } + case <-p.network.Done(): + return nil + } + } +} diff --git a/pkg/mpc/taurus/cmp_test.go b/pkg/mpc/taurus/cmp_test.go new file mode 100644 index 0000000..2f4979d --- /dev/null +++ b/pkg/mpc/taurus/cmp_test.go @@ -0,0 +1,93 @@ +package taurus + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/types" + "github.com/nats-io/nats.go" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/pool" +) + +func TestCmpParty(t *testing.T) { + sid := "test-session-123" + parties := []string{"party1", "party2", "party3"} + ids := make([]party.ID, len(parties)) + for i, id := range parties { + ids[i] = party.ID(id) + } + pl := pool.NewPool(0) + + natsConn, err := nats.Connect("nats://localhost:4223") + if err != nil { + logger.Fatal("Failed to connect to NATS", err) + } + + pubsub := messaging.NewNATSPubSub(natsConn) + + // networks + adapters + network1 := NewNATSTransport(sid, party.ID("party1"), pubsub) + network2 := NewNATSTransport(sid, party.ID("party2"), pubsub) + network3 := NewNATSTransport(sid, party.ID("party3"), pubsub) + + adapter1 := NewTaurusNetworkAdapter(sid, "party1", network1, ids) + adapter2 := NewTaurusNetworkAdapter(sid, "party2", network2, ids) + adapter3 := NewTaurusNetworkAdapter(sid, "party3", network3, ids) + + party1 := NewCmpParty(sid, "party1", ids, 2, pl, adapter1) + party2 := NewCmpParty(sid, "party2", ids, 2, pl, adapter2) + party3 := NewCmpParty(sid, "party3", ids, 2, pl, adapter3) + + result1 := make(chan types.KeyData, 1) + result2 := make(chan types.KeyData, 1) + result3 := make(chan types.KeyData, 1) + + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + defer wg.Done() + res, err := party1.Keygen(context.Background()) + if err != nil { + t.Errorf("party1 keygen error: %v", err) + return + } + result1 <- res + }() + + go func() { + defer wg.Done() + res, err := party2.Keygen(context.Background()) + if err != nil { + t.Errorf("party2 keygen error: %v", err) + return + } + result2 <- res + }() + + go func() { + defer wg.Done() + res, err := party3.Keygen(context.Background()) + if err != nil { + t.Errorf("party3 keygen error: %v", err) + return + } + result3 <- res + }() + + wg.Wait() + + // Read the actual values from channels + r1 := <-result1 + r2 := <-result2 + r3 := <-result3 + + fmt.Println("party1 result:", len(r1.Payload)) + fmt.Println("party2 result:", len(r2.Payload)) + fmt.Println("party3 result:", len(r3.Payload)) +} diff --git a/pkg/mpc/taurus/ecdsa_resharing_session.go b/pkg/mpc/taurus/ecdsa_resharing_session.go deleted file mode 100644 index 0aab2d1..0000000 --- a/pkg/mpc/taurus/ecdsa_resharing_session.go +++ /dev/null @@ -1,351 +0,0 @@ -package taurus - -import ( - "crypto/ecdsa" - "encoding/json" - "fmt" - "math/big" - "sync" - - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/protocol" - "github.com/fystack/mpcium/pkg/protocol/cggmp21" - "github.com/fystack/mpcium/pkg/utils" - "github.com/rs/zerolog" - "github.com/taurusgroup/multi-party-sig/pkg/party" -) - -// cggmp21ReshareSession implements ReshareSession for ECDSA using CGGMP21 -type cggmp21ReshareSession struct { - session - isNewPeer bool - pubKeyResult []byte - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - resultQueue messaging.MessageQueue - protocol protocol.Protocol - party protocol.Party - config protocol.KeyGenConfig - newThreshold int - newNodeIDs []string -} - -// newCGGMP21ReshareSession creates a new CGGMP21 reshare session -func newCGGMP21ReshareSession( - walletID string, - threshold int, - newThreshold int, - newNodeIDs []string, - isNewPeer bool, - pubSub messaging.PubSub, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - selfNodeID string, -) (*cggmp21ReshareSession, error) { - // Generate session ID for resharing - sessionID := fmt.Sprintf("reshare-%s", walletID) - - // For resharing, we need to determine the party IDs - var partyIDs []party.ID - - if !isNewPeer { - // For old peers, get the existing key info to find current parties - keyInfo, err := keyinfoStore.Get(walletID) - if err != nil { - return nil, fmt.Errorf("failed to get key info for resharing: %w", err) - } - - // Old peers use their existing party IDs - for _, id := range keyInfo.ParticipantPeerIDs { - partyIDs = append(partyIDs, party.ID(id)) - } - } else { - // New peers use the new node IDs - for _, id := range newNodeIDs { - partyIDs = append(partyIDs, party.ID(id)) - } - } - - // Create CGGMP21 protocol - protocol := cggmp21.NewCGGMP21Protocol() - - s := &cggmp21ReshareSession{ - session: session{ - walletID: walletID, - sessionID: sessionID, - pubSub: pubSub, - selfPartyID: party.ID(selfNodeID), - partyIDs: partyIDs, - subscriberList: []messaging.Subscription{}, - rounds: 5, // CGGMP21 has 5 rounds - outCh: make(chan msg, 100), - errCh: make(chan error, 10), - finishCh: make(chan bool, 1), - externalFinishChan: make(chan string, 1), - threshold: threshold, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - resultQueue: resultQueue, - logger: zerolog.New(utils.ZerologConsoleWriter()).With().Timestamp().Logger(), - processing: make(map[string]bool), - processingLock: sync.Mutex{}, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("reshare:broadcast:cggmp21:%s", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("reshare:direct:cggmp21:%s:%s", nodeID, walletID) - }, - }, - identityStore: nil, // Not needed for resharing - }, - isNewPeer: isNewPeer, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - resultQueue: resultQueue, - protocol: protocol, - newThreshold: newThreshold, - newNodeIDs: newNodeIDs, - } - - // Load existing config for old peers - if !isNewPeer { - config, err := s.loadConfig(walletID) - if err != nil { - return nil, fmt.Errorf("failed to load existing config: %w", err) - } - s.config = config - } - - return s, nil -} - -// Init initializes the reshare session -func (s *cggmp21ReshareSession) Init() { - s.logger.Info(). - Str("sessionID", s.sessionID). - Bool("isNewPeer", s.isNewPeer). - Int("threshold", s.threshold). - Int("newThreshold", s.newThreshold). - Msg("Initializing CGGMP21 reshare session") -} - -// Reshare starts the resharing protocol -func (s *cggmp21ReshareSession) Reshare(done func()) { - defer done() - - s.logger.Info(). - Str("sessionID", s.sessionID). - Bool("isNewPeer", s.isNewPeer). - Int("threshold", s.threshold). - Msg("Starting CGGMP21 reshare session") - - // Create the protocol party - var err error - if s.isNewPeer { - // New peers participate in key generation with the new committee - // For new peers, this is essentially a new key generation - // but coordinated with the refresh protocol of old peers - s.party, err = s.protocol.KeyGen( - string(s.selfPartyID), - convertFromPartyIDs(s.partyIDs), - s.newThreshold, - ) - } else { - // Old peers run the refresh protocol - s.party, err = s.protocol.Refresh(s.config) - } - - if err != nil { - s.errCh <- fmt.Errorf("failed to create reshare party: %w", err) - return - } - - // Start listening for messages - s.ListenToIncomingMessageAsync(s.ProcessInboundMessage) - go s.ProcessOutboundMessage() - - // Wait for protocol to complete - <-s.finishCh - - // Process the result - if s.party.Done() { - result, err := s.party.Result() - if err != nil { - s.errCh <- fmt.Errorf("reshare protocol failed: %w", err) - return - } - - // Handle the result based on peer type - if newConfig, ok := result.(protocol.KeyGenConfig); ok { - // Save the new configuration - if err := s.saveConfig(newConfig); err != nil { - s.errCh <- fmt.Errorf("failed to save reshare result: %w", err) - return - } - - // Extract public key for result - pubKey := newConfig.GetPublicKey() - if pubKey != nil { - pubKeyBytes := append(pubKey.X.Bytes(), pubKey.Y.Bytes()...) - s.pubKeyResult = pubKeyBytes - } - - s.logger.Info(). - Str("sessionID", s.sessionID). - Bool("isNewPeer", s.isNewPeer). - Msg("CGGMP21 reshare completed successfully") - } else { - s.errCh <- fmt.Errorf("unexpected result type from reshare: %T", result) - } - } -} - -// ProcessInboundMessage handles incoming protocol messages -func (s *cggmp21ReshareSession) ProcessInboundMessage(msgBytes []byte) { - // Implementation similar to keygen session - // Convert message and send to protocol party -} - -// ProcessOutboundMessage handles outgoing protocol messages -func (s *cggmp21ReshareSession) ProcessOutboundMessage() { - // Implementation similar to keygen session -} - -// GetPubKeyResult returns the public key after successful resharing -func (s *cggmp21ReshareSession) GetPubKeyResult() []byte { - return s.pubKeyResult -} - -// IsNewPeer returns true if this node is joining as a new peer -func (s *cggmp21ReshareSession) IsNewPeer() bool { - return s.isNewPeer -} - -// ErrChan returns the error channel -func (s *cggmp21ReshareSession) ErrChan() <-chan error { - return s.errCh -} - -// Stop stops the session -func (s *cggmp21ReshareSession) Stop() { - // Protocol doesn't have Close method - close(s.outCh) - close(s.errCh) -} - -// WaitForFinish waits for the session to complete -func (s *cggmp21ReshareSession) WaitForFinish() string { - return <-s.externalFinishChan -} - -// loadConfig loads the existing key configuration -func (s *cggmp21ReshareSession) loadConfig(walletID string) (protocol.KeyGenConfig, error) { - // Get key info - keyInfo, err := s.keyinfoStore.Get(walletID) - if err != nil { - return nil, err - } - - // Load the key share data - keyShareData, err := s.kvstore.Get(walletID) - if err != nil { - return nil, err - } - - // Create a config adapter that implements protocol.KeyGenConfig - return &keyGenConfigAdapter{ - keyInfo: keyInfo, - keyShareData: keyShareData, - walletID: walletID, - }, nil -} - -// saveConfig saves the new key configuration after resharing -func (s *cggmp21ReshareSession) saveConfig(config protocol.KeyGenConfig) error { - // Serialize the config - configData, err := config.Serialize() - if err != nil { - return fmt.Errorf("failed to serialize config: %w", err) - } - - // Save to kvstore - if err := s.kvstore.Put(s.walletID, configData); err != nil { - return fmt.Errorf("failed to save share data: %w", err) - } - - // Update key info - keyInfo := &keyinfo.KeyInfo{ - ParticipantPeerIDs: s.newNodeIDs, - Threshold: s.newThreshold, - Version: 1, - } - - if err := s.keyinfoStore.Save(s.walletID, keyInfo); err != nil { - return fmt.Errorf("failed to save key info: %w", err) - } - - return nil -} - -// keyGenConfigAdapter adapts stored key data to protocol.KeyGenConfig interface -type keyGenConfigAdapter struct { - keyInfo *keyinfo.KeyInfo - keyShareData []byte - walletID string -} - -func (a *keyGenConfigAdapter) GetPartyID() string { - // Extract from the stored data - this is implementation specific - var data map[string]interface{} - if err := json.Unmarshal(a.keyShareData, &data); err != nil { - return "" - } - if id, ok := data["ID"].(string); ok { - return id - } - return "" -} - -func (a *keyGenConfigAdapter) GetThreshold() int { - return a.keyInfo.Threshold -} - -func (a *keyGenConfigAdapter) GetPublicKey() *ecdsa.PublicKey { - // Extract from stored data - var data map[string]interface{} - if err := json.Unmarshal(a.keyShareData, &data); err != nil { - return nil - } - - // This is a simplified version - actual implementation would need proper parsing - return nil -} - -func (a *keyGenConfigAdapter) GetShare() *big.Int { - // Extract from the stored data - var data map[string]interface{} - if err := json.Unmarshal(a.keyShareData, &data); err != nil { - return nil - } - - // This is a simplified version - actual implementation would need proper parsing - return nil -} - -func (a *keyGenConfigAdapter) GetSharePublicKey() *ecdsa.PublicKey { - // This would need to be extracted from the stored data - // For now, return nil as it's not critical for refresh - return nil -} - -func (a *keyGenConfigAdapter) GetPartyIDs() []string { - return a.keyInfo.ParticipantPeerIDs -} - -func (a *keyGenConfigAdapter) Serialize() ([]byte, error) { - return a.keyShareData, nil -} diff --git a/pkg/mpc/taurus/eddsa_resharing_session.go b/pkg/mpc/taurus/eddsa_resharing_session.go deleted file mode 100644 index 837d8fe..0000000 --- a/pkg/mpc/taurus/eddsa_resharing_session.go +++ /dev/null @@ -1,345 +0,0 @@ -package taurus - -import ( - "crypto/ecdsa" - "crypto/ed25519" - "encoding/json" - "fmt" - "math/big" - "sync" - - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/protocol" - "github.com/fystack/mpcium/pkg/protocol/frost" - "github.com/fystack/mpcium/pkg/utils" - "github.com/rs/zerolog" - "github.com/taurusgroup/multi-party-sig/pkg/party" -) - -// eddsaReshareSession implements ReshareSession for EdDSA using FROST -type eddsaReshareSession struct { - session - isNewPeer bool - pubKeyResult []byte - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - resultQueue messaging.MessageQueue - protocol protocol.Protocol - party protocol.Party - config protocol.KeyGenConfig - newThreshold int - newNodeIDs []string -} - -// newEdDSAReshareSession creates a new EdDSA reshare session -func newEdDSAReshareSession( - walletID string, - threshold int, - newThreshold int, - newNodeIDs []string, - isNewPeer bool, - pubSub messaging.PubSub, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - selfNodeID string, -) (*eddsaReshareSession, error) { - // Generate session ID for resharing - sessionID := fmt.Sprintf("reshare-%s", walletID) - - // For resharing, we need to determine the party IDs - var partyIDs []party.ID - - if !isNewPeer { - // For old peers, get the existing key info to find current parties - keyInfo, err := keyinfoStore.Get(walletID) - if err != nil { - return nil, fmt.Errorf("failed to get key info for resharing: %w", err) - } - - // Old peers use their existing party IDs - for _, id := range keyInfo.ParticipantPeerIDs { - partyIDs = append(partyIDs, party.ID(id)) - } - } else { - // New peers use the new node IDs - for _, id := range newNodeIDs { - partyIDs = append(partyIDs, party.ID(id)) - } - } - - // Create FROST protocol - protocol := frost.NewFROSTProtocol() - - s := &eddsaReshareSession{ - session: session{ - walletID: walletID, - sessionID: sessionID, - pubSub: pubSub, - selfPartyID: party.ID(selfNodeID), - partyIDs: partyIDs, - subscriberList: []messaging.Subscription{}, - rounds: 3, // FROST has fewer rounds - outCh: make(chan msg, 100), - errCh: make(chan error, 10), - finishCh: make(chan bool, 1), - externalFinishChan: make(chan string, 1), - threshold: threshold, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - resultQueue: resultQueue, - logger: zerolog.New(utils.ZerologConsoleWriter()).With().Timestamp().Logger(), - processing: make(map[string]bool), - processingLock: sync.Mutex{}, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("reshare:broadcast:frost:%s", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("reshare:direct:frost:%s:%s", nodeID, walletID) - }, - }, - identityStore: nil, // Not needed for resharing - }, - isNewPeer: isNewPeer, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - resultQueue: resultQueue, - protocol: protocol, - newThreshold: newThreshold, - newNodeIDs: newNodeIDs, - } - - // Load existing config for old peers - if !isNewPeer { - config, err := s.loadConfig(walletID) - if err != nil { - return nil, fmt.Errorf("failed to load existing config: %w", err) - } - s.config = config - } - - return s, nil -} - -// Init initializes the reshare session -func (s *eddsaReshareSession) Init() { - s.logger.Info(). - Str("sessionID", s.sessionID). - Bool("isNewPeer", s.isNewPeer). - Int("threshold", s.threshold). - Int("newThreshold", s.newThreshold). - Msg("Initializing EdDSA/FROST reshare session") -} - -// Reshare starts the resharing protocol -func (s *eddsaReshareSession) Reshare(done func()) { - defer done() - - s.logger.Info(). - Str("sessionID", s.sessionID). - Bool("isNewPeer", s.isNewPeer). - Int("threshold", s.threshold). - Msg("Starting EdDSA/FROST reshare session") - - // Create the protocol party - var err error - if s.isNewPeer { - // New peers participate in key generation with the new committee - s.party, err = s.protocol.KeyGen( - string(s.selfPartyID), - convertFromPartyIDs(s.partyIDs), - s.newThreshold, - ) - } else { - // Old peers run the refresh protocol - s.party, err = s.protocol.Refresh(s.config) - } - - if err != nil { - s.errCh <- fmt.Errorf("failed to create reshare party: %w", err) - return - } - - // Start listening for messages - s.ListenToIncomingMessageAsync(s.ProcessInboundMessage) - go s.ProcessOutboundMessage() - - // Wait for protocol to complete - <-s.finishCh - - // Process the result - if s.party.Done() { - result, err := s.party.Result() - if err != nil { - s.errCh <- fmt.Errorf("reshare protocol failed: %w", err) - return - } - - // Handle the result - if newConfig, ok := result.(protocol.KeyGenConfig); ok { - // Save the new configuration - if err := s.saveConfig(newConfig); err != nil { - s.errCh <- fmt.Errorf("failed to save reshare result: %w", err) - return - } - - // For EdDSA, we would extract the Ed25519 public key - // This is a placeholder - actual implementation would depend on the protocol - s.pubKeyResult = []byte{} // Placeholder - - s.logger.Info(). - Str("sessionID", s.sessionID). - Bool("isNewPeer", s.isNewPeer). - Msg("EdDSA/FROST reshare completed successfully") - } else { - s.errCh <- fmt.Errorf("unexpected result type from reshare: %T", result) - } - } -} - -// ProcessInboundMessage handles incoming protocol messages -func (s *eddsaReshareSession) ProcessInboundMessage(msgBytes []byte) { - // Implementation similar to keygen session -} - -// ProcessOutboundMessage handles outgoing protocol messages -func (s *eddsaReshareSession) ProcessOutboundMessage() { - // Implementation similar to keygen session -} - -// GetPubKeyResult returns the public key after successful resharing -func (s *eddsaReshareSession) GetPubKeyResult() []byte { - return s.pubKeyResult -} - -// IsNewPeer returns true if this node is joining as a new peer -func (s *eddsaReshareSession) IsNewPeer() bool { - return s.isNewPeer -} - -// ErrChan returns the error channel -func (s *eddsaReshareSession) ErrChan() <-chan error { - return s.errCh -} - -// Stop stops the session -func (s *eddsaReshareSession) Stop() { - // Protocol doesn't have Close method - close(s.outCh) - close(s.errCh) -} - -// WaitForFinish waits for the session to complete -func (s *eddsaReshareSession) WaitForFinish() string { - return <-s.externalFinishChan -} - -// loadConfig loads the existing key configuration -func (s *eddsaReshareSession) loadConfig(walletID string) (protocol.KeyGenConfig, error) { - // Get key info - keyInfo, err := s.keyinfoStore.Get(walletID) - if err != nil { - return nil, err - } - - // Load the key share data - keyShareData, err := s.kvstore.Get(walletID) - if err != nil { - return nil, err - } - - // Create a config adapter for EdDSA - return &eddsaKeyGenConfigAdapter{ - keyInfo: keyInfo, - keyShareData: keyShareData, - walletID: walletID, - }, nil -} - -// saveConfig saves the new key configuration after resharing -func (s *eddsaReshareSession) saveConfig(config protocol.KeyGenConfig) error { - // Serialize the config - configData, err := config.Serialize() - if err != nil { - return fmt.Errorf("failed to serialize config: %w", err) - } - - // Save to kvstore - if err := s.kvstore.Put(s.walletID, configData); err != nil { - return fmt.Errorf("failed to save share data: %w", err) - } - - // Update key info - keyInfo := &keyinfo.KeyInfo{ - ParticipantPeerIDs: s.newNodeIDs, - Threshold: s.newThreshold, - Version: 1, - } - - if err := s.keyinfoStore.Save(s.walletID, keyInfo); err != nil { - return fmt.Errorf("failed to save key info: %w", err) - } - - return nil -} - -// eddsaKeyGenConfigAdapter adapts stored key data to protocol.KeyGenConfig interface for EdDSA -type eddsaKeyGenConfigAdapter struct { - keyInfo *keyinfo.KeyInfo - keyShareData []byte - walletID string -} - -func (a *eddsaKeyGenConfigAdapter) GetPartyID() string { - // Extract from the stored data - var data map[string]interface{} - if err := json.Unmarshal(a.keyShareData, &data); err != nil { - return "" - } - if id, ok := data["ID"].(string); ok { - return id - } - return "" -} - -func (a *eddsaKeyGenConfigAdapter) GetThreshold() int { - return a.keyInfo.Threshold -} - -func (a *eddsaKeyGenConfigAdapter) GetPublicKey() *ecdsa.PublicKey { - // EdDSA doesn't use ECDSA public keys - return nil -} - -// GetPublicKeyEd25519 returns the Ed25519 public key -func (a *eddsaKeyGenConfigAdapter) GetPublicKeyEd25519() ed25519.PublicKey { - // Extract from stored data - var data map[string]interface{} - if err := json.Unmarshal(a.keyShareData, &data); err != nil { - return nil - } - - // This is a simplified version - actual implementation would need proper parsing - return nil -} - -func (a *eddsaKeyGenConfigAdapter) GetShare() *big.Int { - // EdDSA shares are handled differently - return nil -} - -func (a *eddsaKeyGenConfigAdapter) GetSharePublicKey() *ecdsa.PublicKey { - // EdDSA doesn't use ECDSA public keys - return nil -} - -func (a *eddsaKeyGenConfigAdapter) GetPartyIDs() []string { - return a.keyInfo.ParticipantPeerIDs -} - -func (a *eddsaKeyGenConfigAdapter) Serialize() ([]byte, error) { - return a.keyShareData, nil -} diff --git a/pkg/mpc/taurus/keygen_session.go b/pkg/mpc/taurus/keygen_session.go deleted file mode 100644 index c87e1be..0000000 --- a/pkg/mpc/taurus/keygen_session.go +++ /dev/null @@ -1,328 +0,0 @@ -package taurus - -import ( - "encoding/json" - "fmt" - "sync" - - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/types" - "github.com/fystack/mpcium/pkg/utils" - "github.com/rs/zerolog" - "github.com/taurusgroup/multi-party-sig/pkg/math/curve" - "github.com/taurusgroup/multi-party-sig/pkg/party" - "github.com/taurusgroup/multi-party-sig/pkg/pool" - "github.com/taurusgroup/multi-party-sig/pkg/protocol" - "github.com/taurusgroup/multi-party-sig/protocols/cmp" - "github.com/taurusgroup/multi-party-sig/protocols/cmp/config" -) - -type KeyGenSession interface { - Session -} - -type cggmp21KeygenSession struct { - session - handler *protocol.MultiHandler - pool *pool.Pool - config *config.Config - messagesCh chan *protocol.Message - resultMutex sync.Mutex - done bool - resultErr error -} - -func NewCGGMP21KeygenSession( - walletID string, - pubSub messaging.PubSub, - selfPartyID party.ID, - partyIDs []party.ID, - threshold int, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, -) *cggmp21KeygenSession { - // Create thread pool - threadPool := pool.NewPool(0) // Use max threads - - return &cggmp21KeygenSession{ - session: session{ - walletID: walletID, - pubSub: pubSub, - selfPartyID: selfPartyID, - partyIDs: partyIDs, - subscriberList: []messaging.Subscription{}, - rounds: 5, // CGGMP21 keygen has 5 rounds - outCh: make(chan msg, 100), - errCh: make(chan error, 10), - finishCh: make(chan bool, 1), - externalFinishChan: make(chan string, 1), - threshold: threshold, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - resultQueue: resultQueue, - logger: zerolog.New(utils.ZerologConsoleWriter()).With().Timestamp().Logger(), - processing: make(map[string]bool), - processingLock: sync.Mutex{}, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("keygen:broadcast:cggmp21:%s", walletID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("keygen:direct:cggmp21:%s:%s", nodeID, walletID) - }, - }, - identityStore: identityStore, - }, - pool: threadPool, - messagesCh: make(chan *protocol.Message, 100), - done: false, - } -} - -func (s *cggmp21KeygenSession) Init() { - s.logger.Info(). - Int("threshold", s.threshold). - Interface("partyIDs", s.partyIDs). - Msg("Initializing CGGMP21 keygen session") - - // Create CGGMP21 keygen protocol - startFunc := cmp.Keygen(curve.Secp256k1{}, s.selfPartyID, s.partyIDs, s.threshold, s.pool) - - // Create handler - handler, err := protocol.NewMultiHandler(startFunc, nil) - if err != nil { - s.logger.Fatal().Err(err).Msg("Failed to create keygen handler") - return - } - - s.handler = handler - - // Start message handling goroutine - go s.handleProtocolMessages() - - s.logger.Info(). - Str("partyID", string(s.selfPartyID)). - Interface("peerIDs", s.partyIDs). - Str("walletID", s.walletID). - Msg("[INITIALIZED] CGGMP21 keygen session initialized successfully") -} - -func (s *cggmp21KeygenSession) handleProtocolMessages() { - for { - select { - case protoMsg, ok := <-s.handler.Listen(): - if !ok { - // Protocol finished - s.resultMutex.Lock() - s.done = true - result, err := s.handler.Result() - if err != nil { - s.resultErr = err - s.errCh <- err - } else { - s.config = result.(*config.Config) - } - s.resultMutex.Unlock() - s.finishCh <- true - return - } - - // Convert protocol message to our message format - var toPartyIDs []party.ID - if !protoMsg.Broadcast && protoMsg.To != "" { - toPartyIDs = []party.ID{protoMsg.To} - } - outMsg := msg{ - FromPartyID: protoMsg.From, - ToPartyIDs: toPartyIDs, - IsBroadcast: protoMsg.Broadcast, - Data: protoMsg.Data, - } - - s.outCh <- outMsg - - case protoMsg := <-s.messagesCh: - // Handle incoming message - if !s.handler.CanAccept(protoMsg) { - s.logger.Warn().Msgf("Handler cannot accept message from %s", protoMsg.From) - continue - } - - s.handler.Accept(protoMsg) - } - } -} - -func (s *cggmp21KeygenSession) ProcessInboundMessage(msgBytes []byte) { - s.processingLock.Lock() - defer s.processingLock.Unlock() - - inboundMessage := &types.TaurusMessage{} - if err := json.Unmarshal(msgBytes, inboundMessage); err != nil { - s.logger.Error().Err(err).Msg("ProcessInboundMessage unmarshal error") - return - } - - msgHashStr := fmt.Sprintf("%x", utils.GetMessageHash(msgBytes)) - if s.processing[msgHashStr] { - return - } - s.processing[msgHashStr] = true - - // Convert to protocol message - protoMsg := &protocol.Message{ - From: party.ID(inboundMessage.SenderID), - To: party.ID(""), // Single recipient for protocol messages - Data: inboundMessage.Body, - Broadcast: inboundMessage.IsBroadcast, - } - - // Send to handler - s.messagesCh <- protoMsg -} - -func (s *cggmp21KeygenSession) ProcessOutboundMessage() { - s.logger.Info().Msgf("ProcessOutboundMessage started: %s", s.walletID) - for { - select { - case m := <-s.outCh: - // Convert party IDs back to strings - recipientIDs := make([]string, len(m.ToPartyIDs)) - for i, pid := range m.ToPartyIDs { - recipientIDs[i] = string(pid) - } - - msgWireBytes := &types.TaurusMessage{ - SessionID: s.walletID, - SenderID: string(m.FromPartyID), - RecipientIDs: recipientIDs, - Body: m.Data, - IsBroadcast: m.IsBroadcast, - } - - s.sendMsg(msgWireBytes) - - case err := <-s.errCh: - s.logger.Error().Err(err).Msg("Received error during ProcessOutboundMessage") - - case <-s.finishCh: - s.logger.Info().Msg("Received finish message during ProcessOutboundMessage") - s.publishResult() - return - } - } -} - -func (s *cggmp21KeygenSession) publishResult() { - s.resultMutex.Lock() - defer s.resultMutex.Unlock() - - if s.resultErr != nil { - // failureEvent := event.CreateKeygenFailure( - // s.walletID, - // map[string]any{ - // "error": s.resultErr.Error(), - // }, - // ) - // evtData, _ := json.Marshal(failureEvent) - // if err := s.resultQueue.Enqueue(fmt.Sprintf("mpc.keygen_result.%s", s.walletID), evtData, nil); err != nil { - // s.logger.Error().Err(err).Msg("failed to publish keygen failure event") - // } - return - } - - if s.config == nil { - s.logger.Error().Msg("No config available after keygen completion") - return - } - - // Save key share - shareBytes, err := json.Marshal(s.config) - if err != nil { - s.logger.Error().Err(err).Msg("Failed to marshal key share") - return - } - - if err := s.kvstore.Put(s.walletID, shareBytes); err != nil { - s.logger.Error().Err(err).Msgf("Failed to save key share for wallet %s", s.walletID) - return - } - - // Convert public key to hex - // Use the X coordinate as a simple representation - var pubKeyHex string - if s.config != nil && s.config.PublicPoint() != nil { - if xScalar := s.config.PublicPoint().XScalar(); xScalar != nil { - xBytes, _ := xScalar.MarshalBinary() - pubKeyHex = fmt.Sprintf("%x", xBytes) - } - } - - // Save key info - keyInfo := &keyinfo.KeyInfo{ - ParticipantPeerIDs: convertFromPartyIDs(s.partyIDs), - Threshold: s.threshold, - Version: 1, - } - - if err := s.keyinfoStore.Save(s.walletID, keyInfo); err != nil { - s.logger.Error().Err(err).Msgf("Failed to save key info for wallet %s", s.walletID) - return - } - - // Publish success event - // successEvent := event.CreateKeygenSuccess( - // s.walletID, - // pubKeyHex, - // map[string]any{ - // "threshold": s.threshold, - // "parties": len(s.partyIDs), - // "protocol": "CGGMP21", - // }, - // ) - - // evtData, _ := json.Marshal(successEvent) - // if err := s.resultQueue.Enqueue(fmt.Sprintf("mpc.keygen_result.%s", s.walletID), evtData, nil); err != nil { - // s.logger.Error().Err(err).Msg("failed to publish keygen success event") - // } - - s.logger.Info(). - Str("walletID", s.walletID). - Str("publicKey", pubKeyHex). - Msg("CGGMP21 keygen completed successfully") -} - -func (s *cggmp21KeygenSession) Stop() { - if s.pool != nil { - s.pool.TearDown() - } - close(s.outCh) - close(s.errCh) - close(s.messagesCh) -} - -func (s *cggmp21KeygenSession) WaitForFinish() string { - return <-s.externalFinishChan -} - -// Helper functions -func convertToPartyIDs(ids []string) []party.ID { - result := make([]party.ID, len(ids)) - for i, id := range ids { - result[i] = party.ID(id) - } - return result -} - -func convertFromPartyIDs(ids []party.ID) []string { - result := make([]string, len(ids)) - for i, id := range ids { - result[i] = string(id) - } - return result -} diff --git a/pkg/mpc/taurus/nats_transport.go b/pkg/mpc/taurus/nats_transport.go new file mode 100644 index 0000000..b14bae7 --- /dev/null +++ b/pkg/mpc/taurus/nats_transport.go @@ -0,0 +1,146 @@ +package taurus + +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/nats-io/nats.go" + "github.com/taurusgroup/multi-party-sig/pkg/party" + + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" +) + +type NATSTransport struct { + selfID string + wallet string + pubsub messaging.PubSub + + composeBroadcast func() string + composeDirect func(nodeID string) string + + inbox chan Msg + doneCh chan struct{} + errCh chan error + + mu sync.Mutex + subs []messaging.Subscription + closeMu sync.Once +} + +// NewNATSTransport creates a transport bound to a walletID and party. +func NewNATSTransport(walletID string, self party.ID, pubsub messaging.PubSub) *NATSTransport { + t := &NATSTransport{ + selfID: string(self), + wallet: walletID, + pubsub: pubsub, + inbox: make(chan Msg, 128), + doneCh: make(chan struct{}), + errCh: make(chan error, 8), + + composeBroadcast: func() string { + return fmt.Sprintf("mpc:broadcast:%s", walletID) + }, + composeDirect: func(nodeID string) string { + return fmt.Sprintf("mpc:direct:%s:%s", nodeID, walletID) + }, + } + + // subscribe broadcast + bcastTopic := t.composeBroadcast() + bcast, err := pubsub.Subscribe(bcastTopic, func(m *nats.Msg) { + t.handleRaw(m.Data) + }) + if err == nil { + t.subs = append(t.subs, bcast) + } else { + t.pushErr(err) + } + + // subscribe direct + directTopic := t.composeDirect(t.selfID) + direct, err := pubsub.Subscribe(directTopic, func(m *nats.Msg) { + t.handleRaw(m.Data) + }) + if err == nil { + t.subs = append(t.subs, direct) + } else { + t.pushErr(err) + } + + logger.Info("✅ NATSTransport listening", + "wallet", walletID, + "broadcast", bcastTopic, + "direct", directTopic) + + return t +} + +// --- Transport interface --- + +func (t *NATSTransport) Send(to string, msg Msg) error { + // Marshal the message + data, err := encoding.StructToJsonBytes(&msg) + if err != nil { + return err + } + + if msg.IsBroadcast { + // publish to broadcast topic + topic := t.composeBroadcast() + return t.pubsub.Publish(topic, data) + } + + // unicast to "to" + topic := t.composeDirect(to) + return t.pubsub.Publish(topic, data) +} + +func (t *NATSTransport) Inbox() <-chan Msg { return t.inbox } +func (t *NATSTransport) Done() <-chan struct{} { return t.doneCh } + +func (t *NATSTransport) Close() error { + t.closeMu.Do(func() { + t.mu.Lock() + defer t.mu.Unlock() + for i, sub := range t.subs { + if sub != nil { + _ = sub.Unsubscribe() + logger.Debug("✅ unsubscribed", "index", i, "wallet", t.wallet) + } + } + close(t.inbox) + close(t.errCh) + close(t.doneCh) + logger.Info("🛑 NATSTransport closed", "wallet", t.wallet) + }) + return nil +} + +// --- Internal helpers --- + +func (t *NATSTransport) handleRaw(data []byte) { + var m Msg + if err := json.Unmarshal(data, &m); err != nil { + t.pushErr(fmt.Errorf("unmarshal inbound: %w", err)) + return + } + if m.From == t.selfID { + return // skip self + } + select { + case t.inbox <- m: + default: + logger.Warn("⚠️ dropping inbound message, inbox full", "wallet", t.wallet) + } +} + +func (t *NATSTransport) pushErr(err error) { + select { + case t.errCh <- err: + default: + logger.Warn("⚠️ dropping error (buffer full)", "wallet", t.wallet) + } +} diff --git a/pkg/mpc/taurus/node.go b/pkg/mpc/taurus/node.go deleted file mode 100644 index c64fe15..0000000 --- a/pkg/mpc/taurus/node.go +++ /dev/null @@ -1,242 +0,0 @@ -package taurus - -import ( - "fmt" - "time" - - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/taurusgroup/multi-party-sig/pkg/party" -) - -const ( - PurposeKeygen string = "keygen" - PurposeSign string = "sign" - PurposeReshare string = "reshare" - - DefaultVersion int = 1 -) - -type ID string - -type Node struct { - nodeID string - peerIDs []string - pubSub messaging.PubSub - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - identityStore identity.Store - peerRegistry PeerRegistry -} - -func ComposeReadyKey(nodeID string) string { - return fmt.Sprintf("ready/%s", nodeID) -} - -func NewNode( - nodeID string, - peerIDs []string, - pubSub messaging.PubSub, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - peerRegistry PeerRegistry, - identityStore identity.Store, -) *Node { - start := time.Now() - elapsed := time.Since(start) - logger.Info("Starting new CGGMP21 node", "nodeID", nodeID, "elapsed", elapsed.Milliseconds()) - - node := &Node{ - nodeID: nodeID, - peerIDs: peerIDs, - pubSub: pubSub, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - peerRegistry: peerRegistry, - identityStore: identityStore, - } - - go peerRegistry.WatchPeersReady() - return node -} - -func (p *Node) ID() string { - return p.nodeID -} - -func (p *Node) KeyInfoStore() keyinfo.Store { - return p.keyinfoStore -} - -// func (p *Node) CreateKeyGenSession( -// walletID string, -// threshold int, -// resultQueue messaging.MessageQueue, -// ) (KeyGenSession, error) { -// if !p.peerRegistry.ArePeersReady() { -// return nil, fmt.Errorf( -// "peers are not ready yet. ready: %d, expected: %d", -// p.peerRegistry.GetReadyPeersCount(), -// len(p.peerIDs)+1, -// ) -// } - -// readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() -// selfPartyID, allPartyIDs := p.generatePartyIDs(PurposeKeygen, readyPeerIDs, DefaultVersion) - -// session := newCGGMP21KeygenSession( -// walletID, -// p.pubSub, -// selfPartyID, -// allPartyIDs, -// threshold, -// p.kvstore, -// p.keyinfoStore, -// resultQueue, -// p.identityStore, -// ) - -// session.Init() -// return session, nil -// } - -func (p *Node) CreateSignSession( - sessionID string, - walletID string, - messageHash []byte, - signerPeerIDs []string, - resultQueue messaging.MessageQueue, - useBroadcast bool, -) (SignSession, error) { - // Check if we have enough signers - keyInfo, err := p.keyinfoStore.Get(walletID) - if err != nil { - return nil, fmt.Errorf("failed to get key info: %w", err) - } - - if len(signerPeerIDs) < keyInfo.Threshold+1 { - return nil, ErrNotEnoughParticipants - } - - // Check if this node is in the signer list - if !contains(signerPeerIDs, p.nodeID) { - return nil, ErrNotInParticipantList - } - - // Generate party IDs for signers - version := p.getVersion(SessionTypeCGGMP21, walletID) - selfPartyID, signerPartyIDs := p.generatePartyIDs(PurposeSign, signerPeerIDs, version) - - session, err := newCGGMP21SigningSession( - sessionID, - walletID, - messageHash, - p.pubSub, - selfPartyID, - signerPartyIDs, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.identityStore, - useBroadcast, - ) - if err != nil { - return nil, err - } - - session.Init() - return session, nil -} - -func (p *Node) generatePartyIDs(purpose string, peerIDs []string, version int) (party.ID, []party.ID) { - partyIDs := make([]party.ID, len(peerIDs)) - var selfPartyID party.ID - - for i, peerID := range peerIDs { - partyID := createPartyID(peerID, purpose, version) - partyIDs[i] = partyID - if peerID == p.nodeID { - selfPartyID = partyID - } - } - - return selfPartyID, partyIDs -} - -func createPartyID(sessionID string, keyType string, version int) party.ID { - if version == 0 { - // Backward compatible version - just use sessionID - return party.ID(sessionID) - } - // Include version in party ID - return party.ID(fmt.Sprintf("%s:%s:%d", sessionID, keyType, version)) -} - -func (p *Node) getVersion(sessionType SessionType, walletID string) int { - // In production, you might want to store and retrieve version info - // For now, always use the default version - return DefaultVersion -} - -func (p *Node) CreateReshareSession( - sessionType SessionType, - walletID string, - threshold int, - newThreshold int, - newNodeIDs []string, - isNewPeer bool, - resultQueue messaging.MessageQueue, -) (ReshareSession, error) { - logger.Info("Creating reshare session", - "sessionType", sessionType, - "walletID", walletID, - "threshold", threshold, - "newThreshold", newThreshold, - "newNodeIDs", newNodeIDs, - "isNewPeer", isNewPeer, - "nodeID", p.nodeID, - ) - - switch sessionType { - case SessionTypeECDSA: - return newCGGMP21ReshareSession( - walletID, - threshold, - newThreshold, - newNodeIDs, - isNewPeer, - p.pubSub, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.nodeID, - ) - case SessionTypeEDDSA: - return newEdDSAReshareSession( - walletID, - threshold, - newThreshold, - newNodeIDs, - isNewPeer, - p.pubSub, - p.kvstore, - p.keyinfoStore, - resultQueue, - p.nodeID, - ) - default: - return nil, fmt.Errorf("unsupported session type for reshare: %v", sessionType) - } -} - -func contains(slice []string, item string) bool { - for _, s := range slice { - if s == item { - return true - } - } - return false -} diff --git a/pkg/mpc/taurus/node_test.go b/pkg/mpc/taurus/node_test.go deleted file mode 100644 index 805b3b2..0000000 --- a/pkg/mpc/taurus/node_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package taurus - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestPartyIDToNodeID(t *testing.T) { - partyID := createPartyID("4d8cb873-dc86-4776-b6f6-cf5c668f6468", "keygen", 1) - nodeID := PartyIDToRoutingDest(partyID) - assert.Equal(t, "4d8cb873-dc86-4776-b6f6-cf5c668f6468:keygen:1", nodeID, "NodeID should be equal") -} - -func TestCreatePartyID_Structure(t *testing.T) { - sessionID := "test-session-123" - keyType := "keygen" - version := 5 - - partyID := createPartyID(sessionID, keyType, version) - - assert.NotNil(t, partyID) - // The party ID should be in the format sessionID:keyType:version - expectedID := "test-session-123:keygen:5" - assert.Equal(t, expectedID, string(partyID)) -} - -func TestCreatePartyID_DifferentVersions(t *testing.T) { - sessionID := "test-session-456" - keyType := "keygen" - - // Test version 0 (backward compatible) - partyID0 := createPartyID(sessionID, keyType, 0) - assert.NotNil(t, partyID0) - // Version 0 should just be the sessionID - assert.Equal(t, sessionID, string(partyID0)) - - // Test version 1 (default) - partyID1 := createPartyID(sessionID, keyType, DefaultVersion) - assert.NotNil(t, partyID1) - // Version 1 should include version info - expectedID1 := "test-session-456:keygen:1" - assert.Equal(t, expectedID1, string(partyID1)) - - // Different versions should produce different party IDs - assert.NotEqual(t, partyID0, partyID1) -} - -func TestPartyIDToRoutingDest_BackwardCompatible(t *testing.T) { - sessionID := "test-session-789" - keyType := "signing" - - partyID := createPartyID(sessionID, keyType, 0) - nodeID := PartyIDToRoutingDest(partyID) - - // For backward compatible version, should just be the sessionID - assert.Equal(t, sessionID, nodeID) -} - -func TestPartyIDToRoutingDest_DefaultVersion(t *testing.T) { - sessionID := "test-session-999" - keyType := "signing" - - partyID := createPartyID(sessionID, keyType, DefaultVersion) - nodeID := PartyIDToRoutingDest(partyID) - - // For default version, should be the full party ID string - expected := "test-session-999:signing:1" - assert.Equal(t, expected, nodeID) -} - -func TestCreatePartyID_EmptyValues(t *testing.T) { - // Test with empty session ID - partyID := createPartyID("", "keygen", 0) - assert.NotNil(t, partyID) - // Version 0 should just return empty string - assert.Equal(t, "", string(partyID)) - - // Test with empty key type - partyID = createPartyID("session", "", 1) - assert.NotNil(t, partyID) - // Should still create the party ID with format - expectedID := "session::1" - assert.Equal(t, expectedID, string(partyID)) -} - -func TestPartyIDToRoutingDest_Consistency(t *testing.T) { - sessionID := "consistent-session" - keyType := "keygen" - version := 3 - - // Create the same party ID multiple times - partyID1 := createPartyID(sessionID, keyType, version) - partyID2 := createPartyID(sessionID, keyType, version) - - nodeID1 := PartyIDToRoutingDest(partyID1) - nodeID2 := PartyIDToRoutingDest(partyID2) - - // Should produce consistent results based on sessionID and version - assert.Equal(t, nodeID1, nodeID2, "Same parameters should produce same routing destinations") -} - -func TestCreatePartyID_SameParameters(t *testing.T) { - sessionID := "test-session" - keyType := "keygen" - version := 1 - - // Create multiple party IDs with same parameters - partyID1 := createPartyID(sessionID, keyType, version) - partyID2 := createPartyID(sessionID, keyType, version) - - // Party IDs with same parameters should be identical in the new implementation - assert.Equal(t, partyID1, partyID2, "Party IDs with same parameters should be equal") - - // Both should have the same format - expectedID := "test-session:keygen:1" - assert.Equal(t, expectedID, string(partyID1)) - assert.Equal(t, expectedID, string(partyID2)) -} diff --git a/pkg/mpc/taurus/registry.go b/pkg/mpc/taurus/registry.go deleted file mode 100644 index cfe334d..0000000 --- a/pkg/mpc/taurus/registry.go +++ /dev/null @@ -1,213 +0,0 @@ -package taurus - -import ( - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/fystack/mpcium/pkg/infra" - "github.com/fystack/mpcium/pkg/logger" - "github.com/hashicorp/consul/api" - "github.com/samber/lo" -) - -const ( - ReadinessCheckPeriod = 1 * time.Second -) - -type PeerRegistry interface { - Ready() error - ArePeersReady() bool - WatchPeersReady() - // Resign is called by the node when it is going to shutdown - Resign() error - GetReadyPeersCount() int64 - GetReadyPeersIncludeSelf() []string // get ready peers include self - GetTotalPeersCount() int64 -} - -type registry struct { - nodeID string - peerNodeIDs []string - readyMap map[string]bool - readyCount int64 - mu sync.RWMutex - ready bool // ready is true when all peers are ready - - consulKV infra.ConsulKV -} - -func NewRegistry( - nodeID string, - peerNodeIDs []string, - consulKV infra.ConsulKV, -) *registry { - return ®istry{ - consulKV: consulKV, - nodeID: nodeID, - peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), - readyMap: make(map[string]bool), - readyCount: 1, // self - } -} - -func getPeerIDsExceptSelf(nodeID string, peerNodeIDs []string) []string { - peerIDs := make([]string, 0, len(peerNodeIDs)) - for _, peerID := range peerNodeIDs { - if peerID != nodeID { - peerIDs = append(peerIDs, peerID) - } - } - return peerIDs -} - -func (r *registry) readyKey(nodeID string) string { - return fmt.Sprintf("ready/%s", nodeID) -} - -func (r *registry) registerReadyPairs(peerIDs []string) { - for _, peerID := range peerIDs { - ready, exist := r.readyMap[peerID] - if !exist { - atomic.AddInt64(&r.readyCount, 1) - logger.Info("Register", "peerID", peerID) - } else if !ready { - atomic.AddInt64(&r.readyCount, 1) - logger.Info("Reconnecting...", "peerID", peerID) - } - - r.readyMap[peerID] = true - } - - if len(peerIDs) == len(r.peerNodeIDs) && !r.ready { - r.mu.Lock() - r.ready = true - r.mu.Unlock() - logger.Info("ALL PEERS ARE READY! Starting to accept MPC requests") - } - -} - -// Ready is called by the node when it complete generate preparams and starting to accept -// incoming requests -func (r *registry) Ready() error { - k := r.readyKey(r.nodeID) - - kv := &api.KVPair{ - Key: k, - Value: []byte("true"), - } - - _, err := r.consulKV.Put(kv, nil) - if err != nil { - return fmt.Errorf("Put ready key failed: %w", err) - } - - return nil -} - -func (r *registry) WatchPeersReady() { - ticker := time.NewTicker(ReadinessCheckPeriod) - go r.logReadyStatus() - // first tick is executed immediately - for ; true; <-ticker.C { - pairs, _, err := r.consulKV.List("ready/", nil) - if err != nil { - logger.Error("List ready keys failed", err) - } - - newReadyPeerIDs := r.getReadyPeersFromKVStore(pairs) - if len(newReadyPeerIDs) != len(r.peerNodeIDs) { - r.mu.Lock() - r.ready = false - r.mu.Unlock() - - var readyPeerIDs []string - for peerID, isReady := range r.readyMap { - if isReady { - readyPeerIDs = append(readyPeerIDs, peerID) - } - } - - disconnecteds, _ := lo.Difference(readyPeerIDs, newReadyPeerIDs) - if len(disconnecteds) > 0 { - for _, peerID := range disconnecteds { - logger.Warn("Peer disconnected!", "peerID", peerID) - r.readyMap[peerID] = false - atomic.AddInt64(&r.readyCount, -1) - } - - } - - } - r.registerReadyPairs(newReadyPeerIDs) - } - -} - -func (r *registry) logReadyStatus() { - for { - time.Sleep(5 * time.Second) - if !r.ArePeersReady() { - logger.Info("Peers are not ready yet", "ready", r.GetReadyPeersCount(), "expected", len(r.peerNodeIDs)+1) - } - } -} - -func (r *registry) GetReadyPeersCount() int64 { - return atomic.LoadInt64(&r.readyCount) -} - -func (r *registry) GetReadyPeersIncludeSelf() []string { - var peerIDs []string - for peerID, isReady := range r.readyMap { - if isReady { - peerIDs = append(peerIDs, peerID) - } - } - - peerIDs = append(peerIDs, r.nodeID) // append self - return peerIDs -} - -func (r *registry) getReadyPeersFromKVStore(kvPairs api.KVPairs) []string { - var peers []string - for _, k := range kvPairs { - var peerNodeID string - _, err := fmt.Sscanf(k.Key, "ready/%s", &peerNodeID) - if err != nil { - logger.Error("Parse ready key failed", err) - } - if peerNodeID == r.nodeID { - continue - } - - peers = append(peers, peerNodeID) - } - - return peers -} - -func (r *registry) ArePeersReady() bool { - r.mu.RLock() - defer r.mu.RUnlock() - - return r.ready -} - -func (r *registry) GetTotalPeersCount() int64 { - var self int64 = 1 - return int64(len(r.peerNodeIDs)) + self -} - -func (r *registry) Resign() error { - k := r.readyKey(r.nodeID) - - _, err := r.consulKV.Delete(k, nil) - if err != nil { - return fmt.Errorf("Delete ready key failed: %w", err) - } - - return nil -} diff --git a/pkg/mpc/taurus/reshare_session.go b/pkg/mpc/taurus/reshare_session.go deleted file mode 100644 index 2d3d9cf..0000000 --- a/pkg/mpc/taurus/reshare_session.go +++ /dev/null @@ -1,32 +0,0 @@ -package taurus - -// ReshareSession represents a threshold signature resharing session -type ReshareSession interface { - Session - - // Reshare starts the resharing protocol - Reshare(done func()) - - // GetPubKeyResult returns the public key after successful resharing - GetPubKeyResult() []byte - - // IsNewPeer returns true if this node is joining as a new peer - IsNewPeer() bool -} - -// BaseReshareSession provides common functionality for reshare sessions -type BaseReshareSession struct { - session - isNewPeer bool - pubKeyResult []byte -} - -// IsNewPeer returns true if this node is joining as a new peer -func (s *BaseReshareSession) IsNewPeer() bool { - return s.isNewPeer -} - -// GetPubKeyResult returns the public key after successful resharing -func (s *BaseReshareSession) GetPubKeyResult() []byte { - return s.pubKeyResult -} diff --git a/pkg/mpc/taurus/session.go b/pkg/mpc/taurus/session.go deleted file mode 100644 index bfa5921..0000000 --- a/pkg/mpc/taurus/session.go +++ /dev/null @@ -1,184 +0,0 @@ -package taurus - -import ( - "sync" - - "github.com/nats-io/nats.go" - "github.com/rs/zerolog" - "github.com/taurusgroup/multi-party-sig/pkg/party" - - "github.com/fystack/mpcium/pkg/common/errors" - "github.com/fystack/mpcium/pkg/encoding" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/types" -) - -type SessionType string - -const ( - TypeGenerateWalletResultFmt = "mpc.mpc_keygen_result.%s" - TypeReshareWalletResultFmt = "mpc.mpc_reshare_result.%s" - - SessionTypeCGGMP21 SessionType = "session_cggmp21" - SessionTypeECDSA SessionType = "ecdsa" - SessionTypeEDDSA SessionType = "eddsa" -) - -var ( - ErrNotEnoughParticipants = errors.New("Not enough participants to sign") - ErrNotInParticipantList = errors.New("Node is not in the participant list") -) - -type TopicComposer struct { - ComposeBroadcastTopic func() string - ComposeDirectTopic func(nodeID string) string -} - -type KeyComposerFn func(id string) string - -type Session interface { - ListenToIncomingMessageAsync(f func(msgBytes []byte)) - ErrChan() <-chan error - Init() - ProcessInboundMessage(msgBytes []byte) - ProcessOutboundMessage() - WaitForFinish() string -} - -type session struct { - walletID string - sessionID string - pubSub messaging.PubSub - selfPartyID party.ID - partyIDs []party.ID - subscriberList []messaging.Subscription - rounds int - outCh chan msg - errCh chan error - finishCh chan bool - externalFinishChan chan string - threshold int - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - resultQueue messaging.MessageQueue - logger zerolog.Logger - processing map[string]bool - processingLock sync.Mutex - topicComposer *TopicComposer - identityStore identity.Store -} - -type msg struct { - FromPartyID party.ID - ToPartyIDs []party.ID - IsBroadcast bool - Data []byte -} - -func (s *session) ProcessInboundMessage(msgBytes []byte) { - // This should be implemented by specific session types - // If this method is called directly on the base session type, it means - // the concrete type doesn't properly implement ProcessInboundMessage - panic("ProcessInboundMessage must be implemented by session type") -} - -func (s *session) ListenToIncomingMessageAsync(f func(msgBytes []byte)) { - // Subscribe to broadcast messages - broadcastTopic := s.topicComposer.ComposeBroadcastTopic() - broadcastSub, err := s.pubSub.Subscribe(broadcastTopic, func(m *nats.Msg) { - s.logger.Debug(). - Str("topic", broadcastTopic). - Int("size", len(m.Data)). - Msg("Received broadcast message") - f(m.Data) - }) - - if err != nil { - s.logger.Error().Err(err).Msgf("Failed to subscribe to broadcast topic %s", broadcastTopic) - s.errCh <- err - return - } - - s.subscriberList = append(s.subscriberList, broadcastSub) - - // Subscribe to direct messages - directTopic := s.topicComposer.ComposeDirectTopic(string(s.selfPartyID)) - directSub, err := s.pubSub.Subscribe(directTopic, func(m *nats.Msg) { - s.logger.Debug(). - Str("topic", directTopic). - Int("size", len(m.Data)). - Msg("Received direct message") - f(m.Data) - }) - - if err != nil { - s.logger.Error().Err(err).Msgf("Failed to subscribe to direct topic %s", directTopic) - s.errCh <- err - return - } - - s.subscriberList = append(s.subscriberList, directSub) - - s.logger.Info(). - Str("broadcast", broadcastTopic). - Str("direct", directTopic). - Msg("Listening to incoming messages") -} - -func (s *session) sendMsg(message *types.TaurusMessage) { - data, err := encoding.StructToJsonBytes(message) - if err != nil { - s.logger.Error().Err(err).Msg("Failed to marshal message") - return - } - - if message.IsBroadcast { - topic := s.topicComposer.ComposeBroadcastTopic() - if err := s.pubSub.Publish(topic, data); err != nil { - s.logger.Error().Err(err).Msgf("Failed to publish broadcast message to %s", topic) - } else { - s.logger.Debug().Str("topic", topic).Msg("Published broadcast message") - } - } else { - // Send to specific recipients - for _, recipient := range message.RecipientIDs { - topic := s.topicComposer.ComposeDirectTopic(recipient) - if err := s.pubSub.Publish(topic, data); err != nil { - s.logger.Error().Err(err).Msgf("Failed to publish direct message to %s", topic) - } else { - s.logger.Debug(). - Str("topic", topic). - Str("recipient", recipient). - Msg("Published direct message") - } - } - } -} - -func (s *session) ErrChan() <-chan error { - return s.errCh -} - -func (s *session) unsubscribe() { - for _, sub := range s.subscriberList { - if err := sub.Unsubscribe(); err != nil { - s.logger.Error().Err(err).Msg("Failed to unsubscribe") - } - } - s.subscriberList = nil -} - -func (s *session) Stop() { - s.unsubscribe() -} - -// Helper function to get party routing destination -func PartyIDToRoutingDest(partyID party.ID) string { - // Extract node ID from party ID if it contains version info - nodeID := string(partyID) - // Simple extraction - in production you'd have more robust parsing - return nodeID -} diff --git a/pkg/mpc/taurus/signing_session.go b/pkg/mpc/taurus/signing_session.go deleted file mode 100644 index 8b12ec3..0000000 --- a/pkg/mpc/taurus/signing_session.go +++ /dev/null @@ -1,332 +0,0 @@ -package taurus - -import ( - "encoding/hex" - "encoding/json" - "fmt" - "sync" - - "github.com/fystack/mpcium/pkg/encoding" - "github.com/fystack/mpcium/pkg/identity" - "github.com/fystack/mpcium/pkg/keyinfo" - "github.com/fystack/mpcium/pkg/kvstore" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/types" - "github.com/fystack/mpcium/pkg/utils" - "github.com/rs/zerolog" - "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" - "github.com/taurusgroup/multi-party-sig/pkg/party" - "github.com/taurusgroup/multi-party-sig/pkg/pool" - "github.com/taurusgroup/multi-party-sig/pkg/protocol" - "github.com/taurusgroup/multi-party-sig/protocols/cmp" - "github.com/taurusgroup/multi-party-sig/protocols/cmp/config" -) - -type SignSession interface { - Session -} - -type cggmp21SigningSession struct { - session - handler *protocol.MultiHandler - pool *pool.Pool - config *config.Config - signature *ecdsa.Signature - messagesCh chan *protocol.Message - resultMutex sync.Mutex - done bool - resultErr error - messageHash []byte - signerIDs []party.ID - useBroadcast bool -} - -func newCGGMP21SigningSession( - sessionID string, - walletID string, - messageHash []byte, - pubSub messaging.PubSub, - selfPartyID party.ID, - signerIDs []party.ID, - kvstore kvstore.KVStore, - keyinfoStore keyinfo.Store, - resultQueue messaging.MessageQueue, - identityStore identity.Store, - useBroadcast bool, -) (*cggmp21SigningSession, error) { - // Load config from kvstore - shareBytes, err := kvstore.Get(walletID) - if err != nil { - return nil, fmt.Errorf("failed to get key share: %w", err) - } - - config := &config.Config{} - if err := json.Unmarshal(shareBytes, config); err != nil { - return nil, fmt.Errorf("failed to unmarshal key share: %w", err) - } - - // Create thread pool - threadPool := pool.NewPool(0) // Use max threads - - return &cggmp21SigningSession{ - session: session{ - walletID: walletID, - sessionID: sessionID, - pubSub: pubSub, - selfPartyID: selfPartyID, - partyIDs: signerIDs, - subscriberList: []messaging.Subscription{}, - rounds: 5, // CGGMP21 signing has 5 rounds - outCh: make(chan msg, 100), - errCh: make(chan error, 10), - finishCh: make(chan bool, 1), - externalFinishChan: make(chan string, 1), - threshold: config.Threshold, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - resultQueue: resultQueue, - logger: zerolog.New(utils.ZerologConsoleWriter()).With().Timestamp().Logger(), - processing: make(map[string]bool), - processingLock: sync.Mutex{}, - topicComposer: &TopicComposer{ - ComposeBroadcastTopic: func() string { - return fmt.Sprintf("sign:broadcast:cggmp21:%s", sessionID) - }, - ComposeDirectTopic: func(nodeID string) string { - return fmt.Sprintf("sign:direct:cggmp21:%s:%s", nodeID, sessionID) - }, - }, - identityStore: identityStore, - }, - pool: threadPool, - config: config, - messagesCh: make(chan *protocol.Message, 100), - messageHash: messageHash, - signerIDs: signerIDs, - useBroadcast: useBroadcast, - done: false, - }, nil -} - -func (s *cggmp21SigningSession) Init() { - s.logger.Info(). - Str("sessionID", s.sessionID). - Str("walletID", s.walletID). - Hex("messageHash", s.messageHash). - Interface("signerIDs", s.signerIDs). - Bool("useBroadcast", s.useBroadcast). - Msg("Initializing CGGMP21 signing session") - - // Create CGGMP21 signing protocol - startFunc := cmp.Sign(s.config, s.signerIDs, s.messageHash, s.pool) - - // Create handler - handler, err := protocol.NewMultiHandler(startFunc, nil) - if err != nil { - s.logger.Fatal().Err(err).Msg("Failed to create signing handler") - return - } - - s.handler = handler - - // Start message handling goroutine - go s.handleProtocolMessages() - - s.logger.Info(). - Str("sessionID", s.sessionID). - Str("partyID", string(s.selfPartyID)). - Interface("signerIDs", s.signerIDs). - Msg("[INITIALIZED] CGGMP21 signing session initialized successfully") -} - -func (s *cggmp21SigningSession) handleProtocolMessages() { - for { - select { - case protoMsg, ok := <-s.handler.Listen(): - if !ok { - // Protocol finished - s.resultMutex.Lock() - s.done = true - result, err := s.handler.Result() - if err != nil { - s.resultErr = err - s.errCh <- err - } else { - s.signature = result.(*ecdsa.Signature) - } - s.resultMutex.Unlock() - s.finishCh <- true - return - } - - // Convert protocol message to our message format - var toPartyIDs []party.ID - if !protoMsg.Broadcast && protoMsg.To != "" { - toPartyIDs = []party.ID{protoMsg.To} - } - outMsg := msg{ - FromPartyID: protoMsg.From, - ToPartyIDs: toPartyIDs, - IsBroadcast: protoMsg.Broadcast, - Data: protoMsg.Data, - } - - s.outCh <- outMsg - - case protoMsg := <-s.messagesCh: - // Handle incoming message - if !s.handler.CanAccept(protoMsg) { - s.logger.Warn().Msgf("Handler cannot accept message from %s", protoMsg.From) - continue - } - - s.handler.Accept(protoMsg) - } - } -} - -func (s *cggmp21SigningSession) ProcessInboundMessage(msgBytes []byte) { - s.processingLock.Lock() - defer s.processingLock.Unlock() - - inboundMessage := &types.TaurusMessage{} - if err := encoding.JsonBytesToStruct(msgBytes, inboundMessage); err != nil { - s.logger.Error().Err(err).Msg("ProcessInboundMessage unmarshal error") - return - } - - msgHashStr := fmt.Sprintf("%x", utils.GetMessageHash(msgBytes)) - if s.processing[msgHashStr] { - return - } - s.processing[msgHashStr] = true - - // Convert to protocol message - protoMsg := &protocol.Message{ - From: party.ID(inboundMessage.SenderID), - To: party.ID(""), // Single recipient for protocol messages - Data: inboundMessage.Body, - Broadcast: inboundMessage.IsBroadcast, - } - - // Send to handler - s.messagesCh <- protoMsg -} - -func (s *cggmp21SigningSession) ProcessOutboundMessage() { - s.logger.Info().Msgf("ProcessOutboundMessage started: %s", s.sessionID) - for { - select { - case m := <-s.outCh: - // Convert party IDs back to strings - recipientIDs := make([]string, len(m.ToPartyIDs)) - for i, pid := range m.ToPartyIDs { - recipientIDs[i] = string(pid) - } - - msgWireBytes := &types.TaurusMessage{ - SessionID: s.sessionID, - SenderID: string(m.FromPartyID), - RecipientIDs: recipientIDs, - Body: m.Data, - IsBroadcast: m.IsBroadcast, - } - - s.sendMsg(msgWireBytes) - - case err := <-s.errCh: - s.logger.Error().Err(err).Msg("Received error during ProcessOutboundMessage") - - case <-s.finishCh: - s.logger.Info().Msg("Received finish message during ProcessOutboundMessage") - s.publishResult() - return - } - } -} - -func (s *cggmp21SigningSession) publishResult() { - s.resultMutex.Lock() - defer s.resultMutex.Unlock() - - if s.resultErr != nil { - // failureEvent := event.CreateSignFailure( - // s.sessionID, - // s.walletID, - // map[string]any{ - // "error": s.resultErr.Error(), - // }, - // ) - // evtData, _ := encoding.StructToJsonBytes(failureEvent) - // if err := s.resultQueue.Enqueue(fmt.Sprintf("%s.%s", event.SigningResultTopic, s.walletID), evtData, nil); err != nil { - // s.logger.Error().Err(err).Msg("failed to publish sign failure event") - // } - return - } - - if s.signature == nil { - s.logger.Error().Msg("No signature available after signing completion") - return - } - - // Verify signature - if !s.signature.Verify(s.config.PublicPoint(), s.messageHash) { - s.logger.Error().Msg("Failed to verify signature") - // failureEvent := event.CreateSignFailure( - // s.sessionID, - // s.walletID, - // map[string]any{ - // "error": "signature verification failed", - // }, - // ) - // evtData, _ := encoding.StructToJsonBytes(failureEvent) - // if err := s.resultQueue.Enqueue(fmt.Sprintf("%s.%s", event.SigningResultTopic, s.walletID), evtData, nil); err != nil { - // s.logger.Error().Err(err).Msg("failed to publish sign failure event") - // } - return - } - - // Convert signature to hex - sigRBytes, _ := s.signature.R.MarshalBinary() - sigSBytes, _ := s.signature.S.MarshalBinary() - sigR := hex.EncodeToString(sigRBytes) - sigS := hex.EncodeToString(sigSBytes) - - // Publish success event - // successEvent := event.CreateSignSuccess( - // s.sessionID, - // s.walletID, - // sigR, - // sigS, - // map[string]any{ - // "messageHash": hex.EncodeToString(s.messageHash), - // "signers": len(s.signerIDs), - // "protocol": "CGGMP21", - // }, - // ) - - // evtData, _ := encoding.StructToJsonBytes(successEvent) - // if err := s.resultQueue.Enqueue(fmt.Sprintf("%s.%s", event.SigningResultTopic, s.walletID), evtData, nil); err != nil { - // s.logger.Error().Err(err).Msg("failed to publish sign success event") - // } - - s.logger.Info(). - Str("sessionID", s.sessionID). - Str("walletID", s.walletID). - Str("sigR", sigR). - Str("sigS", sigS). - Msg("CGGMP21 signing completed successfully") -} - -func (s *cggmp21SigningSession) Stop() { - if s.pool != nil { - s.pool.TearDown() - } - close(s.outCh) - close(s.errCh) - close(s.messagesCh) -} - -func (s *cggmp21SigningSession) WaitForFinish() string { - return <-s.externalFinishChan -} diff --git a/pkg/mpc/taurus/transport.go b/pkg/mpc/taurus/transport.go new file mode 100644 index 0000000..7b1b496 --- /dev/null +++ b/pkg/mpc/taurus/transport.go @@ -0,0 +1,83 @@ +package taurus + +import "sync" + +type Msg struct { + SID string + From string + To []string + IsBroadcast bool + Data []byte +} + +type Transport interface { + Send(to string, msg Msg) error + Inbox() <-chan Msg + Done() <-chan struct{} + Close() error +} + +// Memory implements Transport for local testing (per-party instance) +type Memory struct { + selfID string + peers map[string]*Memory // reference tới các peer + mu sync.RWMutex + + inbox chan Msg + done chan struct{} +} + +// NewMemoryParty creates a new memory transport for a party +func NewMemoryParty(selfID string) *Memory { + return &Memory{ + selfID: selfID, + peers: make(map[string]*Memory), + inbox: make(chan Msg, 100), + done: make(chan struct{}), + } +} + +// LinkPeers links all parties together (must be called after all parties are created) +func LinkPeers(parties ...*Memory) { + for _, p := range parties { + for _, q := range parties { + if p.selfID == q.selfID { + continue + } + p.peers[q.selfID] = q + } + } +} + +func (m *Memory) SelfID() string { + return m.selfID +} + +func (m *Memory) Send(to string, msg Msg) error { + m.mu.RLock() + peer, ok := m.peers[to] + m.mu.RUnlock() + if !ok { + return nil + } + select { + case peer.inbox <- msg: + default: + // drop if inbox full + } + return nil +} + +func (m *Memory) Inbox() <-chan Msg { + return m.inbox +} + +func (m *Memory) Done() <-chan struct{} { + return m.done +} + +func (m *Memory) Close() error { + close(m.done) + close(m.inbox) + return nil +} diff --git a/pkg/types/taurus.go b/pkg/types/taurus.go index e224c93..a0b6385 100644 --- a/pkg/types/taurus.go +++ b/pkg/types/taurus.go @@ -8,3 +8,10 @@ type TaurusMessage struct { Body []byte `json:"body"` IsBroadcast bool `json:"is_broadcast"` } + +// KeyData represents the result of key generation +type KeyData struct { + SID string + Type string + Payload []byte +} From 3fff1500bb6727abf358b36ca0e7c26cfcdce6c3 Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 6 Oct 2025 16:19:26 +0700 Subject: [PATCH 03/21] Implement Taurus CMP key generation and signing functionality in event consumer --- pkg/eventconsumer/event_consumer.go | 226 +++++++++++++++++++++------- pkg/identity/identity.go | 64 +++++++- pkg/mpc/node.go | 13 +- pkg/mpc/session.go | 5 +- pkg/mpc/taurus/adapter.go | 31 ++-- pkg/mpc/taurus/cmp.go | 198 ++++++++++++++++-------- pkg/mpc/taurus/cmp_test.go | 155 +++++++++++-------- pkg/mpc/taurus/nats_transport.go | 208 ++++++++++++++----------- pkg/mpc/taurus/transport.go | 26 ++-- pkg/types/initiator_msg.go | 1 + pkg/types/taurus.go | 39 ++++- 11 files changed, 649 insertions(+), 317 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 9c325e5..a469ae3 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -15,6 +15,7 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/mpc/taurus" "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -167,88 +168,115 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { } walletID := msg.WalletID - // ecdsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeECDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) - // if err != nil { - // ec.handleKeygenSessionError(walletID, err, "Failed to create ECDSA key generation session", natMsg) - // return - // } - // eddsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeEDDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) - // if err != nil { - // ec.handleKeygenSessionError(walletID, err, "Failed to create EdDSA key generation session", natMsg) - // return - // } - taurusSession, err := ec.node.CreateCMPKeyGenSession(walletID, ec.mpcThreshold, ec.genKeyResultQueue) + ecdsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeECDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) + if err != nil { + ec.handleKeygenSessionError(walletID, err, "Failed to create ECDSA key generation session", natMsg) + return + } + eddsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeEDDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) + if err != nil { + ec.handleKeygenSessionError(walletID, err, "Failed to create EdDSA key generation session", natMsg) + return + } + taurusSession, err := ec.node.CreateCMPSession(walletID, ec.mpcThreshold, taurus.ActKeygen) if err != nil { logger.Error("Failed to create Taurus CMP session", err, "walletID", walletID) ec.handleKeygenSessionError(walletID, err, "Failed to create Taurus CMP key generation session", natMsg) return } - // ecdsaSession.Init() - // eddsaSession.Init() + ecdsaSession.Init() + eddsaSession.Init() - // ctxEcdsa, doneEcdsa := context.WithCancel(baseCtx) - // ctxEddsa, doneEddsa := context.WithCancel(baseCtx) + ctxEcdsa, doneEcdsa := context.WithCancel(baseCtx) + ctxEddsa, doneEddsa := context.WithCancel(baseCtx) ctxTaurus, doneTaurus := context.WithCancel(baseCtx) successEvent := &event.KeygenResultEvent{WalletID: walletID, ResultType: event.ResultTypeSuccess} var wg sync.WaitGroup - wg.Add(1) + wg.Add(3) // Channel to communicate errors from goroutines to main function - errorChan := make(chan error, 1) - - // go func() { - // defer wg.Done() - // select { - // case <-ctxEcdsa.Done(): - // successEvent.ECDSAPubKey = ecdsaSession.GetPubKeyResult() - // case err := <-ecdsaSession.ErrChan(): - // logger.Error("ECDSA keygen session error", err) - // ec.handleKeygenSessionError(walletID, err, "ECDSA keygen session error", natMsg) - // errorChan <- err - // doneEcdsa() - // } - // }() - // go func() { - // defer wg.Done() - // select { - // case <-ctxEddsa.Done(): - // successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() - // case err := <-eddsaSession.ErrChan(): - // logger.Error("EdDSA keygen session error", err) - // ec.handleKeygenSessionError(walletID, err, "EdDSA keygen session error", natMsg) - // errorChan <- err - // doneEddsa() - // } - // }() - - // ecdsaSession.ListenToIncomingMessageAsync() - // eddsaSession.ListenToIncomingMessageAsync() - // Temporary delay for peer setup - ec.warmUpSession() - // go ecdsaSession.GenerateKey(doneEcdsa) - // go eddsaSession.GenerateKey(doneEddsa) + errorChan := make(chan error, 3) + + go func() { + defer wg.Done() + select { + case <-ctxEcdsa.Done(): + successEvent.ECDSAPubKey = ecdsaSession.GetPubKeyResult() + case err := <-ecdsaSession.ErrChan(): + logger.Error("ECDSA keygen session error", err) + ec.handleKeygenSessionError(walletID, err, "ECDSA keygen session error", natMsg) + errorChan <- err + doneEcdsa() + } + }() + go func() { + defer wg.Done() + select { + case <-ctxEddsa.Done(): + successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() + case err := <-eddsaSession.ErrChan(): + logger.Error("EdDSA keygen session error", err) + ec.handleKeygenSessionError(walletID, err, "EdDSA keygen session error", natMsg) + errorChan <- err + doneEddsa() + } + }() go func() { + defer wg.Done() data, err := taurusSession.Keygen(ctxTaurus) if err != nil { logger.Error("Failed to generate key", err) - ec.handleKeygenSessionError(walletID, err, "Failed to generate key", natMsg) errorChan <- err - doneTaurus() + return } - successEvent.TaurusCMPPubKey = data.Payload + + logger.Info("Keygen completed successfully", "walletID", walletID, "payloadLength", len(data.Payload)) + successEvent.TaurusCMPPubKey = data.PubKeyBytes doneTaurus() }() + ecdsaSession.ListenToIncomingMessageAsync() + eddsaSession.ListenToIncomingMessageAsync() + + // Temporary delay for peer setup + ec.warmUpSession() + go ecdsaSession.GenerateKey(doneEcdsa) + go eddsaSession.GenerateKey(doneEddsa) + + // Wait for completion or timeout + doneAll := make(chan struct{}) + go func() { + wg.Wait() + close(doneAll) + }() + + // Check for errors + select { + case <-doneAll: + // Check if any errors occurred during execution + select { + case <-errorChan: + // Error already handled by the goroutine, just return early + return + default: + // No errors, continue with success + } + case <-baseCtx.Done(): + // timeout occurred + logger.Warn("Key generation timed out", "walletID", walletID, "timeout", KeyGenTimeOut) + ec.handleKeygenSessionError(walletID, fmt.Errorf("keygen session timed out after %v", KeyGenTimeOut), "Key generation timed out", natMsg) + return + } + payload, err := json.Marshal(successEvent) if err != nil { logger.Error("Failed to marshal keygen success event", err) ec.handleKeygenSessionError(walletID, err, "Failed to marshal keygen success event", natMsg) return } - fmt.Println("payload", string(payload)) key := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, walletID) if err := ec.genKeyResultQueue.Enqueue( key, @@ -396,6 +424,9 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { ec.signingResultQueue, idempotentKey, ) + case types.KeyTypeTaurusCmp: + ec.handleCMPSigning(msg, natMsg) + return default: sessionErr = fmt.Errorf("unsupported key type: %v", msg.KeyType) } @@ -488,6 +519,93 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { go session.Sign(onSuccess) } +// Add this method to handle CMP signing +func (ec *eventConsumer) handleCMPSigning(msg types.SignTxMessage, natMsg *nats.Msg) { + logger.Info("Starting CMP signing", "walletID", msg.WalletID, "txID", msg.TxID) + + // Create CMP session for signing + taurusSession, err := ec.node.CreateCMPSession(msg.WalletID, ec.mpcThreshold, taurus.ActSign) + if err != nil { + logger.Error("Failed to create Taurus CMP signing session", err, "walletID", msg.WalletID) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to create Taurus CMP signing session", + natMsg, + ) + return + } + + // Convert transaction bytes to big.Int + txBigInt := new(big.Int).SetBytes(msg.Tx) + + // Create context for signing + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Perform CMP signing + signature, err := taurusSession.Sign(ctx, txBigInt) + if err != nil { + logger.Error("CMP signing failed", err, "walletID", msg.WalletID, "txID", msg.TxID) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "CMP signing failed", + natMsg, + ) + return + } + + // Create signing result event + signingResult := event.SigningResultEvent{ + ResultType: event.ResultTypeSuccess, + NetworkInternalCode: msg.NetworkInternalCode, + WalletID: msg.WalletID, + TxID: msg.TxID, + Signature: signature, // CMP returns the full signature + } + + // Marshal and enqueue the result + signingResultBytes, err := json.Marshal(signingResult) + if err != nil { + logger.Error("Failed to marshal CMP signing result event", err, "walletID", msg.WalletID, "txID", msg.TxID) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to marshal CMP signing result", + natMsg, + ) + return + } + + // Enqueue the signing result + err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: composeSigningIdempotentKey(msg.TxID, natMsg), + }) + if err != nil { + logger.Error("Failed to enqueue CMP signing result event", err, "walletID", msg.WalletID, "txID", msg.TxID) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to enqueue CMP signing result", + natMsg, + ) + return + } + + // Send reply and log success + ec.sendReplyToRemoveMsg(natMsg) + logger.Info("[COMPLETED CMP SIGN] CMP signing completed successfully", "walletID", msg.WalletID, "txID", msg.TxID) +} + func (ec *eventConsumer) consumeTxSigningEvent() error { sub, err := ec.pubsub.Subscribe(MPCSignEvent, func(natMsg *nats.Msg) { ec.signingMsgBuffer <- natMsg // Send to worker instead of processing directly @@ -868,6 +986,8 @@ func sessionTypeFromKeyType(keyType types.KeyType) (mpc.SessionType, error) { return mpc.SessionTypeECDSA, nil case types.KeyTypeEd25519: return mpc.SessionTypeEDDSA, nil + case types.KeyTypeTaurusCmp: + return mpc.SessionTypeTaurusCmp, nil default: logger.Warn("Unsupported key type", "keyType", keyType) return "", fmt.Errorf("unsupported key type: %v", keyType) diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 2d2f1ee..473e6b5 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -42,6 +42,9 @@ type Store interface { SignMessage(msg *types.TssMessage) ([]byte, error) VerifyMessage(msg *types.TssMessage) error + SignTaurusMessage(msg *types.TaurusMessage) ([]byte, error) + VerifyTaurusMessage(msg *types.TaurusMessage) error + SignEcdhMessage(msg *types.ECDHMessage) ([]byte, error) VerifySignature(msg *types.ECDHMessage) error @@ -97,15 +100,19 @@ func NewFileStore(identityDir, nodeName string, decrypt bool, agePasswordFile st } // Load peers.json to validate all nodes have identity files - peersData, err := os.ReadFile("peers.json") - if err != nil { - return nil, fmt.Errorf("failed to read peers.json: %w", err) - } + // peersData, err := os.ReadFile("peers.json") + // if err != nil { + // return nil, fmt.Errorf("failed to read peers.json: %w", err) + // } - peers := make(map[string]string) - if err := json.Unmarshal(peersData, &peers); err != nil { - return nil, fmt.Errorf("failed to parse peers.json: %w", err) + peers := map[string]string{ + "node0": "aa4adaea-257d-4337-842a-1d3f966d85c2", + "node1": "21ac5259-ac9e-4b81-bd42-d05f584879e4", + "node2": "2fff5119-a1f1-4763-8f4c-d7d88c212608", } + // if err := json.Unmarshal(peersData, &peers); err != nil { + // return nil, fmt.Errorf("failed to parse peers.json: %w", err) + // } store := &fileStore{ identityDir: identityDir, @@ -427,6 +434,45 @@ func (s *fileStore) VerifyMessage(msg *types.TssMessage) error { return nil } +func (s *fileStore) SignTaurusMessage(msg *types.TaurusMessage) ([]byte, error) { + // Get deterministic bytes for signing + msgBytes, err := msg.MarshalForSigning() + if err != nil { + return nil, fmt.Errorf("failed to marshal message for signing: %w", err) + } + + signature := ed25519.Sign(s.privateKey, msgBytes) + return signature, nil +} + +func (s *fileStore) VerifyTaurusMessage(msg *types.TaurusMessage) error { + if msg.Signature == nil { + return fmt.Errorf("message has no signature") + } + + // Get the sender's NodeID + senderNodeID := taurusPartyIDToNodeID(msg.From) + + // Get the sender's public key + publicKey, err := s.GetPublicKey(senderNodeID) + if err != nil { + return fmt.Errorf("failed to get sender's public key: %w", err) + } + + // Get deterministic bytes for verification + msgBytes, err := msg.MarshalForSigning() + if err != nil { + return fmt.Errorf("failed to marshal message for verification: %w", err) + } + + // Verify the signature + if !ed25519.Verify(publicKey, msgBytes, msg.Signature) { + return fmt.Errorf("invalid signature") + } + + return nil +} + func (s *fileStore) EncryptMessage(plaintext []byte, peerID string) ([]byte, error) { key, err := s.GetSymmetricKey(peerID) if err != nil { @@ -536,3 +582,7 @@ func (s *fileStore) verifyP256(msg types.InitiatorMessage) error { func partyIDToNodeID(partyID *tss.PartyID) string { return strings.Split(string(partyID.KeyInt().Bytes()), ":")[0] } + +func taurusPartyIDToNodeID(partyID string) string { + return strings.Split(partyID, ":")[0] +} diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index e2b4fa3..bec6420 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -143,18 +143,21 @@ func (p *Node) createEDDSAKeyGenSession(walletID string, threshold int, version return session, nil } -func (p *Node) CreateCMPKeyGenSession( +func (p *Node) CreateCMPSession( walletID string, threshold int, - resultQueue messaging.MessageQueue, + act taurus.Act, ) (*taurus.CmpParty, error) { readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() selfPartyID, allPartyIDs := p.generateTaurusPartyIDs(PurposeKeygen, readyPeerIDs, DefaultVersion) - tr := taurus.NewNATSTransport(walletID, selfPartyID, p.pubSub) + tr := taurus.NewNATSTransport(walletID, selfPartyID, act, p.pubSub, p.direct, p.identityStore) adapter := taurus.NewTaurusNetworkAdapter(walletID, selfPartyID, tr, allPartyIDs) pl := pool.NewPool(0) - party := taurus.NewCmpParty(walletID, selfPartyID, allPartyIDs, threshold, pl, adapter) - return party, nil + session := taurus.NewCmpParty(walletID, selfPartyID, allPartyIDs, threshold, pl, adapter, p.keyinfoStore, p.kvstore) + if act == taurus.ActSign { + session.LoadKey(walletID) + } + return session, nil } func (p *Node) CreateSigningSession( diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go index b1a76b5..5f5c147 100644 --- a/pkg/mpc/session.go +++ b/pkg/mpc/session.go @@ -25,8 +25,9 @@ const ( TypeReshareWalletResultFmt = "mpc.mpc_reshare_result.%s" TypeSigningResultFmt = "mpc.mpc_signing_result.%s" - SessionTypeECDSA SessionType = "session_ecdsa" - SessionTypeEDDSA SessionType = "session_eddsa" + SessionTypeECDSA SessionType = "session_ecdsa" + SessionTypeEDDSA SessionType = "session_eddsa" + SessionTypeTaurusCmp SessionType = "session_taurus_cmp" ) var ( diff --git a/pkg/mpc/taurus/adapter.go b/pkg/mpc/taurus/adapter.go index c02c732..4c33ac5 100644 --- a/pkg/mpc/taurus/adapter.go +++ b/pkg/mpc/taurus/adapter.go @@ -1,9 +1,8 @@ package taurus import ( - "encoding/json" - "log/slog" - + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" "github.com/taurusgroup/multi-party-sig/pkg/party" "github.com/taurusgroup/multi-party-sig/pkg/protocol" ) @@ -45,42 +44,44 @@ func (a *TaurusNetworkAdapter) Next() <-chan *protocol.Message { return a.inbox func (a *TaurusNetworkAdapter) Done() <-chan struct{} { return a.done } func (a *TaurusNetworkAdapter) Send(msg *protocol.Message) { - wire, err := json.Marshal(msg) + wire, err := msg.MarshalBinary() if err != nil { - slog.Error("❌ marshal protocol msg", "err", err) + logger.Error("marshal protocol msg", err) return } - m := Msg{SID: a.sid, From: string(msg.From), IsBroadcast: msg.Broadcast, Data: wire} + m := types.TaurusMessage{ + SID: a.sid, + From: string(msg.From), + To: []string{string(msg.To)}, + IsBroadcast: msg.Broadcast, + Data: wire, + } for _, pid := range a.peers { - if pid == a.selfID { - continue - } - if msg.Broadcast || msg.IsFor(pid) { + if pid != a.selfID && (msg.Broadcast || msg.IsFor(pid)) { _ = a.transport.Send(string(pid), m) } } } func (a *TaurusNetworkAdapter) route() { + defer close(a.done) for { select { case tm, ok := <-a.transport.Inbox(): if !ok { - close(a.done) return } var pm protocol.Message - if err := json.Unmarshal(tm.Data, &pm); err != nil { - slog.Error("❌ unmarshal protocol msg", "err", err) + if err := pm.UnmarshalBinary(tm.Data); err != nil { + logger.Error("unmarshal protocol msg", err) continue } select { case a.inbox <- &pm: default: - slog.Warn("⚠️ inbox full, drop msg", "self", a.selfID) + logger.Warn("inbox full, drop msg", "self", a.selfID) } case <-a.transport.Done(): - close(a.done) return } } diff --git a/pkg/mpc/taurus/cmp.go b/pkg/mpc/taurus/cmp.go index b94440e..63c47d4 100644 --- a/pkg/mpc/taurus/cmp.go +++ b/pkg/mpc/taurus/cmp.go @@ -2,10 +2,16 @@ package taurus import ( "context" + cryptoEcdsa "crypto/ecdsa" "errors" "fmt" "math/big" + "github.com/btcsuite/btcd/btcec/v2" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" @@ -17,13 +23,15 @@ import ( ) type CmpParty struct { - sid string - id party.ID - ids party.IDSlice - threshold int - pl *pool.Pool - savedData *cmp.Config - network NetworkInterface + sid string + id party.ID + ids party.IDSlice + threshold int + pl *pool.Pool + savedData *cmp.Config + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + network NetworkInterface } func NewCmpParty( @@ -33,74 +41,115 @@ func NewCmpParty( threshold int, pl *pool.Pool, network NetworkInterface, + keyinfoStore keyinfo.Store, + kvstore kvstore.KVStore, ) *CmpParty { return &CmpParty{ - sid: sid, - id: id, - ids: ids, - threshold: threshold, - pl: pl, - network: network, + sid: sid, + id: id, + ids: ids, + threshold: threshold, + pl: pl, + network: network, + keyinfoStore: keyinfoStore, + kvstore: kvstore, } } +func (p *CmpParty) LoadKey(sid string) error { + key := p.composeKey(sid) + + data, err := p.kvstore.Get(key) + if err != nil { + return fmt.Errorf("load key: %w", err) + } -func (p *CmpParty) LoadKey(data *types.KeyData) error { cfg := cmp.EmptyConfig(curve.Secp256k1{}) - if err := cfg.UnmarshalBinary(data.Payload); err != nil { - return fmt.Errorf("decode key data: %w", err) + if err := cfg.UnmarshalBinary(data); err != nil { + return fmt.Errorf("unmarshal key config: %w", err) } + p.savedData = cfg return nil } func (p *CmpParty) Keygen(ctx context.Context) (types.KeyData, error) { - h, err := protocol.NewMultiHandler( - cmp.Keygen(curve.Secp256k1{}, p.id, p.ids, p.threshold, p.pl), - []byte(p.sid), - ) + logger.Info("Starting to generate key Taurus CMP", "walletID", p.sid) + + result, err := p.run(ctx, cmp.Keygen(curve.Secp256k1{}, p.id, p.ids, p.threshold, p.pl)) if err != nil { - return types.KeyData{}, err + return types.KeyData{}, fmt.Errorf("cmp keygen: %w", err) } - if err := p.executeProtocol(ctx, h); err != nil { - return types.KeyData{}, err + + cfg, ok := result.(*cmp.Config) + if !ok { + return types.KeyData{}, fmt.Errorf("unexpected result type %T", result) } - res, err := h.Result() + p.savedData = cfg + + // Extract public key coordinates + x, y, err := ExtractXYFromPoint(cfg.PublicPoint()) if err != nil { - return types.KeyData{}, err + return types.KeyData{}, fmt.Errorf("extract pubkey: %w", err) } - cfg, ok := res.(*cmp.Config) - if !ok { - return types.KeyData{}, errors.New("unexpected result type") + + // Use secp256k1 curve, not P256 + pubKey := &cryptoEcdsa.PublicKey{ + Curve: btcec.S256(), + X: x, + Y: y, } - p.savedData = cfg - packed, _ := cfg.MarshalBinary() - return types.KeyData{SID: p.sid, Type: "taurus_cmp", Payload: packed}, nil + + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + return types.KeyData{}, fmt.Errorf("encode pubkey: %w", err) + } + + packed, err := cfg.MarshalBinary() + if err != nil { + return types.KeyData{}, fmt.Errorf("marshal config: %w", err) + } + + key := p.composeKey(p.sid) + keyInfo := &keyinfo.KeyInfo{ + ParticipantPeerIDs: p.getParticipantPeerIDs(), + Threshold: p.threshold, + Version: 1, + } + + // Store both key and metadata if stores available + if p.kvstore != nil { + if err := p.kvstore.Put(key, packed); err != nil { + return types.KeyData{}, fmt.Errorf("store key: %w", err) + } + } + if p.keyinfoStore != nil { + if err := p.keyinfoStore.Save(key, keyInfo); err != nil { + return types.KeyData{}, fmt.Errorf("store key info: %w", err) + } + } + + return types.KeyData{ + SID: p.sid, + Type: "taurus_cmp", + PubKeyBytes: pubKeyBytes, + }, nil } func (p *CmpParty) Sign(ctx context.Context, msg *big.Int) ([]byte, error) { if p.savedData == nil { return nil, errors.New("no key loaded") } - h, err := protocol.NewMultiHandler( - cmp.Sign(p.savedData, p.ids, msg.Bytes(), p.pl), - []byte(p.sid), - ) - if err != nil { - return nil, err - } - if err := p.executeProtocol(ctx, h); err != nil { - return nil, err - } - res, err := h.Result() + logger.Info("Starting to sign message Taurus CMP", "walletID", p.sid) + cfg, err := p.run(ctx, cmp.Sign(p.savedData, p.ids, msg.Bytes(), p.pl)) if err != nil { return nil, err } - sig, ok := res.(*ecdsa.Signature) + sig, ok := cfg.(*ecdsa.Signature) if !ok { - return nil, errors.New("unexpected signature result") + return nil, errors.New("unexpected result type") } if !sig.Verify(p.savedData.PublicPoint(), msg.Bytes()) { - return nil, errors.New("failed to verify cmp signature") + return nil, errors.New("signature verification failed") } return sig.SigEthereum() } @@ -109,44 +158,71 @@ func (p *CmpParty) Reshare(ctx context.Context) (types.KeyData, error) { if p.savedData == nil { return types.KeyData{}, errors.New("no key loaded") } - h, err := protocol.NewMultiHandler(cmp.Refresh(p.savedData, p.pl), []byte(p.sid)) + cfg, err := p.run(ctx, cmp.Refresh(p.savedData, p.pl)) if err != nil { return types.KeyData{}, err } - if err := p.executeProtocol(ctx, h); err != nil { - return types.KeyData{}, err - } - res, err := h.Result() - if err != nil { - return types.KeyData{}, err - } - cfg, ok := res.(*cmp.Config) + savedData, ok := cfg.(*cmp.Config) if !ok { return types.KeyData{}, errors.New("unexpected result type") } - p.savedData = cfg - packed, _ := cfg.MarshalBinary() + p.savedData = savedData + packed, _ := p.savedData.MarshalBinary() return types.KeyData{SID: p.sid, Type: "taurus_cmp", Payload: packed}, nil } -func (p *CmpParty) executeProtocol(ctx context.Context, h protocol.Handler) error { +func (p *CmpParty) run(ctx context.Context, proto protocol.StartFunc) (any, error) { + logger.Info("Starting to run Taurus CMP", "walletID", p.sid) + h, err := protocol.NewMultiHandler(proto, []byte(p.sid)) + if err != nil { + return nil, err + } for { select { case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() case msg, ok := <-h.Listen(): if !ok { - return nil + return h.Result() } p.network.Send(msg) case msg := <-p.network.Next(): if h.CanAccept(msg) { h.Accept(msg) } else { - logger.Warn("⚠️ Ignored invalid msg", "self", p.id) + logger.Debug("Ignored self broadcast msg", + "self", p.id, + "from", msg.From, + "to", msg.To, + "broadcast", msg.Broadcast, + ) } case <-p.network.Done(): - return nil + return h.Result() } } } + +func ExtractXYFromPoint(p curve.Point) (*big.Int, *big.Int, error) { + data, err := p.MarshalBinary() // compressed SEC1 form (33 bytes) + if err != nil { + return nil, nil, fmt.Errorf("marshal point: %w", err) + } + pk, err := secp256k1.ParsePubKey(data) + if err != nil { + return nil, nil, fmt.Errorf("parse secp256k1 pubkey: %w", err) + } + return pk.X(), pk.Y(), nil +} + +func (p *CmpParty) getParticipantPeerIDs() []string { + var ids []string + for _, id := range p.ids { + ids = append(ids, string(id)) + } + return ids +} + +func (p *CmpParty) composeKey(sid string) string { + return fmt.Sprintf("taurus_cmp:%s", sid) +} diff --git a/pkg/mpc/taurus/cmp_test.go b/pkg/mpc/taurus/cmp_test.go index 2f4979d..59fef80 100644 --- a/pkg/mpc/taurus/cmp_test.go +++ b/pkg/mpc/taurus/cmp_test.go @@ -1,93 +1,116 @@ package taurus import ( + "bytes" "context" - "fmt" + "math/big" "sync" "testing" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" - "github.com/fystack/mpcium/pkg/types" "github.com/nats-io/nats.go" "github.com/taurusgroup/multi-party-sig/pkg/party" "github.com/taurusgroup/multi-party-sig/pkg/pool" ) -func TestCmpParty(t *testing.T) { - sid := "test-session-123" - parties := []string{"party1", "party2", "party3"} - ids := make([]party.ID, len(parties)) - for i, id := range parties { - ids[i] = party.ID(id) - } - pl := pool.NewPool(0) +type cmpTest struct { + parties []*CmpParty + results map[string]chan any +} - natsConn, err := nats.Connect("nats://localhost:4223") +func newCmpTest(sid string, ids []party.ID) *cmpTest { + pl := pool.NewPool(0) + nc, err := nats.Connect("nats://localhost:4223") if err != nil { logger.Fatal("Failed to connect to NATS", err) } - pubsub := messaging.NewNATSPubSub(natsConn) - - // networks + adapters - network1 := NewNATSTransport(sid, party.ID("party1"), pubsub) - network2 := NewNATSTransport(sid, party.ID("party2"), pubsub) - network3 := NewNATSTransport(sid, party.ID("party3"), pubsub) + pubsub := messaging.NewNATSPubSub(nc) + direct := messaging.NewNatsDirectMessaging(nc) - adapter1 := NewTaurusNetworkAdapter(sid, "party1", network1, ids) - adapter2 := NewTaurusNetworkAdapter(sid, "party2", network2, ids) - adapter3 := NewTaurusNetworkAdapter(sid, "party3", network3, ids) - - party1 := NewCmpParty(sid, "party1", ids, 2, pl, adapter1) - party2 := NewCmpParty(sid, "party2", ids, 2, pl, adapter2) - party3 := NewCmpParty(sid, "party3", ids, 2, pl, adapter3) - - result1 := make(chan types.KeyData, 1) - result2 := make(chan types.KeyData, 1) - result3 := make(chan types.KeyData, 1) + t := &cmpTest{ + results: map[string]chan any{ + "keygen": make(chan any, len(ids)), + "sign": make(chan any, len(ids)), + "reshare": make(chan any, len(ids)), + }, + } - wg := sync.WaitGroup{} - wg.Add(3) + for _, id := range ids { + net := NewNATSTransport(sid, id, ActKeygen, pubsub, direct, nil) + adapter := NewTaurusNetworkAdapter(sid, id, net, ids) + t.parties = append(t.parties, NewCmpParty(sid, id, ids, 2, pl, adapter, nil, nil)) + } - go func() { - defer wg.Done() - res, err := party1.Keygen(context.Background()) - if err != nil { - t.Errorf("party1 keygen error: %v", err) - return - } - result1 <- res - }() - - go func() { - defer wg.Done() - res, err := party2.Keygen(context.Background()) - if err != nil { - t.Errorf("party2 keygen error: %v", err) - return - } - result2 <- res - }() - - go func() { - defer wg.Done() - res, err := party3.Keygen(context.Background()) - if err != nil { - t.Errorf("party3 keygen error: %v", err) - return - } - result3 <- res - }() + return t +} +func (t *cmpTest) runAll(fn func(*CmpParty) (any, error), key string) { + var wg sync.WaitGroup + for _, p := range t.parties { + wg.Add(1) + go func(p *CmpParty) { + defer wg.Done() + res, err := fn(p) + if err != nil { + logger.Error("operation failed", err) + return + } + t.results[key] <- res + }(p) + } wg.Wait() +} - // Read the actual values from channels - r1 := <-result1 - r2 := <-result2 - r3 := <-result3 +func TestCmpParty(t *testing.T) { + sid := "test-session-123" + ids := []party.ID{"node0", "node1", "node2"} + test := newCmpTest(sid, ids) + + // --- Keygen --- + test.runAll(func(p *CmpParty) (any, error) { + return p.Keygen(context.Background()) + }, "keygen") + + // --- Sign 1 --- + msg := big.NewInt(1) + test.runAll(func(p *CmpParty) (any, error) { + return p.Sign(context.Background(), msg) + }, "sign") + + sigs := drain[[]byte](test.results["sign"]) + assertAllBytesEqual(t, sigs) + + // --- Reshare --- + test.runAll(func(p *CmpParty) (any, error) { + return p.Reshare(context.Background()) + }, "reshare") + + // --- Sign 2 --- + msg = big.NewInt(2) + test.runAll(func(p *CmpParty) (any, error) { + return p.Sign(context.Background(), msg) + }, "sign") +} - fmt.Println("party1 result:", len(r1.Payload)) - fmt.Println("party2 result:", len(r2.Payload)) - fmt.Println("party3 result:", len(r3.Payload)) +func drain[T any](ch chan any) []T { + n := len(ch) + out := make([]T, n) + for i := 0; i < n; i++ { + out[i] = (<-ch).(T) + } + return out +} + +func assertAllBytesEqual(t *testing.T, vals [][]byte) { + if len(vals) == 0 { + t.Fatal("no values to compare") + } + first := vals[0] + for i, v := range vals[1:] { + if !bytes.Equal(first, v) { + t.Fatalf("byte slices not equal at index %d", i+1) + } + } } diff --git a/pkg/mpc/taurus/nats_transport.go b/pkg/mpc/taurus/nats_transport.go index b14bae7..a8ff6c3 100644 --- a/pkg/mpc/taurus/nats_transport.go +++ b/pkg/mpc/taurus/nats_transport.go @@ -1,146 +1,184 @@ package taurus import ( - "encoding/json" "fmt" "sync" - - "github.com/nats-io/nats.go" - "github.com/taurusgroup/multi-party-sig/pkg/party" + "time" "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/identity" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/types" + "github.com/nats-io/nats.go" + "github.com/taurusgroup/multi-party-sig/pkg/party" ) -type NATSTransport struct { - selfID string - wallet string - pubsub messaging.PubSub +type Act string - composeBroadcast func() string - composeDirect func(nodeID string) string +const ( + ActKeygen Act = "keygen" + ActSign Act = "sign" + ActReshare Act = "reshare" +) - inbox chan Msg - doneCh chan struct{} - errCh chan error +type TopicComposer struct { + ComposeBroadcastTopic func() string + ComposeDirectTopic func(to string, walletID string) string +} - mu sync.Mutex - subs []messaging.Subscription - closeMu sync.Once +type NATSTransport struct { + selfID string + wallet string + act Act + topicComposer *TopicComposer + pubsub messaging.PubSub + direct messaging.DirectMessaging + identityStore identity.Store + inbox chan types.TaurusMessage + done chan struct{} + subs []messaging.Subscription + closeMu sync.Once } -// NewNATSTransport creates a transport bound to a walletID and party. -func NewNATSTransport(walletID string, self party.ID, pubsub messaging.PubSub) *NATSTransport { +func NewNATSTransport( + walletID string, + self party.ID, + act Act, + pubsub messaging.PubSub, + direct messaging.DirectMessaging, + identityStore identity.Store, +) *NATSTransport { t := &NATSTransport{ - selfID: string(self), - wallet: walletID, - pubsub: pubsub, - inbox: make(chan Msg, 128), - doneCh: make(chan struct{}), - errCh: make(chan error, 8), - - composeBroadcast: func() string { - return fmt.Sprintf("mpc:broadcast:%s", walletID) - }, - composeDirect: func(nodeID string) string { - return fmt.Sprintf("mpc:direct:%s:%s", nodeID, walletID) + selfID: string(self), + wallet: walletID, + act: act, + pubsub: pubsub, + direct: direct, + identityStore: identityStore, + topicComposer: &TopicComposer{ + ComposeBroadcastTopic: func() string { + return fmt.Sprintf("%s:broadcast:cmp:%s", act, walletID) + }, + ComposeDirectTopic: func(to string, walletID string) string { + return fmt.Sprintf("%s:direct:cmp:%s:%s", act, to, walletID) + }, }, + inbox: make(chan types.TaurusMessage, 128), + done: make(chan struct{}), } - // subscribe broadcast - bcastTopic := t.composeBroadcast() - bcast, err := pubsub.Subscribe(bcastTopic, func(m *nats.Msg) { - t.handleRaw(m.Data) - }) - if err == nil { - t.subs = append(t.subs, bcast) - } else { - t.pushErr(err) + bcastTopic := t.topicComposer.ComposeBroadcastTopic() + if sub, err := pubsub.Subscribe(bcastTopic, t.handle); err == nil { + t.subs = append(t.subs, sub) } - // subscribe direct - directTopic := t.composeDirect(t.selfID) - direct, err := pubsub.Subscribe(directTopic, func(m *nats.Msg) { - t.handleRaw(m.Data) - }) - if err == nil { - t.subs = append(t.subs, direct) - } else { - t.pushErr(err) + directTopic := t.topicComposer.ComposeDirectTopic(t.selfID, walletID) + if sub, err := direct.Listen(directTopic, t.handleDirect); err == nil { + t.subs = append(t.subs, sub) } - logger.Info("✅ NATSTransport listening", - "wallet", walletID, - "broadcast", bcastTopic, - "direct", directTopic) - + logger.Debug( + "NATS Transport listening", + "wallet", + walletID, + "broadcast", + bcastTopic, + "direct", + directTopic, + ) return t } -// --- Transport interface --- - -func (t *NATSTransport) Send(to string, msg Msg) error { - // Marshal the message +func (t *NATSTransport) Send(to string, msg types.TaurusMessage) error { + // use AEAD encryption for each message so NATs server learns nothing + if t.identityStore != nil { + cipher, err := t.identityStore.SignTaurusMessage(&msg) + if err != nil { + return err + } + msg.Signature = cipher + } data, err := encoding.StructToJsonBytes(&msg) if err != nil { return err } - if msg.IsBroadcast { - // publish to broadcast topic - topic := t.composeBroadcast() + topic := t.topicComposer.ComposeBroadcastTopic() return t.pubsub.Publish(topic, data) } - // unicast to "to" - topic := t.composeDirect(to) - return t.pubsub.Publish(topic, data) + // Use direct messaging for unicast with retry + topic := t.topicComposer.ComposeDirectTopic(to, t.wallet) + if to == t.selfID { + return t.direct.SendToSelf(topic, data) + } + + return t.direct.SendToOtherWithRetry(topic, data, messaging.RetryConfig{ + RetryAttempt: 3, + ExponentialBackoff: true, + Delay: 50 * time.Millisecond, + OnRetry: func(n uint, err error) { + logger.Warn("Retry sending", "to", to, "attempt", n+1, "err", err.Error()) + }, + }) } -func (t *NATSTransport) Inbox() <-chan Msg { return t.inbox } -func (t *NATSTransport) Done() <-chan struct{} { return t.doneCh } +func (t *NATSTransport) Inbox() <-chan types.TaurusMessage { return t.inbox } +func (t *NATSTransport) Done() <-chan struct{} { return t.done } func (t *NATSTransport) Close() error { t.closeMu.Do(func() { - t.mu.Lock() - defer t.mu.Unlock() - for i, sub := range t.subs { + for _, sub := range t.subs { if sub != nil { _ = sub.Unsubscribe() - logger.Debug("✅ unsubscribed", "index", i, "wallet", t.wallet) } } close(t.inbox) - close(t.errCh) - close(t.doneCh) - logger.Info("🛑 NATSTransport closed", "wallet", t.wallet) + close(t.done) + logger.Debug("NATSTransport closed", "wallet", t.wallet) }) return nil } -// --- Internal helpers --- - -func (t *NATSTransport) handleRaw(data []byte) { - var m Msg - if err := json.Unmarshal(data, &m); err != nil { - t.pushErr(fmt.Errorf("unmarshal inbound: %w", err)) +func (t *NATSTransport) handle(m *nats.Msg) { + var msg types.TaurusMessage + if err := encoding.JsonBytesToStruct(m.Data, &msg); err != nil { return } - if m.From == t.selfID { - return // skip self + if t.identityStore != nil { + if err := t.identityStore.VerifyTaurusMessage(&msg); err != nil { + logger.Warn("failed to verify message", "err", err.Error()) + return + } + } + if msg.From == t.selfID { + return } select { - case t.inbox <- m: + case t.inbox <- msg: default: - logger.Warn("⚠️ dropping inbound message, inbox full", "wallet", t.wallet) + logger.Warn("dropping inbound message, inbox full", "wallet", t.wallet) } } -func (t *NATSTransport) pushErr(err error) { +func (t *NATSTransport) handleDirect(data []byte) { + var msg types.TaurusMessage + if err := encoding.JsonBytesToStruct(data, &msg); err != nil { + return + } + if t.identityStore != nil { + if err := t.identityStore.VerifyTaurusMessage(&msg); err != nil { + logger.Warn("failed to verify message", "err", err.Error()) + return + } + } + if msg.From == t.selfID { + return + } select { - case t.errCh <- err: + case t.inbox <- msg: default: - logger.Warn("⚠️ dropping error (buffer full)", "wallet", t.wallet) + logger.Warn("dropping inbound message, inbox full", "wallet", t.wallet) } } diff --git a/pkg/mpc/taurus/transport.go b/pkg/mpc/taurus/transport.go index 7b1b496..1c4034d 100644 --- a/pkg/mpc/taurus/transport.go +++ b/pkg/mpc/taurus/transport.go @@ -1,18 +1,14 @@ package taurus -import "sync" +import ( + "sync" -type Msg struct { - SID string - From string - To []string - IsBroadcast bool - Data []byte -} + "github.com/fystack/mpcium/pkg/types" +) type Transport interface { - Send(to string, msg Msg) error - Inbox() <-chan Msg + Send(to string, msg types.TaurusMessage) error + Inbox() <-chan types.TaurusMessage Done() <-chan struct{} Close() error } @@ -20,10 +16,10 @@ type Transport interface { // Memory implements Transport for local testing (per-party instance) type Memory struct { selfID string - peers map[string]*Memory // reference tới các peer + peers map[string]*Memory // reference to peers mu sync.RWMutex - inbox chan Msg + inbox chan types.TaurusMessage done chan struct{} } @@ -32,7 +28,7 @@ func NewMemoryParty(selfID string) *Memory { return &Memory{ selfID: selfID, peers: make(map[string]*Memory), - inbox: make(chan Msg, 100), + inbox: make(chan types.TaurusMessage, 100), done: make(chan struct{}), } } @@ -53,7 +49,7 @@ func (m *Memory) SelfID() string { return m.selfID } -func (m *Memory) Send(to string, msg Msg) error { +func (m *Memory) Send(to string, msg types.TaurusMessage) error { m.mu.RLock() peer, ok := m.peers[to] m.mu.RUnlock() @@ -68,7 +64,7 @@ func (m *Memory) Send(to string, msg Msg) error { return nil } -func (m *Memory) Inbox() <-chan Msg { +func (m *Memory) Inbox() <-chan types.TaurusMessage { return m.inbox } diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index d770e79..951b81f 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -7,6 +7,7 @@ type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" KeyTypeEd25519 KeyType = "ed25519" + KeyTypeTaurusCmp KeyType = "taurus_cmp" ) type EventInitiatorKeyType string diff --git a/pkg/types/taurus.go b/pkg/types/taurus.go index a0b6385..61667dd 100644 --- a/pkg/types/taurus.go +++ b/pkg/types/taurus.go @@ -1,17 +1,40 @@ package types +import "encoding/json" + // Message represents a protocol message type TaurusMessage struct { - SessionID string `json:"session_id"` - SenderID string `json:"sender_id"` - RecipientIDs []string `json:"recipient_ids"` - Body []byte `json:"body"` - IsBroadcast bool `json:"is_broadcast"` + SID string + From string + To []string + IsBroadcast bool + Data []byte + Signature []byte +} + +func (m *TaurusMessage) MarshalForSigning() ([]byte, error) { + // Exclude the Signature field from the signed payload to ensure deterministic signatures + type signPayload struct { + SID string `json:"sid"` + From string `json:"from"` + To []string `json:"to"` + IsBroadcast bool `json:"isBroadcast"` + Data []byte `json:"data"` + } + sp := signPayload{ + SID: m.SID, + From: m.From, + To: m.To, + IsBroadcast: m.IsBroadcast, + Data: m.Data, + } + return json.Marshal(sp) } // KeyData represents the result of key generation type KeyData struct { - SID string - Type string - Payload []byte + SID string + Type string + Payload []byte + PubKeyBytes []byte } From 69069cd537d6cc5549dea1e16e635f5de03a388a Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 6 Oct 2025 16:24:24 +0700 Subject: [PATCH 04/21] Refactor NATSTransport message handling to unify handle functions and improve subscription logic --- pkg/mpc/taurus/nats_transport.go | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/pkg/mpc/taurus/nats_transport.go b/pkg/mpc/taurus/nats_transport.go index a8ff6c3..851ff01 100644 --- a/pkg/mpc/taurus/nats_transport.go +++ b/pkg/mpc/taurus/nats_transport.go @@ -69,12 +69,14 @@ func NewNATSTransport( } bcastTopic := t.topicComposer.ComposeBroadcastTopic() - if sub, err := pubsub.Subscribe(bcastTopic, t.handle); err == nil { + if sub, err := pubsub.Subscribe(bcastTopic, func(msg *nats.Msg) { + t.handle(msg.Data) + }); err == nil { t.subs = append(t.subs, sub) } directTopic := t.topicComposer.ComposeDirectTopic(t.selfID, walletID) - if sub, err := direct.Listen(directTopic, t.handleDirect); err == nil { + if sub, err := direct.Listen(directTopic, t.handle); err == nil { t.subs = append(t.subs, sub) } @@ -141,28 +143,7 @@ func (t *NATSTransport) Close() error { return nil } -func (t *NATSTransport) handle(m *nats.Msg) { - var msg types.TaurusMessage - if err := encoding.JsonBytesToStruct(m.Data, &msg); err != nil { - return - } - if t.identityStore != nil { - if err := t.identityStore.VerifyTaurusMessage(&msg); err != nil { - logger.Warn("failed to verify message", "err", err.Error()) - return - } - } - if msg.From == t.selfID { - return - } - select { - case t.inbox <- msg: - default: - logger.Warn("dropping inbound message, inbox full", "wallet", t.wallet) - } -} - -func (t *NATSTransport) handleDirect(data []byte) { +func (t *NATSTransport) handle(data []byte) { var msg types.TaurusMessage if err := encoding.JsonBytesToStruct(data, &msg); err != nil { return From 71062e6b52189e9a244d815f6d1930a28bc22b05 Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 6 Oct 2025 16:49:34 +0700 Subject: [PATCH 05/21] Add CMP reshare handling in event consumer --- pkg/eventconsumer/event_consumer.go | 71 +++++++++++++++++++++++++++++ pkg/mpc/node.go | 3 +- pkg/mpc/taurus/cmp.go | 37 +++++++++++++-- pkg/types/taurus.go | 5 ++ 4 files changed, 110 insertions(+), 6 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index a469ae3..1ea01e1 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -712,6 +712,11 @@ func (ec *eventConsumer) consumeReshareEvent() error { ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to get session type", natMsg) return } + // Handle CMP reshare separately + if keyType == types.KeyTypeTaurusCmp { + ec.handleCMPReshare(msg, natMsg) + return + } createSession := func(isNewPeer bool) (mpc.ReshareSession, error) { return ec.node.CreateReshareSession( @@ -851,6 +856,72 @@ func (ec *eventConsumer) consumeReshareEvent() error { return err } +// NOTE: In CMP reshare, it just refresh the keyshare of each node but keep the same public key and threshold. +// Therefore, we don't need to create new party sessions for CMP reshare. +func (ec *eventConsumer) handleCMPReshare(msg types.ResharingMessage, natMsg *nats.Msg) { + logger.Info("Starting CMP reshare", "walletID", msg.WalletID, "sessionID", msg.SessionID) + + // Create CMP session for reshare + taurusSession, err := ec.node.CreateCMPSession(msg.WalletID, msg.NewThreshold, taurus.ActReshare) + if err != nil { + logger.Error("Failed to create Taurus CMP reshare session", err, "walletID", msg.WalletID) + ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to create Taurus CMP reshare session", natMsg) + return + } + + // Load the existing key for reshare + if err := taurusSession.LoadKey(msg.WalletID); err != nil { + logger.Error("Failed to load key for CMP reshare", err, "walletID", msg.WalletID) + ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to load key for CMP reshare", natMsg) + return + } + + // Create context for reshare + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Longer timeout for reshare + defer cancel() + + // Perform CMP reshare + keyData, err := taurusSession.Reshare(ctx) + if err != nil { + logger.Error("CMP reshare failed", err, "walletID", msg.WalletID, "sessionID", msg.SessionID) + ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "CMP reshare failed", natMsg) + return + } + + // Create reshare result event + reshareResult := event.ResharingResultEvent{ + ResultType: event.ResultTypeSuccess, + WalletID: msg.WalletID, + NewThreshold: keyData.Threshold, + KeyType: msg.KeyType, + PubKey: keyData.PubKeyBytes, + } + + // Marshal and enqueue the result + reshareResultBytes, err := json.Marshal(reshareResult) + if err != nil { + logger.Error("Failed to marshal CMP reshare result event", err, "walletID", msg.WalletID, "sessionID", msg.SessionID) + ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to marshal CMP reshare result", natMsg) + return + } + + // Enqueue the reshare result + key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, msg.SessionID) + err = ec.reshareResultQueue.Enqueue(key, reshareResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: composeReshareIdempotentKey(msg.SessionID, natMsg), + }) + if err != nil { + logger.Error("Failed to enqueue CMP reshare result event", err, "walletID", msg.WalletID, "sessionID", msg.SessionID) + ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to enqueue CMP reshare result", natMsg) + return + } + + // Remove this line - don't send reply for reshare messages + // ec.sendReplyToRemoveMsg(natMsg) + + logger.Info("[COMPLETED CMP RESHARE] CMP reshare completed successfully", "walletID", msg.WalletID, "sessionID", msg.SessionID) +} + // handleReshareSessionError handles errors that occur during reshare operations func (ec *eventConsumer) handleReshareSessionError( walletID string, diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index bec6420..0245e25 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -154,9 +154,10 @@ func (p *Node) CreateCMPSession( adapter := taurus.NewTaurusNetworkAdapter(walletID, selfPartyID, tr, allPartyIDs) pl := pool.NewPool(0) session := taurus.NewCmpParty(walletID, selfPartyID, allPartyIDs, threshold, pl, adapter, p.keyinfoStore, p.kvstore) - if act == taurus.ActSign { + if act == taurus.ActSign || act == taurus.ActReshare { session.LoadKey(walletID) } + return session, nil } diff --git a/pkg/mpc/taurus/cmp.go b/pkg/mpc/taurus/cmp.go index 63c47d4..d099d85 100644 --- a/pkg/mpc/taurus/cmp.go +++ b/pkg/mpc/taurus/cmp.go @@ -154,21 +154,48 @@ func (p *CmpParty) Sign(ctx context.Context, msg *big.Int) ([]byte, error) { return sig.SigEthereum() } -func (p *CmpParty) Reshare(ctx context.Context) (types.KeyData, error) { +func (p *CmpParty) Reshare(ctx context.Context) (res types.ReshareData, err error) { if p.savedData == nil { - return types.KeyData{}, errors.New("no key loaded") + return res, errors.New("no key loaded") } cfg, err := p.run(ctx, cmp.Refresh(p.savedData, p.pl)) if err != nil { - return types.KeyData{}, err + return res, err } savedData, ok := cfg.(*cmp.Config) if !ok { - return types.KeyData{}, errors.New("unexpected result type") + return res, errors.New("unexpected result type") } p.savedData = savedData packed, _ := p.savedData.MarshalBinary() - return types.KeyData{SID: p.sid, Type: "taurus_cmp", Payload: packed}, nil + + key := p.composeKey(p.sid) + // Store updated key share + if p.kvstore != nil { + if err := p.kvstore.Put(key, packed); err != nil { + return res, fmt.Errorf("store key: %w", err) + } + } + + // Extract public key coordinates + x, y, err := ExtractXYFromPoint(p.savedData.PublicPoint()) + if err != nil { + return res, fmt.Errorf("extract pubkey: %w", err) + } + + // Use secp256k1 curve, not P256 + pubKey := &cryptoEcdsa.PublicKey{ + Curve: btcec.S256(), + X: x, + Y: y, + } + + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + return res, fmt.Errorf("encode pubkey: %w", err) + } + + return types.ReshareData{KeyData: types.KeyData{SID: p.sid, Type: "taurus_cmp", PubKeyBytes: pubKeyBytes}, Threshold: p.threshold}, nil } func (p *CmpParty) run(ctx context.Context, proto protocol.StartFunc) (any, error) { diff --git a/pkg/types/taurus.go b/pkg/types/taurus.go index 61667dd..d37b67d 100644 --- a/pkg/types/taurus.go +++ b/pkg/types/taurus.go @@ -38,3 +38,8 @@ type KeyData struct { Payload []byte PubKeyBytes []byte } + +type ReshareData struct { + KeyData + Threshold int +} From 4b417faf13a91a053a873dbab2df6fb0bf358da6 Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 6 Oct 2025 17:07:04 +0700 Subject: [PATCH 06/21] refactor: avoid conflict --- pkg/eventconsumer/event_consumer.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 1ea01e1..3caccaa 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -223,7 +223,6 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { doneEddsa() } }() - go func() { defer wg.Done() data, err := taurusSession.Keygen(ctxTaurus) @@ -233,7 +232,7 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { return } - logger.Info("Keygen completed successfully", "walletID", walletID, "payloadLength", len(data.Payload)) + logger.Info("CMP Keygen completed successfully", "walletID", walletID, "payloadLength", len(data.Payload)) successEvent.TaurusCMPPubKey = data.PubKeyBytes doneTaurus() }() @@ -253,7 +252,6 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { close(doneAll) }() - // Check for errors select { case <-doneAll: // Check if any errors occurred during execution @@ -277,6 +275,7 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { ec.handleKeygenSessionError(walletID, err, "Failed to marshal keygen success event", natMsg) return } + key := fmt.Sprintf(mpc.TypeGenerateWalletResultFmt, walletID) if err := ec.genKeyResultQueue.Enqueue( key, From 0c0e0dd9c46f146e2565478dc12d6a87cbf2ff4b Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 6 Oct 2025 17:24:10 +0700 Subject: [PATCH 07/21] fix: lint --- pkg/mpc/node.go | 5 +- pkg/protocol/cggmp21/adapter.go | 452 -------------------------------- pkg/protocol/frost/adapter.go | 445 ------------------------------- pkg/protocol/interfaces.go | 91 ------- 4 files changed, 4 insertions(+), 989 deletions(-) delete mode 100644 pkg/protocol/cggmp21/adapter.go delete mode 100644 pkg/protocol/frost/adapter.go delete mode 100644 pkg/protocol/interfaces.go diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index 0245e25..a7e4b20 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -155,7 +155,10 @@ func (p *Node) CreateCMPSession( pl := pool.NewPool(0) session := taurus.NewCmpParty(walletID, selfPartyID, allPartyIDs, threshold, pl, adapter, p.keyinfoStore, p.kvstore) if act == taurus.ActSign || act == taurus.ActReshare { - session.LoadKey(walletID) + err := session.LoadKey(walletID) + if err != nil { + return nil, err + } } return session, nil diff --git a/pkg/protocol/cggmp21/adapter.go b/pkg/protocol/cggmp21/adapter.go deleted file mode 100644 index 7f4f710..0000000 --- a/pkg/protocol/cggmp21/adapter.go +++ /dev/null @@ -1,452 +0,0 @@ -package cggmp21 - -import ( - "crypto/ecdsa" - "encoding/json" - "errors" - "fmt" - "math/big" - "sync" - - "github.com/fystack/mpcium/pkg/protocol" - mpsEcdsa "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" - "github.com/taurusgroup/multi-party-sig/pkg/math/curve" - "github.com/taurusgroup/multi-party-sig/pkg/party" - "github.com/taurusgroup/multi-party-sig/pkg/pool" - mpsProtocol "github.com/taurusgroup/multi-party-sig/pkg/protocol" - "github.com/taurusgroup/multi-party-sig/protocols/cmp" - "github.com/taurusgroup/multi-party-sig/protocols/cmp/config" -) - -// CGGMP21Protocol implements the Protocol interface using CGGMP21 -type CGGMP21Protocol struct { - pool *pool.Pool -} - -// NewCGGMP21Protocol creates a new CGGMP21 protocol adapter -func NewCGGMP21Protocol() *CGGMP21Protocol { - return &CGGMP21Protocol{ - pool: pool.NewPool(0), // Use max threads - } -} - -// Close cleans up resources -func (p *CGGMP21Protocol) Close() { - if p.pool != nil { - p.pool.TearDown() - } -} - -// Name returns the protocol name -func (p *CGGMP21Protocol) Name() string { - return "CGGMP21" -} - -// KeyGen starts a distributed key generation -func (p *CGGMP21Protocol) KeyGen(selfID string, partyIDs []string, threshold int) (protocol.Party, error) { - // Convert string IDs to party.ID - ids := make([]party.ID, len(partyIDs)) - for i, id := range partyIDs { - ids[i] = party.ID(id) - } - - // Create the keygen protocol - startFunc := cmp.Keygen(curve.Secp256k1{}, party.ID(selfID), ids, threshold, p.pool) - - // Create handler - handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) - if err != nil { - return nil, fmt.Errorf("failed to create keygen handler: %w", err) - } - - return &partyAdapter{ - handler: handler, - selfID: selfID, - }, nil -} - -// Refresh refreshes shares from an existing config -func (p *CGGMP21Protocol) Refresh(cfg protocol.KeyGenConfig) (protocol.Party, error) { - // Convert to CGGMP21 config - cmpConfig, err := toCMPConfig(cfg) - if err != nil { - return nil, err - } - - // Create refresh protocol - startFunc := cmp.Refresh(cmpConfig, p.pool) - - // Create handler - handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) - if err != nil { - return nil, fmt.Errorf("failed to create refresh handler: %w", err) - } - - return &partyAdapter{ - handler: handler, - selfID: cfg.GetPartyID(), - }, nil -} - -// Sign starts a signing protocol -func (p *CGGMP21Protocol) Sign(cfg protocol.KeyGenConfig, signers []string, messageHash []byte) (protocol.Party, error) { - // Convert to CGGMP21 config - cmpConfig, err := toCMPConfig(cfg) - if err != nil { - return nil, err - } - - // Convert signer IDs - signerIDs := make([]party.ID, len(signers)) - for i, id := range signers { - signerIDs[i] = party.ID(id) - } - - // Create sign protocol - startFunc := cmp.Sign(cmpConfig, signerIDs, messageHash, p.pool) - - // Create handler - handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) - if err != nil { - return nil, fmt.Errorf("failed to create sign handler: %w", err) - } - - return &partyAdapter{ - handler: handler, - selfID: cfg.GetPartyID(), - }, nil -} - -// PreSign starts a presigning protocol -func (p *CGGMP21Protocol) PreSign(cfg protocol.KeyGenConfig, signers []string) (protocol.Party, error) { - // Convert to CGGMP21 config - cmpConfig, err := toCMPConfig(cfg) - if err != nil { - return nil, err - } - - // Convert signer IDs - signerIDs := make([]party.ID, len(signers)) - for i, id := range signers { - signerIDs[i] = party.ID(id) - } - - // Create presign protocol - startFunc := cmp.Presign(cmpConfig, signerIDs, p.pool) - - // Create handler - handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) - if err != nil { - return nil, fmt.Errorf("failed to create presign handler: %w", err) - } - - return &partyAdapter{ - handler: handler, - selfID: cfg.GetPartyID(), - }, nil -} - -// PreSignOnline completes a signature with a presignature -func (p *CGGMP21Protocol) PreSignOnline(cfg protocol.KeyGenConfig, preSig protocol.PreSignature, messageHash []byte) (protocol.Party, error) { - // Convert to CGGMP21 types - cmpConfig, err := toCMPConfig(cfg) - if err != nil { - return nil, err - } - - cmpPreSig, ok := preSig.(*preSignatureAdapter) - if !ok { - return nil, errors.New("invalid presignature type") - } - - // Create presign online protocol - startFunc := cmp.PresignOnline(cmpConfig, cmpPreSig.preSig, messageHash, p.pool) - - // Create handler - handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) - if err != nil { - return nil, fmt.Errorf("failed to create presign online handler: %w", err) - } - - return &partyAdapter{ - handler: handler, - selfID: cfg.GetPartyID(), - }, nil -} - -// partyAdapter adapts mpsProtocol.Handler to protocol.Party -type partyAdapter struct { - handler *mpsProtocol.MultiHandler - selfID string - mu sync.Mutex - done bool - result interface{} - err error -} - -func (p *partyAdapter) Update(msg protocol.Message) error { - // Convert to MPS message format - // If broadcast, To is nil. Otherwise, it's the first recipient - var to party.ID - if !msg.IsBroadcast() && len(msg.GetTo()) > 0 { - to = party.ID(msg.GetTo()[0]) - } - - mpsMsg := &mpsProtocol.Message{ - From: party.ID(msg.GetFrom()), - To: to, - Broadcast: msg.IsBroadcast(), - Data: msg.GetData(), - } - - // Check if handler can accept the message - if !p.handler.CanAccept(mpsMsg) { - return errors.New("message rejected by handler") - } - - // Update handler with message - // Note: MultiHandler doesn't have Update method, we need to send via Accept - p.handler.Accept(mpsMsg) - return nil -} - -func (p *partyAdapter) Messages() <-chan protocol.Message { - ch := make(chan protocol.Message) - - go func() { - defer close(ch) - - for { - select { - case msg, ok := <-p.handler.Listen(): - if !ok { - // Protocol finished - p.mu.Lock() - p.done = true - p.result, p.err = p.handler.Result() - p.mu.Unlock() - return - } - - // Convert and send message - var toList []string - if !msg.Broadcast && msg.To != "" { - toList = []string{string(msg.To)} - } - ch <- &messageAdapter{ - from: string(msg.From), - to: toList, - data: msg.Data, - broadcast: msg.Broadcast, - } - } - } - }() - - return ch -} - -func (p *partyAdapter) Errors() <-chan error { - // CGGMP21 doesn't have a separate error channel - // Errors are returned in Result() - ch := make(chan error) - close(ch) - return ch -} - -func (p *partyAdapter) Done() bool { - p.mu.Lock() - defer p.mu.Unlock() - return p.done -} - -func (p *partyAdapter) Result() (interface{}, error) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.done { - return nil, errors.New("protocol not finished") - } - - if p.err != nil { - return nil, p.err - } - - // Convert result to appropriate type - switch r := p.result.(type) { - case *config.Config: - return &configAdapter{config: r}, nil - case *mpsEcdsa.Signature: - return &signatureAdapter{sig: r}, nil - case *mpsEcdsa.PreSignature: - return &preSignatureAdapter{preSig: r}, nil - default: - return p.result, nil - } -} - -// messageAdapter implements protocol.Message -type messageAdapter struct { - from string - to []string - data []byte - broadcast bool -} - -func (m *messageAdapter) GetFrom() string { return m.from } -func (m *messageAdapter) GetTo() []string { return m.to } -func (m *messageAdapter) GetData() []byte { return m.data } -func (m *messageAdapter) IsBroadcast() bool { return m.broadcast } - -// configAdapter implements protocol.KeyGenConfig -type configAdapter struct { - config *config.Config -} - -func (c *configAdapter) GetPartyID() string { - return string(c.config.ID) -} - -func (c *configAdapter) GetThreshold() int { - return c.config.Threshold -} - -func (c *configAdapter) GetPublicKey() *ecdsa.PublicKey { - point := c.config.PublicPoint() - // Convert curve.Point to ecdsa.PublicKey - // Using XScalar to get X coordinate as big.Int - if point.XScalar() != nil { - xBytes, _ := point.XScalar().MarshalBinary() - x := new(big.Int).SetBytes(xBytes) - // For Y, we need to derive it from the point - // This is a limitation - we can't get Y directly - return &ecdsa.PublicKey{ - Curve: nil, // We can't convert curve.Curve to elliptic.Curve - X: x, - Y: new(big.Int), // Placeholder - } - } - return nil -} - -func (c *configAdapter) GetShare() *big.Int { - // Get ECDSA scalar share and convert to big.Int - if c.config.ECDSA != nil { - bytes, _ := c.config.ECDSA.MarshalBinary() - return new(big.Int).SetBytes(bytes) - } - return nil -} - -func (c *configAdapter) GetSharePublicKey() *ecdsa.PublicKey { - // Get this party's public share - if public, ok := c.config.Public[c.config.ID]; ok && public.ECDSA != nil { - // Convert curve.Point to ecdsa.PublicKey - if public.ECDSA.XScalar() != nil { - xBytes, _ := public.ECDSA.XScalar().MarshalBinary() - x := new(big.Int).SetBytes(xBytes) - return &ecdsa.PublicKey{ - Curve: nil, // We can't convert curve.Curve to elliptic.Curve - X: x, - Y: new(big.Int), // Placeholder - } - } - } - return nil -} - -func (c *configAdapter) GetPartyIDs() []string { - ids := c.config.PartyIDs() - result := make([]string, len(ids)) - for i, id := range ids { - result[i] = string(id) - } - return result -} - -func (c *configAdapter) Serialize() ([]byte, error) { - return json.Marshal(c.config) -} - -// signatureAdapter implements protocol.Signature -type signatureAdapter struct { - sig *mpsEcdsa.Signature -} - -func (s *signatureAdapter) GetR() *big.Int { - // Convert curve.Point R to big.Int - if s.sig.R != nil && s.sig.R.XScalar() != nil { - bytes, _ := s.sig.R.XScalar().MarshalBinary() - return new(big.Int).SetBytes(bytes) - } - return nil -} - -func (s *signatureAdapter) GetS() *big.Int { - // Convert curve.Scalar S to big.Int - if s.sig.S != nil { - bytes, _ := s.sig.S.MarshalBinary() - return new(big.Int).SetBytes(bytes) - } - return nil -} - -func (s *signatureAdapter) Verify(pubKey *ecdsa.PublicKey, message []byte) bool { - // Verification would require converting ecdsa.PublicKey to curve.Point - // This is complex without the proper curve conversion - // For now, return false - return false -} - -func (s *signatureAdapter) Serialize() ([]byte, error) { - return json.Marshal(s.sig) -} - -// preSignatureAdapter implements protocol.PreSignature -type preSignatureAdapter struct { - preSig *mpsEcdsa.PreSignature -} - -func (p *preSignatureAdapter) GetID() string { - // Convert RID (byte slice) to hex string - return fmt.Sprintf("%x", p.preSig.ID) -} - -func (p *preSignatureAdapter) Validate() error { - return p.preSig.Validate() -} - -// Helper functions - -func convertToPartyIDs(ids []string) []party.ID { - if ids == nil { - return nil - } - result := make([]party.ID, len(ids)) - for i, id := range ids { - result[i] = party.ID(id) - } - return result -} - -func convertFromPartyIDs(ids []party.ID) []string { - if ids == nil { - return nil - } - result := make([]string, len(ids)) - for i, id := range ids { - result[i] = string(id) - } - return result -} - -func toCMPConfig(cfg protocol.KeyGenConfig) (*config.Config, error) { - // Try to cast directly first - if adapter, ok := cfg.(*configAdapter); ok { - return adapter.config, nil - } - - // Otherwise, we need to reconstruct - // This is a simplified version - in production you'd need proper serialization - return nil, errors.New("config conversion not implemented for non-CGGMP21 configs") -} diff --git a/pkg/protocol/frost/adapter.go b/pkg/protocol/frost/adapter.go deleted file mode 100644 index dc875ea..0000000 --- a/pkg/protocol/frost/adapter.go +++ /dev/null @@ -1,445 +0,0 @@ -package frost - -import ( - "crypto/ecdsa" - "encoding/json" - "errors" - "fmt" - "math/big" - "sync" - - "github.com/fystack/mpcium/pkg/protocol" - "github.com/taurusgroup/multi-party-sig/pkg/math/curve" - "github.com/taurusgroup/multi-party-sig/pkg/party" - "github.com/taurusgroup/multi-party-sig/pkg/pool" - mpsProtocol "github.com/taurusgroup/multi-party-sig/pkg/protocol" - "github.com/taurusgroup/multi-party-sig/protocols/frost" -) - -// FROSTProtocol implements the Protocol interface using FROST for EdDSA -type FROSTProtocol struct { - pool *pool.Pool -} - -// NewFROSTProtocol creates a new FROST protocol adapter -func NewFROSTProtocol() *FROSTProtocol { - return &FROSTProtocol{ - pool: pool.NewPool(0), // Use max threads - } -} - -// Close cleans up resources -func (p *FROSTProtocol) Close() { - if p.pool != nil { - p.pool.TearDown() - } -} - -// Name returns the protocol name -func (p *FROSTProtocol) Name() string { - return "FROST" -} - -// KeyGen starts a distributed key generation for EdDSA -func (p *FROSTProtocol) KeyGen(selfID string, partyIDs []string, threshold int) (protocol.Party, error) { - // Convert string IDs to party.ID - ids := make([]party.ID, len(partyIDs)) - for i, id := range partyIDs { - ids[i] = party.ID(id) - } - - // Create the FROST keygen protocol for Ed25519/Taproot - startFunc := frost.KeygenTaproot(party.ID(selfID), ids, threshold) - - // Create handler - handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) - if err != nil { - return nil, fmt.Errorf("failed to create FROST keygen handler: %w", err) - } - - return &frostPartyAdapter{ - handler: handler, - selfID: selfID, - isTaproot: true, - }, nil -} - -// Refresh refreshes shares from an existing config -func (p *FROSTProtocol) Refresh(cfg protocol.KeyGenConfig) (protocol.Party, error) { - // Convert to FROST config - frostConfig, err := toFROSTConfig(cfg) - if err != nil { - return nil, err - } - - // Get party IDs from config - partyIDs := cfg.GetPartyIDs() - ids := make([]party.ID, len(partyIDs)) - for i, id := range partyIDs { - ids[i] = party.ID(id) - } - - // Create refresh protocol - startFunc := frost.Refresh(frostConfig, ids) - - // Create handler - handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) - if err != nil { - return nil, fmt.Errorf("failed to create FROST refresh handler: %w", err) - } - - return &frostPartyAdapter{ - handler: handler, - selfID: cfg.GetPartyID(), - isTaproot: false, - }, nil -} - -// Sign starts a signing protocol -func (p *FROSTProtocol) Sign(cfg protocol.KeyGenConfig, signers []string, messageHash []byte) (protocol.Party, error) { - // Convert to FROST config - frostConfig, err := toFROSTConfig(cfg) - if err != nil { - return nil, err - } - - // Convert signer IDs - signerIDs := make([]party.ID, len(signers)) - for i, id := range signers { - signerIDs[i] = party.ID(id) - } - - // Create sign protocol - startFunc := frost.Sign(frostConfig, signerIDs, messageHash) - - // Create handler - handler, err := mpsProtocol.NewMultiHandler(startFunc, nil) - if err != nil { - return nil, fmt.Errorf("failed to create FROST sign handler: %w", err) - } - - return &frostPartyAdapter{ - handler: handler, - selfID: cfg.GetPartyID(), - isTaproot: false, - }, nil -} - -// PreSign starts a presigning protocol -func (p *FROSTProtocol) PreSign(cfg protocol.KeyGenConfig, signers []string) (protocol.Party, error) { - // FROST doesn't support presigning in the same way as ECDSA protocols - return nil, errors.New("FROST protocol does not support presigning") -} - -// PreSignOnline completes a signature with a presignature -func (p *FROSTProtocol) PreSignOnline(cfg protocol.KeyGenConfig, preSig protocol.PreSignature, messageHash []byte) (protocol.Party, error) { - // FROST doesn't support presigning in the same way as ECDSA protocols - return nil, errors.New("FROST protocol does not support presigning") -} - -// frostPartyAdapter adapts mpsProtocol.Handler to protocol.Party -type frostPartyAdapter struct { - handler *mpsProtocol.MultiHandler - selfID string - isTaproot bool - mu sync.Mutex - done bool - result interface{} - err error -} - -func (p *frostPartyAdapter) Update(msg protocol.Message) error { - // Convert to MPS message format - // If broadcast, To is nil. Otherwise, it's the first recipient - var to party.ID - if !msg.IsBroadcast() && len(msg.GetTo()) > 0 { - to = party.ID(msg.GetTo()[0]) - } - - mpsMsg := &mpsProtocol.Message{ - From: party.ID(msg.GetFrom()), - To: to, - Broadcast: msg.IsBroadcast(), - Data: msg.GetData(), - } - - // Check if handler can accept the message - if !p.handler.CanAccept(mpsMsg) { - return errors.New("message rejected by handler") - } - - // Update handler with message - // Note: MultiHandler doesn't have Update method, we need to send via Accept - p.handler.Accept(mpsMsg) - return nil -} - -func (p *frostPartyAdapter) Messages() <-chan protocol.Message { - ch := make(chan protocol.Message) - - go func() { - defer close(ch) - - for { - select { - case msg, ok := <-p.handler.Listen(): - if !ok { - // Protocol finished - p.mu.Lock() - p.done = true - p.result, p.err = p.handler.Result() - p.mu.Unlock() - return - } - - // Convert and send message - var toList []string - if !msg.Broadcast && msg.To != "" { - toList = []string{string(msg.To)} - } - - ch <- &messageAdapter{ - from: string(msg.From), - to: toList, - data: msg.Data, - broadcast: msg.Broadcast, - } - } - } - }() - - return ch -} - -func (p *frostPartyAdapter) Errors() <-chan error { - // FROST doesn't have a separate error channel - // Errors are returned in Result() - ch := make(chan error) - close(ch) - return ch -} - -func (p *frostPartyAdapter) Done() bool { - p.mu.Lock() - defer p.mu.Unlock() - return p.done -} - -func (p *frostPartyAdapter) Result() (interface{}, error) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.done { - return nil, errors.New("protocol not finished") - } - - if p.err != nil { - return nil, p.err - } - - // Convert result to appropriate type - switch r := p.result.(type) { - case *frost.Signature: - return &frostSignatureAdapter{sig: r}, nil - case *frost.Config: - return &frostConfigAdapter{ - config: r, - isTaproot: false, - }, nil - case *frost.TaprootConfig: - return &frostConfigAdapter{ - taprootConfig: r, - isTaproot: true, - }, nil - default: - return nil, fmt.Errorf("unexpected result type: %T", r) - } -} - -// messageAdapter implements protocol.Message -type messageAdapter struct { - from string - to []string - data []byte - broadcast bool -} - -func (m *messageAdapter) GetFrom() string { return m.from } -func (m *messageAdapter) GetTo() []string { return m.to } -func (m *messageAdapter) GetData() []byte { return m.data } -func (m *messageAdapter) IsBroadcast() bool { return m.broadcast } - -// frostConfigAdapter implements protocol.KeyGenConfig for FROST -type frostConfigAdapter struct { - config *frost.Config - taprootConfig *frost.TaprootConfig - isTaproot bool -} - -func (c *frostConfigAdapter) GetPartyID() string { - if c.isTaproot && c.taprootConfig != nil { - return string(c.taprootConfig.ID) - } - if c.config != nil { - return string(c.config.ID) - } - return "" -} - -func (c *frostConfigAdapter) GetThreshold() int { - if c.isTaproot && c.taprootConfig != nil { - return c.taprootConfig.Threshold - } - if c.config != nil { - return c.config.Threshold - } - return 0 -} - -// GetPublicKey returns nil for EdDSA as it uses different key type -func (c *frostConfigAdapter) GetPublicKey() *ecdsa.PublicKey { - // FROST uses Ed25519/Schnorr, not ECDSA - // This is a limitation of the current interface design - // For Taproot, we could potentially convert but it's not standard ECDSA - return nil -} - -// GetPublicKeyBytes returns the public key as bytes -func (c *frostConfigAdapter) GetPublicKeyBytes() []byte { - if c.isTaproot && c.taprootConfig != nil { - return c.taprootConfig.PublicKey - } - if c.config != nil && c.config.PublicKey != nil { - bytes, _ := c.config.PublicKey.MarshalBinary() - return bytes - } - return nil -} - -func (c *frostConfigAdapter) GetShare() *big.Int { - if c.isTaproot && c.taprootConfig != nil { - bytes, _ := c.taprootConfig.PrivateShare.MarshalBinary() - return new(big.Int).SetBytes(bytes) - } - if c.config != nil && c.config.PrivateShare != nil { - bytes, _ := c.config.PrivateShare.MarshalBinary() - return new(big.Int).SetBytes(bytes) - } - return nil -} - -func (c *frostConfigAdapter) GetSharePublicKey() *ecdsa.PublicKey { - // FROST uses Ed25519/Schnorr, not ECDSA - return nil -} - -func (c *frostConfigAdapter) GetPartyIDs() []string { - if c.isTaproot && c.taprootConfig != nil { - ids := make([]string, 0, len(c.taprootConfig.VerificationShares)) - for id := range c.taprootConfig.VerificationShares { - ids = append(ids, string(id)) - } - return ids - } - if c.config != nil && c.config.VerificationShares != nil { - ids := make([]string, 0, len(c.config.VerificationShares.Points)) - for id := range c.config.VerificationShares.Points { - ids = append(ids, string(id)) - } - return ids - } - return nil -} - -func (c *frostConfigAdapter) Serialize() ([]byte, error) { - if c.isTaproot && c.taprootConfig != nil { - return json.Marshal(c.taprootConfig) - } - if c.config != nil { - return json.Marshal(c.config) - } - return nil, errors.New("no config to serialize") -} - -// frostSignatureAdapter implements protocol.Signature for FROST -type frostSignatureAdapter struct { - sig *frost.Signature -} - -func (s *frostSignatureAdapter) GetR() *big.Int { - // FROST signatures have an R point, convert X coordinate to big.Int - if s.sig.R != nil && s.sig.R.XScalar() != nil { - bytes, _ := s.sig.R.XScalar().MarshalBinary() - return new(big.Int).SetBytes(bytes) - } - return new(big.Int) -} - -func (s *frostSignatureAdapter) GetS() *big.Int { - // FROST signatures don't have a direct S component - // This is a limitation of the current interface - return new(big.Int) -} - -func (s *frostSignatureAdapter) Verify(pubKey *ecdsa.PublicKey, message []byte) bool { - // This adapter doesn't support ECDSA verification - // FROST uses Schnorr signatures, not ECDSA - return false -} - -func (s *frostSignatureAdapter) Serialize() ([]byte, error) { - // Marshal the signature using JSON for now - return json.Marshal(s.sig) -} - -// Helper functions - -func convertToPartyIDs(ids []string) []party.ID { - if ids == nil { - return nil - } - result := make([]party.ID, len(ids)) - for i, id := range ids { - result[i] = party.ID(id) - } - return result -} - -func convertFromPartyIDs(ids []party.ID) []string { - if ids == nil { - return nil - } - result := make([]string, len(ids)) - for i, id := range ids { - result[i] = string(id) - } - return result -} - -func toFROSTConfig(cfg protocol.KeyGenConfig) (*frost.Config, error) { - // Try to cast directly first - if adapter, ok := cfg.(*frostConfigAdapter); ok { - if adapter.config != nil { - return adapter.config, nil - } - // If it's a Taproot config, we need to convert it - if adapter.taprootConfig != nil { - // This would need proper conversion logic - return nil, errors.New("cannot convert Taproot config to regular FROST config") - } - } - - // Otherwise, deserialize if possible - data, err := cfg.Serialize() - if err != nil { - return nil, fmt.Errorf("failed to serialize config: %w", err) - } - - // Try to unmarshal as FROST config - config := frost.EmptyConfig(curve.Secp256k1{}) - if err := json.Unmarshal(data, config); err != nil { - return nil, fmt.Errorf("failed to unmarshal as FROST config: %w", err) - } - - return config, nil -} diff --git a/pkg/protocol/interfaces.go b/pkg/protocol/interfaces.go deleted file mode 100644 index 916dd11..0000000 --- a/pkg/protocol/interfaces.go +++ /dev/null @@ -1,91 +0,0 @@ -package protocol - -import ( - "crypto/ecdsa" - "math/big" -) - -// Message represents a protocol message -type Message interface { - // GetFrom returns the sender ID - GetFrom() string - // GetTo returns the recipient IDs (nil for broadcast) - GetTo() []string - // GetData returns the message data - GetData() []byte - // IsBroadcast returns true if this is a broadcast message - IsBroadcast() bool -} - -// Party represents a participant in the protocol -type Party interface { - // Update processes an incoming message - Update(msg Message) error - // Messages returns a channel of outgoing messages - Messages() <-chan Message - // Errors returns a channel of errors - Errors() <-chan error - // Done returns true when the protocol is complete - Done() bool - // Result returns the protocol result - Result() (interface{}, error) -} - -// KeyGenConfig represents the result of key generation -type KeyGenConfig interface { - // GetPartyID returns this party's ID - GetPartyID() string - // GetThreshold returns the threshold value - GetThreshold() int - // GetPublicKey returns the group's public key - GetPublicKey() *ecdsa.PublicKey - // GetShare returns this party's secret share - GetShare() *big.Int - // GetSharePublicKey returns this party's public share - GetSharePublicKey() *ecdsa.PublicKey - // GetPartyIDs returns all party IDs - GetPartyIDs() []string - // Serialize returns the config as bytes - Serialize() ([]byte, error) -} - -// Signature represents a signature -type Signature interface { - // GetR returns the R component - GetR() *big.Int - // GetS returns the S component - GetS() *big.Int - // Verify verifies the signature - Verify(pubKey *ecdsa.PublicKey, message []byte) bool - // Serialize returns the signature as bytes - Serialize() ([]byte, error) -} - -// PreSignature represents a preprocessed signature -type PreSignature interface { - // GetID returns the presignature ID - GetID() string - // Validate validates the presignature - Validate() error -} - -// Protocol represents a threshold signature protocol implementation -type Protocol interface { - // KeyGen starts a distributed key generation - KeyGen(selfID string, partyIDs []string, threshold int) (Party, error) - - // Refresh refreshes shares from an existing config - Refresh(config KeyGenConfig) (Party, error) - - // Sign starts a signing protocol - Sign(config KeyGenConfig, signers []string, messageHash []byte) (Party, error) - - // PreSign starts a presigning protocol - PreSign(config KeyGenConfig, signers []string) (Party, error) - - // PreSignOnline completes a signature with a presignature - PreSignOnline(config KeyGenConfig, preSignature PreSignature, messageHash []byte) (Party, error) - - // Name returns the protocol name (e.g., "GG20", "CGGMP21") - Name() string -} From 58feb9f7fe453b3c01a41ba853f75b09e9b77349 Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 6 Oct 2025 17:29:15 +0700 Subject: [PATCH 08/21] fix: remove cmp test --- pkg/mpc/taurus/cmp_test.go | 116 ------------------------------------- 1 file changed, 116 deletions(-) delete mode 100644 pkg/mpc/taurus/cmp_test.go diff --git a/pkg/mpc/taurus/cmp_test.go b/pkg/mpc/taurus/cmp_test.go deleted file mode 100644 index 59fef80..0000000 --- a/pkg/mpc/taurus/cmp_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package taurus - -import ( - "bytes" - "context" - "math/big" - "sync" - "testing" - - "github.com/fystack/mpcium/pkg/logger" - "github.com/fystack/mpcium/pkg/messaging" - "github.com/nats-io/nats.go" - "github.com/taurusgroup/multi-party-sig/pkg/party" - "github.com/taurusgroup/multi-party-sig/pkg/pool" -) - -type cmpTest struct { - parties []*CmpParty - results map[string]chan any -} - -func newCmpTest(sid string, ids []party.ID) *cmpTest { - pl := pool.NewPool(0) - nc, err := nats.Connect("nats://localhost:4223") - if err != nil { - logger.Fatal("Failed to connect to NATS", err) - } - - pubsub := messaging.NewNATSPubSub(nc) - direct := messaging.NewNatsDirectMessaging(nc) - - t := &cmpTest{ - results: map[string]chan any{ - "keygen": make(chan any, len(ids)), - "sign": make(chan any, len(ids)), - "reshare": make(chan any, len(ids)), - }, - } - - for _, id := range ids { - net := NewNATSTransport(sid, id, ActKeygen, pubsub, direct, nil) - adapter := NewTaurusNetworkAdapter(sid, id, net, ids) - t.parties = append(t.parties, NewCmpParty(sid, id, ids, 2, pl, adapter, nil, nil)) - } - - return t -} - -func (t *cmpTest) runAll(fn func(*CmpParty) (any, error), key string) { - var wg sync.WaitGroup - for _, p := range t.parties { - wg.Add(1) - go func(p *CmpParty) { - defer wg.Done() - res, err := fn(p) - if err != nil { - logger.Error("operation failed", err) - return - } - t.results[key] <- res - }(p) - } - wg.Wait() -} - -func TestCmpParty(t *testing.T) { - sid := "test-session-123" - ids := []party.ID{"node0", "node1", "node2"} - test := newCmpTest(sid, ids) - - // --- Keygen --- - test.runAll(func(p *CmpParty) (any, error) { - return p.Keygen(context.Background()) - }, "keygen") - - // --- Sign 1 --- - msg := big.NewInt(1) - test.runAll(func(p *CmpParty) (any, error) { - return p.Sign(context.Background(), msg) - }, "sign") - - sigs := drain[[]byte](test.results["sign"]) - assertAllBytesEqual(t, sigs) - - // --- Reshare --- - test.runAll(func(p *CmpParty) (any, error) { - return p.Reshare(context.Background()) - }, "reshare") - - // --- Sign 2 --- - msg = big.NewInt(2) - test.runAll(func(p *CmpParty) (any, error) { - return p.Sign(context.Background(), msg) - }, "sign") -} - -func drain[T any](ch chan any) []T { - n := len(ch) - out := make([]T, n) - for i := 0; i < n; i++ { - out[i] = (<-ch).(T) - } - return out -} - -func assertAllBytesEqual(t *testing.T, vals [][]byte) { - if len(vals) == 0 { - t.Fatal("no values to compare") - } - first := vals[0] - for i, v := range vals[1:] { - if !bytes.Equal(first, v) { - t.Fatalf("byte slices not equal at index %d", i+1) - } - } -} From df2bd3402ec835e51d740a6822ee7f870500cc9e Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 7 Oct 2025 10:13:58 +0700 Subject: [PATCH 09/21] update: add cmp test --- pkg/mpc/taurus/cmp_test.go | 114 +++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 pkg/mpc/taurus/cmp_test.go diff --git a/pkg/mpc/taurus/cmp_test.go b/pkg/mpc/taurus/cmp_test.go new file mode 100644 index 0000000..a338b2b --- /dev/null +++ b/pkg/mpc/taurus/cmp_test.go @@ -0,0 +1,114 @@ +package taurus + +import ( + "bytes" + "context" + "math/big" + "sync" + "testing" + + "github.com/fystack/mpcium/pkg/logger" + "github.com/taurusgroup/multi-party-sig/pkg/party" +) + +type cmpTest struct { + parties []*CmpParty + results map[string]chan any +} + +func newCmpTest(sid string, ids []party.ID) *cmpTest { + t := &cmpTest{ + results: map[string]chan any{ + "keygen": make(chan any, len(ids)), + "sign": make(chan any, len(ids)), + "reshare": make(chan any, len(ids)), + }, + } + + // Create all memory transports first + transports := make([]*Memory, len(ids)) + for i, id := range ids { + transports[i] = NewMemoryParty(string(id)) + } + + // Link all peers together + LinkPeers(transports...) + + // Create parties with linked transports + for i, id := range ids { + adapter := NewTaurusNetworkAdapter(sid, id, transports[i], ids) + t.parties = append(t.parties, NewCmpParty(sid, id, ids, 2, + nil, adapter, nil, nil)) + } + + return t +} + +func (t *cmpTest) runAll(fn func(*CmpParty) (any, error), key string) { + var wg sync.WaitGroup + for _, p := range t.parties { + wg.Add(1) + go func(p *CmpParty) { + defer wg.Done() + res, err := fn(p) + if err != nil { + logger.Error("operation failed", err) + return + } + t.results[key] <- res + }(p) + } + wg.Wait() +} + +func TestCmpParty(t *testing.T) { + sid := "test-session-123" + ids := []party.ID{"node0", "node1", "node2"} + test := newCmpTest(sid, ids) + + // --- Keygen --- + test.runAll(func(p *CmpParty) (any, error) { + return p.Keygen(context.Background()) + }, "keygen") + + // --- Sign 1 --- + msg := big.NewInt(1) + test.runAll(func(p *CmpParty) (any, error) { + return p.Sign(context.Background(), msg) + }, "sign") + + sigs := drain[[]byte](test.results["sign"]) + assertAllBytesEqual(t, sigs) + + // // --- Reshare --- + // test.runAll(func(p *CmpParty) (any, error) { + // return p.Reshare(context.Background()) + // }, "reshare") + + // // // --- Sign 2 --- + // msg = big.NewInt(2) + // test.runAll(func(p *CmpParty) (any, error) { + // return p.Sign(context.Background(), msg) + // }, "sign") +} + +func drain[T any](ch chan any) []T { + n := len(ch) + out := make([]T, n) + for i := 0; i < n; i++ { + out[i] = (<-ch).(T) + } + return out +} + +func assertAllBytesEqual(t *testing.T, vals [][]byte) { + if len(vals) == 0 { + t.Fatal("no values to compare") + } + first := vals[0] + for i, v := range vals[1:] { + if !bytes.Equal(first, v) { + t.Fatalf("byte slices not equal at index %d", i+1) + } + } +} From 56de4453445a9efe59b9a25a010d0d7335b6b905 Mon Sep 17 00:00:00 2001 From: vietddude Date: Tue, 7 Oct 2025 10:25:58 +0700 Subject: [PATCH 10/21] refactor: remove unused Done channel from TaurusNetworkAdapter --- pkg/mpc/taurus/adapter.go | 32 ++++++++++---------------------- pkg/mpc/taurus/cmp.go | 2 -- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/pkg/mpc/taurus/adapter.go b/pkg/mpc/taurus/adapter.go index 4c33ac5..84a6b60 100644 --- a/pkg/mpc/taurus/adapter.go +++ b/pkg/mpc/taurus/adapter.go @@ -10,7 +10,6 @@ import ( type NetworkInterface interface { Next() <-chan *protocol.Message Send(msg *protocol.Message) - Done() <-chan struct{} } type TaurusNetworkAdapter struct { @@ -18,7 +17,6 @@ type TaurusNetworkAdapter struct { selfID party.ID transport Transport inbox chan *protocol.Message - done chan struct{} peers party.IDSlice } @@ -33,7 +31,6 @@ func NewTaurusNetworkAdapter( selfID: selfID, transport: t, inbox: make(chan *protocol.Message, 100), - done: make(chan struct{}), peers: peers, } go a.route() @@ -41,7 +38,6 @@ func NewTaurusNetworkAdapter( } func (a *TaurusNetworkAdapter) Next() <-chan *protocol.Message { return a.inbox } -func (a *TaurusNetworkAdapter) Done() <-chan struct{} { return a.done } func (a *TaurusNetworkAdapter) Send(msg *protocol.Message) { wire, err := msg.MarshalBinary() @@ -64,25 +60,17 @@ func (a *TaurusNetworkAdapter) Send(msg *protocol.Message) { } func (a *TaurusNetworkAdapter) route() { - defer close(a.done) - for { + for tm := range a.transport.Inbox() { + var pm protocol.Message + if err := pm.UnmarshalBinary(tm.Data); err != nil { + logger.Error("unmarshal protocol msg", err) + continue + } + select { - case tm, ok := <-a.transport.Inbox(): - if !ok { - return - } - var pm protocol.Message - if err := pm.UnmarshalBinary(tm.Data); err != nil { - logger.Error("unmarshal protocol msg", err) - continue - } - select { - case a.inbox <- &pm: - default: - logger.Warn("inbox full, drop msg", "self", a.selfID) - } - case <-a.transport.Done(): - return + case a.inbox <- &pm: + default: + logger.Warn("inbox full, drop msg", "self", a.selfID) } } } diff --git a/pkg/mpc/taurus/cmp.go b/pkg/mpc/taurus/cmp.go index d099d85..41a131f 100644 --- a/pkg/mpc/taurus/cmp.go +++ b/pkg/mpc/taurus/cmp.go @@ -224,8 +224,6 @@ func (p *CmpParty) run(ctx context.Context, proto protocol.StartFunc) (any, erro "broadcast", msg.Broadcast, ) } - case <-p.network.Done(): - return h.Result() } } } From 1f1aa9392dbe1c8bf8ce57f9c99817ec8e0752df Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 15 Oct 2025 14:29:48 +0700 Subject: [PATCH 11/21] feat: add Frost, Taproot and change cmp to CGGMP21 --- pkg/event/keygen.go | 9 +- pkg/eventconsumer/event_consumer.go | 162 +++++++++++++------- pkg/mpc/node.go | 25 +++- pkg/mpc/session.go | 8 +- pkg/mpc/taurus/adapter.go | 23 ++- pkg/mpc/taurus/{cmp.go => cggmp21.go} | 153 ++++++------------- pkg/mpc/taurus/cmp_test.go | 114 -------------- pkg/mpc/taurus/common.go | 123 ++++++++++++++++ pkg/mpc/taurus/frost.go | 197 +++++++++++++++++++++++++ pkg/mpc/taurus/nats_transport.go | 7 +- pkg/mpc/taurus/taproot.go | 204 ++++++++++++++++++++++++++ pkg/mpc/taurus/taurus_test.go | 102 +++++++++++++ pkg/types/initiator_msg.go | 4 +- 13 files changed, 826 insertions(+), 305 deletions(-) rename pkg/mpc/taurus/{cmp.go => cggmp21.go} (50%) delete mode 100644 pkg/mpc/taurus/cmp_test.go create mode 100644 pkg/mpc/taurus/common.go create mode 100644 pkg/mpc/taurus/frost.go create mode 100644 pkg/mpc/taurus/taproot.go create mode 100644 pkg/mpc/taurus/taurus_test.go diff --git a/pkg/event/keygen.go b/pkg/event/keygen.go index f72b1ae..5cc9ec0 100644 --- a/pkg/event/keygen.go +++ b/pkg/event/keygen.go @@ -7,10 +7,11 @@ const ( ) type KeygenResultEvent struct { - WalletID string `json:"wallet_id"` - ECDSAPubKey []byte `json:"ecdsa_pub_key"` - EDDSAPubKey []byte `json:"eddsa_pub_key"` - TaurusCMPPubKey []byte `json:"taurus_cmp_pub_key"` + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` + CGGMP21PubKey []byte `json:"cggmp21_pub_key"` + TaprootPubKey []byte `json:"taproot_pub_key"` ResultType ResultType `json:"result_type"` ErrorReason string `json:"error_reason"` diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 3caccaa..c2daaf1 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -178,10 +178,16 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { ec.handleKeygenSessionError(walletID, err, "Failed to create EdDSA key generation session", natMsg) return } - taurusSession, err := ec.node.CreateCMPSession(walletID, ec.mpcThreshold, taurus.ActKeygen) + cggmp21Session, err := ec.node.CreateTaurusSession(walletID, ec.mpcThreshold, types.KeyTypeCGGMP21, taurus.ActKeygen) if err != nil { - logger.Error("Failed to create Taurus CMP session", err, "walletID", walletID) - ec.handleKeygenSessionError(walletID, err, "Failed to create Taurus CMP key generation session", natMsg) + logger.Error("Failed to create CMP session", err, "walletID", walletID) + ec.handleKeygenSessionError(walletID, err, "Failed to create CMP key generation session", natMsg) + return + } + taprootSession, err := ec.node.CreateTaurusSession(walletID, ec.mpcThreshold, types.KeyTypeTaproot, taurus.ActKeygen) + if err != nil { + logger.Error("Failed to create Taproot session", err, "walletID", walletID) + ec.handleKeygenSessionError(walletID, err, "Failed to create Taproot key generation session", natMsg) return } @@ -190,14 +196,15 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { ctxEcdsa, doneEcdsa := context.WithCancel(baseCtx) ctxEddsa, doneEddsa := context.WithCancel(baseCtx) - ctxTaurus, doneTaurus := context.WithCancel(baseCtx) + ctxCggmp21, doneCggmp21 := context.WithCancel(baseCtx) + ctxTaproot, doneTaproot := context.WithCancel(baseCtx) successEvent := &event.KeygenResultEvent{WalletID: walletID, ResultType: event.ResultTypeSuccess} var wg sync.WaitGroup - wg.Add(3) + wg.Add(4) // Channel to communicate errors from goroutines to main function - errorChan := make(chan error, 3) + errorChan := make(chan error, 4) go func() { defer wg.Done() @@ -225,16 +232,30 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { }() go func() { defer wg.Done() - data, err := taurusSession.Keygen(ctxTaurus) + data, err := cggmp21Session.Keygen(ctxCggmp21) if err != nil { logger.Error("Failed to generate key", err) errorChan <- err return } - logger.Info("CMP Keygen completed successfully", "walletID", walletID, "payloadLength", len(data.Payload)) - successEvent.TaurusCMPPubKey = data.PubKeyBytes - doneTaurus() + logger.Info("CGGMP21 Keygen completed successfully", "walletID", walletID, "payloadLength", len(data.Payload)) + successEvent.CGGMP21PubKey = data.PubKeyBytes + doneCggmp21() + }() + + go func() { + defer wg.Done() + data, err := taprootSession.Keygen(ctxTaproot) + if err != nil { + logger.Error("Failed to generate key", err) + errorChan <- err + return + } + + logger.Info("Taproot Keygen completed successfully", "walletID", walletID, "payloadLength", len(data.Payload)) + successEvent.TaprootPubKey = data.PubKeyBytes + doneTaproot() }() ecdsaSession.ListenToIncomingMessageAsync() @@ -423,8 +444,8 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { ec.signingResultQueue, idempotentKey, ) - case types.KeyTypeTaurusCmp: - ec.handleCMPSigning(msg, natMsg) + case types.KeyTypeCGGMP21, types.KeyTypeTaproot, types.KeyTypeFROST: + ec.handleTaurusSigning(msg.KeyType, msg, natMsg) return default: sessionErr = fmt.Errorf("unsupported key type: %v", msg.KeyType) @@ -518,20 +539,17 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { go session.Sign(onSuccess) } -// Add this method to handle CMP signing -func (ec *eventConsumer) handleCMPSigning(msg types.SignTxMessage, natMsg *nats.Msg) { - logger.Info("Starting CMP signing", "walletID", msg.WalletID, "txID", msg.TxID) - - // Create CMP session for signing - taurusSession, err := ec.node.CreateCMPSession(msg.WalletID, ec.mpcThreshold, taurus.ActSign) +func (ec *eventConsumer) handleTaurusSigning(keyType types.KeyType, msg types.SignTxMessage, natMsg *nats.Msg) { + logger.Info("Starting signing", "walletID", msg.WalletID, "txID", msg.TxID, "keyType", keyType) + session, err := ec.node.CreateTaurusSession(msg.WalletID, ec.mpcThreshold, keyType, taurus.ActSign) if err != nil { - logger.Error("Failed to create Taurus CMP signing session", err, "walletID", msg.WalletID) + logger.Error("Failed to create session", err, "walletID", msg.WalletID) ec.handleSigningSessionError( msg.WalletID, msg.TxID, msg.NetworkInternalCode, err, - "Failed to create Taurus CMP signing session", + fmt.Sprintf("Failed to create %s session: %v", keyType, err), natMsg, ) return @@ -544,16 +562,15 @@ func (ec *eventConsumer) handleCMPSigning(msg types.SignTxMessage, natMsg *nats. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // Perform CMP signing - signature, err := taurusSession.Sign(ctx, txBigInt) + signature, err := session.Sign(ctx, txBigInt) if err != nil { - logger.Error("CMP signing failed", err, "walletID", msg.WalletID, "txID", msg.TxID) + logger.Error("signing failed", err, "keyType", keyType, "walletID", msg.WalletID, "txID", msg.TxID) ec.handleSigningSessionError( msg.WalletID, msg.TxID, msg.NetworkInternalCode, err, - "CMP signing failed", + fmt.Sprintf("%s signing failed", keyType), natMsg, ) return @@ -565,19 +582,19 @@ func (ec *eventConsumer) handleCMPSigning(msg types.SignTxMessage, natMsg *nats. NetworkInternalCode: msg.NetworkInternalCode, WalletID: msg.WalletID, TxID: msg.TxID, - Signature: signature, // CMP returns the full signature + Signature: signature, // Returns the full signature } // Marshal and enqueue the result signingResultBytes, err := json.Marshal(signingResult) if err != nil { - logger.Error("Failed to marshal CMP signing result event", err, "walletID", msg.WalletID, "txID", msg.TxID) + logger.Error("Failed to marshal signing result event", err, "keyType", keyType, "walletID", msg.WalletID, "txID", msg.TxID) ec.handleSigningSessionError( msg.WalletID, msg.TxID, msg.NetworkInternalCode, err, - "Failed to marshal CMP signing result", + fmt.Sprintf("Failed to marshal %s signing result", keyType), natMsg, ) return @@ -588,13 +605,13 @@ func (ec *eventConsumer) handleCMPSigning(msg types.SignTxMessage, natMsg *nats. IdempotententKey: composeSigningIdempotentKey(msg.TxID, natMsg), }) if err != nil { - logger.Error("Failed to enqueue CMP signing result event", err, "walletID", msg.WalletID, "txID", msg.TxID) + logger.Error("Failed to enqueue signing result event", err, "keyType", keyType, "walletID", msg.WalletID, "txID", msg.TxID) ec.handleSigningSessionError( msg.WalletID, msg.TxID, msg.NetworkInternalCode, err, - "Failed to enqueue CMP signing result", + fmt.Sprintf("Failed to enqueue %s signing result", keyType), natMsg, ) return @@ -602,7 +619,7 @@ func (ec *eventConsumer) handleCMPSigning(msg types.SignTxMessage, natMsg *nats. // Send reply and log success ec.sendReplyToRemoveMsg(natMsg) - logger.Info("[COMPLETED CMP SIGN] CMP signing completed successfully", "walletID", msg.WalletID, "txID", msg.TxID) + logger.Info("[COMPLETED SIGN] signing completed successfully", "keyType", keyType, "walletID", msg.WalletID, "txID", msg.TxID) } func (ec *eventConsumer) consumeTxSigningEvent() error { @@ -712,8 +729,8 @@ func (ec *eventConsumer) consumeReshareEvent() error { return } // Handle CMP reshare separately - if keyType == types.KeyTypeTaurusCmp { - ec.handleCMPReshare(msg, natMsg) + if keyType == types.KeyTypeCGGMP21 || keyType == types.KeyTypeTaproot || keyType == types.KeyTypeFROST { + ec.handleTaurusReshare(msg, natMsg) return } @@ -855,23 +872,37 @@ func (ec *eventConsumer) consumeReshareEvent() error { return err } -// NOTE: In CMP reshare, it just refresh the keyshare of each node but keep the same public key and threshold. +// NOTE: In Taurus reshare, it just refresh the keyshare of each node but keep the same public key and threshold. // Therefore, we don't need to create new party sessions for CMP reshare. -func (ec *eventConsumer) handleCMPReshare(msg types.ResharingMessage, natMsg *nats.Msg) { - logger.Info("Starting CMP reshare", "walletID", msg.WalletID, "sessionID", msg.SessionID) +func (ec *eventConsumer) handleTaurusReshare(msg types.ResharingMessage, natMsg *nats.Msg) { + logger.Info("Starting reshare", "walletID", msg.WalletID, "sessionID", msg.SessionID, "keyType", msg.KeyType) - // Create CMP session for reshare - taurusSession, err := ec.node.CreateCMPSession(msg.WalletID, msg.NewThreshold, taurus.ActReshare) + // Create Taurus session for reshare + session, err := ec.node.CreateTaurusSession(msg.WalletID, msg.NewThreshold, types.KeyTypeCGGMP21, taurus.ActReshare) if err != nil { - logger.Error("Failed to create Taurus CMP reshare session", err, "walletID", msg.WalletID) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to create Taurus CMP reshare session", natMsg) + logger.Error("Failed to create reshare session", err, "walletID", msg.WalletID, "keyType", msg.KeyType) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Failed to create %s reshare session", msg.KeyType), + natMsg, + ) return } // Load the existing key for reshare - if err := taurusSession.LoadKey(msg.WalletID); err != nil { - logger.Error("Failed to load key for CMP reshare", err, "walletID", msg.WalletID) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to load key for CMP reshare", natMsg) + if err := session.LoadKey(msg.WalletID); err != nil { + logger.Error("Failed to load key for reshare", err, "walletID", msg.WalletID, "keyType", msg.KeyType) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Failed to load key for %s reshare", msg.KeyType), + natMsg, + ) return } @@ -879,11 +910,18 @@ func (ec *eventConsumer) handleCMPReshare(msg types.ResharingMessage, natMsg *na ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Longer timeout for reshare defer cancel() - // Perform CMP reshare - keyData, err := taurusSession.Reshare(ctx) + // Perform reshare + keyData, err := session.Reshare(ctx) if err != nil { - logger.Error("CMP reshare failed", err, "walletID", msg.WalletID, "sessionID", msg.SessionID) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "CMP reshare failed", natMsg) + logger.Error("Reshare failed", err, "walletID", msg.WalletID, "sessionID", msg.SessionID, "keyType", msg.KeyType) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Reshare failed for %s", msg.KeyType), + natMsg, + ) return } @@ -899,8 +937,15 @@ func (ec *eventConsumer) handleCMPReshare(msg types.ResharingMessage, natMsg *na // Marshal and enqueue the result reshareResultBytes, err := json.Marshal(reshareResult) if err != nil { - logger.Error("Failed to marshal CMP reshare result event", err, "walletID", msg.WalletID, "sessionID", msg.SessionID) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to marshal CMP reshare result", natMsg) + logger.Error("Failed to marshal reshare result event", err, "walletID", msg.WalletID, "sessionID", msg.SessionID, "keyType", msg.KeyType) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Failed to marshal %s reshare result", msg.KeyType), + natMsg, + ) return } @@ -910,15 +955,22 @@ func (ec *eventConsumer) handleCMPReshare(msg types.ResharingMessage, natMsg *na IdempotententKey: composeReshareIdempotentKey(msg.SessionID, natMsg), }) if err != nil { - logger.Error("Failed to enqueue CMP reshare result event", err, "walletID", msg.WalletID, "sessionID", msg.SessionID) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to enqueue CMP reshare result", natMsg) + logger.Error("Failed to enqueue reshare result event", err, "walletID", msg.WalletID, "sessionID", msg.SessionID, "keyType", msg.KeyType) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Failed to enqueue %s reshare result", msg.KeyType), + natMsg, + ) return } // Remove this line - don't send reply for reshare messages // ec.sendReplyToRemoveMsg(natMsg) - logger.Info("[COMPLETED CMP RESHARE] CMP reshare completed successfully", "walletID", msg.WalletID, "sessionID", msg.SessionID) + logger.Info("[COMPLETED RESHARE] CMP reshare completed successfully", "walletID", msg.WalletID, "sessionID", msg.SessionID) } // handleReshareSessionError handles errors that occur during reshare operations @@ -1056,8 +1108,12 @@ func sessionTypeFromKeyType(keyType types.KeyType) (mpc.SessionType, error) { return mpc.SessionTypeECDSA, nil case types.KeyTypeEd25519: return mpc.SessionTypeEDDSA, nil - case types.KeyTypeTaurusCmp: - return mpc.SessionTypeTaurusCmp, nil + case types.KeyTypeCGGMP21: + return mpc.SessionTypeCGGMP21, nil + case types.KeyTypeTaproot: + return mpc.SessionTypeTaproot, nil + case types.KeyTypeFROST: + return mpc.SessionTypeFROST, nil default: logger.Warn("Unsupported key type", "keyType", keyType) return "", fmt.Errorf("unsupported key type: %v", keyType) diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index a7e4b20..c9e9f8b 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -15,8 +15,8 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/taurus" + "github.com/fystack/mpcium/pkg/types" "github.com/taurusgroup/multi-party-sig/pkg/party" - "github.com/taurusgroup/multi-party-sig/pkg/pool" ) const ( @@ -143,24 +143,33 @@ func (p *Node) createEDDSAKeyGenSession(walletID string, threshold int, version return session, nil } -func (p *Node) CreateCMPSession( +func (p *Node) CreateTaurusSession( walletID string, threshold int, + sessionType types.KeyType, act taurus.Act, -) (*taurus.CmpParty, error) { +) (taurus.TaurusSession, error) { readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() selfPartyID, allPartyIDs := p.generateTaurusPartyIDs(PurposeKeygen, readyPeerIDs, DefaultVersion) - tr := taurus.NewNATSTransport(walletID, selfPartyID, act, p.pubSub, p.direct, p.identityStore) - adapter := taurus.NewTaurusNetworkAdapter(walletID, selfPartyID, tr, allPartyIDs) - pl := pool.NewPool(0) - session := taurus.NewCmpParty(walletID, selfPartyID, allPartyIDs, threshold, pl, adapter, p.keyinfoStore, p.kvstore) + var session taurus.TaurusSession + switch sessionType { + case types.KeyTypeCGGMP21: + tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.CGGMP21, p.pubSub, p.direct, p.identityStore) + session = taurus.NewCGGMP21Session(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) + case types.KeyTypeTaproot: + tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.FROSTTaproot, p.pubSub, p.direct, p.identityStore) + session = taurus.NewTaprootSession(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) + case types.KeyTypeFROST: + tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.FROST, p.pubSub, p.direct, p.identityStore) + session = taurus.NewFROSTSession(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) + } + if act == taurus.ActSign || act == taurus.ActReshare { err := session.LoadKey(walletID) if err != nil { return nil, err } } - return session, nil } diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go index 5f5c147..43e64ea 100644 --- a/pkg/mpc/session.go +++ b/pkg/mpc/session.go @@ -25,9 +25,11 @@ const ( TypeReshareWalletResultFmt = "mpc.mpc_reshare_result.%s" TypeSigningResultFmt = "mpc.mpc_signing_result.%s" - SessionTypeECDSA SessionType = "session_ecdsa" - SessionTypeEDDSA SessionType = "session_eddsa" - SessionTypeTaurusCmp SessionType = "session_taurus_cmp" + SessionTypeECDSA SessionType = "session_ecdsa" + SessionTypeEDDSA SessionType = "session_eddsa" + SessionTypeCGGMP21 SessionType = "session_cggmp21" + SessionTypeTaproot SessionType = "session_taproot" + SessionTypeFROST SessionType = "session_frost" ) var ( diff --git a/pkg/mpc/taurus/adapter.go b/pkg/mpc/taurus/adapter.go index 84a6b60..d46a398 100644 --- a/pkg/mpc/taurus/adapter.go +++ b/pkg/mpc/taurus/adapter.go @@ -7,39 +7,34 @@ import ( "github.com/taurusgroup/multi-party-sig/pkg/protocol" ) -type NetworkInterface interface { - Next() <-chan *protocol.Message - Send(msg *protocol.Message) -} - -type TaurusNetworkAdapter struct { +type NetworkAdapter struct { sid string selfID party.ID + peers party.IDSlice transport Transport inbox chan *protocol.Message - peers party.IDSlice } -func NewTaurusNetworkAdapter( +func NewNetworkAdapter( sid string, selfID party.ID, t Transport, peers party.IDSlice, -) *TaurusNetworkAdapter { - a := &TaurusNetworkAdapter{ +) *NetworkAdapter { + a := &NetworkAdapter{ sid: sid, selfID: selfID, + peers: peers, transport: t, inbox: make(chan *protocol.Message, 100), - peers: peers, } go a.route() return a } -func (a *TaurusNetworkAdapter) Next() <-chan *protocol.Message { return a.inbox } +func (a *NetworkAdapter) Next() <-chan *protocol.Message { return a.inbox } -func (a *TaurusNetworkAdapter) Send(msg *protocol.Message) { +func (a *NetworkAdapter) Send(msg *protocol.Message) { wire, err := msg.MarshalBinary() if err != nil { logger.Error("marshal protocol msg", err) @@ -59,7 +54,7 @@ func (a *TaurusNetworkAdapter) Send(msg *protocol.Message) { } } -func (a *TaurusNetworkAdapter) route() { +func (a *NetworkAdapter) route() { for tm := range a.transport.Inbox() { var pm protocol.Message if err := pm.UnmarshalBinary(tm.Data); err != nil { diff --git a/pkg/mpc/taurus/cmp.go b/pkg/mpc/taurus/cggmp21.go similarity index 50% rename from pkg/mpc/taurus/cmp.go rename to pkg/mpc/taurus/cggmp21.go index 41a131f..2b27cd5 100644 --- a/pkg/mpc/taurus/cmp.go +++ b/pkg/mpc/taurus/cggmp21.go @@ -8,7 +8,6 @@ import ( "math/big" "github.com/btcsuite/btcd/btcec/v2" - "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/fystack/mpcium/pkg/encoding" "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" @@ -18,44 +17,33 @@ import ( "github.com/taurusgroup/multi-party-sig/pkg/math/curve" "github.com/taurusgroup/multi-party-sig/pkg/party" "github.com/taurusgroup/multi-party-sig/pkg/pool" - "github.com/taurusgroup/multi-party-sig/pkg/protocol" "github.com/taurusgroup/multi-party-sig/protocols/cmp" ) -type CmpParty struct { - sid string - id party.ID - ids party.IDSlice - threshold int - pl *pool.Pool - savedData *cmp.Config - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - network NetworkInterface +type CGGMP21Session struct { + *commonSession + workerPool *pool.Pool + savedData *cmp.Config } -func NewCmpParty( - sid string, - id party.ID, - ids party.IDSlice, +func NewCGGMP21Session( + sessionID string, + selfID party.ID, + peerIDs party.IDSlice, threshold int, - pl *pool.Pool, - network NetworkInterface, - keyinfoStore keyinfo.Store, + transport Transport, kvstore kvstore.KVStore, -) *CmpParty { - return &CmpParty{ - sid: sid, - id: id, - ids: ids, - threshold: threshold, - pl: pl, - network: network, - keyinfoStore: keyinfoStore, - kvstore: kvstore, + keyinfoStore keyinfo.Store, +) TaurusSession { + commonSession := NewCommonSession(sessionID, selfID, peerIDs, threshold, transport, kvstore, keyinfoStore) + return &CGGMP21Session{ + commonSession: commonSession, + workerPool: pool.NewPool(0), + savedData: nil, } } -func (p *CmpParty) LoadKey(sid string) error { + +func (p *CGGMP21Session) LoadKey(sid string) error { key := p.composeKey(sid) data, err := p.kvstore.Get(key) @@ -72,12 +60,12 @@ func (p *CmpParty) LoadKey(sid string) error { return nil } -func (p *CmpParty) Keygen(ctx context.Context) (types.KeyData, error) { - logger.Info("Starting to generate key Taurus CMP", "walletID", p.sid) +func (p *CGGMP21Session) Keygen(ctx context.Context) (types.KeyData, error) { + logger.Info("Starting to generate key CGGMP21", "walletID", p.sessionID) - result, err := p.run(ctx, cmp.Keygen(curve.Secp256k1{}, p.id, p.ids, p.threshold, p.pl)) + result, err := p.run(ctx, cmp.Keygen(curve.Secp256k1{}, p.selfID, p.peerIDs, p.threshold, p.workerPool)) if err != nil { - return types.KeyData{}, fmt.Errorf("cmp keygen: %w", err) + return types.KeyData{}, err } cfg, ok := result.(*cmp.Config) @@ -86,8 +74,12 @@ func (p *CmpParty) Keygen(ctx context.Context) (types.KeyData, error) { } p.savedData = cfg - // Extract public key coordinates - x, y, err := ExtractXYFromPoint(cfg.PublicPoint()) + packed, err := cfg.MarshalBinary() + if err != nil { + return types.KeyData{}, fmt.Errorf("marshal config: %w", err) + } + + x, y, err := extractPublicKey(cfg.PublicPoint()) if err != nil { return types.KeyData{}, fmt.Errorf("extract pubkey: %w", err) } @@ -104,12 +96,7 @@ func (p *CmpParty) Keygen(ctx context.Context) (types.KeyData, error) { return types.KeyData{}, fmt.Errorf("encode pubkey: %w", err) } - packed, err := cfg.MarshalBinary() - if err != nil { - return types.KeyData{}, fmt.Errorf("marshal config: %w", err) - } - - key := p.composeKey(p.sid) + key := p.composeKey(p.sessionID) keyInfo := &keyinfo.KeyInfo{ ParticipantPeerIDs: p.getParticipantPeerIDs(), Threshold: p.threshold, @@ -129,36 +116,37 @@ func (p *CmpParty) Keygen(ctx context.Context) (types.KeyData, error) { } return types.KeyData{ - SID: p.sid, - Type: "taurus_cmp", + SID: p.sessionID, + Type: CGGMP21.String(), PubKeyBytes: pubKeyBytes, }, nil } -func (p *CmpParty) Sign(ctx context.Context, msg *big.Int) ([]byte, error) { +func (p *CGGMP21Session) Sign(ctx context.Context, msg *big.Int) ([]byte, error) { if p.savedData == nil { return nil, errors.New("no key loaded") } - logger.Info("Starting to sign message Taurus CMP", "walletID", p.sid) - cfg, err := p.run(ctx, cmp.Sign(p.savedData, p.ids, msg.Bytes(), p.pl)) + logger.Info("Starting to sign message CGGMP21", "walletID", p.sessionID) + msgHash := msg.Bytes() + result, err := p.run(ctx, cmp.Sign(p.savedData, p.peerIDs, msgHash, p.workerPool)) if err != nil { return nil, err } - sig, ok := cfg.(*ecdsa.Signature) + sig, ok := result.(*ecdsa.Signature) if !ok { return nil, errors.New("unexpected result type") } - if !sig.Verify(p.savedData.PublicPoint(), msg.Bytes()) { + if !sig.Verify(p.savedData.PublicPoint(), msgHash) { return nil, errors.New("signature verification failed") } return sig.SigEthereum() } -func (p *CmpParty) Reshare(ctx context.Context) (res types.ReshareData, err error) { +func (p *CGGMP21Session) Reshare(ctx context.Context) (res types.ReshareData, err error) { if p.savedData == nil { return res, errors.New("no key loaded") } - cfg, err := p.run(ctx, cmp.Refresh(p.savedData, p.pl)) + cfg, err := p.run(ctx, cmp.Refresh(p.savedData, p.workerPool)) if err != nil { return res, err } @@ -169,7 +157,7 @@ func (p *CmpParty) Reshare(ctx context.Context) (res types.ReshareData, err erro p.savedData = savedData packed, _ := p.savedData.MarshalBinary() - key := p.composeKey(p.sid) + key := p.composeKey(p.sessionID) // Store updated key share if p.kvstore != nil { if err := p.kvstore.Put(key, packed); err != nil { @@ -178,7 +166,7 @@ func (p *CmpParty) Reshare(ctx context.Context) (res types.ReshareData, err erro } // Extract public key coordinates - x, y, err := ExtractXYFromPoint(p.savedData.PublicPoint()) + x, y, err := extractPublicKey(p.savedData.PublicPoint()) if err != nil { return res, fmt.Errorf("extract pubkey: %w", err) } @@ -195,59 +183,12 @@ func (p *CmpParty) Reshare(ctx context.Context) (res types.ReshareData, err erro return res, fmt.Errorf("encode pubkey: %w", err) } - return types.ReshareData{KeyData: types.KeyData{SID: p.sid, Type: "taurus_cmp", PubKeyBytes: pubKeyBytes}, Threshold: p.threshold}, nil -} - -func (p *CmpParty) run(ctx context.Context, proto protocol.StartFunc) (any, error) { - logger.Info("Starting to run Taurus CMP", "walletID", p.sid) - h, err := protocol.NewMultiHandler(proto, []byte(p.sid)) - if err != nil { - return nil, err - } - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case msg, ok := <-h.Listen(): - if !ok { - return h.Result() - } - p.network.Send(msg) - case msg := <-p.network.Next(): - if h.CanAccept(msg) { - h.Accept(msg) - } else { - logger.Debug("Ignored self broadcast msg", - "self", p.id, - "from", msg.From, - "to", msg.To, - "broadcast", msg.Broadcast, - ) - } - } - } -} - -func ExtractXYFromPoint(p curve.Point) (*big.Int, *big.Int, error) { - data, err := p.MarshalBinary() // compressed SEC1 form (33 bytes) - if err != nil { - return nil, nil, fmt.Errorf("marshal point: %w", err) - } - pk, err := secp256k1.ParsePubKey(data) - if err != nil { - return nil, nil, fmt.Errorf("parse secp256k1 pubkey: %w", err) - } - return pk.X(), pk.Y(), nil -} - -func (p *CmpParty) getParticipantPeerIDs() []string { - var ids []string - for _, id := range p.ids { - ids = append(ids, string(id)) - } - return ids + return types.ReshareData{ + KeyData: types.KeyData{SID: p.sessionID, Type: CGGMP21.String(), PubKeyBytes: pubKeyBytes}, + Threshold: p.threshold, + }, nil } -func (p *CmpParty) composeKey(sid string) string { - return fmt.Sprintf("taurus_cmp:%s", sid) +func (p *CGGMP21Session) composeKey(sid string) string { + return fmt.Sprintf("cggmp21:%s", sid) } diff --git a/pkg/mpc/taurus/cmp_test.go b/pkg/mpc/taurus/cmp_test.go deleted file mode 100644 index a338b2b..0000000 --- a/pkg/mpc/taurus/cmp_test.go +++ /dev/null @@ -1,114 +0,0 @@ -package taurus - -import ( - "bytes" - "context" - "math/big" - "sync" - "testing" - - "github.com/fystack/mpcium/pkg/logger" - "github.com/taurusgroup/multi-party-sig/pkg/party" -) - -type cmpTest struct { - parties []*CmpParty - results map[string]chan any -} - -func newCmpTest(sid string, ids []party.ID) *cmpTest { - t := &cmpTest{ - results: map[string]chan any{ - "keygen": make(chan any, len(ids)), - "sign": make(chan any, len(ids)), - "reshare": make(chan any, len(ids)), - }, - } - - // Create all memory transports first - transports := make([]*Memory, len(ids)) - for i, id := range ids { - transports[i] = NewMemoryParty(string(id)) - } - - // Link all peers together - LinkPeers(transports...) - - // Create parties with linked transports - for i, id := range ids { - adapter := NewTaurusNetworkAdapter(sid, id, transports[i], ids) - t.parties = append(t.parties, NewCmpParty(sid, id, ids, 2, - nil, adapter, nil, nil)) - } - - return t -} - -func (t *cmpTest) runAll(fn func(*CmpParty) (any, error), key string) { - var wg sync.WaitGroup - for _, p := range t.parties { - wg.Add(1) - go func(p *CmpParty) { - defer wg.Done() - res, err := fn(p) - if err != nil { - logger.Error("operation failed", err) - return - } - t.results[key] <- res - }(p) - } - wg.Wait() -} - -func TestCmpParty(t *testing.T) { - sid := "test-session-123" - ids := []party.ID{"node0", "node1", "node2"} - test := newCmpTest(sid, ids) - - // --- Keygen --- - test.runAll(func(p *CmpParty) (any, error) { - return p.Keygen(context.Background()) - }, "keygen") - - // --- Sign 1 --- - msg := big.NewInt(1) - test.runAll(func(p *CmpParty) (any, error) { - return p.Sign(context.Background(), msg) - }, "sign") - - sigs := drain[[]byte](test.results["sign"]) - assertAllBytesEqual(t, sigs) - - // // --- Reshare --- - // test.runAll(func(p *CmpParty) (any, error) { - // return p.Reshare(context.Background()) - // }, "reshare") - - // // // --- Sign 2 --- - // msg = big.NewInt(2) - // test.runAll(func(p *CmpParty) (any, error) { - // return p.Sign(context.Background(), msg) - // }, "sign") -} - -func drain[T any](ch chan any) []T { - n := len(ch) - out := make([]T, n) - for i := 0; i < n; i++ { - out[i] = (<-ch).(T) - } - return out -} - -func assertAllBytesEqual(t *testing.T, vals [][]byte) { - if len(vals) == 0 { - t.Fatal("no values to compare") - } - first := vals[0] - for i, v := range vals[1:] { - if !bytes.Equal(first, v) { - t.Fatalf("byte slices not equal at index %d", i+1) - } - } -} diff --git a/pkg/mpc/taurus/common.go b/pkg/mpc/taurus/common.go new file mode 100644 index 0000000..8fbdc31 --- /dev/null +++ b/pkg/mpc/taurus/common.go @@ -0,0 +1,123 @@ +package taurus + +import ( + "context" + "errors" + "fmt" + "math/big" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/types" + "github.com/taurusgroup/multi-party-sig/pkg/math/curve" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/protocol" +) + +type Protocol string + +const ( + CGGMP21 Protocol = "cggmp21-ecdsa" // Canetti et al. 2021 (CMP) + FROST Protocol = "frost-schnorr" // FROST Schnorr signatures + FROSTTaproot Protocol = "frost-taproot" // FROST for Bitcoin Taproot +) + +func (p Protocol) String() string { + return string(p) +} + +type TaurusSession interface { + LoadKey(sid string) error + Keygen(ctx context.Context) (types.KeyData, error) + Sign(ctx context.Context, msg *big.Int) ([]byte, error) + Reshare(ctx context.Context) (types.ReshareData, error) +} + +type commonSession struct { + sessionID string + threshold int + selfID party.ID + peerIDs party.IDSlice + network *NetworkAdapter + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store +} + +func NewCommonSession( + sessionID string, + selfID party.ID, + peerIDs party.IDSlice, + threshold int, + transport Transport, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, +) *commonSession { + net := NewNetworkAdapter(sessionID, selfID, transport, peerIDs) + return &commonSession{ + sessionID: sessionID, + selfID: selfID, + peerIDs: peerIDs, + threshold: threshold, + network: net, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + } +} + +func (p *commonSession) run(ctx context.Context, proto protocol.StartFunc) (any, error) { + h, err := protocol.NewMultiHandler(proto, []byte(p.sessionID)) + if err != nil { + return nil, err + } + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-h.Listen(): + if !ok { + return h.Result() + } + go p.network.Send(msg) + case msg := <-p.network.Next(): + if h.CanAccept(msg) { + h.Accept(msg) + } + } + } +} + +func (p *commonSession) getParticipantPeerIDs() []string { + var ids []string + for _, id := range p.peerIDs { + ids = append(ids, string(id)) + } + return ids +} + +func extractPublicKey(pubPoint curve.Point) (*big.Int, *big.Int, error) { + if pubPoint == nil { + return nil, nil, errors.New("nil public point") + } + + data, err := pubPoint.MarshalBinary() + if err != nil { + return nil, nil, fmt.Errorf("marshal public key: %w", err) + } + + if len(data) == 0 { + return nil, nil, errors.New("empty public key data") + } + + // Use btcec's ParsePubKey which handles both compressed and uncompressed formats + pubKey, err := btcec.ParsePubKey(data) + if err != nil { + return nil, nil, fmt.Errorf("parse public key: %w", err) + } + + // Extract x and y coordinates + x := pubKey.X() + y := pubKey.Y() + + return x, y, nil +} diff --git a/pkg/mpc/taurus/frost.go b/pkg/mpc/taurus/frost.go new file mode 100644 index 0000000..dc2cc46 --- /dev/null +++ b/pkg/mpc/taurus/frost.go @@ -0,0 +1,197 @@ +package taurus + +import ( + "context" + cryptoEcdsa "crypto/ecdsa" + "encoding/json" + "errors" + "fmt" + "math/big" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" + "github.com/taurusgroup/multi-party-sig/pkg/math/curve" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/protocols/frost" +) + +type FROSTSession struct { + *commonSession + savedData *frost.Config +} + +func NewFROSTSession( + sessionID string, + selfID party.ID, + peerIDs party.IDSlice, + threshold int, + transport Transport, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, +) TaurusSession { + commonSession := NewCommonSession(sessionID, selfID, peerIDs, threshold, transport, kvstore, keyinfoStore) + return &FROSTSession{ + commonSession: commonSession, + savedData: nil, + } +} + +func (p *FROSTSession) LoadKey(sid string) error { + key := p.composeKey(sid) + + data, err := p.kvstore.Get(key) + if err != nil { + return fmt.Errorf("load key: %w", err) + } + + cfg := frost.EmptyConfig(curve.Secp256k1{}) + if err := json.Unmarshal(data, cfg); err != nil { + return fmt.Errorf("unmarshal key config: %w", err) + } + + p.savedData = cfg + return nil +} + +func (p *FROSTSession) Keygen(ctx context.Context) (types.KeyData, error) { + logger.Info("Starting to generate key FROST", "walletID", p.sessionID) + + result, err := p.run(ctx, frost.Keygen(curve.Secp256k1{}, p.selfID, p.peerIDs, p.threshold)) + if err != nil { + return types.KeyData{}, err + } + + cfg, ok := result.(*frost.Config) + if !ok { + return types.KeyData{}, fmt.Errorf("unexpected result type %T", result) + } + p.savedData = cfg + + // Extract public key coordinates + x, y, err := extractPublicKey(cfg.PublicKey) + if err != nil { + return types.KeyData{}, fmt.Errorf("extract pubkey: %w", err) + } + + // Use secp256k1 curve, not P256 + pubKey := &cryptoEcdsa.PublicKey{ + Curve: btcec.S256(), + X: x, + Y: y, + } + + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + return types.KeyData{}, fmt.Errorf("encode pubkey: %w", err) + } + + packed, err := json.Marshal(cfg) + if err != nil { + return types.KeyData{}, fmt.Errorf("marshal config: %w", err) + } + + key := p.composeKey(p.sessionID) + keyInfo := &keyinfo.KeyInfo{ + ParticipantPeerIDs: p.getParticipantPeerIDs(), + Threshold: p.threshold, + Version: 1, + } + + // Store both key and metadata if stores available + if p.kvstore != nil { + if err := p.kvstore.Put(key, packed); err != nil { + return types.KeyData{}, fmt.Errorf("store key: %w", err) + } + } + if p.keyinfoStore != nil { + if err := p.keyinfoStore.Save(key, keyInfo); err != nil { + return types.KeyData{}, fmt.Errorf("store key info: %w", err) + } + } + + return types.KeyData{ + SID: p.sessionID, + Type: FROST.String(), + PubKeyBytes: pubKeyBytes, + }, nil +} + +func (p *FROSTSession) Sign(ctx context.Context, msg *big.Int) ([]byte, error) { + if p.savedData == nil { + return nil, errors.New("no key loaded") + } + logger.Info("Starting to sign message FROST", "walletID", p.sessionID) + msgHash := msg.Bytes() + result, err := p.run(ctx, frost.Sign(p.savedData, p.peerIDs, msgHash)) + if err != nil { + return nil, err + } + + sig, ok := result.(frost.Signature) + if !ok { + return nil, fmt.Errorf("unexpected result type %T", result) + } + + if !sig.Verify(p.savedData.PublicKey, msgHash) { + return nil, errors.New("signature verification failed") + } + return sig.R.MarshalBinary() +} + +func (p *FROSTSession) Reshare(ctx context.Context) (res types.ReshareData, err error) { + if p.savedData == nil { + return res, errors.New("no key loaded") + } + cfg, err := p.run(ctx, frost.Refresh(p.savedData, p.peerIDs)) + if err != nil { + return res, err + } + savedData, ok := cfg.(*frost.Config) + if !ok { + return res, errors.New("unexpected result type") + } + p.savedData = savedData + packed, err := json.Marshal(p.savedData) + if err != nil { + return res, fmt.Errorf("marshal config: %w", err) + } + + key := p.composeKey(p.sessionID) + // Store updated key share + if p.kvstore != nil { + if err := p.kvstore.Put(key, packed); err != nil { + return res, fmt.Errorf("store key: %w", err) + } + } + + // Extract public key coordinates + x, y, err := extractPublicKey(p.savedData.PublicKey) + if err != nil { + return res, fmt.Errorf("extract pubkey: %w", err) + } + + // Use secp256k1 curve, not P256 + pubKey := &cryptoEcdsa.PublicKey{ + Curve: btcec.S256(), + X: x, + Y: y, + } + + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + return res, fmt.Errorf("encode pubkey: %w", err) + } + + return types.ReshareData{ + KeyData: types.KeyData{SID: p.sessionID, Type: FROST.String(), PubKeyBytes: pubKeyBytes}, + Threshold: p.threshold, + }, nil +} + +func (p *FROSTSession) composeKey(sid string) string { + return fmt.Sprintf("frost:%s", sid) +} diff --git a/pkg/mpc/taurus/nats_transport.go b/pkg/mpc/taurus/nats_transport.go index 851ff01..fe1cf6c 100644 --- a/pkg/mpc/taurus/nats_transport.go +++ b/pkg/mpc/taurus/nats_transport.go @@ -31,6 +31,7 @@ type NATSTransport struct { selfID string wallet string act Act + proto Protocol topicComposer *TopicComposer pubsub messaging.PubSub direct messaging.DirectMessaging @@ -45,6 +46,7 @@ func NewNATSTransport( walletID string, self party.ID, act Act, + proto Protocol, pubsub messaging.PubSub, direct messaging.DirectMessaging, identityStore identity.Store, @@ -53,15 +55,16 @@ func NewNATSTransport( selfID: string(self), wallet: walletID, act: act, + proto: proto, pubsub: pubsub, direct: direct, identityStore: identityStore, topicComposer: &TopicComposer{ ComposeBroadcastTopic: func() string { - return fmt.Sprintf("%s:broadcast:cmp:%s", act, walletID) + return fmt.Sprintf("%s:broadcast:%s:%s", act, proto, walletID) }, ComposeDirectTopic: func(to string, walletID string) string { - return fmt.Sprintf("%s:direct:cmp:%s:%s", act, to, walletID) + return fmt.Sprintf("%s:direct:%s:%s:%s", act, proto, to, walletID) }, }, inbox: make(chan types.TaurusMessage, 128), diff --git a/pkg/mpc/taurus/taproot.go b/pkg/mpc/taurus/taproot.go new file mode 100644 index 0000000..6f7ac72 --- /dev/null +++ b/pkg/mpc/taurus/taproot.go @@ -0,0 +1,204 @@ +package taurus + +import ( + "context" + cryptoEcdsa "crypto/ecdsa" + "errors" + "fmt" + "math/big" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/fxamacker/cbor/v2" + "github.com/fystack/mpcium/pkg/encoding" + "github.com/fystack/mpcium/pkg/keyinfo" + "github.com/fystack/mpcium/pkg/kvstore" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" + "github.com/taurusgroup/multi-party-sig/pkg/math/curve" + "github.com/taurusgroup/multi-party-sig/pkg/party" + "github.com/taurusgroup/multi-party-sig/pkg/taproot" + "github.com/taurusgroup/multi-party-sig/protocols/frost" +) + +type TaprootSession struct { + *commonSession + savedData *frost.TaprootConfig +} + +func NewTaprootSession( + sessionID string, + selfID party.ID, + peerIDs party.IDSlice, + threshold int, + transport Transport, + kvstore kvstore.KVStore, + keyinfoStore keyinfo.Store, +) TaurusSession { + commonSession := NewCommonSession(sessionID, selfID, peerIDs, threshold, transport, kvstore, keyinfoStore) + return &TaprootSession{ + commonSession: commonSession, + savedData: nil, + } +} + +func (p *TaprootSession) LoadKey(sid string) error { + key := p.composeKey(sid) + + data, err := p.kvstore.Get(key) + if err != nil { + return fmt.Errorf("load key: %w", err) + } + + cfg := &frost.TaprootConfig{} + if err := cbor.Unmarshal(data, &cfg); err != nil { + return fmt.Errorf("unmarshal key config: %w", err) + } + + p.savedData = cfg + return nil +} + +func (p *TaprootSession) Keygen(ctx context.Context) (types.KeyData, error) { + logger.Info("Starting to generate key Taproot", "walletID", p.sessionID) + + result, err := p.run(ctx, frost.KeygenTaproot(p.selfID, p.peerIDs, p.threshold)) + if err != nil { + return types.KeyData{}, err + } + + cfg, ok := result.(*frost.TaprootConfig) + if !ok { + return types.KeyData{}, fmt.Errorf("unexpected result type %T", result) + } + p.savedData = cfg + + pubPoint, err := curve.Secp256k1{}.LiftX(cfg.PublicKey) + if err != nil { + return types.KeyData{}, fmt.Errorf("lift pubkey: %w", err) + } + // Extract public key coordinates + x, y, err := extractPublicKey(pubPoint) + if err != nil { + return types.KeyData{}, fmt.Errorf("extract pubkey: %w", err) + } + + // Use secp256k1 curve, not P256 + pubKey := &cryptoEcdsa.PublicKey{ + Curve: btcec.S256(), + X: x, + Y: y, + } + + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + return types.KeyData{}, fmt.Errorf("encode pubkey: %w", err) + } + + packed, err := cbor.Marshal(cfg) + if err != nil { + return types.KeyData{}, fmt.Errorf("marshal config: %w", err) + } + + key := p.composeKey(p.sessionID) + keyInfo := &keyinfo.KeyInfo{ + ParticipantPeerIDs: p.getParticipantPeerIDs(), + Threshold: p.threshold, + Version: 1, + } + + // Store both key and metadata if stores available + if p.kvstore != nil { + if err := p.kvstore.Put(key, packed); err != nil { + return types.KeyData{}, fmt.Errorf("store key: %w", err) + } + } + if p.keyinfoStore != nil { + if err := p.keyinfoStore.Save(key, keyInfo); err != nil { + return types.KeyData{}, fmt.Errorf("store key info: %w", err) + } + } + + return types.KeyData{ + SID: p.sessionID, + Type: FROSTTaproot.String(), + PubKeyBytes: pubKeyBytes, + }, nil +} + +func (p *TaprootSession) Sign(ctx context.Context, msg *big.Int) ([]byte, error) { + if p.savedData == nil { + return nil, errors.New("no key loaded") + } + logger.Info("Starting to sign message Taproot", "walletID", p.sessionID) + msgHash := msg.Bytes() + result, err := p.run(ctx, frost.SignTaproot(p.savedData, p.peerIDs, msgHash)) + if err != nil { + return nil, err + } + sig, ok := result.(taproot.Signature) + if !ok { + return nil, errors.New("unexpected result type") + } + if !p.savedData.PublicKey.Verify(sig, msgHash) { + return nil, errors.New("signature verification failed") + } + return []byte(sig), nil +} + +func (p *TaprootSession) Reshare(ctx context.Context) (res types.ReshareData, err error) { + if p.savedData == nil { + return res, errors.New("no key loaded") + } + cfg, err := p.run(ctx, frost.RefreshTaproot(p.savedData, p.peerIDs)) + if err != nil { + return res, err + } + savedData, ok := cfg.(*frost.TaprootConfig) + if !ok { + return res, errors.New("unexpected result type") + } + p.savedData = savedData + packed, err := cbor.Marshal(p.savedData) + if err != nil { + return res, fmt.Errorf("marshal config: %w", err) + } + + key := p.composeKey(p.sessionID) + // Store updated key share + if p.kvstore != nil { + if err := p.kvstore.Put(key, packed); err != nil { + return res, fmt.Errorf("store key: %w", err) + } + } + + // Extract public key coordinates + pubPoint, err := curve.Secp256k1{}.LiftX(p.savedData.PublicKey) + if err != nil { + return res, fmt.Errorf("lift pubkey: %w", err) + } + x, y, err := extractPublicKey(pubPoint) + if err != nil { + return res, fmt.Errorf("extract pubkey: %w", err) + } + + // Use secp256k1 curve, not P256 + pubKey := &cryptoEcdsa.PublicKey{ + Curve: btcec.S256(), + X: x, + Y: y, + } + + pubKeyBytes, err := encoding.EncodeS256PubKey(pubKey) + if err != nil { + return res, fmt.Errorf("encode pubkey: %w", err) + } + + return types.ReshareData{ + KeyData: types.KeyData{SID: p.sessionID, Type: CGGMP21.String(), PubKeyBytes: pubKeyBytes}, + Threshold: p.threshold, + }, nil +} + +func (p *TaprootSession) composeKey(sid string) string { + return fmt.Sprintf("taproot:%s", sid) +} diff --git a/pkg/mpc/taurus/taurus_test.go b/pkg/mpc/taurus/taurus_test.go new file mode 100644 index 0000000..0d09dc4 --- /dev/null +++ b/pkg/mpc/taurus/taurus_test.go @@ -0,0 +1,102 @@ +package taurus + +import ( + "bytes" + "context" + "math/big" + "sync" + "testing" + + "github.com/fystack/mpcium/pkg/logger" + "github.com/taurusgroup/multi-party-sig/pkg/party" +) + +// taurusTest represents a 2-party in-memory network for Taurus +type taurusTest struct { + parties []TaurusSession + results map[string]chan any +} + +func newTaurusTest(sid string, ids []party.ID) *taurusTest { + t := &taurusTest{ + results: map[string]chan any{ + "keygen": make(chan any, len(ids)), + "sign": make(chan any, len(ids)), + }, + } + + transports := make([]*Memory, len(ids)) + for i, id := range ids { + transports[i] = NewMemoryParty(string(id)) + } + LinkPeers(transports...) + + for i, id := range ids { + t.parties = append(t.parties, + NewTaprootSession(sid, id, ids, 1, transports[i], nil, nil)) + } + + return t +} + +func (t *taurusTest) runAll(fn func(TaurusSession) (any, error), key string) { + var wg sync.WaitGroup + for _, p := range t.parties { + wg.Add(1) + go func(p TaurusSession) { + defer wg.Done() + + res, err := fn(p) + if err != nil { + logger.Error("operation failed", err) + return + } + t.results[key] <- res + }(p) + } + wg.Wait() + close(t.results[key]) +} + +func drain[T any](ch chan any) []T { + out := make([]T, 0, len(ch)) + for v := range ch { + out = append(out, v.(T)) + } + return out +} + +func assertAllBytesEqual(t *testing.T, vals [][]byte) { + if len(vals) == 0 { + t.Fatal("no values to compare") + } + first := vals[0] + for i, v := range vals[1:] { + if !bytes.Equal(first, v) { + t.Fatalf("byte slices not equal at index %d", i+1) + } + } +} + +func TestTaurusParty(t *testing.T) { + t.Parallel() + + // quick test, 2 nodes only + ids := []party.ID{"node0", "node1"} + sid := "cggmp21-fast" + test := newTaurusTest(sid, ids) + + // --- Keygen (cached) --- + test.runAll(func(p TaurusSession) (any, error) { + return p.Keygen(context.Background()) + }, "keygen") + + // --- Sign --- + msg := big.NewInt(42) + test.runAll(func(p TaurusSession) (any, error) { + return p.Sign(context.Background(), msg) + }, "sign") + + sigs := drain[[]byte](test.results["sign"]) + assertAllBytesEqual(t, sigs) +} diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index 951b81f..30ae9f4 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -7,7 +7,9 @@ type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" KeyTypeEd25519 KeyType = "ed25519" - KeyTypeTaurusCmp KeyType = "taurus_cmp" + KeyTypeCGGMP21 KeyType = "cggmp21" + KeyTypeFROST KeyType = "frost" + KeyTypeTaproot KeyType = "taproot" ) type EventInitiatorKeyType string From 17318b01645ee5027fe57f95281cfa14b3b1d83d Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 15 Oct 2025 14:33:35 +0700 Subject: [PATCH 12/21] refactor: format taurus pkg --- pkg/mpc/taurus/cggmp21.go | 21 ++++++++++++++++++--- pkg/mpc/taurus/frost.go | 10 +++++++++- pkg/mpc/taurus/taproot.go | 16 ++++++++++++++-- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/pkg/mpc/taurus/cggmp21.go b/pkg/mpc/taurus/cggmp21.go index 2b27cd5..221bb23 100644 --- a/pkg/mpc/taurus/cggmp21.go +++ b/pkg/mpc/taurus/cggmp21.go @@ -35,7 +35,15 @@ func NewCGGMP21Session( kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, ) TaurusSession { - commonSession := NewCommonSession(sessionID, selfID, peerIDs, threshold, transport, kvstore, keyinfoStore) + commonSession := NewCommonSession( + sessionID, + selfID, + peerIDs, + threshold, + transport, + kvstore, + keyinfoStore, + ) return &CGGMP21Session{ commonSession: commonSession, workerPool: pool.NewPool(0), @@ -63,7 +71,10 @@ func (p *CGGMP21Session) LoadKey(sid string) error { func (p *CGGMP21Session) Keygen(ctx context.Context) (types.KeyData, error) { logger.Info("Starting to generate key CGGMP21", "walletID", p.sessionID) - result, err := p.run(ctx, cmp.Keygen(curve.Secp256k1{}, p.selfID, p.peerIDs, p.threshold, p.workerPool)) + result, err := p.run( + ctx, + cmp.Keygen(curve.Secp256k1{}, p.selfID, p.peerIDs, p.threshold, p.workerPool), + ) if err != nil { return types.KeyData{}, err } @@ -184,7 +195,11 @@ func (p *CGGMP21Session) Reshare(ctx context.Context) (res types.ReshareData, er } return types.ReshareData{ - KeyData: types.KeyData{SID: p.sessionID, Type: CGGMP21.String(), PubKeyBytes: pubKeyBytes}, + KeyData: types.KeyData{ + SID: p.sessionID, + Type: CGGMP21.String(), + PubKeyBytes: pubKeyBytes, + }, Threshold: p.threshold, }, nil } diff --git a/pkg/mpc/taurus/frost.go b/pkg/mpc/taurus/frost.go index dc2cc46..cac8f47 100644 --- a/pkg/mpc/taurus/frost.go +++ b/pkg/mpc/taurus/frost.go @@ -33,7 +33,15 @@ func NewFROSTSession( kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, ) TaurusSession { - commonSession := NewCommonSession(sessionID, selfID, peerIDs, threshold, transport, kvstore, keyinfoStore) + commonSession := NewCommonSession( + sessionID, + selfID, + peerIDs, + threshold, + transport, + kvstore, + keyinfoStore, + ) return &FROSTSession{ commonSession: commonSession, savedData: nil, diff --git a/pkg/mpc/taurus/taproot.go b/pkg/mpc/taurus/taproot.go index 6f7ac72..5865436 100644 --- a/pkg/mpc/taurus/taproot.go +++ b/pkg/mpc/taurus/taproot.go @@ -34,7 +34,15 @@ func NewTaprootSession( kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, ) TaurusSession { - commonSession := NewCommonSession(sessionID, selfID, peerIDs, threshold, transport, kvstore, keyinfoStore) + commonSession := NewCommonSession( + sessionID, + selfID, + peerIDs, + threshold, + transport, + kvstore, + keyinfoStore, + ) return &TaprootSession{ commonSession: commonSession, savedData: nil, @@ -194,7 +202,11 @@ func (p *TaprootSession) Reshare(ctx context.Context) (res types.ReshareData, er } return types.ReshareData{ - KeyData: types.KeyData{SID: p.sessionID, Type: CGGMP21.String(), PubKeyBytes: pubKeyBytes}, + KeyData: types.KeyData{ + SID: p.sessionID, + Type: CGGMP21.String(), + PubKeyBytes: pubKeyBytes, + }, Threshold: p.threshold, }, nil } From 93fbe505260f68b16be587fa1338fe3d0f79b076 Mon Sep 17 00:00:00 2001 From: vietddude Date: Sun, 26 Oct 2025 04:26:48 +0700 Subject: [PATCH 13/21] feat: implement presign CGGMP21 functionality in MPC client and event consumer --- cmd/mpcium/main.go | 4 + examples/presign/main.go | 101 +++++++++++++++++++ pkg/client/client.go | 48 +++++++++ pkg/event/presign.go | 19 ++++ pkg/eventconsumer/event_consumer.go | 146 ++++++++++++++++++++++++++++ pkg/mpc/node.go | 6 +- pkg/mpc/taurus/cggmp21.go | 64 ++++++++++-- pkg/mpc/taurus/common.go | 5 + pkg/mpc/taurus/nats_transport.go | 1 + pkg/mpc/taurus/presign.go | 47 +++++++++ pkg/types/initiator_msg.go | 35 +++++++ 11 files changed, 468 insertions(+), 8 deletions(-) create mode 100644 examples/presign/main.go create mode 100644 pkg/event/presign.go create mode 100644 pkg/mpc/taurus/presign.go diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index a68eb0d..5949bbf 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -177,6 +177,7 @@ func runNode(ctx context.Context, c *cli.Command) error { "mpc.mpc_keygen_result.*", event.SigningResultTopic, "mpc.mpc_reshare_result.*", + event.PresignResultTopic, }, natsConn) genKeyResultQueue := mqManager.NewMessageQueue("mpc_keygen_result") @@ -185,6 +186,8 @@ func runNode(ctx context.Context, c *cli.Command) error { defer singingResultQueue.Close() reshareResultQueue := mqManager.NewMessageQueue("mpc_reshare_result") defer reshareResultQueue.Close() + presignResultQueue := mqManager.NewMessageQueue("mpc_presign_result") + defer presignResultQueue.Close() logger.Info("Node is running", "ID", nodeID, "name", nodeName) @@ -209,6 +212,7 @@ func runNode(ctx context.Context, c *cli.Command) error { genKeyResultQueue, singingResultQueue, reshareResultQueue, + presignResultQueue, identityStore, ) eventConsumer.Run() diff --git a/examples/presign/main.go b/examples/presign/main.go new file mode 100644 index 0000000..bd07778 --- /dev/null +++ b/examples/presign/main.go @@ -0,0 +1,101 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "slices" + "syscall" + + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/config" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" + "github.com/nats-io/nats.go" + "github.com/spf13/viper" +) + +func main() { + const environment = "dev" + config.InitViperConfig("") + logger.Init(environment, true) + + algorithm := viper.GetString("event_initiator_algorithm") + if algorithm == "" { + algorithm = string(types.EventInitiatorKeyTypeEd25519) + } + + // Validate algorithm + if !slices.Contains( + []string{ + string(types.EventInitiatorKeyTypeEd25519), + string(types.EventInitiatorKeyTypeP256), + }, + algorithm, + ) { + logger.Fatal( + fmt.Sprintf( + "invalid algorithm: %s. Must be %s or %s", + algorithm, + types.EventInitiatorKeyTypeEd25519, + types.EventInitiatorKeyTypeP256, + ), + nil, + ) + } + natsURL := viper.GetString("nats.url") + natsConn, err := nats.Connect(natsURL) + if err != nil { + logger.Fatal("Failed to connect to NATS", err) + } + defer natsConn.Drain() + defer natsConn.Close() + + localSigner, err := client.NewLocalSigner(types.EventInitiatorKeyType(algorithm), client.LocalSignerOptions{ + KeyPath: "./event_initiator.key", + }) + if err != nil { + logger.Fatal("Failed to create local signer", err) + } + + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + Signer: localSigner, + }) + + // 2) Once wallet exists, immediately fire a SignTransaction + txID := uuid.New().String() + dummyTx := []byte("deadbeef") // replace with real transaction bytes + + txMsg := &types.PresignTxMessage{ + KeyType: types.KeyTypeCGGMP21, + WalletID: "196c6858-30de-4a49-9134-8bc825d40764", // Use the generated wallet ID + NetworkInternalCode: "solana-devnet", + TxID: txID, + Tx: dummyTx, + } + err = mpcClient.PresignTransaction(txMsg) + if err != nil { + logger.Fatal("PresignTransaction failed", err) + } + fmt.Printf("PresignTransaction(%q) sent, awaiting result...\n", txID) + + // 3) Listen for signing results + err = mpcClient.OnPresignResult(func(evt event.PresignResultEvent) { + logger.Info("Presign result received", + "txID", evt.TxID, + "status", evt.Status, + ) + }) + if err != nil { + logger.Fatal("Failed to subscribe to OnPresignResult", err) + } + + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + <-stop + + fmt.Println("Shutting down.") +} diff --git a/pkg/client/client.go b/pkg/client/client.go index 3121bdb..0413abf 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -27,6 +27,9 @@ type MPCClient interface { Resharing(msg *types.ResharingMessage) error OnResharingResult(callback func(event event.ResharingResultEvent)) error + + PresignTransaction(msg *types.PresignTxMessage) error + OnPresignResult(callback func(event event.PresignResultEvent)) error } type mpcClient struct { @@ -36,6 +39,7 @@ type mpcClient struct { genKeySuccessQueue messaging.MessageQueue signResultQueue messaging.MessageQueue reshareSuccessQueue messaging.MessageQueue + presignSuccessQueue messaging.MessageQueue signer Signer } @@ -85,11 +89,13 @@ func NewMPCClient(opts Options) MPCClient { "mpc.mpc_keygen_result.*", "mpc.mpc_signing_result.*", "mpc.mpc_reshare_result.*", + "mpc.mpc_presign_result.*", }, opts.NatsConn) genKeySuccessQueue := manager.NewMessageQueue("mpc_keygen_result") signResultQueue := manager.NewMessageQueue("mpc_signing_result") reshareSuccessQueue := manager.NewMessageQueue("mpc_reshare_result") + presignSuccessQueue := manager.NewMessageQueue("mpc_presign_result") return &mpcClient{ signingBroker: signingBroker, @@ -98,6 +104,7 @@ func NewMPCClient(opts Options) MPCClient { genKeySuccessQueue: genKeySuccessQueue, signResultQueue: signResultQueue, reshareSuccessQueue: reshareSuccessQueue, + presignSuccessQueue: presignSuccessQueue, signer: opts.Signer, } } @@ -235,3 +242,44 @@ func (c *mpcClient) OnResharingResult(callback func(event event.ResharingResultE return nil } + +func (c *mpcClient) PresignTransaction(msg *types.PresignTxMessage) error { + // compute the canonical raw bytes + raw, err := msg.Raw() + if err != nil { + return fmt.Errorf("PresignTransaction: raw payload error: %w", err) + } + signature, err := c.signer.Sign(raw) + if err != nil { + return fmt.Errorf("PresignTransaction: failed to sign message: %w", err) + } + msg.Signature = signature + + bytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("PresignTransaction: marshal error: %w", err) + } + + if err := c.pubsub.Publish(eventconsumer.MPCPresignEvent, bytes); err != nil { + return fmt.Errorf("PresignTransaction: publish error: %w", err) + } + return nil +} + +func (c *mpcClient) OnPresignResult(callback func(event event.PresignResultEvent)) error { + err := c.presignSuccessQueue.Dequeue(event.PresignResultTopic, func(msg []byte) error { + var event event.PresignResultEvent + err := json.Unmarshal(msg, &event) + if err != nil { + return err + } + callback(event) + return nil + }) + + if err != nil { + return fmt.Errorf("OnPresignResult: subscribe error: %w", err) + } + + return nil +} diff --git a/pkg/event/presign.go b/pkg/event/presign.go new file mode 100644 index 0000000..0f74d34 --- /dev/null +++ b/pkg/event/presign.go @@ -0,0 +1,19 @@ +package event + +const ( + PresignBrokerStream = "mpc-presign" + PresignConsumerStream = "mpc-presign-consumer" + PresignRequestTopic = "mpc.presign_request.*" + PresignResultTopic = "mpc.mpc_presign_result.*" +) + +type PresignResultEvent struct { + ResultType ResultType `json:"result_type"` + ErrorCode ErrorCode `json:"error_code"` + ErrorReason string `json:"error_reason"` + IsTimeout bool `json:"is_timeout"` + NetworkInternalCode string `json:"network_internal_code"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` + Status string `json:"status"` +} diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index c2daaf1..7199306 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -25,6 +25,7 @@ const ( MPCGenerateEvent = "mpc:generate" MPCSignEvent = "mpc:sign" MPCReshareEvent = "mpc:reshare" + MPCPresignEvent = "mpc:presign" DefaultConcurrentKeygen = 2 DefaultConcurrentSigning = 20 @@ -46,10 +47,12 @@ type eventConsumer struct { genKeyResultQueue messaging.MessageQueue signingResultQueue messaging.MessageQueue reshareResultQueue messaging.MessageQueue + presignResultQueue messaging.MessageQueue keyGenerationSub messaging.Subscription signingSub messaging.Subscription reshareSub messaging.Subscription + presignSub messaging.Subscription identityStore identity.Store keygenMsgBuffer chan *nats.Msg @@ -72,6 +75,7 @@ func NewEventConsumer( genKeyResultQueue messaging.MessageQueue, signingResultQueue messaging.MessageQueue, reshareResultQueue messaging.MessageQueue, + presignResultQueue messaging.MessageQueue, identityStore identity.Store, ) EventConsumer { maxConcurrentKeygen := viper.GetInt("max_concurrent_keygen") @@ -105,6 +109,7 @@ func NewEventConsumer( genKeyResultQueue: genKeyResultQueue, signingResultQueue: signingResultQueue, reshareResultQueue: reshareResultQueue, + presignResultQueue: presignResultQueue, activeSessions: make(map[string]time.Time), cleanupInterval: 5 * time.Minute, // Run cleanup every 5 minutes sessionTimeout: 30 * time.Minute, // Consider sessions older than 30 minutes stale @@ -142,6 +147,11 @@ func (ec *eventConsumer) Run() { log.Fatal("Failed to consume reshare event", err) } + err = ec.consumePresignEvent() + if err != nil { + log.Fatal("Failed to consume presign event", err) + } + logger.Info("MPC Event consumer started...!") } @@ -634,6 +644,7 @@ func (ec *eventConsumer) consumeTxSigningEvent() error { return nil } + func (ec *eventConsumer) handleSigningSessionError(walletID, txID, networkInternalCode string, err error, contextMsg string, natMsg *nats.Msg) { fullErrMsg := fmt.Sprintf("%s: %v", contextMsg, err) errorCode := event.GetErrorCodeFromError(err) @@ -1023,6 +1034,141 @@ func (ec *eventConsumer) handleReshareSessionError( } } +func (ec *eventConsumer) consumePresignEvent() error { + sub, err := ec.pubsub.Subscribe(MPCPresignEvent, func(natMsg *nats.Msg) { + var msg types.PresignTxMessage + if err := json.Unmarshal(natMsg.Data, &msg); err != nil { + logger.Error("Failed to unmarshal presign message", err) + return + } + if err := ec.identityStore.VerifyInitiatorMessage(msg); err != nil { + logger.Error("Failed to verify initiator message", err) + return + } + + // Only CGGMP21 supports presign + if msg.KeyType != types.KeyTypeCGGMP21 { + ec.handlePresignSessionError(msg.WalletID, msg.TxID, msg.NetworkInternalCode, + fmt.Errorf("presign is only supported for CGGMP21 key type"), + "Unsupported key type for presign", natMsg) + return + } + + session, err := ec.node.CreateTaurusSession(msg.WalletID, ec.mpcThreshold, msg.KeyType, taurus.ActPresign) + if err != nil { + ec.handlePresignSessionError(msg.WalletID, msg.TxID, msg.NetworkInternalCode, + err, "Failed to create presign session", natMsg) + return + } + + ctx := context.Background() + success, err := session.Presign(ctx, msg.TxID) + if err != nil { + ec.handlePresignSessionError(msg.WalletID, msg.TxID, msg.NetworkInternalCode, + err, "Presign operation failed", natMsg) + return + } + + if success { + ec.handlePresignSessionSuccess(msg.WalletID, msg.TxID, msg.NetworkInternalCode, natMsg) + } else { + ec.handlePresignSessionError(msg.WalletID, msg.TxID, msg.NetworkInternalCode, + fmt.Errorf("presign operation returned false"), + "Presign operation failed", natMsg) + } + }) + if err != nil { + return err + } + + ec.presignSub = sub + return nil +} + +// handlePresignSessionSuccess handles successful presign operations +func (ec *eventConsumer) handlePresignSessionSuccess(walletID, txID, networkInternalCode string, natMsg *nats.Msg) { + presignResult := event.PresignResultEvent{ + ResultType: event.ResultTypeSuccess, + NetworkInternalCode: networkInternalCode, + WalletID: walletID, + TxID: txID, + Status: "success", + } + + presignResultBytes, err := json.Marshal(presignResult) + if err != nil { + logger.Error("Failed to marshal presign result event", err, + "walletID", walletID, + "txID", txID, + ) + return + } + + err = ec.presignResultQueue.Enqueue(event.PresignResultTopic, presignResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: composePresignIdempotentKey(txID, natMsg), + }) + if err != nil { + logger.Error("Failed to enqueue presign result event", err, + "walletID", walletID, + "txID", txID, + "payload", string(presignResultBytes), + ) + } + // Presign events don't use reply inboxes, so no need to send reply + logger.Info("[COMPLETED PRESIGN] Presign completed successfully", "walletID", walletID, "txID", txID) +} + +// handlePresignSessionError handles errors that occur during presign operations +func (ec *eventConsumer) handlePresignSessionError(walletID, txID, networkInternalCode string, err error, contextMsg string, natMsg *nats.Msg) { + fullErrMsg := fmt.Sprintf("%s: %v", contextMsg, err) + errorCode := event.GetErrorCodeFromError(err) + + logger.Warn("Presign session error", + "walletID", walletID, + "txID", txID, + "networkInternalCode", networkInternalCode, + "error", err.Error(), + "errorCode", errorCode, + "context", contextMsg, + ) + + presignResult := event.PresignResultEvent{ + ResultType: event.ResultTypeError, + ErrorCode: errorCode, + NetworkInternalCode: networkInternalCode, + WalletID: walletID, + TxID: txID, + ErrorReason: fullErrMsg, + Status: "failed", + } + + presignResultBytes, err := json.Marshal(presignResult) + if err != nil { + logger.Error("Failed to marshal presign result event", err, + "walletID", walletID, + "txID", txID, + ) + return + } + + err = ec.presignResultQueue.Enqueue(event.PresignResultTopic, presignResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: composePresignIdempotentKey(txID, natMsg), + }) + if err != nil { + logger.Error("Failed to enqueue presign result event", err, + "walletID", walletID, + "txID", txID, + "payload", string(presignResultBytes), + ) + } + // Presign events don't use reply inboxes, so no need to send reply +} + +// composePresignIdempotentKey creates an idempotent key for presign operations +func composePresignIdempotentKey(txID string, natMsg *nats.Msg) string { + return fmt.Sprintf("presign:%s:%s", txID, natMsg.Header.Get("Nats-Msg-Id")) +} + // Add a cleanup routine that runs periodically func (ec *eventConsumer) sessionCleanupRoutine() { ticker := time.NewTicker(ec.cleanupInterval) diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index c9e9f8b..51017f9 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -40,6 +40,7 @@ type Node struct { keyinfoStore keyinfo.Store ecdsaPreParams []*keygen.LocalPreParams identityStore identity.Store + presignCache *taurus.PresignCache peerRegistry PeerRegistry } @@ -67,6 +68,7 @@ func NewNode( keyinfoStore: keyinfoStore, peerRegistry: peerRegistry, identityStore: identityStore, + presignCache: taurus.NewPresignCache(), } node.ecdsaPreParams = node.generatePreParams() @@ -155,7 +157,7 @@ func (p *Node) CreateTaurusSession( switch sessionType { case types.KeyTypeCGGMP21: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.CGGMP21, p.pubSub, p.direct, p.identityStore) - session = taurus.NewCGGMP21Session(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) + session = taurus.NewCGGMP21Session(walletID, selfPartyID, allPartyIDs, threshold, p.presignCache, tr, p.kvstore, p.keyinfoStore) case types.KeyTypeTaproot: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.FROSTTaproot, p.pubSub, p.direct, p.identityStore) session = taurus.NewTaprootSession(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) @@ -164,7 +166,7 @@ func (p *Node) CreateTaurusSession( session = taurus.NewFROSTSession(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) } - if act == taurus.ActSign || act == taurus.ActReshare { + if act == taurus.ActSign || act == taurus.ActReshare || act == taurus.ActPresign { err := session.LoadKey(walletID) if err != nil { return nil, err diff --git a/pkg/mpc/taurus/cggmp21.go b/pkg/mpc/taurus/cggmp21.go index 221bb23..6e3c94a 100644 --- a/pkg/mpc/taurus/cggmp21.go +++ b/pkg/mpc/taurus/cggmp21.go @@ -22,8 +22,9 @@ import ( type CGGMP21Session struct { *commonSession - workerPool *pool.Pool - savedData *cmp.Config + workerPool *pool.Pool + savedData *cmp.Config + presignCache *PresignCache } func NewCGGMP21Session( @@ -31,6 +32,7 @@ func NewCGGMP21Session( selfID party.ID, peerIDs party.IDSlice, threshold int, + presignCache *PresignCache, transport Transport, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, @@ -48,6 +50,7 @@ func NewCGGMP21Session( commonSession: commonSession, workerPool: pool.NewPool(0), savedData: nil, + presignCache: presignCache, } } @@ -137,12 +140,27 @@ func (p *CGGMP21Session) Sign(ctx context.Context, msg *big.Int) ([]byte, error) if p.savedData == nil { return nil, errors.New("no key loaded") } - logger.Info("Starting to sign message CGGMP21", "walletID", p.sessionID) + + logger.Info("starting CGGMP21 sign", "walletID", p.sessionID) msgHash := msg.Bytes() - result, err := p.run(ctx, cmp.Sign(p.savedData, p.peerIDs, msgHash, p.workerPool)) - if err != nil { - return nil, err + + var ( + result any + err error + ) + + if presign := p.getCachedPresign(); presign != nil { + result, err = p.run(ctx, cmp.PresignOnline(p.savedData, presign, msgHash, p.workerPool)) + if err != nil { + return nil, fmt.Errorf("presign online failed: %w", err) + } + } else { + result, err = p.run(ctx, cmp.Sign(p.savedData, p.peerIDs, msgHash, p.workerPool)) + if err != nil { + return nil, fmt.Errorf("full sign failed: %w", err) + } } + sig, ok := result.(*ecdsa.Signature) if !ok { return nil, errors.New("unexpected result type") @@ -153,6 +171,27 @@ func (p *CGGMP21Session) Sign(ctx context.Context, msg *big.Int) ([]byte, error) return sig.SigEthereum() } +func (p *CGGMP21Session) Presign(ctx context.Context, txID string) (bool, error) { + if p.savedData == nil { + return false, errors.New("no key loaded") + } + logger.Info("Starting to presign message CGGMP21", "walletID", p.sessionID, "txID", txID) + result, err := p.run(ctx, cmp.Presign(p.savedData, p.peerIDs, p.workerPool)) + if err != nil { + return false, err + } + presig, ok := result.(*ecdsa.PreSignature) + if !ok { + return false, errors.New("unexpected result type") + } + if err = presig.Validate(); err != nil { + return false, errors.New("presign validation failed") + } + + p.presignCache.Put(p.sessionID, txID, presig) + return true, nil +} + func (p *CGGMP21Session) Reshare(ctx context.Context) (res types.ReshareData, err error) { if p.savedData == nil { return res, errors.New("no key loaded") @@ -207,3 +246,16 @@ func (p *CGGMP21Session) Reshare(ctx context.Context) (res types.ReshareData, er func (p *CGGMP21Session) composeKey(sid string) string { return fmt.Sprintf("cggmp21:%s", sid) } + +func (p *CGGMP21Session) getCachedPresign() *ecdsa.PreSignature { + if p.presignCache == nil { + return nil + } + + presig, ok := p.presignCache.Get(p.sessionID) + if !ok || presig == nil { + return nil + } + + return presig +} diff --git a/pkg/mpc/taurus/common.go b/pkg/mpc/taurus/common.go index 8fbdc31..2ff3f18 100644 --- a/pkg/mpc/taurus/common.go +++ b/pkg/mpc/taurus/common.go @@ -32,6 +32,7 @@ type TaurusSession interface { Keygen(ctx context.Context) (types.KeyData, error) Sign(ctx context.Context, msg *big.Int) ([]byte, error) Reshare(ctx context.Context) (types.ReshareData, error) + Presign(ctx context.Context, txID string) (bool, error) } type commonSession struct { @@ -65,6 +66,10 @@ func NewCommonSession( } } +func (p *commonSession) Presign(ctx context.Context, txID string) (bool, error) { + return false, errors.New("not implemented") +} + func (p *commonSession) run(ctx context.Context, proto protocol.StartFunc) (any, error) { h, err := protocol.NewMultiHandler(proto, []byte(p.sessionID)) if err != nil { diff --git a/pkg/mpc/taurus/nats_transport.go b/pkg/mpc/taurus/nats_transport.go index fe1cf6c..88899df 100644 --- a/pkg/mpc/taurus/nats_transport.go +++ b/pkg/mpc/taurus/nats_transport.go @@ -20,6 +20,7 @@ const ( ActKeygen Act = "keygen" ActSign Act = "sign" ActReshare Act = "reshare" + ActPresign Act = "presign" ) type TopicComposer struct { diff --git a/pkg/mpc/taurus/presign.go b/pkg/mpc/taurus/presign.go new file mode 100644 index 0000000..9d033a8 --- /dev/null +++ b/pkg/mpc/taurus/presign.go @@ -0,0 +1,47 @@ +package taurus + +import ( + "sync" + "time" + + "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" +) + +type PresignCache struct { + mu sync.Mutex + data map[string][]PresignEntry // walletID -> entries +} + +type PresignEntry struct { + SessionID string + Result *ecdsa.PreSignature + CreatedAt time.Time +} + +func NewPresignCache() *PresignCache { + return &PresignCache{ + data: make(map[string][]PresignEntry), + } +} + +func (c *PresignCache) Put(walletID, sessionID string, res *ecdsa.PreSignature) { + c.mu.Lock() + defer c.mu.Unlock() + c.data[walletID] = append(c.data[walletID], PresignEntry{ + SessionID: sessionID, + Result: res, + CreatedAt: time.Now(), + }) +} + +func (c *PresignCache) Get(walletID string) (*ecdsa.PreSignature, bool) { + c.mu.Lock() + defer c.mu.Unlock() + entries := c.data[walletID] + if len(entries) == 0 { + return nil, false + } + res := entries[0].Result + c.data[walletID] = entries[1:] + return res, true +} diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index 30ae9f4..f3e7e8d 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -52,6 +52,15 @@ type ResharingMessage struct { Signature []byte `json:"signature,omitempty"` } +type PresignTxMessage struct { + KeyType KeyType `json:"key_type"` + WalletID string `json:"wallet_id"` + NetworkInternalCode string `json:"network_internal_code"` + TxID string `json:"tx_id"` + Tx []byte `json:"tx"` + Signature []byte `json:"signature"` +} + func (m *SignTxMessage) Raw() ([]byte, error) { // omit the Signature field itself when computing the signed‐over data payload := struct { @@ -103,3 +112,29 @@ func (m *ResharingMessage) Sig() []byte { func (m *ResharingMessage) InitiatorID() string { return m.WalletID } + +func (m PresignTxMessage) Raw() ([]byte, error) { + // omit the Signature field itself when computing the signed‐over data + payload := struct { + KeyType KeyType `json:"key_type"` + WalletID string `json:"wallet_id"` + NetworkInternalCode string `json:"network_internal_code"` + TxID string `json:"tx_id"` + Tx []byte `json:"tx"` + }{ + KeyType: m.KeyType, + WalletID: m.WalletID, + NetworkInternalCode: m.NetworkInternalCode, + TxID: m.TxID, + Tx: m.Tx, + } + return json.Marshal(payload) +} + +func (m PresignTxMessage) Sig() []byte { + return m.Signature +} + +func (m PresignTxMessage) InitiatorID() string { + return m.TxID +} From 8e06092e620928a45afa853b2efc9bcb7c47c32c Mon Sep 17 00:00:00 2001 From: vietddude Date: Sun, 26 Oct 2025 04:36:57 +0700 Subject: [PATCH 14/21] feat: enhance PresignCache with TTL support and automatic cleanup --- pkg/mpc/node.go | 2 +- pkg/mpc/taurus/presign.go | 54 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index 51017f9..aa78a68 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -68,7 +68,7 @@ func NewNode( keyinfoStore: keyinfoStore, peerRegistry: peerRegistry, identityStore: identityStore, - presignCache: taurus.NewPresignCache(), + presignCache: taurus.NewPresignCache(10 * time.Minute), } node.ecdsaPreParams = node.generatePreParams() diff --git a/pkg/mpc/taurus/presign.go b/pkg/mpc/taurus/presign.go index 9d033a8..a2f608e 100644 --- a/pkg/mpc/taurus/presign.go +++ b/pkg/mpc/taurus/presign.go @@ -7,9 +7,12 @@ import ( "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" ) +// PresignCache provides an in-memory cache of pre-signature data +// with automatic TTL-based cleanup. type PresignCache struct { mu sync.Mutex data map[string][]PresignEntry // walletID -> entries + ttl time.Duration } type PresignEntry struct { @@ -18,15 +21,27 @@ type PresignEntry struct { CreatedAt time.Time } -func NewPresignCache() *PresignCache { - return &PresignCache{ +// NewPresignCache creates a new cache with optional TTL. +// If ttl <= 0, defaults to 10 minutes. +func NewPresignCache(ttl time.Duration) *PresignCache { + if ttl <= 0 { + ttl = 10 * time.Minute + } + + cache := &PresignCache{ data: make(map[string][]PresignEntry), + ttl: ttl, } + + go cache.startCleanup() + return cache } +// Put adds a new presign result for a wallet. func (c *PresignCache) Put(walletID, sessionID string, res *ecdsa.PreSignature) { c.mu.Lock() defer c.mu.Unlock() + c.data[walletID] = append(c.data[walletID], PresignEntry{ SessionID: sessionID, Result: res, @@ -34,14 +49,47 @@ func (c *PresignCache) Put(walletID, sessionID string, res *ecdsa.PreSignature) }) } +// Get retrieves and removes the oldest available presign for a wallet. func (c *PresignCache) Get(walletID string) (*ecdsa.PreSignature, bool) { c.mu.Lock() defer c.mu.Unlock() + entries := c.data[walletID] if len(entries) == 0 { return nil, false } + res := entries[0].Result - c.data[walletID] = entries[1:] + c.data[walletID] = entries[1:] // pop first entry + if len(c.data[walletID]) == 0 { + delete(c.data, walletID) + } return res, true } + +// startCleanup periodically removes expired presign entries based on TTL. +func (c *PresignCache) startCleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + now := time.Now() + expireBefore := now.Add(-c.ttl) + + c.mu.Lock() + for walletID, entries := range c.data { + filtered := entries[:0] + for _, e := range entries { + if e.CreatedAt.After(expireBefore) { + filtered = append(filtered, e) + } + } + if len(filtered) == 0 { + delete(c.data, walletID) + } else { + c.data[walletID] = filtered + } + } + c.mu.Unlock() + } +} From bb8cf07260e36d4baa942041534b96ab216827d5 Mon Sep 17 00:00:00 2001 From: vietddude Date: Mon, 3 Nov 2025 10:01:31 +0700 Subject: [PATCH 15/21] refactor: update presign handling to use WalletID instead of TxID and remove wallets.json --- .gitignore | 1 + examples/presign/main.go | 16 ++----- pkg/event/presign.go | 14 +++--- pkg/eventconsumer/event_consumer.go | 71 ++++++++++++++--------------- pkg/types/initiator_msg.go | 25 ++++------ pkg/utils/utils.go | 20 -------- wallets.json | 3 -- 7 files changed, 53 insertions(+), 97 deletions(-) delete mode 100644 pkg/utils/utils.go delete mode 100644 wallets.json diff --git a/.gitignore b/.gitignore index 08c8a39..7698f3c 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ node2 config.yaml .vscode .vagrant +wallets.json \ No newline at end of file diff --git a/examples/presign/main.go b/examples/presign/main.go index bd07778..16762c9 100644 --- a/examples/presign/main.go +++ b/examples/presign/main.go @@ -12,7 +12,6 @@ import ( "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" - "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/spf13/viper" ) @@ -65,27 +64,20 @@ func main() { Signer: localSigner, }) - // 2) Once wallet exists, immediately fire a SignTransaction - txID := uuid.New().String() - dummyTx := []byte("deadbeef") // replace with real transaction bytes - txMsg := &types.PresignTxMessage{ - KeyType: types.KeyTypeCGGMP21, - WalletID: "196c6858-30de-4a49-9134-8bc825d40764", // Use the generated wallet ID - NetworkInternalCode: "solana-devnet", - TxID: txID, - Tx: dummyTx, + KeyType: types.KeyTypeCGGMP21, + WalletID: "196c6858-30de-4a49-9134-8bc825d40764", // Use the generated wallet ID } err = mpcClient.PresignTransaction(txMsg) if err != nil { logger.Fatal("PresignTransaction failed", err) } - fmt.Printf("PresignTransaction(%q) sent, awaiting result...\n", txID) + fmt.Printf("PresignTransaction(%q) sent, awaiting result...\n", txMsg.WalletID) // 3) Listen for signing results err = mpcClient.OnPresignResult(func(evt event.PresignResultEvent) { logger.Info("Presign result received", - "txID", evt.TxID, + "walletID", evt.WalletID, "status", evt.Status, ) }) diff --git a/pkg/event/presign.go b/pkg/event/presign.go index 0f74d34..13b782e 100644 --- a/pkg/event/presign.go +++ b/pkg/event/presign.go @@ -8,12 +8,10 @@ const ( ) type PresignResultEvent struct { - ResultType ResultType `json:"result_type"` - ErrorCode ErrorCode `json:"error_code"` - ErrorReason string `json:"error_reason"` - IsTimeout bool `json:"is_timeout"` - NetworkInternalCode string `json:"network_internal_code"` - WalletID string `json:"wallet_id"` - TxID string `json:"tx_id"` - Status string `json:"status"` + ResultType ResultType `json:"result_type"` + ErrorCode ErrorCode `json:"error_code"` + ErrorReason string `json:"error_reason"` + IsTimeout bool `json:"is_timeout"` + WalletID string `json:"wallet_id"` + Status string `json:"status"` } diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 7199306..742ce68 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -1041,40 +1041,47 @@ func (ec *eventConsumer) consumePresignEvent() error { logger.Error("Failed to unmarshal presign message", err) return } - if err := ec.identityStore.VerifyInitiatorMessage(msg); err != nil { + if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { logger.Error("Failed to verify initiator message", err) return } // Only CGGMP21 supports presign if msg.KeyType != types.KeyTypeCGGMP21 { - ec.handlePresignSessionError(msg.WalletID, msg.TxID, msg.NetworkInternalCode, + ec.handlePresignSessionError(msg.WalletID, fmt.Errorf("presign is only supported for CGGMP21 key type"), - "Unsupported key type for presign", natMsg) + "Unsupported key type for presign", + natMsg, + ) return } - session, err := ec.node.CreateTaurusSession(msg.WalletID, ec.mpcThreshold, msg.KeyType, taurus.ActPresign) if err != nil { - ec.handlePresignSessionError(msg.WalletID, msg.TxID, msg.NetworkInternalCode, - err, "Failed to create presign session", natMsg) + ec.handlePresignSessionError(msg.WalletID, + err, "Failed to create presign session", + natMsg, + ) return } ctx := context.Background() - success, err := session.Presign(ctx, msg.TxID) + success, err := session.Presign(ctx, msg.WalletID) if err != nil { - ec.handlePresignSessionError(msg.WalletID, msg.TxID, msg.NetworkInternalCode, - err, "Presign operation failed", natMsg) + ec.handlePresignSessionError(msg.WalletID, + err, "Presign operation failed", + natMsg, + ) return } if success { - ec.handlePresignSessionSuccess(msg.WalletID, msg.TxID, msg.NetworkInternalCode, natMsg) + ec.handlePresignSessionSuccess(msg.WalletID, natMsg) } else { - ec.handlePresignSessionError(msg.WalletID, msg.TxID, msg.NetworkInternalCode, + ec.handlePresignSessionError(msg.WalletID, fmt.Errorf("presign operation returned false"), - "Presign operation failed", natMsg) + "Presign operation failed", + natMsg, + ) } }) if err != nil { @@ -1086,78 +1093,68 @@ func (ec *eventConsumer) consumePresignEvent() error { } // handlePresignSessionSuccess handles successful presign operations -func (ec *eventConsumer) handlePresignSessionSuccess(walletID, txID, networkInternalCode string, natMsg *nats.Msg) { +func (ec *eventConsumer) handlePresignSessionSuccess(walletID string, natMsg *nats.Msg) { presignResult := event.PresignResultEvent{ - ResultType: event.ResultTypeSuccess, - NetworkInternalCode: networkInternalCode, - WalletID: walletID, - TxID: txID, - Status: "success", + ResultType: event.ResultTypeSuccess, + WalletID: walletID, + Status: "success", } presignResultBytes, err := json.Marshal(presignResult) if err != nil { logger.Error("Failed to marshal presign result event", err, "walletID", walletID, - "txID", txID, ) return } err = ec.presignResultQueue.Enqueue(event.PresignResultTopic, presignResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: composePresignIdempotentKey(txID, natMsg), + IdempotententKey: composePresignIdempotentKey(walletID, natMsg), }) if err != nil { logger.Error("Failed to enqueue presign result event", err, "walletID", walletID, - "txID", txID, "payload", string(presignResultBytes), ) } // Presign events don't use reply inboxes, so no need to send reply - logger.Info("[COMPLETED PRESIGN] Presign completed successfully", "walletID", walletID, "txID", txID) + logger.Info("[COMPLETED PRESIGN] Presign completed successfully", "walletID", walletID) } // handlePresignSessionError handles errors that occur during presign operations -func (ec *eventConsumer) handlePresignSessionError(walletID, txID, networkInternalCode string, err error, contextMsg string, natMsg *nats.Msg) { +func (ec *eventConsumer) handlePresignSessionError(walletID string, err error, contextMsg string, natMsg *nats.Msg) { fullErrMsg := fmt.Sprintf("%s: %v", contextMsg, err) errorCode := event.GetErrorCodeFromError(err) logger.Warn("Presign session error", "walletID", walletID, - "txID", txID, - "networkInternalCode", networkInternalCode, "error", err.Error(), "errorCode", errorCode, "context", contextMsg, ) presignResult := event.PresignResultEvent{ - ResultType: event.ResultTypeError, - ErrorCode: errorCode, - NetworkInternalCode: networkInternalCode, - WalletID: walletID, - TxID: txID, - ErrorReason: fullErrMsg, - Status: "failed", + ResultType: event.ResultTypeError, + ErrorCode: errorCode, + WalletID: walletID, + ErrorReason: fullErrMsg, + Status: "failed", } presignResultBytes, err := json.Marshal(presignResult) if err != nil { logger.Error("Failed to marshal presign result event", err, "walletID", walletID, - "txID", txID, ) return } err = ec.presignResultQueue.Enqueue(event.PresignResultTopic, presignResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: composePresignIdempotentKey(txID, natMsg), + IdempotententKey: composePresignIdempotentKey(walletID, natMsg), }) if err != nil { logger.Error("Failed to enqueue presign result event", err, "walletID", walletID, - "txID", txID, "payload", string(presignResultBytes), ) } @@ -1165,8 +1162,8 @@ func (ec *eventConsumer) handlePresignSessionError(walletID, txID, networkIntern } // composePresignIdempotentKey creates an idempotent key for presign operations -func composePresignIdempotentKey(txID string, natMsg *nats.Msg) string { - return fmt.Sprintf("presign:%s:%s", txID, natMsg.Header.Get("Nats-Msg-Id")) +func composePresignIdempotentKey(walletID string, natMsg *nats.Msg) string { + return fmt.Sprintf("presign:%s:%s", walletID, natMsg.Header.Get("Nats-Msg-Id")) } // Add a cleanup routine that runs periodically diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index f3e7e8d..a86118d 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -53,12 +53,9 @@ type ResharingMessage struct { } type PresignTxMessage struct { - KeyType KeyType `json:"key_type"` - WalletID string `json:"wallet_id"` - NetworkInternalCode string `json:"network_internal_code"` - TxID string `json:"tx_id"` - Tx []byte `json:"tx"` - Signature []byte `json:"signature"` + KeyType KeyType `json:"key_type"` + WalletID string `json:"wallet_id"` + Signature []byte `json:"signature"` } func (m *SignTxMessage) Raw() ([]byte, error) { @@ -116,17 +113,11 @@ func (m *ResharingMessage) InitiatorID() string { func (m PresignTxMessage) Raw() ([]byte, error) { // omit the Signature field itself when computing the signed‐over data payload := struct { - KeyType KeyType `json:"key_type"` - WalletID string `json:"wallet_id"` - NetworkInternalCode string `json:"network_internal_code"` - TxID string `json:"tx_id"` - Tx []byte `json:"tx"` + KeyType KeyType `json:"key_type"` + WalletID string `json:"wallet_id"` }{ - KeyType: m.KeyType, - WalletID: m.WalletID, - NetworkInternalCode: m.NetworkInternalCode, - TxID: m.TxID, - Tx: m.Tx, + KeyType: m.KeyType, + WalletID: m.WalletID, } return json.Marshal(payload) } @@ -136,5 +127,5 @@ func (m PresignTxMessage) Sig() []byte { } func (m PresignTxMessage) InitiatorID() string { - return m.TxID + return m.WalletID } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go deleted file mode 100644 index bd62682..0000000 --- a/pkg/utils/utils.go +++ /dev/null @@ -1,20 +0,0 @@ -package utils - -import ( - "crypto/sha256" - "io" - "os" - - "github.com/rs/zerolog" -) - -// GetMessageHash returns the SHA256 hash of the message -func GetMessageHash(msgBytes []byte) []byte { - hash := sha256.Sum256(msgBytes) - return hash[:] -} - -// ZerologConsoleWriter returns a console writer for zerolog -func ZerologConsoleWriter() io.Writer { - return zerolog.ConsoleWriter{Out: os.Stdout} -} diff --git a/wallets.json b/wallets.json deleted file mode 100644 index 0e05a2d..0000000 --- a/wallets.json +++ /dev/null @@ -1,3 +0,0 @@ -[ - "fb89e64c-e2ee-4e1c-a04e-2fa728dae170" -] \ No newline at end of file From 07e8c7fc8273dc9bdcd1156f6be316e6a36d3079 Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 6 Nov 2025 17:10:47 +0700 Subject: [PATCH 16/21] feat: update wallet creation and signing processes to support multiple protocols --- examples/generate/kms/main.go | 6 +- examples/presign/main.go | 5 +- examples/reshare/main.go | 1 + examples/sign/main.go | 7 +- go.mod | 7 +- pkg/client/client.go | 8 +- pkg/event/keygen.go | 8 +- pkg/event/presign.go | 1 + pkg/eventconsumer/event_consumer.go | 645 +++++----------------------- pkg/eventconsumer/keygen_runner.go | 153 +++++++ pkg/eventconsumer/reshare_runner.go | 390 +++++++++++++++++ pkg/eventconsumer/sign_runner.go | 274 ++++++++++++ pkg/mpc/node.go | 10 +- pkg/types/initiator_msg.go | 133 ++++-- 14 files changed, 1049 insertions(+), 599 deletions(-) create mode 100644 pkg/eventconsumer/keygen_runner.go create mode 100644 pkg/eventconsumer/reshare_runner.go create mode 100644 pkg/eventconsumer/sign_runner.go diff --git a/examples/generate/kms/main.go b/examples/generate/kms/main.go index 9bf4025..a27d54f 100644 --- a/examples/generate/kms/main.go +++ b/examples/generate/kms/main.go @@ -117,7 +117,11 @@ func main() { for _, walletID := range walletIDs { wg.Add(1) // Add to WaitGroup BEFORE attempting to create wallet - if err := mpcClient.CreateWallet(walletID); err != nil { + if err := mpcClient.CreateWallet(&types.GenerateKeyMessage{ + WalletID: walletID, + ECDSAProtocol: types.ProtocolCGGMP21, + EdDSAProtocol: types.ProtocolGG18, + }); err != nil { logger.Error("CreateWallet failed", err) walletStartTimes.Delete(walletID) wg.Done() // Now this is safe since we added 1 above diff --git a/examples/presign/main.go b/examples/presign/main.go index 16762c9..77de2e9 100644 --- a/examples/presign/main.go +++ b/examples/presign/main.go @@ -12,6 +12,7 @@ import ( "github.com/fystack/mpcium/pkg/event" "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/spf13/viper" ) @@ -65,8 +66,10 @@ func main() { }) txMsg := &types.PresignTxMessage{ - KeyType: types.KeyTypeCGGMP21, + KeyType: types.KeyTypeSecp256k1, + Protocol: types.ProtocolCGGMP21, WalletID: "196c6858-30de-4a49-9134-8bc825d40764", // Use the generated wallet ID + TxID: uuid.New().String(), } err = mpcClient.PresignTransaction(txMsg) if err != nil { diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 47c4d85..03f6c5b 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -88,6 +88,7 @@ func main() { NewThreshold: 1, // t+1 <= len(NodeIDs) KeyType: types.KeyTypeEd25519, + Protocol: types.ProtocolFROST, } err = mpcClient.Resharing(resharingMsg) if err != nil { diff --git a/examples/sign/main.go b/examples/sign/main.go index 3424610..51768b9 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -70,8 +70,9 @@ func main() { dummyTx := []byte("deadbeef") // replace with real transaction bytes txMsg := &types.SignTxMessage{ - KeyType: types.KeyTypeEd25519, - WalletID: "ad24f678-b04b-4149-bcf6-bf9c90df8e63", // Use the generated wallet ID + KeyType: types.KeyTypeSecp256k1, + Protocol: types.ProtocolFROST, + WalletID: "6d553e80-a1dc-4894-9eaf-b81e3fe0c94a", // Use the generated wallet ID NetworkInternalCode: "solana-devnet", TxID: txID, Tx: dummyTx, @@ -87,6 +88,8 @@ func main() { logger.Info("Signing result received", "txID", evt.TxID, "signature", fmt.Sprintf("%x", evt.Signature), + "error", evt.ErrorReason, + "errorCode", evt.ErrorCode, ) }) if err != nil { diff --git a/go.mod b/go.mod index 21eee92..8f858c9 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,10 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.18.8 github.com/aws/aws-sdk-go-v2/service/kms v1.45.0 github.com/bnb-chain/tss-lib/v2 v2.0.2 + github.com/btcsuite/btcd/btcec/v2 v2.3.2 github.com/decred/dcrd/dcrec/edwards/v2 v2.0.3 - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 github.com/dgraph-io/badger/v4 v4.7.0 + github.com/fxamacker/cbor/v2 v2.4.0 github.com/google/uuid v1.6.0 github.com/hashicorp/consul/api v1.32.1 github.com/mitchellh/mapstructure v1.5.0 @@ -25,6 +26,7 @@ require ( github.com/taurusgroup/multi-party-sig v0.7.0-alpha-2025-01-28 github.com/urfave/cli/v3 v3.3.2 golang.org/x/crypto v0.37.0 + golang.org/x/sync v0.13.0 golang.org/x/term v0.31.0 ) @@ -43,17 +45,16 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.38.1 // indirect github.com/aws/smithy-go v1.23.0 // indirect github.com/btcsuite/btcd v0.24.2 // indirect - github.com/btcsuite/btcd/btcec/v2 v2.3.2 // indirect github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect github.com/btcsuite/btcutil v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cronokirby/saferith v0.33.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/fxamacker/cbor/v2 v2.4.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/pkg/client/client.go b/pkg/client/client.go index 0413abf..2467091 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -19,7 +19,7 @@ const ( ) type MPCClient interface { - CreateWallet(walletID string) error + CreateWallet(msg *types.GenerateKeyMessage) error OnWalletCreationResult(callback func(event event.KeygenResultEvent)) error SignTransaction(msg *types.SignTxMessage) error @@ -110,11 +110,7 @@ func NewMPCClient(opts Options) MPCClient { } // CreateWallet generates a GenerateKeyMessage, signs it, and publishes it. -func (c *mpcClient) CreateWallet(walletID string) error { - // build the message - msg := &types.GenerateKeyMessage{ - WalletID: walletID, - } +func (c *mpcClient) CreateWallet(msg *types.GenerateKeyMessage) error { // compute the canonical raw bytes raw, err := msg.Raw() if err != nil { diff --git a/pkg/event/keygen.go b/pkg/event/keygen.go index 5cc9ec0..6e12da5 100644 --- a/pkg/event/keygen.go +++ b/pkg/event/keygen.go @@ -7,11 +7,9 @@ const ( ) type KeygenResultEvent struct { - WalletID string `json:"wallet_id"` - ECDSAPubKey []byte `json:"ecdsa_pub_key"` - EDDSAPubKey []byte `json:"eddsa_pub_key"` - CGGMP21PubKey []byte `json:"cggmp21_pub_key"` - TaprootPubKey []byte `json:"taproot_pub_key"` + WalletID string `json:"wallet_id"` + ECDSAPubKey []byte `json:"ecdsa_pub_key"` + EDDSAPubKey []byte `json:"eddsa_pub_key"` ResultType ResultType `json:"result_type"` ErrorReason string `json:"error_reason"` diff --git a/pkg/event/presign.go b/pkg/event/presign.go index 13b782e..7f5c0c0 100644 --- a/pkg/event/presign.go +++ b/pkg/event/presign.go @@ -13,5 +13,6 @@ type PresignResultEvent struct { ErrorReason string `json:"error_reason"` IsTimeout bool `json:"is_timeout"` WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` Status string `json:"status"` } diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index 742ce68..a3034aa 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "log" - "math/big" "sync" "time" @@ -163,140 +162,104 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { baseCtx, baseCancel := context.WithTimeout(context.Background(), KeyGenTimeOut) defer baseCancel() - raw := natMsg.Data var msg types.GenerateKeyMessage - if err := json.Unmarshal(raw, &msg); err != nil { - logger.Error("Failed to unmarshal keygen message", err) + if err := json.Unmarshal(natMsg.Data, &msg); err != nil { ec.handleKeygenSessionError(msg.WalletID, err, "Failed to unmarshal keygen message", natMsg) return } - if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { - logger.Error("Failed to verify initiator message", err) ec.handleKeygenSessionError(msg.WalletID, err, "Failed to verify initiator message", natMsg) return } - walletID := msg.WalletID - ecdsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeECDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) - if err != nil { - ec.handleKeygenSessionError(walletID, err, "Failed to create ECDSA key generation session", natMsg) - return - } - eddsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeEDDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) - if err != nil { - ec.handleKeygenSessionError(walletID, err, "Failed to create EdDSA key generation session", natMsg) + if err := types.ValidateKeyProtocol(types.KeyTypeSecp256k1, msg.ECDSAProtocol); err != nil { + ec.handleKeygenSessionError(msg.WalletID, err, "Invalid ECDSA protocol", natMsg) return } - cggmp21Session, err := ec.node.CreateTaurusSession(walletID, ec.mpcThreshold, types.KeyTypeCGGMP21, taurus.ActKeygen) - if err != nil { - logger.Error("Failed to create CMP session", err, "walletID", walletID) - ec.handleKeygenSessionError(walletID, err, "Failed to create CMP key generation session", natMsg) - return - } - taprootSession, err := ec.node.CreateTaurusSession(walletID, ec.mpcThreshold, types.KeyTypeTaproot, taurus.ActKeygen) - if err != nil { - logger.Error("Failed to create Taproot session", err, "walletID", walletID) - ec.handleKeygenSessionError(walletID, err, "Failed to create Taproot key generation session", natMsg) + + if err := types.ValidateKeyProtocol(types.KeyTypeEd25519, msg.EdDSAProtocol); err != nil { + ec.handleKeygenSessionError(msg.WalletID, err, "Invalid EdDSA protocol", natMsg) return } - ecdsaSession.Init() - eddsaSession.Init() + walletID := msg.WalletID + logger.Info( + "[KEYGEN START]", + "walletID", + walletID, + "ecdsa_protocol", + msg.ECDSAProtocol, + "eddsa_protocol", + msg.EdDSAProtocol, + ) + + ctx, cancelAll := context.WithCancel(baseCtx) + defer cancelAll() - ctxEcdsa, doneEcdsa := context.WithCancel(baseCtx) - ctxEddsa, doneEddsa := context.WithCancel(baseCtx) - ctxCggmp21, doneCggmp21 := context.WithCancel(baseCtx) - ctxTaproot, doneTaproot := context.WithCancel(baseCtx) + successEvent := &event.KeygenResultEvent{ + WalletID: walletID, + ResultType: event.ResultTypeSuccess, + ECDSAPubKey: nil, + EDDSAPubKey: nil, + } - successEvent := &event.KeygenResultEvent{WalletID: walletID, ResultType: event.ResultTypeSuccess} + errCh := make(chan error, 2) var wg sync.WaitGroup - wg.Add(4) - // Channel to communicate errors from goroutines to main function - errorChan := make(chan error, 4) + wg.Add(2) + // run ECDSA keygen go func() { defer wg.Done() - select { - case <-ctxEcdsa.Done(): - successEvent.ECDSAPubKey = ecdsaSession.GetPubKeyResult() - case err := <-ecdsaSession.ErrChan(): - logger.Error("ECDSA keygen session error", err) - ec.handleKeygenSessionError(walletID, err, "ECDSA keygen session error", natMsg) - errorChan <- err - doneEcdsa() - } - }() - go func() { - defer wg.Done() - select { - case <-ctxEddsa.Done(): - successEvent.EDDSAPubKey = eddsaSession.GetPubKeyResult() - case err := <-eddsaSession.ErrChan(): - logger.Error("EdDSA keygen session error", err) - ec.handleKeygenSessionError(walletID, err, "EdDSA keygen session error", natMsg) - errorChan <- err - doneEddsa() - } - }() - go func() { - defer wg.Done() - data, err := cggmp21Session.Keygen(ctxCggmp21) + pub, err := ec.runECDSAKeygen(ctx, walletID, msg.ECDSAProtocol, natMsg) if err != nil { - logger.Error("Failed to generate key", err) - errorChan <- err + errCh <- err + cancelAll() return } - - logger.Info("CGGMP21 Keygen completed successfully", "walletID", walletID, "payloadLength", len(data.Payload)) - successEvent.CGGMP21PubKey = data.PubKeyBytes - doneCggmp21() + successEvent.ECDSAPubKey = pub }() + // run EdDSA keygen go func() { defer wg.Done() - data, err := taprootSession.Keygen(ctxTaproot) + pub, err := ec.runEdDSAKeygen(ctx, walletID, msg.EdDSAProtocol, natMsg) if err != nil { - logger.Error("Failed to generate key", err) - errorChan <- err + errCh <- err + cancelAll() return } - - logger.Info("Taproot Keygen completed successfully", "walletID", walletID, "payloadLength", len(data.Payload)) - successEvent.TaprootPubKey = data.PubKeyBytes - doneTaproot() + successEvent.EDDSAPubKey = pub }() - ecdsaSession.ListenToIncomingMessageAsync() - eddsaSession.ListenToIncomingMessageAsync() - - // Temporary delay for peer setup - ec.warmUpSession() - go ecdsaSession.GenerateKey(doneEcdsa) - go eddsaSession.GenerateKey(doneEddsa) - - // Wait for completion or timeout - doneAll := make(chan struct{}) + waitDone := make(chan struct{}) go func() { wg.Wait() - close(doneAll) + close(waitDone) }() select { - case <-doneAll: - // Check if any errors occurred during execution - select { - case <-errorChan: - // Error already handled by the goroutine, just return early - return - default: - // No errors, continue with success + case <-waitDone: + close(errCh) + for err := range errCh { + if err != nil { + return + } + } + case err := <-errCh: + cancelAll() + if err != nil { + logger.Error("keygen failed", err, "walletID", walletID) } + return case <-baseCtx.Done(): - // timeout occurred - logger.Warn("Key generation timed out", "walletID", walletID, "timeout", KeyGenTimeOut) - ec.handleKeygenSessionError(walletID, fmt.Errorf("keygen session timed out after %v", KeyGenTimeOut), "Key generation timed out", natMsg) + cancelAll() + ec.handleKeygenSessionError( + walletID, + fmt.Errorf("keygen timeout after %v", KeyGenTimeOut), + "Keygen timeout", + natMsg, + ) return } @@ -406,6 +369,18 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { return } + if verr := types.ValidateKeyProtocol(msg.KeyType, msg.Protocol); verr != nil { + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + verr, + verr.Error(), + natMsg, + ) + return + } + logger.Info( "Received signing event", "waleltID", @@ -420,216 +395,31 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { // Check for duplicate session and track if new if ec.checkDuplicateSession(msg.WalletID, msg.TxID) { - duplicateErr := fmt.Errorf("duplicate signing request detected for walletID=%s txID=%s", msg.WalletID, msg.TxID) - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - duplicateErr, - "Duplicate session", - natMsg, - ) - return - } - - var session mpc.SigningSession - idempotentKey := composeSigningIdempotentKey(msg.TxID, natMsg) - var sessionErr error - switch msg.KeyType { - case types.KeyTypeSecp256k1: - session, sessionErr = ec.node.CreateSigningSession( - mpc.SessionTypeECDSA, - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - ec.signingResultQueue, - idempotentKey, - ) - case types.KeyTypeEd25519: - session, sessionErr = ec.node.CreateSigningSession( - mpc.SessionTypeEDDSA, - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - ec.signingResultQueue, - idempotentKey, - ) - case types.KeyTypeCGGMP21, types.KeyTypeTaproot, types.KeyTypeFROST: - ec.handleTaurusSigning(msg.KeyType, msg, natMsg) - return - default: - sessionErr = fmt.Errorf("unsupported key type: %v", msg.KeyType) - } - if sessionErr != nil { - if errors.Is(sessionErr, mpc.ErrNotEnoughParticipants) { - logger.Info( - "RETRY LATER: Not enough participants to sign", - "walletID", msg.WalletID, - "txID", msg.TxID, - "nodeID", ec.node.ID(), - ) - //Return for retry later - return - } - - if errors.Is(sessionErr, mpc.ErrNotInParticipantList) { - logger.Info("Node is not in participant list for this wallet, skipping signing", - "walletID", msg.WalletID, - "txID", msg.TxID, - "nodeID", ec.node.ID(), - ) - // Skip signing instead of treating as error - return - } - - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - sessionErr, - "Failed to create signing session", - natMsg, - ) - return - } - - txBigInt := new(big.Int).SetBytes(msg.Tx) - err = session.Init(txBigInt) - if err != nil { - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - "Failed to init signing session", - natMsg, - ) - return - } - - // Mark session as already processed - ec.addSession(msg.WalletID, msg.TxID) - - ctx, done := context.WithCancel(context.Background()) - go func() { - for { - select { - case <-ctx.Done(): - return - case err := <-session.ErrChan(): - if err != nil { - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - "Failed to sign tx", - natMsg, - ) - return - } - } - } - }() - - session.ListenToIncomingMessageAsync() - // TODO: use consul distributed lock here, only sign after all nodes has already completed listing to incoming message async - // The purpose of the sleep is to be ensuring that the node has properly set up its message listeners - // before it starts the signing process. If the signing process starts sending messages before other nodes - // have set up their listeners, those messages might be missed, potentially causing the signing process to fail. - // One solution: - // The messaging includes mechanisms for direct point-to-point communication (in point2point.go). - // The nodes could explicitly coordinate through request-response patterns before starting signing - ec.warmUpSession() - - onSuccess := func(data []byte) { - done() - ec.sendReplyToRemoveMsg(natMsg) - } - go session.Sign(onSuccess) -} - -func (ec *eventConsumer) handleTaurusSigning(keyType types.KeyType, msg types.SignTxMessage, natMsg *nats.Msg) { - logger.Info("Starting signing", "walletID", msg.WalletID, "txID", msg.TxID, "keyType", keyType) - session, err := ec.node.CreateTaurusSession(msg.WalletID, ec.mpcThreshold, keyType, taurus.ActSign) - if err != nil { - logger.Error("Failed to create session", err, "walletID", msg.WalletID) - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - fmt.Sprintf("Failed to create %s session: %v", keyType, err), - natMsg, - ) - return - } - - // Convert transaction bytes to big.Int - txBigInt := new(big.Int).SetBytes(msg.Tx) - - // Create context for signing - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - signature, err := session.Sign(ctx, txBigInt) - if err != nil { - logger.Error("signing failed", err, "keyType", keyType, "walletID", msg.WalletID, "txID", msg.TxID) - ec.handleSigningSessionError( + duplicateErr := fmt.Errorf( + "duplicate signing request detected for walletID=%s txID=%s", msg.WalletID, msg.TxID, - msg.NetworkInternalCode, - err, - fmt.Sprintf("%s signing failed", keyType), - natMsg, ) - return - } - - // Create signing result event - signingResult := event.SigningResultEvent{ - ResultType: event.ResultTypeSuccess, - NetworkInternalCode: msg.NetworkInternalCode, - WalletID: msg.WalletID, - TxID: msg.TxID, - Signature: signature, // Returns the full signature - } - - // Marshal and enqueue the result - signingResultBytes, err := json.Marshal(signingResult) - if err != nil { - logger.Error("Failed to marshal signing result event", err, "keyType", keyType, "walletID", msg.WalletID, "txID", msg.TxID) ec.handleSigningSessionError( msg.WalletID, msg.TxID, msg.NetworkInternalCode, - err, - fmt.Sprintf("Failed to marshal %s signing result", keyType), + duplicateErr, + "Duplicate session", natMsg, ) return } - // Enqueue the signing result - err = ec.signingResultQueue.Enqueue(event.SigningResultCompleteTopic, signingResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: composeSigningIdempotentKey(msg.TxID, natMsg), - }) - if err != nil { - logger.Error("Failed to enqueue signing result event", err, "keyType", keyType, "walletID", msg.WalletID, "txID", msg.TxID) - ec.handleSigningSessionError( - msg.WalletID, - msg.TxID, - msg.NetworkInternalCode, - err, - fmt.Sprintf("Failed to enqueue %s signing result", keyType), - natMsg, - ) + // Route Taurus signing by algorithm (matches keygen behavior) + if msg.Protocol == types.ProtocolCGGMP21 || msg.Protocol == types.ProtocolTaproot || + msg.Protocol == types.ProtocolFROST { + ec.handleTaurusSigning(msg.Protocol, msg, natMsg) return } - // Send reply and log success - ec.sendReplyToRemoveMsg(natMsg) - logger.Info("[COMPLETED SIGN] signing completed successfully", "keyType", keyType, "walletID", msg.WalletID, "txID", msg.TxID) + // Classic signing (ECDSA/EDDSA) + ec.runClassicSigning(msg, natMsg) } func (ec *eventConsumer) consumeTxSigningEvent() error { @@ -726,7 +516,25 @@ func (ec *eventConsumer) consumeReshareEvent() error { if err := ec.identityStore.VerifyInitiatorMessage(&msg); err != nil { logger.Error("Failed to verify initiator message", err) - ec.handleReshareSessionError(msg.WalletID, msg.KeyType, msg.NewThreshold, err, "Failed to verify initiator message", natMsg) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + "Failed to verify initiator message", + natMsg, + ) + return + } + if verr := types.ValidateKeyProtocol(msg.KeyType, msg.Protocol); verr != nil { + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + verr, + verr.Error(), + natMsg, + ) return } @@ -736,254 +544,29 @@ func (ec *eventConsumer) consumeReshareEvent() error { sessionType, err := sessionTypeFromKeyType(keyType) if err != nil { logger.Error("Failed to get session type", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to get session type", natMsg) - return - } - // Handle CMP reshare separately - if keyType == types.KeyTypeCGGMP21 || keyType == types.KeyTypeTaproot || keyType == types.KeyTypeFROST { - ec.handleTaurusReshare(msg, natMsg) - return - } - - createSession := func(isNewPeer bool) (mpc.ReshareSession, error) { - return ec.node.CreateReshareSession( - sessionType, + ec.handleReshareSessionError( walletID, + keyType, msg.NewThreshold, - msg.NodeIDs, - isNewPeer, - ec.reshareResultQueue, + err, + "Failed to get session type", + natMsg, ) - } - - oldSession, err := createSession(false) - if err != nil { - logger.Error("Failed to create old reshare session", err, "walletID", walletID) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to create old reshare session", natMsg) - return - } - newSession, err := createSession(true) - if err != nil { - logger.Error("Failed to create new reshare session", err, "walletID", walletID) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to create new reshare session", natMsg) return } - - if oldSession == nil && newSession == nil { - logger.Info("Node is not participating in this reshare (neither old nor new)", "walletID", walletID) + // Handle CMP reshare separately by algorithm + if msg.Protocol == types.ProtocolCGGMP21 || msg.Protocol == types.ProtocolTaproot || + msg.Protocol == types.ProtocolFROST { + ec.handleTaurusReshare(msg, natMsg) return } - ctx := context.Background() - var wg sync.WaitGroup - - successEvent := &event.ResharingResultEvent{ - WalletID: walletID, - NewThreshold: msg.NewThreshold, - KeyType: msg.KeyType, - ResultType: event.ResultTypeSuccess, - } - - if oldSession != nil { - err := oldSession.Init() - if err != nil { - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to init old reshare session", natMsg) - return - } - oldSession.ListenToIncomingMessageAsync() - } - - if newSession != nil { - err := newSession.Init() - if err != nil { - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to init new reshare session", natMsg) - return - } - newSession.ListenToIncomingMessageAsync() - // In resharing process, we need to ensure that the new session is aware of the old committee peers. - // Then new committee peers can start listening to the old committee peers - // and thus enable receiving direct messages from them. - extraOldCommiteePeers := newSession.GetLegacyCommitteePeers() - newSession.ListenToPeersAsync(extraOldCommiteePeers) - } - - ec.warmUpSession() - if oldSession != nil { - ctxOld, doneOld := context.WithCancel(ctx) - go oldSession.Reshare(doneOld) - - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-ctxOld.Done(): - return - case err := <-oldSession.ErrChan(): - logger.Error("Old reshare session error", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Old reshare session error", natMsg) - doneOld() - return - } - } - }() - } - - if newSession != nil { - ctxNew, doneNew := context.WithCancel(ctx) - go newSession.Reshare(doneNew) - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-ctxNew.Done(): - successEvent.PubKey = newSession.GetPubKeyResult() - return - case err := <-newSession.ErrChan(): - logger.Error("New reshare session error", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "New reshare session error", natMsg) - doneNew() - return - } - } - }() - } - - wg.Wait() - logger.Info("Reshare session finished", "walletID", walletID, "pubKey", fmt.Sprintf("%x", successEvent.PubKey)) - - if newSession != nil { - successBytes, err := json.Marshal(successEvent) - if err != nil { - logger.Error("Failed to marshal reshare success event", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to marshal reshare success event", natMsg) - return - } - - key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, msg.SessionID) - err = ec.reshareResultQueue.Enqueue( - key, - successBytes, - &messaging.EnqueueOptions{ - IdempotententKey: composeReshareIdempotentKey(msg.SessionID, natMsg), - }) - if err != nil { - logger.Error("Failed to publish reshare success message", err) - ec.handleReshareSessionError(walletID, keyType, msg.NewThreshold, err, "Failed to publish reshare success message", natMsg) - return - } - logger.Info("[COMPLETED RESHARE] Successfully published", "walletID", walletID) - } else { - logger.Info("[COMPLETED RESHARE] Done (not a new party)", "walletID", walletID) - } + ec.runClassicReshare(msg, natMsg, sessionType) }) - ec.reshareSub = sub return err } -// NOTE: In Taurus reshare, it just refresh the keyshare of each node but keep the same public key and threshold. -// Therefore, we don't need to create new party sessions for CMP reshare. -func (ec *eventConsumer) handleTaurusReshare(msg types.ResharingMessage, natMsg *nats.Msg) { - logger.Info("Starting reshare", "walletID", msg.WalletID, "sessionID", msg.SessionID, "keyType", msg.KeyType) - - // Create Taurus session for reshare - session, err := ec.node.CreateTaurusSession(msg.WalletID, msg.NewThreshold, types.KeyTypeCGGMP21, taurus.ActReshare) - if err != nil { - logger.Error("Failed to create reshare session", err, "walletID", msg.WalletID, "keyType", msg.KeyType) - ec.handleReshareSessionError( - msg.WalletID, - msg.KeyType, - msg.NewThreshold, - err, - fmt.Sprintf("Failed to create %s reshare session", msg.KeyType), - natMsg, - ) - return - } - - // Load the existing key for reshare - if err := session.LoadKey(msg.WalletID); err != nil { - logger.Error("Failed to load key for reshare", err, "walletID", msg.WalletID, "keyType", msg.KeyType) - ec.handleReshareSessionError( - msg.WalletID, - msg.KeyType, - msg.NewThreshold, - err, - fmt.Sprintf("Failed to load key for %s reshare", msg.KeyType), - natMsg, - ) - return - } - - // Create context for reshare - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Longer timeout for reshare - defer cancel() - - // Perform reshare - keyData, err := session.Reshare(ctx) - if err != nil { - logger.Error("Reshare failed", err, "walletID", msg.WalletID, "sessionID", msg.SessionID, "keyType", msg.KeyType) - ec.handleReshareSessionError( - msg.WalletID, - msg.KeyType, - msg.NewThreshold, - err, - fmt.Sprintf("Reshare failed for %s", msg.KeyType), - natMsg, - ) - return - } - - // Create reshare result event - reshareResult := event.ResharingResultEvent{ - ResultType: event.ResultTypeSuccess, - WalletID: msg.WalletID, - NewThreshold: keyData.Threshold, - KeyType: msg.KeyType, - PubKey: keyData.PubKeyBytes, - } - - // Marshal and enqueue the result - reshareResultBytes, err := json.Marshal(reshareResult) - if err != nil { - logger.Error("Failed to marshal reshare result event", err, "walletID", msg.WalletID, "sessionID", msg.SessionID, "keyType", msg.KeyType) - ec.handleReshareSessionError( - msg.WalletID, - msg.KeyType, - msg.NewThreshold, - err, - fmt.Sprintf("Failed to marshal %s reshare result", msg.KeyType), - natMsg, - ) - return - } - - // Enqueue the reshare result - key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, msg.SessionID) - err = ec.reshareResultQueue.Enqueue(key, reshareResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: composeReshareIdempotentKey(msg.SessionID, natMsg), - }) - if err != nil { - logger.Error("Failed to enqueue reshare result event", err, "walletID", msg.WalletID, "sessionID", msg.SessionID, "keyType", msg.KeyType) - ec.handleReshareSessionError( - msg.WalletID, - msg.KeyType, - msg.NewThreshold, - err, - fmt.Sprintf("Failed to enqueue %s reshare result", msg.KeyType), - natMsg, - ) - return - } - - // Remove this line - don't send reply for reshare messages - // ec.sendReplyToRemoveMsg(natMsg) - - logger.Info("[COMPLETED RESHARE] CMP reshare completed successfully", "walletID", msg.WalletID, "sessionID", msg.SessionID) -} - // handleReshareSessionError handles errors that occur during reshare operations func (ec *eventConsumer) handleReshareSessionError( walletID string, @@ -1047,7 +630,7 @@ func (ec *eventConsumer) consumePresignEvent() error { } // Only CGGMP21 supports presign - if msg.KeyType != types.KeyTypeCGGMP21 { + if msg.Protocol != types.ProtocolCGGMP21 { ec.handlePresignSessionError(msg.WalletID, fmt.Errorf("presign is only supported for CGGMP21 key type"), "Unsupported key type for presign", @@ -1055,7 +638,7 @@ func (ec *eventConsumer) consumePresignEvent() error { ) return } - session, err := ec.node.CreateTaurusSession(msg.WalletID, ec.mpcThreshold, msg.KeyType, taurus.ActPresign) + session, err := ec.node.CreateTaurusSession(msg.WalletID, ec.mpcThreshold, msg.Protocol, taurus.ActPresign) if err != nil { ec.handlePresignSessionError(msg.WalletID, err, "Failed to create presign session", @@ -1251,12 +834,6 @@ func sessionTypeFromKeyType(keyType types.KeyType) (mpc.SessionType, error) { return mpc.SessionTypeECDSA, nil case types.KeyTypeEd25519: return mpc.SessionTypeEDDSA, nil - case types.KeyTypeCGGMP21: - return mpc.SessionTypeCGGMP21, nil - case types.KeyTypeTaproot: - return mpc.SessionTypeTaproot, nil - case types.KeyTypeFROST: - return mpc.SessionTypeFROST, nil default: logger.Warn("Unsupported key type", "keyType", keyType) return "", fmt.Errorf("unsupported key type: %v", keyType) diff --git a/pkg/eventconsumer/keygen_runner.go b/pkg/eventconsumer/keygen_runner.go new file mode 100644 index 0000000..3d52659 --- /dev/null +++ b/pkg/eventconsumer/keygen_runner.go @@ -0,0 +1,153 @@ +package eventconsumer + +import ( + "context" + "fmt" + + "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/mpc/taurus" + "github.com/fystack/mpcium/pkg/types" + "github.com/nats-io/nats.go" +) + +func (ec *eventConsumer) runECDSAKeygen( + ctx context.Context, + walletID string, + algo types.Protocol, + natMsg *nats.Msg, +) ([]byte, error) { + switch algo { + case types.ProtocolCGGMP21: + ts, err := ec.node.CreateTaurusSession( + walletID, + ec.mpcThreshold, + types.ProtocolCGGMP21, + taurus.ActKeygen, + ) + if err != nil { + return nil, err + } + res, err := ts.Keygen(ctx) + if err != nil { + return nil, err + } + return res.PubKeyBytes, nil + + case types.ProtocolFROST: + ts, err := ec.node.CreateTaurusSession( + walletID, + ec.mpcThreshold, + types.ProtocolFROST, + taurus.ActKeygen, + ) + if err != nil { + return nil, err + } + res, err := ts.Keygen(ctx) + if err != nil { + return nil, err + } + return res.PubKeyBytes, nil + + case types.ProtocolTaproot: + ts, err := ec.node.CreateTaurusSession( + walletID, + ec.mpcThreshold, + types.ProtocolTaproot, + taurus.ActKeygen, + ) + if err != nil { + return nil, err + } + res, err := ts.Keygen(ctx) + if err != nil { + return nil, err + } + return res.PubKeyBytes, nil + case types.ProtocolGG18: + fallthrough + default: + // Fallback to GG18 ECDSA when algorithm is GG18 or unspecified/unknown + sess, err := ec.node.CreateKeyGenSession( + mpc.SessionTypeECDSA, + walletID, + ec.mpcThreshold, + ec.genKeyResultQueue, + ) + if err != nil { + ec.handleKeygenSessionError( + walletID, + err, + "Failed to create ECDSA (GG18) session", + natMsg, + ) + return nil, err + } + sess.Init() + sess.ListenToIncomingMessageAsync() + ec.warmUpSession() + + ctxLocal, cancel := context.WithCancel(ctx) + defer cancel() + go sess.GenerateKey(cancel) + + select { + case err := <-sess.ErrChan(): + if err != nil { + return nil, err + } + case <-ctxLocal.Done(): + // success + case <-ctx.Done(): + return nil, fmt.Errorf("ECDSA keygen cancelled") + } + return sess.GetPubKeyResult(), nil + } +} + +func (ec *eventConsumer) runEdDSAKeygen( + ctx context.Context, + walletID string, + algo types.Protocol, + natMsg *nats.Msg, +) ([]byte, error) { + switch algo { + case types.ProtocolGG18: + fallthrough + default: + sess, err := ec.node.CreateKeyGenSession( + mpc.SessionTypeEDDSA, + walletID, + ec.mpcThreshold, + ec.genKeyResultQueue, + ) + if err != nil { + ec.handleKeygenSessionError( + walletID, + err, + "Failed to create EdDSA keygen session", + natMsg, + ) + return nil, err + } + sess.Init() + sess.ListenToIncomingMessageAsync() + ec.warmUpSession() + + ctxLocal, cancel := context.WithCancel(ctx) + defer cancel() + go sess.GenerateKey(cancel) + + select { + case err := <-sess.ErrChan(): + if err != nil { + return nil, err + } + case <-ctxLocal.Done(): + // success + case <-ctx.Done(): + return nil, fmt.Errorf("EdDSA keygen cancelled") + } + return sess.GetPubKeyResult(), nil + } +} diff --git a/pkg/eventconsumer/reshare_runner.go b/pkg/eventconsumer/reshare_runner.go new file mode 100644 index 0000000..348db57 --- /dev/null +++ b/pkg/eventconsumer/reshare_runner.go @@ -0,0 +1,390 @@ +package eventconsumer + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/mpc/taurus" + "github.com/fystack/mpcium/pkg/types" + "github.com/nats-io/nats.go" +) + +// NOTE: In Taurus reshare, it just refresh the keyshare of each node but keep the same public key and threshold. +// Therefore, we don't need to create new party sessions for CMP reshare. +func (ec *eventConsumer) handleTaurusReshare(msg types.ResharingMessage, natMsg *nats.Msg) { + logger.Info( + "Starting reshare", + "walletID", + msg.WalletID, + "sessionID", + msg.SessionID, + "keyType", + msg.KeyType, + ) + + // Create Taurus session for reshare + session, err := ec.node.CreateTaurusSession( + msg.WalletID, + msg.NewThreshold, + msg.Protocol, + taurus.ActReshare, + ) + if err != nil { + logger.Error( + "Failed to create reshare session", + err, + "walletID", + msg.WalletID, + "keyType", + msg.KeyType, + ) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Failed to create %s reshare session", msg.KeyType), + natMsg, + ) + return + } + + // Load the existing key for reshare + if err := session.LoadKey(msg.WalletID); err != nil { + logger.Error( + "Failed to load key for reshare", + err, + "walletID", + msg.WalletID, + "keyType", + msg.KeyType, + ) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Failed to load key for %s reshare", msg.KeyType), + natMsg, + ) + return + } + + // Create context for reshare + ctx, cancel := context.WithTimeout( + context.Background(), + 60*time.Second, + ) // Longer timeout for reshare + defer cancel() + + // Perform reshare + keyData, err := session.Reshare(ctx) + if err != nil { + logger.Error( + "Reshare failed", + err, + "walletID", + msg.WalletID, + "sessionID", + msg.SessionID, + "keyType", + msg.KeyType, + ) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Reshare failed for %s", msg.KeyType), + natMsg, + ) + return + } + + // Create reshare result event + reshareResult := event.ResharingResultEvent{ + ResultType: event.ResultTypeSuccess, + WalletID: msg.WalletID, + NewThreshold: keyData.Threshold, + KeyType: msg.KeyType, + PubKey: keyData.PubKeyBytes, + } + + // Marshal and enqueue the result + reshareResultBytes, err := json.Marshal(reshareResult) + if err != nil { + logger.Error( + "Failed to marshal reshare result event", + err, + "walletID", + msg.WalletID, + "sessionID", + msg.SessionID, + "keyType", + msg.KeyType, + ) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Failed to marshal %s reshare result", msg.KeyType), + natMsg, + ) + return + } + + // Enqueue the reshare result + key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, msg.SessionID) + err = ec.reshareResultQueue.Enqueue(key, reshareResultBytes, &messaging.EnqueueOptions{ + IdempotententKey: composeReshareIdempotentKey(msg.SessionID, natMsg), + }) + if err != nil { + logger.Error( + "Failed to enqueue reshare result event", + err, + "walletID", + msg.WalletID, + "sessionID", + msg.SessionID, + "keyType", + msg.KeyType, + ) + ec.handleReshareSessionError( + msg.WalletID, + msg.KeyType, + msg.NewThreshold, + err, + fmt.Sprintf("Failed to enqueue %s reshare result", msg.KeyType), + natMsg, + ) + return + } + + // Remove this line - don't send reply for reshare messages + // ec.sendReplyToRemoveMsg(natMsg) + + logger.Info( + "[COMPLETED RESHARE] CMP reshare completed successfully", + "walletID", + msg.WalletID, + "sessionID", + msg.SessionID, + ) +} + +// runClassicReshare handles non-Taurus reshare flows (ECDSA/EDDSA) +func (ec *eventConsumer) runClassicReshare( + msg types.ResharingMessage, + natMsg *nats.Msg, + sessionType mpc.SessionType, +) { + walletID := msg.WalletID + keyType := msg.KeyType + + createSession := func(isNewPeer bool) (mpc.ReshareSession, error) { + return ec.node.CreateReshareSession( + sessionType, + walletID, + msg.NewThreshold, + msg.NodeIDs, + isNewPeer, + ec.reshareResultQueue, + ) + } + + oldSession, err := createSession(false) + if err != nil { + logger.Error("Failed to create old reshare session", err, "walletID", walletID) + ec.handleReshareSessionError( + walletID, + keyType, + msg.NewThreshold, + err, + "Failed to create old reshare session", + natMsg, + ) + return + } + newSession, err := createSession(true) + if err != nil { + logger.Error("Failed to create new reshare session", err, "walletID", walletID) + ec.handleReshareSessionError( + walletID, + keyType, + msg.NewThreshold, + err, + "Failed to create new reshare session", + natMsg, + ) + return + } + + if oldSession == nil && newSession == nil { + logger.Info( + "Node is not participating in this reshare (neither old nor new)", + "walletID", + walletID, + ) + return + } + + ctx := context.Background() + var wg sync.WaitGroup + + successEvent := &event.ResharingResultEvent{ + WalletID: walletID, + NewThreshold: msg.NewThreshold, + KeyType: msg.KeyType, + ResultType: event.ResultTypeSuccess, + } + + if oldSession != nil { + err := oldSession.Init() + if err != nil { + ec.handleReshareSessionError( + walletID, + keyType, + msg.NewThreshold, + err, + "Failed to init old reshare session", + natMsg, + ) + return + } + oldSession.ListenToIncomingMessageAsync() + } + + if newSession != nil { + err := newSession.Init() + if err != nil { + ec.handleReshareSessionError( + walletID, + keyType, + msg.NewThreshold, + err, + "Failed to init new reshare session", + natMsg, + ) + return + } + newSession.ListenToIncomingMessageAsync() + // In resharing process, we need to ensure that the new session is aware of the old committee peers. + // Then new committee peers can start listening to the old committee peers + // and thus enable receiving direct messages from them. + extraOldCommiteePeers := newSession.GetLegacyCommitteePeers() + newSession.ListenToPeersAsync(extraOldCommiteePeers) + } + + ec.warmUpSession() + if oldSession != nil { + ctxOld, doneOld := context.WithCancel(ctx) + go oldSession.Reshare(doneOld) + + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctxOld.Done(): + return + case err := <-oldSession.ErrChan(): + logger.Error("Old reshare session error", err) + ec.handleReshareSessionError( + walletID, + keyType, + msg.NewThreshold, + err, + "Old reshare session error", + natMsg, + ) + doneOld() + return + } + } + }() + } + + if newSession != nil { + ctxNew, doneNew := context.WithCancel(ctx) + go newSession.Reshare(doneNew) + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctxNew.Done(): + successEvent.PubKey = newSession.GetPubKeyResult() + return + case err := <-newSession.ErrChan(): + logger.Error("New reshare session error", err) + ec.handleReshareSessionError( + walletID, + keyType, + msg.NewThreshold, + err, + "New reshare session error", + natMsg, + ) + doneNew() + return + } + } + }() + } + + wg.Wait() + logger.Info( + "Reshare session finished", + "walletID", + walletID, + "pubKey", + fmt.Sprintf("%x", successEvent.PubKey), + ) + + if newSession != nil { + successBytes, err := json.Marshal(successEvent) + if err != nil { + logger.Error("Failed to marshal reshare success event", err) + ec.handleReshareSessionError( + walletID, + keyType, + msg.NewThreshold, + err, + "Failed to marshal reshare success event", + natMsg, + ) + return + } + + key := fmt.Sprintf(mpc.TypeReshareWalletResultFmt, msg.SessionID) + err = ec.reshareResultQueue.Enqueue( + key, + successBytes, + &messaging.EnqueueOptions{ + IdempotententKey: composeReshareIdempotentKey(msg.SessionID, natMsg), + }) + if err != nil { + logger.Error("Failed to publish reshare success message", err) + ec.handleReshareSessionError( + walletID, + keyType, + msg.NewThreshold, + err, + "Failed to publish reshare success message", + natMsg, + ) + return + } + logger.Info("[COMPLETED RESHARE] Successfully published", "walletID", walletID) + } else { + logger.Info("[COMPLETED RESHARE] Done (not a new party)", "walletID", walletID) + } +} diff --git a/pkg/eventconsumer/sign_runner.go b/pkg/eventconsumer/sign_runner.go new file mode 100644 index 0000000..8ac1c7a --- /dev/null +++ b/pkg/eventconsumer/sign_runner.go @@ -0,0 +1,274 @@ +package eventconsumer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/big" + "time" + + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/messaging" + "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/mpc/taurus" + "github.com/fystack/mpcium/pkg/types" + "github.com/nats-io/nats.go" +) + +func (ec *eventConsumer) handleTaurusSigning( + algorithm types.Protocol, + msg types.SignTxMessage, + natMsg *nats.Msg, +) { + logger.Info( + "Starting signing", + "walletID", + msg.WalletID, + "txID", + msg.TxID, + "algorithm", + algorithm, + ) + session, err := ec.node.CreateTaurusSession( + msg.WalletID, + ec.mpcThreshold, + algorithm, + taurus.ActSign, + ) + if err != nil { + logger.Error("Failed to create session", err, "walletID", msg.WalletID) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + fmt.Sprintf("Failed to create %s session: %v", algorithm, err), + natMsg, + ) + return + } + + // Convert transaction bytes to big.Int + txBigInt := new(big.Int).SetBytes(msg.Tx) + + // Create context for signing + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + signature, err := session.Sign(ctx, txBigInt) + if err != nil { + logger.Error( + "signing failed", + err, + "algorithm", + algorithm, + "walletID", + msg.WalletID, + "txID", + msg.TxID, + ) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + fmt.Sprintf("%s signing failed", algorithm), + natMsg, + ) + return + } + + // Create signing result event + signingResult := event.SigningResultEvent{ + ResultType: event.ResultTypeSuccess, + NetworkInternalCode: msg.NetworkInternalCode, + WalletID: msg.WalletID, + TxID: msg.TxID, + Signature: signature, // Returns the full signature + } + + // Marshal and enqueue the result + signingResultBytes, err := json.Marshal(signingResult) + if err != nil { + logger.Error( + "Failed to marshal signing result event", + err, + "algorithm", + algorithm, + "walletID", + msg.WalletID, + "txID", + msg.TxID, + ) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + fmt.Sprintf("Failed to marshal %s signing result", algorithm), + natMsg, + ) + return + } + + // Enqueue the signing result + err = ec.signingResultQueue.Enqueue( + event.SigningResultCompleteTopic, + signingResultBytes, + &messaging.EnqueueOptions{ + IdempotententKey: composeSigningIdempotentKey(msg.TxID, natMsg), + }, + ) + if err != nil { + logger.Error( + "Failed to enqueue signing result event", + err, + "algorithm", + algorithm, + "walletID", + msg.WalletID, + "txID", + msg.TxID, + ) + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + fmt.Sprintf("Failed to enqueue %s signing result", algorithm), + natMsg, + ) + return + } + + // Send reply and log success + ec.sendReplyToRemoveMsg(natMsg) + logger.Info( + "[COMPLETED SIGN] signing completed successfully", + "algorithm", + algorithm, + "walletID", + msg.WalletID, + "txID", + msg.TxID, + ) +} + +// runClassicSigning handles non-Taurus signing flows (ECDSA/EDDSA) +func (ec *eventConsumer) runClassicSigning(msg types.SignTxMessage, natMsg *nats.Msg) { + var session mpc.SigningSession + idempotentKey := composeSigningIdempotentKey(msg.TxID, natMsg) + + var sessionErr error + switch msg.KeyType { + case types.KeyTypeSecp256k1: + session, sessionErr = ec.node.CreateSigningSession( + mpc.SessionTypeECDSA, + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + ec.signingResultQueue, + idempotentKey, + ) + case types.KeyTypeEd25519: + session, sessionErr = ec.node.CreateSigningSession( + mpc.SessionTypeEDDSA, + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + ec.signingResultQueue, + idempotentKey, + ) + default: + sessionErr = fmt.Errorf("unsupported key type: %v", msg.KeyType) + } + if sessionErr != nil { + if errors.Is(sessionErr, mpc.ErrNotEnoughParticipants) { + logger.Info( + "RETRY LATER: Not enough participants to sign", + "walletID", msg.WalletID, + "txID", msg.TxID, + "nodeID", ec.node.ID(), + ) + //Return for retry later + return + } + + if errors.Is(sessionErr, mpc.ErrNotInParticipantList) { + logger.Info("Node is not in participant list for this wallet, skipping signing", + "walletID", msg.WalletID, + "txID", msg.TxID, + "nodeID", ec.node.ID(), + ) + // Skip signing instead of treating as error + return + } + + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + sessionErr, + "Failed to create signing session", + natMsg, + ) + return + } + + txBigInt := new(big.Int).SetBytes(msg.Tx) + err := session.Init(txBigInt) + if err != nil { + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to init signing session", + natMsg, + ) + return + } + + // Mark session as already processed + ec.addSession(msg.WalletID, msg.TxID) + + ctx, done := context.WithCancel(context.Background()) + go func() { + for { + select { + case <-ctx.Done(): + return + case err := <-session.ErrChan(): + if err != nil { + ec.handleSigningSessionError( + msg.WalletID, + msg.TxID, + msg.NetworkInternalCode, + err, + "Failed to sign tx", + natMsg, + ) + return + } + } + } + }() + + session.ListenToIncomingMessageAsync() + // TODO: use consul distributed lock here, only sign after all nodes has already completed listing to incoming message async + // The purpose of the sleep is to be ensuring that the node has properly set up its message listeners + // before it starts the signing process. If the signing process starts sending messages before other nodes + // have set up their listeners, those messages might be missed, potentially causing the signing process to fail. + // One solution: + // The messaging includes mechanisms for direct point-to-point communication (in point2point.go). + // The nodes could explicitly coordinate through request-response patterns before starting signing + ec.warmUpSession() + + onSuccess := func(data []byte) { + done() + ec.sendReplyToRemoveMsg(natMsg) + } + go session.Sign(onSuccess) +} diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index aa78a68..d94a9ee 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -148,20 +148,20 @@ func (p *Node) createEDDSAKeyGenSession(walletID string, threshold int, version func (p *Node) CreateTaurusSession( walletID string, threshold int, - sessionType types.KeyType, + protocol types.Protocol, act taurus.Act, ) (taurus.TaurusSession, error) { readyPeerIDs := p.peerRegistry.GetReadyPeersIncludeSelf() selfPartyID, allPartyIDs := p.generateTaurusPartyIDs(PurposeKeygen, readyPeerIDs, DefaultVersion) var session taurus.TaurusSession - switch sessionType { - case types.KeyTypeCGGMP21: + switch protocol { + case types.ProtocolCGGMP21: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.CGGMP21, p.pubSub, p.direct, p.identityStore) session = taurus.NewCGGMP21Session(walletID, selfPartyID, allPartyIDs, threshold, p.presignCache, tr, p.kvstore, p.keyinfoStore) - case types.KeyTypeTaproot: + case types.ProtocolTaproot: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.FROSTTaproot, p.pubSub, p.direct, p.identityStore) session = taurus.NewTaprootSession(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) - case types.KeyTypeFROST: + case types.ProtocolFROST: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.FROST, p.pubSub, p.direct, p.identityStore) session = taurus.NewFROSTSession(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) } diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index a86118d..82b2d48 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -1,46 +1,81 @@ package types -import "encoding/json" +import ( + "encoding/json" + "errors" + "fmt" +) + +type EventInitiatorKeyType string + +const ( + EventInitiatorKeyTypeEd25519 EventInitiatorKeyType = "ed25519" + EventInitiatorKeyTypeP256 EventInitiatorKeyType = "p256" +) type KeyType string const ( KeyTypeSecp256k1 KeyType = "secp256k1" KeyTypeEd25519 KeyType = "ed25519" - KeyTypeCGGMP21 KeyType = "cggmp21" - KeyTypeFROST KeyType = "frost" - KeyTypeTaproot KeyType = "taproot" ) -type EventInitiatorKeyType string +type Protocol string const ( - EventInitiatorKeyTypeEd25519 EventInitiatorKeyType = "ed25519" - EventInitiatorKeyTypeP256 EventInitiatorKeyType = "p256" + ProtocolGG18 Protocol = "gg18" + ProtocolCGGMP21 Protocol = "cggmp21" + ProtocolFROST Protocol = "frost" + ProtocolTaproot Protocol = "taproot" ) +func (p Protocol) String() string { + return string(p) +} + +// ValidateKeyProtocol checks if a key type supports a given protocol. +func ValidateKeyProtocol(keyType KeyType, protocol Protocol) error { + if keyType == "" || protocol == "" { + return errors.New("key_type and protocol are required") + } + + switch keyType { + case KeyTypeSecp256k1: + if protocol != ProtocolGG18 && protocol != ProtocolCGGMP21 { + return fmt.Errorf("protocol %q not supported for key_type %q; expected gg18 or cggmp21", protocol, keyType) + } + case KeyTypeEd25519: + if protocol != ProtocolFROST && protocol != ProtocolTaproot { + return fmt.Errorf("protocol %q not supported for key_type %q; expected frost or taproot", protocol, keyType) + } + default: + return fmt.Errorf("unsupported key_type %q", keyType) + } + return nil +} + // InitiatorMessage is anything that carries a payload to verify and its signature. type InitiatorMessage interface { - // Raw returns the canonical byte‐slice that was signed. Raw() ([]byte, error) - // Sig returns the signature over Raw(). Sig() []byte - // InitiatorID returns the ID whose public key we have to look up. InitiatorID() string } type GenerateKeyMessage struct { - WalletID string `json:"wallet_id"` - Signature []byte `json:"signature"` + WalletID string `json:"wallet_id"` + ECDSAProtocol Protocol `json:"ecdsa_protocol,omitempty"` + EdDSAProtocol Protocol `json:"eddsa_protocol,omitempty"` + Signature []byte `json:"signature"` } type SignTxMessage struct { - KeyType KeyType `json:"key_type"` - WalletID string `json:"wallet_id"` - NetworkInternalCode string `json:"network_internal_code"` - TxID string `json:"tx_id"` - Tx []byte `json:"tx"` - Signature []byte `json:"signature"` + KeyType KeyType `json:"key_type"` + Protocol Protocol `json:"protocol,omitempty"` + WalletID string `json:"wallet_id"` + NetworkInternalCode string `json:"network_internal_code"` + TxID string `json:"tx_id"` + Tx []byte `json:"tx"` + Signature []byte `json:"signature"` } type ResharingMessage struct { @@ -48,18 +83,41 @@ type ResharingMessage struct { NodeIDs []string `json:"node_ids"` // new peer IDs NewThreshold int `json:"new_threshold"` KeyType KeyType `json:"key_type"` + Protocol Protocol `json:"protocol,omitempty"` WalletID string `json:"wallet_id"` Signature []byte `json:"signature,omitempty"` } type PresignTxMessage struct { - KeyType KeyType `json:"key_type"` - WalletID string `json:"wallet_id"` - Signature []byte `json:"signature"` + KeyType KeyType `json:"key_type"` + Protocol Protocol `json:"protocol"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` + Signature []byte `json:"signature"` +} + +func (m *GenerateKeyMessage) Raw() ([]byte, error) { + payload := struct { + WalletID string `json:"wallet_id"` + ECDSAProtocol Protocol `json:"ecdsa_protocol,omitempty"` + EdDSAProtocol Protocol `json:"eddsa_protocol,omitempty"` + }{ + WalletID: m.WalletID, + ECDSAProtocol: m.ECDSAProtocol, + EdDSAProtocol: m.EdDSAProtocol, + } + return json.Marshal(payload) +} + +func (m *GenerateKeyMessage) Sig() []byte { + return m.Signature +} + +func (m *GenerateKeyMessage) InitiatorID() string { + return m.WalletID } func (m *SignTxMessage) Raw() ([]byte, error) { - // omit the Signature field itself when computing the signed‐over data payload := struct { KeyType KeyType `json:"key_type"` WalletID string `json:"wallet_id"` @@ -84,21 +142,9 @@ func (m *SignTxMessage) InitiatorID() string { return m.TxID } -func (m *GenerateKeyMessage) Raw() ([]byte, error) { - return []byte(m.WalletID), nil -} - -func (m *GenerateKeyMessage) Sig() []byte { - return m.Signature -} - -func (m *GenerateKeyMessage) InitiatorID() string { - return m.WalletID -} - func (m *ResharingMessage) Raw() ([]byte, error) { - copy := *m // create a shallow copy - copy.Signature = nil // modify only the copy + copy := *m + copy.Signature = nil return json.Marshal(©) } @@ -110,22 +156,25 @@ func (m *ResharingMessage) InitiatorID() string { return m.WalletID } -func (m PresignTxMessage) Raw() ([]byte, error) { - // omit the Signature field itself when computing the signed‐over data +func (m *PresignTxMessage) Raw() ([]byte, error) { payload := struct { - KeyType KeyType `json:"key_type"` - WalletID string `json:"wallet_id"` + KeyType KeyType `json:"key_type"` + Protocol Protocol `json:"protocol"` + WalletID string `json:"wallet_id"` + TxID string `json:"tx_id"` }{ KeyType: m.KeyType, + Protocol: m.Protocol, WalletID: m.WalletID, + TxID: m.TxID, } return json.Marshal(payload) } -func (m PresignTxMessage) Sig() []byte { +func (m *PresignTxMessage) Sig() []byte { return m.Signature } -func (m PresignTxMessage) InitiatorID() string { +func (m *PresignTxMessage) InitiatorID() string { return m.WalletID } From 01af87cef04227110be900dd2255d88c9d3d5ef8 Mon Sep 17 00:00:00 2001 From: vietddude Date: Thu, 6 Nov 2025 17:11:20 +0700 Subject: [PATCH 17/21] feat (WIP): implement presign pool worker and enhance presign handling with new data structures --- cmd/mpcium/main.go | 44 +++- examples/generate/main.go | 6 +- examples/sign/main.go | 4 +- pkg/eventconsumer/event_consumer.go | 55 ++++- pkg/mpc/node.go | 36 ++-- pkg/mpc/taurus/cggmp21.go | 112 ++++++++-- pkg/mpc/taurus/presign.go | 95 --------- pkg/presign/presign.go | 309 ++++++++++++++++++++++++++++ pkg/presigninfo/presigninfo.go | 108 ++++++++++ pkg/types/initiator_msg.go | 37 +++- 10 files changed, 655 insertions(+), 151 deletions(-) delete mode 100644 pkg/mpc/taurus/presign.go create mode 100644 pkg/presign/presign.go create mode 100644 pkg/presigninfo/presigninfo.go diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 5949bbf..ef8385a 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "github.com/fystack/mpcium/pkg/client" "github.com/fystack/mpcium/pkg/config" "github.com/fystack/mpcium/pkg/constant" "github.com/fystack/mpcium/pkg/event" @@ -22,7 +23,10 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/presign" + "github.com/fystack/mpcium/pkg/presigninfo" "github.com/fystack/mpcium/pkg/security" + "github.com/fystack/mpcium/pkg/types" "github.com/hashicorp/consul/api" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -77,6 +81,11 @@ func main() { Aliases: []string{"k"}, Usage: "Path to file containing password for decrypting .age encrypted node private key", }, + &cli.BoolFlag{ + Name: "presign-pool-worker", + Usage: "Enable presign pool worker", + Value: false, + }, &cli.BoolFlag{ Name: "debug", Usage: "Enable debug logging", @@ -109,6 +118,7 @@ func runNode(ctx context.Context, c *cli.Command) error { usePrompts := c.Bool("prompt-credentials") passwordFile := c.String("password-file") agePasswordFile := c.String("identity-password-file") + presignPoolWorker := c.Bool("presign-pool-worker") debug := c.Bool("debug") viper.SetDefault("backup_enabled", true) @@ -193,6 +203,7 @@ func runNode(ctx context.Context, c *cli.Command) error { peerNodeIDs := GetPeerIDs(peers) peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV(), directMessaging, pubsub, identityStore) + presignInfoStore := presigninfo.NewStore(consulClient.KV()) mpcNode := mpc.NewNode( nodeID, @@ -201,6 +212,7 @@ func runNode(ctx context.Context, c *cli.Command) error { directMessaging, badgerKV, keyinfoStore, + presignInfoStore, peerRegistry, identityStore, ) @@ -293,6 +305,37 @@ func runNode(ctx context.Context, c *cli.Command) error { logger.Info("All consumers have finished") close(errChan) }() + + // Start presign pool worker before entering the blocking error loop + if presignPoolWorker { + presignPoolCtx, presignPoolCancel := context.WithCancel(appContext) + defer presignPoolCancel() + localSigner, err := client.NewLocalSigner(types.EventInitiatorKeyTypeEd25519, client.LocalSignerOptions{ + KeyPath: "./event_initiator.key", + }) + if err != nil { + logger.Fatal("Failed to create local signer", err) + } + mpcClient := client.NewMPCClient(client.Options{ + NatsConn: natsConn, + Signer: localSigner, + }) + presignPool := presign.NewPresignPool(nil, mpcClient, presignInfoStore) + + _, err = pubsub.Subscribe(eventconsumer.MPCHotWalletEvent, func(nm *nats.Msg) { + walletID := string(nm.Data) + if walletID != "" { + presignPool.TouchHot(walletID) + } + }) + if err != nil { + logger.Warn("Failed to subscribe to hot wallet events", "err", err) + } + + presignPool.Start(presignPoolCtx) + defer presignPool.Stop() + } + for err := range errChan { if err != nil { logger.Error("Consumer error received", err) @@ -300,7 +343,6 @@ func runNode(ctx context.Context, c *cli.Command) error { return err } } - return nil } diff --git a/examples/generate/main.go b/examples/generate/main.go index a2cd085..5e263c7 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -134,7 +134,11 @@ func main() { for _, walletID := range walletIDs { wg.Add(1) // Add to WaitGroup BEFORE attempting to create wallet - if err := mpcClient.CreateWallet(walletID); err != nil { + if err := mpcClient.CreateWallet(&types.GenerateKeyMessage{ + WalletID: walletID, + ECDSAProtocol: types.ProtocolCGGMP21, + EdDSAProtocol: types.ProtocolGG18, + }); err != nil { logger.Error("CreateWallet failed", err) walletStartTimes.Delete(walletID) // Mark this wallet as processed to prevent callback from processing it diff --git a/examples/sign/main.go b/examples/sign/main.go index 51768b9..fc9f593 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -71,8 +71,8 @@ func main() { txMsg := &types.SignTxMessage{ KeyType: types.KeyTypeSecp256k1, - Protocol: types.ProtocolFROST, - WalletID: "6d553e80-a1dc-4894-9eaf-b81e3fe0c94a", // Use the generated wallet ID + Protocol: types.ProtocolCGGMP21, + WalletID: "7ae6ae1c-7663-4dc4-b982-33fb0a3602c3", // Use the generated wallet ID NetworkInternalCode: "solana-devnet", TxID: txID, Tx: dummyTx, diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index a3034aa..fadaa97 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -26,6 +26,9 @@ const ( MPCReshareEvent = "mpc:reshare" MPCPresignEvent = "mpc:presign" + // Internal event to notify presign pool of a hot wallet + MPCHotWalletEvent = "mpc:wallet_hot" + DefaultConcurrentKeygen = 2 DefaultConcurrentSigning = 20 DefaultSessionWarmUpDelay = 200 @@ -66,6 +69,12 @@ type eventConsumer struct { cleanupInterval time.Duration // How often to run cleanup sessionTimeout time.Duration // How long before a session is considered stale cleanupStopChan chan struct{} // Signal to stop cleanup goroutine + + // Track recent signing activity to detect hot wallets + hotMu sync.Mutex + recentSigns map[string][]time.Time // key: walletID|keyType|protocol → timestamps within window + hotWindow time.Duration // window for counting signs (e.g., 5 minutes) + hotThreshold int // signs needed to mark as hot } func NewEventConsumer( @@ -120,6 +129,9 @@ func NewEventConsumer( keygenMsgBuffer: make(chan *nats.Msg, 100), signingMsgBuffer: make(chan *nats.Msg, 200), // Larger buffer for signing sessionWarmUpDelayMs: sessionWarmUpDelayMs, + recentSigns: make(map[string][]time.Time), + hotWindow: 5 * time.Minute, + hotThreshold: 2, } go ec.startKeyGenEventWorker() @@ -393,6 +405,9 @@ func (ec *eventConsumer) handleSigningEvent(natMsg *nats.Msg) { ec.node.ID(), ) + // Track activity to detect hot wallets + ec.trackAndMaybeNotifyHot(msg) + // Check for duplicate session and track if new if ec.checkDuplicateSession(msg.WalletID, msg.TxID) { duplicateErr := fmt.Errorf( @@ -658,7 +673,7 @@ func (ec *eventConsumer) consumePresignEvent() error { } if success { - ec.handlePresignSessionSuccess(msg.WalletID, natMsg) + ec.handlePresignSessionSuccess(msg.WalletID, msg.TxID, natMsg) } else { ec.handlePresignSessionError(msg.WalletID, fmt.Errorf("presign operation returned false"), @@ -676,10 +691,11 @@ func (ec *eventConsumer) consumePresignEvent() error { } // handlePresignSessionSuccess handles successful presign operations -func (ec *eventConsumer) handlePresignSessionSuccess(walletID string, natMsg *nats.Msg) { +func (ec *eventConsumer) handlePresignSessionSuccess(walletID string, txID string, natMsg *nats.Msg) { presignResult := event.PresignResultEvent{ ResultType: event.ResultTypeSuccess, WalletID: walletID, + TxID: txID, Status: "success", } @@ -863,3 +879,38 @@ func composeSigningIdempotentKey(txID string, natMsg *nats.Msg) string { func composeReshareIdempotentKey(sessionID string, natMsg *nats.Msg) string { return composeIdempotentKey(sessionID, natMsg, mpc.TypeReshareWalletResultFmt) } + +// trackAndMaybeNotifyHot records a signing event and publishes a hot wallet event +// if at least hotThreshold signs occur within hotWindow for the same +// (walletID, keyType, protocol) tuple. +func (ec *eventConsumer) trackAndMaybeNotifyHot(msg types.SignTxMessage) { + if msg.Protocol != types.ProtocolCGGMP21 { + return + } + key := fmt.Sprintf("%s:%s:%s", msg.WalletID, string(msg.KeyType), string(msg.Protocol)) + now := time.Now() + + ec.hotMu.Lock() + // prune old entries + list := ec.recentSigns[key] + pruned := list[:0] + cutoff := now.Add(-ec.hotWindow) + for _, t := range list { + if t.After(cutoff) { + pruned = append(pruned, t) + } + } + + ec.recentSigns[key] = append([]time.Time(nil), pruned...) + currentCount := len(ec.recentSigns[key]) + + // If this push reaches the threshold, publish hot wallet once + shouldPublish := currentCount+1 == ec.hotThreshold + ec.recentSigns[key] = append(ec.recentSigns[key], now) + ec.hotMu.Unlock() + + if shouldPublish { + _ = ec.pubsub.Publish(MPCHotWalletEvent, []byte(msg.WalletID)) + logger.Info("Published hot wallet event", "walletID", msg.WalletID) + } +} diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index d94a9ee..4e92200 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -15,6 +15,7 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc/taurus" + "github.com/fystack/mpcium/pkg/presigninfo" "github.com/fystack/mpcium/pkg/types" "github.com/taurusgroup/multi-party-sig/pkg/party" ) @@ -34,13 +35,13 @@ type Node struct { nodeID string peerIDs []string - pubSub messaging.PubSub - direct messaging.DirectMessaging - kvstore kvstore.KVStore - keyinfoStore keyinfo.Store - ecdsaPreParams []*keygen.LocalPreParams - identityStore identity.Store - presignCache *taurus.PresignCache + pubSub messaging.PubSub + direct messaging.DirectMessaging + kvstore kvstore.KVStore + keyinfoStore keyinfo.Store + presignInfoStore presigninfo.Store + ecdsaPreParams []*keygen.LocalPreParams + identityStore identity.Store peerRegistry PeerRegistry } @@ -52,6 +53,7 @@ func NewNode( direct messaging.DirectMessaging, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, + presignInfoStore presigninfo.Store, peerRegistry PeerRegistry, identityStore identity.Store, ) *Node { @@ -60,15 +62,15 @@ func NewNode( logger.Info("Starting new node, preparams is generated successfully!", "elapsed", elapsed.Milliseconds()) node := &Node{ - nodeID: nodeID, - peerIDs: peerIDs, - pubSub: pubSub, - direct: direct, - kvstore: kvstore, - keyinfoStore: keyinfoStore, - peerRegistry: peerRegistry, - identityStore: identityStore, - presignCache: taurus.NewPresignCache(10 * time.Minute), + nodeID: nodeID, + peerIDs: peerIDs, + pubSub: pubSub, + direct: direct, + kvstore: kvstore, + keyinfoStore: keyinfoStore, + presignInfoStore: presignInfoStore, + peerRegistry: peerRegistry, + identityStore: identityStore, } node.ecdsaPreParams = node.generatePreParams() @@ -157,7 +159,7 @@ func (p *Node) CreateTaurusSession( switch protocol { case types.ProtocolCGGMP21: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.CGGMP21, p.pubSub, p.direct, p.identityStore) - session = taurus.NewCGGMP21Session(walletID, selfPartyID, allPartyIDs, threshold, p.presignCache, tr, p.kvstore, p.keyinfoStore) + session = taurus.NewCGGMP21Session(walletID, selfPartyID, allPartyIDs, threshold, nil, tr, p.kvstore, p.keyinfoStore) case types.ProtocolTaproot: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.FROSTTaproot, p.pubSub, p.direct, p.identityStore) session = taurus.NewTaprootSession(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) diff --git a/pkg/mpc/taurus/cggmp21.go b/pkg/mpc/taurus/cggmp21.go index 6e3c94a..adad007 100644 --- a/pkg/mpc/taurus/cggmp21.go +++ b/pkg/mpc/taurus/cggmp21.go @@ -3,15 +3,20 @@ package taurus import ( "context" cryptoEcdsa "crypto/ecdsa" + "crypto/sha256" "errors" "fmt" "math/big" + "sort" + "time" "github.com/btcsuite/btcd/btcec/v2" + "github.com/fxamacker/cbor/v2" "github.com/fystack/mpcium/pkg/encoding" "github.com/fystack/mpcium/pkg/keyinfo" "github.com/fystack/mpcium/pkg/kvstore" "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/presigninfo" "github.com/fystack/mpcium/pkg/types" "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" "github.com/taurusgroup/multi-party-sig/pkg/math/curve" @@ -22,9 +27,9 @@ import ( type CGGMP21Session struct { *commonSession - workerPool *pool.Pool - savedData *cmp.Config - presignCache *PresignCache + workerPool *pool.Pool + savedData *cmp.Config + presignInfoStore presigninfo.Store } func NewCGGMP21Session( @@ -32,7 +37,7 @@ func NewCGGMP21Session( selfID party.ID, peerIDs party.IDSlice, threshold int, - presignCache *PresignCache, + presignInfoStore presigninfo.Store, transport Transport, kvstore kvstore.KVStore, keyinfoStore keyinfo.Store, @@ -47,10 +52,10 @@ func NewCGGMP21Session( keyinfoStore, ) return &CGGMP21Session{ - commonSession: commonSession, - workerPool: pool.NewPool(0), - savedData: nil, - presignCache: presignCache, + commonSession: commonSession, + workerPool: pool.NewPool(0), + savedData: nil, + presignInfoStore: presignInfoStore, } } @@ -149,10 +154,22 @@ func (p *CGGMP21Session) Sign(ctx context.Context, msg *big.Int) ([]byte, error) err error ) - if presign := p.getCachedPresign(); presign != nil { - result, err = p.run(ctx, cmp.PresignOnline(p.savedData, presign, msgHash, p.workerPool)) - if err != nil { - return nil, fmt.Errorf("presign online failed: %w", err) + // Try deterministic presign selection across nodes + if p.presignInfoStore != nil { + presig, txID := p.selectAndLoadPresign(msgHash) + if presig != nil && txID != "" { + logger.Info("using presign for signing", "walletID", p.sessionID, "txID", txID) + result, err = p.run(ctx, cmp.PresignOnline(p.savedData, presig, msgHash, p.workerPool)) + if err != nil { + return nil, fmt.Errorf("presign online failed: %w", err) + } + // Mark used (best-effort) + _ = p.markPresignUsed(txID) + } else { + result, err = p.run(ctx, cmp.Sign(p.savedData, p.peerIDs, msgHash, p.workerPool)) + if err != nil { + return nil, fmt.Errorf("full sign failed: %w", err) + } } } else { result, err = p.run(ctx, cmp.Sign(p.savedData, p.peerIDs, msgHash, p.workerPool)) @@ -187,8 +204,14 @@ func (p *CGGMP21Session) Presign(ctx context.Context, txID string) (bool, error) if err = presig.Validate(); err != nil { return false, errors.New("presign validation failed") } - - p.presignCache.Put(p.sessionID, txID, presig) + // Store presign in KV using deterministic key including txID + packed, err := cbor.Marshal(presig) + if err != nil { + return false, fmt.Errorf("marshal presign: %w", err) + } + if err := p.kvstore.Put(p.composePresignKey(p.sessionID, txID), packed); err != nil { + return false, fmt.Errorf("store presign: %w", err) + } return true, nil } @@ -247,15 +270,60 @@ func (p *CGGMP21Session) composeKey(sid string) string { return fmt.Sprintf("cggmp21:%s", sid) } -func (p *CGGMP21Session) getCachedPresign() *ecdsa.PreSignature { - if p.presignCache == nil { - return nil - } +func (p *CGGMP21Session) composePresignKey(sid, txID string) string { + return fmt.Sprintf("cggmp21:%s:%s", sid, txID) +} - presig, ok := p.presignCache.Get(p.sessionID) - if !ok || presig == nil { - return nil +// selectAndLoadPresign deterministically chooses a presign txID and loads its PreSignature from KV. +// Selection: sort by CreatedAt asc, TxID asc; pick index = hash(msgHash) mod len(list). +func (p *CGGMP21Session) selectAndLoadPresign(msgHash []byte) (*ecdsa.PreSignature, string) { + infos, err := p.presignInfoStore.ListPendingPresigns(p.sessionID) + if err != nil || len(infos) == 0 { + return nil, "" + } + // filter by active + protocol/keytype + filtered := make([]*presigninfo.PresignInfo, 0, len(infos)) + for _, inf := range infos { + if inf.Status == presigninfo.PresignStatusActive && inf.Protocol == types.ProtocolCGGMP21 && inf.KeyType == types.KeyTypeSecp256k1 { + filtered = append(filtered, inf) + } } + if len(filtered) == 0 { + return nil, "" + } + // sort + sort.Slice(filtered, func(i, j int) bool { + if filtered[i].CreatedAt.Equal(filtered[j].CreatedAt) { + return filtered[i].TxID < filtered[j].TxID + } + return filtered[i].CreatedAt.Before(filtered[j].CreatedAt) + }) + // pick index via hash + h := sha256.Sum256(msgHash) + idx := int(h[0]) % len(filtered) + chosen := filtered[idx] + // load material + bytes, err := p.kvstore.Get(p.composePresignKey(p.sessionID, chosen.TxID)) + if err != nil || len(bytes) == 0 { + return nil, "" + } + presig := new(ecdsa.PreSignature) + if err := cbor.Unmarshal(bytes, presig); err != nil { + return nil, "" + } + if err := presig.Validate(); err != nil { + return nil, "" + } + return presig, chosen.TxID +} - return presig +func (p *CGGMP21Session) markPresignUsed(txID string) error { + info, err := p.presignInfoStore.Get(p.sessionID, txID) + if err != nil { + return err + } + now := time.Now() + info.Status = presigninfo.PresignStatusUsed + info.UsedAt = &now + return p.presignInfoStore.Save(p.sessionID, info) } diff --git a/pkg/mpc/taurus/presign.go b/pkg/mpc/taurus/presign.go deleted file mode 100644 index a2f608e..0000000 --- a/pkg/mpc/taurus/presign.go +++ /dev/null @@ -1,95 +0,0 @@ -package taurus - -import ( - "sync" - "time" - - "github.com/taurusgroup/multi-party-sig/pkg/ecdsa" -) - -// PresignCache provides an in-memory cache of pre-signature data -// with automatic TTL-based cleanup. -type PresignCache struct { - mu sync.Mutex - data map[string][]PresignEntry // walletID -> entries - ttl time.Duration -} - -type PresignEntry struct { - SessionID string - Result *ecdsa.PreSignature - CreatedAt time.Time -} - -// NewPresignCache creates a new cache with optional TTL. -// If ttl <= 0, defaults to 10 minutes. -func NewPresignCache(ttl time.Duration) *PresignCache { - if ttl <= 0 { - ttl = 10 * time.Minute - } - - cache := &PresignCache{ - data: make(map[string][]PresignEntry), - ttl: ttl, - } - - go cache.startCleanup() - return cache -} - -// Put adds a new presign result for a wallet. -func (c *PresignCache) Put(walletID, sessionID string, res *ecdsa.PreSignature) { - c.mu.Lock() - defer c.mu.Unlock() - - c.data[walletID] = append(c.data[walletID], PresignEntry{ - SessionID: sessionID, - Result: res, - CreatedAt: time.Now(), - }) -} - -// Get retrieves and removes the oldest available presign for a wallet. -func (c *PresignCache) Get(walletID string) (*ecdsa.PreSignature, bool) { - c.mu.Lock() - defer c.mu.Unlock() - - entries := c.data[walletID] - if len(entries) == 0 { - return nil, false - } - - res := entries[0].Result - c.data[walletID] = entries[1:] // pop first entry - if len(c.data[walletID]) == 0 { - delete(c.data, walletID) - } - return res, true -} - -// startCleanup periodically removes expired presign entries based on TTL. -func (c *PresignCache) startCleanup() { - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for range ticker.C { - now := time.Now() - expireBefore := now.Add(-c.ttl) - - c.mu.Lock() - for walletID, entries := range c.data { - filtered := entries[:0] - for _, e := range entries { - if e.CreatedAt.After(expireBefore) { - filtered = append(filtered, e) - } - } - if len(filtered) == 0 { - delete(c.data, walletID) - } else { - c.data[walletID] = filtered - } - } - c.mu.Unlock() - } -} diff --git a/pkg/presign/presign.go b/pkg/presign/presign.go new file mode 100644 index 0000000..c198be9 --- /dev/null +++ b/pkg/presign/presign.go @@ -0,0 +1,309 @@ +package presign + +import ( + "context" + "sync" + "time" + + "github.com/fystack/mpcium/pkg/client" + "github.com/fystack/mpcium/pkg/event" + "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/presigninfo" + "github.com/fystack/mpcium/pkg/types" + "github.com/google/uuid" + "golang.org/x/sync/semaphore" +) + +type Config struct { + MinPoolSize int + MaxPoolSize int + GlobalMaxConcurrency int + PerWalletMaxConcurrency int + HotWindowDuration time.Duration + RefillInterval time.Duration +} + +var DefaultConfig = Config{ + MinPoolSize: 5, + MaxPoolSize: 20, + GlobalMaxConcurrency: 20, + PerWalletMaxConcurrency: 5, + HotWindowDuration: 5 * time.Minute, + RefillInterval: 10 * time.Second, +} + +type PresignPool struct { + cfg *Config + ctx context.Context + cancel context.CancelFunc + client client.MPCClient + infoStore presigninfo.Store + + wg sync.WaitGroup + mu sync.RWMutex + hot map[string]time.Time + pending map[string]int + cache map[string][]*presigninfo.PresignInfo + globalSem *semaphore.Weighted +} + +func NewPresignPool(cfg *Config, client client.MPCClient, infoStore presigninfo.Store) *PresignPool { + if cfg == nil { + tmp := DefaultConfig + cfg = &tmp + } + + ctx, cancel := context.WithCancel(context.Background()) + p := &PresignPool{ + cfg: cfg, + client: client, + infoStore: infoStore, + ctx: ctx, + cancel: cancel, + hot: make(map[string]time.Time), + pending: make(map[string]int), + cache: make(map[string][]*presigninfo.PresignInfo), + globalSem: semaphore.NewWeighted(int64(cfg.GlobalMaxConcurrency)), + } + + // Subscribe to presign completion events + if err := p.client.OnPresignResult(func(evt event.PresignResultEvent) { + p.OnPresignCompleted(evt.WalletID, evt.TxID, evt.ResultType == event.ResultTypeSuccess) + }); err != nil { + logger.Error("subscribe presign handler failed", err) + } + + return p +} + +func (p *PresignPool) Start(ctx context.Context) { + logger.Info("[PRESIGN] Presign pool worker started") + + p.wg.Add(2) + go p.refillLoop() + go p.cleanupLoop() + + go func() { + <-ctx.Done() + p.Stop() + }() +} + +func (p *PresignPool) Stop() { + p.cancel() + p.wg.Wait() + logger.Info("presign pool stopped") +} + +func (p *PresignPool) TouchHot(walletID string) { + p.mu.Lock() + defer p.mu.Unlock() + p.hot[walletID] = time.Now() + logger.Info("hot wallet detected", "wallet", walletID) +} + +func (p *PresignPool) refillLoop() { + defer p.wg.Done() + ticker := time.NewTicker(p.cfg.RefillInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.refill() + case <-p.ctx.Done(): + return + } + } +} + +func (p *PresignPool) cleanupLoop() { + defer p.wg.Done() + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.cleanup() + p.syncCache() // refresh cache periodically + case <-p.ctx.Done(): + return + } + } +} + +func (p *PresignPool) refill() { + now := time.Now() + + p.mu.RLock() + wallets := make([]string, 0, len(p.hot)) + for w, t := range p.hot { + if now.Sub(t) < p.cfg.HotWindowDuration { + wallets = append(wallets, w) + } + } + p.mu.RUnlock() + + // refill sequentially, with delay between wallets + for i, wallet := range wallets { + if i > 0 { + time.Sleep(2 * time.Second) // spacing between wallets + } + if err := p.refillWallet(wallet); err != nil { + logger.Warn("refill failed", "wallet", wallet, "err", err) + } + } +} + +func (p *PresignPool) refillWallet(walletID string) error { + list := p.getPresignListCached(walletID) + + var ready int + for _, m := range list { + if m.Status == presigninfo.PresignStatusActive { + ready++ + } + } + + if ready >= p.cfg.MinPoolSize { + return nil + } + + // Skip if there is a pending request + p.mu.Lock() + if p.pending[walletID] > 0 { + p.mu.Unlock() + return nil + } + p.pending[walletID] = 1 + p.mu.Unlock() + + logger.Info("refilling presign", "wallet", walletID, "current", ready, "target", p.cfg.MinPoolSize) + go p.requestPresign(walletID) + return nil +} + +func (p *PresignPool) requestPresign(walletID string) { + if err := p.globalSem.Acquire(p.ctx, 1); err != nil { + p.decrement(walletID) + return + } + defer p.globalSem.Release(1) + + // throttle lightly to allow other operations to continue + time.Sleep(3 * time.Second) + + txID := "presign_" + uuid.NewString() + req := &types.PresignTxMessage{ + KeyType: types.KeyTypeSecp256k1, + Protocol: types.ProtocolCGGMP21, + WalletID: walletID, + TxID: txID, + } + + if err := p.client.PresignTransaction(req); err != nil { + logger.Error("publish presign failed", err, "wallet", walletID) + p.decrement(walletID) + return + } + + logger.Debug("presign request sent", "wallet", walletID, "tx_id", txID) +} + +func (p *PresignPool) OnPresignCompleted(walletID, txID string, success bool) { + p.decrement(walletID) + if !success { + logger.Warn("presign failed", "wallet", walletID, "tx_id", txID) + return + } + + info := &presigninfo.PresignInfo{ + TxID: txID, + WalletID: walletID, + KeyType: types.KeyTypeSecp256k1, + Protocol: types.ProtocolCGGMP21, + Status: presigninfo.PresignStatusActive, + CreatedAt: time.Now(), + } + + // update cache + p.mu.Lock() + p.cache[walletID] = append(p.cache[walletID], info) + p.mu.Unlock() + + // async write to Consul KV + go func() { + if err := p.infoStore.Save(walletID, info); err != nil { + logger.Error("save presign info failed", err, "wallet", walletID) + } + }() +} + +func (p *PresignPool) getPresignListCached(walletID string) []*presigninfo.PresignInfo { + p.mu.RLock() + cached := p.cache[walletID] + p.mu.RUnlock() + + if len(cached) > 0 { + return cached + } + + // cache miss → load from Consul + list, err := p.infoStore.ListPendingPresigns(walletID) + if err != nil { + logger.Warn("load presign list failed", "wallet", walletID, "err", err) + return nil + } + + p.mu.Lock() + p.cache[walletID] = list + p.mu.Unlock() + return list +} + +// sync cache periodically +func (p *PresignPool) syncCache() { + wallets := p.GetHotWalletsSnapshot() + for _, wallet := range wallets { + list, err := p.infoStore.ListPendingPresigns(wallet) + if err != nil { + logger.Warn("sync cache failed", "wallet", wallet, "err", err) + continue + } + p.mu.Lock() + p.cache[wallet] = list + p.mu.Unlock() + } +} + +func (p *PresignPool) decrement(walletID string) { + p.mu.Lock() + defer p.mu.Unlock() + if n := p.pending[walletID]; n > 1 { + p.pending[walletID] = n - 1 + } else { + delete(p.pending, walletID) + } +} + +func (p *PresignPool) cleanup() { + now := time.Now() + p.mu.Lock() + for w, t := range p.hot { + if now.Sub(t) >= p.cfg.HotWindowDuration { + delete(p.hot, w) + } + } + p.mu.Unlock() +} + +func (p *PresignPool) GetHotWalletsSnapshot() []string { + p.mu.RLock() + defer p.mu.RUnlock() + keys := make([]string, 0, len(p.hot)) + for k := range p.hot { + keys = append(keys, k) + } + return keys +} diff --git a/pkg/presigninfo/presigninfo.go b/pkg/presigninfo/presigninfo.go new file mode 100644 index 0000000..0489ff2 --- /dev/null +++ b/pkg/presigninfo/presigninfo.go @@ -0,0 +1,108 @@ +package presigninfo + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/fystack/mpcium/pkg/infra" + "github.com/fystack/mpcium/pkg/types" + "github.com/hashicorp/consul/api" +) + +const ( + PresignStatusActive = "active" + PresignStatusUsed = "used" +) + +type PresignInfo struct { + TxID string `json:"tx_id"` + WalletID string `json:"wallet_id"` + KeyType types.KeyType `json:"key_type"` + Protocol types.Protocol `json:"protocol"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + UsedAt *time.Time `json:"used_at,omitempty"` +} + +type store struct { + consulKV infra.ConsulKV +} + +func NewStore(consulKV infra.ConsulKV) *store { + return &store{consulKV: consulKV} +} + +type Store interface { + Get(walletID string, txID string) (*PresignInfo, error) + Save(walletID string, info *PresignInfo) error + ListPendingPresigns(walletID string) ([]*PresignInfo, error) +} + +func (s *store) Get(walletID string, txID string) (*PresignInfo, error) { + pair, _, err := s.consulKV.Get(s.composeKey(walletID, txID), nil) + if err != nil { + return nil, fmt.Errorf("Failed to get presign info: %w", err) + } + if pair == nil { + return nil, fmt.Errorf("Presign info not found") + } + + info := &PresignInfo{} + err = json.Unmarshal(pair.Value, info) + if err != nil { + return nil, fmt.Errorf("Failed to unmarshal presign info: %w", err) + } + + return info, nil +} + +func (s *store) Save(walletID string, info *PresignInfo) error { + bytes, err := json.Marshal(info) + if err != nil { + return fmt.Errorf("failed to marshal presign info: %w", err) + } + + pair := &api.KVPair{ + Key: s.composeKey(walletID, info.TxID), + Value: bytes, + } + + _, err = s.consulKV.Put(pair, nil) + if err != nil { + return fmt.Errorf("Failed to save presign info: %w", err) + } + + return nil +} + +func (s *store) Delete(walletID string, txID string) error { + _, err := s.consulKV.Delete(s.composeKey(walletID, txID), nil) + if err != nil { + return fmt.Errorf("Failed to delete presign info: %w", err) + } + return nil +} + +// ListPendingPresigns returns all pending presigns for a given wallet ID +func (s *store) ListPendingPresigns(walletID string) ([]*PresignInfo, error) { + entries, _, err := s.consulKV.List(s.composeKey(walletID, ""), nil) + if err != nil { + return nil, fmt.Errorf("Failed to list presign info: %w", err) + } + infos := make([]*PresignInfo, 0, len(entries)) + for _, entry := range entries { + info := &PresignInfo{} + if err := json.Unmarshal(entry.Value, info); err != nil { + return nil, fmt.Errorf("Failed to unmarshal presign info: %w", err) + } + if info.TxID != "" { + infos = append(infos, info) + } + } + return infos, nil +} + +func (s *store) composeKey(walletID string, txID string) string { + return fmt.Sprintf("presign_info/%s/%s", walletID, txID) +} diff --git a/pkg/types/initiator_msg.go b/pkg/types/initiator_msg.go index 82b2d48..8dd1f82 100644 --- a/pkg/types/initiator_msg.go +++ b/pkg/types/initiator_msg.go @@ -33,25 +33,40 @@ func (p Protocol) String() string { return string(p) } +// mapping of key types → supported protocols +var supportedProtocols = map[KeyType][]Protocol{ + KeyTypeSecp256k1: { + ProtocolGG18, + ProtocolCGGMP21, + ProtocolFROST, + ProtocolTaproot, + }, + KeyTypeEd25519: { + ProtocolGG18, + }, +} + // ValidateKeyProtocol checks if a key type supports a given protocol. func ValidateKeyProtocol(keyType KeyType, protocol Protocol) error { if keyType == "" || protocol == "" { return errors.New("key_type and protocol are required") } - switch keyType { - case KeyTypeSecp256k1: - if protocol != ProtocolGG18 && protocol != ProtocolCGGMP21 { - return fmt.Errorf("protocol %q not supported for key_type %q; expected gg18 or cggmp21", protocol, keyType) - } - case KeyTypeEd25519: - if protocol != ProtocolFROST && protocol != ProtocolTaproot { - return fmt.Errorf("protocol %q not supported for key_type %q; expected frost or taproot", protocol, keyType) - } - default: + supported, ok := supportedProtocols[keyType] + if !ok { return fmt.Errorf("unsupported key_type %q", keyType) } - return nil + + for _, p := range supported { + if p == protocol { + return nil // valid combo + } + } + + return fmt.Errorf( + "protocol %q not supported for key_type %q; expected one of %v", + protocol, keyType, supported, + ) } // InitiatorMessage is anything that carries a payload to verify and its signature. From bfd2ebf85c1194d7e8ee20b1ec20715d210751d0 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 7 Nov 2025 11:24:11 +0700 Subject: [PATCH 18/21] feat: cggmp21 presign pool with auto cleanup --- INSTALLATION.md | 45 +++++ pkg/eventconsumer/event_consumer.go | 6 +- pkg/mpc/node.go | 2 +- pkg/mpc/taurus/cggmp21.go | 118 ++++++++----- pkg/presign/presign.go | 253 +++++++++++----------------- pkg/presigninfo/presigninfo.go | 31 ++++ 6 files changed, 260 insertions(+), 195 deletions(-) diff --git a/INSTALLATION.md b/INSTALLATION.md index c13e662..00e82ae 100644 --- a/INSTALLATION.md +++ b/INSTALLATION.md @@ -168,6 +168,8 @@ Update `config.yaml`: event_initiator_pubkey: "09be5d070816aadaa1b6638cad33e819a8aed7101626f6bf1e0b427412c3408a" ``` +> 💡 **Note**: If you plan to use the presign pool worker (see [Presign Pool Worker](#presign-pool-worker) section), you'll need the `event_initiator.key` file (or `event_initiator.key.age` if encrypted) to be available. The private key file is generated alongside the identity file. + --- ## Configure Node Identities @@ -274,6 +276,49 @@ mpcium start -n node2 --- +## Presign Pool Worker + +The presign pool worker automatically maintains a pool of presignatures for hot wallets, ensuring they're ready for immediate use. + +### Setup + +To enable the presign pool worker on a node: + +1. **Copy the event initiator private key** to the node directory: + + If you generated the initiator with encryption: + ```bash + # Decrypt the key first + age --decrypt -o event_initiator.key event_initiator.key.age + ``` + + Then copy it to the node directory: + ```bash + cp event_initiator.key node0/ + ``` + + If you generated the initiator without encryption: + ```bash + cp event_initiator.key node0/ + ``` + +2. **Start the node with the `--presign-pool-worker` flag**: + + ```bash + cd node0 + mpcium start -n node0 --presign-pool-worker + ``` + +> ⚠️ **Important**: Only one node in the cluster should run the presign pool worker. The node running this worker must have the `event_initiator.key` file in its working directory. + +### How It Works + +- The worker monitors hot wallet activity and automatically generates presignatures when needed +- It maintains a pool of presignatures between `MinPoolSize` (default: 5) and `MaxPoolSize` (default: 20) +- The worker subscribes to hot wallet events and proactively refills the presignature pool + +--- + ## Production Deployment (High Security) 1. Use production-grade **NATS** and **Consul** clusters. diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index fadaa97..8e54a15 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -663,7 +663,7 @@ func (ec *eventConsumer) consumePresignEvent() error { } ctx := context.Background() - success, err := session.Presign(ctx, msg.WalletID) + success, err := session.Presign(ctx, msg.TxID) if err != nil { ec.handlePresignSessionError(msg.WalletID, err, "Presign operation failed", @@ -708,7 +708,7 @@ func (ec *eventConsumer) handlePresignSessionSuccess(walletID string, txID strin } err = ec.presignResultQueue.Enqueue(event.PresignResultTopic, presignResultBytes, &messaging.EnqueueOptions{ - IdempotententKey: composePresignIdempotentKey(walletID, natMsg), + IdempotententKey: composePresignIdempotentKey(txID, natMsg), }) if err != nil { logger.Error("Failed to enqueue presign result event", err, @@ -717,7 +717,7 @@ func (ec *eventConsumer) handlePresignSessionSuccess(walletID string, txID strin ) } // Presign events don't use reply inboxes, so no need to send reply - logger.Info("[COMPLETED PRESIGN] Presign completed successfully", "walletID", walletID) + logger.Info("[COMPLETED PRESIGN] Presign completed successfully", "walletID", walletID, "txID", txID) } // handlePresignSessionError handles errors that occur during presign operations diff --git a/pkg/mpc/node.go b/pkg/mpc/node.go index 4e92200..4cb2b49 100644 --- a/pkg/mpc/node.go +++ b/pkg/mpc/node.go @@ -159,7 +159,7 @@ func (p *Node) CreateTaurusSession( switch protocol { case types.ProtocolCGGMP21: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.CGGMP21, p.pubSub, p.direct, p.identityStore) - session = taurus.NewCGGMP21Session(walletID, selfPartyID, allPartyIDs, threshold, nil, tr, p.kvstore, p.keyinfoStore) + session = taurus.NewCGGMP21Session(walletID, selfPartyID, allPartyIDs, threshold, p.presignInfoStore, tr, p.kvstore, p.keyinfoStore) case types.ProtocolTaproot: tr := taurus.NewNATSTransport(walletID, selfPartyID, act, taurus.FROSTTaproot, p.pubSub, p.direct, p.identityStore) session = taurus.NewTaprootSession(walletID, selfPartyID, allPartyIDs, threshold, tr, p.kvstore, p.keyinfoStore) diff --git a/pkg/mpc/taurus/cggmp21.go b/pkg/mpc/taurus/cggmp21.go index adad007..019cdad 100644 --- a/pkg/mpc/taurus/cggmp21.go +++ b/pkg/mpc/taurus/cggmp21.go @@ -4,6 +4,7 @@ import ( "context" cryptoEcdsa "crypto/ecdsa" "crypto/sha256" + "encoding/binary" "errors" "fmt" "math/big" @@ -145,40 +146,26 @@ func (p *CGGMP21Session) Sign(ctx context.Context, msg *big.Int) ([]byte, error) if p.savedData == nil { return nil, errors.New("no key loaded") } + logger.Info("Starting CGGMP21 sign", "walletID", p.sessionID) - logger.Info("starting CGGMP21 sign", "walletID", p.sessionID) msgHash := msg.Bytes() - var ( - result any - err error + sigResult any + err error ) - // Try deterministic presign selection across nodes + // Try presign path if store available if p.presignInfoStore != nil { - presig, txID := p.selectAndLoadPresign(msgHash) - if presig != nil && txID != "" { - logger.Info("using presign for signing", "walletID", p.sessionID, "txID", txID) - result, err = p.run(ctx, cmp.PresignOnline(p.savedData, presig, msgHash, p.workerPool)) - if err != nil { - return nil, fmt.Errorf("presign online failed: %w", err) - } - // Mark used (best-effort) - _ = p.markPresignUsed(txID) - } else { - result, err = p.run(ctx, cmp.Sign(p.savedData, p.peerIDs, msgHash, p.workerPool)) - if err != nil { - return nil, fmt.Errorf("full sign failed: %w", err) - } - } + sigResult, err = p.signWithPresign(ctx, msgHash) } else { - result, err = p.run(ctx, cmp.Sign(p.savedData, p.peerIDs, msgHash, p.workerPool)) - if err != nil { - return nil, fmt.Errorf("full sign failed: %w", err) - } + sigResult, err = p.signFull(ctx, msgHash) + } + if err != nil { + return nil, err } - sig, ok := result.(*ecdsa.Signature) + // Cast and verify + sig, ok := sigResult.(*ecdsa.Signature) if !ok { return nil, errors.New("unexpected result type") } @@ -274,46 +261,93 @@ func (p *CGGMP21Session) composePresignKey(sid, txID string) string { return fmt.Sprintf("cggmp21:%s:%s", sid, txID) } -// selectAndLoadPresign deterministically chooses a presign txID and loads its PreSignature from KV. -// Selection: sort by CreatedAt asc, TxID asc; pick index = hash(msgHash) mod len(list). -func (p *CGGMP21Session) selectAndLoadPresign(msgHash []byte) (*ecdsa.PreSignature, string) { +// signWithPresign tries to select and use an existing presign for signing. +// If no valid presign exists, falls back to full sign. +func (p *CGGMP21Session) signWithPresign(ctx context.Context, msgHash []byte) (any, error) { + presig, txID := p.selectAndLoadPresign() + if presig == nil || txID == "" { + logger.Debug("No presign found, fallback to full sign", "walletID", p.sessionID) + return p.signFull(ctx, msgHash) + } + + logger.Info("Using presign for signing", "walletID", p.sessionID, "txID", txID) + result, err := p.run(ctx, cmp.PresignOnline(p.savedData, presig, msgHash, p.workerPool)) + if err != nil { + return nil, fmt.Errorf("presign online failed: %w", err) + } + + // Mark and cleanup in background (best effort) + go func() { + if err := p.markPresignUsed(txID); err != nil { + logger.Warn("mark presign used failed", "walletID", p.sessionID, "txID", txID, "err", err) + } + if err := p.deletePresign(txID); err != nil { + logger.Warn("delete presign failed", "walletID", p.sessionID, "txID", txID, "err", err) + } + }() + + return result, nil +} + +// signFull executes a full CGGMP21 signing round. +func (p *CGGMP21Session) signFull(ctx context.Context, msgHash []byte) (any, error) { + logger.Info("Executing full CGGMP21 signing", "walletID", p.sessionID) + result, err := p.run(ctx, cmp.Sign(p.savedData, p.peerIDs, msgHash, p.workerPool)) + if err != nil { + return nil, fmt.Errorf("full sign failed: %w", err) + } + return result, nil +} + +func (p *CGGMP21Session) selectAndLoadPresign() (*ecdsa.PreSignature, string) { infos, err := p.presignInfoStore.ListPendingPresigns(p.sessionID) if err != nil || len(infos) == 0 { return nil, "" } - // filter by active + protocol/keytype - filtered := make([]*presigninfo.PresignInfo, 0, len(infos)) + + // Filter usable presigns + var filtered []*presigninfo.PresignInfo for _, inf := range infos { - if inf.Status == presigninfo.PresignStatusActive && inf.Protocol == types.ProtocolCGGMP21 && inf.KeyType == types.KeyTypeSecp256k1 { + if inf.Status == presigninfo.PresignStatusActive && + inf.Protocol == types.ProtocolCGGMP21 && + inf.KeyType == types.KeyTypeSecp256k1 { filtered = append(filtered, inf) } } if len(filtered) == 0 { return nil, "" } - // sort + + // Sort + pick deterministically sort.Slice(filtered, func(i, j int) bool { if filtered[i].CreatedAt.Equal(filtered[j].CreatedAt) { return filtered[i].TxID < filtered[j].TxID } return filtered[i].CreatedAt.Before(filtered[j].CreatedAt) }) - // pick index via hash - h := sha256.Sum256(msgHash) - idx := int(h[0]) % len(filtered) + h := sha256.Sum256([]byte(p.sessionID)) + idx := int(binary.BigEndian.Uint64(h[:8]) % uint64(len(filtered))) chosen := filtered[idx] - // load material - bytes, err := p.kvstore.Get(p.composePresignKey(p.sessionID, chosen.TxID)) - if err != nil || len(bytes) == 0 { + + // Load presign from KV + key := p.composePresignKey(p.sessionID, chosen.TxID) + data, err := p.kvstore.Get(key) + if err != nil || len(data) == 0 { + logger.Warn("presign missing", "walletID", p.sessionID, "txID", chosen.TxID, "err", err) return nil, "" } - presig := new(ecdsa.PreSignature) - if err := cbor.Unmarshal(bytes, presig); err != nil { + + presig := ecdsa.EmptyPreSignature(curve.Secp256k1{}) + if err := cbor.Unmarshal(data, presig); err != nil { + logger.Warn("unmarshal presign failed", "walletID", p.sessionID, "txID", chosen.TxID, "err", err) return nil, "" } if err := presig.Validate(); err != nil { + logger.Warn("presign invalid", "walletID", p.sessionID, "txID", chosen.TxID, "err", err) return nil, "" } + + logger.Debug("Presign chosen", "walletID", p.sessionID, "txID", chosen.TxID, "idx", idx) return presig, chosen.TxID } @@ -327,3 +361,7 @@ func (p *CGGMP21Session) markPresignUsed(txID string) error { info.UsedAt = &now return p.presignInfoStore.Save(p.sessionID, info) } + +func (p *CGGMP21Session) deletePresign(txID string) error { + return p.kvstore.Delete(p.composePresignKey(p.sessionID, txID)) +} diff --git a/pkg/presign/presign.go b/pkg/presign/presign.go index c198be9..59063db 100644 --- a/pkg/presign/presign.go +++ b/pkg/presign/presign.go @@ -15,21 +15,26 @@ import ( ) type Config struct { - MinPoolSize int - MaxPoolSize int - GlobalMaxConcurrency int - PerWalletMaxConcurrency int - HotWindowDuration time.Duration - RefillInterval time.Duration + MinPoolSize int + MaxPoolSize int + GlobalMaxConcurrency int + HotWindowDuration time.Duration + RefillInterval time.Duration + ThrottleDelay time.Duration } var DefaultConfig = Config{ - MinPoolSize: 5, - MaxPoolSize: 20, - GlobalMaxConcurrency: 20, - PerWalletMaxConcurrency: 5, - HotWindowDuration: 5 * time.Minute, - RefillInterval: 10 * time.Second, + MinPoolSize: 5, + MaxPoolSize: 20, + GlobalMaxConcurrency: 10, + HotWindowDuration: 5 * time.Minute, + RefillInterval: 15 * time.Second, + ThrottleDelay: 5 * time.Second, +} + +type walletState struct { + lastTouch time.Time + pendingCount int } type PresignPool struct { @@ -41,9 +46,7 @@ type PresignPool struct { wg sync.WaitGroup mu sync.RWMutex - hot map[string]time.Time - pending map[string]int - cache map[string][]*presigninfo.PresignInfo + wallets map[string]*walletState globalSem *semaphore.Weighted } @@ -52,7 +55,6 @@ func NewPresignPool(cfg *Config, client client.MPCClient, infoStore presigninfo. tmp := DefaultConfig cfg = &tmp } - ctx, cancel := context.WithCancel(context.Background()) p := &PresignPool{ cfg: cfg, @@ -60,29 +62,23 @@ func NewPresignPool(cfg *Config, client client.MPCClient, infoStore presigninfo. infoStore: infoStore, ctx: ctx, cancel: cancel, - hot: make(map[string]time.Time), - pending: make(map[string]int), - cache: make(map[string][]*presigninfo.PresignInfo), + wallets: make(map[string]*walletState), globalSem: semaphore.NewWeighted(int64(cfg.GlobalMaxConcurrency)), } - // Subscribe to presign completion events + // Subscribe to presign completion if err := p.client.OnPresignResult(func(evt event.PresignResultEvent) { - p.OnPresignCompleted(evt.WalletID, evt.TxID, evt.ResultType == event.ResultTypeSuccess) + p.handlePresignResult(evt.WalletID, evt.TxID, evt.ResultType == event.ResultTypeSuccess) }); err != nil { - logger.Error("subscribe presign handler failed", err) + logger.Error("[PRESIGN] subscribe handler failed", err) } - return p } func (p *PresignPool) Start(ctx context.Context) { - logger.Info("[PRESIGN] Presign pool worker started") - - p.wg.Add(2) - go p.refillLoop() - go p.cleanupLoop() - + logger.Info("[PRESIGN] Pool started") + p.wg.Add(1) + go p.mainLoop() go func() { <-ctx.Done() p.Stop() @@ -92,17 +88,20 @@ func (p *PresignPool) Start(ctx context.Context) { func (p *PresignPool) Stop() { p.cancel() p.wg.Wait() - logger.Info("presign pool stopped") + logger.Info("[PRESIGN] Pool stopped") } func (p *PresignPool) TouchHot(walletID string) { p.mu.Lock() defer p.mu.Unlock() - p.hot[walletID] = time.Now() - logger.Info("hot wallet detected", "wallet", walletID) + if state, ok := p.wallets[walletID]; ok { + state.lastTouch = time.Now() + } else { + p.wallets[walletID] = &walletState{lastTouch: time.Now()} + } } -func (p *PresignPool) refillLoop() { +func (p *PresignPool) mainLoop() { defer p.wg.Done() ticker := time.NewTicker(p.cfg.RefillInterval) defer ticker.Stop() @@ -110,89 +109,67 @@ func (p *PresignPool) refillLoop() { for { select { case <-ticker.C: - p.refill() + p.refillAll() case <-p.ctx.Done(): return } } } -func (p *PresignPool) cleanupLoop() { - defer p.wg.Done() - ticker := time.NewTicker(10 * time.Minute) - defer ticker.Stop() - - for { +func (p *PresignPool) refillAll() { + for _, walletID := range p.getHotWallets() { select { - case <-ticker.C: - p.cleanup() - p.syncCache() // refresh cache periodically case <-p.ctx.Done(): return + default: + p.refillWallet(walletID) + time.Sleep(2 * time.Second) } } } -func (p *PresignPool) refill() { - now := time.Now() - - p.mu.RLock() - wallets := make([]string, 0, len(p.hot)) - for w, t := range p.hot { - if now.Sub(t) < p.cfg.HotWindowDuration { - wallets = append(wallets, w) - } +func (p *PresignPool) refillWallet(walletID string) { + list, err := p.infoStore.ListPendingPresigns(walletID) + if err != nil { + logger.Warn("[PRESIGN] list presigns failed", "wallet", walletID, "err", err) + return } - p.mu.RUnlock() - // refill sequentially, with delay between wallets - for i, wallet := range wallets { - if i > 0 { - time.Sleep(2 * time.Second) // spacing between wallets - } - if err := p.refillWallet(wallet); err != nil { - logger.Warn("refill failed", "wallet", wallet, "err", err) + activeCount := 0 + for _, info := range list { + if info.Status == presigninfo.PresignStatusActive { + activeCount++ } } -} - -func (p *PresignPool) refillWallet(walletID string) error { - list := p.getPresignListCached(walletID) - var ready int - for _, m := range list { - if m.Status == presigninfo.PresignStatusActive { - ready++ - } + p.mu.RLock() + pendingCount := 0 + if st := p.wallets[walletID]; st != nil { + pendingCount = st.pendingCount } + p.mu.RUnlock() - if ready >= p.cfg.MinPoolSize { - return nil + total := activeCount + pendingCount + if total >= p.cfg.MinPoolSize { + return } - - // Skip if there is a pending request - p.mu.Lock() - if p.pending[walletID] > 0 { - p.mu.Unlock() - return nil + if pendingCount > 0 { + return } - p.pending[walletID] = 1 - p.mu.Unlock() - logger.Info("refilling presign", "wallet", walletID, "current", ready, "target", p.cfg.MinPoolSize) + p.incrementPending(walletID) go p.requestPresign(walletID) - return nil } func (p *PresignPool) requestPresign(walletID string) { if err := p.globalSem.Acquire(p.ctx, 1); err != nil { - p.decrement(walletID) + logger.Warn("[PRESIGN] semaphore acquire failed", "wallet", walletID, "err", err) + p.decrementPending(walletID) return } defer p.globalSem.Release(1) - // throttle lightly to allow other operations to continue - time.Sleep(3 * time.Second) + time.Sleep(p.cfg.ThrottleDelay) txID := "presign_" + uuid.NewString() req := &types.PresignTxMessage{ @@ -201,20 +178,20 @@ func (p *PresignPool) requestPresign(walletID string) { WalletID: walletID, TxID: txID, } - if err := p.client.PresignTransaction(req); err != nil { - logger.Error("publish presign failed", err, "wallet", walletID) - p.decrement(walletID) + logger.Warn("[PRESIGN] presign publish failed", "wallet", walletID, "tx", txID, "err", err) + p.decrementPending(walletID) return } - - logger.Debug("presign request sent", "wallet", walletID, "tx_id", txID) + logger.Debug("[PRESIGN] presign sent", "wallet", walletID, "tx", txID) } -func (p *PresignPool) OnPresignCompleted(walletID, txID string, success bool) { - p.decrement(walletID) +func (p *PresignPool) handlePresignResult(walletID, txID string, success bool) { + p.decrementPending(walletID) + if !success { - logger.Warn("presign failed", "wallet", walletID, "tx_id", txID) + logger.Warn("[PRESIGN] presign failed", "wallet", walletID, "tx", txID) + _ = p.infoStore.Delete(walletID, txID) // cleanup failed presign return } @@ -226,84 +203,58 @@ func (p *PresignPool) OnPresignCompleted(walletID, txID string, success bool) { Status: presigninfo.PresignStatusActive, CreatedAt: time.Now(), } + if err := p.infoStore.Save(walletID, info); err != nil { + logger.Warn("[PRESIGN] save failed", "wallet", walletID, "tx", txID, "err", err) + return + } - // update cache - p.mu.Lock() - p.cache[walletID] = append(p.cache[walletID], info) - p.mu.Unlock() - - // async write to Consul KV - go func() { - if err := p.infoStore.Save(walletID, info); err != nil { - logger.Error("save presign info failed", err, "wallet", walletID) - } - }() + // Clean up expired/used presigns + p.cleanupUsed(walletID) + logger.Debug("[PRESIGN] presign done", "wallet", walletID, "tx", txID) } -func (p *PresignPool) getPresignListCached(walletID string) []*presigninfo.PresignInfo { - p.mu.RLock() - cached := p.cache[walletID] - p.mu.RUnlock() - - if len(cached) > 0 { - return cached - } - - // cache miss → load from Consul +func (p *PresignPool) cleanupUsed(walletID string) { list, err := p.infoStore.ListPendingPresigns(walletID) if err != nil { - logger.Warn("load presign list failed", "wallet", walletID, "err", err) - return nil + return + } + for _, inf := range list { + if inf.Status == presigninfo.PresignStatusUsed { + _ = p.infoStore.Delete(walletID, inf.TxID) + logger.Debug("[PRESIGN] cleaned used presign", "wallet", walletID, "tx", inf.TxID) + } } - - p.mu.Lock() - p.cache[walletID] = list - p.mu.Unlock() - return list } -// sync cache periodically -func (p *PresignPool) syncCache() { - wallets := p.GetHotWalletsSnapshot() - for _, wallet := range wallets { - list, err := p.infoStore.ListPendingPresigns(wallet) - if err != nil { - logger.Warn("sync cache failed", "wallet", wallet, "err", err) - continue +func (p *PresignPool) getHotWallets() []string { + now := time.Now() + p.mu.RLock() + defer p.mu.RUnlock() + hot := make([]string, 0, len(p.wallets)) + for id, st := range p.wallets { + if now.Sub(st.lastTouch) < p.cfg.HotWindowDuration { + hot = append(hot, id) } - p.mu.Lock() - p.cache[wallet] = list - p.mu.Unlock() } + return hot } -func (p *PresignPool) decrement(walletID string) { +func (p *PresignPool) incrementPending(walletID string) { p.mu.Lock() defer p.mu.Unlock() - if n := p.pending[walletID]; n > 1 { - p.pending[walletID] = n - 1 + if st, ok := p.wallets[walletID]; ok { + st.pendingCount++ } else { - delete(p.pending, walletID) + p.wallets[walletID] = &walletState{lastTouch: time.Now(), pendingCount: 1} } } -func (p *PresignPool) cleanup() { - now := time.Now() +func (p *PresignPool) decrementPending(walletID string) { p.mu.Lock() - for w, t := range p.hot { - if now.Sub(t) >= p.cfg.HotWindowDuration { - delete(p.hot, w) - } + defer p.mu.Unlock() + if st, ok := p.wallets[walletID]; ok && st.pendingCount > 0 { + st.pendingCount-- } - p.mu.Unlock() } -func (p *PresignPool) GetHotWalletsSnapshot() []string { - p.mu.RLock() - defer p.mu.RUnlock() - keys := make([]string, 0, len(p.hot)) - for k := range p.hot { - keys = append(keys, k) - } - return keys -} +func (p *PresignPool) GetHotWalletsSnapshot() []string { return p.getHotWallets() } diff --git a/pkg/presigninfo/presigninfo.go b/pkg/presigninfo/presigninfo.go index 0489ff2..059b204 100644 --- a/pkg/presigninfo/presigninfo.go +++ b/pkg/presigninfo/presigninfo.go @@ -3,6 +3,7 @@ package presigninfo import ( "encoding/json" "fmt" + "strings" "time" "github.com/fystack/mpcium/pkg/infra" @@ -36,7 +37,9 @@ func NewStore(consulKV infra.ConsulKV) *store { type Store interface { Get(walletID string, txID string) (*PresignInfo, error) Save(walletID string, info *PresignInfo) error + Delete(walletID string, txID string) error ListPendingPresigns(walletID string) ([]*PresignInfo, error) + ListAllWallets() ([]string, error) } func (s *store) Get(walletID string, txID string) (*PresignInfo, error) { @@ -103,6 +106,34 @@ func (s *store) ListPendingPresigns(walletID string) ([]*PresignInfo, error) { return infos, nil } +// ListAllWallets returns all wallet IDs that have presigns in Consul KV +func (s *store) ListAllWallets() ([]string, error) { + entries, _, err := s.consulKV.List("presign_info/", nil) + if err != nil { + return nil, fmt.Errorf("Failed to list all presign info: %w", err) + } + + walletMap := make(map[string]bool) + for _, entry := range entries { + // Key format: presign_info/{walletID}/{txID} + // Extract walletID from the key + key := entry.Key + parts := strings.Split(key, "/") + if len(parts) >= 2 && parts[0] == "presign_info" { + walletID := parts[1] + if walletID != "" { + walletMap[walletID] = true + } + } + } + + wallets := make([]string, 0, len(walletMap)) + for walletID := range walletMap { + wallets = append(wallets, walletID) + } + return wallets, nil +} + func (s *store) composeKey(walletID string, txID string) string { return fmt.Sprintf("presign_info/%s/%s", walletID, txID) } From ed52f34faa3014effecfe3d62ccab27c29416b6a Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 7 Nov 2025 11:28:27 +0700 Subject: [PATCH 19/21] =?UTF-8?q?fix:=20resolve=20gosec=20uint64=E2=86=92i?= =?UTF-8?q?nt=20conversion=20warning=20in=20deterministic=20presign=20sele?= =?UTF-8?q?ction?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/mpc/taurus/cggmp21.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/mpc/taurus/cggmp21.go b/pkg/mpc/taurus/cggmp21.go index 019cdad..46842a5 100644 --- a/pkg/mpc/taurus/cggmp21.go +++ b/pkg/mpc/taurus/cggmp21.go @@ -326,8 +326,10 @@ func (p *CGGMP21Session) selectAndLoadPresign() (*ecdsa.PreSignature, string) { return filtered[i].CreatedAt.Before(filtered[j].CreatedAt) }) h := sha256.Sum256([]byte(p.sessionID)) - idx := int(binary.BigEndian.Uint64(h[:8]) % uint64(len(filtered))) - chosen := filtered[idx] + v := binary.BigEndian.Uint64(h[:8]) + n := uint64(len(filtered)) + idx := v % n + chosen := filtered[int(idx)] // Load presign from KV key := p.composePresignKey(p.sessionID, chosen.TxID) From bbfddde1cc91eef048a656a65cbed1759a3e0d33 Mon Sep 17 00:00:00 2001 From: vietddude Date: Fri, 7 Nov 2025 11:34:46 +0700 Subject: [PATCH 20/21] fix: update presign selection to use int32 for hash index calculation --- pkg/mpc/taurus/cggmp21.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/mpc/taurus/cggmp21.go b/pkg/mpc/taurus/cggmp21.go index 46842a5..40818f4 100644 --- a/pkg/mpc/taurus/cggmp21.go +++ b/pkg/mpc/taurus/cggmp21.go @@ -326,10 +326,9 @@ func (p *CGGMP21Session) selectAndLoadPresign() (*ecdsa.PreSignature, string) { return filtered[i].CreatedAt.Before(filtered[j].CreatedAt) }) h := sha256.Sum256([]byte(p.sessionID)) - v := binary.BigEndian.Uint64(h[:8]) - n := uint64(len(filtered)) - idx := v % n - chosen := filtered[int(idx)] + hashVal := int64(binary.BigEndian.Uint32(h[:4])) + idx := int(hashVal % int64(len(filtered))) + chosen := filtered[idx] // Load presign from KV key := p.composePresignKey(p.sessionID, chosen.TxID) From 7851ea0bafa881b86311df68d075e51af43cb4b8 Mon Sep 17 00:00:00 2001 From: vietddude Date: Wed, 19 Nov 2025 10:23:32 +0700 Subject: [PATCH 21/21] refactor: update identity file loading to read and parse peers.json dynamically --- pkg/identity/identity.go | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 473e6b5..5509932 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -100,19 +100,14 @@ func NewFileStore(identityDir, nodeName string, decrypt bool, agePasswordFile st } // Load peers.json to validate all nodes have identity files - // peersData, err := os.ReadFile("peers.json") - // if err != nil { - // return nil, fmt.Errorf("failed to read peers.json: %w", err) - // } - - peers := map[string]string{ - "node0": "aa4adaea-257d-4337-842a-1d3f966d85c2", - "node1": "21ac5259-ac9e-4b81-bd42-d05f584879e4", - "node2": "2fff5119-a1f1-4763-8f4c-d7d88c212608", - } - // if err := json.Unmarshal(peersData, &peers); err != nil { - // return nil, fmt.Errorf("failed to parse peers.json: %w", err) - // } + peersData, err := os.ReadFile("peers.json") + if err != nil { + return nil, fmt.Errorf("failed to read peers.json: %w", err) + } + peers := make(map[string]string) + if err := json.Unmarshal(peersData, &peers); err != nil { + return nil, fmt.Errorf("failed to parse peers.json: %w", err) + } store := &fileStore{ identityDir: identityDir,