diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dbea2e6f..3166fcc6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,5 +19,8 @@ jobs: with: go-version: 1.21.x + - name: Lint + run: make lint + - name: Test run: make test diff --git a/.gitignore b/.gitignore index 6e9c3ae5..860f8a3e 100644 --- a/.gitignore +++ b/.gitignore @@ -22,7 +22,10 @@ go.work # Build output directories build/ +.build/ # Visual Studio Code .vscode/ +# Jetbrains +.idea/ diff --git a/.golangci-lint.yaml b/.golangci-lint.yaml new file mode 100644 index 00000000..b219848c --- /dev/null +++ b/.golangci-lint.yaml @@ -0,0 +1,18 @@ +linters: + enable: + - bodyclose + - whitespace + - revive + +issues: + exclude-rules: + # var-naming conflicts are huge + - path: '(.+)\.go' + text: "var-naming" + # unused-parameter conflicts are sizeable + - path: '(.+)\.go' + text: "unused-parameter" + - path: '(.+)_test\.go' + text: ".*GetPaymentHistory is deprecated: Payment history has migrated to chats" + - path: '(.+)testutil\.go' + text: ".*GetPaymentHistory is deprecated: Payment history has migrated to chats" diff --git a/Makefile b/Makefile index c8d8faf8..8112551d 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ .PHONY: all -all: test +all: lint test .PHONY: test test: @@ -17,3 +17,17 @@ clean-integration-containers: @docker ps | grep -E "etcd-test-[0-9a-z]{8}-[0-9]+" | awk '{print $$1}' | xargs docker rm -f 2>/dev/null || true @echo Removing etcd cluster networks... @docker network ls | grep -E "etcd-test-[0-9a-z]{8}-network" | awk '{print $$1}' | xargs docker network remove 2>/dev/null || true + +.PHONY: lint +lint: tools.golangci-lint + @golangci-lint --timeout=3m --config .golangci-lint.yaml run ./... + +.PHONY: tools +tools: tools.golangci-lint + +tools.golangci-lint: .build/markers/golangci-lint_installed +.build/markers/golangci-lint_installed: + @command -v golangci-lint >/dev/null ; if [ $$? -ne 0 ]; then \ + CGO_ENABLED=0 go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.57.2; \ + fi + @mkdir -p $(shell dirname $@) && touch "$@" diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index f96b59c9..7b60dc27 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -1,7 +1,6 @@ package cache import ( - "errors" "log" "sync" ) @@ -11,7 +10,7 @@ type Cache interface { SetVerbose(verbose bool) GetWeight() int GetBudget() int - Insert(key string, value interface{}, weight int) error + Insert(key string, value interface{}, weight int) (inserted bool) Retrieve(key string) (interface{}, bool) Clear() } @@ -61,12 +60,12 @@ func (c *cache) GetBudget() int { } // Insert inserts an object into the cache -func (c *cache) Insert(key string, value interface{}, weight int) error { +func (c *cache) Insert(key string, value interface{}, weight int) (inserted bool) { c.mutex.Lock() defer c.mutex.Unlock() if _, found := c.lookup[key]; found { - return errors.New("key already exists in cache") + return false } node := &cacheNode{ @@ -106,7 +105,7 @@ func (c *cache) Insert(key string, value interface{}, weight int) error { delete(c.lookup, c.tail.key) } - return nil + return true } // Retrieve gets an object out of the cache diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 34da141d..525fd1a9 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -2,24 +2,18 @@ package cache import ( "testing" + + "github.com/stretchr/testify/require" ) func TestCacheInsert(t *testing.T) { cache := NewCache(1) - insertError := cache.Insert("A", "", 1) - - if insertError != nil { - t.Fatalf("Cache insert resulted in unexpected error: %s", insertError) - } + require.True(t, cache.Insert("A", "", 1)) } func TestCacheInsertWithinBudget(t *testing.T) { cache := NewCache(1) - insertError := cache.Insert("A", "", 2) - - if insertError != nil { - t.Fatalf("Cache insert resulted in unexpected error: %s", insertError) - } + require.True(t, cache.Insert("A", "", 2)) } func TestCacheInsertUpdatesWeight(t *testing.T) { @@ -28,19 +22,13 @@ func TestCacheInsertUpdatesWeight(t *testing.T) { _ = cache.Insert("B", "", 1) _ = cache.Insert("budget_exceeded", "", 1) - if cache.GetWeight() != 2 { - t.Fatal("Cache with budget 2 did not correctly set weight after evicting one of three nodes") - } + require.Equal(t, 2, cache.GetWeight()) } func TestCacheInsertDuplicateRejected(t *testing.T) { cache := NewCache(2) - _ = cache.Insert("dupe", "", 1) - dupeError := cache.Insert("dupe", "", 1) - - if dupeError == nil { - t.Fatal("Cache insert of duplicate key did not result in any err") - } + require.True(t, cache.Insert("dupe", "", 1)) + require.False(t, cache.Insert("dupe", "", 1)) } func TestCacheInsertEvictsLeastRecentlyUsed(t *testing.T) { @@ -51,17 +39,13 @@ func TestCacheInsertEvictsLeastRecentlyUsed(t *testing.T) { _ = cache.Insert("B", "", 1) _, foundEvicted := cache.Retrieve("evicted") - if foundEvicted { - t.Fatal("Cache insert did not trigger eviction after weight exceedance") - } + require.False(t, foundEvicted) // double check that only 1 one was evicted and not any extra _, foundA := cache.Retrieve("A") + require.True(t, foundA) _, foundB := cache.Retrieve("B") - - if !foundA || !foundB { - t.Fatal("Cache insert evicted more than necessary") - } + require.True(t, foundB) } func TestCacheInsertEvictsLeastRecentlyRetrieved(t *testing.T) { @@ -69,16 +53,14 @@ func TestCacheInsertEvictsLeastRecentlyRetrieved(t *testing.T) { _ = cache.Insert("A", "", 1) _ = cache.Insert("evicted", "", 1) - // retrieve the oldest node, promoting it head so it is not evicted + // retrieve the oldest node, promoting it head, so it is not evicted cache.Retrieve("A") // insert once more, exceeding weight capacity _ = cache.Insert("B", "", 1) // now the least recently used key should be evicted _, foundEvicted := cache.Retrieve("evicted") - if foundEvicted { - t.Fatal("Cache insert did not evict least recently used after weight exceedance") - } + require.False(t, foundEvicted) } func TestClear(t *testing.T) { @@ -86,8 +68,5 @@ func TestClear(t *testing.T) { _ = cache.Insert("cleared", "", 1) cache.Clear() _, found := cache.Retrieve("cleared") - - if found { - t.Fatal("Still able to retrieve nodes after cache was cleared") - } + require.False(t, found) } diff --git a/pkg/code/antispam/airdrop.go b/pkg/code/antispam/airdrop.go index 5701e301..3ef594c0 100644 --- a/pkg/code/antispam/airdrop.go +++ b/pkg/code/antispam/airdrop.go @@ -220,7 +220,6 @@ func (g *Guard) AllowReferralBonus( recordDenialEvent(ctx, actionReferralBonus, "region restricted") return false, nil } - } // Deny from mobile networks where we're currently under attack diff --git a/pkg/code/antispam/guard_test.go b/pkg/code/antispam/guard_test.go index 5449e3f8..9b71e6fb 100644 --- a/pkg/code/antispam/guard_test.go +++ b/pkg/code/antispam/guard_test.go @@ -214,7 +214,7 @@ func TestAllowSendPayment_StaffUser(t *testing.T) { ownerAccount2 := testutil.NewRandomAccount(t) require.NoError(t, env.data.PutUser(env.ctx, &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -366,7 +366,7 @@ func TestAllowReceivePayments_StaffUser(t *testing.T) { ownerAccount2 := testutil.NewRandomAccount(t) require.NoError(t, env.data.PutUser(env.ctx, &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -375,7 +375,6 @@ func TestAllowReceivePayments_StaffUser(t *testing.T) { })) for _, ownerAccount := range []*common.Account{ownerAccount1, ownerAccount2} { - verification := &phone.Verification{ PhoneNumber: phoneNumber, OwnerAccount: ownerAccount.PublicKey().ToBase58(), @@ -499,7 +498,7 @@ func TestAllowOpenAccounts_StaffUser(t *testing.T) { ownerAccount2 := testutil.NewRandomAccount(t) require.NoError(t, env.data.PutUser(env.ctx, &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -508,7 +507,6 @@ func TestAllowOpenAccounts_StaffUser(t *testing.T) { })) for _, ownerAccount := range []*common.Account{ownerAccount1, ownerAccount2} { - verification := &phone.Verification{ PhoneNumber: phoneNumber, OwnerAccount: ownerAccount.PublicKey().ToBase58(), @@ -599,7 +597,7 @@ func TestAllowEstablishNewRelationship_StaffUser(t *testing.T) { ownerAccount2 := testutil.NewRandomAccount(t) require.NoError(t, env.data.PutUser(env.ctx, &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -692,7 +690,7 @@ func TestAllowNewPhoneVerification_StaffUser(t *testing.T) { phoneNumber := "+12223334444" require.NoError(t, env.data.PutUser(env.ctx, &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, diff --git a/pkg/code/async/account/config.go b/pkg/code/async/account/config.go index e8d8dbc2..ae245f72 100644 --- a/pkg/code/async/account/config.go +++ b/pkg/code/async/account/config.go @@ -3,7 +3,7 @@ package async_account // todo: setup configs const ( - envConfigPrefix = "ACCOUNT_SERVICE_" + envConfigPrefix = "ACCOUNT_SERVICE_" //nolint:unused ) type conf struct { diff --git a/pkg/code/async/account/gift_card.go b/pkg/code/async/account/gift_card.go index 37e37ba1..b1f0ac45 100644 --- a/pkg/code/async/account/gift_card.go +++ b/pkg/code/async/account/gift_card.go @@ -41,7 +41,7 @@ func (p *service) giftCardAutoReturnWorker(serviceCtx context.Context, interval func() (err error) { time.Sleep(delay) - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__account_service__handle_gift_card_auto_return") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -198,12 +198,18 @@ func (p *service) initiateProcessToAutoReturnGiftCard(ctx context.Context, giftC } // Finally, update the user by best-effort sending them a push - go push.SendGiftCardReturnedPushNotification( - ctx, - p.data, - p.pusher, - giftCardVaultAccount, - ) + go func() { + err := push.SendGiftCardReturnedPushNotification( + ctx, + p.data, + p.pusher, + giftCardVaultAccount, + ) + if err != nil { + p.log.WithError(err).Warn("failed to send gift card return push notification (best effort)") + } + }() + return nil } diff --git a/pkg/code/async/account/service.go b/pkg/code/async/account/service.go index 15c7b32a..5741527a 100644 --- a/pkg/code/async/account/service.go +++ b/pkg/code/async/account/service.go @@ -42,8 +42,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { } }() - select { - case <-ctx.Done(): - return ctx.Err() - } + <-ctx.Done() + return ctx.Err() } diff --git a/pkg/code/async/commitment/service.go b/pkg/code/async/commitment/service.go index ab324e10..530e8e84 100644 --- a/pkg/code/async/commitment/service.go +++ b/pkg/code/async/commitment/service.go @@ -24,7 +24,6 @@ func New(data code_data.Provider) async.Service { } func (p *service) Start(ctx context.Context, interval time.Duration) error { - // Setup workers to watch for commitment state changes on the Solana side for _, item := range []commitment.State{ commitment.StateReadyToOpen, @@ -38,7 +37,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { if err != nil && err != context.Canceled { p.log.WithError(err).Warnf("commitment processing loop terminated unexpectedly for state %d", state) } - }(item) } @@ -49,8 +47,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { } }() - select { - case <-ctx.Done(): - return ctx.Err() - } + <-ctx.Done() + return ctx.Err() } diff --git a/pkg/code/async/commitment/testutil.go b/pkg/code/async/commitment/testutil.go index 12490f92..5d17cbd6 100644 --- a/pkg/code/async/commitment/testutil.go +++ b/pkg/code/async/commitment/testutil.go @@ -3,10 +3,11 @@ package async_commitment import ( "context" "crypto/ed25519" + "crypto/rand" "encoding/hex" "fmt" "math" - "math/rand" + mrand "math/rand" "testing" "time" @@ -73,7 +74,7 @@ func setup(t *testing.T) testEnv { SolanaBlock: 123, - State: treasury.TreasuryPoolStateAvailable, + State: treasury.PoolStateAvailable, } merkleTree, err := db.InitializeNewMerkleTree( @@ -123,7 +124,7 @@ func (e testEnv) simulateCommitment(t *testing.T, recentRoot string, state commi Amount: kin.ToQuarks(1), Intent: testutil.NewRandomAccount(t).PublicKey().ToBase58(), - ActionId: rand.Uint32(), + ActionId: mrand.Uint32(), Owner: testutil.NewRandomAccount(t).PublicKey().ToBase58(), @@ -178,7 +179,7 @@ func (e testEnv) simulateCommitment(t *testing.T, recentRoot string, state commi FulfillmentType: fulfillment.TemporaryPrivacyTransferWithAuthority, Data: []byte("data"), - Signature: pointer.String(fmt.Sprintf("sig%d", rand.Uint64())), + Signature: pointer.String(fmt.Sprintf("sig%d", mrand.Uint64())), Nonce: pointer.String(testutil.NewRandomAccount(t).PublicKey().ToBase58()), Blockhash: pointer.String("bh"), @@ -192,7 +193,7 @@ func (e testEnv) simulateCommitment(t *testing.T, recentRoot string, state commi timelockRecord := timelockAccounts.ToDBRecord() timelockRecord.VaultState = timelock_token_v1.StateLocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, e.data.SaveTimelock(e.ctx, timelockRecord)) return commitmentRecord @@ -214,7 +215,7 @@ func (e testEnv) simulateSourceAccountUnlocked(t *testing.T, commitmentRecord *c require.NoError(t, err) timelockRecord.VaultState = timelock_token_v1.StateUnlocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, e.data.SaveTimelock(e.ctx, timelockRecord)) } @@ -235,7 +236,7 @@ func (e testEnv) simulateAddingLeaves(t *testing.T, commitmentRecords []*commitm e.treasuryPool.CurrentIndex = (e.treasuryPool.CurrentIndex + 1) % e.treasuryPool.HistoryListSize e.treasuryPool.HistoryList[e.treasuryPool.CurrentIndex] = hex.EncodeToString(rootNode.Hash) - e.treasuryPool.SolanaBlock += 1 + e.treasuryPool.SolanaBlock++ require.NoError(t, e.data.SaveTreasuryPool(e.ctx, e.treasuryPool)) } @@ -276,7 +277,7 @@ func (e testEnv) simulateCommitmentBeingUpgraded(t *testing.T, upgradeFrom, upgr permanentPrivacyFulfillment := fulfillmentRecords[0].Clone() permanentPrivacyFulfillment.Id = 0 - permanentPrivacyFulfillment.Signature = pointer.String(fmt.Sprintf("txn%d", rand.Uint64())) + permanentPrivacyFulfillment.Signature = pointer.String(fmt.Sprintf("txn%d", mrand.Uint64())) permanentPrivacyFulfillment.FulfillmentType = fulfillment.PermanentPrivacyTransferWithAuthority permanentPrivacyFulfillment.Destination = &upgradeTo.Vault require.NoError(t, e.data.PutAllFulfillments(e.ctx, &permanentPrivacyFulfillment)) @@ -492,7 +493,8 @@ func (e *testEnv) generateAvailableNonce(t *testing.T) *nonce.Record { nonceAccount := testutil.NewRandomAccount(t) var bh solana.Blockhash - rand.Read(bh[:]) + _, err := rand.Read(bh[:]) + require.NoError(t, err) nonceKey := &vault.Record{ PublicKey: nonceAccount.PublicKey().ToBase58(), diff --git a/pkg/code/async/commitment/worker.go b/pkg/code/async/commitment/worker.go index 496503d6..713d635f 100644 --- a/pkg/code/async/commitment/worker.go +++ b/pkg/code/async/commitment/worker.go @@ -4,18 +4,15 @@ import ( "context" "database/sql" "errors" + "fmt" "math" "sync" "time" "github.com/mr-tron/base58" "github.com/newrelic/go-agent/v3/newrelic" + "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/metrics" - "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/retry" - "github.com/code-payments/code-server/pkg/solana" "github.com/code-payments/code-server/pkg/code/common" "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/commitment" @@ -24,6 +21,11 @@ import ( "github.com/code-payments/code-server/pkg/code/data/nonce" "github.com/code-payments/code-server/pkg/code/data/treasury" "github.com/code-payments/code-server/pkg/code/transaction" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/metrics" + "github.com/code-payments/code-server/pkg/pointer" + "github.com/code-payments/code-server/pkg/retry" + "github.com/code-payments/code-server/pkg/solana" ) // @@ -58,7 +60,7 @@ func (p *service) worker(serviceCtx context.Context, state commitment.State, int func() (err error) { time.Sleep(delay) - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__commitment_service__handle_" + state.String()) defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -424,8 +426,28 @@ func (p *service) injectCommitmentVaultManagementFulfillments(ctx context.Contex if err != nil { return err } + + // note: defer() will only run when the outer function returns, and therefore + // all of the defer()'s in this loop will be run all at once at the end, rather + // than at the end of each iteration. + // + // Since we are not committing (and therefore consuming) the nonce's until the + // end of the function, this is desirable. If we released at the end of each + // iteration, we could potentially acquire the same nonce multiple times for + // different transactions, which would fail. defer func() { - selectedNonce.ReleaseIfNotReserved() + if err := selectedNonce.ReleaseIfNotReserved(); err != nil { + p.log. + WithFields(logrus.Fields{ + "method": "injectCommitmentVaultManagementFulfillments", + "nonce_account": selectedNonce.Account.PublicKey().ToBase58(), + "blockhash": selectedNonce.Blockhash.ToBase58(), + }). + WithError(err). + Warn("failed to release nonce") + } + + // This is idempotent regardless of whether the above selectedNonce.Unlock() }() @@ -433,7 +455,10 @@ func (p *service) injectCommitmentVaultManagementFulfillments(ctx context.Contex if err != nil { return err } - txn.Sign(common.GetSubsidizer().PrivateKey().ToBytes()) + + if err := txn.Sign(common.GetSubsidizer().PrivateKey().ToBytes()); err != nil { + return fmt.Errorf("failed to sign transaction: %w", err) + } intentOrderingIndex := uint64(0) fulfillmentOrderingIndex := uint32(i) diff --git a/pkg/code/async/currency/exchange_rate.go b/pkg/code/async/currency/exchange_rate.go index 13f59523..7ca662e0 100644 --- a/pkg/code/async/currency/exchange_rate.go +++ b/pkg/code/async/currency/exchange_rate.go @@ -35,7 +35,7 @@ func (p *exchangeRateService) Start(serviceCtx context.Context, interval time.Du func() error { p.log.Trace("updating exchange rates") - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__currency_service") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) diff --git a/pkg/code/async/geyser/backup.go b/pkg/code/async/geyser/backup.go index 96d47a7c..262b3754 100644 --- a/pkg/code/async/geyser/backup.go +++ b/pkg/code/async/geyser/backup.go @@ -41,7 +41,7 @@ func (p *service) backupTimelockStateWorker(serviceCtx context.Context, interval start := time.Now() func() { - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__geyser_consumer_service__backup_timelock_state_worker") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -113,7 +113,7 @@ func (p *service) backupExternalDepositWorker(serviceCtx context.Context, interv select { case <-time.After(interval): func() { - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__geyser_consumer_service__backup_external_deposit_worker") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -186,7 +186,7 @@ func (p *service) backupMessagingWorker(serviceCtx context.Context, interval tim start := time.Now() func() { - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__geyser_consumer_service__backup_messaging_worker") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) diff --git a/pkg/code/async/geyser/consumer.go b/pkg/code/async/geyser/consumer.go index 7f6093d1..c5c02ab3 100644 --- a/pkg/code/async/geyser/consumer.go +++ b/pkg/code/async/geyser/consumer.go @@ -8,8 +8,8 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/metrics" "github.com/code-payments/code-server/pkg/code/common" + "github.com/code-payments/code-server/pkg/metrics" ) func (p *service) consumeGeyserProgramUpdateEvents(ctx context.Context) error { @@ -76,7 +76,7 @@ func (p *service) programUpdateWorker(serviceCtx context.Context, id int) { for update := range p.programUpdatesChan { func() { - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__geyser_consumer_service__program_update_worker") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -124,7 +124,7 @@ func (p *service) programUpdateWorker(serviceCtx context.Context, id int) { } p.metricStatusLock.Lock() - p.programUpdateWorkerMetrics[id].eventsProcessed += 1 + p.programUpdateWorkerMetrics[id].eventsProcessed++ p.metricStatusLock.Unlock() }() } diff --git a/pkg/code/async/geyser/external_deposit.go b/pkg/code/async/geyser/external_deposit.go index 82740129..680c69f5 100644 --- a/pkg/code/async/geyser/external_deposit.go +++ b/pkg/code/async/geyser/external_deposit.go @@ -10,6 +10,7 @@ import ( "github.com/mr-tron/base58" "github.com/pkg/errors" + "github.com/sirupsen/logrus" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" @@ -75,7 +76,6 @@ func findPotentialExternalDeposits(ctx context.Context, data code_data.Provider, var cursor []byte var totalTransactionsFound int for { - history, err := data.GetBlockchainHistory( ctx, vault.PublicKey().ToBase58(), @@ -200,7 +200,6 @@ func processPotentialExternalDeposit(ctx context.Context, conf *conf, data code_ // todo: Below logic is beginning to get messy and might be in need of a // refactor soon switch accountInfoRecord.AccountType { - case commonpb.AccountType_PRIMARY, commonpb.AccountType_RELATIONSHIP: // Check whether we've previously processed this external deposit _, err = data.GetExternalDeposit(ctx, signature, tokenAccount.PublicKey().ToBase58()) @@ -274,21 +273,23 @@ func processPotentialExternalDeposit(ctx context.Context, conf *conf, data code_ } canPush, err := chat_util.SendKinPurchasesMessage(ctx, data, chatMessageReceiver, chatMessage) - switch err { - case nil: - if canPush { - push.SendChatMessagePushNotification( - ctx, - data, - pusher, - chat_util.KinPurchasesName, - chatMessageReceiver, - chatMessage, - ) + if err != nil && !errors.Is(err, chat.ErrMessageAlreadyExists) { + return fmt.Errorf("failed to send chat message: %w", err) + } else if err == nil && canPush { + pushErr := push.SendChatMessagePushNotification( + ctx, + data, + pusher, + chat_util.KinPurchasesName, + chatMessageReceiver, + chatMessage, + ) + if pushErr != nil { + logrus.StandardLogger(). + WithError(pushErr). + WithField("method", "processPotentialExternalDeposit"). + Warn("failed to send chat message push notification (best effort)") } - case chat.ErrMessageAlreadyExists: - default: - return errors.Wrap(err, "error sending chat message") } } else { err = chat_util.SendCashTransactionsExchangeMessage(ctx, data, intentRecord) @@ -300,7 +301,12 @@ func processPotentialExternalDeposit(ctx context.Context, conf *conf, data code_ return errors.Wrap(err, "error updating merchant chat") } - push.SendDepositPushNotification(ctx, data, pusher, tokenAccount, uint64(deltaQuarks)) + if pushErr := push.SendDepositPushNotification(ctx, data, pusher, tokenAccount, uint64(deltaQuarks)); err != nil { + logrus.StandardLogger(). + WithError(pushErr). + WithField("method", "processPotentialExternalDeposit"). + Warn("failed to send deposit push notification (best effort)") + } } // For tracking in balances @@ -693,36 +699,47 @@ func delayedUsdcDepositProcessing( } canPush, err := chat_util.SendKinPurchasesMessage(ctx, data, ownerAccount, chatMessage) - switch err { - case nil: - if canPush { - push.SendChatMessagePushNotification( - ctx, - data, - pusher, - chat_util.KinPurchasesName, - ownerAccount, - chatMessage, - ) + if err == nil && canPush { + pushErr := push.SendChatMessagePushNotification( + ctx, + data, + pusher, + chat_util.KinPurchasesName, + ownerAccount, + chatMessage, + ) + if pushErr != nil { + logrus.StandardLogger(). + WithField("method", "delayedUsdcDepositProcessing"). + WithError(err). + Warn("failed to send chat message push notification (best effort)") } - case chat.ErrMessageAlreadyExists: - default: - return } } // Optimistically tries to cache a balance for an external account not managed // Code. It doesn't need to be perfect and will be lazily corrected on the next -// balance fetch with a newer state returned by a RPC node. +// balance fetch with a newer state returned by an RPC node. func bestEffortCacheExternalAccountBalance(ctx context.Context, data code_data.Provider, tokenAccount *common.Account, tokenBalances *solana.TransactionTokenBalances) { postBalance, err := getPostQuarkBalance(tokenAccount, tokenBalances) - if err == nil { - checkpointRecord := &balance.Record{ - TokenAccount: tokenAccount.PublicKey().ToBase58(), - Quarks: postBalance, - SlotCheckpoint: tokenBalances.Slot, - } - data.SaveBalanceCheckpoint(ctx, checkpointRecord) + if err != nil { + return + } + + checkpointRecord := &balance.Record{ + TokenAccount: tokenAccount.PublicKey().ToBase58(), + Quarks: postBalance, + SlotCheckpoint: tokenBalances.Slot, + } + + if err := data.SaveBalanceCheckpoint(ctx, checkpointRecord); err != nil { + logrus.StandardLogger(). + WithFields(logrus.Fields{ + "method": "bestEffortCacheExternalAccountBalance", + "account": tokenAccount.PublicKey().ToBase58(), + }). + WithError(err). + Warn("failed to save balance checkpoint (best effort)") } } diff --git a/pkg/code/async/geyser/handler.go b/pkg/code/async/geyser/handler.go index f44e8f68..34815fd0 100644 --- a/pkg/code/async/geyser/handler.go +++ b/pkg/code/async/geyser/handler.go @@ -102,7 +102,6 @@ func (h *TokenProgramAccountHandler) Handle(ctx context.Context, update *geyserp } switch mintAccount.PublicKey().ToBase58() { - case common.KinMintAccount.PublicKey().ToBase58(): // Not a program vault account, so filter it out. It cannot be a Timelock // account. diff --git a/pkg/code/async/geyser/handler_test.go b/pkg/code/async/geyser/handler_test.go index 935680ca..725c8bc9 100644 --- a/pkg/code/async/geyser/handler_test.go +++ b/pkg/code/async/geyser/handler_test.go @@ -16,12 +16,12 @@ func TestTimelockV1ProgramAccountHandler(t *testing.T) { } -type testEnv struct { +type testEnv struct { //nolint:unused data code_data.Provider handlers map[string]ProgramAccountUpdateHandler } -func setup(t *testing.T) *testEnv { +func setup(t *testing.T) *testEnv { //nolint:unused data := code_data.NewTestDataProvider() return &testEnv{ data: data, diff --git a/pkg/code/async/geyser/messenger.go b/pkg/code/async/geyser/messenger.go index 5c94153c..b4cd3f7b 100644 --- a/pkg/code/async/geyser/messenger.go +++ b/pkg/code/async/geyser/messenger.go @@ -7,6 +7,7 @@ import ( "github.com/mr-tron/base58" "github.com/pkg/errors" + "github.com/sirupsen/logrus" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" @@ -169,7 +170,7 @@ func processPotentialBlockchainMessage(ctx context.Context, data code_data.Provi ctx, data, asciiBaseDomain, - chat.ChatTypeExternalApp, + chat.TypeExternalApp, true, recipientOwner, chatMessage, @@ -181,7 +182,7 @@ func processPotentialBlockchainMessage(ctx context.Context, data code_data.Provi if canPush { // Best-effort send a push - push.SendChatMessagePushNotification( + pushErr := push.SendChatMessagePushNotification( ctx, data, pusher, @@ -189,6 +190,10 @@ func processPotentialBlockchainMessage(ctx context.Context, data code_data.Provi recipientOwner, chatMessage, ) + logrus.StandardLogger(). + WithField("method", "processPotentialBlockchainMessage"). + WithError(pushErr). + Warn("failed to send chat message push notification (best effort)") } } diff --git a/pkg/code/async/geyser/metrics.go b/pkg/code/async/geyser/metrics.go index 65c48d5e..ed6d8dba 100644 --- a/pkg/code/async/geyser/metrics.go +++ b/pkg/code/async/geyser/metrics.go @@ -80,7 +80,7 @@ func (p *service) recordEventWorkerStatusPollingEvent(ctx context.Context) { var numActive int for _, workerMetrics := range p.programUpdateWorkerMetrics { if workerMetrics.active { - numActive += 1 + numActive++ } eventsProcessed += workerMetrics.eventsProcessed workerMetrics.eventsProcessed = 0 diff --git a/pkg/code/async/geyser/retry.go b/pkg/code/async/geyser/retry.go index 88b74923..a346bfef 100644 --- a/pkg/code/async/geyser/retry.go +++ b/pkg/code/async/geyser/retry.go @@ -2,18 +2,12 @@ package async_geyser import ( "context" - "errors" "time" "github.com/code-payments/code-server/pkg/retry" "github.com/code-payments/code-server/pkg/retry/backoff" ) -var ( - errSignatureNotConfirmed = errors.New("signature is not confirmed") - errSignatureNotFinalized = errors.New("signature is not finalized") -) - var waitForFinalizationRetryStrategies = []retry.Strategy{ retry.NonRetriableErrors(context.Canceled), retry.Limit(30), diff --git a/pkg/code/async/geyser/service.go b/pkg/code/async/geyser/service.go index b22450ba..577942d4 100644 --- a/pkg/code/async/geyser/service.go +++ b/pkg/code/async/geyser/service.go @@ -111,9 +111,7 @@ func (p *service) Start(ctx context.Context, _ time.Duration) error { }() // Wait for the service to stop - select { - case <-ctx.Done(): - } + <-ctx.Done() // Gracefully shutdown close(p.programUpdatesChan) diff --git a/pkg/code/async/nonce/allocator.go b/pkg/code/async/nonce/allocator.go index cebee7be..2a311393 100644 --- a/pkg/code/async/nonce/allocator.go +++ b/pkg/code/async/nonce/allocator.go @@ -12,13 +12,12 @@ import ( ) func (p *service) generateNonceAccounts(serviceCtx context.Context) error { - hasWarnedUser := false err := retry.Loop( func() (err error) { time.Sleep(time.Second) - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__nonce_service__nonce_accounts") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -70,7 +69,6 @@ func (p *service) generateNonceAccounts(serviceCtx context.Context) error { } return nil - }, retry.NonRetriableErrors(context.Canceled, ErrInvalidNonceLimitExceeded), ) diff --git a/pkg/code/async/nonce/keys.go b/pkg/code/async/nonce/keys.go index 287b008d..dd9241f0 100644 --- a/pkg/code/async/nonce/keys.go +++ b/pkg/code/async/nonce/keys.go @@ -37,7 +37,7 @@ func (p *service) generateKeys(ctx context.Context) error { // Give the server some time to breath. time.Sleep(time.Second * 15) - nr := ctx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := ctx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__nonce_service__vault_keys") defer func() { m.End() diff --git a/pkg/code/async/nonce/pool.go b/pkg/code/async/nonce/pool.go index 11b83bd4..9042d3fb 100644 --- a/pkg/code/async/nonce/pool.go +++ b/pkg/code/async/nonce/pool.go @@ -25,7 +25,7 @@ func (p *service) worker(serviceCtx context.Context, state nonce.State, interval func() (err error) { time.Sleep(delay) - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__nonce_service__handle_" + state.String()) defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -74,7 +74,6 @@ func (p *service) worker(serviceCtx context.Context, state nonce.State, interval } func (p *service) handle(ctx context.Context, record *nonce.Record) error { - /* Finite state machine: States: diff --git a/pkg/code/async/nonce/service.go b/pkg/code/async/nonce/service.go index 4c62270e..875b2e8f 100644 --- a/pkg/code/async/nonce/service.go +++ b/pkg/code/async/nonce/service.go @@ -2,6 +2,7 @@ package async_nonce import ( "context" + "fmt" "os" "strconv" "time" @@ -66,11 +67,21 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { p.size = size } + errCh := make(chan error, 2+2+1) + // Generate vault keys until we have at least 10 in reserve to use for the pool - go p.generateKeys(ctx) + go func() { + if err := p.generateKeys(ctx); err != nil { + errCh <- fmt.Errorf("failed to generate keys: %w", err) + } + }() // Watch the size of the nonce pool and create accounts if necessary - go p.generateNonceAccounts(ctx) + go func() { + if err := p.generateNonceAccounts(ctx); err != nil { + errCh <- fmt.Errorf("failed to generate nonce accounts: %w", err) + } + }() // Setup workers to watch for nonce state changes on the Solana side for _, item := range []nonce.State{ @@ -78,23 +89,21 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { nonce.StateReleased, } { go func(state nonce.State) { - - err := p.worker(ctx, state, interval) - if err != nil && err != context.Canceled { - p.log.WithError(err).Warnf("nonce processing loop terminated unexpectedly for state %d", state) + if err := p.worker(ctx, state, interval); err != nil && errors.Is(err, context.Canceled) { + errCh <- err } - }(item) } go func() { - err := p.metricsGaugeWorker(ctx) - if err != nil && err != context.Canceled { - p.log.WithError(err).Warn("nonce metrics gauge loop terminated unexpectedly") + if err := p.metricsGaugeWorker(ctx); err != nil && !errors.Is(err, context.Canceled) { + errCh <- err } }() select { + case err := <-errCh: + return err case <-ctx.Done(): return ctx.Err() } diff --git a/pkg/code/async/sequencer/fulfillment_handler.go b/pkg/code/async/sequencer/fulfillment_handler.go index b0d0c25a..602a642a 100644 --- a/pkg/code/async/sequencer/fulfillment_handler.go +++ b/pkg/code/async/sequencer/fulfillment_handler.go @@ -8,6 +8,8 @@ import ( "sync" "time" + "github.com/sirupsen/logrus" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" commitment_worker "github.com/code-payments/code-server/pkg/code/async/commitment" @@ -1454,7 +1456,15 @@ func isTokenAccountOnBlockchain(ctx context.Context, data code_data.Provider, ad return existsOnBlockchain, nil } - markFulfillmentAsActivelyScheduled(ctx, data, initializeFulfillmentRecord) + if err := markFulfillmentAsActivelyScheduled(ctx, data, initializeFulfillmentRecord); err != nil { + logrus.StandardLogger(). + WithFields(logrus.Fields{ + "method": "isTokenAccountOnBlockchain", + "address": address, + }). + WithError(err). + Warn("failed to mark fulfillment as actively scheduled (best effort)") + } } return existsOnBlockchain, nil diff --git a/pkg/code/async/sequencer/fulfillment_handler_test.go b/pkg/code/async/sequencer/fulfillment_handler_test.go index 235466b1..c4a01758 100644 --- a/pkg/code/async/sequencer/fulfillment_handler_test.go +++ b/pkg/code/async/sequencer/fulfillment_handler_test.go @@ -3,10 +3,11 @@ package async_sequencer import ( "context" "crypto/ed25519" + "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" - "math/rand" + mrand "math/rand" "strings" "testing" "time" @@ -235,7 +236,7 @@ func TestCloseDormantTimelockAccountFulfillmentHandler_IsRevoked(t *testing.T) { timelock_token_v1.StateUnlocked, } { timelockRecord.VaultState = state - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) revoked, nonceUsed, err := handler.IsRevoked(env.ctx, fulfillmentRecord) @@ -247,7 +248,7 @@ func TestCloseDormantTimelockAccountFulfillmentHandler_IsRevoked(t *testing.T) { } timelockRecord.VaultState = timelock_token_v1.StateClosed - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) revoked, nonceUsed, err := handler.IsRevoked(env.ctx, fulfillmentRecord) @@ -831,7 +832,7 @@ func TestIsTokenAccountOnBlockchain_CodeAccount(t *testing.T) { timelock_token_v1.StateClosed, } { timelockRecord.VaultState = state - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) actual, err := isTokenAccountOnBlockchain(env.ctx, env.data, timelockRecord.VaultAddress) @@ -966,7 +967,7 @@ func (e *fulfillmentHandlerTestEnv) setupTimelockRecord(t *testing.T, owner *com func (e *fulfillmentHandlerTestEnv) setupCommitmentInState(t *testing.T, intentId string, actionId uint32, state commitment.State) *commitment.Record { hasher := sha256.New() - hasher.Write([]byte(fmt.Sprintf("recent-root%d", rand.Uint64()))) + hasher.Write([]byte(fmt.Sprintf("recent-root%d", mrand.Uint64()))) recentRoot := hex.EncodeToString(hasher.Sum(nil)) hasher = sha256.New() @@ -1034,7 +1035,7 @@ func (e *fulfillmentHandlerTestEnv) setupForPayment(t *testing.T, fulfillmentRec } if fulfillmentRecord.Signature == nil { - fulfillmentRecord.Signature = pointer.String(fmt.Sprintf("txn%d", rand.Uint64())) + fulfillmentRecord.Signature = pointer.String(fmt.Sprintf("txn%d", mrand.Uint64())) } if fulfillmentRecord.Destination == nil { @@ -1044,7 +1045,7 @@ func (e *fulfillmentHandlerTestEnv) setupForPayment(t *testing.T, fulfillmentRec fulfillmentRecord.Data = []byte("data") - quantity := rand.Uint64() + quantity := mrand.Uint64() actionRecord := &action.Record{ Source: fulfillmentRecord.Source, Destination: fulfillmentRecord.Destination, @@ -1090,7 +1091,7 @@ func (e *fulfillmentHandlerTestEnv) simulatePrivacyUpgrade(t *testing.T, fulfill cloned := fulfillmentRecord.Clone() cloned.Id = 0 cloned.FulfillmentType = fulfillment.PermanentPrivacyTransferWithAuthority - cloned.Signature = pointer.String(fmt.Sprintf("txn%d", rand.Uint64())) + cloned.Signature = pointer.String(fmt.Sprintf("txn%d", mrand.Uint64())) require.NoError(t, e.data.PutAllFulfillments(e.ctx, &cloned)) nonceRecord, err := e.data.GetNonce(e.ctx, *fulfillmentRecord.Nonce) @@ -1176,7 +1177,8 @@ func (e *fulfillmentHandlerTestEnv) generateAvailableNonce(t *testing.T) *nonce. nonceAccount := testutil.NewRandomAccount(t) var bh solana.Blockhash - rand.Read(bh[:]) + _, err := rand.Read(bh[:]) + require.NoError(t, err) nonceKey := &vault.Record{ PublicKey: nonceAccount.PublicKey().ToBase58(), diff --git a/pkg/code/async/sequencer/scheduler_test.go b/pkg/code/async/sequencer/scheduler_test.go index 6c2dd1aa..9806dd89 100644 --- a/pkg/code/async/sequencer/scheduler_test.go +++ b/pkg/code/async/sequencer/scheduler_test.go @@ -605,7 +605,7 @@ func TestContextualScheduler_NoPrivacyWithdraw(t *testing.T) { for _, fulfillmentRecord := range fulfillmentRecords { if fulfillmentRecord.FulfillmentType == fulfillment.TransferWithCommitment && *fulfillmentRecord.Destination == transfer.Source { fulfillmentRecord.State = fulfillment.StateConfirmed - env.data.UpdateFulfillment(env.ctx, fulfillmentRecord) + require.NoError(t, env.data.UpdateFulfillment(env.ctx, fulfillmentRecord)) } } } @@ -614,7 +614,7 @@ func TestContextualScheduler_NoPrivacyWithdraw(t *testing.T) { for _, fulfillmentRecord := range fulfillmentRecords { if fulfillmentRecord.FulfillmentType == fulfillment.NoPrivacyTransferWithAuthority && fulfillmentRecord.Source == transfer.Source { fulfillmentRecord.State = fulfillment.StateConfirmed - env.data.UpdateFulfillment(env.ctx, fulfillmentRecord) + require.NoError(t, env.data.UpdateFulfillment(env.ctx, fulfillmentRecord)) break } } @@ -624,7 +624,7 @@ func TestContextualScheduler_NoPrivacyWithdraw(t *testing.T) { for _, fulfillmentRecord := range fulfillmentRecords { if fulfillmentRecord.FulfillmentType == fulfillment.NoPrivacyTransferWithAuthority && fulfillmentRecord.Source == transfer.Source { fulfillmentRecord.State = fulfillment.StateConfirmed - env.data.UpdateFulfillment(env.ctx, fulfillmentRecord) + require.NoError(t, env.data.UpdateFulfillment(env.ctx, fulfillmentRecord)) } } } @@ -767,7 +767,7 @@ func TestContextualScheduler_CloseDormantTimelockAccount(t *testing.T) { for _, fulfillmentRecord := range fulfillmentRecords { if fulfillmentRecord.FulfillmentType == fulfillment.TransferWithCommitment && *fulfillmentRecord.Destination == closeDormantAccount.Source { fulfillmentRecord.State = fulfillment.StateConfirmed - env.data.UpdateFulfillment(env.ctx, fulfillmentRecord) + require.NoError(t, env.data.UpdateFulfillment(env.ctx, fulfillmentRecord)) } } } @@ -776,7 +776,7 @@ func TestContextualScheduler_CloseDormantTimelockAccount(t *testing.T) { for _, fulfillmentRecord := range fulfillmentRecords { if fulfillmentRecord.FulfillmentType == fulfillment.TemporaryPrivacyTransferWithAuthority && fulfillmentRecord.Source == closeDormantAccount.Source { fulfillmentRecord.State = fulfillment.StateConfirmed - env.data.UpdateFulfillment(env.ctx, fulfillmentRecord) + require.NoError(t, env.data.UpdateFulfillment(env.ctx, fulfillmentRecord)) break } } @@ -786,7 +786,7 @@ func TestContextualScheduler_CloseDormantTimelockAccount(t *testing.T) { for _, fulfillmentRecord := range fulfillmentRecords { if fulfillmentRecord.FulfillmentType == fulfillment.TemporaryPrivacyTransferWithAuthority && fulfillmentRecord.Source == closeDormantAccount.Source { fulfillmentRecord.State = fulfillment.StateConfirmed - env.data.UpdateFulfillment(env.ctx, fulfillmentRecord) + require.NoError(t, env.data.UpdateFulfillment(env.ctx, fulfillmentRecord)) } } } @@ -926,7 +926,7 @@ func TestContextualScheduler_PrivateTransfer_TemporaryPrivacyFlow(t *testing.T) require.NoError(t, err) timelockRecord.VaultState = timelock_token_v1.StateUnlocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) } @@ -1103,7 +1103,7 @@ func TestContextualScheduler_PrivateTransfer_PermanentPrivacyFlow(t *testing.T) for _, fulfillmentRecord := range fulfillmentRecords { if fulfillmentRecord.FulfillmentType == fulfillment.TransferWithCommitment && *fulfillmentRecord.Destination == transfer.Source { fulfillmentRecord.State = fulfillment.StateConfirmed - env.data.UpdateFulfillment(env.ctx, fulfillmentRecord) + require.NoError(t, env.data.UpdateFulfillment(env.ctx, fulfillmentRecord)) } } } @@ -2180,7 +2180,7 @@ func (e *schedulerTestEnv) setupSchedulerTest(t *testing.T, intentRecords []*int HistoryListSize: 1, HistoryList: []string{"unused"}, SolanaBlock: 123, - State: treasury.TreasuryPoolStateAvailable, + State: treasury.PoolStateAvailable, } require.NoError(t, e.data.SaveTreasuryPool(e.ctx, treasuryPoolRecord)) @@ -2380,7 +2380,7 @@ func (e *schedulerTestEnv) setupSchedulerTest(t *testing.T, intentRecords []*int closeNewOutgoing, ) - currentOutgoingByUser[intentRecord.InitiatorOwnerAccount] += 1 + currentOutgoingByUser[intentRecord.InitiatorOwnerAccount]++ case intent.ReceivePaymentsPrivately: assert.True(t, intentRecord.ReceivePaymentsPrivatelyMetadata.Quantity < kin.ToQuarks(1000)) @@ -2472,7 +2472,7 @@ func (e *schedulerTestEnv) setupSchedulerTest(t *testing.T, intentRecords []*int closeNewIncoming, ) - currentIncomingByUser[intentRecord.InitiatorOwnerAccount] += 1 + currentIncomingByUser[intentRecord.InitiatorOwnerAccount]++ case intent.SendPublicPayment: newActionRecords = append( newActionRecords, @@ -3020,7 +3020,7 @@ func (e *schedulerTestEnv) assertReservedTreasuryFunds(t *testing.T, expected ui } func (e *schedulerTestEnv) getNextSlot() uint64 { - e.nextSlot += 1 + e.nextSlot++ return e.nextSlot } diff --git a/pkg/code/async/sequencer/service.go b/pkg/code/async/sequencer/service.go index 32d031eb..be5cf4f6 100644 --- a/pkg/code/async/sequencer/service.go +++ b/pkg/code/async/sequencer/service.go @@ -43,7 +43,6 @@ func New(data code_data.Provider, scheduler Scheduler, configProvider ConfigProv } func (p *service) Start(ctx context.Context, interval time.Duration) error { - // Setup workers to watch for fulfillment state changes on the Solana side for _, item := range []fulfillment.State{ fulfillment.StateUnknown, @@ -55,7 +54,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { // fulfillment.StateRevoked, } { go func(state fulfillment.State) { - // todo: Note to our future selves that there are some components of // the scheduler (ie. subsidizer balance checks) that won't // work perfectly in a multi-threaded or multi-node environment. @@ -63,7 +61,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { if err != nil && err != context.Canceled { p.log.WithError(err).Warnf("fulfillment processing loop terminated unexpectedly for state %d", state) } - }(item) } @@ -74,8 +71,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { } }() - select { - case <-ctx.Done(): - return ctx.Err() - } + <-ctx.Done() + return ctx.Err() } diff --git a/pkg/code/async/sequencer/testutil.go b/pkg/code/async/sequencer/testutil.go index eb2081f7..ece94ec7 100644 --- a/pkg/code/async/sequencer/testutil.go +++ b/pkg/code/async/sequencer/testutil.go @@ -4,12 +4,12 @@ import ( "context" "errors" - "github.com/code-payments/code-server/pkg/solana" - "github.com/code-payments/code-server/pkg/solana/memo" "github.com/code-payments/code-server/pkg/code/common" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/transaction" transaction_util "github.com/code-payments/code-server/pkg/code/transaction" + "github.com/code-payments/code-server/pkg/solana" + "github.com/code-payments/code-server/pkg/solana/memo" ) type mockScheduler struct { @@ -48,7 +48,10 @@ func (h *mockFulfillmentHandler) MakeOnDemandTransaction(ctx context.Context, fu } txn := solana.NewTransaction(common.GetSubsidizer().PublicKey().ToBytes(), memo.Instruction(selectedNonce.Account.PublicKey().ToBase58())) - txn.Sign(common.GetSubsidizer().PrivateKey().ToBytes()) + if err := txn.Sign(common.GetSubsidizer().PrivateKey().ToBytes()); err != nil { + return nil, err + } + return &txn, nil } diff --git a/pkg/code/async/sequencer/worker.go b/pkg/code/async/sequencer/worker.go index eda24ef0..6a6d479d 100644 --- a/pkg/code/async/sequencer/worker.go +++ b/pkg/code/async/sequencer/worker.go @@ -9,15 +9,16 @@ import ( "github.com/mr-tron/base58" "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/metrics" - "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/retry" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/nonce" "github.com/code-payments/code-server/pkg/code/data/transaction" transaction_util "github.com/code-payments/code-server/pkg/code/transaction" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/metrics" + "github.com/code-payments/code-server/pkg/pointer" + "github.com/code-payments/code-server/pkg/retry" ) func (p *service) worker(serviceCtx context.Context, state fulfillment.State, interval time.Duration) error { @@ -28,7 +29,7 @@ func (p *service) worker(serviceCtx context.Context, state fulfillment.State, in func() (err error) { time.Sleep(delay) - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__sequencer_service__handle_" + state.String()) defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -228,8 +229,26 @@ func (p *service) handlePending(ctx context.Context, record *fulfillment.Record) if err != nil { return err } + + // note: defer() will only run when the outer function returns, and + // therefore all of the defer()'s in this loop will be run all at once + // at the end, rather than at the end of each iteration. + // + // Since we are not committing (and therefore consuming) the nonce's until + // the end of the function, this is desirable. If we released at the end of + // each iteration, we could potentially acquire the same nonce multiple times + // for different transactions, which would fail. defer func() { - selectedNonce.ReleaseIfNotReserved() + if err := selectedNonce.ReleaseIfNotReserved(); err != nil { + p.log. + WithFields(logrus.Fields{ + "method": "handlePending", + "nonce_account": selectedNonce.Account.PublicKey().ToBase58(), + "blockhash": selectedNonce.Blockhash.ToBase58(), + }). + WithError(err). + Warn("failed to release nonce") + } selectedNonce.Unlock() }() diff --git a/pkg/code/async/sequencer/worker_test.go b/pkg/code/async/sequencer/worker_test.go index d8b23ca6..1164d18b 100644 --- a/pkg/code/async/sequencer/worker_test.go +++ b/pkg/code/async/sequencer/worker_test.go @@ -2,7 +2,7 @@ package async_sequencer import ( "context" - "math/rand" + "crypto/rand" "testing" "time" @@ -10,11 +10,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/solana" - "github.com/code-payments/code-server/pkg/solana/memo" - "github.com/code-payments/code-server/pkg/solana/system" - "github.com/code-payments/code-server/pkg/testutil" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/action" @@ -24,6 +19,11 @@ import ( "github.com/code-payments/code-server/pkg/code/data/transaction" "github.com/code-payments/code-server/pkg/code/data/vault" transaction_util "github.com/code-payments/code-server/pkg/code/transaction" + "github.com/code-payments/code-server/pkg/pointer" + "github.com/code-payments/code-server/pkg/solana" + "github.com/code-payments/code-server/pkg/solana/memo" + "github.com/code-payments/code-server/pkg/solana/system" + "github.com/code-payments/code-server/pkg/testutil" ) func TestFulfillmentWorker_StateUnknown_RemainInStateUnknown(t *testing.T) { @@ -266,7 +266,7 @@ func (e *workerTestEnv) createAnyFulfillmentInState(t *testing.T, state fulfillm copy(typedBlockhash[:], untypedBlockhash) txn.SetBlockhash(typedBlockhash) - txn.Sign(fakeCodeAccouht.PrivateKey().ToBytes()) + require.NoError(t, txn.Sign(fakeCodeAccouht.PrivateKey().ToBytes())) fulfillmentRecord := &fulfillment.Record{ Intent: testutil.NewRandomAccount(t).PublicKey().ToBase58(), @@ -344,7 +344,7 @@ func (e *workerTestEnv) assertFulfillmentCreatedOnDemand(t *testing.T, id uint64 require.NotEmpty(t, fulfillmentRecord.Data) expectedTxn := solana.NewTransaction(common.GetSubsidizer().PublicKey().ToBytes(), memo.Instruction(nonceAddress)) - expectedTxn.Sign(e.subsidizer.PrivateKey().ToBytes()) + require.NoError(t, expectedTxn.Sign(e.subsidizer.PrivateKey().ToBytes())) expectedSignature := base58.Encode(expectedTxn.Signature()) assert.Equal(t, expectedSignature, *fulfillmentRecord.Signature) @@ -359,7 +359,8 @@ func (e *workerTestEnv) generateAvailableNonce(t *testing.T) *nonce.Record { nonceAccount := testutil.NewRandomAccount(t) var bh solana.Blockhash - rand.Read(bh[:]) + _, err := rand.Read(bh[:]) + require.NoError(t, err) nonceKey := &vault.Record{ PublicKey: nonceAccount.PublicKey().ToBase58(), diff --git a/pkg/code/async/treasury/merkle_tree.go b/pkg/code/async/treasury/merkle_tree.go index 448e7f06..0d8c4477 100644 --- a/pkg/code/async/treasury/merkle_tree.go +++ b/pkg/code/async/treasury/merkle_tree.go @@ -8,9 +8,6 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/solana" - splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/intent" @@ -18,6 +15,9 @@ import ( "github.com/code-payments/code-server/pkg/code/data/payment" "github.com/code-payments/code-server/pkg/code/data/transaction" "github.com/code-payments/code-server/pkg/code/data/treasury" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/solana" + splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" ) func (p *service) syncMerkleTree(ctx context.Context, treasuryPoolRecord *treasury.Record) error { @@ -214,7 +214,7 @@ func (p *service) syncMerkleTree(ctx context.Context, treasuryPoolRecord *treasu endingBlockToQuery := endingBlock + 1 startingBlockToQuery := startingBlock if startingBlockToQuery > 0 { - startingBlockToQuery -= 1 + startingBlockToQuery-- } paymentRecords, err := p.data.GetPaymentHistoryWithinBlockRange( @@ -222,7 +222,7 @@ func (p *service) syncMerkleTree(ctx context.Context, treasuryPoolRecord *treasu treasuryPoolRecord.Vault, startingBlockToQuery, endingBlockToQuery, - query.WithFilter(query.Filter{Value: uint64(payment.PaymentType_Send), Valid: true}), + query.WithFilter(query.Filter{Value: uint64(payment.TypeSend), Valid: true}), query.WithLimit(1000), query.WithCursor(cursor), ) diff --git a/pkg/code/async/treasury/metrics.go b/pkg/code/async/treasury/metrics.go index 626fbdf6..0e2a6f9c 100644 --- a/pkg/code/async/treasury/metrics.go +++ b/pkg/code/async/treasury/metrics.go @@ -4,10 +4,10 @@ import ( "context" "time" - "github.com/code-payments/code-server/pkg/kin" - "github.com/code-payments/code-server/pkg/metrics" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/treasury" + "github.com/code-payments/code-server/pkg/kin" + "github.com/code-payments/code-server/pkg/metrics" ) const ( @@ -26,7 +26,7 @@ func (p *service) metricsGaugeWorker(ctx context.Context) error { case <-time.After(delay): start := time.Now() - treasuryPoolRecords, err := p.data.GetAllTreasuryPoolsByState(ctx, treasury.TreasuryPoolStateAvailable) + treasuryPoolRecords, err := p.data.GetAllTreasuryPoolsByState(ctx, treasury.PoolStateAvailable) if err != nil { continue } diff --git a/pkg/code/async/treasury/recent_root.go b/pkg/code/async/treasury/recent_root.go index 9df41c7b..221f7a62 100644 --- a/pkg/code/async/treasury/recent_root.go +++ b/pkg/code/async/treasury/recent_root.go @@ -3,16 +3,13 @@ package async_treasury import ( "context" "database/sql" + "fmt" "math" "time" "github.com/mr-tron/base58" "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/solana" - splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" "github.com/code-payments/code-server/pkg/code/common" "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/fulfillment" @@ -22,6 +19,10 @@ import ( "github.com/code-payments/code-server/pkg/code/data/payment" "github.com/code-payments/code-server/pkg/code/data/treasury" "github.com/code-payments/code-server/pkg/code/transaction" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/pointer" + "github.com/code-payments/code-server/pkg/solana" + splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" ) // This method is expected to be extremely safe due to the implications of saving @@ -326,7 +327,9 @@ func makeSaveRecentRootTransaction(selectedNonce *transaction.SelectedNonce, rec return solana.Transaction{}, err } - txn.Sign(common.GetSubsidizer().PrivateKey().ToBytes()) + if err := txn.Sign(common.GetSubsidizer().PrivateKey().ToBytes()); err != nil { + return solana.Transaction{}, fmt.Errorf("failed to sign transaction: %w", err) + } return txn, nil } @@ -364,7 +367,7 @@ func (p *service) anyFinalizedTreasuryAdvancesAfterLastSaveRecentRoot(ctx contex treasuryPoolRecord.Vault, lowerBoundBlock+1, math.MaxInt64, - query.WithFilter(query.Filter{Value: uint64(payment.PaymentType_Send), Valid: true}), + query.WithFilter(query.Filter{Value: uint64(payment.TypeSend), Valid: true}), query.WithLimit(1), ) if err == payment.ErrNotFound || len(paymentRecords) == 0 { diff --git a/pkg/code/async/treasury/service.go b/pkg/code/async/treasury/service.go index abf395b7..952e7426 100644 --- a/pkg/code/async/treasury/service.go +++ b/pkg/code/async/treasury/service.go @@ -27,20 +27,18 @@ func New(data code_data.Provider, configProvider ConfigProvider) async.Service { func (p *service) Start(ctx context.Context, interval time.Duration) error { // Setup workers to watch for updates to pools - for _, item := range []treasury.TreasuryPoolState{ - treasury.TreasuryPoolStateAvailable, + for _, item := range []treasury.PoolState{ + treasury.PoolStateAvailable, // Below states have no executable logic - // treasury.TreasuryPoolStateUnknown, - // treasury.TreasuryPoolStateDeprecated, + // treasury.PoolStateUnknown, + // treasury.PoolStateDeprecated, } { - go func(state treasury.TreasuryPoolState) { - + go func(state treasury.PoolState) { err := p.worker(ctx, state, interval) if err != nil && err != context.Canceled { p.log.WithError(err).Warnf("pool processing loop terminated unexpectedly for state %d", state) } - }(item) } @@ -51,8 +49,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { } }() - select { - case <-ctx.Done(): - return ctx.Err() - } + <-ctx.Done() + return ctx.Err() } diff --git a/pkg/code/async/treasury/testutil.go b/pkg/code/async/treasury/testutil.go index ccd48671..8bfa5a0b 100644 --- a/pkg/code/async/treasury/testutil.go +++ b/pkg/code/async/treasury/testutil.go @@ -3,9 +3,10 @@ package async_treasury import ( "context" "crypto/ed25519" + "crypto/rand" "encoding/hex" "fmt" - "math/rand" + mrand "math/rand" "testing" "time" @@ -13,12 +14,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/code-payments/code-server/pkg/kin" - "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/solana" - splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" - "github.com/code-payments/code-server/pkg/solana/system" - "github.com/code-payments/code-server/pkg/testutil" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/action" @@ -31,6 +26,12 @@ import ( "github.com/code-payments/code-server/pkg/code/data/transaction" "github.com/code-payments/code-server/pkg/code/data/treasury" "github.com/code-payments/code-server/pkg/code/data/vault" + "github.com/code-payments/code-server/pkg/kin" + "github.com/code-payments/code-server/pkg/pointer" + "github.com/code-payments/code-server/pkg/solana" + splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" + "github.com/code-payments/code-server/pkg/solana/system" + "github.com/code-payments/code-server/pkg/testutil" ) type testEnv struct { @@ -71,7 +72,7 @@ func setup(t *testing.T, testOverrides *testOverrides) *testEnv { SolanaBlock: 1, - State: treasury.TreasuryPoolStateAvailable, + State: treasury.PoolStateAvailable, } merkleTree, err := db.InitializeNewMerkleTree( @@ -126,7 +127,7 @@ func (e *testEnv) simulateMostRecentRoot(t *testing.T, intentState intent.State, FulfillmentType: fulfillment.SaveRecentRoot, Data: []byte("data"), - Signature: pointer.String(fmt.Sprintf("sig%d", rand.Uint64())), + Signature: pointer.String(fmt.Sprintf("sig%d", mrand.Uint64())), Nonce: pointer.String(testutil.NewRandomAccount(t).PublicKey().ToBase58()), Blockhash: pointer.String("bh"), @@ -190,7 +191,7 @@ func (e *testEnv) simulateCommitments(t *testing.T, count int, recentRoot string Amount: kin.ToQuarks(1), Intent: testutil.NewRandomAccount(t).PublicKey().ToBase58(), - ActionId: rand.Uint32(), + ActionId: mrand.Uint32(), Owner: testutil.NewRandomAccount(t).PublicKey().ToBase58(), @@ -208,7 +209,7 @@ func (e *testEnv) simulateCommitments(t *testing.T, count int, recentRoot string FulfillmentType: fulfillment.TransferWithCommitment, Data: []byte("data"), - Signature: pointer.String(fmt.Sprintf("sig%d", rand.Uint64())), + Signature: pointer.String(fmt.Sprintf("sig%d", mrand.Uint64())), Nonce: pointer.String(testutil.NewRandomAccount(t).PublicKey().ToBase58()), Blockhash: pointer.String("bh"), @@ -427,7 +428,8 @@ func (e *testEnv) generateAvailableNonce(t *testing.T) *nonce.Record { nonceAccount := testutil.NewRandomAccount(t) var bh solana.Blockhash - rand.Read(bh[:]) + _, err := rand.Read(bh[:]) + require.NoError(t, err) nonceKey := &vault.Record{ PublicKey: nonceAccount.PublicKey().ToBase58(), @@ -456,6 +458,6 @@ func (e *testEnv) generateAvailableNonces(t *testing.T, count int) []*nonce.Reco } func (e *testEnv) getNextBlock() uint64 { - e.nextBlock += 1 + e.nextBlock++ return e.nextBlock } diff --git a/pkg/code/async/treasury/worker.go b/pkg/code/async/treasury/worker.go index e9f1db29..c8c33370 100644 --- a/pkg/code/async/treasury/worker.go +++ b/pkg/code/async/treasury/worker.go @@ -8,11 +8,11 @@ import ( "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + "github.com/code-payments/code-server/pkg/code/data/treasury" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/metrics" "github.com/code-payments/code-server/pkg/retry" splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" - "github.com/code-payments/code-server/pkg/code/data/treasury" ) const ( @@ -24,7 +24,7 @@ var ( treasuryPoolLock sync.Mutex ) -func (p *service) worker(serviceCtx context.Context, state treasury.TreasuryPoolState, interval time.Duration) error { +func (p *service) worker(serviceCtx context.Context, state treasury.PoolState, interval time.Duration) error { delay := interval var cursor query.Cursor @@ -32,7 +32,7 @@ func (p *service) worker(serviceCtx context.Context, state treasury.TreasuryPool func() (err error) { time.Sleep(delay) - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__treasury_pool_service__handle_" + state.String()) defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -82,7 +82,7 @@ func (p *service) worker(serviceCtx context.Context, state treasury.TreasuryPool func (p *service) handle(ctx context.Context, record *treasury.Record) error { switch record.State { - case treasury.TreasuryPoolStateAvailable: + case treasury.PoolStateAvailable: return p.handleAvailable(ctx, record) default: return nil diff --git a/pkg/code/async/user/service.go b/pkg/code/async/user/service.go index 9230704c..a6e5279a 100644 --- a/pkg/code/async/user/service.go +++ b/pkg/code/async/user/service.go @@ -47,8 +47,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { } }() - select { - case <-ctx.Done(): - return ctx.Err() - } + <-ctx.Done() + return ctx.Err() } diff --git a/pkg/code/async/user/twitter.go b/pkg/code/async/user/twitter.go index c8612e9f..ecdd02b7 100644 --- a/pkg/code/async/user/twitter.go +++ b/pkg/code/async/user/twitter.go @@ -43,7 +43,7 @@ func (p *service) twitterRegistrationWorker(serviceCtx context.Context, interval func() (err error) { time.Sleep(delay) - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__user_service__handle_twitter_registration") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -70,7 +70,7 @@ func (p *service) twitterUserInfoUpdateWorker(serviceCtx context.Context, interv func() (err error) { time.Sleep(delay) - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__user_service__handle_twitter_user_info_update") defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -140,7 +140,12 @@ func (p *service) processNewTwitterRegistrations(ctx context.Context) error { switch err { case nil: - go push_util.SendTwitterAccountConnectedPushNotification(ctx, p.data, p.pusher, tipAccount) + go func() { + err := push_util.SendTwitterAccountConnectedPushNotification(ctx, p.data, p.pusher, tipAccount) + if err != nil { + p.log.WithError(err).Warn("failed to send twitter account connected push notification (best effort)") + } + }() case twitter.ErrDuplicateTipAddress: err = p.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { err = p.data.MarkTwitterNonceAsUsed(ctx, tweet.ID, *registrationNonce) diff --git a/pkg/code/async/webhook/service.go b/pkg/code/async/webhook/service.go index 61518849..6cdc2bd9 100644 --- a/pkg/code/async/webhook/service.go +++ b/pkg/code/async/webhook/service.go @@ -50,8 +50,6 @@ func (p *service) Start(ctx context.Context, interval time.Duration) error { } }() - select { - case <-ctx.Done(): - return ctx.Err() - } + <-ctx.Done() + return ctx.Err() } diff --git a/pkg/code/async/webhook/worker.go b/pkg/code/async/webhook/worker.go index 1de6d99b..baf9bbac 100644 --- a/pkg/code/async/webhook/worker.go +++ b/pkg/code/async/webhook/worker.go @@ -9,11 +9,11 @@ import ( "github.com/newrelic/go-agent/v3/newrelic" "github.com/sirupsen/logrus" + "github.com/code-payments/code-server/pkg/code/data/webhook" + webhook_util "github.com/code-payments/code-server/pkg/code/webhook" "github.com/code-payments/code-server/pkg/metrics" "github.com/code-payments/code-server/pkg/pointer" "github.com/code-payments/code-server/pkg/retry" - "github.com/code-payments/code-server/pkg/code/data/webhook" - webhook_util "github.com/code-payments/code-server/pkg/code/webhook" ) var ( @@ -59,7 +59,7 @@ func (p *service) worker(serviceCtx context.Context, interval time.Duration) err wg.Add(1) go func(record *webhook.Record) { - nr := serviceCtx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr := serviceCtx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) m := nr.StartTransaction("async__webhook_service__handle_" + webhook.StatePending.String()) defer m.End() tracedCtx := newrelic.NewContext(serviceCtx, m) @@ -140,8 +140,8 @@ func (p *service) setupNextAttempt(ctx context.Context, record *webhook.Record) return false, p.updateWebhookRecord(ctx, &cloned) } - record.Attempts += 1 - cloned.Attempts += 1 + record.Attempts++ + cloned.Attempts++ cloned.NextAttemptAt = pointer.Time(time.Now().Add(delay)) return true, p.updateWebhookRecord(ctx, &cloned) @@ -155,7 +155,7 @@ func (p *service) onWebhookExecuted(ctx context.Context, record *webhook.Record, // Save success state immediately if isSuccess { p.metricsMu.Lock() - p.successfulWebhooks += 1 + p.successfulWebhooks++ p.metricsMu.Unlock() record.State = webhook.StateConfirmed @@ -164,7 +164,7 @@ func (p *service) onWebhookExecuted(ctx context.Context, record *webhook.Record, } p.metricsMu.Lock() - p.failedWebhooks += 1 + p.failedWebhooks++ p.metricsMu.Unlock() // Otherwise, save failure state only if we're on the last attempt diff --git a/pkg/code/balance/calculator.go b/pkg/code/balance/calculator.go index 001b3ded..586fb59a 100644 --- a/pkg/code/balance/calculator.go +++ b/pkg/code/balance/calculator.go @@ -171,7 +171,12 @@ func CalculateFromBlockchain(ctx context.Context, data code_data.Provider, token Quarks: quarks, SlotCheckpoint: slot, } - data.SaveBalanceCheckpoint(ctx, newCheckpointRecord) + if err := data.SaveBalanceCheckpoint(ctx, newCheckpointRecord); err != nil { + logrus.StandardLogger(). + WithField("method", "CalculateFromBlockchain"). + WithError(err). + Warn("failed to save checkpoint record (best effort)") + } } return quarks, BlockchainSource, nil diff --git a/pkg/code/balance/calculator_test.go b/pkg/code/balance/calculator_test.go index 6f9153af..73914d8c 100644 --- a/pkg/code/balance/calculator_test.go +++ b/pkg/code/balance/calculator_test.go @@ -346,7 +346,7 @@ func TestDefaultCalculationMethods_NotManagedByCode(t *testing.T) { timelockRecord, err := env.data.GetTimelockByVault(env.ctx, tokenAccount.PublicKey().ToBase58()) require.NoError(t, err) timelockRecord.VaultState = timelock_token_v1.StateWaitingForTimeout - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) accountRecords, err := common.GetLatestTokenAccountRecordsForOwner(env.ctx, env.data, ownerAccount) @@ -494,7 +494,7 @@ func setupBalanceTestData(t *testing.T, env balanceTestEnv, data *balanceTestDat require.NoError(t, err) timelockRecord := timelockAccounts.ToDBRecord() timelockRecord.VaultState = timelock_token_v1.StateLocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) if !conf.useLegacyIntents { diff --git a/pkg/code/chat/message_cash_transactions.go b/pkg/code/chat/message_cash_transactions.go index 1976d9d5..2947994f 100644 --- a/pkg/code/chat/message_cash_transactions.go +++ b/pkg/code/chat/message_cash_transactions.go @@ -152,7 +152,7 @@ func SendCashTransactionsExchangeMessage(ctx context.Context, data code_data.Pro ctx, data, CashTransactionsName, - chat.ChatTypeInternal, + chat.TypeInternal, true, receiver, protoMessage, diff --git a/pkg/code/chat/message_code_team.go b/pkg/code/chat/message_code_team.go index fe8a7049..9637159e 100644 --- a/pkg/code/chat/message_code_team.go +++ b/pkg/code/chat/message_code_team.go @@ -20,7 +20,7 @@ func SendCodeTeamMessage(ctx context.Context, data code_data.Provider, receiver ctx, data, CodeTeamName, - chat.ChatTypeInternal, + chat.TypeInternal, true, receiver, chatMessage, diff --git a/pkg/code/chat/message_kin_purchases.go b/pkg/code/chat/message_kin_purchases.go index 1ec10b68..036187e8 100644 --- a/pkg/code/chat/message_kin_purchases.go +++ b/pkg/code/chat/message_kin_purchases.go @@ -17,7 +17,7 @@ import ( // GetKinPurchasesChatId returns the chat ID for the Kin Purchases chat for a // given owner account -func GetKinPurchasesChatId(owner *common.Account) chat.ChatId { +func GetKinPurchasesChatId(owner *common.Account) chat.Id { return chat.GetChatId(KinPurchasesName, owner.PublicKey().ToBase58(), true) } @@ -27,7 +27,7 @@ func SendKinPurchasesMessage(ctx context.Context, data code_data.Provider, recei ctx, data, KinPurchasesName, - chat.ChatTypeInternal, + chat.TypeInternal, true, receiver, chatMessage, diff --git a/pkg/code/chat/message_merchant.go b/pkg/code/chat/message_merchant.go index b4504c39..4f7c79b6 100644 --- a/pkg/code/chat/message_merchant.go +++ b/pkg/code/chat/message_merchant.go @@ -36,7 +36,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // the merchant. Representation in the UI may differ (ie. 2 and 3 are grouped), // but this is the most flexible solution with the chat model. chatTitle := PaymentsName - chatType := chat.ChatTypeInternal + chatType := chat.TypeInternal isVerifiedChat := false exchangeData, ok := getExchangeDataFromIntent(intentRecord) @@ -59,7 +59,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i if paymentRequestRecord.Domain != nil { chatTitle = *paymentRequestRecord.Domain - chatType = chat.ChatTypeExternalApp + chatType = chat.TypeExternalApp isVerifiedChat = paymentRequestRecord.IsVerified } @@ -87,7 +87,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat.TypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -107,7 +107,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat.TypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -126,7 +126,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat.TypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.ExternalDepositMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, diff --git a/pkg/code/chat/message_tips.go b/pkg/code/chat/message_tips.go index b9984a9b..03d04467 100644 --- a/pkg/code/chat/message_tips.go +++ b/pkg/code/chat/message_tips.go @@ -70,7 +70,7 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, inten ctx, data, TipsName, - chat.ChatTypeInternal, + chat.TypeInternal, true, receiver, protoMessage, diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index 41da0902..a39d9a0e 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -24,7 +24,7 @@ func SendChatMessage( ctx context.Context, data code_data.Provider, chatTitle string, - chatType chat.ChatType, + chatType chat.Type, isVerifiedChat bool, receiver *common.Account, protoMessage *chatpb.ChatMessage, diff --git a/pkg/code/chat/sender_test.go b/pkg/code/chat/sender_test.go index 7625a1f3..f92fa725 100644 --- a/pkg/code/chat/sender_test.go +++ b/pkg/code/chat/sender_test.go @@ -31,9 +31,9 @@ func TestSendChatMessage_HappyPath(t *testing.T) { var expectedBadgeCount int for i := 0; i < 10; i++ { chatMessage := newRandomChatMessage(t, i+1) - expectedBadgeCount += 1 + expectedBadgeCount++ - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.TypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.True(t, canPush) @@ -56,7 +56,7 @@ func TestSendChatMessage_VerifiedChat(t *testing.T) { for _, isVerified := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - _, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, isVerified, receiver, chatMessage, true) + _, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.TypeInternal, isVerified, receiver, chatMessage, true) require.NoError(t, err) env.assertChatRecordSaved(t, chatTitle, receiver, isVerified) } @@ -71,7 +71,7 @@ func TestSendChatMessage_SilentMessage(t *testing.T) { for i, isSilent := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, isSilent) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.TypeInternal, true, receiver, chatMessage, isSilent) require.NoError(t, err) assert.Equal(t, !isSilent, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, isSilent) @@ -92,7 +92,7 @@ func TestSendChatMessage_MuteState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.TypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isMuted, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, false) @@ -113,7 +113,7 @@ func TestSendChatMessage_SubscriptionState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.TypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isUnsubscribed, canPush) if isUnsubscribed { @@ -135,7 +135,7 @@ func TestSendChatMessage_InvalidProtoMessage(t *testing.T) { chatMessage := newRandomChatMessage(t, 1) chatMessage.Content = nil - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.TypeInternal, true, receiver, chatMessage, false) assert.Error(t, err) assert.False(t, canPush) env.assertChatRecordNotSaved(t, chatId) @@ -179,7 +179,7 @@ func (e *testEnv) assertChatRecordSaved(t *testing.T, chatTitle string, receiver require.NoError(t, err) assert.Equal(t, chatId[:], chatRecord.ChatId[:]) - assert.Equal(t, chat.ChatTypeInternal, chatRecord.ChatType) + assert.Equal(t, chat.TypeInternal, chatRecord.ChatType) assert.Equal(t, chatTitle, chatRecord.ChatTitle) assert.Equal(t, isVerified, chatRecord.IsVerified) assert.Equal(t, receiver.PublicKey().ToBase58(), chatRecord.CodeUser) @@ -188,7 +188,7 @@ func (e *testEnv) assertChatRecordSaved(t *testing.T, chatTitle string, receiver assert.False(t, chatRecord.IsUnsubscribed) } -func (e *testEnv) assertChatMessageRecordSaved(t *testing.T, chatId chat.ChatId, protoMessage *chatpb.ChatMessage, isSilent bool) { +func (e *testEnv) assertChatMessageRecordSaved(t *testing.T, chatId chat.Id, protoMessage *chatpb.ChatMessage, isSilent bool) { messageRecord, err := e.data.GetChatMessage(e.ctx, chatId, base58.Encode(protoMessage.GetMessageId().Value)) require.NoError(t, err) @@ -218,22 +218,20 @@ func (e *testEnv) assertBadgeCount(t *testing.T, owner *common.Account, expected assert.EqualValues(t, expected, badgeCountRecord.BadgeCount) } -func (e *testEnv) assertChatRecordNotSaved(t *testing.T, chatId chat.ChatId) { +func (e *testEnv) assertChatRecordNotSaved(t *testing.T, chatId chat.Id) { _, err := e.data.GetChatById(e.ctx, chatId) assert.Equal(t, chat.ErrChatNotFound, err) - } -func (e *testEnv) assertChatMessageRecordNotSaved(t *testing.T, chatId chat.ChatId, messageId *chatpb.ChatMessageId) { +func (e *testEnv) assertChatMessageRecordNotSaved(t *testing.T, chatId chat.Id, messageId *chatpb.ChatMessageId) { _, err := e.data.GetChatMessage(e.ctx, chatId, base58.Encode(messageId.Value)) assert.Equal(t, chat.ErrMessageNotFound, err) - } -func (e *testEnv) muteChat(t *testing.T, chatId chat.ChatId) { +func (e *testEnv) muteChat(t *testing.T, chatId chat.Id) { require.NoError(t, e.data.SetChatMuteState(e.ctx, chatId, true)) } -func (e *testEnv) unsubscribeFromChat(t *testing.T, chatId chat.ChatId) { +func (e *testEnv) unsubscribeFromChat(t *testing.T, chatId chat.Id) { require.NoError(t, e.data.SetChatSubscriptionState(e.ctx, chatId, false)) } diff --git a/pkg/code/common/account_test.go b/pkg/code/common/account_test.go index f0f1ccd5..64209775 100644 --- a/pkg/code/common/account_test.go +++ b/pkg/code/common/account_test.go @@ -184,7 +184,7 @@ func TestIsAccountManagedByCode_TimelockState_V1Program(t *testing.T) { assert.True(t, result) timelockRecord.VaultState = timelock_token_v1.StateLocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, data.SaveTimelock(ctx, timelockRecord)) result, err = timelockAccounts.Vault.IsManagedByCode(ctx, data) @@ -193,7 +193,7 @@ func TestIsAccountManagedByCode_TimelockState_V1Program(t *testing.T) { // The timelock account is waiting for timeout timelockRecord.VaultState = timelock_token_v1.StateWaitingForTimeout - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, data.SaveTimelock(ctx, timelockRecord)) result, err = timelockAccounts.Vault.IsManagedByCode(ctx, data) @@ -202,7 +202,7 @@ func TestIsAccountManagedByCode_TimelockState_V1Program(t *testing.T) { // The timelock account is unlocked timelockRecord.VaultState = timelock_token_v1.StateUnlocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, data.SaveTimelock(ctx, timelockRecord)) result, err = timelockAccounts.Vault.IsManagedByCode(ctx, data) @@ -569,7 +569,7 @@ func TestIsAccountManagedByCode_TimelockState_Legacy2022Program(t *testing.T) { assert.True(t, result) timelockRecord.VaultState = timelock_token_v1.StateLocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, data.SaveTimelock(ctx, timelockRecord)) result, err = timelockAccounts.Vault.IsManagedByCode(ctx, data) @@ -578,7 +578,7 @@ func TestIsAccountManagedByCode_TimelockState_Legacy2022Program(t *testing.T) { // The timelock account is waiting for timeout timelockRecord.VaultState = timelock_token_v1.StateWaitingForTimeout - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, data.SaveTimelock(ctx, timelockRecord)) result, err = timelockAccounts.Vault.IsManagedByCode(ctx, data) @@ -587,7 +587,7 @@ func TestIsAccountManagedByCode_TimelockState_Legacy2022Program(t *testing.T) { // The timelock account is unlocked timelockRecord.VaultState = timelock_token_v1.StateUnlocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, data.SaveTimelock(ctx, timelockRecord)) result, err = timelockAccounts.Vault.IsManagedByCode(ctx, data) diff --git a/pkg/code/common/owner_test.go b/pkg/code/common/owner_test.go index 3f08ba7f..49174016 100644 --- a/pkg/code/common/owner_test.go +++ b/pkg/code/common/owner_test.go @@ -98,7 +98,7 @@ func TestGetOwnerMetadata_User12Words(t *testing.T) { // Unlock a Timelock account timelockRecord.VaultState = timelock_token_v1.StateWaitingForTimeout - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, data.SaveTimelock(ctx, timelockRecord)) actual, err = GetOwnerMetadata(ctx, data, owner) diff --git a/pkg/code/common/subsidizer.go b/pkg/code/common/subsidizer.go index 903bfc60..d1270d42 100644 --- a/pkg/code/common/subsidizer.go +++ b/pkg/code/common/subsidizer.go @@ -7,11 +7,11 @@ import ( "github.com/newrelic/go-agent/v3/newrelic" - "github.com/code-payments/code-server/pkg/metrics" - "github.com/code-payments/code-server/pkg/solana" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/metrics" + "github.com/code-payments/code-server/pkg/solana" ) const ( @@ -180,7 +180,7 @@ func EnforceMinimumSubsidizerBalance(ctx context.Context, data code_data.Provide return nil } - nr, ok := ctx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr, ok := ctx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) if ok { nr.RecordCustomMetric("Subsidizer/min_balance_enforced", 1) } diff --git a/pkg/code/data/account/memory/store.go b/pkg/code/data/account/memory/store.go index f11e4bae..a7fd6c7c 100644 --- a/pkg/code/data/account/memory/store.go +++ b/pkg/code/data/account/memory/store.go @@ -404,7 +404,7 @@ func (s *store) CountRequiringAutoReturnCheck(ctx context.Context) (uint64, erro return uint64(len(items)), nil } -func cloneRecords(items []*account.Record) []*account.Record { +func cloneRecords(items []*account.Record) []*account.Record { //nolint:unused res := make([]*account.Record, len(items)) for i, item := range items { diff --git a/pkg/code/data/action/memory/store.go b/pkg/code/data/action/memory/store.go index ef3352e9..01c2dd70 100644 --- a/pkg/code/data/action/memory/store.go +++ b/pkg/code/data/action/memory/store.go @@ -7,9 +7,9 @@ import ( "sync" "time" + "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/action" ) type ById []*action.Record @@ -90,7 +90,7 @@ func (s *store) findBySource(source string) []*action.Record { return res } -func (s *store) filter(items []*action.Record, cursor query.Cursor, limit uint64, direction query.Ordering) []*action.Record { +func (s *store) filter(items []*action.Record, cursor query.Cursor, limit uint64, direction query.Ordering) []*action.Record { //nolint:unused var start uint64 start = 0 @@ -231,12 +231,12 @@ func (s *store) GetAllByIntent(ctx context.Context, intent string) ([]*action.Re return nil, action.ErrActionNotFound } - copy := make([]*action.Record, len(items)) + cpy := make([]*action.Record, len(items)) for i, item := range items { cloned := item.Clone() - copy[i] = &cloned + cpy[i] = &cloned } - return copy, nil + return cpy, nil } // GetAllByAddress implements action.store.GetAllByAddress @@ -249,12 +249,12 @@ func (s *store) GetAllByAddress(ctx context.Context, address string) ([]*action. return nil, action.ErrActionNotFound } - copy := make([]*action.Record, len(items)) + cpy := make([]*action.Record, len(items)) for i, item := range items { cloned := item.Clone() - copy[i] = &cloned + cpy[i] = &cloned } - return copy, nil + return cpy, nil } // GetNetBalance implements action.store.GetNetBalance diff --git a/pkg/code/data/balance/tests/tests.go b/pkg/code/data/balance/tests/tests.go index a354f134..78ce6fc7 100644 --- a/pkg/code/data/balance/tests/tests.go +++ b/pkg/code/data/balance/tests/tests.go @@ -65,7 +65,7 @@ func testHappyPath(t *testing.T, s balance.Store) { require.NoError(t, err) assertEquivalentRecords(t, actual, &cloned) - expected.SlotCheckpoint -= 1 + expected.SlotCheckpoint-- assert.Equal(t, balance.ErrStaleCheckpoint, s.SaveCheckpoint(ctx, expected)) actual, err = s.GetCheckpoint(ctx, "token_account") diff --git a/pkg/code/data/blockchain.go b/pkg/code/data/blockchain.go index deabb706..06133700 100644 --- a/pkg/code/data/blockchain.go +++ b/pkg/code/data/blockchain.go @@ -3,6 +3,7 @@ package data import ( "context" "crypto/ed25519" + "fmt" "github.com/mr-tron/base58" @@ -203,11 +204,13 @@ func (dp *BlockchainProvider) GetBlockchainHistory(ctx context.Context, account tracer := metrics.TraceMethodCall(ctx, blockchainProviderMetricsName, "GetBlockchainHistory") defer tracer.End() - req := query.QueryOptions{ + req := query.Options{ Limit: 1000, Supported: query.CanLimitResults | query.CanQueryByCursor, } - req.Apply(opts...) + if err := req.Apply(opts...); err != nil { + return nil, fmt.Errorf("%w: %w", query.ErrQueryNotSupported, err) + } var cursor = "" if len(req.Cursor) > 0 { diff --git a/pkg/code/data/chat/memory/store.go b/pkg/code/data/chat/memory/store.go index 1b869007..ac9e890f 100644 --- a/pkg/code/data/chat/memory/store.go +++ b/pkg/code/data/chat/memory/store.go @@ -7,8 +7,8 @@ import ( "sync" "time" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/chat" + "github.com/code-payments/code-server/pkg/database/query" ) type ChatsById []*chat.Chat @@ -71,7 +71,7 @@ func (s *store) PutChat(ctx context.Context, data *chat.Chat) error { } // GetChatById implements chat.Store.GetChatById -func (s *store) GetChatById(ctx context.Context, id chat.ChatId) (*chat.Chat, error) { +func (s *store) GetChatById(ctx context.Context, id chat.Id) (*chat.Chat, error) { s.mu.Lock() defer s.mu.Unlock() @@ -121,7 +121,7 @@ func (s *store) PutMessage(ctx context.Context, data *chat.Message) error { } // DeleteMessage implements chat.Store.DeleteMessage -func (s *store) DeleteMessage(ctx context.Context, chatId chat.ChatId, messageId string) error { +func (s *store) DeleteMessage(ctx context.Context, chatId chat.Id, messageId string) error { s.mu.Lock() defer s.mu.Unlock() @@ -136,7 +136,7 @@ func (s *store) DeleteMessage(ctx context.Context, chatId chat.ChatId, messageId } // GetMessageById implements chat.Store.GetMessageById -func (s *store) GetMessageById(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) { +func (s *store) GetMessageById(ctx context.Context, chatId chat.Id, messageId string) (*chat.Message, error) { s.mu.Lock() defer s.mu.Unlock() @@ -150,7 +150,7 @@ func (s *store) GetMessageById(ctx context.Context, chatId chat.ChatId, messageI } // GetAllMessagesByChat implements chat.Store.GetAllMessagesByChat -func (s *store) GetAllMessagesByChat(ctx context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.Message, error) { +func (s *store) GetAllMessagesByChat(ctx context.Context, chatId chat.Id, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.Message, error) { s.mu.Lock() defer s.mu.Unlock() @@ -167,7 +167,7 @@ func (s *store) GetAllMessagesByChat(ctx context.Context, chatId chat.ChatId, cu } // AdvancePointer implements chat.Store.AdvancePointer -func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, pointer string) error { +func (s *store) AdvancePointer(_ context.Context, chatId chat.Id, pointer string) error { s.mu.Lock() defer s.mu.Unlock() @@ -182,7 +182,7 @@ func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, pointer st } // GetUnreadCount implements chat.Store.GetUnreadCount -func (s *store) GetUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) { +func (s *store) GetUnreadCount(ctx context.Context, chatId chat.Id) (uint32, error) { s.mu.Lock() defer s.mu.Unlock() @@ -206,7 +206,7 @@ func (s *store) GetUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, } // SetMuteState implements chat.Store.SetMuteState -func (s *store) SetMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error { +func (s *store) SetMuteState(ctx context.Context, chatId chat.Id, isMuted bool) error { s.mu.Lock() defer s.mu.Unlock() @@ -221,7 +221,7 @@ func (s *store) SetMuteState(ctx context.Context, chatId chat.ChatId, isMuted bo } // SetSubscriptionState implements chat.Store.SetSubscriptionState -func (s *store) SetSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error { +func (s *store) SetSubscriptionState(ctx context.Context, chatId chat.Id, isSubscribed bool) error { s.mu.Lock() defer s.mu.Unlock() @@ -257,7 +257,7 @@ func (s *store) findChat(data *chat.Chat) *chat.Chat { return nil } -func (s *store) findChatById(id chat.ChatId) *chat.Chat { +func (s *store) findChatById(id chat.Id) *chat.Chat { for _, item := range s.chatRecords { if bytes.Equal(id[:], item.ChatId[:]) { return item @@ -289,7 +289,7 @@ func (s *store) findMessage(data *chat.Message) *chat.Message { return nil } -func (s *store) findMessageById(chatId chat.ChatId, messageId string) *chat.Message { +func (s *store) findMessageById(chatId chat.Id, messageId string) *chat.Message { for _, item := range s.messageRecords { if bytes.Equal(item.ChatId[:], chatId[:]) && item.MessageId == messageId { return item @@ -298,7 +298,7 @@ func (s *store) findMessageById(chatId chat.ChatId, messageId string) *chat.Mess return nil } -func (s *store) findMessagesByChatId(chatId chat.ChatId) []*chat.Message { +func (s *store) findMessagesByChatId(chatId chat.Id) []*chat.Message { var res []*chat.Message for _, item := range s.messageRecords { if bytes.Equal(chatId[:], item.ChatId[:]) { @@ -413,7 +413,7 @@ func (s *store) filterPagedMessagesByChat(items []*chat.Message, cursor query.Cu return res, nil } -func (s *store) sumContentLengths(items []*chat.Message) uint32 { +func (s *store) sumContentLengths(items []*chat.Message) uint32 { //nolint:unused var res uint32 for _, item := range items { res += uint32(item.ContentLength) diff --git a/pkg/code/data/chat/model.go b/pkg/code/data/chat/model.go index d8fe7432..38b42a16 100644 --- a/pkg/code/data/chat/model.go +++ b/pkg/code/data/chat/model.go @@ -14,21 +14,21 @@ import ( "github.com/code-payments/code-server/pkg/pointer" ) -type ChatType uint8 +type Type uint8 const ( - ChatTypeUnknown ChatType = iota - ChatTypeInternal // todo: better name, or split into the various buckets (eg. Code Team vs Cash Transactions) - ChatTypeExternalApp + TypeUnknown Type = iota + TypeInternal // todo: better name, or split into the various buckets (eg. Code Team vs Cash Transactions) + TypeExternalApp ) -type ChatId [32]byte +type Id [32]byte type Chat struct { Id uint64 - ChatId ChatId - ChatType ChatType + ChatId Id + ChatType Type ChatTitle string // The message sender IsVerified bool @@ -44,7 +44,7 @@ type Chat struct { type Message struct { Id uint64 - ChatId ChatId + ChatId Id MessageId string Data []byte @@ -55,7 +55,7 @@ type Message struct { Timestamp time.Time } -func GetChatId(sender, receiver string, isVerified bool) ChatId { +func GetChatId(sender, receiver string, isVerified bool) Id { combined := []byte(fmt.Sprintf("%s:%s:%v", sender, receiver, isVerified)) if strings.Compare(sender, receiver) > 0 { combined = []byte(fmt.Sprintf("%s:%s:%v", receiver, sender, isVerified)) @@ -63,19 +63,19 @@ func GetChatId(sender, receiver string, isVerified bool) ChatId { return sha256.Sum256(combined) } -func (c ChatId) ToProto() *chatpb.ChatId { +func (c Id) ToProto() *chatpb.ChatId { return &chatpb.ChatId{ Value: c[:], } } -func ChatIdFromProto(proto *chatpb.ChatId) ChatId { - var chatId ChatId +func IdFromProto(proto *chatpb.ChatId) Id { + var chatId Id copy(chatId[:], proto.Value) return chatId } -func (c ChatId) String() string { +func (c Id) String() string { return hex.EncodeToString(c[:]) } @@ -85,7 +85,7 @@ func (r *Chat) Validate() error { return errors.New("chat id is invalid") } - if r.ChatType == ChatTypeUnknown { + if r.ChatType == TypeUnknown { return errors.New("chat type is required") } @@ -105,7 +105,7 @@ func (r *Chat) Validate() error { } func (r *Chat) Clone() Chat { - var chatIdCopy ChatId + var chatIdCopy Id copy(chatIdCopy[:], r.ChatId[:]) return Chat{ @@ -166,7 +166,7 @@ func (r *Message) Validate() error { } func (r *Message) Clone() Message { - var chatIdCopy ChatId + var chatIdCopy Id copy(chatIdCopy[:], r.ChatId[:]) dataCopy := make([]byte, len(r.Data)) diff --git a/pkg/code/data/chat/postgres/model.go b/pkg/code/data/chat/postgres/model.go index 07158095..aa7b41b3 100644 --- a/pkg/code/data/chat/postgres/model.go +++ b/pkg/code/data/chat/postgres/model.go @@ -8,10 +8,10 @@ import ( "github.com/jmoiron/sqlx" + "github.com/code-payments/code-server/pkg/code/data/chat" pgutil "github.com/code-payments/code-server/pkg/database/postgres" q "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/chat" ) const ( @@ -75,14 +75,14 @@ func toChatModel(obj *chat.Chat) (*chatModel, error) { } func fromChatModel(obj *chatModel) *chat.Chat { - var chatId chat.ChatId + var chatId chat.Id copy(chatId[:], obj.ChatId) return &chat.Chat{ Id: uint64(obj.Id.Int64), ChatId: chatId, - ChatType: chat.ChatType(obj.ChatType), + ChatType: chat.Type(obj.ChatType), IsVerified: obj.IsVerified, CodeUser: obj.Member1, @@ -115,7 +115,7 @@ func toMessageModel(obj *chat.Message) (*messageModel, error) { } func fromMessageModel(obj *messageModel) *chat.Message { - var chatId chat.ChatId + var chatId chat.Id copy(chatId[:], obj.ChatId) return &chat.Message{ @@ -184,7 +184,7 @@ func (m *messageModel) dbPut(ctx context.Context, db *sqlx.DB) error { return pgutil.CheckUniqueViolation(err, chat.ErrMessageAlreadyExists) } -func dbGetChatById(ctx context.Context, db *sqlx.DB, id chat.ChatId) (*chatModel, error) { +func dbGetChatById(ctx context.Context, db *sqlx.DB, id chat.Id) (*chatModel, error) { res := &chatModel{} query := `SELECT id, chat_id, chat_type, is_verified, member1, member2, read_pointer, is_muted, is_unsubscribed, created_at FROM ` + chatTableName + ` @@ -202,7 +202,7 @@ func dbGetChatById(ctx context.Context, db *sqlx.DB, id chat.ChatId) (*chatModel return res, nil } -func dbDeleteMessage(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, messageId string) error { +func dbDeleteMessage(ctx context.Context, db *sqlx.DB, chatId chat.Id, messageId string) error { return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { query := `DELETE FROM ` + messageTableName + ` WHERE chat_id = $1 AND message_id = $2 @@ -218,7 +218,7 @@ func dbDeleteMessage(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, messa }) } -func dbGetMessageById(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, messageId string) (*messageModel, error) { +func dbGetMessageById(ctx context.Context, db *sqlx.DB, chatId chat.Id, messageId string) (*messageModel, error) { res := &messageModel{} query := `SELECT id, chat_id, message_id, data, is_silent, content_length, timestamp FROM ` + messageTableName + ` @@ -260,7 +260,7 @@ func dbGetAllChatsForUser(ctx context.Context, db *sqlx.DB, user string, cursor return res, nil } -func dbGetAllMessagesByChat(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, cursor q.Cursor, direction q.Ordering, limit uint64) ([]*messageModel, error) { +func dbGetAllMessagesByChat(ctx context.Context, db *sqlx.DB, chatId chat.Id, cursor q.Cursor, direction q.Ordering, limit uint64) ([]*messageModel, error) { res := []*messageModel{} query := `SELECT id, chat_id, message_id, data, is_silent, content_length, timestamp FROM ` + messageTableName + ` @@ -311,7 +311,7 @@ func dbGetAllMessagesByChat(ctx context.Context, db *sqlx.DB, chatId chat.ChatId return res, nil } -func dbAdvancePointer(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, pointer string) error { +func dbAdvancePointer(ctx context.Context, db *sqlx.DB, chatId chat.Id, pointer string) error { query := `UPDATE ` + chatTableName + ` SET read_pointer = $2 WHERE chat_id = $1 @@ -337,7 +337,7 @@ func dbAdvancePointer(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, poin return nil } -func dbGetUnreadCount(ctx context.Context, db *sqlx.DB, chatId chat.ChatId) (uint32, error) { +func dbGetUnreadCount(ctx context.Context, db *sqlx.DB, chatId chat.Id) (uint32, error) { res := &struct { UnreadCount sql.NullInt64 `db:"unread_count"` }{} @@ -360,7 +360,7 @@ func dbGetUnreadCount(ctx context.Context, db *sqlx.DB, chatId chat.ChatId) (uin return uint32(res.UnreadCount.Int64), nil } -func dbSetMuteState(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, isMuted bool) error { +func dbSetMuteState(ctx context.Context, db *sqlx.DB, chatId chat.Id, isMuted bool) error { query := `UPDATE ` + chatTableName + ` SET is_muted = $2 WHERE chat_id = $1 @@ -386,7 +386,7 @@ func dbSetMuteState(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, isMute return nil } -func dbSetSubscriptionState(ctx context.Context, db *sqlx.DB, chatId chat.ChatId, isSubscribed bool) error { +func dbSetSubscriptionState(ctx context.Context, db *sqlx.DB, chatId chat.Id, isSubscribed bool) error { query := `UPDATE ` + chatTableName + ` SET is_unsubscribed = $2 WHERE chat_id = $1 diff --git a/pkg/code/data/chat/postgres/store.go b/pkg/code/data/chat/postgres/store.go index 943a1935..c8ce2b4c 100644 --- a/pkg/code/data/chat/postgres/store.go +++ b/pkg/code/data/chat/postgres/store.go @@ -6,8 +6,8 @@ import ( "github.com/jmoiron/sqlx" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/chat" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -39,7 +39,7 @@ func (s *store) PutChat(ctx context.Context, record *chat.Chat) error { } // GetChatById implements chat.Store.GetChatById -func (s *store) GetChatById(ctx context.Context, id chat.ChatId) (*chat.Chat, error) { +func (s *store) GetChatById(ctx context.Context, id chat.Id) (*chat.Chat, error) { model, err := dbGetChatById(ctx, s.db, id) if err != nil { return nil, err @@ -80,12 +80,12 @@ func (s *store) PutMessage(ctx context.Context, record *chat.Message) error { } // DeleteMessage implements chat.Store.DeleteMessage -func (s *store) DeleteMessage(ctx context.Context, chatId chat.ChatId, messageId string) error { +func (s *store) DeleteMessage(ctx context.Context, chatId chat.Id, messageId string) error { return dbDeleteMessage(ctx, s.db, chatId, messageId) } // GetMessageById implements chat.Store.GetMessageById -func (s *store) GetMessageById(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) { +func (s *store) GetMessageById(ctx context.Context, chatId chat.Id, messageId string) (*chat.Message, error) { model, err := dbGetMessageById(ctx, s.db, chatId, messageId) if err != nil { return nil, err @@ -95,7 +95,7 @@ func (s *store) GetMessageById(ctx context.Context, chatId chat.ChatId, messageI } // GetAllMessagesByChat implements chat.Store.GetAllMessagesByChat -func (s *store) GetAllMessagesByChat(ctx context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.Message, error) { +func (s *store) GetAllMessagesByChat(ctx context.Context, chatId chat.Id, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.Message, error) { models, err := dbGetAllMessagesByChat(ctx, s.db, chatId, cursor, direction, limit) if err != nil { return nil, err @@ -109,21 +109,21 @@ func (s *store) GetAllMessagesByChat(ctx context.Context, chatId chat.ChatId, cu } // AdvancePointer implements chat.Store.AdvancePointer -func (s *store) AdvancePointer(ctx context.Context, chatId chat.ChatId, pointer string) error { +func (s *store) AdvancePointer(ctx context.Context, chatId chat.Id, pointer string) error { return dbAdvancePointer(ctx, s.db, chatId, pointer) } // GetUnreadCount implements chat.Store.GetUnreadCount -func (s *store) GetUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) { +func (s *store) GetUnreadCount(ctx context.Context, chatId chat.Id) (uint32, error) { return dbGetUnreadCount(ctx, s.db, chatId) } // SetMuteState implements chat.Store.SetMuteState -func (s *store) SetMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error { +func (s *store) SetMuteState(ctx context.Context, chatId chat.Id, isMuted bool) error { return dbSetMuteState(ctx, s.db, chatId, isMuted) } // SetSubscriptionState implements chat.Store.SetSubscriptionState -func (s *store) SetSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error { +func (s *store) SetSubscriptionState(ctx context.Context, chatId chat.Id, isSubscribed bool) error { return dbSetSubscriptionState(ctx, s.db, chatId, isSubscribed) } diff --git a/pkg/code/data/chat/store.go b/pkg/code/data/chat/store.go index 2e79a228..42b9de18 100644 --- a/pkg/code/data/chat/store.go +++ b/pkg/code/data/chat/store.go @@ -20,7 +20,7 @@ type Store interface { PutChat(ctx context.Context, record *Chat) error // GetChatById gets a chat by its chat ID - GetChatById(ctx context.Context, chatId ChatId) (*Chat, error) + GetChatById(ctx context.Context, chatId Id) (*Chat, error) // GetAllChatsForUser gets all chats for a given user // @@ -32,25 +32,25 @@ type Store interface { // Delete message deletes a message within a chat. The call is idempotent // and will not fail if the message doesn't exist. - DeleteMessage(ctx context.Context, chatId ChatId, messageId string) error + DeleteMessage(ctx context.Context, chatId Id, messageId string) error // GetMessageById gets a chat message by its message ID within a chat - GetMessageById(ctx context.Context, chatId ChatId, messageId string) (*Message, error) + GetMessageById(ctx context.Context, chatId Id, messageId string) (*Message, error) // GetAllMessagesByChat gets all messages for a given chat // // Note: Cursor is a message ID - GetAllMessagesByChat(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*Message, error) + GetAllMessagesByChat(ctx context.Context, chatId Id, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*Message, error) // AdvancePointer advances a chat pointer - AdvancePointer(ctx context.Context, chatId ChatId, pointer string) error + AdvancePointer(ctx context.Context, chatId Id, pointer string) error // GetUnreadCount gets the unread message count for a chat ID - GetUnreadCount(ctx context.Context, chatId ChatId) (uint32, error) + GetUnreadCount(ctx context.Context, chatId Id) (uint32, error) // SetMuteState updates the mute state for a chat - SetMuteState(ctx context.Context, chatId ChatId, isMuted bool) error + SetMuteState(ctx context.Context, chatId Id, isMuted bool) error // SetSubscriptionState updates the subscription state for a chat - SetSubscriptionState(ctx context.Context, chatId ChatId, isSubscribed bool) error + SetSubscriptionState(ctx context.Context, chatId Id, isSubscribed bool) error } diff --git a/pkg/code/data/chat/tests/tests.go b/pkg/code/data/chat/tests/tests.go index f9eaaadf..90f152ea 100644 --- a/pkg/code/data/chat/tests/tests.go +++ b/pkg/code/data/chat/tests/tests.go @@ -10,9 +10,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/code-payments/code-server/pkg/code/data/chat" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/chat" ) func RunTests(t *testing.T, s chat.Store, teardown func()) { @@ -45,7 +45,7 @@ func testChatRoundTrip(t *testing.T, s chat.Store) { expected := &chat.Chat{ ChatId: chatId, - ChatType: chat.ChatTypeExternalApp, + ChatType: chat.TypeExternalApp, ChatTitle: "domain", IsVerified: true, @@ -154,7 +154,7 @@ func testAdvancePointer(t *testing.T, s chat.Store) { record := &chat.Chat{ ChatId: chatId, - ChatType: chat.ChatTypeExternalApp, + ChatType: chat.TypeExternalApp, ChatTitle: "domain", IsVerified: true, @@ -185,7 +185,7 @@ func testGetUnreadCount(t *testing.T, s chat.Store) { chatRecord := &chat.Chat{ ChatId: chatId, - ChatType: chat.ChatTypeExternalApp, + ChatType: chat.TypeExternalApp, ChatTitle: "domain", IsVerified: true, @@ -222,7 +222,7 @@ func testGetUnreadCount(t *testing.T, s chat.Store) { var deltaRead int for i := 1; i <= 3; i++ { - deltaRead += 1 + deltaRead++ for j := 0; j < 2; j++ { require.NoError(t, s.AdvancePointer(ctx, chatId, fmt.Sprintf("message%d%d", i, j))) @@ -243,7 +243,7 @@ func testMuteState(t *testing.T, s chat.Store) { require.NoError(t, s.PutChat(ctx, &chat.Chat{ ChatId: chatId, - ChatType: chat.ChatTypeExternalApp, + ChatType: chat.TypeExternalApp, ChatTitle: "domain", IsVerified: true, @@ -272,7 +272,7 @@ func tesSubscriptionState(t *testing.T, s chat.Store) { require.NoError(t, s.PutChat(ctx, &chat.Chat{ ChatId: chatId, - ChatType: chat.ChatTypeExternalApp, + ChatType: chat.TypeExternalApp, ChatTitle: "domain", IsVerified: true, @@ -305,7 +305,7 @@ func testGetAllChatsByUserPaging(t *testing.T, s chat.Store) { record := &chat.Chat{ ChatId: chat.GetChatId(merchant, user, true), - ChatType: chat.ChatTypeExternalApp, + ChatType: chat.TypeExternalApp, ChatTitle: merchant, IsVerified: true, CodeUser: user, diff --git a/pkg/code/data/currency/memory/store.go b/pkg/code/data/currency/memory/store.go index 750dbc29..4207ad75 100644 --- a/pkg/code/data/currency/memory/store.go +++ b/pkg/code/data/currency/memory/store.go @@ -6,8 +6,8 @@ import ( "sync" "time" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/currency" + "github.com/code-payments/code-server/pkg/database/query" ) const ( @@ -36,7 +36,7 @@ func New() currency.Store { } } -func (s *store) reset() { +func (s *store) reset() { //nolint:unused s.currencyStoreMu.Lock() s.currencyStore = make([]*currency.ExchangeRateRecord, 0) s.currencyStoreMu.Unlock() diff --git a/pkg/code/data/currency/postgres/model.go b/pkg/code/data/currency/postgres/model.go index 70e4c2c9..c44abf0e 100644 --- a/pkg/code/data/currency/postgres/model.go +++ b/pkg/code/data/currency/postgres/model.go @@ -7,8 +7,8 @@ import ( "github.com/jmoiron/sqlx" - q "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/currency" + q "github.com/code-payments/code-server/pkg/database/query" pgutil "github.com/code-payments/code-server/pkg/database/postgres" ) @@ -81,14 +81,14 @@ func makeRangeQuery(condition string, ordering q.Ordering, interval q.Interval) return query } -func (self *model) txSave(ctx context.Context, tx *sqlx.Tx) error { +func (m *model) txSave(ctx context.Context, tx *sqlx.Tx) error { err := tx.QueryRowxContext(ctx, makeInsertQuery(), - self.ForDate, - self.ForTimestamp, - self.CurrencyCode, - self.CurrencyRate, - ).StructScan(self) + m.ForDate, + m.ForTimestamp, + m.CurrencyCode, + m.CurrencyRate, + ).StructScan(m) return pgutil.CheckUniqueViolation(err, currency.ErrExists) } diff --git a/pkg/code/data/currency/tests/tests.go b/pkg/code/data/currency/tests/tests.go index b7042130..90e03d15 100644 --- a/pkg/code/data/currency/tests/tests.go +++ b/pkg/code/data/currency/tests/tests.go @@ -8,8 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/currency" + "github.com/code-payments/code-server/pkg/database/query" ) func RunTests(t *testing.T, s currency.Store, teardown func()) { @@ -103,12 +103,12 @@ func testGetRange(t *testing.T, s currency.Store) { assert.EqualValues(t, rates[i].Rates["usd"], item.Rate) } - result, err = s.GetRange(context.Background(), "usd", query.IntervalHour, rates[0].Time, rates[99].Time, query.Ascending) + _, err = s.GetRange(context.Background(), "usd", query.IntervalHour, rates[0].Time, rates[99].Time, query.Ascending) require.NoError(t, err) - result, err = s.GetRange(context.Background(), "usd", query.IntervalDay, rates[0].Time, rates[99].Time, query.Ascending) + _, err = s.GetRange(context.Background(), "usd", query.IntervalDay, rates[0].Time, rates[99].Time, query.Ascending) require.NoError(t, err) - result, err = s.GetRange(context.Background(), "usd", query.IntervalWeek, rates[0].Time, rates[99].Time, query.Ascending) + _, err = s.GetRange(context.Background(), "usd", query.IntervalWeek, rates[0].Time, rates[99].Time, query.Ascending) require.NoError(t, err) - result, err = s.GetRange(context.Background(), "usd", query.IntervalMonth, rates[0].Time, rates[99].Time, query.Ascending) + _, err = s.GetRange(context.Background(), "usd", query.IntervalMonth, rates[0].Time, rates[99].Time, query.Ascending) require.NoError(t, err) } diff --git a/pkg/code/data/deposit/tests/tests.go b/pkg/code/data/deposit/tests/tests.go index 256afad4..eefadf24 100644 --- a/pkg/code/data/deposit/tests/tests.go +++ b/pkg/code/data/deposit/tests/tests.go @@ -57,7 +57,6 @@ func testRoundTrip(t *testing.T, s deposit.Store) { actual, err = s.Get(ctx, cloned.Signature, cloned.Destination) require.NoError(t, err) assertEquivalentRecords(t, &cloned, actual) - }) } diff --git a/pkg/code/data/estimated.go b/pkg/code/data/estimated.go index b40c319a..e0116cfd 100644 --- a/pkg/code/data/estimated.go +++ b/pkg/code/data/estimated.go @@ -49,10 +49,10 @@ func (p *EstimatedProvider) TestForKnownAccount(ctx context.Context, account []b tracer := metrics.TraceMethodCall(ctx, estimatedProviderMetricsName, "TestForKnownAccount") defer tracer.End() - if len(account) > 0 { - } else { + if len(account) == 0 { return false, ErrInvalidAccount } + return p.knownAccounts.Test(account), nil } diff --git a/pkg/code/data/event/postgres/model.go b/pkg/code/data/event/postgres/model.go index 080b61fd..6208a7e1 100644 --- a/pkg/code/data/event/postgres/model.go +++ b/pkg/code/data/event/postgres/model.go @@ -7,9 +7,9 @@ import ( "github.com/jmoiron/sqlx" + "github.com/code-payments/code-server/pkg/code/data/event" pgutil "github.com/code-payments/code-server/pkg/database/postgres" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/event" ) const ( @@ -109,7 +109,7 @@ func fromModel(obj *model) *event.Record { Id: uint64(obj.Id.Int64), EventId: obj.EventId, - EventType: event.EventType(obj.EventType), + EventType: event.Type(obj.EventType), SourceCodeAccount: obj.SourceCodeAccount, DestinationCodeAccount: pointer.StringIfValid(obj.DestinationCodeAccount.Valid, obj.DestinationCodeAccount.String), diff --git a/pkg/code/data/event/record.go b/pkg/code/data/event/record.go index 26fef093..a58bc668 100644 --- a/pkg/code/data/event/record.go +++ b/pkg/code/data/event/record.go @@ -7,10 +7,10 @@ import ( "github.com/code-payments/code-server/pkg/pointer" ) -type EventType uint32 +type Type uint32 const ( - UnknownEvent EventType = iota + UnknownEvent Type = iota AccountCreated WelcomeBonusClaimed InPersonGrab @@ -26,7 +26,7 @@ type Record struct { // multiple intents may want to standardize on which intent ID is used to add // additional metadata as it becomes available. EventId string - EventType EventType + EventType Type // Involved accounts SourceCodeAccount string diff --git a/pkg/code/data/fulfillment/memory/store.go b/pkg/code/data/fulfillment/memory/store.go index 4618d465..8465c58e 100644 --- a/pkg/code/data/fulfillment/memory/store.go +++ b/pkg/code/data/fulfillment/memory/store.go @@ -7,9 +7,9 @@ import ( "sync" "time" + "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/fulfillment" ) type store struct { @@ -120,7 +120,7 @@ func (s *store) findByStateAndAddress(state fulfillment.State, address string) [ return res } -func (s *store) findByStateAndAddressAsSource(state fulfillment.State, address string) []*fulfillment.Record { +func (s *store) findByStateAndAddressAsSource(state fulfillment.State, address string) []*fulfillment.Record { //nolint:unused res := make([]*fulfillment.Record, 0) for _, item := range s.records { if item.State != state { @@ -358,7 +358,7 @@ func (s *store) CountByStateGroupedByType(ctx context.Context, state fulfillment res := make(map[fulfillment.Type]uint64) for _, item := range items { - res[item.FulfillmentType] += 1 + res[item.FulfillmentType]++ } return res, nil } @@ -423,7 +423,7 @@ func (s *store) CountPendingByType(ctx context.Context) (map[fulfillment.Type]ui res := make(map[fulfillment.Type]uint64) for _, item := range items { - res[item.FulfillmentType] += 1 + res[item.FulfillmentType]++ } return res, nil } @@ -551,7 +551,7 @@ func (s *store) ActivelyScheduleTreasuryAdvances(ctx context.Context, treasury s data.DisableActiveScheduling = false - updateCount += 1 + updateCount++ if updateCount >= uint64(limit) { return updateCount, nil } diff --git a/pkg/code/data/fulfillment/tests/tests.go b/pkg/code/data/fulfillment/tests/tests.go index 1c33db06..f89c9974 100644 --- a/pkg/code/data/fulfillment/tests/tests.go +++ b/pkg/code/data/fulfillment/tests/tests.go @@ -10,11 +10,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/pointer" "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/pointer" ) func RunTests(t *testing.T, s fulfillment.Store, teardown func()) { @@ -1010,7 +1010,7 @@ func testTreasuryQueries(t *testing.T, s fulfillment.Store) { require.NoError(t, err) if !actual.DisableActiveScheduling { - totalUpdated += 1 + totalUpdated++ } } assert.Equal(t, 2, totalUpdated) diff --git a/pkg/code/data/intent/memory/store.go b/pkg/code/data/intent/memory/store.go index 94293dbd..a9c092b8 100644 --- a/pkg/code/data/intent/memory/store.go +++ b/pkg/code/data/intent/memory/store.go @@ -6,8 +6,8 @@ import ( "sync" "time" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -63,7 +63,7 @@ func (s *store) findIntent(intentID string) *intent.Record { return nil } -func (s *store) findByState(state intent.State) []*intent.Record { +func (s *store) findByState(state intent.State) []*intent.Record { //nolint:unused res := make([]*intent.Record, 0) for _, item := range s.records { if item.State == state { diff --git a/pkg/code/data/intent/tests/tests.go b/pkg/code/data/intent/tests/tests.go index c17af4cb..7b55c8e6 100644 --- a/pkg/code/data/intent/tests/tests.go +++ b/pkg/code/data/intent/tests/tests.go @@ -652,7 +652,6 @@ func testGetLatestByInitiatorAndType(t *testing.T, s intent.Store) { require.NoError(t, err) assert.Equal(t, "t4", actual.IntentId) }) - } func testGetCountForAntispam(t *testing.T, s intent.Store) { diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index b106a093..7ee223ad 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -69,7 +69,7 @@ import ( intent_memory_client "github.com/code-payments/code-server/pkg/code/data/intent/memory" login_memory_client "github.com/code-payments/code-server/pkg/code/data/login/memory" merkletree_memory_client "github.com/code-payments/code-server/pkg/code/data/merkletree/memory" - messaging "github.com/code-payments/code-server/pkg/code/data/messaging" + "github.com/code-payments/code-server/pkg/code/data/messaging" messaging_memory_client "github.com/code-payments/code-server/pkg/code/data/messaging/memory" nonce_memory_client "github.com/code-payments/code-server/pkg/code/data/nonce/memory" onramp_memory_client "github.com/code-payments/code-server/pkg/code/data/onramp/memory" @@ -286,7 +286,7 @@ type DatabaseData interface { // User Identity // -------------------------------------------------------------------------------- PutUser(ctx context.Context, record *identity.Record) error - GetUserByID(ctx context.Context, id *user.UserID) (*identity.Record, error) + GetUserByID(ctx context.Context, id *user.Id) (*identity.Record, error) GetUserByPhoneView(ctx context.Context, phoneNumber string) (*identity.Record, error) // User Storage Management @@ -330,7 +330,7 @@ type DatabaseData interface { GetTreasuryPoolByName(ctx context.Context, name string) (*treasury.Record, error) GetTreasuryPoolByAddress(ctx context.Context, address string) (*treasury.Record, error) GetTreasuryPoolByVault(ctx context.Context, vault string) (*treasury.Record, error) - GetAllTreasuryPoolsByState(ctx context.Context, state treasury.TreasuryPoolState, opts ...query.Option) ([]*treasury.Record, error) + GetAllTreasuryPoolsByState(ctx context.Context, state treasury.PoolState, opts ...query.Option) ([]*treasury.Record, error) SaveTreasuryPoolFunding(ctx context.Context, record *treasury.FundingHistoryRecord) error GetTotalAvailableTreasuryPoolFunds(ctx context.Context, vault string) (uint64, error) @@ -379,16 +379,16 @@ type DatabaseData interface { // Chat // -------------------------------------------------------------------------------- PutChat(ctx context.Context, record *chat.Chat) error - GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.Chat, error) + GetChatById(ctx context.Context, chatId chat.Id) (*chat.Chat, error) GetAllChatsForUser(ctx context.Context, user string, opts ...query.Option) ([]*chat.Chat, error) PutChatMessage(ctx context.Context, record *chat.Message) error - DeleteChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) error - GetChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) - GetAllChatMessages(ctx context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.Message, error) - AdvanceChatPointer(ctx context.Context, chatId chat.ChatId, pointer string) error - GetChatUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) - SetChatMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error - SetChatSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error + DeleteChatMessage(ctx context.Context, chatId chat.Id, messageId string) error + GetChatMessage(ctx context.Context, chatId chat.Id, messageId string) (*chat.Message, error) + GetAllChatMessages(ctx context.Context, chatId chat.Id, opts ...query.Option) ([]*chat.Message, error) + AdvanceChatPointer(ctx context.Context, chatId chat.Id, pointer string) error + GetChatUnreadCount(ctx context.Context, chatId chat.Id) (uint32, error) + SetChatMuteState(ctx context.Context, chatId chat.Id, isMuted bool) error + SetChatSubscriptionState(ctx context.Context, chatId chat.Id, isSubscribed bool) error // Badge Count // -------------------------------------------------------------------------------- @@ -663,13 +663,15 @@ func (dp *DatabaseProvider) GetAllExchangeRates(ctx context.Context, t time.Time return rates, nil } func (dp *DatabaseProvider) GetExchangeRateHistory(ctx context.Context, code currency_lib.Code, opts ...query.Option) ([]*currency.ExchangeRateRecord, error) { - req := query.QueryOptions{ + req := query.Options{ Limit: maxCurrencyHistoryReqSize, End: time.Now(), SortBy: query.Ascending, Supported: query.CanLimitResults | query.CanSortBy | query.CanBucketBy | query.CanQueryByStartTime | query.CanQueryByEndTime, } - req.Apply(opts...) + if err := req.Apply(opts...); err != nil { + return nil, fmt.Errorf("%w: %w", query.ErrQueryNotSupported, err) + } if req.Start.IsZero() { return nil, query.ErrQueryNotSupported @@ -929,12 +931,14 @@ func (dp *DatabaseProvider) UpdateOrCreatePayment(ctx context.Context, record *p return dp.CreatePayment(ctx, record) } func (dp *DatabaseProvider) GetPaymentHistory(ctx context.Context, account string, opts ...query.Option) ([]*payment.Record, error) { - req := query.QueryOptions{ + req := query.Options{ Limit: maxPaymentHistoryReqSize, SortBy: query.Ascending, Supported: query.CanLimitResults | query.CanSortBy | query.CanQueryByCursor | query.CanFilterBy, } - req.Apply(opts...) + if err := req.Apply(opts...); err != nil { + return nil, fmt.Errorf("%w: %w", query.ErrQueryNotSupported, err) + } if req.Limit > maxPaymentHistoryReqSize { return nil, query.ErrQueryNotSupported @@ -948,18 +952,20 @@ func (dp *DatabaseProvider) GetPaymentHistory(ctx context.Context, account strin } if req.FilterBy.Valid { - return dp.payments.GetAllForAccountByType(ctx, account, cursor, uint(req.Limit), req.SortBy, payment.PaymentType(req.FilterBy.Value)) + return dp.payments.GetAllForAccountByType(ctx, account, cursor, uint(req.Limit), req.SortBy, payment.Type(req.FilterBy.Value)) } return dp.payments.GetAllForAccount(ctx, account, cursor, uint(req.Limit), req.SortBy) } func (dp *DatabaseProvider) GetPaymentHistoryWithinBlockRange(ctx context.Context, account string, lowerBound, upperBound uint64, opts ...query.Option) ([]*payment.Record, error) { - req := query.QueryOptions{ + req := query.Options{ Limit: maxPaymentHistoryReqSize, SortBy: query.Ascending, Supported: query.CanLimitResults | query.CanSortBy | query.CanQueryByCursor | query.CanFilterBy, } - req.Apply(opts...) + if err := req.Apply(opts...); err != nil { + return nil, fmt.Errorf("%w: %w", query.ErrQueryNotSupported, err) + } if req.Limit > maxPaymentHistoryReqSize { return nil, query.ErrQueryNotSupported @@ -973,7 +979,7 @@ func (dp *DatabaseProvider) GetPaymentHistoryWithinBlockRange(ctx context.Contex } if req.FilterBy.Valid { - return dp.payments.GetAllForAccountByTypeWithinBlockRange(ctx, account, lowerBound, upperBound, cursor, uint(req.Limit), req.SortBy, payment.PaymentType(req.FilterBy.Value)) + return dp.payments.GetAllForAccountByTypeWithinBlockRange(ctx, account, lowerBound, upperBound, cursor, uint(req.Limit), req.SortBy, payment.Type(req.FilterBy.Value)) } return nil, query.ErrQueryNotSupported @@ -1101,7 +1107,7 @@ func (dp *DatabaseProvider) BatchRemoveContacts(ctx context.Context, owner *user func (dp *DatabaseProvider) PutUser(ctx context.Context, record *identity.Record) error { return dp.userIdentity.Put(ctx, record) } -func (dp *DatabaseProvider) GetUserByID(ctx context.Context, id *user.UserID) (*identity.Record, error) { +func (dp *DatabaseProvider) GetUserByID(ctx context.Context, id *user.Id) (*identity.Record, error) { return dp.userIdentity.GetByID(ctx, id) } func (dp *DatabaseProvider) GetUserByPhoneView(ctx context.Context, phoneNumber string) (*identity.Record, error) { @@ -1293,7 +1299,7 @@ func (dp *DatabaseProvider) GetTreasuryPoolByAddress(ctx context.Context, addres func (dp *DatabaseProvider) GetTreasuryPoolByVault(ctx context.Context, vault string) (*treasury.Record, error) { return dp.treasury.GetByVault(ctx, vault) } -func (dp *DatabaseProvider) GetAllTreasuryPoolsByState(ctx context.Context, state treasury.TreasuryPoolState, opts ...query.Option) ([]*treasury.Record, error) { +func (dp *DatabaseProvider) GetAllTreasuryPoolsByState(ctx context.Context, state treasury.PoolState, opts ...query.Option) ([]*treasury.Record, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err @@ -1396,7 +1402,7 @@ func (dp *DatabaseProvider) GetAllPendingWebhooksReadyToSend(ctx context.Context func (dp *DatabaseProvider) PutChat(ctx context.Context, record *chat.Chat) error { return dp.chat.PutChat(ctx, record) } -func (dp *DatabaseProvider) GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.Chat, error) { +func (dp *DatabaseProvider) GetChatById(ctx context.Context, chatId chat.Id) (*chat.Chat, error) { return dp.chat.GetChatById(ctx, chatId) } func (dp *DatabaseProvider) GetAllChatsForUser(ctx context.Context, user string, opts ...query.Option) ([]*chat.Chat, error) { @@ -1409,29 +1415,29 @@ func (dp *DatabaseProvider) GetAllChatsForUser(ctx context.Context, user string, func (dp *DatabaseProvider) PutChatMessage(ctx context.Context, record *chat.Message) error { return dp.chat.PutMessage(ctx, record) } -func (dp *DatabaseProvider) DeleteChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) error { +func (dp *DatabaseProvider) DeleteChatMessage(ctx context.Context, chatId chat.Id, messageId string) error { return dp.chat.DeleteMessage(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) { +func (dp *DatabaseProvider) GetChatMessage(ctx context.Context, chatId chat.Id, messageId string) (*chat.Message, error) { return dp.chat.GetMessageById(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetAllChatMessages(ctx context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.Message, error) { +func (dp *DatabaseProvider) GetAllChatMessages(ctx context.Context, chatId chat.Id, opts ...query.Option) ([]*chat.Message, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } return dp.chat.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } -func (dp *DatabaseProvider) AdvanceChatPointer(ctx context.Context, chatId chat.ChatId, pointer string) error { +func (dp *DatabaseProvider) AdvanceChatPointer(ctx context.Context, chatId chat.Id, pointer string) error { return dp.chat.AdvancePointer(ctx, chatId, pointer) } -func (dp *DatabaseProvider) GetChatUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) { +func (dp *DatabaseProvider) GetChatUnreadCount(ctx context.Context, chatId chat.Id) (uint32, error) { return dp.chat.GetUnreadCount(ctx, chatId) } -func (dp *DatabaseProvider) SetChatMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error { +func (dp *DatabaseProvider) SetChatMuteState(ctx context.Context, chatId chat.Id, isMuted bool) error { return dp.chat.SetMuteState(ctx, chatId, isMuted) } -func (dp *DatabaseProvider) SetChatSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error { +func (dp *DatabaseProvider) SetChatSubscriptionState(ctx context.Context, chatId chat.Id, isSubscribed bool) error { return dp.chat.SetSubscriptionState(ctx, chatId, isSubscribed) } diff --git a/pkg/code/data/invite/v2/memory/store.go b/pkg/code/data/invite/v2/memory/store.go index 0bbed49b..23298081 100644 --- a/pkg/code/data/invite/v2/memory/store.go +++ b/pkg/code/data/invite/v2/memory/store.go @@ -6,8 +6,8 @@ import ( "sync" "time" - "github.com/code-payments/code-server/pkg/phone" "github.com/code-payments/code-server/pkg/code/data/invite/v2" + "github.com/code-payments/code-server/pkg/phone" ) type store struct { @@ -45,7 +45,7 @@ func (s *store) PutUser(ctx context.Context, user *invite.User) error { s.mu.Lock() defer s.mu.Unlock() - copy := &invite.User{ + cpy := &invite.User{ PhoneNumber: user.PhoneNumber, InvitedBy: user.InvitedBy, Invited: user.Invited, @@ -55,22 +55,21 @@ func (s *store) PutUser(ctx context.Context, user *invite.User) error { IsRevoked: user.IsRevoked, } - _, alreadyExists := s.usersByPhoneNumber[copy.PhoneNumber] + _, alreadyExists := s.usersByPhoneNumber[cpy.PhoneNumber] if alreadyExists { return invite.ErrAlreadyExists } - if copy.InvitedBy != nil { - - if phone.IsE164Format(*copy.InvitedBy) { - sender, ok := s.usersByPhoneNumber[*copy.InvitedBy] + if cpy.InvitedBy != nil { + if phone.IsE164Format(*cpy.InvitedBy) { + sender, ok := s.usersByPhoneNumber[*cpy.InvitedBy] if !ok || sender.InvitesSent >= sender.InviteCount || sender.IsRevoked { return invite.ErrInviteCountExceeded } sender.InvitesSent++ } else { - influencer, ok := s.influencerCodeByCode[*copy.InvitedBy] + influencer, ok := s.influencerCodeByCode[*cpy.InvitedBy] if !ok || influencer.InvitesSent >= influencer.InviteCount || influencer.IsRevoked || influencer.ExpiresAt.Before(time.Now()) { return invite.ErrInviteCountExceeded } @@ -79,7 +78,7 @@ func (s *store) PutUser(ctx context.Context, user *invite.User) error { } } - s.usersByPhoneNumber[copy.PhoneNumber] = copy + s.usersByPhoneNumber[cpy.PhoneNumber] = cpy return nil } @@ -158,7 +157,6 @@ func (s *store) FilterInvitedNumbers(ctx context.Context, phoneNumbers []string) } return filtered, nil - } // PutOnWaitlist implements invite.v2.Store.PutOnWaitlist diff --git a/pkg/code/data/invite/v2/postgres/model.go b/pkg/code/data/invite/v2/postgres/model.go index 05428403..987eddc2 100644 --- a/pkg/code/data/invite/v2/postgres/model.go +++ b/pkg/code/data/invite/v2/postgres/model.go @@ -114,16 +114,25 @@ func (m *userModel) dbPut(ctx context.Context, db *sqlx.DB) error { result, err := tx.ExecContext(ctx, updateQuery, m.InvitedBy.String) if err != nil { - tx.Rollback() + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("failed to rollback transaction: %w", rollbackErr) + } + return err } rowsAffected, err := result.RowsAffected() if err != nil { - tx.Rollback() + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("failed to rollback transaction: %w", rollbackErr) + } + return err } else if rowsAffected == 0 { - tx.Rollback() + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("failed to rollback transaction: %w", rollbackErr) + } + return invite.ErrInviteCountExceeded } } @@ -144,7 +153,10 @@ func (m *userModel) dbPut(ctx context.Context, db *sqlx.DB) error { m.IsRevoked, ) if err != nil { - tx.Rollback() + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("failed to rollback transaction: %w", rollbackErr) + } + return pgutil.CheckUniqueViolation(err, invite.ErrAlreadyExists) } diff --git a/pkg/code/data/invite/v2/tests/tests.go b/pkg/code/data/invite/v2/tests/tests.go index cc52241a..125335b4 100644 --- a/pkg/code/data/invite/v2/tests/tests.go +++ b/pkg/code/data/invite/v2/tests/tests.go @@ -319,7 +319,6 @@ func testInfluencerCodeClaim(t *testing.T, s invite.Store) { // Test multiple claims for i := uint32(influencerCode.InvitesSent + 1); i <= influencerCode.InviteCount; i++ { - // Claim the influencer code err = s.ClaimInfluencerCode(ctx, inviteCode) assert.NoError(t, err) @@ -425,7 +424,6 @@ func testInfluencerCodePutUser(t *testing.T, s invite.Store) { // Test multiple claims for i := uint32(influencerCode.InvitesSent + 1); i <= influencerCode.InviteCount; i++ { - // Put a user to claim the code phoneNumber := fmt.Sprintf("+1800555000%d", i) invitedUser := &invite.User{ diff --git a/pkg/code/data/merkletree/merkletree.go b/pkg/code/data/merkletree/merkletree.go index cf1aac40..2fac79fe 100644 --- a/pkg/code/data/merkletree/merkletree.go +++ b/pkg/code/data/merkletree/merkletree.go @@ -155,7 +155,7 @@ func (t *MerkleTree) AddLeaf(ctx context.Context, leaf Leaf) error { // Maintain a temporary copy of the next state nextVersion := t.mtdt.NextIndex + 1 nextMtdt := t.mtdt.Clone() - nextMtdt.NextIndex += 1 + nextMtdt.NextIndex++ nextFilledSubtrees := make([]Hash, t.mtdt.Levels) leafNode := &Node{ diff --git a/pkg/code/data/nonce/nonce.go b/pkg/code/data/nonce/nonce.go index 69cec6d6..17c5ec3a 100644 --- a/pkg/code/data/nonce/nonce.go +++ b/pkg/code/data/nonce/nonce.go @@ -75,16 +75,16 @@ func (r *Record) CopyTo(dst *Record) { dst.Signature = r.Signature } -func (v *Record) Validate() error { - if len(v.Address) == 0 { +func (r *Record) Validate() error { + if len(r.Address) == 0 { return errors.New("nonce account address is required") } - if len(v.Authority) == 0 { + if len(r.Authority) == 0 { return errors.New("authority address is required") } - if v.Purpose == PurposeUnknown { + if r.Purpose == PurposeUnknown { return errors.New("nonce purpose must be set") } return nil diff --git a/pkg/code/data/onramp/memory/store.go b/pkg/code/data/onramp/memory/store.go index de4d5ce9..33c08eff 100644 --- a/pkg/code/data/onramp/memory/store.go +++ b/pkg/code/data/onramp/memory/store.go @@ -34,17 +34,17 @@ func (s *store) Put(_ context.Context, data *onramp.Record) error { s.last++ if item := s.find(data); item != nil { return onramp.ErrPurchaseAlreadyExists - } else { - if data.Id == 0 { - data.Id = s.last - } - if data.CreatedAt.IsZero() { - data.CreatedAt = time.Now() - } - c := data.Clone() - s.records = append(s.records, &c) } + if data.Id == 0 { + data.Id = s.last + } + if data.CreatedAt.IsZero() { + data.CreatedAt = time.Now() + } + c := data.Clone() + s.records = append(s.records, &c) + return nil } diff --git a/pkg/code/data/payment/memory/store.go b/pkg/code/data/payment/memory/store.go index e5ea0dad..4a976d41 100644 --- a/pkg/code/data/payment/memory/store.go +++ b/pkg/code/data/payment/memory/store.go @@ -7,9 +7,9 @@ import ( "strings" "sync" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/payment" "github.com/code-payments/code-server/pkg/code/data/transaction" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -55,7 +55,7 @@ func (s *store) Put(ctx context.Context, data *payment.Record) error { data.Id = s.lastIndex data.ExchangeCurrency = strings.ToLower(data.ExchangeCurrency) s.paymentRecords[pk] = data - s.lastIndex += 1 + s.lastIndex++ return nil } @@ -157,7 +157,7 @@ func (s *store) GetAllForAccount(ctx context.Context, account string, cursor uin return all[:limit], nil } -func (s *store) GetAllForAccountByType(ctx context.Context, account string, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.PaymentType) (result []*payment.Record, err error) { +func (s *store) GetAllForAccountByType(ctx context.Context, account string, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.Type) (result []*payment.Record, err error) { s.paymentRecordMu.Lock() defer s.paymentRecordMu.Unlock() @@ -168,9 +168,9 @@ func (s *store) GetAllForAccountByType(ctx context.Context, account string, curs // not ideal, but this is for testing purposes and s.paymentRecord should be small all := make([]*payment.Record, 0) for _, record := range s.paymentRecords { - if (paymentType == payment.PaymentType_Send && record.Source == account) || - (paymentType == payment.PaymentType_Receive && record.Destination == account) { - + isSender := paymentType == payment.TypeSend && record.Source == account + isReceiver := paymentType == payment.TypeReceive && record.Destination == account + if isSender || isReceiver { if ordering == query.Ascending { if record.Id > cursor { all = append(all, record) @@ -202,7 +202,7 @@ func (s *store) GetAllForAccountByType(ctx context.Context, account string, curs return all[:limit], nil } -func (s *store) GetAllForAccountByTypeAfterBlock(ctx context.Context, account string, block uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.PaymentType) (result []*payment.Record, err error) { +func (s *store) GetAllForAccountByTypeAfterBlock(ctx context.Context, account string, block uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.Type) (result []*payment.Record, err error) { s.paymentRecordMu.Lock() defer s.paymentRecordMu.Unlock() @@ -217,9 +217,9 @@ func (s *store) GetAllForAccountByTypeAfterBlock(ctx context.Context, account st continue } - if (paymentType == payment.PaymentType_Send && record.Source == account) || - (paymentType == payment.PaymentType_Receive && record.Destination == account) { - + isSender := paymentType == payment.TypeSend && record.Source == account + isReceiver := paymentType == payment.TypeReceive && record.Destination == account + if isSender || isReceiver { if ordering == query.Ascending { if record.Id > cursor { all = append(all, record) @@ -251,7 +251,7 @@ func (s *store) GetAllForAccountByTypeAfterBlock(ctx context.Context, account st return all[:limit], nil } -func (s *store) GetAllForAccountByTypeWithinBlockRange(ctx context.Context, account string, lowerBound, upperBound uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.PaymentType) ([]*payment.Record, error) { +func (s *store) GetAllForAccountByTypeWithinBlockRange(ctx context.Context, account string, lowerBound, upperBound uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.Type) ([]*payment.Record, error) { s.paymentRecordMu.Lock() defer s.paymentRecordMu.Unlock() @@ -266,9 +266,9 @@ func (s *store) GetAllForAccountByTypeWithinBlockRange(ctx context.Context, acco continue } - if (paymentType == payment.PaymentType_Send && record.Source == account) || - (paymentType == payment.PaymentType_Receive && record.Destination == account) { - + isSender := paymentType == payment.TypeSend && record.Source == account + isReceiver := paymentType == payment.TypeReceive && record.Destination == account + if isSender || isReceiver { if ordering == query.Ascending { if record.Id > cursor { all = append(all, record) diff --git a/pkg/code/data/payment/payment.go b/pkg/code/data/payment/payment.go index 71d92c07..ed3cd114 100644 --- a/pkg/code/data/payment/payment.go +++ b/pkg/code/data/payment/payment.go @@ -4,9 +4,9 @@ import ( "bytes" "time" + "github.com/code-payments/code-server/pkg/code/data/transaction" "github.com/code-payments/code-server/pkg/kin" "github.com/code-payments/code-server/pkg/solana/token" - "github.com/code-payments/code-server/pkg/code/data/transaction" "github.com/mr-tron/base58" "github.com/pkg/errors" ) @@ -53,11 +53,11 @@ type Record struct { CreatedAt time.Time } -type PaymentType uint32 +type Type uint32 const ( - PaymentType_Send PaymentType = iota - PaymentType_Receive + TypeSend Type = iota + TypeReceive ) func NewFromTransfer(transfer *token.DecompiledTransfer, sig string, index int, rate float64, now time.Time) *Record { diff --git a/pkg/code/data/payment/postgres/model.go b/pkg/code/data/payment/postgres/model.go index a9d5969a..314776fb 100644 --- a/pkg/code/data/payment/postgres/model.go +++ b/pkg/code/data/payment/postgres/model.go @@ -8,10 +8,10 @@ import ( "github.com/jmoiron/sqlx" - pgutil "github.com/code-payments/code-server/pkg/database/postgres" - q "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/payment" "github.com/code-payments/code-server/pkg/code/data/transaction" + pgutil "github.com/code-payments/code-server/pkg/database/postgres" + q "github.com/code-payments/code-server/pkg/database/query" ) const ( @@ -116,33 +116,33 @@ func fromModel(obj *model) *payment.Record { return record } -func (self *model) dbSave(ctx context.Context, db *sqlx.DB) error { +func (m *model) dbSave(ctx context.Context, db *sqlx.DB) error { query := `INSERT INTO ` + tableName + ` (` + tableColumns + `) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) RETURNING *;` err := db.QueryRowxContext(ctx, query, - self.BlockId, - self.BlockTime, - self.TransactionId, - self.TransactionIndex, - self.Rendezvous, - self.IsExternal, - self.SourceId, - self.DestinationId, - self.Quantity, - self.ExchangeCurrency, - self.Region, - self.ExchangeRate, - self.UsdMarketValue, - self.IsWithdraw, - self.ConfirmationState, - self.CreatedAt, - ).StructScan(self) + m.BlockId, + m.BlockTime, + m.TransactionId, + m.TransactionIndex, + m.Rendezvous, + m.IsExternal, + m.SourceId, + m.DestinationId, + m.Quantity, + m.ExchangeCurrency, + m.Region, + m.ExchangeRate, + m.UsdMarketValue, + m.IsWithdraw, + m.ConfirmationState, + m.CreatedAt, + ).StructScan(m) return pgutil.CheckUniqueViolation(err, payment.ErrExists) } -func (self *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { +func (m *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { query := `UPDATE ` + tableName + ` SET block_id = $2, @@ -155,15 +155,15 @@ func (self *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { WHERE id = $1 RETURNING *;` err := db.QueryRowxContext(ctx, query, - self.Id, - self.BlockId, - self.BlockTime, - self.ExchangeCurrency, - self.Region, - self.ExchangeRate, - self.UsdMarketValue, - self.ConfirmationState, - ).StructScan(self) + m.Id, + m.BlockId, + m.BlockTime, + m.ExchangeCurrency, + m.Region, + m.ExchangeRate, + m.UsdMarketValue, + m.ConfirmationState, + ).StructScan(m) return pgutil.CheckNoRows(err, payment.ErrNotFound) } @@ -245,11 +245,11 @@ func dbGetAllForAccount(ctx context.Context, db *sqlx.DB, account string, cursor return res, nil } -func dbGetAllForAccountByType(ctx context.Context, db *sqlx.DB, account string, cursor uint64, limit uint, ordering q.Ordering, paymentType payment.PaymentType) ([]*model, error) { +func dbGetAllForAccountByType(ctx context.Context, db *sqlx.DB, account string, cursor uint64, limit uint, ordering q.Ordering, paymentType payment.Type) ([]*model, error) { res := []*model{} var condition string - if paymentType == payment.PaymentType_Send { + if paymentType == payment.TypeSend { condition = "source = $1" } else { condition = "destination = $1" @@ -270,11 +270,11 @@ func dbGetAllForAccountByType(ctx context.Context, db *sqlx.DB, account string, return res, nil } -func dbGetAllForAccountByTypeAfterBlock(ctx context.Context, db *sqlx.DB, account string, block uint64, cursor uint64, limit uint, ordering q.Ordering, paymentType payment.PaymentType) ([]*model, error) { +func dbGetAllForAccountByTypeAfterBlock(ctx context.Context, db *sqlx.DB, account string, block uint64, cursor uint64, limit uint, ordering q.Ordering, paymentType payment.Type) ([]*model, error) { res := []*model{} var condition string - if paymentType == payment.PaymentType_Send { + if paymentType == payment.TypeSend { condition = "source = $1" } else { condition = "destination = $1" @@ -297,11 +297,11 @@ func dbGetAllForAccountByTypeAfterBlock(ctx context.Context, db *sqlx.DB, accoun return res, nil } -func dbGetAllForAccountByTypeWithinBlockRange(ctx context.Context, db *sqlx.DB, account string, lowerBound, upperBound uint64, cursor uint64, limit uint, ordering q.Ordering, paymentType payment.PaymentType) ([]*model, error) { +func dbGetAllForAccountByTypeWithinBlockRange(ctx context.Context, db *sqlx.DB, account string, lowerBound, upperBound uint64, cursor uint64, limit uint, ordering q.Ordering, paymentType payment.Type) ([]*model, error) { res := []*model{} var condition string - if paymentType == payment.PaymentType_Send { + if paymentType == payment.TypeSend { condition = "source = $1" } else { condition = "destination = $1" diff --git a/pkg/code/data/payment/postgres/store.go b/pkg/code/data/payment/postgres/store.go index acfa8ae8..c38cf436 100644 --- a/pkg/code/data/payment/postgres/store.go +++ b/pkg/code/data/payment/postgres/store.go @@ -4,8 +4,8 @@ import ( "context" "database/sql" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/payment" + "github.com/code-payments/code-server/pkg/database/query" "github.com/jmoiron/sqlx" ) @@ -64,7 +64,7 @@ func (s *store) GetAllForAccount(ctx context.Context, account string, cursor uin return res, nil } -func (s *store) GetAllForAccountByType(ctx context.Context, account string, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.PaymentType) ([]*payment.Record, error) { +func (s *store) GetAllForAccountByType(ctx context.Context, account string, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.Type) ([]*payment.Record, error) { list, err := dbGetAllForAccountByType(ctx, s.db, account, cursor, limit, ordering, paymentType) if err != nil { return nil, err @@ -78,7 +78,7 @@ func (s *store) GetAllForAccountByType(ctx context.Context, account string, curs return res, nil } -func (s *store) GetAllForAccountByTypeAfterBlock(ctx context.Context, account string, block uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.PaymentType) ([]*payment.Record, error) { +func (s *store) GetAllForAccountByTypeAfterBlock(ctx context.Context, account string, block uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.Type) ([]*payment.Record, error) { list, err := dbGetAllForAccountByTypeAfterBlock(ctx, s.db, account, block, cursor, limit, ordering, paymentType) if err != nil { return nil, err @@ -92,7 +92,7 @@ func (s *store) GetAllForAccountByTypeAfterBlock(ctx context.Context, account st return res, nil } -func (s *store) GetAllForAccountByTypeWithinBlockRange(ctx context.Context, account string, lowerBound, upperBound uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.PaymentType) ([]*payment.Record, error) { +func (s *store) GetAllForAccountByTypeWithinBlockRange(ctx context.Context, account string, lowerBound, upperBound uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType payment.Type) ([]*payment.Record, error) { list, err := dbGetAllForAccountByTypeWithinBlockRange(ctx, s.db, account, lowerBound, upperBound, cursor, limit, ordering, paymentType) if err != nil { return nil, err diff --git a/pkg/code/data/payment/store.go b/pkg/code/data/payment/store.go index 558f5410..f71c05ce 100644 --- a/pkg/code/data/payment/store.go +++ b/pkg/code/data/payment/store.go @@ -36,21 +36,21 @@ type Store interface { // "limit" results. // // ErrNotFound is returned if no rows are found. - GetAllForAccountByType(ctx context.Context, account string, cursor uint64, limit uint, ordering query.Ordering, paymentType PaymentType) ([]*Record, error) + GetAllForAccountByType(ctx context.Context, account string, cursor uint64, limit uint, ordering query.Ordering, paymentType Type) ([]*Record, error) // GetAllForAccountByTypeAfterBlock returns payment records in the store for a // given "account" after a "block" after a provided "cursor" value and limited // to at most "limit" results. // // ErrNotFound is returned if no rows are found. - GetAllForAccountByTypeAfterBlock(ctx context.Context, account string, block uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType PaymentType) ([]*Record, error) + GetAllForAccountByTypeAfterBlock(ctx context.Context, account string, block uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType Type) ([]*Record, error) // GetAllForAccountByTypeWithinBlockRange returns payment records in the store // for a given "account" within a "block" range (lowerBound, upperBOund) after a // provided "cursor" value and limited to at most "limit" results. // // ErrNotFound is returned if no rows are found. - GetAllForAccountByTypeWithinBlockRange(ctx context.Context, account string, lowerBound, upperBound uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType PaymentType) ([]*Record, error) + GetAllForAccountByTypeWithinBlockRange(ctx context.Context, account string, lowerBound, upperBound uint64, cursor uint64, limit uint, ordering query.Ordering, paymentType Type) ([]*Record, error) // GetExternalDepositAmount gets the total amount of Kin in quarks deposited to // an account via a deposit from an external account. diff --git a/pkg/code/data/payment/tests/tests.go b/pkg/code/data/payment/tests/tests.go index d354a4f1..01facbee 100644 --- a/pkg/code/data/payment/tests/tests.go +++ b/pkg/code/data/payment/tests/tests.go @@ -9,8 +9,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/payment" + "github.com/code-payments/code-server/pkg/database/query" ) type TestData struct { @@ -144,5 +144,4 @@ func testGetRange(t *testing.T, s payment.Store) { assert.Equal(t, 2, len(results)) CompareTestDataToPayment(t, testPayments[4], results[0]) CompareTestDataToPayment(t, testPayments[5], results[1]) - } diff --git a/pkg/code/data/paymentrequest/memory/store.go b/pkg/code/data/paymentrequest/memory/store.go index 8c86ca20..7dfaf2c8 100644 --- a/pkg/code/data/paymentrequest/memory/store.go +++ b/pkg/code/data/paymentrequest/memory/store.go @@ -40,25 +40,25 @@ func (s *store) Put(_ context.Context, data *paymentrequest.Record) error { s.last++ if item := s.find(data); item != nil { return paymentrequest.ErrPaymentRequestAlreadyExists - } else { - seenDestinations := make(map[string]any) - for _, fee := range data.Fees { - _, ok := seenDestinations[fee.DestinationTokenAccount] - if ok { - return paymentrequest.ErrInvalidPaymentRequest - } - seenDestinations[fee.DestinationTokenAccount] = true - } + } - if data.Id == 0 { - data.Id = s.last - } - if data.CreatedAt.IsZero() { - data.CreatedAt = time.Now() + seenDestinations := make(map[string]any) + for _, fee := range data.Fees { + _, ok := seenDestinations[fee.DestinationTokenAccount] + if ok { + return paymentrequest.ErrInvalidPaymentRequest } - c := data.Clone() - s.records = append(s.records, &c) + seenDestinations[fee.DestinationTokenAccount] = true + } + + if data.Id == 0 { + data.Id = s.last + } + if data.CreatedAt.IsZero() { + data.CreatedAt = time.Now() } + c := data.Clone() + s.records = append(s.records, &c) return nil } diff --git a/pkg/code/data/paywall/memory/store.go b/pkg/code/data/paywall/memory/store.go index 0d6d8a45..184b49d2 100644 --- a/pkg/code/data/paywall/memory/store.go +++ b/pkg/code/data/paywall/memory/store.go @@ -40,17 +40,17 @@ func (s *store) Put(_ context.Context, data *paywall.Record) error { s.last++ if item := s.find(data); item != nil { return paywall.ErrPaywallExists - } else { - if data.Id == 0 { - data.Id = s.last - } - if data.CreatedAt.IsZero() { - data.CreatedAt = time.Now() - } - c := data.Clone() - s.records = append(s.records, &c) } + if data.Id == 0 { + data.Id = s.last + } + if data.CreatedAt.IsZero() { + data.CreatedAt = time.Now() + } + c := data.Clone() + s.records = append(s.records, &c) + return nil } diff --git a/pkg/code/data/phone/memory/store.go b/pkg/code/data/phone/memory/store.go index 2404d16f..66c56915 100644 --- a/pkg/code/data/phone/memory/store.go +++ b/pkg/code/data/phone/memory/store.go @@ -64,14 +64,14 @@ func (s *store) SaveVerification(ctx context.Context, newVerification *phone.Ver break } if !alreadyExists { - copy := &phone.Verification{ + cpy := &phone.Verification{ OwnerAccount: newVerification.OwnerAccount, PhoneNumber: newVerification.PhoneNumber, CreatedAt: newVerification.CreatedAt, LastVerifiedAt: newVerification.LastVerifiedAt, } - s.verificationsByAccount[copy.OwnerAccount] = append(s.verificationsByAccount[copy.OwnerAccount], copy) - s.verificationsByNumber[copy.PhoneNumber] = append(s.verificationsByNumber[copy.PhoneNumber], copy) + s.verificationsByAccount[cpy.OwnerAccount] = append(s.verificationsByAccount[cpy.OwnerAccount], cpy) + s.verificationsByNumber[cpy.PhoneNumber] = append(s.verificationsByNumber[cpy.PhoneNumber], cpy) } currentByAccount = s.verificationsByAccount[newVerification.OwnerAccount] @@ -154,14 +154,14 @@ func (s *store) SaveLinkingToken(ctx context.Context, token *phone.LinkingToken) s.mu.Lock() defer s.mu.Unlock() - copy := &phone.LinkingToken{ + cpy := &phone.LinkingToken{ PhoneNumber: token.PhoneNumber, Code: token.Code, ExpiresAt: token.ExpiresAt, CurrentCheckCount: token.CurrentCheckCount, MaxCheckCount: token.MaxCheckCount, } - s.linkingTokensByNumber[copy.PhoneNumber] = copy + s.linkingTokensByNumber[cpy.PhoneNumber] = cpy return nil } @@ -231,17 +231,17 @@ func (s *store) SaveOwnerAccountSetting(ctx context.Context, phoneNumber string, s.mu.Lock() defer s.mu.Unlock() - copy := &phone.Settings{ + cpy := &phone.Settings{ PhoneNumber: phoneNumber, ByOwnerAccount: make(map[string]*phone.OwnerAccountSetting), } phoneSettings, ok := s.settingsByNumber[phoneNumber] if ok { - copy = phoneSettings + cpy = phoneSettings } - currentSettings, ok := copy.ByOwnerAccount[newSettings.OwnerAccount] + currentSettings, ok := cpy.ByOwnerAccount[newSettings.OwnerAccount] if ok { if newSettings.IsUnlinked != nil { flagCopy := *newSettings.IsUnlinked @@ -251,13 +251,13 @@ func (s *store) SaveOwnerAccountSetting(ctx context.Context, phoneNumber string, return nil } - copy.ByOwnerAccount[newSettings.OwnerAccount] = &phone.OwnerAccountSetting{ + cpy.ByOwnerAccount[newSettings.OwnerAccount] = &phone.OwnerAccountSetting{ OwnerAccount: newSettings.OwnerAccount, IsUnlinked: newSettings.IsUnlinked, CreatedAt: newSettings.CreatedAt, LastUpdatedAt: newSettings.LastUpdatedAt, } - s.settingsByNumber[phoneNumber] = copy + s.settingsByNumber[phoneNumber] = cpy return nil } @@ -274,7 +274,7 @@ func (s *store) PutEvent(ctx context.Context, event *phone.Event) error { eventsByNumber := s.eventsByNumber[event.PhoneNumber] eventsByVerification := s.eventsByVerification[event.VerificationId] - copy := &phone.Event{ + cpy := &phone.Event{ Type: event.Type, VerificationId: event.VerificationId, PhoneNumber: event.PhoneNumber, @@ -282,10 +282,10 @@ func (s *store) PutEvent(ctx context.Context, event *phone.Event) error { CreatedAt: event.CreatedAt, } - eventsByNumber = append(eventsByNumber, copy) + eventsByNumber = append(eventsByNumber, cpy) s.eventsByNumber[event.PhoneNumber] = eventsByNumber - eventsByVerification = append(eventsByVerification, copy) + eventsByVerification = append(eventsByVerification, cpy) s.eventsByVerification[event.VerificationId] = eventsByVerification return nil @@ -327,7 +327,7 @@ func (s *store) CountEventsForVerificationByType(ctx context.Context, verificati for _, event := range s.eventsByVerification[verification] { if event.Type == eventType { - count += 1 + count++ } } @@ -350,7 +350,7 @@ func (s *store) CountEventsForNumberByTypeSinceTimestamp(ctx context.Context, ph continue } - count += 1 + count++ } return count, nil diff --git a/pkg/code/data/provider.go b/pkg/code/data/provider.go index 9dc46d57..2ab696ed 100644 --- a/pkg/code/data/provider.go +++ b/pkg/code/data/provider.go @@ -26,7 +26,7 @@ type Provider interface { GetEstimatedDataProvider() EstimatedData } -type DataProvider struct { +type provider struct { *BlockchainProvider *DatabaseProvider *WebProvider @@ -44,7 +44,7 @@ func NewDataProvider(dbConfig *pg.Config, solanaEnv string, configProvider Confi return nil, err } - provider := p.(*DataProvider) + provider := p.(*provider) provider.BlockchainProvider = blockchain.(*BlockchainProvider) return provider, nil @@ -66,7 +66,7 @@ func NewDataProviderWithoutBlockchain(dbConfig *pg.Config, configProvider Config return nil, err } - provider := &DataProvider{ + provider := &provider{ DatabaseProvider: db.(*DatabaseProvider), WebProvider: web.(*WebProvider), EstimatedProvider: estimated.(*EstimatedProvider), @@ -84,21 +84,21 @@ func NewTestDataProvider() Provider { panic(err) } - return &DataProvider{ + return &provider{ DatabaseProvider: NewTestDatabaseProvider().(*DatabaseProvider), BlockchainProvider: blockchain.(*BlockchainProvider), } } -func (p *DataProvider) GetBlockchainDataProvider() BlockchainData { +func (p *provider) GetBlockchainDataProvider() BlockchainData { return p.BlockchainProvider } -func (p *DataProvider) GetWebDataProvider() WebData { +func (p *provider) GetWebDataProvider() WebData { return p.WebProvider } -func (p *DataProvider) GetDatabaseDataProvider() DatabaseData { +func (p *provider) GetDatabaseDataProvider() DatabaseData { return p.DatabaseProvider } -func (p *DataProvider) GetEstimatedDataProvider() EstimatedData { +func (p *provider) GetEstimatedDataProvider() EstimatedData { return p.EstimatedProvider } diff --git a/pkg/code/data/push/memory/store.go b/pkg/code/data/push/memory/store.go index faf81a65..c3293b34 100644 --- a/pkg/code/data/push/memory/store.go +++ b/pkg/code/data/push/memory/store.go @@ -32,11 +32,11 @@ func (s *store) Put(_ context.Context, record *push.Record) error { if item := s.find(record); item != nil { return push.ErrTokenExists - } else { - record.Id = s.last - s.records = append(s.records, record.Clone()) } + record.Id = s.last + s.records = append(s.records, record.Clone()) + return nil } diff --git a/pkg/code/data/transaction/memory/store.go b/pkg/code/data/transaction/memory/store.go index 23ce994b..ed1b1709 100644 --- a/pkg/code/data/transaction/memory/store.go +++ b/pkg/code/data/transaction/memory/store.go @@ -5,8 +5,8 @@ import ( "sort" "sync" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/transaction" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -45,8 +45,7 @@ func (s *store) Put(ctx context.Context, data *transaction.Record) error { defer s.transactionStoreMu.Unlock() s.lastIndex++ - var clone transaction.Record - clone = *data + clone := *data for index, item := range s.transactionStore { if item.Signature == data.Signature { diff --git a/pkg/code/data/transaction/postgres/model.go b/pkg/code/data/transaction/postgres/model.go index 07220a61..281ce0c2 100644 --- a/pkg/code/data/transaction/postgres/model.go +++ b/pkg/code/data/transaction/postgres/model.go @@ -7,9 +7,9 @@ import ( "github.com/jmoiron/sqlx" + "github.com/code-payments/code-server/pkg/code/data/transaction" pg "github.com/code-payments/code-server/pkg/database/postgres" q "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/code/data/transaction" ) const ( @@ -90,7 +90,7 @@ func makeAllQuery(table, columns, condition string, ordering q.Ordering) string return query } -func (self *transactionModel) txSaveTx(ctx context.Context, tx *sqlx.Tx) error { +func (m *transactionModel) txSaveTx(ctx context.Context, tx *sqlx.Tx) error { query := `INSERT INTO ` + tableNameTx + ` ( signature, block_id, block_time, raw_data, fee, has_errors, @@ -109,16 +109,16 @@ func (self *transactionModel) txSaveTx(ctx context.Context, tx *sqlx.Tx) error { RETURNING *;` err := tx.QueryRowxContext(ctx, query, - self.Signature, - self.Slot, - self.BlockTime, - self.Data, - self.Fee, - self.HasErrors, - self.ConfirmationState, - self.Confirmations, - self.CreatedAt, - ).StructScan(self) + m.Signature, + m.Slot, + m.BlockTime, + m.Data, + m.Fee, + m.HasErrors, + m.ConfirmationState, + m.Confirmations, + m.CreatedAt, + ).StructScan(m) return err } @@ -168,8 +168,23 @@ func dbGetFirstPending(ctx context.Context, db *sqlx.DB, address string) (*trans postBalance uint64 //`db:"post_balance"` ) - rows.Scan(&txId, &signature, &slot, &blockTime, &data, &fee, &hasErrors, - &confirmationState, &confirmations, &createdAt, &account, &preBalance, &postBalance) + err = rows.Scan( + &txId, + &signature, + &slot, + &blockTime, + &data, &fee, + &hasErrors, + &confirmationState, + &confirmations, + &createdAt, + &account, + &preBalance, + &postBalance, + ) + if err != nil { + return nil, err + } tb := &transaction.TokenBalance{ Account: account, @@ -242,8 +257,24 @@ func dbGetLatestByState(ctx context.Context, db *sqlx.DB, address string, state postBalance uint64 //`db:"post_balance"` ) - rows.Scan(&txId, &signature, &slot, &blockTime, &data, &fee, &hasErrors, - &confirmationState, &confirmations, &createdAt, &account, &preBalance, &postBalance) + err = rows.Scan( + &txId, + &signature, + &slot, + &blockTime, + &data, + &fee, + &hasErrors, + &confirmationState, + &confirmations, + &createdAt, + &account, + &preBalance, + &postBalance, + ) + if err != nil { + return nil, err + } tb := &transaction.TokenBalance{ Account: account, @@ -337,8 +368,24 @@ func dbGetAllByAddress(ctx context.Context, db *sqlx.DB, address string, cursor postBalance uint64 //`db:"post_balance"` ) - rows.Scan(&txId, &signature, &slot, &blockTime, &data, &fee, &hasErrors, - &confirmationState, &confirmations, &createdAt, &account, &preBalance, &postBalance) + err = rows.Scan( + &txId, + &signature, + &slot, + &blockTime, + &data, + &fee, + &hasErrors, + &confirmationState, + &confirmations, + &createdAt, + &account, + &preBalance, + &postBalance, + ) + if err != nil { + return nil, err + } tb := &transaction.TokenBalance{ Account: account, @@ -418,7 +465,7 @@ func fromTxBalanceModel(obj *tokenBalanceModel) *transaction.TokenBalance { } } -func (self *tokenBalanceModel) txSaveTxBalance(ctx context.Context, tx *sqlx.Tx) error { +func (m *tokenBalanceModel) txSaveTxBalance(ctx context.Context, tx *sqlx.Tx) error { query := `INSERT INTO ` + tableNameTxBalance + ` (transaction_id, account, pre_balance, post_balance) VALUES ($1,$2,$3,$4) @@ -429,11 +476,11 @@ func (self *tokenBalanceModel) txSaveTxBalance(ctx context.Context, tx *sqlx.Tx) RETURNING *;` err := tx.QueryRowxContext(ctx, query, - self.TransactionId, - self.Account, - self.PreBalance, - self.PostBalance, - ).StructScan(self) + m.TransactionId, + m.Account, + m.PreBalance, + m.PostBalance, + ).StructScan(m) return err } diff --git a/pkg/code/data/treasury/memory/store.go b/pkg/code/data/treasury/memory/store.go index bb8d16bb..218beb37 100644 --- a/pkg/code/data/treasury/memory/store.go +++ b/pkg/code/data/treasury/memory/store.go @@ -6,8 +6,8 @@ import ( "sync" "time" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/treasury" + "github.com/code-payments/code-server/pkg/database/query" ) type ById []*treasury.Record @@ -99,7 +99,7 @@ func (s *store) GetByVault(_ context.Context, vault string) (*treasury.Record, e } // GetAllByState implements treasury.Store.GetAllByState -func (s *store) GetAllByState(_ context.Context, state treasury.TreasuryPoolState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*treasury.Record, error) { +func (s *store) GetAllByState(_ context.Context, state treasury.PoolState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*treasury.Record, error) { s.mu.Lock() defer s.mu.Unlock() @@ -203,7 +203,7 @@ func (s *store) findTreasuryPoolByVault(vault string) *treasury.Record { return nil } -func (s *store) findTreasuryPoolByState(state treasury.TreasuryPoolState) []*treasury.Record { +func (s *store) findTreasuryPoolByState(state treasury.PoolState) []*treasury.Record { res := make([]*treasury.Record, 0) for _, item := range s.treasuryPoolRecords { if item.State == state { diff --git a/pkg/code/data/treasury/postgres/model.go b/pkg/code/data/treasury/postgres/model.go index adb3b0b9..14cca990 100644 --- a/pkg/code/data/treasury/postgres/model.go +++ b/pkg/code/data/treasury/postgres/model.go @@ -8,10 +8,10 @@ import ( "github.com/jmoiron/sqlx" + "github.com/code-payments/code-server/pkg/code/data/treasury" pgutil "github.com/code-payments/code-server/pkg/database/postgres" q "github.com/code-payments/code-server/pkg/database/query" splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" - "github.com/code-payments/code-server/pkg/code/data/treasury" ) const ( @@ -136,7 +136,7 @@ func fromTreasuryPoolModel(obj *treasuryPoolModel) *treasury.Record { SolanaBlock: obj.SolanaBlock, - State: treasury.TreasuryPoolState(obj.State), + State: treasury.PoolState(obj.State), LastUpdatedAt: obj.LastUpdatedAt, } @@ -317,7 +317,7 @@ func dbGetByVault(ctx context.Context, db *sqlx.DB, vault string) (*treasuryPool return &res, nil } -func dbGetAllByState(ctx context.Context, db *sqlx.DB, state treasury.TreasuryPoolState, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*treasuryPoolModel, error) { +func dbGetAllByState(ctx context.Context, db *sqlx.DB, state treasury.PoolState, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*treasuryPoolModel, error) { res := []*treasuryPoolModel{} query := `SELECT id, data_version, name, address, bump, vault, vault_bump, authority, merkle_tree_levels, current_index, history_list_size, solana_block, state, last_updated_at diff --git a/pkg/code/data/treasury/postgres/store.go b/pkg/code/data/treasury/postgres/store.go index 39f9770f..07d9fb4d 100644 --- a/pkg/code/data/treasury/postgres/store.go +++ b/pkg/code/data/treasury/postgres/store.go @@ -6,8 +6,8 @@ import ( "github.com/jmoiron/sqlx" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/treasury" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -69,7 +69,7 @@ func (s *store) GetByVault(ctx context.Context, vault string) (*treasury.Record, } // GetAllByState implements treasury.Store.GetAllByState -func (s *store) GetAllByState(ctx context.Context, state treasury.TreasuryPoolState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*treasury.Record, error) { +func (s *store) GetAllByState(ctx context.Context, state treasury.PoolState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*treasury.Record, error) { models, err := dbGetAllByState(ctx, s.db, state, cursor, limit, direction) if err != nil { return nil, err diff --git a/pkg/code/data/treasury/store.go b/pkg/code/data/treasury/store.go index 31497f71..8921250a 100644 --- a/pkg/code/data/treasury/store.go +++ b/pkg/code/data/treasury/store.go @@ -28,7 +28,7 @@ type Store interface { GetByVault(ctx context.Context, vault string) (*Record, error) // GetAllByState gets all treasury pool accounts in the provided state - GetAllByState(ctx context.Context, state TreasuryPoolState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*Record, error) + GetAllByState(ctx context.Context, state PoolState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*Record, error) // SaveFunding saves a funding history record for a treasury pool vault SaveFunding(ctx context.Context, record *FundingHistoryRecord) error diff --git a/pkg/code/data/treasury/tests/tests.go b/pkg/code/data/treasury/tests/tests.go index 1694e053..e8d84936 100644 --- a/pkg/code/data/treasury/tests/tests.go +++ b/pkg/code/data/treasury/tests/tests.go @@ -8,9 +8,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/code-payments/code-server/pkg/code/data/treasury" "github.com/code-payments/code-server/pkg/database/query" splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" - "github.com/code-payments/code-server/pkg/code/data/treasury" ) func RunTests(t *testing.T, s treasury.Store, teardown func()) { @@ -51,7 +51,7 @@ func testTreasuryPoolHappyPath(t *testing.T, s treasury.Store) { SolanaBlock: 3, - State: treasury.TreasuryPoolStateAvailable, + State: treasury.PoolStateAvailable, } cloned := expected.Clone() @@ -102,7 +102,7 @@ func testTreasuryPoolHappyPath(t *testing.T, s treasury.Store) { } expected.CurrentIndex = 2 - expected.SolanaBlock += 1 + expected.SolanaBlock++ cloned = expected.Clone() require.NoError(t, s.Save(ctx, expected)) assertEquivalentTreasuryPoolRecords(t, cloned, expected) @@ -130,33 +130,33 @@ func testGetAllByState(t *testing.T, s treasury.Store) { t.Run("testGetAllByState", func(t *testing.T) { ctx := context.Background() - _, err := s.GetAllByState(ctx, treasury.TreasuryPoolStateAvailable, query.EmptyCursor, 10, query.Ascending) + _, err := s.GetAllByState(ctx, treasury.PoolStateAvailable, query.EmptyCursor, 10, query.Ascending) assert.Equal(t, treasury.ErrTreasuryPoolNotFound, err) expected := []*treasury.Record{ - {DataVersion: splitter_token.DataVersion1, Name: "name1", Address: "treasury1", Vault: "vault1", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root1"}, SolanaBlock: 1, State: treasury.TreasuryPoolStateAvailable}, - {DataVersion: splitter_token.DataVersion1, Name: "name2", Address: "treasury2", Vault: "vault2", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root2"}, SolanaBlock: 2, State: treasury.TreasuryPoolStateAvailable}, - {DataVersion: splitter_token.DataVersion1, Name: "name3", Address: "treasury3", Vault: "vault3", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root3"}, SolanaBlock: 3, State: treasury.TreasuryPoolStateAvailable}, - {DataVersion: splitter_token.DataVersion1, Name: "name4", Address: "treasury4", Vault: "vault4", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root4"}, SolanaBlock: 4, State: treasury.TreasuryPoolStateDeprecated}, - {DataVersion: splitter_token.DataVersion1, Name: "name5", Address: "treasury5", Vault: "vault5", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root5"}, SolanaBlock: 5, State: treasury.TreasuryPoolStateDeprecated}, + {DataVersion: splitter_token.DataVersion1, Name: "name1", Address: "treasury1", Vault: "vault1", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root1"}, SolanaBlock: 1, State: treasury.PoolStateAvailable}, + {DataVersion: splitter_token.DataVersion1, Name: "name2", Address: "treasury2", Vault: "vault2", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root2"}, SolanaBlock: 2, State: treasury.PoolStateAvailable}, + {DataVersion: splitter_token.DataVersion1, Name: "name3", Address: "treasury3", Vault: "vault3", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root3"}, SolanaBlock: 3, State: treasury.PoolStateAvailable}, + {DataVersion: splitter_token.DataVersion1, Name: "name4", Address: "treasury4", Vault: "vault4", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root4"}, SolanaBlock: 4, State: treasury.PoolStateDeprecated}, + {DataVersion: splitter_token.DataVersion1, Name: "name5", Address: "treasury5", Vault: "vault5", Authority: "code", MerkleTreeLevels: 32, CurrentIndex: 0, HistoryListSize: 1, HistoryList: []string{"root5"}, SolanaBlock: 5, State: treasury.PoolStateDeprecated}, } for _, record := range expected { require.NoError(t, s.Save(ctx, record)) } - _, err = s.GetAllByState(ctx, treasury.TreasuryPoolStateUnknown, query.EmptyCursor, 10, query.Ascending) + _, err = s.GetAllByState(ctx, treasury.PoolStateUnknown, query.EmptyCursor, 10, query.Ascending) assert.Equal(t, treasury.ErrTreasuryPoolNotFound, err) - actual, err := s.GetAllByState(ctx, treasury.TreasuryPoolStateAvailable, query.EmptyCursor, 10, query.Ascending) + actual, err := s.GetAllByState(ctx, treasury.PoolStateAvailable, query.EmptyCursor, 10, query.Ascending) require.NoError(t, err) assert.Len(t, actual, 3) - actual, err = s.GetAllByState(ctx, treasury.TreasuryPoolStateDeprecated, query.EmptyCursor, 10, query.Ascending) + actual, err = s.GetAllByState(ctx, treasury.PoolStateDeprecated, query.EmptyCursor, 10, query.Ascending) require.NoError(t, err) assert.Len(t, actual, 2) // Check items (asc) - actual, err = s.GetAllByState(ctx, treasury.TreasuryPoolStateAvailable, query.EmptyCursor, 5, query.Ascending) + actual, err = s.GetAllByState(ctx, treasury.PoolStateAvailable, query.EmptyCursor, 5, query.Ascending) require.NoError(t, err) require.Len(t, actual, 3) assert.Equal(t, "treasury1", actual[0].Address) @@ -164,7 +164,7 @@ func testGetAllByState(t *testing.T, s treasury.Store) { assert.Equal(t, "treasury3", actual[2].Address) // Check items (desc) - actual, err = s.GetAllByState(ctx, treasury.TreasuryPoolStateAvailable, query.EmptyCursor, 5, query.Descending) + actual, err = s.GetAllByState(ctx, treasury.PoolStateAvailable, query.EmptyCursor, 5, query.Descending) require.NoError(t, err) require.Len(t, actual, 3) assert.Equal(t, "treasury3", actual[0].Address) @@ -172,28 +172,28 @@ func testGetAllByState(t *testing.T, s treasury.Store) { assert.Equal(t, "treasury1", actual[2].Address) // Check items (asc + limit) - actual, err = s.GetAllByState(ctx, treasury.TreasuryPoolStateAvailable, query.EmptyCursor, 2, query.Ascending) + actual, err = s.GetAllByState(ctx, treasury.PoolStateAvailable, query.EmptyCursor, 2, query.Ascending) require.NoError(t, err) require.Len(t, actual, 2) assert.Equal(t, "treasury1", actual[0].Address) assert.Equal(t, "treasury2", actual[1].Address) // Check items (desc + limit) - actual, err = s.GetAllByState(ctx, treasury.TreasuryPoolStateAvailable, query.EmptyCursor, 2, query.Descending) + actual, err = s.GetAllByState(ctx, treasury.PoolStateAvailable, query.EmptyCursor, 2, query.Descending) require.NoError(t, err) require.Len(t, actual, 2) assert.Equal(t, "treasury3", actual[0].Address) assert.Equal(t, "treasury2", actual[1].Address) // Check items (asc + cursor) - actual, err = s.GetAllByState(ctx, treasury.TreasuryPoolStateAvailable, query.ToCursor(1), 5, query.Ascending) + actual, err = s.GetAllByState(ctx, treasury.PoolStateAvailable, query.ToCursor(1), 5, query.Ascending) require.NoError(t, err) require.Len(t, actual, 2) assert.Equal(t, "treasury2", actual[0].Address) assert.Equal(t, "treasury3", actual[1].Address) // Check items (desc + cursor) - actual, err = s.GetAllByState(ctx, treasury.TreasuryPoolStateAvailable, query.ToCursor(3), 5, query.Descending) + actual, err = s.GetAllByState(ctx, treasury.PoolStateAvailable, query.ToCursor(3), 5, query.Descending) require.NoError(t, err) require.Len(t, actual, 2) assert.Equal(t, "treasury2", actual[0].Address) diff --git a/pkg/code/data/treasury/treasury.go b/pkg/code/data/treasury/treasury.go index 7e65928c..096a7d05 100644 --- a/pkg/code/data/treasury/treasury.go +++ b/pkg/code/data/treasury/treasury.go @@ -10,13 +10,13 @@ import ( splitter_token "github.com/code-payments/code-server/pkg/solana/splitter" ) -type TreasuryPoolState uint8 +type PoolState uint8 type FundingState uint8 const ( - TreasuryPoolStateUnknown TreasuryPoolState = iota - TreasuryPoolStateAvailable - TreasuryPoolStateDeprecated + PoolStateUnknown PoolState = iota + PoolStateAvailable + PoolStateDeprecated ) const ( @@ -49,7 +49,7 @@ type Record struct { SolanaBlock uint64 - State TreasuryPoolState // currently managed manually + State PoolState // currently managed manually LastUpdatedAt time.Time } @@ -256,11 +256,11 @@ func (r *Record) CopyTo(dst *Record) { dst.LastUpdatedAt = r.LastUpdatedAt } -func (s TreasuryPoolState) String() string { +func (s PoolState) String() string { switch s { - case TreasuryPoolStateAvailable: + case PoolStateAvailable: return "available" - case TreasuryPoolStateDeprecated: + case PoolStateDeprecated: return "deprecated" } return "unknown" diff --git a/pkg/code/data/twitter/tests/tests.go b/pkg/code/data/twitter/tests/tests.go index 11230d06..4d715e85 100644 --- a/pkg/code/data/twitter/tests/tests.go +++ b/pkg/code/data/twitter/tests/tests.go @@ -190,7 +190,6 @@ func testGetStaleUsers(t *testing.T, s twitter.Store) { require.NoError(t, err) require.Len(t, res, 1) assert.Equal(t, "username0", res[0].Username) - }) } diff --git a/pkg/code/data/user/identity/memory/store.go b/pkg/code/data/user/identity/memory/store.go index f08f0a82..0b3c0676 100644 --- a/pkg/code/data/user/identity/memory/store.go +++ b/pkg/code/data/user/identity/memory/store.go @@ -41,7 +41,7 @@ func (s *store) Put(ctx context.Context, record *user_identity.Record) error { return user_identity.ErrAlreadyExists } - copy := &user_identity.Record{ + cpy := &user_identity.Record{ ID: record.ID, View: &user.View{ PhoneNumber: record.View.PhoneNumber, @@ -51,14 +51,14 @@ func (s *store) Put(ctx context.Context, record *user_identity.Record) error { CreatedAt: record.CreatedAt, } - s.usersByID[record.ID.String()] = copy - s.usersByPhoneNumber[*record.View.PhoneNumber] = copy + s.usersByID[record.ID.String()] = cpy + s.usersByPhoneNumber[*record.View.PhoneNumber] = cpy return nil } // GetByID implements user_identity.Store.GetByID -func (s *store) GetByID(ctx context.Context, id *user.UserID) (*user_identity.Record, error) { +func (s *store) GetByID(ctx context.Context, id *user.Id) (*user_identity.Record, error) { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/code/data/user/identity/postgres/model.go b/pkg/code/data/user/identity/postgres/model.go index a4c252f9..5a119ca6 100644 --- a/pkg/code/data/user/identity/postgres/model.go +++ b/pkg/code/data/user/identity/postgres/model.go @@ -78,7 +78,7 @@ func (m *model) dbSave(ctx context.Context, db *sqlx.DB) error { return pgutil.CheckUniqueViolation(err, user_identity.ErrAlreadyExists) } -func dbGetByID(ctx context.Context, db *sqlx.DB, id *user.UserID) (*model, error) { +func dbGetByID(ctx context.Context, db *sqlx.DB, id *user.Id) (*model, error) { query := `SELECT id, user_id, phone_number, is_staff_user, is_banned, created_at FROM ` + tableName + ` WHERE user_id = $1` diff --git a/pkg/code/data/user/identity/postgres/model_test.go b/pkg/code/data/user/identity/postgres/model_test.go index 497af2c9..19942baf 100644 --- a/pkg/code/data/user/identity/postgres/model_test.go +++ b/pkg/code/data/user/identity/postgres/model_test.go @@ -14,7 +14,7 @@ import ( func TestModelConversion(t *testing.T) { phoneNumber := "+12223334444" record := &user_identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, diff --git a/pkg/code/data/user/identity/postgres/store.go b/pkg/code/data/user/identity/postgres/store.go index fd2fef2b..c37bab54 100644 --- a/pkg/code/data/user/identity/postgres/store.go +++ b/pkg/code/data/user/identity/postgres/store.go @@ -31,7 +31,7 @@ func (s *store) Put(ctx context.Context, record *user_identity.Record) error { } // GetByID implements user_identity.Store.GetByID -func (s *store) GetByID(ctx context.Context, id *user.UserID) (*user_identity.Record, error) { +func (s *store) GetByID(ctx context.Context, id *user.Id) (*user_identity.Record, error) { model, err := dbGetByID(ctx, s.db, id) if err != nil { return nil, err diff --git a/pkg/code/data/user/identity/store.go b/pkg/code/data/user/identity/store.go index 76b9e9c2..c9e856e1 100644 --- a/pkg/code/data/user/identity/store.go +++ b/pkg/code/data/user/identity/store.go @@ -19,7 +19,7 @@ var ( // User is the highest order of a form of identity. type Record struct { - ID *user.UserID + ID *user.Id View *user.View IsStaffUser bool IsBanned bool @@ -32,7 +32,7 @@ type Store interface { Put(ctx context.Context, record *Record) error // GetByID fetches a user by its ID. - GetByID(ctx context.Context, id *user.UserID) (*Record, error) + GetByID(ctx context.Context, id *user.Id) (*Record, error) // GetByView fetches a user by a view. GetByView(ctx context.Context, view *user.View) (*Record, error) diff --git a/pkg/code/data/user/identity/tests/tests.go b/pkg/code/data/user/identity/tests/tests.go index 8ebeecac..9814d7df 100644 --- a/pkg/code/data/user/identity/tests/tests.go +++ b/pkg/code/data/user/identity/tests/tests.go @@ -28,7 +28,7 @@ func testHappyPath(t *testing.T, s user_identity.Store) { phoneNumber := "+12223334444" record := &user_identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, diff --git a/pkg/code/data/user/model.go b/pkg/code/data/user/model.go index 96973642..b4fe4d0b 100644 --- a/pkg/code/data/user/model.go +++ b/pkg/code/data/user/model.go @@ -10,8 +10,8 @@ import ( "github.com/code-payments/code-server/pkg/phone" ) -// UserID uniquely identifies a user -type UserID struct { +// Id uniquely identifies a user +type Id struct { id uuid.UUID } @@ -34,34 +34,34 @@ type View struct { PhoneNumber *string } -// NewUserID returns a new random UserID -func NewUserID() *UserID { - return &UserID{ +// NewID returns a new random Id +func NewID() *Id { + return &Id{ id: uuid.New(), } } -// GetUserIDFromProto returns a UserID from the protobuf message -func GetUserIDFromProto(proto *commonpb.UserId) (*UserID, error) { +// GetUserIDFromProto returns a Id from the protobuf message +func GetUserIDFromProto(proto *commonpb.UserId) (*Id, error) { id, err := uuid.FromBytes(proto.Value) if err != nil { return nil, err } - return &UserID{id}, nil + return &Id{id}, nil } -// GetUserIDFromString parses a UserID from a string value -func GetUserIDFromString(value string) (*UserID, error) { +// GetUserIDFromString parses a Id from a string value +func GetUserIDFromString(value string) (*Id, error) { id, err := uuid.Parse(value) if err != nil { return nil, err } - return &UserID{id}, nil + return &Id{id}, nil } -// Validate validate a UserID -func (id *UserID) Validate() error { +// Validate validate a Id +func (id *Id) Validate() error { if id == nil { return errors.New("user id is nil") } @@ -74,13 +74,13 @@ func (id *UserID) Validate() error { return nil } -// String returns the string form of a UserID -func (id *UserID) String() string { +// String returns the string form of a Id +func (id *Id) String() string { return id.id.String() } -// Proto returns a UserID into its protobuf message form -func (id *UserID) Proto() *commonpb.UserId { +// Proto returns a Id into its protobuf message form +func (id *Id) Proto() *commonpb.UserId { return &commonpb.UserId{ Value: id.id[:], } @@ -93,7 +93,7 @@ func NewDataContainerID() *DataContainerID { } } -// GetDataContainerIDFromProto returns a UserID from the protobuf message +// GetDataContainerIDFromProto returns a Id from the protobuf message func GetDataContainerIDFromProto(proto *commonpb.DataContainerId) (*DataContainerID, error) { id, err := uuid.FromBytes(proto.Value) if err != nil { @@ -131,7 +131,7 @@ func (id *DataContainerID) String() string { return id.id.String() } -// Proto returns a UserID into its protobuf message form +// Proto returns a Id into its protobuf message form func (id *DataContainerID) Proto() *commonpb.DataContainerId { return &commonpb.DataContainerId{ Value: id.id[:], diff --git a/pkg/code/data/user/model_test.go b/pkg/code/data/user/model_test.go index bf758549..d06c3931 100644 --- a/pkg/code/data/user/model_test.go +++ b/pkg/code/data/user/model_test.go @@ -28,11 +28,11 @@ func TestUserIDProtoConverstion(t *testing.T) { } func TestUserIDValidation(t *testing.T) { - var nilUserID *UserID - emptyUserID := &UserID{} + var nilUserID *Id + emptyUserID := &Id{} assert.Error(t, nilUserID.Validate()) assert.Error(t, emptyUserID.Validate()) - assert.NoError(t, NewUserID().Validate()) + assert.NoError(t, NewID().Validate()) } func TestDataContainerIDStringConverstion(t *testing.T) { diff --git a/pkg/code/data/user/storage/memory/store.go b/pkg/code/data/user/storage/memory/store.go index 8390a6bb..177918ff 100644 --- a/pkg/code/data/user/storage/memory/store.go +++ b/pkg/code/data/user/storage/memory/store.go @@ -43,7 +43,7 @@ func (s *store) Put(ctx context.Context, container *user_storage.Record) error { return user_storage.ErrAlreadyExists } - copy := &user_storage.Record{ + cpy := &user_storage.Record{ ID: container.ID, OwnerAccount: container.OwnerAccount, IdentifyingFeatures: &user.IdentifyingFeatures{ @@ -52,8 +52,8 @@ func (s *store) Put(ctx context.Context, container *user_storage.Record) error { CreatedAt: container.CreatedAt, } - s.dataContainersByID[container.ID.String()] = copy - s.dataContainersByFeatures[key] = copy + s.dataContainersByID[container.ID.String()] = cpy + s.dataContainersByFeatures[key] = cpy return nil } diff --git a/pkg/code/data/vault/key_test.go b/pkg/code/data/vault/key_test.go index dbbb8ce7..f6d4c719 100644 --- a/pkg/code/data/vault/key_test.go +++ b/pkg/code/data/vault/key_test.go @@ -10,7 +10,6 @@ import ( ) func TestCreateKey(t *testing.T) { - for i := 0; i < 100; i++ { actual, err := vault.CreateKey() require.NoError(t, err) @@ -22,7 +21,6 @@ func TestCreateKey(t *testing.T) { _, err = base58.Decode(actual.PrivateKey) require.NoError(t, err) } - } func TestGrindKey(t *testing.T) { diff --git a/pkg/code/data/vault/memory/store.go b/pkg/code/data/vault/memory/store.go index 956ff86a..1d1ba987 100644 --- a/pkg/code/data/vault/memory/store.go +++ b/pkg/code/data/vault/memory/store.go @@ -5,8 +5,8 @@ import ( "sort" "sync" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/vault" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -146,7 +146,6 @@ func (s *store) Get(ctx context.Context, sig string) (*vault.Record, error) { defer s.mu.Unlock() if item := s.findPublicKey(sig); item != nil { - val, err := vault.Decrypt(item.PrivateKey, item.PublicKey) if err != nil { return nil, err diff --git a/pkg/code/data/vault/postgres/model.go b/pkg/code/data/vault/postgres/model.go index 94078827..e0e01d4c 100644 --- a/pkg/code/data/vault/postgres/model.go +++ b/pkg/code/data/vault/postgres/model.go @@ -44,7 +44,6 @@ func toKeyModel(obj *vault.Record) (*vaultModel, error) { } func fromKeyModel(obj *vaultModel) *vault.Record { - return &vault.Record{ Id: uint64(obj.Id.Int64), PublicKey: obj.PublicKey, diff --git a/pkg/code/data/vault/postgres/store.go b/pkg/code/data/vault/postgres/store.go index 2ca627c5..4e707d63 100644 --- a/pkg/code/data/vault/postgres/store.go +++ b/pkg/code/data/vault/postgres/store.go @@ -4,8 +4,8 @@ import ( "context" "database/sql" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/vault" + "github.com/code-payments/code-server/pkg/database/query" "github.com/jmoiron/sqlx" ) @@ -81,7 +81,6 @@ func (s *store) GetAllByState(ctx context.Context, state vault.State, cursor que keys := make([]*vault.Record, len(models)) for i, model := range models { - plaintext, err := vault.Decrypt(model.PrivateKey, model.PublicKey) if err != nil { return nil, err diff --git a/pkg/code/data/vault/tests/tests.go b/pkg/code/data/vault/tests/tests.go index 49f2a74a..bfffe0dc 100644 --- a/pkg/code/data/vault/tests/tests.go +++ b/pkg/code/data/vault/tests/tests.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/vault" + "github.com/code-payments/code-server/pkg/database/query" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,6 +16,7 @@ func RunTests(t *testing.T, s vault.Store, teardown func()) { testRoundTrip, testUpdate, testGetAllByState, + testGetCount, } { tf(t, s) teardown() diff --git a/pkg/code/data/webhook/memory/store.go b/pkg/code/data/webhook/memory/store.go index 754ae807..36f807ec 100644 --- a/pkg/code/data/webhook/memory/store.go +++ b/pkg/code/data/webhook/memory/store.go @@ -5,8 +5,8 @@ import ( "sync" "time" - "github.com/code-payments/code-server/pkg/pointer" "github.com/code-payments/code-server/pkg/code/data/webhook" + "github.com/code-payments/code-server/pkg/pointer" ) type store struct { @@ -32,15 +32,15 @@ func (s *store) Put(_ context.Context, data *webhook.Record) error { s.last++ if item := s.find(data); item != nil { return webhook.ErrAlreadyExists - } else { - if data.Id == 0 { - data.Id = s.last - } - data.CreatedAt = time.Now() + } - cloned := data.Clone() - s.records = append(s.records, &cloned) + if data.Id == 0 { + data.Id = s.last } + data.CreatedAt = time.Now() + + cloned := data.Clone() + s.records = append(s.records, &cloned) return nil } diff --git a/pkg/code/lawenforcement/anti_money_laundering_test.go b/pkg/code/lawenforcement/anti_money_laundering_test.go index 59a13197..bdb393f0 100644 --- a/pkg/code/lawenforcement/anti_money_laundering_test.go +++ b/pkg/code/lawenforcement/anti_money_laundering_test.go @@ -340,19 +340,20 @@ func setupAmlTest(t *testing.T) (env amlTestEnv) { testutil.SetupRandomSubsidizer(t, env.data) - env.data.ImportExchangeRates(env.ctx, ¤cy.MultiRateRecord{ + err := env.data.ImportExchangeRates(env.ctx, ¤cy.MultiRateRecord{ Time: time.Now(), Rates: map[string]float64{ string(currency_lib.USD): 0.1, }, }) + require.NoError(t, err) return env } func setupPhoneUser(t *testing.T, env amlTestEnv, phoneNumber string) { require.NoError(t, env.data.PutUser(env.ctx, &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, diff --git a/pkg/code/localization/keys.go b/pkg/code/localization/keys.go index 10a78abc..771c1170 100644 --- a/pkg/code/localization/keys.go +++ b/pkg/code/localization/keys.go @@ -112,7 +112,7 @@ func LoadKeys(directory string) error { } // LoadTestKeys is a utility for injecting test localization keys -func LoadTestKeys(kvsByLocale map[language.Tag]map[string]string) { +func LoadTestKeys(kvsByLocale map[language.Tag]map[string]string) error { bundleMu.Lock() defer bundleMu.Unlock() @@ -126,10 +126,14 @@ func LoadTestKeys(kvsByLocale map[language.Tag]map[string]string) { Other: v, }) } - newBundle.AddMessages(locale, messages...) + + if err := newBundle.AddMessages(locale, messages...); err != nil { + return err + } } bundle = newBundle + return nil } // ResetKeys resets localization to an empty mapping diff --git a/pkg/code/push/badge_count.go b/pkg/code/push/badge_count.go index b2465be4..31405c80 100644 --- a/pkg/code/push/badge_count.go +++ b/pkg/code/push/badge_count.go @@ -2,8 +2,8 @@ package push import ( "context" + "fmt" - "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/code-payments/code-server/pkg/code/common" @@ -84,12 +84,19 @@ func UpdateBadgeCount( ) if pushErr != nil { - log.WithError(err).Warn("failure sending push notification") - isValid, err := onPushError(ctx, data, pusher, pushTokenRecord) - if isValid { - return errors.Wrap(pushErr, "error pushing to valid token") - } else if err != nil { - return errors.Wrap(err, "error handling push error") + isValid, onPushErr := onPushError(ctx, data, pusher, pushTokenRecord) + + log.WithError(pushErr). + WithFields(logrus.Fields{ + "on_push_error": onPushErr, + "is_valid": isValid, + }). + Warn("failure sending push notification") + + if onPushErr != nil { + return fmt.Errorf("failed to handle push error (%w): %w", pushErr, onPushErr) + } else if isValid { + return fmt.Errorf("failed to push to valid token: %w", pushErr) } } } diff --git a/pkg/code/push/data.go b/pkg/code/push/data.go index bfe7fc7f..dd573d10 100644 --- a/pkg/code/push/data.go +++ b/pkg/code/push/data.go @@ -2,7 +2,6 @@ package push import ( "context" - "github.com/sirupsen/logrus" "github.com/code-payments/code-server/pkg/code/common" @@ -19,57 +18,6 @@ const ( chatMessageDataPush dataPushType = "ChatMessage" ) -// sendRawDataPushNotificationToOwner is a generic utility for sending raw data push -// notification to the devices linked to an owner account. -// -// todo: Duplicated code with other send push utitilies -func sendRawDataPushNotificationToOwner( - ctx context.Context, - data code_data.Provider, - pusher push_lib.Provider, - owner *common.Account, - notificationType dataPushType, - kvs map[string]string, -) error { - log := logrus.StandardLogger().WithFields(logrus.Fields{ - "method": "sendRawDataPushNotificationToOwner", - "owner": owner.PublicKey().ToBase58(), - }) - - kvs[dataPushTypeKey] = string(notificationType) - - pushTokenRecords, err := getPushTokensForOwner(ctx, data, owner) - if err != nil { - log.WithError(err).Warn("failure getting push tokens for owner") - return err - } - - seenPushTokens := make(map[string]struct{}) - for _, pushTokenRecord := range pushTokenRecords { - // Dedup push tokens, since they may appear more than once per app install - if _, ok := seenPushTokens[pushTokenRecord.PushToken]; ok { - continue - } - - log := log.WithField("push_token", pushTokenRecord.PushToken) - - // Try push - err := pusher.SendDataPush( - ctx, - pushTokenRecord.PushToken, - kvs, - ) - - if err != nil { - log.WithError(err).Warn("failure sending push notification") - onPushError(ctx, data, pusher, pushTokenRecord) - } - - seenPushTokens[pushTokenRecord.PushToken] = struct{}{} - } - return nil -} - // sendMutableNotificationToOwner is a generic utility for sending mutable // push notification to the devices linked to an owner account. It's a // special data push where the notification content is replaced by the contents @@ -129,8 +77,13 @@ func sendMutableNotificationToOwner( } if err != nil { - log.WithError(err).Warn("failure sending push notification") - onPushError(ctx, data, pusher, pushTokenRecord) + isValid, onPushErr := onPushError(ctx, data, pusher, pushTokenRecord) + log.WithError(err). + WithFields(logrus.Fields{ + "on_push_error": onPushErr, + "is_valid": isValid, + }). + Warn("failure sending push notification") } seenPushTokens[pushTokenRecord.PushToken] = struct{}{} diff --git a/pkg/code/push/text.go b/pkg/code/push/text.go index 80930bce..e53b1e00 100644 --- a/pkg/code/push/text.go +++ b/pkg/code/push/text.go @@ -50,8 +50,12 @@ func sendBasicPushNotificationToOwner( ) if err != nil { - log.WithError(err).Warn("failure sending push notification") - onPushError(ctx, data, pusher, pushTokenRecord) + isValid, onPushErr := onPushError(ctx, data, pusher, pushTokenRecord) + log.WithError(err). + WithFields(logrus.Fields{ + "cleanup_error": onPushErr, + "is_valid": isValid, + }).Warn("failed to send push notification (best effort)") } seenPushTokens[pushTokenRecord.PushToken] = struct{}{} diff --git a/pkg/code/push/util.go b/pkg/code/push/util.go index 55e33556..b72da9b5 100644 --- a/pkg/code/push/util.go +++ b/pkg/code/push/util.go @@ -2,6 +2,8 @@ package push import ( "context" + "fmt" + "github.com/sirupsen/logrus" "github.com/pkg/errors" @@ -34,8 +36,18 @@ func getPushTokensForOwner(ctx context.Context, data code_data.Provider, owner * func onPushError(ctx context.Context, data code_data.Provider, pusher push_lib.Provider, pushTokenRecord *push_data.Record) (bool, error) { // On failure, verify token validity, and cleanup if necessary isValid, err := pusher.IsValidPushToken(ctx, pushTokenRecord.PushToken) - if err == nil && !isValid { - data.DeletePushToken(ctx, pushTokenRecord.PushToken) + if isValid { + return true, nil + } else if err != nil { + return false, fmt.Errorf("failed to check if push token is valid: %w", err) + } + + if err := data.DeletePushToken(ctx, pushTokenRecord.PushToken); err != nil { + logrus.StandardLogger().WithFields(logrus.Fields{ + "method": "onPushError", + "token": pushTokenRecord.PushToken, + }).WithError(err).Warn("failed to cleanup invalid push token (best effort)") } - return isValid, err + + return false, nil } diff --git a/pkg/code/server/grpc/account/server.go b/pkg/code/server/grpc/account/server.go index 2d994c52..c3326197 100644 --- a/pkg/code/server/grpc/account/server.go +++ b/pkg/code/server/grpc/account/server.go @@ -22,7 +22,6 @@ import ( "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/grpc/client" - "github.com/code-payments/code-server/pkg/kin" timelock_token_v1 "github.com/code-payments/code-server/pkg/solana/timelock/v1" ) @@ -581,7 +580,3 @@ func (s *server) getOriginalGiftCardExchangeData(ctx context.Context, records *c Quarks: intentRecord.SendPrivatePaymentMetadata.Quantity, }, nil } - -func hideDust(quarks uint64) uint64 { - return kin.ToQuarks(kin.FromQuarks(quarks)) -} diff --git a/pkg/code/server/grpc/account/server_test.go b/pkg/code/server/grpc/account/server_test.go index a1274a79..50c2ffda 100644 --- a/pkg/code/server/grpc/account/server_test.go +++ b/pkg/code/server/grpc/account/server_test.go @@ -8,11 +8,11 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" accountpb "github.com/code-payments/code-protobuf-api/generated/go/account/v1" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" @@ -133,7 +133,7 @@ func TestIsCodeAccount_LegacyPrimary2022Migration_HappyPath(t *testing.T) { assert.Equal(t, accountpb.IsCodeAccountResponse_OK, resp.Result) legacyAccountRecords.Timelock.VaultState = timelock_token_v1.StateClosed - legacyAccountRecords.Timelock.Block += 1 + legacyAccountRecords.Timelock.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, legacyAccountRecords.Timelock)) resp, err = env.client.IsCodeAccount(env.ctx, req) @@ -177,7 +177,7 @@ func TestIsCodeAccount_NotManagedByCode(t *testing.T) { assert.Equal(t, accountpb.IsCodeAccountResponse_OK, resp.Result) allAccountRecords[i].Timelock.VaultState = unmanagedState - allAccountRecords[i].Timelock.Block += 1 + allAccountRecords[i].Timelock.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, allAccountRecords[i].Timelock)) resp, err = env.client.IsCodeAccount(env.ctx, req) @@ -456,7 +456,7 @@ func TestGetTokenAccountInfos_RemoteSendGiftCard_HappyPath(t *testing.T) { } userIdentityRecord := &user_identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -534,7 +534,7 @@ func TestGetTokenAccountInfos_RemoteSendGiftCard_HappyPath(t *testing.T) { } accountRecords.Timelock.VaultState = tc.timelockState - accountRecords.Timelock.Block += 1 + accountRecords.Timelock.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, accountRecords.Timelock)) resp, err := env.client.GetTokenAccountInfos(env.ctx, req) @@ -608,7 +608,7 @@ func TestGetTokenAccountInfos_BlockchainState(t *testing.T) { accountRecords := getDefaultTestAccountRecords(t, env, ownerAccount, ownerAccount, 0, commonpb.AccountType_PRIMARY) accountRecords.Timelock.VaultState = tc.timelockState - accountRecords.Timelock.Block += 1 + accountRecords.Timelock.Block++ require.NoError(t, env.data.CreateAccountInfo(env.ctx, accountRecords.General)) require.NoError(t, env.data.SaveTimelock(env.ctx, accountRecords.Timelock)) @@ -889,7 +889,7 @@ func TestGetTokenAccountInfos_LegacyPrimary2022Migration_AccountClosed(t *testin assert.Len(t, resp.TokenAccountInfos, 1) accountRecords.Timelock.VaultState = timelock_token_v1.StateClosed - accountRecords.Timelock.Block += 1 + accountRecords.Timelock.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, accountRecords.Timelock)) resp, err = env.client.GetTokenAccountInfos(env.ctx, req) @@ -1115,7 +1115,7 @@ func setupAccountRecords(t *testing.T, env testEnv, ownerAccount, authorityAccou if accountRecords.IsTimelock() { accountRecords.Timelock.VaultState = timelock_token_v1.StateLocked - accountRecords.Timelock.Block += 1 + accountRecords.Timelock.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, accountRecords.Timelock)) } diff --git a/pkg/code/server/grpc/badge/server_test.go b/pkg/code/server/grpc/badge/server_test.go index 70b02d30..e3443a1d 100644 --- a/pkg/code/server/grpc/badge/server_test.go +++ b/pkg/code/server/grpc/badge/server_test.go @@ -14,8 +14,6 @@ import ( badgepb "github.com/code-payments/code-protobuf-api/generated/go/badge/v1" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" - memory_push "github.com/code-payments/code-server/pkg/push/memory" - "github.com/code-payments/code-server/pkg/testutil" auth_util "github.com/code-payments/code-server/pkg/code/auth" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" @@ -25,6 +23,8 @@ import ( "github.com/code-payments/code-server/pkg/code/data/user" user_identity "github.com/code-payments/code-server/pkg/code/data/user/identity" user_storage "github.com/code-payments/code-server/pkg/code/data/user/storage" + memory_push "github.com/code-payments/code-server/pkg/push/memory" + "github.com/code-payments/code-server/pkg/testutil" ) func TestResetBadgeCount_HappyPath(t *testing.T) { @@ -45,7 +45,7 @@ func TestResetBadgeCount_HappyPath(t *testing.T) { assert.Equal(t, resp.Result, badgepb.ResetBadgeCountResponse_OK) env.assertBadgeCount(t, owner, 0) - env.data.AddToBadgeCount(env.ctx, owner.PublicKey().ToBase58(), 5) + require.NoError(t, env.data.AddToBadgeCount(env.ctx, owner.PublicKey().ToBase58(), 5)) env.assertBadgeCount(t, owner, 5) resp, err = env.client.ResetBadgeCount(env.ctx, req) @@ -108,7 +108,7 @@ func (e *testEnv) createUser(t *testing.T, owner *common.Account, phoneNumber st require.NoError(t, e.data.SavePhoneVerification(e.ctx, phoneVerificationRecord)) userIdentityRecord := &user_identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/server.go index 362c0fd7..04662e4d 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/server.go @@ -131,7 +131,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch skipUnreadCountQuery := chatRecord.IsUnsubscribed switch chatRecord.ChatType { - case chat.ChatTypeInternal: + case chat.TypeInternal: chatProperties, ok := chat_util.InternalChatProperties[chatRecord.ChatTitle] if !ok { log.Warnf("%s chat doesn't have properties defined", chatRecord.ChatTitle) @@ -157,7 +157,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } protoMetadata.CanMute = chatProperties.CanMute protoMetadata.CanUnsubscribe = chatProperties.CanUnsubscribe - case chat.ChatTypeExternalApp: + case chat.TypeExternalApp: protoMetadata.Title = &chatpb.ChatMetadata_Domain{ Domain: &commonpb.Domain{ Value: chatRecord.ChatTitle, @@ -220,7 +220,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId := chat.ChatIdFromProto(req.ChatId) + chatId := chat.IdFromProto(req.ChatId) log = log.WithField("chat_id", chatId.String()) signature := req.Signature @@ -347,7 +347,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId := chat.ChatIdFromProto(req.ChatId) + chatId := chat.IdFromProto(req.ChatId) log = log.WithField("chat_id", chatId.String()) messageId := base58.Encode(req.Pointer.Value.Value) @@ -425,7 +425,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId := chat.ChatIdFromProto(req.ChatId) + chatId := chat.IdFromProto(req.ChatId) log = log.WithField("chat_id", chatId.String()) signature := req.Signature @@ -483,7 +483,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId := chat.ChatIdFromProto(req.ChatId) + chatId := chat.IdFromProto(req.ChatId) log = log.WithField("chat_id", chatId.String()) signature := req.Signature diff --git a/pkg/code/server/grpc/chat/server_test.go b/pkg/code/server/grpc/chat/server_test.go index f6dd6eff..d38035ca 100644 --- a/pkg/code/server/grpc/chat/server_test.go +++ b/pkg/code/server/grpc/chat/server_test.go @@ -6,12 +6,12 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/text/language" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" @@ -41,12 +41,13 @@ func TestGetChatsAndMessages_HappyPath(t *testing.T) { owner := testutil.NewRandomAccount(t) env.setupUserWithLocale(t, owner, language.English) - localization.LoadTestKeys(map[language.Tag]map[string]string{ + err := localization.LoadTestKeys(map[language.Tag]map[string]string{ language.English: { localization.ChatTitleCodeTeam: "Code Team", "msg.body.key": "localized message body content", }, }) + require.NoError(t, err) defer localization.ResetKeys() testExternalAppDomain := "test.com" @@ -893,12 +894,12 @@ func setup(t *testing.T) (env *testEnv, cleanup func()) { } func (e *testEnv) sendExternalAppChatMessage(t *testing.T, msg *chatpb.ChatMessage, domain string, isVerified bool, recipient *common.Account) { - _, err := chat_util.SendChatMessage(e.ctx, e.data, domain, chat.ChatTypeExternalApp, isVerified, recipient, msg, false) + _, err := chat_util.SendChatMessage(e.ctx, e.data, domain, chat.TypeExternalApp, isVerified, recipient, msg, false) require.NoError(t, err) } func (e *testEnv) sendInternalChatMessage(t *testing.T, msg *chatpb.ChatMessage, chatTitle string, recipient *common.Account) { - _, err := chat_util.SendChatMessage(e.ctx, e.data, chatTitle, chat.ChatTypeInternal, true, recipient, msg, false) + _, err := chat_util.SendChatMessage(e.ctx, e.data, chatTitle, chat.TypeInternal, true, recipient, msg, false) require.NoError(t, err) } diff --git a/pkg/code/server/grpc/currency/currency.go b/pkg/code/server/grpc/currency/currency.go index 86f2151d..95affa54 100644 --- a/pkg/code/server/grpc/currency/currency.go +++ b/pkg/code/server/grpc/currency/currency.go @@ -41,7 +41,7 @@ func (s *currencyServer) GetAllRates(ctx context.Context, req *currencypb.GetAll var record *currency.MultiRateRecord if req.Timestamp != nil && req.Timestamp.AsTime().Before(time.Now().Add(-15*time.Minute)) { record, err = s.LoadExchangeRatesForTime(ctx, req.Timestamp.AsTime()) - } else if req.Timestamp == nil || req.Timestamp.AsTime().Sub(time.Now()) < time.Hour { + } else if req.Timestamp == nil || time.Until(req.Timestamp.AsTime()) < time.Hour { record, err = s.LoadExchangeRatesLatest(ctx) } else { return nil, status.Error(codes.InvalidArgument, "timestamp too far in the future") diff --git a/pkg/code/server/grpc/messaging/internal.go b/pkg/code/server/grpc/messaging/internal.go index cee25d81..6e19c9b9 100644 --- a/pkg/code/server/grpc/messaging/internal.go +++ b/pkg/code/server/grpc/messaging/internal.go @@ -27,7 +27,7 @@ const ( ) // todo: Similar to the common push package, we should put message creation and -// a proper client (ie. not tied to the server) in a common messaging package. +// a proper client (ie. not tied to the Server) in a common messaging package. type InternalMessageClient interface { // InternallyCreateMessage creates and forwards a message on a stream @@ -35,12 +35,12 @@ type InternalMessageClient interface { InternallyCreateMessage(ctx context.Context, rendezvousKey *common.Account, message *messagingpb.Message) (uuid.UUID, error) } -// Note: Assumes messages are generated in a RPC server where the messaging +// Note: Assumes messages are generated in a RPC Server where the messaging // service exists. This likely won't be a good assumption (eg. message generated // in a worker), but is good enough to enable some initial use cases (eg. payment // requests). This is mostly an optimization around not needing to create a gRPC -// client if the stream and message generation are on the same server. -func (s *server) InternallyCreateMessage(ctx context.Context, rendezvousKey *common.Account, message *messagingpb.Message) (uuid.UUID, error) { +// client if the stream and message generation are on the same Server. +func (s *Server) InternallyCreateMessage(ctx context.Context, rendezvousKey *common.Account, message *messagingpb.Message) (uuid.UUID, error) { if message.Id != nil { return uuid.Nil, errors.New("message.id is generated in InternallyCreateMessage") } @@ -86,7 +86,7 @@ func (s *server) InternallyCreateMessage(ctx context.Context, rendezvousKey *com } // Best effort attempt to forward the message to the active stream - retry.Retry( + attempts, err := retry.Retry( func() error { return s.internallyForwardMessage(ctx, &messagingpb.SendMessageRequest{ RendezvousKey: &messagingpb.RendezvousKey{ @@ -96,7 +96,7 @@ func (s *server) InternallyCreateMessage(ctx context.Context, rendezvousKey *com Signature: &commonpb.Signature{ // Needs to be set to pass validation, but won't be used. This // is only required for client-initiated messages. Rendezvous - // private keys are typically hidden from server. + // private keys are typically hidden from Server. // // todo: Different RPCs for public versus internal message sending. Value: make([]byte, 64), @@ -106,11 +106,20 @@ func (s *server) InternallyCreateMessage(ctx context.Context, rendezvousKey *com retry.Limit(5), retry.Backoff(backoff.BinaryExponential(100*time.Millisecond), 500*time.Millisecond), ) + if err != nil { + s.log. + WithError(err). + WithFields(logrus.Fields{ + "method": "InternallyCreateMessage", + "attempts": attempts, + }). + Warn("failed to forward message (best effort)") + } return id, nil } -func (s *server) internallyForwardMessage(ctx context.Context, req *messagingpb.SendMessageRequest) error { +func (s *Server) internallyForwardMessage(ctx context.Context, req *messagingpb.SendMessageRequest) error { streamKey := base58.Encode(req.RendezvousKey.Value) log := s.log.WithFields(logrus.Fields{ @@ -122,10 +131,10 @@ func (s *server) internallyForwardMessage(ctx context.Context, req *messagingpb. if err == nil { log := log.WithField("receiver_location", rendezvousRecord.Location) - // We got lucky and the receiver's stream is on the same RPC server as + // We got lucky and the receiver's stream is on the same RPC Server as // where the message is created. No forwarding between servers is required. // Note that we always use the rendezvous record as the source of truth - // instead of checking for an active stream on this server. This server's + // instead of checking for an active stream on this Server. This Server's // active stream may not be holding the lock, which can only be determined // by who set the location in the rendezvous record. if rendezvousRecord.Location == s.broadcastAddress { @@ -153,7 +162,11 @@ func (s *server) internallyForwardMessage(ctx context.Context, req *messagingpb. log.WithError(err).Warn("failure creating internal grpc messaging client") return err } - defer cleanup() + defer func() { + if err := cleanup(); err != nil { + log.WithError(err).Warn("failed to cleanup internal messaging client") + } + }() reqBytes, err := proto.Marshal(req) if err != nil { @@ -189,7 +202,7 @@ func (s *server) internallyForwardMessage(ctx context.Context, req *messagingpb. return nil } -func (s *server) verifyForwardedSendMessageRequest(ctx context.Context, req *messagingpb.SendMessageRequest) (bool, error) { +func (s *Server) verifyForwardedSendMessageRequest(ctx context.Context, req *messagingpb.SendMessageRequest) (bool, error) { signature, _ := headers.GetASCIIHeaderByName(ctx, internalSignatureHeaderName) if len(signature) == 0 { return false, nil diff --git a/pkg/code/server/grpc/messaging/server.go b/pkg/code/server/grpc/messaging/server.go index c89f69ce..84762111 100644 --- a/pkg/code/server/grpc/messaging/server.go +++ b/pkg/code/server/grpc/messaging/server.go @@ -43,7 +43,7 @@ const ( rendezvousRecordMaxAge = messageStreamWithoutKeepAliveTimeout ) -type server struct { +type Server struct { log *logrus.Entry conf *conf data code_data.Provider @@ -54,7 +54,7 @@ type server struct { domainVerifier thirdparty.DomainVerifier - rendezvousFirstSeenAtCache cache.Cache // todo: Back with something like Redis when we go multi-server + rendezvousFirstSeenAtCache cache.Cache // todo: Back with something like Redis when we go multi-Server rpcSignatureVerifier *auth.RPCSignatureVerifier @@ -65,35 +65,35 @@ type server struct { // NewMessagingClient returns a new internal messaging client // -// todo: Proper separation of internal client and server +// todo: Proper separation of internal client and Server func NewMessagingClient( data code_data.Provider, ) InternalMessageClient { - return &server{ + return &Server{ log: logrus.StandardLogger().WithField("type", "messaging/client"), data: data, } } -// NewMessagingClientAndServer returns a new messaging client and server bundle. +// NewMessagingClientAndServer returns a new messaging client and Server bundle. // // These are currently highly coupled atm due to need to detect an active stream -// on a local server and avoiding the network call. +// on a local Server and avoiding the network call. // -// Note: The multi-server implementation of this server is not perfect, and it +// Note: The multi-Server implementation of this Server is not perfect, and it // doesn't need to be initially. We're mostly acting as a notification system // where updates can be sent out-of-order. If we need stronger guarantees, // resurrecting the Black Marlin KikX project might be necessary. Ideally we'd // avoid this due to the step level increase in complexity. // -// todo: Proper separation of internal client and server +// todo: Proper separation of internal client and Server func NewMessagingClientAndServer( data code_data.Provider, rpcSignatureVerifier *auth.RPCSignatureVerifier, broadcastAddress string, configProvider ConfigProvider, -) *server { - return &server{ +) *Server { + return &Server{ log: logrus.StandardLogger().WithField("type", "messaging/client_and_server"), conf: configProvider(), data: data, @@ -109,7 +109,7 @@ func NewMessagingClientAndServer( // OpenMessageStreamWithKeepAlive implements messagingpb.MessagingServer.OpenMessageStreamWithKeepAlive. // // todo: Majority of message streaming logic is duplicated here and in OpenMessageStream -func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_OpenMessageStreamWithKeepAliveServer) error { +func (s *Server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_OpenMessageStreamWithKeepAliveServer) error { ctx := streamer.Context() req, err := s.boundedRecv(ctx, streamer, 250*time.Millisecond) @@ -152,9 +152,9 @@ func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_O ms, exists := s.streams[streamKey] if exists { s.streamsMu.Unlock() - // There's an existing stream on this server that must be terminated first. + // There's an existing stream on this Server that must be terminated first. // Warn to see how often this happens in practice - log.Warnf("existing stream detected on this server (stream=%p) ; aborting", ms) + log.Warnf("existing stream detected on this Server (stream=%p) ; aborting", ms) return status.Error(codes.Aborted, "stream already exists") } @@ -183,7 +183,7 @@ func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_O // necessary since a striped lock could cause cross-payment delays // due to the long-lived nature of a stream. // todo: This can potentially hang for a really long time. This isn't a - // problem until we're actually multi-server and we have a real + // problem until we're actually multi-Server and we have a real // distributed lock where we can timeout the lock acquire process. myStreamMu.Lock() @@ -203,7 +203,7 @@ func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_O s.streamsMu.Unlock() // Delete the rendezvous record after killing the stream. This will allow - // another stream to "queue up" on the same server without failing with a + // another stream to "queue up" on the same Server without failing with a // duplication check while we wait for this slower DB operation. ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) err := s.data.DeleteRendezvous(ctx, streamKey) @@ -308,7 +308,7 @@ func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_O } } -func (s *server) boundedRecv( +func (s *Server) boundedRecv( ctx context.Context, streamer messagingpb.Messaging_OpenMessageStreamWithKeepAliveServer, timeout time.Duration, @@ -330,7 +330,7 @@ func (s *server) boundedRecv( } // Very naive implementation to start -func (s *server) monitorOpenMessageStreamHealth( +func (s *Server) monitorOpenMessageStreamHealth( ctx context.Context, log *logrus.Entry, ssRef string, @@ -364,7 +364,7 @@ func (s *server) monitorOpenMessageStreamHealth( // such by having a hard upper bound time that it can be opened. // // todo: Majority of message streaming logic is duplicated here and in OpenMessageStream -func (s *server) OpenMessageStream(req *messagingpb.OpenMessageStreamRequest, streamer messagingpb.Messaging_OpenMessageStreamServer) error { +func (s *Server) OpenMessageStream(req *messagingpb.OpenMessageStreamRequest, streamer messagingpb.Messaging_OpenMessageStreamServer) error { ctx := streamer.Context() streamKey := base58.Encode(req.RendezvousKey.Value) @@ -396,9 +396,9 @@ func (s *server) OpenMessageStream(req *messagingpb.OpenMessageStreamRequest, st ms, exists := s.streams[streamKey] if exists { s.streamsMu.Unlock() - // There's an existing stream on this server that must be terminated first. + // There's an existing stream on this Server that must be terminated first. // Warn to see how often this happens in practice - log.Warnf("existing stream detected on this server (stream=%p) ; aborting", ms) + log.Warnf("existing stream detected on this Server (stream=%p) ; aborting", ms) return status.Error(codes.Aborted, "stream already exists") } @@ -427,7 +427,7 @@ func (s *server) OpenMessageStream(req *messagingpb.OpenMessageStreamRequest, st // necessary since a striped lock could cause cross-payment delays // due to the long-lived nature of a stream. // todo: This can potentially hang for 60s. This isn't a problem until - // we're actually multi-server and we have a real distributed lock + // we're actually multi-Server and we have a real distributed lock // where we can timeout the lock acquire process. myStreamMu.Lock() @@ -447,7 +447,7 @@ func (s *server) OpenMessageStream(req *messagingpb.OpenMessageStreamRequest, st s.streamsMu.Unlock() // Delete the rendezvous record after killing the stream. This will allow - // another stream to "queue up" on the same server without failing with a + // another stream to "queue up" on the same Server without failing with a // duplication check while we wait for this slower DB operation. ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) err := s.data.DeleteRendezvous(ctx, streamKey) @@ -523,7 +523,7 @@ func (s *server) OpenMessageStream(req *messagingpb.OpenMessageStreamRequest, st } // PollMessages implements messagingpb.MessagingServer.PollMessages. -func (s *server) PollMessages(ctx context.Context, req *messagingpb.PollMessagesRequest) (*messagingpb.PollMessagesResponse, error) { +func (s *Server) PollMessages(ctx context.Context, req *messagingpb.PollMessagesRequest) (*messagingpb.PollMessagesResponse, error) { log := s.log.WithFields(logrus.Fields{ "method": "PollMessages", "rendezvous_key": base58.Encode(req.RendezvousKey.Value), @@ -576,7 +576,7 @@ func (s *server) PollMessages(ctx context.Context, req *messagingpb.PollMessages } // AckMessages implements messagingpb.MessagingServer.AckMessages. -func (s *server) AckMessages(ctx context.Context, req *messagingpb.AckMessagesRequest) (*messagingpb.AckMesssagesResponse, error) { +func (s *Server) AckMessages(ctx context.Context, req *messagingpb.AckMessagesRequest) (*messagingpb.AckMesssagesResponse, error) { log := s.log.WithFields(logrus.Fields{ "method": "AckMessages", "acks": len(req.MessageIds), @@ -605,7 +605,7 @@ func (s *server) AckMessages(ctx context.Context, req *messagingpb.AckMessagesRe } // SendMessage implements messagingpb.MessagingServer.SendMessage. -func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRequest) (*messagingpb.SendMessageResponse, error) { +func (s *Server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRequest) (*messagingpb.SendMessageResponse, error) { streamKey := base58.Encode(req.RendezvousKey.Value) log := s.log.WithFields(logrus.Fields{ @@ -621,7 +621,7 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe } // The request message has a message ID, which implies it was forwarded by - // another RPC server. All we need to do is to verify the request, and attempt + // another RPC Server. All we need to do is to verify the request, and attempt // to forward it to receiver's open message stream, if it exists. // // todo: Long term, we need public and internal APIs properly separated. For now, @@ -656,7 +656,7 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe // Otherwise, handle the request as a brand new message that must both be // created and sent to the receiver's stream, possibly by forwarding it to - // another RPC server. + // another RPC Server. if req.Message.SendMessageRequestSignature != nil { return nil, status.Error(codes.InvalidArgument, "message.send_message_request_signature cannot be set by clients") @@ -664,11 +664,9 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe var messageHandler MessageHandler switch req.Message.Kind.(type) { - // // Section: Cash // - case *messagingpb.Message_RequestToGrabBill: log = log.WithField("message_type", "request_to_grab_bill") messageHandler = NewRequestToGrabBillMessageHandler(s.data) @@ -676,7 +674,6 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe // // Section: Payment Request // - case *messagingpb.Message_RequestToReceiveBill: log = log.WithField("message_type", "request_to_receive_bill") messageHandler = NewRequestToReceiveBillMessageHandler(s.conf, s.data, s.rpcSignatureVerifier, s.domainVerifier) @@ -694,7 +691,6 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe // // Section: Login // - case *messagingpb.Message_RequestToLogin: log = log.WithField("message_type", "request_to_login") messageHandler = NewRequestToLoginMessageHandler(s.data, s.rpcSignatureVerifier, s.domainVerifier) @@ -705,7 +701,6 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe // // Section: Airdrops // - case *messagingpb.Message_AirdropReceived: return nil, status.Error(codes.InvalidArgument, "message.kind cannot be airdrop_received") @@ -815,7 +810,7 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe // scaled out, crashed, etc. The rendezvous record also hasn't been // cleaned up and we're within the stream timeout. This is an edge case, // and we won't consider it a failure. It's effectively the same as - // forwarding it to a server where the stream doesn't exist. The + // forwarding it to a Server where the stream doesn't exist. The // message will be picked up on the next stream open. isRpcFailure = false } @@ -832,7 +827,7 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe }, nil } -func (s *server) flush(ctx context.Context, accountID *messagingpb.RendezvousKey, stream *messageStream) { +func (s *Server) flush(ctx context.Context, accountID *messagingpb.RendezvousKey, stream *messageStream) { accountStr := base58.Encode(accountID.Value) log := s.log.WithFields(logrus.Fields{ @@ -862,7 +857,7 @@ func (s *server) flush(ctx context.Context, accountID *messagingpb.RendezvousKey } } -func (s *server) markRendezvousKeyAsSeen(rendezvousAccount *common.Account) { +func (s *Server) markRendezvousKeyAsSeen(rendezvousAccount *common.Account) { _, ok := s.rendezvousFirstSeenAtCache.Retrieve(rendezvousAccount.PublicKey().ToBase58()) if !ok { s.rendezvousFirstSeenAtCache.Insert(rendezvousAccount.PublicKey().ToBase58(), time.Now(), 1) diff --git a/pkg/code/server/grpc/messaging/server_test.go b/pkg/code/server/grpc/messaging/server_test.go index e6fc5d90..ddebc4ad 100644 --- a/pkg/code/server/grpc/messaging/server_test.go +++ b/pkg/code/server/grpc/messaging/server_test.go @@ -221,7 +221,6 @@ func TestSendMessage_RequestToGrabBill_Validation(t *testing.T) { sendMessageCall = env.client1.sendRequestToGrabBillMessage(t, rendezvousKey) sendMessageCall.assertInvalidMessageError(t, "requestor account must be latest temporary incoming account") env.server1.assertNoMessages(t, rendezvousKey) - } func TestSendMessage_RequestToReceiveBill_KinValue_HappyPath(t *testing.T) { diff --git a/pkg/code/server/grpc/messaging/testutil.go b/pkg/code/server/grpc/messaging/testutil.go index f1e7cf97..5db78ede 100644 --- a/pkg/code/server/grpc/messaging/testutil.go +++ b/pkg/code/server/grpc/messaging/testutil.go @@ -115,7 +115,7 @@ func setup(t *testing.T, enableMultiServer bool) (env testEnv, cleanup func()) { type serverEnv struct { ctx context.Context - server *server + server *Server subsidizer *common.Account } @@ -215,7 +215,7 @@ func (s *serverEnv) assertInitialRendezvousRecordSaved(t *testing.T, rendezvousK require.NoError(t, err) assert.Equal(t, rendezvousKey.PublicKey().ToBase58(), rendezvousRecord.Key) - assert.Equal(t, s.server.broadcastAddress, rendezvousRecord.Location) // Note: assertion must be called on the expected server + assert.Equal(t, s.server.broadcastAddress, rendezvousRecord.Location) // Note: assertion must be called on the expected Server assert.True(t, start.Sub(rendezvousRecord.CreatedAt) <= 50*time.Millisecond) assert.True(t, start.Sub(rendezvousRecord.CreatedAt) >= -50*time.Millisecond) assert.Equal(t, rendezvousRecord.CreatedAt.Unix(), rendezvousRecord.LastUpdatedAt.Unix()) @@ -461,7 +461,7 @@ func (c *clientEnv) waitUntilStreamTerminationOrTimeout(t *testing.T, rendezvous assert.True(t, time.Since(lastPingTs) <= messageStreamPingDelay+50*time.Millisecond) } - pingCount += 1 + pingCount++ lastPingTs = time.Now() if keepStreamAlive { diff --git a/pkg/code/server/grpc/micropayment/server_test.go b/pkg/code/server/grpc/micropayment/server_test.go index 39f4487f..eca18851 100644 --- a/pkg/code/server/grpc/micropayment/server_test.go +++ b/pkg/code/server/grpc/micropayment/server_test.go @@ -7,13 +7,13 @@ import ( "strings" "testing" - "github.com/golang/protobuf/proto" "github.com/google/uuid" "github.com/mr-tron/base58" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" diff --git a/pkg/code/server/grpc/phone/server_test.go b/pkg/code/server/grpc/phone/server_test.go index 5a7e1eec..68fd4c0e 100644 --- a/pkg/code/server/grpc/phone/server_test.go +++ b/pkg/code/server/grpc/phone/server_test.go @@ -606,7 +606,7 @@ func TestGetAssociatedPhoneNumber_UnlockedTimelockAccount(t *testing.T) { assert.Equal(t, phonepb.GetAssociatedPhoneNumberResponse_OK, resp.Result) timelockRecord.VaultState = timelock_token.StateUnlocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) resp, err = env.client.GetAssociatedPhoneNumber(env.ctx, req) diff --git a/pkg/code/server/grpc/push/server_test.go b/pkg/code/server/grpc/push/server_test.go index 2d94cfb8..71592a4f 100644 --- a/pkg/code/server/grpc/push/server_test.go +++ b/pkg/code/server/grpc/push/server_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" @@ -360,21 +361,29 @@ func generateNewDataContainer(t *testing.T, env testEnv, ownerAccount *common.Ac // todo: integrate below client utilities with the main testutil package func newClientWithoutUserAgent(t *testing.T, env testEnv) pushpb.PushClient { - conn, err := grpc.Dial(env.target, grpc.WithInsecure()) + conn, err := grpc.Dial(env.target, grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) return pushpb.NewPushClient(conn) } func newAndroidClient(t *testing.T, env testEnv) pushpb.PushClient { - conn, err := grpc.Dial(env.target, grpc.WithInsecure(), grpc.WithUserAgent("Code/Android/1.0.0")) + conn, err := grpc.Dial( + env.target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUserAgent("Code/Android/1.0.0"), + ) require.NoError(t, err) return pushpb.NewPushClient(conn) } func newIOSClient(t *testing.T, env testEnv) pushpb.PushClient { - conn, err := grpc.Dial(env.target, grpc.WithInsecure(), grpc.WithUserAgent("Code/iOS/1.0.0")) + conn, err := grpc.Dial( + env.target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUserAgent("Code/iOS/1.0.0"), + ) require.NoError(t, err) return pushpb.NewPushClient(conn) diff --git a/pkg/code/server/grpc/transaction/v2/action_handler.go b/pkg/code/server/grpc/transaction/v2/action_handler.go index 34a69625..0e988d0d 100644 --- a/pkg/code/server/grpc/transaction/v2/action_handler.go +++ b/pkg/code/server/grpc/transaction/v2/action_handler.go @@ -28,7 +28,7 @@ import ( ) // todo: a better name for this lol? -type makeSolanaTransactionResult struct { +type MakeSolanaTransactionResult struct { isCreatedOnDemand bool txn *solana.Transaction // Can be null if the transaction is on-demand created at scheduling time @@ -88,7 +88,7 @@ type CreateActionHandler interface { index int, nonce *common.Account, bh solana.Blockhash, - ) (*makeSolanaTransactionResult, error) + ) (*MakeSolanaTransactionResult, error) } // UpgradeActionHandler is an interface for upgrading existing actions. It's @@ -104,7 +104,7 @@ type UpgradeActionHandler interface { MakeUpgradedSolanaTransaction( nonce *common.Account, bh solana.Blockhash, - ) (*makeSolanaTransactionResult, error) + ) (*MakeSolanaTransactionResult, error) } type OpenAccountActionHandler struct { @@ -193,10 +193,10 @@ func (h *OpenAccountActionHandler) MakeNewSolanaTransaction( index int, nonce *common.Account, bh solana.Blockhash, -) (*makeSolanaTransactionResult, error) { +) (*MakeSolanaTransactionResult, error) { switch index { case 0: - return &makeSolanaTransactionResult{ + return &MakeSolanaTransactionResult{ isCreatedOnDemand: true, txn: nil, @@ -274,7 +274,7 @@ func (h *CloseEmptyAccountActionHandler) MakeNewSolanaTransaction( index int, nonce *common.Account, bh solana.Blockhash, -) (*makeSolanaTransactionResult, error) { +) (*MakeSolanaTransactionResult, error) { switch index { case 0: txn, err := transaction_util.MakeCloseEmptyAccountTransaction( @@ -286,7 +286,7 @@ func (h *CloseEmptyAccountActionHandler) MakeNewSolanaTransaction( return nil, err } - return &makeSolanaTransactionResult{ + return &MakeSolanaTransactionResult{ txn: &txn, fulfillmentType: fulfillment.CloseEmptyTimelockAccount, @@ -380,7 +380,7 @@ func (h *CloseDormantAccountActionHandler) MakeNewSolanaTransaction( index int, nonce *common.Account, bh solana.Blockhash, -) (*makeSolanaTransactionResult, error) { +) (*MakeSolanaTransactionResult, error) { switch index { case 0: txn, err := transaction_util.MakeCloseAccountWithBalanceTransaction( @@ -397,7 +397,7 @@ func (h *CloseDormantAccountActionHandler) MakeNewSolanaTransaction( intentOrderingIndex := uint64(math.MaxInt64) actionOrderingIndex := uint32(0) - return &makeSolanaTransactionResult{ + return &MakeSolanaTransactionResult{ txn: &txn, fulfillmentType: fulfillment.CloseDormantTimelockAccount, @@ -528,7 +528,7 @@ func (h *NoPrivacyTransferActionHandler) MakeNewSolanaTransaction( index int, nonce *common.Account, bh solana.Blockhash, -) (*makeSolanaTransactionResult, error) { +) (*MakeSolanaTransactionResult, error) { switch index { case 0: txn, err := transaction_util.MakeTransferWithAuthorityTransaction( @@ -542,7 +542,7 @@ func (h *NoPrivacyTransferActionHandler) MakeNewSolanaTransaction( return nil, err } - return &makeSolanaTransactionResult{ + return &MakeSolanaTransactionResult{ txn: &txn, fulfillmentType: fulfillment.NoPrivacyTransferWithAuthority, @@ -646,7 +646,7 @@ func (h *NoPrivacyWithdrawActionHandler) MakeNewSolanaTransaction( index int, nonce *common.Account, bh solana.Blockhash, -) (*makeSolanaTransactionResult, error) { +) (*MakeSolanaTransactionResult, error) { switch index { case 0: txn, err := transaction_util.MakeCloseAccountWithBalanceTransaction( @@ -660,7 +660,7 @@ func (h *NoPrivacyWithdrawActionHandler) MakeNewSolanaTransaction( return nil, err } - return &makeSolanaTransactionResult{ + return &MakeSolanaTransactionResult{ txn: &txn, fulfillmentType: fulfillment.NoPrivacyWithdraw, @@ -892,10 +892,10 @@ func (h *TemporaryPrivacyTransferActionHandler) MakeNewSolanaTransaction( index int, nonce *common.Account, bh solana.Blockhash, -) (*makeSolanaTransactionResult, error) { +) (*MakeSolanaTransactionResult, error) { switch index { case 0: - return &makeSolanaTransactionResult{ + return &MakeSolanaTransactionResult{ isCreatedOnDemand: true, txn: nil, @@ -917,7 +917,7 @@ func (h *TemporaryPrivacyTransferActionHandler) MakeNewSolanaTransaction( return nil, err } - return &makeSolanaTransactionResult{ + return &MakeSolanaTransactionResult{ txn: &txn, fulfillmentType: fulfillment.TemporaryPrivacyTransferWithAuthority, @@ -956,7 +956,7 @@ func NewPermanentPrivacyUpgradeActionHandler( data code_data.Provider, intentRecord *intent.Record, protoAction *transactionpb.PermanentPrivacyUpgradeAction, - cachedUpgradeTarget *privacyUpgradeCandidate, + cachedUpgradeTarget *PrivacyUpgradeCandidate, ) (UpgradeActionHandler, error) { h := &PermanentPrivacyUpgradeActionHandler{ data: data, @@ -1050,7 +1050,7 @@ func (h *PermanentPrivacyUpgradeActionHandler) getFulfillmentBeingUpgraded(ctx c func (h *PermanentPrivacyUpgradeActionHandler) MakeUpgradedSolanaTransaction( nonce *common.Account, bh solana.Blockhash, -) (*makeSolanaTransactionResult, error) { +) (*MakeSolanaTransactionResult, error) { txn, err := transaction_util.MakeTransferWithAuthorityTransaction( nonce, bh, @@ -1062,7 +1062,7 @@ func (h *PermanentPrivacyUpgradeActionHandler) MakeUpgradedSolanaTransaction( return nil, err } - return &makeSolanaTransactionResult{ + return &MakeSolanaTransactionResult{ txn: &txn, fulfillmentType: fulfillment.PermanentPrivacyTransferWithAuthority, diff --git a/pkg/code/server/grpc/transaction/v2/airdrop.go b/pkg/code/server/grpc/transaction/v2/airdrop.go index 122c4832..db434e6c 100644 --- a/pkg/code/server/grpc/transaction/v2/airdrop.go +++ b/pkg/code/server/grpc/transaction/v2/airdrop.go @@ -182,7 +182,7 @@ func (s *transactionServer) Airdrop(ctx context.Context, req *transactionpb.Aird func (s *transactionServer) maybeAirdropForSubmittingIntent(ctx context.Context, intentRecord *intent.Record, submitActionsOwnerMetadata *common.OwnerMetadata) { if false { // Disabled - s.maybeAirdropForSendingUserTheirFirstKin(ctx, intentRecord) + _ = s.maybeAirdropForSendingUserTheirFirstKin(ctx, intentRecord) } } @@ -443,8 +443,26 @@ func (s *transactionServer) airdrop(ctx context.Context, intentId string, owner log.WithError(err).Warn("failure selecting available nonce") return nil, err } + + // note: defer() will only run when the outer function returns, and therefore + // all of the defer()'s in // this loop will be run all at once at the end, + // rather than at the end of each iteration. + // + // Since we are not committing (and therefore consuming) the nonce's until the + // end of the function, this is desirable. If we released at the end of each + // iteration, we could potentially acquire the same nonce multiple times for + // different transactions, which would fail. defer func() { - selectedNonce.ReleaseIfNotReserved() + if err := selectedNonce.ReleaseIfNotReserved(); err != nil { + s.log. + WithFields(logrus.Fields{ + "method": "airdrop", + "nonce_account": selectedNonce.Account.PublicKey().ToBase58(), + "blockhash": selectedNonce.Blockhash.ToBase58(), + }). + WithError(err). + Warn("failed to release nonce") + } selectedNonce.Unlock() }() @@ -535,7 +553,7 @@ func (s *transactionServer) airdrop(ctx context.Context, intentId string, owner CreatedAt: time.Now(), } - var eventType event.EventType + var eventType event.Type switch airdropType { case AirdropTypeGetFirstKin: eventType = event.WelcomeBonusClaimed @@ -620,7 +638,7 @@ func (s *transactionServer) airdrop(ctx context.Context, intentId string, owner if canPushChatMessage { // Best-effort send a push - push_util.SendChatMessagePushNotification( + pushErr := push_util.SendChatMessagePushNotification( ctx, s.data, s.pusher, @@ -628,6 +646,9 @@ func (s *transactionServer) airdrop(ctx context.Context, intentId string, owner owner, chatMessage, ) + if pushErr != nil { + log.WithError(err).Warn("failed to send chat message push notification (best effort)") + } } recordAirdropEvent(ctx, owner, airdropType, usdValue) diff --git a/pkg/code/server/grpc/transaction/v2/errors.go b/pkg/code/server/grpc/transaction/v2/errors.go index 5c4d2d66..e2355dda 100644 --- a/pkg/code/server/grpc/transaction/v2/errors.go +++ b/pkg/code/server/grpc/transaction/v2/errors.go @@ -77,10 +77,6 @@ func newIntentDeniedError(message string) IntentDeniedError { } } -func newIntentDeniedErrorf(format string, args ...any) IntentDeniedError { - return newIntentDeniedError(fmt.Sprintf(format, args...)) -} - func (e IntentDeniedError) Error() string { return e.message } @@ -113,10 +109,6 @@ func newSwapDeniedError(message string) SwapDeniedError { } } -func newSwapDeniedErrorf(format string, args ...any) SwapDeniedError { - return newSwapDeniedError(fmt.Sprintf(format, args...)) -} - func (e SwapDeniedError) Error() string { return e.message } diff --git a/pkg/code/server/grpc/transaction/v2/intent.go b/pkg/code/server/grpc/transaction/v2/intent.go index 9faffc47..f0e89ba0 100644 --- a/pkg/code/server/grpc/transaction/v2/intent.go +++ b/pkg/code/server/grpc/transaction/v2/intent.go @@ -210,7 +210,7 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm log.Warnf("unhandled owner account type %s", submitActionsOwnerMetadata.Type) return handleSubmitIntentError(streamer, errors.New("unhandled owner account type")) } - } else if err == common.ErrOwnerNotFound { + } else if errors.Is(err, common.ErrOwnerNotFound) { //nolint:revive // Caught by later error } else if err != nil { log.WithError(err).Warn("failure getting owner account metadata") @@ -556,7 +556,7 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm } for j := 0; j < transactionCount; j++ { - var makeTxnResult *makeSolanaTransactionResult + var makeTxnResult *MakeSolanaTransactionResult var selectedNonce *transaction.SelectedNonce var actionId uint32 if isUpgradeActionOperation { @@ -602,14 +602,34 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm log.WithError(err).Warn("failure selecting available nonce") return handleSubmitIntentError(streamer, err) } + + // If we never assign the nonce a signature in the action creation flow, + // it's safe to put it back in the available pool. The client will have + // caused a failed RPC call, and we want to avoid malicious or erroneous + // clients from consuming our nonce pool! + // + // note: defer() will only run when the outer function returns, and + // therefore all of the defer()'s in this loop will be run all at once at + // the end, rather than at the end of each iteration. + // + // Since we are not committing (and therefore consuming) the nonce's until + // the end of the function, this is desirable. If we released at the end of + // each iteration, we could potentially acquire the same nonce multiple times + // for different transactions, which would fail. defer func() { - // If we never assign the nonce a signature in the action creation flow, - // it's safe to put it back in the available pool. The client will have - // caused a failed RPC call, and we want to avoid malicious or erroneous - // clients from consuming our nonce pool! - selectedNonce.ReleaseIfNotReserved() + if err := selectedNonce.ReleaseIfNotReserved(); err != nil { + s.log. + WithFields(logrus.Fields{ + "method": "SubmitIntent", + "nonce_account": selectedNonce.Account.PublicKey().ToBase58(), + "blockhash": selectedNonce.Blockhash.ToBase58(), + }). + WithError(err). + Warn("failed to release nonce") + } selectedNonce.Unlock() }() + nonceAccount = selectedNonce.Account nonceBlockchash = selectedNonce.Blockhash } else { @@ -961,7 +981,7 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm if len(chatMessagesToPush) > 0 { go func() { for _, chatMessageToPush := range chatMessagesToPush { - push.SendChatMessagePushNotification( + pushErr := push.SendChatMessagePushNotification( context.TODO(), s.data, s.pusher, @@ -969,6 +989,9 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm chatMessageToPush.Owner, chatMessageToPush.Message, ) + if pushErr != nil { + log.WithError(err).Warn("failure sending chat message push notification") + } } }() } @@ -1004,9 +1027,9 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm backgroundCtx := context.Background() // todo: generic metrics utility for this - nr, ok := ctx.Value(metrics.NewRelicContextKey).(*newrelic.Application) + nr, ok := ctx.Value(metrics.NewRelicContextKey{}).(*newrelic.Application) if ok { - backgroundCtx = context.WithValue(backgroundCtx, metrics.NewRelicContextKey, nr) + backgroundCtx = context.WithValue(backgroundCtx, metrics.NewRelicContextKey{}, nr) } // todo: We likely want to put this in a worker if this is a long term feature @@ -1290,10 +1313,10 @@ func (s *transactionServer) CanWithdrawToAccount(ctx context.Context, req *trans AccountType: transactionpb.CanWithdrawToAccountResponse_TokenAccount, IsValidPaymentDestination: accountInfoRecord.AccountType == commonpb.AccountType_PRIMARY || accountInfoRecord.AccountType == commonpb.AccountType_RELATIONSHIP, }, nil - } else { - log.WithError(err).Warn("failure checking account info db") - return nil, status.Error(codes.Internal, "") } + + log.WithError(err).Warn("failure checking account info db") + return nil, status.Error(codes.Internal, "") } // diff --git a/pkg/code/server/grpc/transaction/v2/intent_handler.go b/pkg/code/server/grpc/transaction/v2/intent_handler.go index 22f0d331..7b09b2cf 100644 --- a/pkg/code/server/grpc/transaction/v2/intent_handler.go +++ b/pkg/code/server/grpc/transaction/v2/intent_handler.go @@ -73,7 +73,7 @@ func init() { } } -type lockableAccounts struct { +type LockableAccounts struct { DestinationOwner *common.Account RemoteSendGiftCardVault *common.Account } @@ -98,7 +98,7 @@ type CreateIntentHandler interface { // // Note: Assumes relevant information is contained in the intent record after // calling PopulateMetadata. - GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) + GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*LockableAccounts, error) // AllowCreation determines whether the new intent creation should be allowed. AllowCreation(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata, actions []*transactionpb.Action, deviceToken *string) error @@ -169,8 +169,8 @@ func (h *OpenAccountsIntentHandler) IsNoop(ctx context.Context, intentRecord *in return false, nil } -func (h *OpenAccountsIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { - return &lockableAccounts{}, nil +func (h *OpenAccountsIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*LockableAccounts, error) { + return &LockableAccounts{}, nil } func (h *OpenAccountsIntentHandler) AllowCreation(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata, actions []*transactionpb.Action, deviceToken *string) error { @@ -499,7 +499,7 @@ func (h *SendPrivatePaymentIntentHandler) IsNoop(ctx context.Context, intentReco return false, nil } -func (h *SendPrivatePaymentIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { +func (h *SendPrivatePaymentIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*LockableAccounts, error) { var destinationOwnerAccount, giftCardVaultAccount *common.Account var err error @@ -517,7 +517,7 @@ func (h *SendPrivatePaymentIntentHandler) GetAdditionalAccountsToLock(ctx contex } } - return &lockableAccounts{ + return &LockableAccounts{ DestinationOwner: destinationOwnerAccount, RemoteSendGiftCardVault: giftCardVaultAccount, }, nil @@ -779,7 +779,7 @@ func (h *SendPrivatePaymentIntentHandler) validateActions( return err } } - } else if metadata.IsRemoteSend { + } else if metadata.IsRemoteSend { //nolint:revive // No validation needed here. Open validation is handled later. } else { // The client is trying a Code->Code payment since withdrawal and remote @@ -1069,8 +1069,8 @@ func (h *ReceivePaymentsPrivatelyIntentHandler) IsNoop(ctx context.Context, inte return false, nil } -func (h *ReceivePaymentsPrivatelyIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { - return &lockableAccounts{}, nil +func (h *ReceivePaymentsPrivatelyIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*LockableAccounts, error) { + return &LockableAccounts{}, nil } func (h *ReceivePaymentsPrivatelyIntentHandler) AllowCreation(ctx context.Context, intentRecord *intent.Record, untypedMetadata *transactionpb.Metadata, actions []*transactionpb.Action, deviceToken *string) error { @@ -1232,7 +1232,6 @@ func (h *ReceivePaymentsPrivatelyIntentHandler) validateActions( return err } } else { - // There's one opened account, and it must be the new temporary incoming account. if len(openedAccounts) != 1 { return newIntentValidationError("must open one account") @@ -1380,7 +1379,7 @@ func (h *ReceivePaymentsPrivatelyIntentHandler) OnCommittedToDB(ctx context.Cont type UpgradePrivacyIntentHandler struct { conf *conf data code_data.Provider - cachedUpgradeTargets map[uint32]*privacyUpgradeCandidate + cachedUpgradeTargets map[uint32]*PrivacyUpgradeCandidate } func NewUpgradePrivacyIntentHandler(conf *conf, data code_data.Provider) UpdateIntentHandler { @@ -1395,7 +1394,7 @@ func (h *UpgradePrivacyIntentHandler) AllowUpdate(ctx context.Context, existingI return errors.New("unexpected metadata proto message") } - cachedUpgradeTargets := make(map[uint32]*privacyUpgradeCandidate) + cachedUpgradeTargets := make(map[uint32]*PrivacyUpgradeCandidate) for _, untypedAction := range actions { var actionId uint32 switch typedAction := untypedAction.Type.(type) { @@ -1426,7 +1425,7 @@ func (h *UpgradePrivacyIntentHandler) AllowUpdate(ctx context.Context, existingI return nil } -func (h *UpgradePrivacyIntentHandler) GetCachedUpgradeTarget(protoAction *transactionpb.PermanentPrivacyUpgradeAction) (*privacyUpgradeCandidate, bool) { +func (h *UpgradePrivacyIntentHandler) GetCachedUpgradeTarget(protoAction *transactionpb.PermanentPrivacyUpgradeAction) (*PrivacyUpgradeCandidate, bool) { upgradeTo, ok := h.cachedUpgradeTargets[protoAction.ActionId] return upgradeTo, ok } @@ -1461,8 +1460,8 @@ func (h *MigrateToPrivacy2022IntentHandler) IsNoop(ctx context.Context, intentRe return false, nil } -func (h *MigrateToPrivacy2022IntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { - return &lockableAccounts{}, nil +func (h *MigrateToPrivacy2022IntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*LockableAccounts, error) { + return &LockableAccounts{}, nil } // Note: Most validation helper functions (eg. LocalSimulation) assume DataVersion1 @@ -1719,9 +1718,9 @@ func (h *SendPublicPaymentIntentHandler) IsNoop(ctx context.Context, intentRecor return false, nil } -func (h *SendPublicPaymentIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { +func (h *SendPublicPaymentIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*LockableAccounts, error) { if len(intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount) == 0 { - return &lockableAccounts{}, nil + return &LockableAccounts{}, nil } destinationOwnerAccount, err := common.NewAccountFromPublicKeyString(intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount) @@ -1729,7 +1728,7 @@ func (h *SendPublicPaymentIntentHandler) GetAdditionalAccountsToLock(ctx context return nil, err } - return &lockableAccounts{ + return &LockableAccounts{ DestinationOwner: destinationOwnerAccount, }, nil } @@ -2056,9 +2055,9 @@ func (h *ReceivePaymentsPubliclyIntentHandler) IsNoop(ctx context.Context, inten return false, nil } -func (h *ReceivePaymentsPubliclyIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { +func (h *ReceivePaymentsPubliclyIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*LockableAccounts, error) { if !intentRecord.ReceivePaymentsPubliclyMetadata.IsRemoteSend { - return &lockableAccounts{}, nil + return &LockableAccounts{}, nil } giftCardVaultAccount, err := common.NewAccountFromPublicKeyString(intentRecord.ReceivePaymentsPubliclyMetadata.Source) @@ -2066,7 +2065,7 @@ func (h *ReceivePaymentsPubliclyIntentHandler) GetAdditionalAccountsToLock(ctx c return nil, err } - return &lockableAccounts{ + return &LockableAccounts{ RemoteSendGiftCardVault: giftCardVaultAccount, }, nil } @@ -2340,8 +2339,8 @@ func (h *EstablishRelationshipIntentHandler) IsNoop(ctx context.Context, intentR return false, nil } -func (h *EstablishRelationshipIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*lockableAccounts, error) { - return &lockableAccounts{}, nil +func (h *EstablishRelationshipIntentHandler) GetAdditionalAccountsToLock(ctx context.Context, intentRecord *intent.Record) (*LockableAccounts, error) { + return &LockableAccounts{}, nil } func (h *EstablishRelationshipIntentHandler) AllowCreation(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata, actions []*transactionpb.Action, deviceToken *string) error { @@ -2967,7 +2966,6 @@ func validateExchangeDataWithinIntent(ctx context.Context, data code_data.Provid // exact exchange data requirements, which has already been validated at time // of payment intent creation. We do leave the ability open to reserve an exchange // rate, but no use cases warrant that atm. - } else if err != paymentrequest.ErrPaymentRequestNotFound { return err } @@ -3121,7 +3119,7 @@ func validateMinimalTempIncomingAccountUsage(ctx context.Context, data code_data continue } - paymentCount += 1 + paymentCount++ } // Should be coordinated with MustRotate flag in GetTokenAccountInfos @@ -3151,7 +3149,7 @@ func validateClaimedGiftCard(ctx context.Context, data code_data.Provider, giftC _, err = data.GetGiftCardClaimedAction(ctx, giftCardVaultAccount.PublicKey().ToBase58()) if err == nil { return newStaleStateError("gift card balance has already been claimed") - } else if err == action.ErrActionNotFound { + } else if errors.Is(err, action.ErrActionNotFound) { //nolint:revive // No action to claim it, so we can proceed } else if err != nil { return err diff --git a/pkg/code/server/grpc/transaction/v2/intent_test.go b/pkg/code/server/grpc/transaction/v2/intent_test.go index 8fdf7bbb..8e28f0d7 100644 --- a/pkg/code/server/grpc/transaction/v2/intent_test.go +++ b/pkg/code/server/grpc/transaction/v2/intent_test.go @@ -1209,7 +1209,6 @@ func TestSubmitIntent_SendPublicPayment_Validation_ExchangeData(t *testing.T) { submitIntentCall = sendingPhone.publiclyWithdraw777KinToCodeUserBetweenPrimaryAccounts(t, receivingPhone) submitIntentCall.assertInvalidIntentResponse(t, "payment native amount and quark value mismatch") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_SendPublicPayment_Validation_Balances(t *testing.T) { @@ -1227,7 +1226,6 @@ func TestSubmitIntent_SendPublicPayment_Validation_Balances(t *testing.T) { submitIntentCall := sendingPhone.publiclyWithdraw123KinToExternalWallet(t) submitIntentCall.assertInvalidIntentResponse(t, "insufficient balance to perform action") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_SendPublicPayment_Validation_Actions(t *testing.T) { @@ -1463,7 +1461,6 @@ func TestSubmitIntent_ReceivePaymentsPrivately_FromDeposit_AntiMoneyLaunderingGu submitIntentCall := sendingPhone.depositMillionDollarsIntoOrganizer(t) submitIntentCall.assertDeniedResponse(t, "dollar value exceeds limit") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReceivePaymentsPrivately_Validation_ManagedByCode(t *testing.T) { @@ -1485,7 +1482,6 @@ func TestSubmitIntent_ReceivePaymentsPrivately_Validation_ManagedByCode(t *testi submitIntentCall := receivingPhone.deposit777KinIntoOrganizer(t) submitIntentCall.assertDeniedResponse(t, "at least one account is no longer managed by code") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } server, _, receivingPhone, cleanup := setupTestEnv(t, &testOverrides{}) @@ -1522,7 +1518,6 @@ func TestSubmitIntent_ReceivePaymentsPrivately_Validation_Balances(t *testing.T) submitIntentCall := receivingPhone.deposit777KinIntoOrganizer(t) submitIntentCall.assertInvalidIntentResponse(t, "insufficient balance to perform action") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReceivePaymentsPrivately_Validation_Actions(t *testing.T) { @@ -1935,7 +1930,6 @@ func TestSubmitIntent_ReceivePaymentsPublicly_RemoteSend_ClaimGiftCardTwice(t *t submitIntentCall := receivingPhone.receive42KinFromGiftCard(t, giftCardAccount, false) submitIntentCall.assertStaleStateResponse(t, "gift card balance has already been claimed") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReceivePaymentsPublicly_RemoteSend_ExpiredGiftCard(t *testing.T) { @@ -1955,7 +1949,6 @@ func TestSubmitIntent_ReceivePaymentsPublicly_RemoteSend_ExpiredGiftCard(t *test submitIntentCall := receivingPhone.receive42KinFromGiftCard(t, giftCardAccount, false) submitIntentCall.assertStaleStateResponse(t, "gift card is expired") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReceivePaymentsPublicly_RemoteSend_NonIssuerAttepmtsToVoid(t *testing.T) { @@ -1975,7 +1968,6 @@ func TestSubmitIntent_ReceivePaymentsPublicly_RemoteSend_NonIssuerAttepmtsToVoid submitIntentCall := receivingPhone.receive42KinFromGiftCard(t, giftCardAccount, true) submitIntentCall.assertInvalidIntentResponse(t, "only the issuer can void the gift card") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReceivePaymentsPublicly_RemoteSend_ClaimInvalidGiftCardBalance(t *testing.T) { @@ -2001,7 +1993,6 @@ func TestSubmitIntent_ReceivePaymentsPublicly_RemoteSend_ClaimInvalidGiftCardBal submitIntentCall = receivingPhone.receive42KinFromGiftCard(t, giftCardAccount, false) submitIntentCall.assertInvalidIntentResponse(t, "must receive entire gift card balance of 4200000 quarks") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReceivePaymentsPublicly_AntispamGuard(t *testing.T) { @@ -2056,7 +2047,6 @@ func TestSubmitIntent_ReceivePaymentsPublicly_Validation_ManagedByCode(t *testin submitIntentCall := receivingPhone.receive42KinFromGiftCard(t, giftCardAccount, false) submitIntentCall.assertDeniedResponse(t, "at least one account is no longer managed by code") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } } @@ -2086,7 +2076,6 @@ func TestSubmitIntent_ReceivePaymentsPublicly_Validation_ManagedByCode(t *testin submitIntentCall.assertDeniedResponse(t, "at least one account is no longer managed by code") } server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } } @@ -2107,7 +2096,6 @@ func TestSubmitIntent_ReceivePaymentsPublicly_Validation_Balances(t *testing.T) submitIntentCall := receivingPhone.receive42KinFromGiftCard(t, giftCardAccount, false) submitIntentCall.assertInvalidIntentResponse(t, "actions[0]: insufficient balance to perform action") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReceivePaymentsPublicly_Validation_OnlyRemoteSendSupported(t *testing.T) { @@ -2127,7 +2115,6 @@ func TestSubmitIntent_ReceivePaymentsPublicly_Validation_OnlyRemoteSendSupported submitIntentCall := receivingPhone.receive42KinFromGiftCard(t, giftCardAccount, false) submitIntentCall.assertGrpcError(t, codes.InvalidArgument) server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReceivePaymentsPublicly_Validation_Actions(t *testing.T) { @@ -2609,7 +2596,6 @@ func TestSubmitIntent_SubmitIntentDisabled(t *testing.T) { submitIntentCall := phone.openAccounts(t) submitIntentCall.assertGrpcError(t, codes.Unavailable) server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_ReuseIntentId(t *testing.T) { @@ -2649,7 +2635,6 @@ func TestSubmitIntent_NoAvailableNonces(t *testing.T) { submitIntentCall := phone.openAccounts(t) submitIntentCall.assertGrpcError(t, codes.Unavailable) server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_NotPhoneVerified(t *testing.T) { @@ -2664,7 +2649,6 @@ func TestSubmitIntent_NotPhoneVerified(t *testing.T) { submitIntentCall := phone.openAccounts(t) submitIntentCall.assertDeniedResponse(t, "not phone verified") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_UnauthenticatedAccess(t *testing.T) { @@ -2681,7 +2665,6 @@ func TestSubmitIntent_UnauthenticatedAccess(t *testing.T) { submitIntentCall = phone.openAccounts(t) submitIntentCall.assertGrpcError(t, codes.Unauthenticated) server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_InvalidActionId(t *testing.T) { @@ -2692,7 +2675,6 @@ func TestSubmitIntent_InvalidActionId(t *testing.T) { submitIntentCall := phone.openAccounts(t) submitIntentCall.assertGrpcError(t, codes.InvalidArgument) server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_InvalidOpenAccountOwner(t *testing.T) { @@ -2710,7 +2692,6 @@ func TestSubmitIntent_InvalidOpenAccountOwner(t *testing.T) { submitIntentCall = phone.send42KinToGiftCardAccount(t, testutil.NewRandomAccount(t)) submitIntentCall.assertInvalidIntentResponse(t, "actions[0]: owner must be") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_InvalidSignatureValueSubmitted(t *testing.T) { @@ -2723,7 +2704,6 @@ func TestSubmitIntent_InvalidSignatureValueSubmitted(t *testing.T) { submitIntentCall := phone.openAccounts(t) submitIntentCall.assertInvalidSignatureValueResponse(t) server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_InvalidNumberOfSignaturesSubmitted(t *testing.T) { @@ -2742,7 +2722,6 @@ func TestSubmitIntent_InvalidNumberOfSignaturesSubmitted(t *testing.T) { submitIntentCall = phone.openAccounts(t) submitIntentCall.assertSignatureErrorResponse(t, "at least one signature is missing") server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_TimeBoundedRequestSend(t *testing.T) { @@ -2764,7 +2743,6 @@ func TestSubmitIntent_TimeBoundedRequestSend(t *testing.T) { submitIntentCall = phone.openAccounts(t) submitIntentCall.assertGrpcError(t, codes.DeadlineExceeded) server.assertIntentNotSubmitted(t, submitIntentCall.intentId) - } func TestSubmitIntent_TreasuryPoolUsage(t *testing.T) { diff --git a/pkg/code/server/grpc/transaction/v2/local_simulation_test.go b/pkg/code/server/grpc/transaction/v2/local_simulation_test.go index 9a52d26b..53541d08 100644 --- a/pkg/code/server/grpc/transaction/v2/local_simulation_test.go +++ b/pkg/code/server/grpc/transaction/v2/local_simulation_test.go @@ -640,7 +640,6 @@ func TestLocalSimulation_InvalidTimelockVault(t *testing.T) { require.Error(t, err) assert.True(t, strings.Contains(err.Error(), "token must be")) } - } type localSimulationTestEnv struct { @@ -662,7 +661,7 @@ func (env localSimulationTestEnv) setupTimelockRecord(t *testing.T, authority *c require.NoError(t, err) timelockRecord := timelockAccounts.ToDBRecord() timelockRecord.VaultState = state - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) } diff --git a/pkg/code/server/grpc/transaction/v2/proof.go b/pkg/code/server/grpc/transaction/v2/proof.go index 5336b781..65e01ab3 100644 --- a/pkg/code/server/grpc/transaction/v2/proof.go +++ b/pkg/code/server/grpc/transaction/v2/proof.go @@ -42,7 +42,7 @@ type privacyUpgradeProof struct { newCommitmentRoot merkletree.Hash } -type privacyUpgradeCandidate struct { +type PrivacyUpgradeCandidate struct { newCommitmentRecord *commitment.Record forLeafHash merkletree.Hash @@ -66,7 +66,7 @@ func canUpgradeCommitmentAction(ctx context.Context, data code_data.Provider, co // Note: How we get select commmitments for proofs plays into how we decide to // close commitment vaults. Updates to logic should be in sync. -func selectCandidateForPrivacyUpgrade(ctx context.Context, data code_data.Provider, intentId string, actionId uint32) (*privacyUpgradeCandidate, error) { +func selectCandidateForPrivacyUpgrade(ctx context.Context, data code_data.Provider, intentId string, actionId uint32) (*PrivacyUpgradeCandidate, error) { actionRecord, err := data.GetActionById(ctx, intentId, actionId) if err != nil { return nil, err @@ -167,7 +167,7 @@ func selectCandidateForPrivacyUpgrade(ctx context.Context, data code_data.Provid return nil, ErrWaitForNextBlock } - return &privacyUpgradeCandidate{ + return &PrivacyUpgradeCandidate{ newCommitmentRecord: latestCommitmentRecord, forLeafHash: commitmentAddressBytes, @@ -180,7 +180,7 @@ func selectCandidateForPrivacyUpgrade(ctx context.Context, data code_data.Provid // Note: How we get proofs plays into how we decide to close commitment vaults. Updates to // logic should be in sync. -func getProofForPrivacyUpgrade(ctx context.Context, data code_data.Provider, upgradingTo *privacyUpgradeCandidate) (*privacyUpgradeProof, error) { +func getProofForPrivacyUpgrade(ctx context.Context, data code_data.Provider, upgradingTo *PrivacyUpgradeCandidate) (*privacyUpgradeProof, error) { merkleTree, err := getCachedMerkleTreeForTreasury(ctx, data, upgradingTo.newCommitmentRecord.Pool) if err != nil { return nil, err diff --git a/pkg/code/server/grpc/transaction/v2/swap.go b/pkg/code/server/grpc/transaction/v2/swap.go index 26309beb..fda5f533 100644 --- a/pkg/code/server/grpc/transaction/v2/swap.go +++ b/pkg/code/server/grpc/transaction/v2/swap.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ed25519" "encoding/base64" + "fmt" "time" "github.com/mr-tron/base58" @@ -357,7 +358,9 @@ func (s *transactionServer) Swap(streamer transactionpb.Transaction_SwapServer) } copy(txn.Signatures[clientSignatureIndex][:], submitSignatureReq.Signature.Value) - txn.Sign(s.swapSubsidizer.PrivateKey().ToBytes()) + if err := txn.Sign(s.swapSubsidizer.PrivateKey().ToBytes()); err != nil { + return fmt.Errorf("failed to sign transaction with swap subsidizer: %w", err) + } log = log.WithField("transaction_id", base58.Encode(txn.Signature())) @@ -532,7 +535,7 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con } if canPush { - push_util.SendChatMessagePushNotification( + pushErr := push_util.SendChatMessagePushNotification( ctx, s.data, s.pusher, @@ -540,6 +543,12 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con owner, chatMessage, ) + if pushErr != nil { + s.log. + WithField("method", "bestEffortNotifyUserOfSwapInProgress"). + WithError(err). + Warn("failed to push chat message notification (best effort)") + } } return nil diff --git a/pkg/code/server/grpc/transaction/v2/testutil.go b/pkg/code/server/grpc/transaction/v2/testutil.go index 5b2826e9..69e43b93 100644 --- a/pkg/code/server/grpc/transaction/v2/testutil.go +++ b/pkg/code/server/grpc/transaction/v2/testutil.go @@ -4,12 +4,13 @@ import ( "bytes" "context" "crypto/ed25519" + "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" "io" "math" - "math/rand" + mrand "math/rand" "strconv" "strings" "testing" @@ -21,6 +22,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" @@ -134,7 +136,7 @@ func setupTestEnv(t *testing.T, serverOverrides *testOverrides) (serverTestEnv, SolanaBlock: 123, - State: treasury.TreasuryPoolStateAvailable, + State: treasury.PoolStateAvailable, } serverEnv.treasuryPoolByAddress[treasuryPoolRecord.Address] = treasuryPoolRecord serverEnv.treasuryPoolByBucket[bucket] = treasuryPoolRecord @@ -201,7 +203,11 @@ func setupTestEnv(t *testing.T, serverOverrides *testOverrides) (serverTestEnv, var phoneEnvs []phoneTestEnv for i := 0; i < 2; i++ { // Force iOS user agent to pass airdrop tests - iosGrpcClientConn, err := grpc.Dial(grpcClientConn.Target(), grpc.WithInsecure(), grpc.WithUserAgent("Code/iOS/1.0.0")) + iosGrpcClientConn, err := grpc.Dial( + grpcClientConn.Target(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUserAgent("Code/iOS/1.0.0"), + ) require.NoError(t, err) phoneEnv := phoneTestEnv{ @@ -238,7 +244,7 @@ func setupTestEnv(t *testing.T, serverOverrides *testOverrides) (serverTestEnv, require.NoError(t, serverEnv.data.SavePhoneVerification(serverEnv.ctx, verificationRecord)) userIdentityRecord := &user_identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneEnv.verifiedPhoneNumber, }, @@ -262,7 +268,7 @@ func setupTestEnv(t *testing.T, serverOverrides *testOverrides) (serverTestEnv, require.NoError(t, err) timelockRecord := legacyTimelockAccounts.ToDBRecord() timelockRecord.VaultState = timelock_token_v1.StateLocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, serverEnv.data.SaveTimelock(serverEnv.ctx, timelockRecord)) // Simulate a swap account being created @@ -351,7 +357,7 @@ func (s *serverTestEnv) fundAccount(t *testing.T, account *common.Account, quark Rendezvous: "", IsExternal: true, - TransactionId: fmt.Sprintf("txn%d", rand.Uint64()), + TransactionId: fmt.Sprintf("txn%d", mrand.Uint64()), ConfirmationState: transaction.ConfirmationFinalized, @@ -366,7 +372,7 @@ func (s *serverTestEnv) fundAccount(t *testing.T, account *common.Account, quark require.NoError(t, s.data.CreatePayment(s.ctx, paymentRecord)) depositRecord := &deposit.Record{ - Signature: fmt.Sprintf("txn%d", rand.Uint64()), + Signature: fmt.Sprintf("txn%d", mrand.Uint64()), Destination: account.PublicKey().ToBase58(), Amount: quarks, UsdMarketValue: 0.1 * float64(quarks) / float64(kin.QuarksPerKin), @@ -490,7 +496,8 @@ func (s *serverTestEnv) generateAvailableNonce(t *testing.T) *nonce.Record { nonceAccount := testutil.NewRandomAccount(t) var bh solana.Blockhash - rand.Read(bh[:]) + _, err := rand.Read(bh[:]) + require.NoError(t, err) nonceKey := &vault.Record{ PublicKey: nonceAccount.PublicKey().ToBase58(), @@ -533,7 +540,7 @@ func (s *serverTestEnv) simulateTimelockAccountInState(t *testing.T, vault *comm require.NoError(t, err) timelockRecord.VaultState = state - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, s.data.SaveTimelock(s.ctx, timelockRecord)) } @@ -1086,28 +1093,17 @@ func (s *serverTestEnv) setupAirdropper(t *testing.T, initialFunds uint64) *comm return owner } -func (s *serverTestEnv) assertAirdroppedFirstKin(t *testing.T, phone phoneTestEnv) { +func (s serverTestEnv) assertAirdroppedFirstKin(t *testing.T, phone phoneTestEnv) { airdropIntentId := GetNewAirdropIntentId(AirdropTypeGetFirstKin, phone.parentAccount.PublicKey().ToBase58()) s.assertAirdropped(t, phone, AirdropTypeGetFirstKin, airdropIntentId, 1.0) } -func (s *serverTestEnv) assertNotAirdroppedFirstKin(t *testing.T, phone phoneTestEnv) { +func (s serverTestEnv) assertNotAirdroppedFirstKin(t *testing.T, phone phoneTestEnv) { airdropIntentId := GetNewAirdropIntentId(AirdropTypeGetFirstKin, phone.parentAccount.PublicKey().ToBase58()) _, err := s.data.GetIntent(s.ctx, airdropIntentId) assert.Equal(t, intent.ErrIntentNotFound, err) } -func (s *serverTestEnv) assertAirdroppedForGivingFirstKin(t *testing.T, phone phoneTestEnv, intentId string) { - airdropIntentId := GetNewAirdropIntentId(AirdropTypeGiveFirstKin, intentId) - s.assertAirdropped(t, phone, AirdropTypeGiveFirstKin, airdropIntentId, 5.0) -} - -func (s *serverTestEnv) assertNotAirdroppedForGivingFirstKin(t *testing.T, intentId string) { - airdropIntentId := GetNewAirdropIntentId(AirdropTypeGiveFirstKin, intentId) - _, err := s.data.GetIntent(s.ctx, airdropIntentId) - assert.Equal(t, intent.ErrIntentNotFound, err) -} - func (s serverTestEnv) assertAirdropped(t *testing.T, phone phoneTestEnv, airdropType AirdropType, intentId string, usdValue float64) { airdropIntentRecord, err := s.data.GetIntent(s.ctx, intentId) require.NoError(t, err) @@ -1194,12 +1190,6 @@ func (s serverTestEnv) assertAirdropped(t *testing.T, phone phoneTestEnv, airdro } } -func (s serverTestEnv) assertNoNoncesReserved(t *testing.T) { - count, err := s.data.GetNonceCountByState(s.ctx, nonce.StateReserved) - require.NoError(t, err) - assert.EqualValues(t, 0, count) -} - func (s serverTestEnv) assertNoNoncesReservedForIntent(t *testing.T, intentId string) { var cursor query.Cursor for { @@ -2173,7 +2163,6 @@ func (p *phoneTestEnv) openAccounts(t *testing.T) submitIntentCallMetadata { } func (p *phoneTestEnv) send42KinToGiftCardAccount(t *testing.T, giftCardAccount *common.Account) submitIntentCallMetadata { - // Generate a new random gift card account (no derivation logic, index, etc...) p.allGiftCardAccounts = append(p.allGiftCardAccounts, giftCardAccount) @@ -3803,9 +3792,9 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra isPrivateTransferIntent = true if p.conf.simulateSendingTooLittle { - typed.SendPrivatePayment.ExchangeData.Quarks += 1 + typed.SendPrivatePayment.ExchangeData.Quarks++ } else if p.conf.simulateSendingTooMuch { - typed.SendPrivatePayment.ExchangeData.Quarks -= 1 + typed.SendPrivatePayment.ExchangeData.Quarks-- } if p.conf.simulateInvalidExchangeRate { @@ -4149,9 +4138,9 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra isPrivateTransferIntent = true if p.conf.simulateReceivingTooLittle { - typed.ReceivePaymentsPrivately.Quarks += 1 + typed.ReceivePaymentsPrivately.Quarks++ } else if p.conf.simulateReceivingTooMuch { - typed.ReceivePaymentsPrivately.Quarks -= 1 + typed.ReceivePaymentsPrivately.Quarks-- } if p.conf.simulateFundingTempAccountTooMuch { @@ -4324,7 +4313,7 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra for _, action := range actions { switch typed := action.Type.(type) { case *transactionpb.Action_NoPrivacyWithdraw: - typed.NoPrivacyWithdraw.Amount += 1 + typed.NoPrivacyWithdraw.Amount++ } } } @@ -4384,9 +4373,9 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra } case *transactionpb.Metadata_SendPublicPayment: if p.conf.simulateSendingTooLittle { - typed.SendPublicPayment.ExchangeData.Quarks += 1 + typed.SendPublicPayment.ExchangeData.Quarks++ } else if p.conf.simulateSendingTooMuch { - typed.SendPublicPayment.ExchangeData.Quarks -= 1 + typed.SendPublicPayment.ExchangeData.Quarks-- } if p.conf.simulateInvalidExchangeRate { @@ -4458,15 +4447,15 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra } if typed.ReceivePaymentsPublicly.IsRemoteSend && p.conf.simulateClaimingTooLittleFromGiftCard { - typed.ReceivePaymentsPublicly.Quarks += 1 + typed.ReceivePaymentsPublicly.Quarks++ } else if typed.ReceivePaymentsPublicly.IsRemoteSend && p.conf.simulateClaimingTooMuchFromGiftCard { - typed.ReceivePaymentsPublicly.Quarks -= 1 + typed.ReceivePaymentsPublicly.Quarks-- } if p.conf.simulateReceivingTooLittle { - actions[0].GetNoPrivacyWithdraw().Amount -= 1 + actions[0].GetNoPrivacyWithdraw().Amount-- } else if p.conf.simulateReceivingTooMuch { - actions[0].GetNoPrivacyWithdraw().Amount += 1 + actions[0].GetNoPrivacyWithdraw().Amount++ } if p.conf.simulateNotReceivingFromSource { @@ -4586,7 +4575,7 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra } if p.conf.simulateInvalidIndexForOpenPrimaryAccountAction && typed.OpenAccount.AccountType == commonpb.AccountType_PRIMARY { - typed.OpenAccount.Index += 1 + typed.OpenAccount.Index++ } if p.conf.simulateInvalidAccountTypeForOpenNonPrimaryAccountAction && typed.OpenAccount.AccountType != commonpb.AccountType_PRIMARY { @@ -4609,7 +4598,7 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra } if p.conf.simulateInvalidIndexForOpenNonPrimaryAccountAction && typed.OpenAccount.AccountType != commonpb.AccountType_PRIMARY { - typed.OpenAccount.Index += 1 + typed.OpenAccount.Index++ } if p.conf.simulateOpeningWrongTempAccount && typed.OpenAccount.AccountType == commonpb.AccountType_TEMPORARY_INCOMING { @@ -4927,8 +4916,8 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra if p.conf.simulateInvalidActionId { require.True(t, len(actions) >= 2) - actions[0].Id += 1 - actions[1].Id -= 1 + actions[0].Id++ + actions[1].Id-- } if p.conf.simulateReusingIntentId { @@ -4980,7 +4969,7 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra bps := defaultTestThirdPartyFeeBps if p.conf.simulateInvalidThirdPartyFeeAmount { - bps += 1 + bps++ } paymentRequestRecord.Fees = append(paymentRequestRecord.Fees, &paymentrequest.Fee{ @@ -5279,7 +5268,7 @@ func (p *phoneTestEnv) submitIntent(t *testing.T, intentId string, metadata *tra } if p.conf.simulateInvalidSignatureValueSubmitted { - protoSignatures[0].Value[0] += 1 + protoSignatures[0].Value[0]++ } if p.conf.simulateDelayForSubmittingSignatures { @@ -6063,7 +6052,7 @@ func (p *phoneTestEnv) assertAirdropCount(t *testing.T, expected int) { var actual int for _, historyItem := range history { if historyItem.IsAirdrop { - actual += 1 + actual++ } } assert.Equal(t, expected, actual) @@ -6152,12 +6141,6 @@ func (m submitIntentCallMetadata) assertGrpcError(t *testing.T, code codes.Code) testutil.AssertStatusErrorWithCode(t, m.err, code) } -func (m submitIntentCallMetadata) assertGrpcErrorWithMessage(t *testing.T, code codes.Code, message string) { - m.assertGrpcError(t, code) - - assert.True(t, strings.Contains(m.err.Error(), message)) -} - func (m submitIntentCallMetadata) isError(t *testing.T) bool { return m.err != nil || m.resp.GetError() != nil } diff --git a/pkg/code/server/grpc/transaction/v2/treasury.go b/pkg/code/server/grpc/transaction/v2/treasury.go index 99ad3c0d..88373435 100644 --- a/pkg/code/server/grpc/transaction/v2/treasury.go +++ b/pkg/code/server/grpc/transaction/v2/treasury.go @@ -8,10 +8,10 @@ import ( "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/kin" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/treasury" + "github.com/code-payments/code-server/pkg/kin" ) // todo: any other treasury-related things we can put here? @@ -126,7 +126,9 @@ func (s *transactionServer) treasuryPoolMonitor(ctx context.Context, name string defer cancel() start := time.Now() - defer time.Sleep(s.conf.treasuryPoolStatsRefreshInterval.Get(ctx) - time.Since(start)) + defer func() { + time.Sleep(s.conf.treasuryPoolStatsRefreshInterval.Get(ctx) - time.Since(start)) + }() if treasuryPoolRecord == nil { treasuryPoolRecord, err = s.data.GetTreasuryPoolByName(ctx, name) diff --git a/pkg/code/server/grpc/user/server.go b/pkg/code/server/grpc/user/server.go index 00176fa9..ec3519b0 100644 --- a/pkg/code/server/grpc/user/server.go +++ b/pkg/code/server/grpc/user/server.go @@ -98,7 +98,7 @@ func (s *identityServer) LinkAccount(ctx context.Context, req *userpb.LinkAccoun } var result userpb.LinkAccountResponse_Result - var userID *user.UserID + var userID *user.Id var dataContainerID *user.DataContainerID var metadata *userpb.PhoneMetadata @@ -156,7 +156,7 @@ func (s *identityServer) LinkAccount(ctx context.Context, req *userpb.LinkAccoun } newUser := identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &token.Phone.PhoneNumber.Value, }, @@ -302,7 +302,7 @@ func (s *identityServer) GetUser(ctx context.Context, req *userpb.GetUserRequest } var result userpb.GetUserResponse_Result - var userID *user.UserID + var userID *user.Id var isStaff bool var dataContainerID *user.DataContainerID var metadata *userpb.PhoneMetadata @@ -717,7 +717,6 @@ func (s *identityServer) GetTwitterUser(ctx context.Context, req *userpb.GetTwit log.WithError(err).Warn("failure getting twitter user info") return nil, status.Error(codes.Internal, "") } - } func (s *identityServer) markWebhookAsPending(ctx context.Context, id string) error { diff --git a/pkg/code/server/grpc/user/server_test.go b/pkg/code/server/grpc/user/server_test.go index 2319aa38..8632614c 100644 --- a/pkg/code/server/grpc/user/server_test.go +++ b/pkg/code/server/grpc/user/server_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" "github.com/mr-tron/base58" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -15,6 +14,8 @@ import ( xrate "golang.org/x/time/rate" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/proto" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" @@ -59,7 +60,11 @@ func setup(t *testing.T) (env testEnv, cleanup func()) { require.NoError(t, err) // Force iOS user agent to pass airdrop tests - iosConn, err := grpc.Dial(conn.Target(), grpc.WithInsecure(), grpc.WithUserAgent("Code/iOS/1.0.0")) + iosConn, err := grpc.Dial( + conn.Target(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUserAgent("Code/iOS/1.0.0"), + ) require.NoError(t, err) env.ctx = context.Background() @@ -226,7 +231,7 @@ func TestLinkAccount_UserAlreadyExists(t *testing.T) { })) userRecord := &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -331,7 +336,7 @@ func TestUnlinkAccount_PhoneNeverAssociated(t *testing.T) { invalidPhoneNumber := "+18005550000" userRecord := &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &validPhoneNumber, }, @@ -439,7 +444,7 @@ func TestGetUser_UnlockedTimelockAccount(t *testing.T) { phoneNumber := "+12223334444" userRecord := &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -501,7 +506,7 @@ func TestGetUser_UnlockedTimelockAccount(t *testing.T) { assert.Equal(t, userpb.GetUserResponse_OK, resp.Result) timelockRecord.VaultState = timelock_token.StateUnlocked - timelockRecord.Block += 1 + timelockRecord.Block++ require.NoError(t, env.data.SaveTimelock(env.ctx, timelockRecord)) resp, err = env.client.GetUser(env.ctx, req) @@ -524,7 +529,7 @@ func TestGetUser_LinkStatus(t *testing.T) { for _, phoneNumber := range phoneNumbers { userRecord := &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -657,7 +662,7 @@ func TestGetUser_FeatureFlags(t *testing.T) { phoneNumber := "+12223334444" userRecord := &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -719,7 +724,7 @@ func TestGetUser_AirdropStatus(t *testing.T) { phoneNumber := "+12223334444" userRecord := &identity.Record{ - ID: user.NewUserID(), + ID: user.NewID(), View: &user.View{ PhoneNumber: &phoneNumber, }, @@ -1271,7 +1276,6 @@ func TestGetLoginForThirdPartyApp_HappyPath(t *testing.T) { {paymentRequestRecord, paymentIntentRecord}, {loginRequestRecord, loginIntentRecord}, } { - env, cleanup := setup(t) defer cleanup() diff --git a/pkg/code/server/web/request/server.go b/pkg/code/server/web/request/server.go index 95730471..f2158e32 100644 --- a/pkg/code/server/web/request/server.go +++ b/pkg/code/server/web/request/server.go @@ -60,7 +60,9 @@ func (s *Server) createIntentHandler(path string) func(w http.ResponseWriter, r w.Header().Set(contentTypeHeaderName, jsonContentTypeHeaderValue) w.WriteHeader(statusCode) - w.Write([]byte(body.ToString())) + if _, err := w.Write([]byte(body.ToString())); err != nil { + log.WithError(err).Info("failed to write body") + } } } @@ -101,7 +103,9 @@ func (s *Server) getStatusHandler(path string) func(w http.ResponseWriter, r *ht w.Header().Set(contentTypeHeaderName, jsonContentTypeHeaderValue) w.WriteHeader(statusCode) - w.Write([]byte(body.ToString())) + if _, err := w.Write([]byte(body.ToString())); err != nil { + log.WithError(err).Warn("failed to write body") + } } } @@ -137,7 +141,9 @@ func (s *Server) getUserIdHandler(path string) func(w http.ResponseWriter, r *ht w.Header().Set(contentTypeHeaderName, jsonContentTypeHeaderValue) w.WriteHeader(statusCode) - w.Write([]byte(body.ToString())) + if _, err := w.Write([]byte(body.ToString())); err != nil { + log.WithError(err).Warn("failed to write body") + } } } diff --git a/pkg/code/thirdparty/message.go b/pkg/code/thirdparty/message.go index 03ed2aee..964db750 100644 --- a/pkg/code/thirdparty/message.go +++ b/pkg/code/thirdparty/message.go @@ -4,6 +4,7 @@ import ( "crypto/ed25519" "crypto/rand" "encoding/binary" + "fmt" "strings" "time" @@ -72,7 +73,10 @@ func NewNaclBoxBlockchainMessage( return nil, errors.New("sender account private key unavailable") } - encryptedMessage, nonce := encryptMessageUsingNaclBox(sender, receiver, plaintextMessage) + encryptedMessage, nonce, err := encryptMessageUsingNaclBox(sender, receiver, plaintextMessage) + if err != nil { + return nil, fmt.Errorf("failed to encrypt with nacl box: %w", err) + } if len(encryptedMessage)+len(senderDomain) > maxNaclBoxDynamicContentSize { return nil, errors.New("encrypted message length exceeds limit") @@ -223,10 +227,12 @@ func DecodeNaclBoxBlockchainMessage(payload []byte) (*NaclBoxBlockchainMessage, }, nil } -func encryptMessageUsingNaclBox(sender, receiver *common.Account, plaintextMessage string) ([]byte, naclBoxNonce) { +func encryptMessageUsingNaclBox(sender, receiver *common.Account, plaintextMessage string) ([]byte, naclBoxNonce, error) { var nonce naclBoxNonce - rand.Read(nonce[:]) - return encryptMessageUsingNaclBoxWithProvidedNonce(sender, receiver, plaintextMessage, nonce), nonce + if _, err := rand.Read(nonce[:]); err != nil { + return nil, nonce, err + } + return encryptMessageUsingNaclBoxWithProvidedNonce(sender, receiver, plaintextMessage, nonce), nonce, nil } // Nonce should always be random. Use encryptMessageUsingNaclBox, unless testing diff --git a/pkg/code/transaction/nonce.go b/pkg/code/transaction/nonce.go index 1f6fc7f3..e041afaf 100644 --- a/pkg/code/transaction/nonce.go +++ b/pkg/code/transaction/nonce.go @@ -8,16 +8,17 @@ import ( "github.com/mr-tron/base58" - "github.com/code-payments/code-server/pkg/retry" - "github.com/code-payments/code-server/pkg/solana" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/retry" + "github.com/code-payments/code-server/pkg/solana" ) var ( ErrNoAvailableNonces = errors.New("no available nonces") + ErrUnlocked = errors.New("nonce is unlocked") ) var ( @@ -66,7 +67,7 @@ func SelectAvailableNonce(ctx context.Context, data code_data.Provider, useCase defer globalNonceLock.Unlock() randomRecord, err := data.GetRandomAvailableNonceByPurpose(ctx, useCase) - if err == nonce.ErrNonceNotFound { + if errors.Is(err, nonce.ErrNonceNotFound) { return ErrNoAvailableNonces } else if err != nil { return err @@ -257,7 +258,7 @@ func (n *SelectedNonce) ReleaseIfNotReserved() error { defer n.localLock.Unlock() if n.isUnlocked { - return errors.New("nonce is unlocked") + return ErrUnlocked } if n.record.State == nonce.StateAvailable { diff --git a/pkg/code/transaction/nonce_test.go b/pkg/code/transaction/nonce_test.go index ca8d7f28..ffb6e860 100644 --- a/pkg/code/transaction/nonce_test.go +++ b/pkg/code/transaction/nonce_test.go @@ -10,14 +10,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/solana" - "github.com/code-payments/code-server/pkg/testutil" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/nonce" "github.com/code-payments/code-server/pkg/code/data/vault" + "github.com/code-payments/code-server/pkg/pointer" + "github.com/code-payments/code-server/pkg/solana" + "github.com/code-payments/code-server/pkg/testutil" ) func TestNonce_SelectAvailableNonce(t *testing.T) { @@ -226,7 +226,8 @@ func generateAvailableNonce(t *testing.T, env nonceTestEnv, useCase nonce.Purpos nonceAccount := testutil.NewRandomAccount(t) var bh solana.Blockhash - rand.Read(bh[:]) + _, err := rand.Read(bh[:]) + require.NoError(t, err) nonceKey := &vault.Record{ PublicKey: nonceAccount.PublicKey().ToBase58(), diff --git a/pkg/code/webhook/execution.go b/pkg/code/webhook/execution.go index 7573e7f6..2364e4b1 100644 --- a/pkg/code/webhook/execution.go +++ b/pkg/code/webhook/execution.go @@ -13,11 +13,11 @@ import ( messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" - "github.com/code-payments/code-server/pkg/metrics" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/webhook" "github.com/code-payments/code-server/pkg/code/server/grpc/messaging" + "github.com/code-payments/code-server/pkg/metrics" ) const ( @@ -95,7 +95,10 @@ func Execute( resp, err := http.DefaultClient.Do(webhookReq) if err != nil { return errors.Wrap(err, "error executing http post request") - } else if resp.StatusCode != http.StatusOK { + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { return errors.Errorf("%d status code returned", resp.StatusCode) } diff --git a/pkg/config/wrapper/wrappers.go b/pkg/config/wrapper/wrappers.go index f64ca561..c47e85fb 100644 --- a/pkg/config/wrapper/wrappers.go +++ b/pkg/config/wrapper/wrappers.go @@ -437,10 +437,10 @@ func (c *DurationConfig) GetSafe(ctx context.Context) (time.Duration, error) { } else if err != nil { return lastValue, err } - switch override.(type) { + switch override := override.(type) { case []byte: var newValue time.Duration - strValue := string(override.([]byte)) + strValue := string(override) newValue, err = time.ParseDuration(strValue) if err != nil { @@ -451,7 +451,7 @@ func (c *DurationConfig) GetSafe(ctx context.Context) (time.Duration, error) { c.stateMu.Unlock() return newValue, nil case time.Duration: - newValue := override.(time.Duration) + newValue := override c.stateMu.Lock() c.lastValue = newValue c.stateMu.Unlock() diff --git a/pkg/currency/coingecko/client.go b/pkg/currency/coingecko/client.go index 99171cf0..1b674973 100644 --- a/pkg/currency/coingecko/client.go +++ b/pkg/currency/coingecko/client.go @@ -89,7 +89,9 @@ func (c *client) submitRequest(ctx context.Context, url string, body io.Reader, var httpResp *http.Response _, err = c.retrier.Retry( func() error { - httpResp, err = c.httpClient.Do(req) + // Retry only occurs if err != nil, in which case the body does not need to be closed. + // The body itself is closed below + httpResp, err = c.httpClient.Do(req) //nolint:bodyclose return err }, ) diff --git a/pkg/currency/fixer/client.go b/pkg/currency/fixer/client.go index d0220e7f..1ac3f3a8 100644 --- a/pkg/currency/fixer/client.go +++ b/pkg/currency/fixer/client.go @@ -113,7 +113,9 @@ func (c *client) submitRequest(ctx context.Context, url string, resp interface{} var httpResp *http.Response _, err = c.retrier.Retry( func() error { - httpResp, err = c.httpClient.Do(req) + // Retry only occurs if err != nil, in which case the body does not need to be closed. + // The body itself is closed below + httpResp, err = c.httpClient.Do(req) //nolint:bodyclose return err }, ) diff --git a/pkg/database/postgres/db.go b/pkg/database/postgres/db.go index 19d689a4..37800c1b 100644 --- a/pkg/database/postgres/db.go +++ b/pkg/database/postgres/db.go @@ -4,16 +4,15 @@ import ( "context" "database/sql" "errors" + "fmt" "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" "github.com/jmoiron/sqlx" ) -const ( - txStructContextKey = "code-sqlx-tx-struct" - txIsolationContextKey = "code-sqlx-isolation" -) +type txStructContextKey struct{} +type txIsolationContextKey struct{} var ( ErrAlreadyInTx = errors.New("already executing in existing db tx") @@ -41,7 +40,7 @@ func ExecuteTxWithinCtx(ctx context.Context, db *sqlx.DB, isolation sql.Isolatio isolation = sql.LevelReadCommitted // Postgres default } - existing := ctx.Value(txStructContextKey) + existing := ctx.Value(txStructContextKey{}) if existing != nil { return ErrAlreadyInTx } @@ -53,13 +52,16 @@ func ExecuteTxWithinCtx(ctx context.Context, db *sqlx.DB, isolation sql.Isolatio return err } - ctx = context.WithValue(ctx, txStructContextKey, tx) - ctx = context.WithValue(ctx, txIsolationContextKey, isolation) + ctx = context.WithValue(ctx, txStructContextKey{}, tx) + ctx = context.WithValue(ctx, txIsolationContextKey{}, isolation) err = fn(ctx) if err != nil { // We always need to execute a Rollback() so sql.DB releases the connection. - tx.Rollback() + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("failed to rollback transaction: %w", rollbackErr) + } + return err } return tx.Commit() @@ -94,7 +96,9 @@ func ExecuteInTx(ctx context.Context, db *sqlx.DB, isolation sql.IsolationLevel, if err != nil { if startedNewTx { // We always need to execute a Rollback() so sql.DB releases the connection. - tx.Rollback() + if rollBackErr := tx.Rollback(); rollBackErr != nil { + return fmt.Errorf("failed to rollback transaction: %w", rollBackErr) + } } return err } @@ -105,12 +109,12 @@ func ExecuteInTx(ctx context.Context, db *sqlx.DB, isolation sql.IsolationLevel, } func getTxFromCtx(ctx context.Context, desiredIsolation sql.IsolationLevel) (*sqlx.Tx, error) { - txFromCtx := ctx.Value(txStructContextKey) + txFromCtx := ctx.Value(txStructContextKey{}) if txFromCtx == nil { return nil, ErrNotInTx } - isolationFromCtx := ctx.Value(txIsolationContextKey) + isolationFromCtx := ctx.Value(txIsolationContextKey{}) if isolationFromCtx == nil { return nil, errors.New("unexpectedly don't have isolation level set") } diff --git a/pkg/database/postgres/test/util.go b/pkg/database/postgres/test/util.go index 1880ccc9..7d220519 100644 --- a/pkg/database/postgres/test/util.go +++ b/pkg/database/postgres/test/util.go @@ -10,7 +10,7 @@ import ( "github.com/pkg/errors" - _ "github.com/jackc/pgx/v4/stdlib" + _ "github.com/jackc/pgx/v4/stdlib" //nolint:revive "github.com/code-payments/code-server/pkg/retry" "github.com/code-payments/code-server/pkg/retry/backoff" @@ -19,7 +19,7 @@ import ( const ( containerName = "postgres" containerVersion = "10.4" - containerAutoKill = 120 // seconds + containerAutoKill = 120 * time.Second port = 5432 user = "localtest" @@ -82,8 +82,12 @@ func StartPostgresDB(pool *dockertest.Pool) (db *sql.DB, closeFunc func(), err e // logrus.StandardLogger().Println("Connecting to database on url: ", databaseUrl) // logrus.StandardLogger().Println("Setting container auto-kill to: ", containerAutoKill, " seconds") + // Tell docker to expire the container (kill) after containerAutoKill (120 sec). + // // You may need to adjust this number if it is too low for your test environment. - resource.Expire(containerAutoKill) // Tell docker to hard kill the container in 120 seconds + // + // 2024/04/11: Expire() _never_ returns an error. + _ = resource.Expire(uint(containerAutoKill.Seconds())) _, err = retry.Retry( func() error { diff --git a/pkg/database/query/cursor.go b/pkg/database/query/cursor.go index 533ef49f..08e29085 100644 --- a/pkg/database/query/cursor.go +++ b/pkg/database/query/cursor.go @@ -9,7 +9,7 @@ import ( type Cursor []byte var ( - EmptyCursor Cursor = Cursor([]byte{}) + EmptyCursor = Cursor([]byte{}) ) func ToCursor(val uint64) Cursor { diff --git a/pkg/database/query/query.go b/pkg/database/query/query.go index cfaf9112..165faf8a 100644 --- a/pkg/database/query/query.go +++ b/pkg/database/query/query.go @@ -21,7 +21,7 @@ const ( CanFilterBy SupportedOptions = 0x01 << 6 ) -type QueryOptions struct { +type Options struct { Supported SupportedOptions Start time.Time @@ -34,13 +34,13 @@ type QueryOptions struct { FilterBy Filter } -type Option func(*QueryOptions) error +type Option func(*Options) error -func (qo *QueryOptions) check(cap SupportedOptions) bool { +func (qo *Options) check(cap SupportedOptions) bool { return qo.Supported&cap != cap } -func (qo *QueryOptions) Apply(opts ...Option) error { +func (qo *Options) Apply(opts ...Option) error { for _, o := range opts { err := o(qo) if err != nil { @@ -51,7 +51,7 @@ func (qo *QueryOptions) Apply(opts ...Option) error { } func WithInterval(val Interval) Option { - return func(qo *QueryOptions) error { + return func(qo *Options) error { if qo.check(CanBucketBy) { return ErrQueryNotSupported } @@ -61,7 +61,7 @@ func WithInterval(val Interval) Option { } func WithFilter(val Filter) Option { - return func(qo *QueryOptions) error { + return func(qo *Options) error { if qo.check(CanFilterBy) { return ErrQueryNotSupported } @@ -71,7 +71,7 @@ func WithFilter(val Filter) Option { } func WithDirection(val Ordering) Option { - return func(qo *QueryOptions) error { + return func(qo *Options) error { if qo.check(CanSortBy) { return ErrQueryNotSupported } @@ -81,7 +81,7 @@ func WithDirection(val Ordering) Option { } func WithLimit(val uint64) Option { - return func(qo *QueryOptions) error { + return func(qo *Options) error { if qo.check(CanLimitResults) { return ErrQueryNotSupported } @@ -91,7 +91,7 @@ func WithLimit(val uint64) Option { } func WithCursor(val []byte) Option { - return func(qo *QueryOptions) error { + return func(qo *Options) error { if qo.check(CanQueryByCursor) { return ErrQueryNotSupported } @@ -101,7 +101,7 @@ func WithCursor(val []byte) Option { } func WithStartTime(val time.Time) Option { - return func(qo *QueryOptions) error { + return func(qo *Options) error { if qo.check(CanQueryByStartTime) { return ErrQueryNotSupported } @@ -111,7 +111,7 @@ func WithStartTime(val time.Time) Option { } func WithEndTime(val time.Time) Option { - return func(qo *QueryOptions) error { + return func(qo *Options) error { if qo.check(CanQueryByEndTime) { return ErrQueryNotSupported } diff --git a/pkg/database/query/utils.go b/pkg/database/query/utils.go index 23904d81..ba90efc9 100644 --- a/pkg/database/query/utils.go +++ b/pkg/database/query/utils.go @@ -9,26 +9,33 @@ const ( // PaginateQuery returns a paginated query string for the given input options. // // The input query string is expected as follows: -// "SELECT ... WHERE (...)" <- these brackets are not optional +// +// "SELECT ... WHERE (...)" <- these brackets are not optional // // The output query string would be as follows: -// "SELECT ... WHERE (...) AND id > ? ORDER BY ? LIMIT ?" -// -or- -// "SELECT ... WHERE (...) AND id < ? ORDER BY ? LIMIT ?" +// +// "SELECT ... WHERE (...) AND id > ? ORDER BY ? LIMIT ?" +// -or- +// "SELECT ... WHERE (...) AND id < ? ORDER BY ? LIMIT ?" // // Example: -// query := "SELECT * FROM table WHERE (state = $1 OR age > $2)" // -// opts := []interface{}{ state: 123, age: 45, } -// cursor := 123 -// limit := 10 -// direction := Ascending +// query := "SELECT * FROM table WHERE (state = $1 OR age > $2)" // -// PaginateQuery(query, opts, cursor, limit, direction) -// > "SELECT * FROM table WHERE (state = $1 OR age > $2) AND cursor > $3 ORDER BY id ASC LIMIT 10" -func PaginateQuery(query string, opts []interface{}, - cursor Cursor, limit uint64, direction Ordering) (string, []interface{}) { - +// opts := []interface{}{ state: 123, age: 45, } +// cursor := 123 +// limit := 10 +// direction := Ascending +// +// PaginateQuery(query, opts, cursor, limit, direction) +// > "SELECT * FROM table WHERE (state = $1 OR age > $2) AND cursor > $3 ORDER BY id ASC LIMIT 10" +func PaginateQuery( + query string, + opts []interface{}, + cursor Cursor, + limit uint64, + direction Ordering, +) (string, []interface{}) { if len(cursor) > 0 { v := strconv.Itoa(len(opts) + 1) @@ -58,13 +65,15 @@ func PaginateQuery(query string, opts []interface{}, return query, opts } -func DefaultPaginationHandler(opts ...Option) (*QueryOptions, error) { - req := QueryOptions{ +func DefaultPaginationHandler(opts ...Option) (*Options, error) { + req := Options{ Limit: defaultPagingLimit, SortBy: Ascending, Supported: CanLimitResults | CanSortBy | CanQueryByCursor, } - req.Apply(opts...) + if err := req.Apply(opts...); err != nil { + return nil, ErrQueryNotSupported + } if req.Limit > defaultPagingLimit { return nil, ErrQueryNotSupported @@ -73,13 +82,15 @@ func DefaultPaginationHandler(opts ...Option) (*QueryOptions, error) { return &req, nil } -func DefaultPaginationHandlerWithLimit(limit uint64, opts ...Option) (*QueryOptions, error) { - req := QueryOptions{ +func DefaultPaginationHandlerWithLimit(limit uint64, opts ...Option) (*Options, error) { + req := Options{ Limit: limit, SortBy: Ascending, Supported: CanLimitResults | CanSortBy | CanQueryByCursor, } - req.Apply(opts...) + if err := req.Apply(opts...); err != nil { + return nil, ErrQueryNotSupported + } if req.Limit > limit { return nil, ErrQueryNotSupported diff --git a/pkg/grpc/headers/headers_context.go b/pkg/grpc/headers/headers_context.go index 66e7d8a4..62516ee9 100644 --- a/pkg/grpc/headers/headers_context.go +++ b/pkg/grpc/headers/headers_context.go @@ -5,9 +5,9 @@ import ( "fmt" "strings" - "github.com/golang/protobuf/proto" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "google.golang.org/protobuf/proto" ) // HeaderKey is the key to store all the header information in the context @@ -61,18 +61,6 @@ func setHeader(ctx context.Context, data proto.Message, name string, headerType return err } -// GetHeader takes the inbound header and unmarshal the data into the destination message. -// If no header exists, destination will have all default values. -func GetHeader(ctx context.Context, destination proto.Message) error { - return getHeader(ctx, destination, getBinaryHeaderName(destination), Inbound) -} - -// GetHeaderByName takes the inbound header and unmarshal the data into the destination messaging. -// If no header exists, destination will have all default values. -func GetHeaderByName(ctx context.Context, destination proto.Message, name string) error { - return getHeader(ctx, destination, name, Inbound) -} - // GetStringHeaderByName takes the inbound header and returns the data, if it is a string. // If no header exists or the header value is not a string, an empty string will be returned. func GetASCIIHeaderByName(ctx context.Context, name string) (string, error) { @@ -98,29 +86,11 @@ func GetStringHeaderByName(ctx context.Context, name string) (string, error) { return getStringHeader(ctx, name, Inbound) } -// GetPropagatingHeader takes the propagating header and unmarshal the data into the destination message. -// If no header exists, destination will have all default values -func GetPropagatingHeader(ctx context.Context, destination proto.Message) error { - return getHeader(ctx, destination, getBinaryHeaderName(destination), Propagating) -} - -// GetRootHeader takes the root header and unmarshal the data into the destination message. -// If no header exists, destination will have all default values. -func GetRootHeader(ctx context.Context, destination proto.Message) error { - return getHeader(ctx, destination, getBinaryHeaderName(destination), Root) -} - -// GetRootHeaderByName takes the root header and unmarshal the data into the destination message. -// If no header exists, destination will have all default values. -func GetRootHeaderByName(ctx context.Context, destination proto.Message, name string) error { - return getHeader(ctx, destination, name, Root) -} - // GetAnyBinaryHeader will try to find the header by checking all the binary headers, in the following order. // -// 1. GetRootHeader() -// 2. GetHeader() -// 3. GetPropagatingHeader() +// 1. Root +// 2. Inbound +// 3. Propagating // // Note, it is not recommended to to use this method to retrieve headers, as it may induce // some ambiguity. For example, XiUUid may have a different meaning in each Type. @@ -130,20 +100,20 @@ func GetAnyBinaryHeader(ctx context.Context, destination proto.Message) error { // GetAnyBinaryHeaderByName will try to find the header by checking all the binary headers, in the following order. // -// 1. GetRootHeader() -// 2. GetHeader() -// 3. GetPropagatingHeader() +// 1. Root +// 2. Inbound +// 3. Propagating // -// Note, it is not recommended to to use this method to retrieve headers, as it may induce +// Note, it is not recommended to use this method to retrieve headers, as it may induce // some ambiguity. For example, XiUUid may have a different meaning in each Type. func GetAnyBinaryHeaderByName(ctx context.Context, destination proto.Message, name string) error { headers := []Type{Root, Inbound, Propagating} for _, headerType := range headers { - err := getHeader(ctx, destination, name, headerType) - if err == nil && destination.String() != "" { - return nil - } else if !strings.Contains(err.Error(), "header not found") { + found, err := getHeader(ctx, destination, name, headerType) + if err != nil { return errors.Wrapf(err, "issue finding header %s in type %s", proto.MessageName(destination), headerType) + } else if found { + return nil } } @@ -152,23 +122,23 @@ func GetAnyBinaryHeaderByName(ctx context.Context, destination proto.Message, na // getHeader takes the given header and unmarshal the data into the destination message. // If no header exists, destination will have all default values. -func getHeader(ctx context.Context, destination proto.Message, name string, headerType Type) error { +func getHeader(ctx context.Context, destination proto.Message, name string, headerType Type) (bool, error) { selectedHeader, err := getHeadersFromContext(ctx, headerType) if err != nil { - return err + return false, err } headerName := getPrefixedHeaderName(name, headerType) data, exists := selectedHeader[headerName] if !exists { logrus.StandardLogger().Tracef("Header %s not found in type %s. Proto %s will have default values", headerName, headerType, proto.MessageName(destination)) - return fmt.Errorf("%s %s header not found", headerType, proto.MessageName(destination)) + return false, nil } if rawBytes, ok := data.([]byte); ok { err := proto.Unmarshal(rawBytes, destination) - return errors.Wrapf(err, "failed to unmarshal %s header %s", headerType, proto.MessageName(destination)) + return true, errors.Wrapf(err, "failed to unmarshal %s header %s", headerType, proto.MessageName(destination)) } - return fmt.Errorf("header %s in type %s does not have a binary value (%T)", proto.MessageName(destination), headerType, data) + return true, fmt.Errorf("header %s in type %s does not have a binary value (%T)", proto.MessageName(destination), headerType, data) } // getStringHeader takes the given header and returns the data, if it is a string. @@ -219,7 +189,7 @@ func getHeadersFromContext(ctx context.Context, headerType Type) (Headers, error // getBinaryHeaderName returns the name of the header using data's ProtoName() func and appending "-bin" to symbolize it is a binary header. // An absence of "-bin" symbolizes a ascii header. func getBinaryHeaderName(data proto.Message) string { - return proto.MessageName(data) + "-bin" + return string(proto.MessageName(data)) + "-bin" } // getPrefixedHeaderName returns the prefixed header name. diff --git a/pkg/grpc/metrics/new_relic_server_interceptor.go b/pkg/grpc/metrics/new_relic_server_interceptor.go index 78182b96..7fde97f1 100644 --- a/pkg/grpc/metrics/new_relic_server_interceptor.go +++ b/pkg/grpc/metrics/new_relic_server_interceptor.go @@ -148,7 +148,7 @@ func CustomNewRelicUnaryServerInterceptor(app *newrelic.Application) grpc_core.U return func(ctx context.Context, req interface{}, info *grpc_core.UnaryServerInfo, handler grpc_core.UnaryHandler) (interface{}, error) { // Inject the application to allow for any custom metrics, events, etc // in downstream code. - ctx = context.WithValue(ctx, metrics.NewRelicContextKey, app) + ctx = context.WithValue(ctx, metrics.NewRelicContextKey{}, app) m := startTransaction(ctx, app, info.FullMethod) defer m.End() @@ -181,7 +181,7 @@ func CustomNewRelicStreamServerInterceptor(app *newrelic.Application) grpc_core. return func(srv interface{}, ss grpc_core.ServerStream, info *grpc_core.StreamServerInfo, handler grpc_core.StreamHandler) error { // Inject the application to allow for any custom metrics, events, etc // in downstream code. - ctx := context.WithValue(ss.Context(), metrics.NewRelicContextKey, app) + ctx := context.WithValue(ss.Context(), metrics.NewRelicContextKey{}, app) m := startTransaction(ctx, app, info.FullMethod) defer m.End() diff --git a/pkg/grpc/util.go b/pkg/grpc/util.go index cd5113bf..8c9188a5 100644 --- a/pkg/grpc/util.go +++ b/pkg/grpc/util.go @@ -13,7 +13,7 @@ import ( var ( healthCheckEndpoint = "/grpc.health.v1.Health/Check" - fullMethodNameRegex = regexp.MustCompile("/([a-zA-Z0-9]+\\.)+[a-zA-Z0-9]+/[a-zA-Z0-9]+") + fullMethodNameRegex = regexp.MustCompile(`/([a-zA-Z0-9]+\.)+[a-zA-Z0-9]+/[a-zA-Z0-9]+`) ) // ParseFullMethodName parses a gRPC full method name into its components diff --git a/pkg/kikcode/payload.go b/pkg/kikcode/payload.go index 14613bd6..ac0fc41a 100644 --- a/pkg/kikcode/payload.go +++ b/pkg/kikcode/payload.go @@ -136,9 +136,9 @@ func (p *Payload) ToQrCodeDescription(dimension float64) (*Description, error) { return nil, err } - kikCodePayload := CreateKikCodePayload(viewPayload) + scanPayload := CreateScanPayload(viewPayload) - return GenerateDescription(dimension, kikCodePayload) + return GenerateDescription(dimension, scanPayload) } func (p *Payload) GetIdempotencyKey() IdempotencyKey { @@ -149,10 +149,10 @@ func (p *Payload) ToRendezvousKey() ed25519.PrivateKey { return DeriveRendezvousPrivateKey(p) } -func GenerateRandomIdempotencyKey() IdempotencyKey { +func GenerateRandomIdempotencyKey() (IdempotencyKey, error) { var buffer [nonceSize]byte - rand.Read(buffer[:]) - return buffer + _, err := rand.Read(buffer[:]) + return buffer, err } type amountBuffer interface { diff --git a/pkg/kikcode/qr.go b/pkg/kikcode/qr.go index 3e45fc39..a25b1e05 100644 --- a/pkg/kikcode/qr.go +++ b/pkg/kikcode/qr.go @@ -39,9 +39,9 @@ type Description struct { dotDimension float64 } -type KikCodePayload []byte +type ScanPayload []byte -func GenerateDescription(dimension float64, data KikCodePayload) (*Description, error) { +func GenerateDescription(dimension float64, data ScanPayload) (*Description, error) { if dimension <= 0 { return nil, ErrInvalidSize } @@ -129,7 +129,7 @@ func GenerateDescription(dimension float64, data KikCodePayload) (*Description, } } - offset += 1 + offset++ } } @@ -142,7 +142,7 @@ func GenerateDescription(dimension float64, data KikCodePayload) (*Description, }, nil } -func CreateKikCodePayload(data []byte) KikCodePayload { +func CreateScanPayload(data []byte) ScanPayload { finderBytes := []byte{0xb2, 0xcb, 0x25, 0xc6} return append(finderBytes, data...) } diff --git a/pkg/merkletree/tree.go b/pkg/merkletree/tree.go index 03eccec7..79345079 100644 --- a/pkg/merkletree/tree.go +++ b/pkg/merkletree/tree.go @@ -81,14 +81,14 @@ func (t *MerkleTree) AddLeaf(leaf Leaf) error { } t.root = currentLevelHash - t.nextIndex += 1 + t.nextIndex++ return nil } func (t *MerkleTree) GetRoot() Hash { - var copy Hash - return append(copy, t.root...) + var cpy Hash + return append(cpy, t.root...) } func (t *MerkleTree) GetLeafHash(leaf Leaf) Hash { @@ -115,11 +115,11 @@ func (t *MerkleTree) GetExpectedHashFromPair(h1, h2 Hash) Hash { } func (t *MerkleTree) GetZeroValues() []Hash { - copy := make([]Hash, len(t.zeroValues)) + cpy := make([]Hash, len(t.zeroValues)) for i, zeroValue := range t.zeroValues { - copy[i] = append(copy[i], zeroValue...) + cpy[i] = append(cpy[i], zeroValue...) } - return copy + return cpy } // todo: We'll need a more efficient version of this method in production diff --git a/pkg/metrics/constants.go b/pkg/metrics/constants.go index 3cb25229..fcb471f9 100644 --- a/pkg/metrics/constants.go +++ b/pkg/metrics/constants.go @@ -1,5 +1,3 @@ package metrics -const ( - NewRelicContextKey = "newrelic_context" -) +type NewRelicContextKey struct{} diff --git a/pkg/metrics/events.go b/pkg/metrics/events.go index a7fce392..df9bca63 100644 --- a/pkg/metrics/events.go +++ b/pkg/metrics/events.go @@ -8,7 +8,7 @@ import ( // RecordEvent records a new event with a name and set of key-value pairs func RecordEvent(ctx context.Context, eventName string, kvPairs map[string]interface{}) { - nr, ok := ctx.Value(NewRelicContextKey).(*newrelic.Application) + nr, ok := ctx.Value(NewRelicContextKey{}).(*newrelic.Application) if ok { nr.RecordCustomEvent(eventName, kvPairs) } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 3ad8a2af..f9066cc3 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -9,7 +9,7 @@ import ( // RecordCount records a count metric func RecordCount(ctx context.Context, metricName string, count uint64) { - nr, ok := ctx.Value(NewRelicContextKey).(*newrelic.Application) + nr, ok := ctx.Value(NewRelicContextKey{}).(*newrelic.Application) if ok { nr.RecordCustomMetric(metricName, float64(count)) } @@ -17,7 +17,7 @@ func RecordCount(ctx context.Context, metricName string, count uint64) { // RecordDuration records a duration metric func RecordDuration(ctx context.Context, metricName string, duration time.Duration) { - nr, ok := ctx.Value(NewRelicContextKey).(*newrelic.Application) + nr, ok := ctx.Value(NewRelicContextKey{}).(*newrelic.Application) if ok { nr.RecordCustomMetric(metricName, float64(duration/time.Millisecond)) } diff --git a/pkg/netutil/url.go b/pkg/netutil/url.go index edd586ec..42b9f855 100644 --- a/pkg/netutil/url.go +++ b/pkg/netutil/url.go @@ -51,7 +51,9 @@ func ValidateHttpUrl( var resp *http.Response _, err = retry.Retry( func() error { - resp, err = http.Get(value) + // Retry only occurs if err != nil, in which case the body does not need to be closed. + // The body itself is closed below + resp, err = http.Get(value) //nolint:bodyclose return err }, retry.Limit(5), @@ -60,6 +62,7 @@ func ValidateHttpUrl( if err != nil { return err } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return errors.Errorf("%d status code fetching content", resp.StatusCode) diff --git a/pkg/osutil/memory.go b/pkg/osutil/memory.go index 84f7b125..a5c8fec3 100644 --- a/pkg/osutil/memory.go +++ b/pkg/osutil/memory.go @@ -1,7 +1,7 @@ package osutil import ( - "io/ioutil" + "os" "strconv" "strings" @@ -24,7 +24,7 @@ const ( func GetTotalMemory() uint64 { totalMemory := memory.TotalMemory() - cgroupLimit, err := ioutil.ReadFile(dockerMemoryLimitLocation) + cgroupLimit, err := os.ReadFile(dockerMemoryLimitLocation) if err == nil { dockerMemoryLimit, err := strconv.ParseUint(strings.Replace(string(cgroupLimit), "\n", "", 1), 10, 64) if err == nil && dockerMemoryLimit != unrestrictedMemoryLimit { diff --git a/pkg/phone/validation.go b/pkg/phone/validation.go index d34b50ae..84f7588e 100644 --- a/pkg/phone/validation.go +++ b/pkg/phone/validation.go @@ -4,7 +4,7 @@ import "regexp" var ( // E.164 phone number format regex provided by Twilio: https://www.twilio.com/docs/glossary/what-e164#regex-matching-for-e164 - phonePattern = regexp.MustCompile("^\\+[1-9]\\d{1,14}$") + phonePattern = regexp.MustCompile(`^\+[1-9]\d{1,14}$`) // A verification code must be a 4-10 digit string verificationCodePattern = regexp.MustCompile("^[0-9]{4,10}$") diff --git a/pkg/solana/errors.go b/pkg/solana/errors.go index 24e07965..a3f5b177 100644 --- a/pkg/solana/errors.go +++ b/pkg/solana/errors.go @@ -123,7 +123,8 @@ func parseInstructionError(v interface{}) (e InstructionError, err error) { var k string var v interface{} - for k, v = range t { + for k, v = range t { //nolint:revive + // Retrieve the only KV from the map } if k != "Custom" { @@ -225,7 +226,8 @@ func ParseTransactionError(raw interface{}) (*TransactionError, error) { var k string var v interface{} - for k, v = range t { + for k, v = range t { //nolint:revive + // Retrieve the only KV from the map } if k != "InstructionError" { diff --git a/pkg/solana/shortvec/shortvec.go b/pkg/solana/shortvec/shortvec.go index fd3b45eb..1cc95bad 100644 --- a/pkg/solana/shortvec/shortvec.go +++ b/pkg/solana/shortvec/shortvec.go @@ -9,18 +9,18 @@ import ( // EncodeLen encodes the specified len into the writer. // // If len > math.MaxUint16, an error is returned. -func EncodeLen(w io.Writer, len int) (n int, err error) { - if len > math.MaxUint16 { - return 0, fmt.Errorf("len exceeds %d", math.MaxUint16) +func EncodeLen(w io.Writer, length int) (n int, err error) { + if length > math.MaxUint16 { + return 0, fmt.Errorf("length exceeds %d", math.MaxUint16) } written := 0 valBuf := make([]byte, 1) for { - valBuf[0] = byte(len & 0x7f) - len >>= 7 - if len == 0 { + valBuf[0] = byte(length & 0x7f) + length >>= 7 + if length == 0 { n, err := w.Write(valBuf) written += n diff --git a/pkg/solana/splitter/types_merkletree.go b/pkg/solana/splitter/types_merkletree.go index de1441f7..0aad02a3 100644 --- a/pkg/solana/splitter/types_merkletree.go +++ b/pkg/solana/splitter/types_merkletree.go @@ -90,7 +90,7 @@ func (obj *MerkleTree) ToString() string { func (obj *MerkleTree) Marshal() []byte { data := make([]byte, getMerkleTreeSize(obj.Levels)) - var offset int = 0 + var offset int putUint8(data, obj.Levels, &offset) putUint64(data, obj.NextIndex, &offset) @@ -112,7 +112,7 @@ func (obj *MerkleTree) Marshal() []byte { // Deserializes the {@link MerkleTree} from the provided data Buffer. // @returns an error if the deserialize operation was unsuccessful. func (obj *MerkleTree) Unmarshal(data []byte) error { - var offset int = 0 + var offset int getUint8(data, &obj.Levels, &offset) diff --git a/pkg/solana/splitter/utils.go b/pkg/solana/splitter/utils.go index 6b7f66fd..35cdaef1 100644 --- a/pkg/solana/splitter/utils.go +++ b/pkg/solana/splitter/utils.go @@ -8,7 +8,7 @@ import ( "github.com/mr-tron/base58" ) -const optionalSize = 1 +const optionalSize = 1 //nolint:unused func putDiscriminator(dst []byte, src []byte, offset *int) { copy(dst[*offset:], src) @@ -36,7 +36,7 @@ func putBool(dst []byte, v bool, offset *int) { } else { dst[*offset] = 0 } - *offset += 1 + *offset++ } func getBool(src []byte, dst *bool, offset *int) { if src[*offset] == 1 { @@ -44,16 +44,16 @@ func getBool(src []byte, dst *bool, offset *int) { } else { *dst = false } - *offset += 1 + *offset++ } func putUint8(dst []byte, v uint8, offset *int) { dst[*offset] = v - *offset += 1 + *offset++ } func getUint8(src []byte, dst *uint8, offset *int) { *dst = src[*offset] - *offset += 1 + *offset++ } func putUint32(dst []byte, v uint32, offset *int) { @@ -74,7 +74,7 @@ func getUint64(src []byte, dst *uint64, offset *int) { *offset += 8 } -func putOptionalKey(dst []byte, src []byte, offset *int) { +func putOptionalKey(dst []byte, src []byte, offset *int) { //nolint:unused if len(src) > 0 { dst[*offset] = 1 copy(dst[*offset+optionalSize:], src) @@ -84,7 +84,7 @@ func putOptionalKey(dst []byte, src []byte, offset *int) { *offset += optionalSize } } -func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { +func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { //nolint:unused if src[*offset] == 1 { *dst = make([]byte, ed25519.PublicKeySize) copy(*dst, src[*offset+optionalSize:]) @@ -94,7 +94,7 @@ func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { } } -func putOptionalUint32(dst []byte, v *uint32, offset *int) { +func putOptionalUint32(dst []byte, v *uint32, offset *int) { //nolint:unused if v != nil { dst[*offset] = 1 binary.LittleEndian.PutUint32(dst[*offset+optionalSize:], *v) @@ -104,7 +104,7 @@ func putOptionalUint32(dst []byte, v *uint32, offset *int) { *offset += optionalSize } } -func getOptionalUint32(src []byte, dst **uint32, offset *int) { +func getOptionalUint32(src []byte, dst **uint32, offset *int) { //nolint:unused if src[*offset] == 1 { val := binary.LittleEndian.Uint32(src[*offset+optionalSize:]) *dst = &val @@ -114,7 +114,7 @@ func getOptionalUint32(src []byte, dst **uint32, offset *int) { } } -func putOptionalUint64(dst []byte, v *uint64, offset *int) { +func putOptionalUint64(dst []byte, v *uint64, offset *int) { //nolint:unused if v != nil { dst[*offset] = 1 binary.LittleEndian.PutUint64(dst[*offset+optionalSize:], *v) @@ -124,7 +124,7 @@ func putOptionalUint64(dst []byte, v *uint64, offset *int) { *offset += optionalSize } } -func getOptionalUint64(src []byte, dst **uint64, offset *int) { +func getOptionalUint64(src []byte, dst **uint64, offset *int) { //nolint:unused if src[*offset] == 1 { val := binary.LittleEndian.Uint64(src[*offset+optionalSize:]) *dst = &val diff --git a/pkg/solana/swapvalidator/utils.go b/pkg/solana/swapvalidator/utils.go index 1477ed28..cbc0a829 100644 --- a/pkg/solana/swapvalidator/utils.go +++ b/pkg/solana/swapvalidator/utils.go @@ -13,17 +13,17 @@ func putDiscriminator(dst []byte, src []byte, offset *int) { copy(dst[*offset:], src) *offset += 8 } -func getDiscriminator(src []byte, dst *[]byte, offset *int) { +func getDiscriminator(src []byte, dst *[]byte, offset *int) { //nolint:unused *dst = make([]byte, 8) copy(*dst, src[*offset:]) *offset += 8 } -func putKey(dst []byte, src []byte, offset *int) { +func putKey(dst []byte, src []byte, offset *int) { //nolint:unused copy(dst[*offset:], src) *offset += ed25519.PublicKeySize } -func getKey(src []byte, dst *ed25519.PublicKey, offset *int) { +func getKey(src []byte, dst *ed25519.PublicKey, offset *int) { //nolint:unused *dst = make([]byte, ed25519.PublicKeySize) copy(*dst, src[*offset:]) *offset += ed25519.PublicKeySize @@ -31,18 +31,18 @@ func getKey(src []byte, dst *ed25519.PublicKey, offset *int) { func putUint8(dst []byte, v uint8, offset *int) { dst[*offset] = v - *offset += 1 + *offset++ } -func getUint8(src []byte, dst *uint8, offset *int) { +func getUint8(src []byte, dst *uint8, offset *int) { //nolint:unused *dst = src[*offset] - *offset += 1 + *offset++ } func putUint64(dst []byte, v uint64, offset *int) { binary.LittleEndian.PutUint64(dst[*offset:], v) *offset += 8 } -func getUint64(src []byte, dst *uint64, offset *int) { +func getUint64(src []byte, dst *uint64, offset *int) { //nolint:unused *dst = binary.LittleEndian.Uint64(src[*offset:]) *offset += 8 } diff --git a/pkg/solana/timelock/legacy_2022/utils.go b/pkg/solana/timelock/legacy_2022/utils.go index e1ded361..3c1f4876 100644 --- a/pkg/solana/timelock/legacy_2022/utils.go +++ b/pkg/solana/timelock/legacy_2022/utils.go @@ -31,18 +31,18 @@ func getKey(src []byte, dst *ed25519.PublicKey, offset *int) { func putUint8(dst []byte, v uint8, offset *int) { dst[*offset] = v - *offset += 1 + *offset++ } func getUint8(src []byte, dst *uint8, offset *int) { *dst = src[*offset] - *offset += 1 + *offset++ } -func putUint32(dst []byte, v uint32, offset *int) { +func putUint32(dst []byte, v uint32, offset *int) { //nolint:unused binary.LittleEndian.PutUint32(dst[*offset:], v) *offset += 4 } -func getUint32(src []byte, dst *uint32, offset *int) { +func getUint32(src []byte, dst *uint32, offset *int) { //nolint:unused *dst = binary.LittleEndian.Uint32(src[*offset:]) *offset += 4 } @@ -56,7 +56,7 @@ func getUint64(src []byte, dst *uint64, offset *int) { *offset += 8 } -func putOptionalKey(dst []byte, src []byte, offset *int) { +func putOptionalKey(dst []byte, src []byte, offset *int) { //nolint:unused if len(src) > 0 { dst[*offset] = 1 copy(dst[*offset+optionalSize:], src) @@ -66,7 +66,7 @@ func putOptionalKey(dst []byte, src []byte, offset *int) { *offset += optionalSize } } -func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { +func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { //nolint:unused if src[*offset] == 1 { *dst = make([]byte, ed25519.PublicKeySize) copy(*dst, src[*offset+optionalSize:]) @@ -76,7 +76,7 @@ func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { } } -func putOptionalUint32(dst []byte, v *uint32, offset *int) { +func putOptionalUint32(dst []byte, v *uint32, offset *int) { //nolint:unused if v != nil { dst[*offset] = 1 binary.LittleEndian.PutUint32(dst[*offset+optionalSize:], *v) @@ -86,7 +86,7 @@ func putOptionalUint32(dst []byte, v *uint32, offset *int) { *offset += optionalSize } } -func getOptionalUint32(src []byte, dst **uint32, offset *int) { +func getOptionalUint32(src []byte, dst **uint32, offset *int) { //nolint:unused if src[*offset] == 1 { val := binary.LittleEndian.Uint32(src[*offset+optionalSize:]) *dst = &val diff --git a/pkg/solana/timelock/v1/utils.go b/pkg/solana/timelock/v1/utils.go index e1ded361..3c1f4876 100644 --- a/pkg/solana/timelock/v1/utils.go +++ b/pkg/solana/timelock/v1/utils.go @@ -31,18 +31,18 @@ func getKey(src []byte, dst *ed25519.PublicKey, offset *int) { func putUint8(dst []byte, v uint8, offset *int) { dst[*offset] = v - *offset += 1 + *offset++ } func getUint8(src []byte, dst *uint8, offset *int) { *dst = src[*offset] - *offset += 1 + *offset++ } -func putUint32(dst []byte, v uint32, offset *int) { +func putUint32(dst []byte, v uint32, offset *int) { //nolint:unused binary.LittleEndian.PutUint32(dst[*offset:], v) *offset += 4 } -func getUint32(src []byte, dst *uint32, offset *int) { +func getUint32(src []byte, dst *uint32, offset *int) { //nolint:unused *dst = binary.LittleEndian.Uint32(src[*offset:]) *offset += 4 } @@ -56,7 +56,7 @@ func getUint64(src []byte, dst *uint64, offset *int) { *offset += 8 } -func putOptionalKey(dst []byte, src []byte, offset *int) { +func putOptionalKey(dst []byte, src []byte, offset *int) { //nolint:unused if len(src) > 0 { dst[*offset] = 1 copy(dst[*offset+optionalSize:], src) @@ -66,7 +66,7 @@ func putOptionalKey(dst []byte, src []byte, offset *int) { *offset += optionalSize } } -func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { +func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { //nolint:unused if src[*offset] == 1 { *dst = make([]byte, ed25519.PublicKeySize) copy(*dst, src[*offset+optionalSize:]) @@ -76,7 +76,7 @@ func getOptionalKey(src []byte, dst *ed25519.PublicKey, offset *int) { } } -func putOptionalUint32(dst []byte, v *uint32, offset *int) { +func putOptionalUint32(dst []byte, v *uint32, offset *int) { //nolint:unused if v != nil { dst[*offset] = 1 binary.LittleEndian.PutUint32(dst[*offset+optionalSize:], *v) @@ -86,7 +86,7 @@ func putOptionalUint32(dst []byte, v *uint32, offset *int) { *offset += optionalSize } } -func getOptionalUint32(src []byte, dst **uint32, offset *int) { +func getOptionalUint32(src []byte, dst **uint32, offset *int) { //nolint:unused if src[*offset] == 1 { val := binary.LittleEndian.Uint32(src[*offset+optionalSize:]) *dst = &val diff --git a/pkg/solana/transaction.go b/pkg/solana/transaction.go index e5768801..79e1e1d9 100644 --- a/pkg/solana/transaction.go +++ b/pkg/solana/transaction.go @@ -20,6 +20,10 @@ const ( type Signature [ed25519.SignatureSize]byte type Blockhash [sha256.Size]byte +func (b Blockhash) ToBase58() string { + return base58.Encode(b[:]) +} + type Header struct { NumSignatures byte NumReadonlySigned byte diff --git a/pkg/sync/striped_channel_test.go b/pkg/sync/striped_channel_test.go index 6b8d3afb..369f1834 100644 --- a/pkg/sync/striped_channel_test.go +++ b/pkg/sync/striped_channel_test.go @@ -22,15 +22,8 @@ func TestStripedChannel_HappyPath(t *testing.T) { wg.Done() }() - for { - select { - case val, ok := <-c: - if !ok { - return - } - - results[id][val.(int)]++ - } + for val := range c { + results[id][val.(int)]++ } } diff --git a/pkg/sync/striped_lock_test.go b/pkg/sync/striped_lock_test.go index 1c92d5eb..c849e5f7 100644 --- a/pkg/sync/striped_lock_test.go +++ b/pkg/sync/striped_lock_test.go @@ -16,7 +16,7 @@ func TestStripedLock_HappyPath(t *testing.T) { l := NewStripedLock(4) var workerWg base.WaitGroup - startChan := make(chan struct{}, 0) + startChan := make(chan struct{}) data := make([]int, workerCount) for i := 0; i < workerCount; i++ { @@ -33,9 +33,7 @@ func TestStripedLock_HappyPath(t *testing.T) { go func() { defer opWg.Done() - select { - case <-startChan: - } + <-startChan mu := l.Get([]byte(key)) mu.Lock() diff --git a/pkg/testutil/logging.go b/pkg/testutil/logging.go index 4567e096..c9a1d893 100644 --- a/pkg/testutil/logging.go +++ b/pkg/testutil/logging.go @@ -1,7 +1,7 @@ package testutil import ( - "io/ioutil" + "io" "os" "github.com/sirupsen/logrus" @@ -18,13 +18,13 @@ func init() { logrus.SetLevel(logrus.TraceLevel) if !isVerbose { - logrus.StandardLogger().Out = ioutil.Discard + logrus.StandardLogger().Out = io.Discard } } func DisableLogging() (reset func()) { originalLogOutput := logrus.StandardLogger().Out - logrus.StandardLogger().Out = ioutil.Discard + logrus.StandardLogger().Out = io.Discard return func() { logrus.StandardLogger().Out = originalLogOutput } diff --git a/pkg/testutil/server.go b/pkg/testutil/server.go index 8e577b04..13b230f6 100644 --- a/pkg/testutil/server.go +++ b/pkg/testutil/server.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" @@ -69,7 +70,7 @@ func NewServer(opts ...ServerOption) (*grpc.ClientConn, *Server, error) { // note: this is safe since we don't specify grpc.WithBlock() conn, err := grpc.Dial( fmt.Sprintf("localhost:%d", port), - grpc.WithInsecure(), + grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(o.unaryClientInterceptors...)), grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(o.streamClientInterceptors...)), )