diff --git a/go.mod b/go.mod index c039cef9..4c8b07e0 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( filippo.io/edwards25519 v1.1.0 github.com/aws/aws-sdk-go-v2 v0.17.0 github.com/bits-and-blooms/bloom/v3 v3.1.0 - github.com/code-payments/code-protobuf-api v1.19.1-0.20250605155512-63da5d11d58a + github.com/code-payments/code-protobuf-api v1.19.1-0.20250610140050-4cadbcc86f16 github.com/code-payments/code-vm-indexer v0.1.11-0.20241028132209-23031e814fba github.com/emirpasic/gods v1.12.0 github.com/envoyproxy/protoc-gen-validate v1.2.1 diff --git a/go.sum b/go.sum index 1bc0ece8..9f8ddbb8 100644 --- a/go.sum +++ b/go.sum @@ -80,8 +80,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/code-payments/code-protobuf-api v1.19.1-0.20250605155512-63da5d11d58a h1:h5AFZjmn+Zzkkd0u2Y+h9msj7HYBOSI3l4i5CD0ls34= -github.com/code-payments/code-protobuf-api v1.19.1-0.20250605155512-63da5d11d58a/go.mod h1:ee6TzKbgMS42ZJgaFEMG3c4R3dGOiffHSu6MrY7WQvs= +github.com/code-payments/code-protobuf-api v1.19.1-0.20250610140050-4cadbcc86f16 h1:drAMKRdbyObW8E4H6xc1pKIDxoFYgpaTdMlEnIKBIJ0= +github.com/code-payments/code-protobuf-api v1.19.1-0.20250610140050-4cadbcc86f16/go.mod h1:ee6TzKbgMS42ZJgaFEMG3c4R3dGOiffHSu6MrY7WQvs= github.com/code-payments/code-vm-indexer v0.1.11-0.20241028132209-23031e814fba h1:Bkp+gmeb6Y2PWXfkSCTMBGWkb2P1BujRDSjWeI+0j5I= github.com/code-payments/code-vm-indexer v0.1.11-0.20241028132209-23031e814fba/go.mod h1:jSiifpiBpyBQ8q2R0MGEbkSgWC6sbdRTyDBntmW+j1E= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6 h1:NmTXa/uVnDyp0TY5MKi197+3HWcnYWfnHGyaFthlnGw= diff --git a/pkg/code/async/account/gift_card.go b/pkg/code/async/account/gift_card.go index 7fca28fc..8d224f15 100644 --- a/pkg/code/async/account/gift_card.go +++ b/pkg/code/async/account/gift_card.go @@ -15,6 +15,7 @@ import ( commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + "github.com/code-payments/code-server/pkg/code/balance" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" @@ -93,6 +94,12 @@ func (p *service) maybeInitiateGiftCardAutoReturn(ctx context.Context, accountIn return err } + balanceLock, err := balance.GetOptimisticVersionLock(ctx, p.data, giftCardVaultAccount) + if err != nil { + log.WithError(err).Warn("failure getting balance lock") + return err + } + _, err = p.data.GetGiftCardClaimedAction(ctx, giftCardVaultAccount.PublicKey().ToBase58()) if err == nil { log.Trace("gift card is claimed and will be removed from worker queue") @@ -124,7 +131,7 @@ func (p *service) maybeInitiateGiftCardAutoReturn(ctx context.Context, accountIn // There's no action to claim the gift card and the expiry window has been met. // It's time to initiate the process of auto-returning the funds back to the // issuer. - err = InitiateProcessToAutoReturnGiftCard(ctx, p.data, giftCardVaultAccount, false) + err = InitiateProcessToAutoReturnGiftCard(ctx, p.data, giftCardVaultAccount, false, balanceLock) if err != nil { log.WithError(err).Warn("failure initiating process to return gift card balance to issuer") return err @@ -138,7 +145,7 @@ func (p *service) maybeInitiateGiftCardAutoReturn(ctx context.Context, accountIn // a good guide for similar actions in the future. // // todo: This probably belongs somewhere more common -func InitiateProcessToAutoReturnGiftCard(ctx context.Context, data code_data.Provider, giftCardVaultAccount *common.Account, isVoidedByUser bool) error { +func InitiateProcessToAutoReturnGiftCard(ctx context.Context, data code_data.Provider, giftCardVaultAccount *common.Account, isVoidedByUser bool, balanceLock *balance.OptimisticVersionLock) error { return data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { giftCardIssuedIntent, err := data.GetOriginalGiftCardIssuedIntent(ctx, giftCardVaultAccount.PublicKey().ToBase58()) if err != nil { @@ -199,7 +206,12 @@ func InitiateProcessToAutoReturnGiftCard(ctx context.Context, data code_data.Pro // This will trigger the fulfillment worker to poll for the fulfillment. This // should be the very last DB update called. - return markFulfillmentAsActivelyScheduled(ctx, data, autoReturnFulfillment[0]) + err = markFulfillmentAsActivelyScheduled(ctx, data, autoReturnFulfillment[0]) + if err != nil { + return err + } + + return balanceLock.OnCommit(ctx, data) }) } diff --git a/pkg/code/balance/calculator.go b/pkg/code/balance/calculator.go index f6fc3e9f..70af9954 100644 --- a/pkg/code/balance/calculator.go +++ b/pkg/code/balance/calculator.go @@ -55,6 +55,28 @@ type State struct { current int64 } +// Calculate calculates a token account's balance using a starting point and a set +// of strategies. Each may be incomplete individually, but in total must form a +// complete balance calculation. +func Calculate(ctx context.Context, tokenAccount *common.Account, initialBalance uint64, strategies ...Strategy) (balance uint64, err error) { + balanceState := &State{ + current: int64(initialBalance), + } + + for _, strategy := range strategies { + balanceState, err = strategy(ctx, tokenAccount, balanceState) + if err != nil { + return 0, err + } + } + + if balanceState.current < 0 { + return 0, ErrNegativeBalance + } + + return uint64(balanceState.current), nil +} + // CalculateFromCache is the default and recommended strategy for reliably estimating // a token account's balance using cached values. // @@ -168,28 +190,6 @@ func CalculateFromBlockchain(ctx context.Context, data code_data.Provider, token return quarks, BlockchainSource, nil } -// Calculate calculates a token account's balance using a starting point and a set -// of strategies. Each may be incomplete individually, but in total must form a -// complete balance calculation. -func Calculate(ctx context.Context, tokenAccount *common.Account, initialBalance uint64, strategies ...Strategy) (balance uint64, err error) { - balanceState := &State{ - current: int64(initialBalance), - } - - for _, strategy := range strategies { - balanceState, err = strategy(ctx, tokenAccount, balanceState) - if err != nil { - return 0, err - } - } - - if balanceState.current < 0 { - return 0, ErrNegativeBalance - } - - return uint64(balanceState.current), nil -} - // NetBalanceFromIntentActions is a balance calculation strategy that incorporates // the net balance by applying payment intents to the current balance. func NetBalanceFromIntentActions(ctx context.Context, data code_data.Provider) Strategy { @@ -243,14 +243,39 @@ type BatchState struct { current map[string]int64 } +// CalculateBatch calculates a set of token accounts' balance using a starting point +// and a set of strategies. Each may be incomplete individually, but in total must +// form a complete balance calculation. +func CalculateBatch(ctx context.Context, tokenAccounts []string, strategies ...BatchStrategy) (balanceByTokenAccount map[string]uint64, err error) { + balanceState := &BatchState{ + current: make(map[string]int64), + } + + for _, strategy := range strategies { + balanceState, err = strategy(ctx, tokenAccounts, balanceState) + if err != nil { + return nil, err + } + } + + res := make(map[string]uint64) + for tokenAccount, balance := range balanceState.current { + if balance < 0 { + return nil, ErrNegativeBalance + } + + res[tokenAccount] = uint64(balance) + } + + return res, nil +} + // BatchCalculateFromCacheWithAccountRecords is the default and recommended batch strategy // or reliably estimating a set of token accounts' balance when common.AccountRecords are // available. // // Note: Use this method when calculating balances for accounts that are managed by // Code (ie. Timelock account) and operate within the L2 system. -// -// Note: This only supports post-privacy accounts. Use CalculateFromCache instead. func BatchCalculateFromCacheWithAccountRecords(ctx context.Context, data code_data.Provider, accountRecordsBatch ...*common.AccountRecords) (map[string]uint64, error) { tracer := metrics.TraceMethodCall(ctx, metricsPackageName, "BatchCalculateFromCacheWithAccountRecords") defer tracer.End() @@ -279,8 +304,6 @@ func BatchCalculateFromCacheWithAccountRecords(ctx context.Context, data code_da // // Note: Use this method when calculating balances for accounts that are managed by // Code (ie. Timelock account) and operate within the L2 system. -// -// Note: This only supports post-privacy accounts. Use CalculateFromCache instead. func BatchCalculateFromCacheWithTokenAccounts(ctx context.Context, data code_data.Provider, tokenAccounts ...*common.Account) (map[string]uint64, error) { tracer := metrics.TraceMethodCall(ctx, metricsPackageName, "BatchCalculateFromCacheWithTokenAccounts") defer tracer.End() @@ -333,33 +356,6 @@ func defaultBatchCalculationFromCache(ctx context.Context, data code_data.Provid ) } -// CalculateBatch calculates a set of token accounts' balance using a starting point -// and a set of strategies. Each may be incomplete individually, but in total must -// form a complete balance calculation. -func CalculateBatch(ctx context.Context, tokenAccounts []string, strategies ...BatchStrategy) (balanceByTokenAccount map[string]uint64, err error) { - balanceState := &BatchState{ - current: make(map[string]int64), - } - - for _, strategy := range strategies { - balanceState, err = strategy(ctx, tokenAccounts, balanceState) - if err != nil { - return nil, err - } - } - - res := make(map[string]uint64) - for tokenAccount, balance := range balanceState.current { - if balance < 0 { - return nil, ErrNegativeBalance - } - - res[tokenAccount] = uint64(balance) - } - - return res, nil -} - // NetBalanceFromIntentActionsBatch is a balance calculation strategy that incorporates // the net balance by applying payment intents to the current balance. func NetBalanceFromIntentActionsBatch(ctx context.Context, data code_data.Provider) BatchStrategy { diff --git a/pkg/code/balance/lock.go b/pkg/code/balance/lock.go new file mode 100644 index 00000000..8b47a838 --- /dev/null +++ b/pkg/code/balance/lock.go @@ -0,0 +1,34 @@ +package balance + +import ( + "context" + + "github.com/code-payments/code-server/pkg/code/common" + code_data "github.com/code-payments/code-server/pkg/code/data" +) + +// OptimisticVersionLock is an optimistic version lock on an account's cached +// balance, which can be paired with DB updates against balances that need to +// be protected against race conditions. +type OptimisticVersionLock struct { + vault *common.Account + currentVersion uint64 +} + +// GetOptimisticVersionLock gets an optimistic version lock for the vault account's +// cached balance +func GetOptimisticVersionLock(ctx context.Context, data code_data.Provider, vault *common.Account) (*OptimisticVersionLock, error) { + version, err := data.GetCachedBalanceVersion(ctx, vault.PublicKey().ToBase58()) + if err != nil { + return nil, err + } + return &OptimisticVersionLock{ + vault: vault, + currentVersion: version, + }, nil +} + +// OnCommit is called in the DB transaction updating the account's cached balance +func (l *OptimisticVersionLock) OnCommit(ctx context.Context, data code_data.Provider) error { + return data.AdvanceCachedBalanceVersion(ctx, l.vault.PublicKey().ToBase58(), l.currentVersion) +} diff --git a/pkg/code/data/balance/memory/store.go b/pkg/code/data/balance/memory/store.go index 21ba7537..826b486d 100644 --- a/pkg/code/data/balance/memory/store.go +++ b/pkg/code/data/balance/memory/store.go @@ -9,14 +9,54 @@ import ( ) type store struct { - mu sync.Mutex - externalCheckpointRecords []*balance.ExternalCheckpointRecord - last uint64 + mu sync.Mutex + cachedBalanceVersionsByAccount map[string]uint64 + externalCheckpointRecords []*balance.ExternalCheckpointRecord + last uint64 } // New returns a new in memory balance.Store func New() balance.Store { - return &store{} + return &store{ + cachedBalanceVersionsByAccount: make(map[string]uint64), + } +} + +// GetCachedVersion implements balance.Store.GetCachedVersion +func (s *store) GetCachedVersion(_ context.Context, account string) (uint64, error) { + s.mu.Lock() + defer s.mu.Unlock() + + current, ok := s.cachedBalanceVersionsByAccount[account] + if !ok { + return 0, nil + } + return current, nil +} + +// AdvanceCachedVersion implements balance.Store.AdvanceCachedVersion +func (s *store) AdvanceCachedVersion(_ context.Context, account string, currentVersion uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + + actualVersion, ok := s.cachedBalanceVersionsByAccount[account] + if !ok { + if currentVersion != 0 { + return balance.ErrStaleCachedBalanceVersion + } + + s.cachedBalanceVersionsByAccount[account] = 1 + + return nil + } + + if actualVersion != currentVersion { + return balance.ErrStaleCachedBalanceVersion + } + + s.cachedBalanceVersionsByAccount[account]++ + + return nil } // SaveExternalCheckpoint implements balance.Store.SaveExternalCheckpoint @@ -87,6 +127,7 @@ func (s *store) reset() { s.mu.Lock() defer s.mu.Unlock() + s.cachedBalanceVersionsByAccount = make(map[string]uint64) s.externalCheckpointRecords = nil s.last = 0 } diff --git a/pkg/code/data/balance/postgres/model.go b/pkg/code/data/balance/postgres/model.go index 49f3f32d..514779ab 100644 --- a/pkg/code/data/balance/postgres/model.go +++ b/pkg/code/data/balance/postgres/model.go @@ -8,11 +8,13 @@ import ( "github.com/jmoiron/sqlx" "github.com/code-payments/code-server/pkg/code/data/balance" + pg "github.com/code-payments/code-server/pkg/database/postgres" pgutil "github.com/code-payments/code-server/pkg/database/postgres" ) const ( - externalCheckpointTableName = "codewallet__core_externalbalancecheckpoint" + cachedBalanceVersionTableName = "codewallet__core_cachedbalanceversion" + externalCheckpointTableName = "codewallet__core_externalbalancecheckpoint" ) type externalCheckpointModel struct { @@ -25,6 +27,49 @@ type externalCheckpointModel struct { LastUpdatedAt time.Time `db:"last_updated_at"` } +func dbGetCachedVersion(ctx context.Context, db *sqlx.DB, account string) (uint64, error) { + var res uint64 + query := `SELECT version FROM ` + cachedBalanceVersionTableName + ` + WHERE token_account = $1` + err := db.GetContext(ctx, &res, query, account) + if pg.IsNoRows(err) { + return 0, nil + } else if err != nil { + return 0, err + } + return res, nil +} + +func dbAdvanceCachedVersion(ctx context.Context, db *sqlx.DB, account string, currentVersion uint64) error { + return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { + query := `INSERT INTO ` + cachedBalanceVersionTableName + ` + (token_account, version) + VALUES ($1, 1) + RETURNING version + ` + params := []any{account} + if currentVersion > 0 { + query = `UPDATE ` + cachedBalanceVersionTableName + ` + SET version = version + 1 + WHERE token_account = $1 AND version = $2 + RETURNING version + ` + params = append(params, currentVersion) + } + + var res uint64 + err := tx.GetContext(ctx, &res, query, params...) + if pg.IsNoRows(err) || pg.IsUniqueViolation(err) { + return balance.ErrStaleCachedBalanceVersion + } + if err != nil { + return err + } + return nil + }) + +} + func toExternalCheckpointModel(obj *balance.ExternalCheckpointRecord) (*externalCheckpointModel, error) { if err := obj.Validate(); err != nil { return nil, err @@ -80,9 +125,7 @@ func (m *externalCheckpointModel) dbSave(ctx context.Context, db *sqlx.DB) error func dbGetExternalCheckpoint(ctx context.Context, db *sqlx.DB, account string) (*externalCheckpointModel, error) { res := &externalCheckpointModel{} - query := `SELECT - id, token_account, quarks, slot_checkpoint, last_updated_at - FROM ` + externalCheckpointTableName + ` + query := `SELECT id, token_account, quarks, slot_checkpoint, last_updated_at FROM ` + externalCheckpointTableName + ` WHERE token_account = $1 LIMIT 1` diff --git a/pkg/code/data/balance/postgres/store.go b/pkg/code/data/balance/postgres/store.go index 2426601a..0ce42546 100644 --- a/pkg/code/data/balance/postgres/store.go +++ b/pkg/code/data/balance/postgres/store.go @@ -20,6 +20,16 @@ func New(db *sql.DB) balance.Store { } } +// GetCachedVersion implements balance.Store.GetCachedVersion +func (s *store) GetCachedVersion(ctx context.Context, account string) (uint64, error) { + return dbGetCachedVersion(ctx, s.db, account) +} + +// AdvanceCachedVersion implements balance.Store.AdvanceCachedVersion +func (s *store) AdvanceCachedVersion(ctx context.Context, account string, currentVersion uint64) error { + return dbAdvanceCachedVersion(ctx, s.db, account, currentVersion) +} + // SaveExternalCheckpoint implements balance.Store.SaveExternalCheckpoint func (s *store) SaveExternalCheckpoint(ctx context.Context, record *balance.ExternalCheckpointRecord) error { model, err := toExternalCheckpointModel(record) diff --git a/pkg/code/data/balance/postgres/store_test.go b/pkg/code/data/balance/postgres/store_test.go index 46bc54e7..862b330f 100644 --- a/pkg/code/data/balance/postgres/store_test.go +++ b/pkg/code/data/balance/postgres/store_test.go @@ -24,6 +24,15 @@ var ( const ( // Used for testing ONLY, the table and migrations are external to this repository tableCreate = ` + CREATE TABLE codewallet__core_cachedbalanceversion ( + id SERIAL NOT NULL PRIMARY KEY, + + token_account TEXT NOT NULL, + version INTEGER NOT NULL, + + CONSTRAINT codewallet__core_cachedbalanceversion__unique__token_account UNIQUE (token_account) + ); + CREATE TABLE codewallet__core_externalbalancecheckpoint ( id SERIAL NOT NULL PRIMARY KEY, @@ -39,6 +48,7 @@ const ( // Used for testing ONLY, the table and migrations are external to this repository tableDestroy = ` + DROP TABLE codewallet__core_cachedbalanceversion; DROP TABLE codewallet__core_externalbalancecheckpoint; ` ) diff --git a/pkg/code/data/balance/store.go b/pkg/code/data/balance/store.go index 01330a4b..52f0cd9e 100644 --- a/pkg/code/data/balance/store.go +++ b/pkg/code/data/balance/store.go @@ -6,12 +6,21 @@ import ( ) var ( - ErrCheckpointNotFound = errors.New("checkpoint not found") + ErrStaleCachedBalanceVersion = errors.New("cached balance version is stale") - ErrStaleCheckpoint = errors.New("checkpoint is stale") + ErrCheckpointNotFound = errors.New("checkpoint not found") + ErrStaleCheckpoint = errors.New("checkpoint is stale") ) type Store interface { + // GetCachedVersion gets the current cached balance version, which can be used + // for optimistic locking cached balances for operations with outgoing transfers. + GetCachedVersion(ctx context.Context, account string) (uint64, error) + + // AdvanceCachedVersion advances an account's cached balance version. + // ErrStaleCachedBalanceVersion is returned if the currentVersion is out of date. + AdvanceCachedVersion(ctx context.Context, account string, currentVersion uint64) error + // SaveExternalCheckpoint saves an external balance at a checkpoint. // ErrStaleCheckpoint is returned if the checkpoint is outdated SaveExternalCheckpoint(ctx context.Context, record *ExternalCheckpointRecord) error diff --git a/pkg/code/data/balance/tests/tests.go b/pkg/code/data/balance/tests/tests.go index a4202ff3..d50240d9 100644 --- a/pkg/code/data/balance/tests/tests.go +++ b/pkg/code/data/balance/tests/tests.go @@ -13,6 +13,7 @@ import ( func RunTests(t *testing.T, s balance.Store, teardown func()) { for _, tf := range []func(t *testing.T, s balance.Store){ + testCachedBalanceVersionHappyPath, testExternalCheckpointHappyPath, } { tf(t, s) @@ -20,6 +21,29 @@ func RunTests(t *testing.T, s balance.Store, teardown func()) { } } +func testCachedBalanceVersionHappyPath(t *testing.T, s balance.Store) { + t.Run("testCachedBalanceVersionHappyPath", func(t *testing.T) { + ctx := context.Background() + + for i := range 100 { + if i > 0 { + assert.Equal(t, balance.ErrStaleCachedBalanceVersion, s.AdvanceCachedVersion(ctx, "token_account_1", uint64(i-1))) + } + assert.Equal(t, balance.ErrStaleCachedBalanceVersion, s.AdvanceCachedVersion(ctx, "token_account_1", uint64(i+1))) + + currentVersion, err := s.GetCachedVersion(ctx, "token_account_1") + require.NoError(t, err) + assert.EqualValues(t, i, currentVersion) + + require.NoError(t, s.AdvanceCachedVersion(ctx, "token_account_1", currentVersion)) + } + + currentVersion, err := s.GetCachedVersion(ctx, "token_account_2") + require.NoError(t, err) + assert.EqualValues(t, 0, currentVersion) + }) +} + func testExternalCheckpointHappyPath(t *testing.T, s balance.Store) { t.Run("testExternalCheckpointHappyPath", func(t *testing.T) { ctx := context.Background() diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index d14d4889..62e6fdd7 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -119,6 +119,8 @@ type DatabaseData interface { // Balance // -------------------------------------------------------------------------------- + GetCachedBalanceVersion(ctx context.Context, account string) (uint64, error) + AdvanceCachedBalanceVersion(ctx context.Context, account string, currentVersion uint64) error SaveExternalBalanceCheckpoint(ctx context.Context, record *balance.ExternalCheckpointRecord) error GetExternalBalanceCheckpoint(ctx context.Context, account string) (*balance.ExternalCheckpointRecord, error) @@ -428,6 +430,12 @@ func (dp *DatabaseProvider) HasFeeAction(ctx context.Context, intent string, fee // Balance // -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) GetCachedBalanceVersion(ctx context.Context, account string) (uint64, error) { + return dp.balance.GetCachedVersion(ctx, account) +} +func (dp *DatabaseProvider) AdvanceCachedBalanceVersion(ctx context.Context, account string, currentVersion uint64) error { + return dp.balance.AdvanceCachedVersion(ctx, account, currentVersion) +} func (dp *DatabaseProvider) SaveExternalBalanceCheckpoint(ctx context.Context, record *balance.ExternalCheckpointRecord) error { return dp.balance.SaveExternalCheckpoint(ctx, record) } diff --git a/pkg/code/server/account/server.go b/pkg/code/server/account/server.go index f460296c..f1b37070 100644 --- a/pkg/code/server/account/server.go +++ b/pkg/code/server/account/server.go @@ -379,7 +379,7 @@ func (s *server) getProtoAccountInfo(ctx context.Context, records *common.Accoun // Gift cards that are close to the auto-return window are marked as expired in // a consistent manner as SubmitIntent to avoid race conditions with the auto-return. - if time.Since(records.General.CreatedAt) >= async_account.GiftCardExpiry { + if time.Since(records.General.CreatedAt) >= async_account.GiftCardExpiry-time.Minute { claimState = accountpb.TokenAccountInfo_CLAIM_STATE_EXPIRED } @@ -467,7 +467,7 @@ func (s *server) updateCachedResponse(resp *accountpb.GetTokenAccountInfosRespon switch ai.AccountType { case commonpb.AccountType_REMOTE_SEND_GIFT_CARD: // Transition any gift card records to expired if we elapsed the expiry window - if time.Since(ai.CreatedAt.AsTime()) >= async_account.GiftCardExpiry { + if time.Since(ai.CreatedAt.AsTime()) >= async_account.GiftCardExpiry-time.Minute { ai.ClaimState = accountpb.TokenAccountInfo_CLAIM_STATE_EXPIRED ai.BalanceSource = accountpb.TokenAccountInfo_BALANCE_SOURCE_CACHE ai.Balance = 0 diff --git a/pkg/code/server/transaction/airdrop.go b/pkg/code/server/transaction/airdrop.go index dcd0440a..4f420c9f 100644 --- a/pkg/code/server/transaction/airdrop.go +++ b/pkg/code/server/transaction/airdrop.go @@ -96,10 +96,6 @@ func (s *transactionServer) Airdrop(ctx context.Context, req *transactionpb.Aird }, nil } - ownerLock := s.ownerLocks.Get(owner.PublicKey().ToBytes()) - ownerLock.Lock() - defer ownerLock.Unlock() - intentId := GetAirdropIntentId(AirdropTypeWelcomeBonus, owner.PublicKey().ToBase58()) _, err = s.data.GetIntent(ctx, intentId) if err == nil { @@ -164,6 +160,8 @@ func (s *transactionServer) Airdrop(ctx context.Context, req *transactionpb.Aird } // Note: this function is idempotent with the given intent ID. +// +// todo: This function needs to be more resilient to failures due to balance races func (s *transactionServer) airdrop(ctx context.Context, intentId string, owner *common.Account, airdropType AirdropType) (*intent.Record, error) { log := s.log.WithFields(logrus.Fields{ "method": "airdrop", @@ -260,6 +258,12 @@ func (s *transactionServer) airdrop(ctx context.Context, intentId string, owner return nil, err } + balanceLock, err := balance.GetOptimisticVersionLock(ctx, s.data, s.airdropper.Vault) + if err != nil { + log.WithError(err).Warn("failure getting balance lock") + return nil, err + } + // Do a balance check. If there's insufficient balance, the feature is considered // to be over with until we get more funding. balance, err := balance.CalculateFromCache(ctx, s.data, s.airdropper.Vault) @@ -390,6 +394,11 @@ func (s *transactionServer) airdrop(ctx context.Context, intentId string, owner return err } + err = balanceLock.OnCommit(ctx, s.data) + if err != nil { + return err + } + return nil }) if err != nil { diff --git a/pkg/code/server/transaction/config.go b/pkg/code/server/transaction/config.go index 439e4594..614790d7 100644 --- a/pkg/code/server/transaction/config.go +++ b/pkg/code/server/transaction/config.go @@ -49,7 +49,6 @@ type conf struct { enableAirdrops config.Bool airdropperOwnerPublicKey config.String maxAirdropUsdValue config.Float64 - stripedLockParallelization config.Uint64 } // ConfigProvider defines how config values are pulled @@ -70,7 +69,6 @@ func WithEnvConfigs() ConfigProvider { enableAirdrops: env.NewBoolConfig(EnableAirdropsConfigEnvName, defaultEnableAirdrops), airdropperOwnerPublicKey: env.NewStringConfig(AirdropperOwnerPublicKeyEnvName, defaultAirdropperOwnerPublicKey), maxAirdropUsdValue: env.NewFloat64Config(MaxAirdropUsdValueEnvName, defaultMaxAirdropUsdValue), - stripedLockParallelization: wrapper.NewUint64Config(memory.NewConfig(8192), 8192), } } } @@ -98,7 +96,6 @@ func withManualTestOverrides(overrides *testOverrides) ConfigProvider { enableAirdrops: wrapper.NewBoolConfig(memory.NewConfig(overrides.enableAirdrops), false), airdropperOwnerPublicKey: wrapper.NewStringConfig(memory.NewConfig(defaultAirdropperOwnerPublicKey), defaultAirdropperOwnerPublicKey), maxAirdropUsdValue: wrapper.NewFloat64Config(memory.NewConfig(defaultMaxAirdropUsdValue), defaultMaxAirdropUsdValue), - stripedLockParallelization: wrapper.NewUint64Config(memory.NewConfig(4), 4), } } } diff --git a/pkg/code/server/transaction/intent.go b/pkg/code/server/transaction/intent.go index 07fbd6d2..198d0408 100644 --- a/pkg/code/server/transaction/intent.go +++ b/pkg/code/server/transaction/intent.go @@ -7,6 +7,7 @@ import ( "database/sql" "encoding/base64" "strings" + "sync" "time" "github.com/mr-tron/base58/base58" @@ -20,6 +21,7 @@ import ( transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" async_account "github.com/code-payments/code-server/pkg/code/async/account" + "github.com/code-payments/code-server/pkg/code/balance" "github.com/code-payments/code-server/pkg/code/common" "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/action" @@ -215,12 +217,6 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm CreatedAt: time.Now(), } - initiatorOwnerLock := s.ownerLocks.Get(initiatorOwnerAccount.PublicKey().ToBytes()) - initiatorOwnerLock.Lock() - defer func() { - initiatorOwnerLock.Unlock() - }() - existingIntentRecord, err := s.data.GetIntent(ctx, intentId) if err != intent.ErrIntentNotFound && err != nil { log.WithError(err).Warn("failure checking for existing intent record") @@ -240,6 +236,7 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm return handleSubmitIntentError(streamer, err) } + // Check whether the intent is a no-op isNoop, err := intentHandler.IsNoop(ctx, intentRecord, submitActionsReq.Metadata, submitActionsReq.Actions) if err != nil { log.WithError(err).Warn("failure checking if intent is a no-op") @@ -251,18 +248,33 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm return nil } - // Distributed locking on additional accounts possibly not known until - // populating intent metadata. Importantly, this must be done prior to - // doing validation checks in AllowCreation. - additionalAccountsToLock, err := intentHandler.GetAdditionalAccountsToLock(ctx, intentRecord) + // Lock any acccounts with outgoing transfers of funds: + // 1. Optimistic version lock at the DB layer to guarantee balance consistency + // 2. Local in memory lock to avoid over consumption of local resources (eg. + // nonces) when we're likely to encounter a race resulting in DB txn rollback + // (eg. mass attempt to claim gift card). + accountBalancesToLock, err := intentHandler.GetAccountsWithBalancesToLock(ctx, intentRecord, submitActionsReq.Metadata) if err != nil { + log.WithError(err).Warn("failure getting accounts with balances to lock") return handleSubmitIntentError(streamer, err) } + localAccountLocks := make([]*sync.Mutex, len(accountBalancesToLock)) + globalBalanceLocks := make([]*balance.OptimisticVersionLock, len(accountBalancesToLock)) + for i, account := range accountBalancesToLock { + log := log.WithField("account", account.PublicKey().ToBase58()) + + localAccountLocks[i] = s.getLocalAccountLock(account) - if additionalAccountsToLock.RemoteSendGiftCardVault != nil { - giftCardLock := s.giftCardLocks.Get(additionalAccountsToLock.RemoteSendGiftCardVault.PublicKey().ToBytes()) - giftCardLock.Lock() - defer giftCardLock.Unlock() + globalBalanceLock, err := balance.GetOptimisticVersionLock(ctx, s.data, account) + if err != nil { + log.WithError(err).Warn("failure getting balance lock") + return handleSubmitIntentError(streamer, err) + } + globalBalanceLocks[i] = globalBalanceLock + } + for _, localAccountLock := range localAccountLocks { + localAccountLock.Lock() + defer localAccountLock.Unlock() } // Validate the new intent with intent-specific logic @@ -599,9 +611,21 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm return err } + for _, balanceLock := range globalBalanceLocks { + err = balanceLock.OnCommit(ctx, s.data) + if err != nil { + log.WithError(err).Warn("failure commiting balance update") + return err + } + } + return nil }) if err != nil { + if strings.Contains(err.Error(), "stale") || strings.Contains(err.Error(), "exist") { + log.WithError(err).Info("race condition detected") + return handleSubmitIntentError(streamer, newStaleStateErrorf("race detected: %s", err.Error())) + } return handleSubmitIntentError(streamer, err) } @@ -957,9 +981,15 @@ func (s *transactionServer) VoidGiftCard(ctx context.Context, req *transactionpb }, nil } - giftCardLock := s.giftCardLocks.Get(giftCardVault.PublicKey().ToBytes()) - giftCardLock.Lock() - defer giftCardLock.Unlock() + globalBalanceLock, err := balance.GetOptimisticVersionLock(ctx, s.data, giftCardVault) + if err != nil { + log.WithError(err).Warn("failure getting balance lock") + return nil, status.Error(codes.Internal, "") + } + + localAccountLock := s.getLocalAccountLock(giftCardVault) + localAccountLock.Lock() + defer localAccountLock.Unlock() claimedActionRecord, err := s.data.GetGiftCardClaimedAction(ctx, giftCardVault.PublicKey().ToBase58()) if err == nil { @@ -982,7 +1012,7 @@ func (s *transactionServer) VoidGiftCard(ctx context.Context, req *transactionpb return nil, status.Error(codes.Internal, "") } - err = async_account.InitiateProcessToAutoReturnGiftCard(ctx, s.data, giftCardVault, true) + err = async_account.InitiateProcessToAutoReturnGiftCard(ctx, s.data, giftCardVault, true, globalBalanceLock) if err != nil { log.WithError(err).Warn("failure scheduling auto-return action") return nil, status.Error(codes.Internal, "") @@ -996,3 +1026,14 @@ func (s *transactionServer) VoidGiftCard(ctx context.Context, req *transactionpb Result: transactionpb.VoidGiftCardResponse_OK, }, nil } + +func (s *transactionServer) getLocalAccountLock(account *common.Account) *sync.Mutex { + s.localAccountLocksMu.Lock() + lock, ok := s.localAccountLocks[account.PublicKey().ToBase58()] + if !ok { + lock = &sync.Mutex{} + s.localAccountLocks[account.PublicKey().ToBase58()] = lock + } + s.localAccountLocksMu.Unlock() + return lock +} diff --git a/pkg/code/server/transaction/intent_handler.go b/pkg/code/server/transaction/intent_handler.go index 9b8f642b..060519c9 100644 --- a/pkg/code/server/transaction/intent_handler.go +++ b/pkg/code/server/transaction/intent_handler.go @@ -29,10 +29,6 @@ var accountTypesToOpen = []commonpb.AccountType{ commonpb.AccountType_PRIMARY, } -type lockableAccounts struct { - RemoteSendGiftCardVault *common.Account -} - // CreateIntentHandler is an interface for handling new intent creations type CreateIntentHandler interface { // PopulateMetadata adds intent metadata to the provided intent record @@ -48,12 +44,9 @@ type CreateIntentHandler interface { // error. IsNoop(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata, actions []*transactionpb.Action) (bool, error) - // GetAdditionalAccountsToLock gets additional accounts to apply distributed - // locking on that are specific to an intent. - // - // Note: Assumes relevant information is contained in the intent record after - // calling PopulateMetadata. - GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) + // GetAccountsWithBalancesToLock gets a set of accounts with balances that need + // to be locked. + GetAccountsWithBalancesToLock(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata) ([]*common.Account, error) // AllowCreation determines whether the new intent creation should be allowed. AllowCreation(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata, actions []*transactionpb.Action) error @@ -99,6 +92,10 @@ func (h *OpenAccountsIntentHandler) PopulateMetadata(ctx context.Context, intent return nil } +func (h *OpenAccountsIntentHandler) GetAccountsWithBalancesToLock(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata) ([]*common.Account, error) { + return nil, nil +} + func (h *OpenAccountsIntentHandler) IsNoop(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata, actions []*transactionpb.Action) (bool, error) { initiatiorOwnerAccount, err := common.NewAccountFromPublicKeyString(intentRecord.InitiatorOwnerAccount) if err != nil { @@ -115,10 +112,6 @@ func (h *OpenAccountsIntentHandler) IsNoop(ctx context.Context, intentRecord *in return false, nil } -func (h *OpenAccountsIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { - return &lockableAccounts{}, nil -} - func (h *OpenAccountsIntentHandler) AllowCreation(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata, actions []*transactionpb.Action) error { typedMetadata := metadata.GetOpenAccounts() if typedMetadata == nil { @@ -315,19 +308,15 @@ func (h *SendPublicPaymentIntentHandler) IsNoop(ctx context.Context, intentRecor return false, nil } -func (h *SendPublicPaymentIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { - if !intentRecord.SendPublicPaymentMetadata.IsRemoteSend { - return &lockableAccounts{}, nil - } +func (h *SendPublicPaymentIntentHandler) GetAccountsWithBalancesToLock(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata) ([]*common.Account, error) { + typedMetadata := metadata.GetSendPublicPayment() - giftCardVaultAccount, err := common.NewAccountFromPublicKeyString(intentRecord.SendPublicPaymentMetadata.DestinationTokenAccount) + sourceVault, err := common.NewAccountFromProto(typedMetadata.Source) if err != nil { return nil, err } - return &lockableAccounts{ - RemoteSendGiftCardVault: giftCardVaultAccount, - }, nil + return []*common.Account{sourceVault}, nil } func (h *SendPublicPaymentIntentHandler) AllowCreation(ctx context.Context, intentRecord *intent.Record, untypedMetadata *transactionpb.Metadata, actions []*transactionpb.Action) error { @@ -445,20 +434,9 @@ func (h *SendPublicPaymentIntentHandler) validateActions( actions []*transactionpb.Action, simResult *LocalSimulationResult, ) error { - var source *common.Account - var err error - if metadata.Source != nil { - source, err = common.NewAccountFromProto(metadata.Source) - if err != nil { - return err - } - } else { - // Backwards compat for old clients using metadata without source. It was - // always assumed to be from the primary account - source, err = common.NewAccountFromPublicKeyString(initiatorAccountsByType[commonpb.AccountType_PRIMARY][0].General.TokenAccount) - if err != nil { - return err - } + source, err := common.NewAccountFromProto(metadata.Source) + if err != nil { + return err } destination, err := common.NewAccountFromProto(metadata.Destination) @@ -766,19 +744,12 @@ func (h *ReceivePaymentsPubliclyIntentHandler) IsNoop(ctx context.Context, inten return false, nil } -func (h *ReceivePaymentsPubliclyIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { - if !intentRecord.ReceivePaymentsPubliclyMetadata.IsRemoteSend { - return &lockableAccounts{}, nil - } - - giftCardVaultAccount, err := common.NewAccountFromPublicKeyString(intentRecord.ReceivePaymentsPubliclyMetadata.Source) +func (h *ReceivePaymentsPubliclyIntentHandler) GetAccountsWithBalancesToLock(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata) ([]*common.Account, error) { + giftCardVault, err := common.NewAccountFromPublicKeyString(intentRecord.ReceivePaymentsPubliclyMetadata.Source) if err != nil { return nil, err } - - return &lockableAccounts{ - RemoteSendGiftCardVault: giftCardVaultAccount, - }, nil + return []*common.Account{giftCardVault}, nil } func (h *ReceivePaymentsPubliclyIntentHandler) AllowCreation(ctx context.Context, intentRecord *intent.Record, untypedMetadata *transactionpb.Metadata, actions []*transactionpb.Action) error { @@ -1315,7 +1286,7 @@ func validateClaimedGiftCard(ctx context.Context, data code_data.Provider, giftC // Part 6: Are we within the threshold for auto-return back to the issuer? // - if time.Since(accountInfoRecord.CreatedAt) >= async_account.GiftCardExpiry-15*time.Minute { + if time.Since(accountInfoRecord.CreatedAt) >= async_account.GiftCardExpiry-time.Minute { return newStaleStateError("gift card is expired") } diff --git a/pkg/code/server/transaction/server.go b/pkg/code/server/transaction/server.go index 28a7a989..06708b73 100644 --- a/pkg/code/server/transaction/server.go +++ b/pkg/code/server/transaction/server.go @@ -15,7 +15,6 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/nonce" "github.com/code-payments/code-server/pkg/code/transaction" - sync_util "github.com/code-payments/code-server/pkg/sync" ) type transactionServer struct { @@ -33,14 +32,14 @@ type transactionServer struct { noncePool *transaction.LocalNoncePool + localAccountLocksMu sync.Mutex + localAccountLocks map[string]*sync.Mutex + airdropperLock sync.Mutex airdropper *common.TimelockAccounts feeCollector *common.Account - ownerLocks *sync_util.StripedLock - giftCardLocks *sync_util.StripedLock - transactionpb.UnimplementedTransactionServer } @@ -58,8 +57,6 @@ func NewTransactionServer( conf := configProvider() - stripedLockParallelization := uint(conf.stripedLockParallelization.Get(ctx)) - if err := noncePool.Validate(nonce.EnvironmentCvm, common.CodeVmAccount.PublicKey().ToBase58(), nonce.PurposeClientTransaction); err != nil { return nil, err } @@ -79,8 +76,7 @@ func NewTransactionServer( noncePool: noncePool, - ownerLocks: sync_util.NewStripedLock(stripedLockParallelization), - giftCardLocks: sync_util.NewStripedLock(stripedLockParallelization), + localAccountLocks: make(map[string]*sync.Mutex), } s.feeCollector, err = common.NewAccountFromPublicKeyString(s.conf.feeCollectorTokenPublicKey.Get(ctx)) diff --git a/pkg/database/postgres/errors.go b/pkg/database/postgres/errors.go index 38f8155a..a975d59b 100644 --- a/pkg/database/postgres/errors.go +++ b/pkg/database/postgres/errors.go @@ -9,12 +9,19 @@ import ( ) func CheckNoRows(inErr, outErr error) error { - if inErr == sql.ErrNoRows { + if IsNoRows(inErr) { return outErr } return inErr } +func IsNoRows(err error) bool { + if err == nil { + return false + } + return err == sql.ErrNoRows +} + func CheckUniqueViolation(inErr, outErr error) error { if inErr != nil { var pgErr *pgconn.PgError