diff --git a/go.mod b/go.mod index 8701d46c..5eff3863 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,8 @@ require ( firebase.google.com/go/v4 v4.8.0 github.com/aws/aws-sdk-go-v2 v0.17.0 github.com/bits-and-blooms/bloom/v3 v3.1.0 - github.com/code-payments/code-protobuf-api v1.16.6 + github.com/code-payments/code-protobuf-api v1.19.0 + github.com/dghubble/oauth1 v0.7.3 github.com/emirpasic/gods v1.12.0 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/golang-jwt/jwt/v5 v5.0.0 @@ -41,6 +42,7 @@ require ( golang.org/x/crypto v0.21.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/net v0.22.0 + golang.org/x/sync v0.7.0 golang.org/x/text v0.14.0 golang.org/x/time v0.5.0 google.golang.org/api v0.170.0 @@ -119,7 +121,6 @@ require ( go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.18.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/appengine/v2 v2.0.1 // indirect @@ -130,3 +131,5 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20241010162320-5dac31db232d diff --git a/go.sum b/go.sum index ca244823..0db72051 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,6 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/code-payments/code-protobuf-api v1.16.6 h1:QCot0U+4Ar5SdSX4v955FORMsd3Qcf0ZgkoqlGJZzu0= -github.com/code-payments/code-protobuf-api v1.16.6/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6 h1:NmTXa/uVnDyp0TY5MKi197+3HWcnYWfnHGyaFthlnGw= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= @@ -141,6 +139,8 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dghubble/oauth1 v0.7.3 h1:EkEM/zMDMp3zOsX2DC/ZQ2vnEX3ELK0/l9kb+vs4ptE= +github.com/dghubble/oauth1 v0.7.3/go.mod h1:oxTe+az9NSMIucDPDCCtzJGsPhciJV33xocHfcR2sVY= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/docker/cli v20.10.7+incompatible h1:pv/3NqibQKphWZiAskMzdz8w0PRbtTaEB+f6NwdU7Is= @@ -423,6 +423,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mfycheng/code-protobuf-api v0.0.0-20241010162320-5dac31db232d h1:pOwndvvkUvWXzoiJIIo5wiPT/IP67J5AJqF4sLPdKcY= +github.com/mfycheng/code-protobuf-api v0.0.0-20241010162320-5dac31db232d/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= diff --git a/pkg/code/async/geyser/external_deposit.go b/pkg/code/async/geyser/external_deposit.go index 1b5939ac..48685818 100644 --- a/pkg/code/async/geyser/external_deposit.go +++ b/pkg/code/async/geyser/external_deposit.go @@ -20,7 +20,7 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/balance" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/deposit" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/intent" @@ -299,7 +299,7 @@ func processPotentialExternalDeposit(ctx context.Context, conf *conf, data code_ chatMessage, ) } - case chat.ErrMessageAlreadyExists: + case chat_v1.ErrMessageAlreadyExists: default: return errors.Wrap(err, "error sending chat message") } @@ -772,7 +772,7 @@ func delayedUsdcDepositProcessing( chatMessage, ) } - case chat.ErrMessageAlreadyExists: + case chat_v1.ErrMessageAlreadyExists: default: return } diff --git a/pkg/code/async/geyser/messenger.go b/pkg/code/async/geyser/messenger.go index 5c94153c..c6b7f005 100644 --- a/pkg/code/async/geyser/messenger.go +++ b/pkg/code/async/geyser/messenger.go @@ -15,7 +15,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/push" "github.com/code-payments/code-server/pkg/code/thirdparty" "github.com/code-payments/code-server/pkg/database/query" @@ -165,17 +165,17 @@ func processPotentialBlockchainMessage(ctx context.Context, data code_data.Provi return errors.Wrap(err, "error creating proto message") } - canPush, err := chat_util.SendChatMessage( + canPush, err := chat_util.SendNotificationChatMessageV1( ctx, data, asciiBaseDomain, - chat.ChatTypeExternalApp, + chat_v1.ChatTypeExternalApp, true, recipientOwner, chatMessage, false, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return errors.Wrap(err, "error persisting chat message") } diff --git a/pkg/code/async/user/twitter.go b/pkg/code/async/user/twitter.go index c8612e9f..eb1b6f5d 100644 --- a/pkg/code/async/user/twitter.go +++ b/pkg/code/async/user/twitter.go @@ -4,6 +4,7 @@ import ( "context" "crypto/ed25519" "database/sql" + "fmt" "strings" "time" @@ -11,6 +12,7 @@ import ( "github.com/mr-tron/base58" "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + "github.com/sirupsen/logrus" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" userpb "github.com/code-payments/code-protobuf-api/generated/go/user/v1" @@ -103,6 +105,8 @@ func (p *service) twitterUserInfoUpdateWorker(serviceCtx context.Context, interv } func (p *service) processNewTwitterRegistrations(ctx context.Context) error { + log := p.log.WithField("method", "processNewTwitterRegistrations") + tweets, err := p.findNewRegistrationTweets(ctx) if err != nil { return errors.Wrap(err, "error finding new registration tweets") @@ -113,6 +117,11 @@ func (p *service) processNewTwitterRegistrations(ctx context.Context) error { return errors.Errorf("author missing in tweet %s", tweet.ID) } + log := log.WithFields(logrus.Fields{ + "tweet": tweet.ID, + "username": tweet.AdditionalMetadata.Author, + }) + // Attempt to find a verified tip account from the registration tweet tipAccount, registrationNonce, err := p.findVerifiedTipAccountRegisteredInTweet(ctx, tweet) switch err { @@ -140,7 +149,21 @@ func (p *service) processNewTwitterRegistrations(ctx context.Context) error { switch err { case nil: - go push_util.SendTwitterAccountConnectedPushNotification(ctx, p.data, p.pusher, tipAccount) + // todo: all of these success handlers are fire and forget best-effort delivery + + go func() { + err := push_util.SendTwitterAccountConnectedPushNotification(ctx, p.data, p.pusher, tipAccount) + if err != nil { + log.WithError(err).Warn("failure sending success push") + } + }() + + go func() { + err := p.sendRegistrationSuccessReply(ctx, tweet.ID, tweet.AdditionalMetadata.Author.Username) + if err != nil { + log.WithError(err).Warn("failure sending success reply") + } + }() case twitter.ErrDuplicateTipAddress: err = p.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { err = p.data.MarkTwitterNonceAsUsed(ctx, tweet.ID, *registrationNonce) @@ -171,6 +194,14 @@ func (p *service) processNewTwitterRegistrations(ctx context.Context) error { func (p *service) refreshTwitterUserInfo(ctx context.Context, username string) error { user, err := p.twitterClient.GetUserByUsername(ctx, username) if err != nil { + if strings.Contains(strings.ToLower(err.Error()), "could not find user with username") || strings.Contains(strings.ToLower(err.Error()), "user has been suspended") { + err = p.onTwitterUsernameNotFound(ctx, username) + if err != nil { + return errors.Wrap(err, "error updating cached user state") + } + return nil + } + return errors.Wrap(err, "error getting user info from twitter") } @@ -334,6 +365,36 @@ func (p *service) findVerifiedTipAccountRegisteredInTweet(ctx context.Context, t return nil, nil, errTwitterRegistrationNotFound } +func (p *service) sendRegistrationSuccessReply(ctx context.Context, regristrationTweetId, username string) error { + // todo: localize this + message := fmt.Sprintf( + "@%s your X account is now connected. Share this link to receive tips: https://tipcard.getcode.com/x/%s", + username, + username, + ) + _, err := p.twitterClient.SendReply(ctx, regristrationTweetId, message) + return err +} + +func (p *service) onTwitterUsernameNotFound(ctx context.Context, username string) error { + record, err := p.data.GetTwitterUserByUsername(ctx, username) + switch err { + case nil: + case twitter.ErrUserNotFound: + return nil + default: + return errors.Wrap(err, "error getting cached twitter user") + } + + record.LastUpdatedAt = time.Now() + + err = p.data.SaveTwitterUser(ctx, record) + if err != nil { + return errors.Wrap(err, "error updating cached twitter user") + } + return nil +} + func toProtoVerifiedType(value string) userpb.TwitterUser_VerifiedType { switch value { case "blue": diff --git a/pkg/code/chat/chat.go b/pkg/code/chat/chat.go index e23f6555..2a5b20d0 100644 --- a/pkg/code/chat/chat.go +++ b/pkg/code/chat/chat.go @@ -8,6 +8,7 @@ const ( KinPurchasesName = "Kin Purchases" PaymentsName = "Payments" // Renamed to Web Payments on client TipsName = "Tips" + TwoWayChatName = "Two Way Chat" // Test chats used for unit/integration testing only TestCantMuteName = "TestCantMute" @@ -45,6 +46,11 @@ var ( CanMute: true, CanUnsubscribe: false, }, + TwoWayChatName: { + TitleLocalizationKey: localization.ChatTitleTwoWay, + CanMute: true, + CanUnsubscribe: false, + }, TestCantMuteName: { TitleLocalizationKey: "n/a", diff --git a/pkg/code/chat/message_cash_transactions.go b/pkg/code/chat/message_cash_transactions.go index 1976d9d5..620f32fb 100644 --- a/pkg/code/chat/message_cash_transactions.go +++ b/pkg/code/chat/message_cash_transactions.go @@ -11,7 +11,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" ) @@ -93,9 +93,9 @@ func SendCashTransactionsExchangeMessage(ctx context.Context, data code_data.Pro return errors.Wrap(err, "error getting original gift card issued intent") } - chatId := chat.GetChatId(CashTransactionsName, giftCardIssuedIntentRecord.InitiatorOwnerAccount, true) + chatId := chat_v1.GetChatId(CashTransactionsName, giftCardIssuedIntentRecord.InitiatorOwnerAccount, true) - err = data.DeleteChatMessage(ctx, chatId, giftCardIssuedIntentRecord.IntentId) + err = data.DeleteChatMessageV1(ctx, chatId, giftCardIssuedIntentRecord.IntentId) if err != nil { return errors.Wrap(err, "error deleting chat message") } @@ -148,17 +148,17 @@ func SendCashTransactionsExchangeMessage(ctx context.Context, data code_data.Pro return errors.Wrap(err, "error creating proto chat message") } - _, err = SendChatMessage( + _, err = SendNotificationChatMessageV1( ctx, data, CashTransactionsName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, protoMessage, true, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return errors.Wrap(err, "error persisting chat message") } } diff --git a/pkg/code/chat/message_code_team.go b/pkg/code/chat/message_code_team.go index fe8a7049..423781cf 100644 --- a/pkg/code/chat/message_code_team.go +++ b/pkg/code/chat/message_code_team.go @@ -9,18 +9,18 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/code/localization" ) // SendCodeTeamMessage sends a message to the Code Team chat. func SendCodeTeamMessage(ctx context.Context, data code_data.Provider, receiver *common.Account, chatMessage *chatpb.ChatMessage) (bool, error) { - return SendChatMessage( + return SendNotificationChatMessageV1( ctx, data, CodeTeamName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, chatMessage, @@ -48,8 +48,8 @@ func newIncentiveMessage(localizedTextKey string, intentRecord *intent.Record) ( content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localizedTextKey, }, }, diff --git a/pkg/code/chat/message_kin_purchases.go b/pkg/code/chat/message_kin_purchases.go index 1ec10b68..1e637355 100644 --- a/pkg/code/chat/message_kin_purchases.go +++ b/pkg/code/chat/message_kin_purchases.go @@ -11,23 +11,23 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" ) // GetKinPurchasesChatId returns the chat ID for the Kin Purchases chat for a // given owner account -func GetKinPurchasesChatId(owner *common.Account) chat.ChatId { - return chat.GetChatId(KinPurchasesName, owner.PublicKey().ToBase58(), true) +func GetKinPurchasesChatId(owner *common.Account) chat_v1.ChatId { + return chat_v1.GetChatId(KinPurchasesName, owner.PublicKey().ToBase58(), true) } // SendKinPurchasesMessage sends a message to the Kin Purchases chat. func SendKinPurchasesMessage(ctx context.Context, data code_data.Provider, receiver *common.Account, chatMessage *chatpb.ChatMessage) (bool, error) { - return SendChatMessage( + return SendNotificationChatMessageV1( ctx, data, KinPurchasesName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, chatMessage, @@ -40,8 +40,8 @@ func SendKinPurchasesMessage(ctx context.Context, data code_data.Provider, recei func ToUsdcDepositedMessage(signature string, ts time.Time) (*chatpb.ChatMessage, error) { content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageUsdcDeposited, }, }, @@ -60,8 +60,8 @@ func NewUsdcBeingConvertedMessage(ts time.Time) (*chatpb.ChatMessage, error) { content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageUsdcBeingConverted, }, }, @@ -79,8 +79,8 @@ func ToKinAvailableForUseMessage(signature string, ts time.Time, purchases ...*t content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageKinAvailableForUse, }, }, diff --git a/pkg/code/chat/message_merchant.go b/pkg/code/chat/message_merchant.go index b4504c39..3acf3ed4 100644 --- a/pkg/code/chat/message_merchant.go +++ b/pkg/code/chat/message_merchant.go @@ -13,7 +13,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/action" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" ) @@ -36,7 +36,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // the merchant. Representation in the UI may differ (ie. 2 and 3 are grouped), // but this is the most flexible solution with the chat model. chatTitle := PaymentsName - chatType := chat.ChatTypeInternal + chatType := chat_v1.ChatTypeInternal isVerifiedChat := false exchangeData, ok := getExchangeDataFromIntent(intentRecord) @@ -59,7 +59,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i if paymentRequestRecord.Domain != nil { chatTitle = *paymentRequestRecord.Domain - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = paymentRequestRecord.IsVerified } @@ -87,7 +87,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -107,7 +107,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -126,7 +126,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.ExternalDepositMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -161,7 +161,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i return nil, errors.Wrap(err, "error creating proto chat message") } - canPush, err := SendChatMessage( + canPush, err := SendNotificationChatMessageV1( ctx, data, chatTitle, @@ -171,7 +171,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i protoMessage, verbAndExchangeData.verb != chatpb.ExchangeDataContent_RECEIVED || !isVerifiedChat, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return nil, errors.Wrap(err, "error persisting chat message") } diff --git a/pkg/code/chat/message_tips.go b/pkg/code/chat/message_tips.go index b9984a9b..5500bc89 100644 --- a/pkg/code/chat/message_tips.go +++ b/pkg/code/chat/message_tips.go @@ -2,15 +2,12 @@ package chat import ( "context" - - "github.com/pkg/errors" - chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" - "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/pkg/errors" ) // SendTipsExchangeMessage sends a message to the Tips chat with exchange data @@ -18,7 +15,7 @@ import ( // Tips chat will be ignored. // // Note: Tests covered in SubmitIntent history tests -func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, intentRecord *intent.Record) ([]*MessageWithOwner, error) { +func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, notifier Notifier, intentRecord *intent.Record) ([]*MessageWithOwner, error) { messageId := intentRecord.IntentId exchangeData, ok := getExchangeDataFromIntent(intentRecord) @@ -30,7 +27,6 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, inten switch intentRecord.IntentType { case intent.SendPrivatePayment: if !intentRecord.SendPrivatePaymentMetadata.IsTip { - // Not a tip return nil, nil } @@ -61,30 +57,31 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, inten }, }, } - protoMessage, err := newProtoChatMessage(messageId, content, intentRecord.CreatedAt) + + v1Message, err := newProtoChatMessage(messageId, content, intentRecord.CreatedAt) if err != nil { return nil, errors.Wrap(err, "error creating proto chat message") } - canPush, err := SendChatMessage( + canPush, err := SendNotificationChatMessageV1( ctx, data, TipsName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, - protoMessage, + v1Message, verb != chatpb.ExchangeDataContent_RECEIVED_TIP, ) - if err != nil && err != chat.ErrMessageAlreadyExists { - return nil, errors.Wrap(err, "error persisting chat message") + if err != nil && !errors.Is(err, chat_v1.ErrMessageAlreadyExists) { + return nil, errors.Wrap(err, "error persisting v1 chat message") } if canPush { messagesToPush = append(messagesToPush, &MessageWithOwner{ Owner: receiver, Title: TipsName, - Message: protoMessage, + Message: v1Message, }) } } diff --git a/pkg/code/chat/notifier.go b/pkg/code/chat/notifier.go new file mode 100644 index 00000000..7fa4c1ff --- /dev/null +++ b/pkg/code/chat/notifier.go @@ -0,0 +1,22 @@ +package chat + +import ( + "context" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" +) + +type Notifier interface { + NotifyMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.Message) +} + +type NoopNotifier struct{} + +func NewNoopNotifier() *NoopNotifier { + return &NoopNotifier{} +} + +func (n *NoopNotifier) NotifyMessage(_ context.Context, _ chat.ChatId, _ *chatpb.Message) { +} diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index 41da0902..f63436a8 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -9,28 +9,27 @@ import ( "google.golang.org/protobuf/proto" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" - "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" ) -// SendChatMessage sends a chat message to a receiving owner account. +// SendNotificationChatMessageV1 sends a chat message to a receiving owner account. // // Note: This function is not responsible for push notifications. This method // might be called within the context of a DB transaction, which might have -// unrelated failures. A hint as to whether a push should be sent is provided. -func SendChatMessage( +// unrelated failures. A hint whether a push should be sent is provided. +func SendNotificationChatMessageV1( ctx context.Context, data code_data.Provider, chatTitle string, - chatType chat.ChatType, + chatType chat_v1.ChatType, isVerifiedChat bool, receiver *common.Account, protoMessage *chatpb.ChatMessage, isSilentMessage bool, ) (canPushMessage bool, err error) { - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) if protoMessage.Cursor != nil { // Let the utilities and GetMessages RPC handle cursors @@ -58,13 +57,13 @@ func SendChatMessage( canPersistMessage := true canPushMessage = !isSilentMessage - existingChatRecord, err := data.GetChatById(ctx, chatId) + existingChatRecord, err := data.GetChatByIdV1(ctx, chatId) switch err { case nil: canPersistMessage = !existingChatRecord.IsUnsubscribed canPushMessage = canPushMessage && canPersistMessage && !existingChatRecord.IsMuted - case chat.ErrChatNotFound: - chatRecord := &chat.Chat{ + case chat_v1.ErrChatNotFound: + chatRecord := &chat_v1.Chat{ ChatId: chatId, ChatType: chatType, ChatTitle: chatTitle, @@ -79,8 +78,8 @@ func SendChatMessage( CreatedAt: time.Now(), } - err = data.PutChat(ctx, chatRecord) - if err != nil && err != chat.ErrChatAlreadyExists { + err = data.PutChatV1(ctx, chatRecord) + if err != nil && !errors.Is(err, chat_v1.ErrChatAlreadyExists) { return false, err } default: @@ -88,7 +87,7 @@ func SendChatMessage( } if canPersistMessage { - messageRecord := &chat.Message{ + messageRecord := &chat_v1.Message{ ChatId: chatId, MessageId: base58.Encode(messageId), @@ -100,7 +99,7 @@ func SendChatMessage( Timestamp: ts.AsTime(), } - err = data.PutChatMessage(ctx, messageRecord) + err = data.PutChatMessageV1(ctx, messageRecord) if err != nil { return false, err } diff --git a/pkg/code/chat/sender_test.go b/pkg/code/chat/sender_test.go index 7625a1f3..18028c43 100644 --- a/pkg/code/chat/sender_test.go +++ b/pkg/code/chat/sender_test.go @@ -17,7 +17,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/badgecount" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/testutil" ) @@ -26,16 +26,15 @@ func TestSendChatMessage_HappyPath(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) var expectedBadgeCount int for i := 0; i < 10; i++ { chatMessage := newRandomChatMessage(t, i+1) expectedBadgeCount += 1 - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) - assert.True(t, canPush) assert.NotNil(t, chatMessage.MessageId) @@ -56,7 +55,7 @@ func TestSendChatMessage_VerifiedChat(t *testing.T) { for _, isVerified := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - _, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, isVerified, receiver, chatMessage, true) + _, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, isVerified, receiver, chatMessage, true) require.NoError(t, err) env.assertChatRecordSaved(t, chatTitle, receiver, isVerified) } @@ -67,11 +66,11 @@ func TestSendChatMessage_SilentMessage(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for i, isSilent := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, isSilent) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, isSilent) require.NoError(t, err) assert.Equal(t, !isSilent, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, isSilent) @@ -84,7 +83,7 @@ func TestSendChatMessage_MuteState(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for _, isMuted := range []bool{false, true} { if isMuted { @@ -92,7 +91,7 @@ func TestSendChatMessage_MuteState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isMuted, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, false) @@ -105,7 +104,7 @@ func TestSendChatMessage_SubscriptionState(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for _, isUnsubscribed := range []bool{false, true} { if isUnsubscribed { @@ -113,7 +112,7 @@ func TestSendChatMessage_SubscriptionState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isUnsubscribed, canPush) if isUnsubscribed { @@ -130,12 +129,12 @@ func TestSendChatMessage_InvalidProtoMessage(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) chatMessage := newRandomChatMessage(t, 1) chatMessage.Content = nil - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) assert.Error(t, err) assert.False(t, canPush) env.assertChatRecordNotSaved(t, chatId) @@ -159,8 +158,8 @@ func newRandomChatMessage(t *testing.T, contentLength int) *chatpb.ChatMessage { var content []*chatpb.Content for i := 0; i < contentLength; i++ { content = append(content, &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("key%d", rand.Uint32()), }, }, @@ -173,13 +172,13 @@ func newRandomChatMessage(t *testing.T, contentLength int) *chatpb.ChatMessage { } func (e *testEnv) assertChatRecordSaved(t *testing.T, chatTitle string, receiver *common.Account, isVerified bool) { - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerified) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerified) - chatRecord, err := e.data.GetChatById(e.ctx, chatId) + chatRecord, err := e.data.GetChatByIdV1(e.ctx, chatId) require.NoError(t, err) assert.Equal(t, chatId[:], chatRecord.ChatId[:]) - assert.Equal(t, chat.ChatTypeInternal, chatRecord.ChatType) + assert.Equal(t, chat_v1.ChatTypeInternal, chatRecord.ChatType) assert.Equal(t, chatTitle, chatRecord.ChatTitle) assert.Equal(t, isVerified, chatRecord.IsVerified) assert.Equal(t, receiver.PublicKey().ToBase58(), chatRecord.CodeUser) @@ -188,8 +187,8 @@ func (e *testEnv) assertChatRecordSaved(t *testing.T, chatTitle string, receiver assert.False(t, chatRecord.IsUnsubscribed) } -func (e *testEnv) assertChatMessageRecordSaved(t *testing.T, chatId chat.ChatId, protoMessage *chatpb.ChatMessage, isSilent bool) { - messageRecord, err := e.data.GetChatMessage(e.ctx, chatId, base58.Encode(protoMessage.GetMessageId().Value)) +func (e *testEnv) assertChatMessageRecordSaved(t *testing.T, chatId chat_v1.ChatId, protoMessage *chatpb.ChatMessage, isSilent bool) { + messageRecord, err := e.data.GetChatMessageV1(e.ctx, chatId, base58.Encode(protoMessage.GetMessageId().Value)) require.NoError(t, err) cloned := proto.Clone(protoMessage).(*chatpb.ChatMessage) @@ -218,22 +217,22 @@ func (e *testEnv) assertBadgeCount(t *testing.T, owner *common.Account, expected assert.EqualValues(t, expected, badgeCountRecord.BadgeCount) } -func (e *testEnv) assertChatRecordNotSaved(t *testing.T, chatId chat.ChatId) { - _, err := e.data.GetChatById(e.ctx, chatId) - assert.Equal(t, chat.ErrChatNotFound, err) +func (e *testEnv) assertChatRecordNotSaved(t *testing.T, chatId chat_v1.ChatId) { + _, err := e.data.GetChatByIdV1(e.ctx, chatId) + assert.Equal(t, chat_v1.ErrChatNotFound, err) } -func (e *testEnv) assertChatMessageRecordNotSaved(t *testing.T, chatId chat.ChatId, messageId *chatpb.ChatMessageId) { - _, err := e.data.GetChatMessage(e.ctx, chatId, base58.Encode(messageId.Value)) - assert.Equal(t, chat.ErrMessageNotFound, err) +func (e *testEnv) assertChatMessageRecordNotSaved(t *testing.T, chatId chat_v1.ChatId, messageId *chatpb.ChatMessageId) { + _, err := e.data.GetChatMessageV1(e.ctx, chatId, base58.Encode(messageId.Value)) + assert.Equal(t, chat_v1.ErrMessageNotFound, err) } -func (e *testEnv) muteChat(t *testing.T, chatId chat.ChatId) { - require.NoError(t, e.data.SetChatMuteState(e.ctx, chatId, true)) +func (e *testEnv) muteChat(t *testing.T, chatId chat_v1.ChatId) { + require.NoError(t, e.data.SetChatMuteStateV1(e.ctx, chatId, true)) } -func (e *testEnv) unsubscribeFromChat(t *testing.T, chatId chat.ChatId) { - require.NoError(t, e.data.SetChatSubscriptionState(e.ctx, chatId, false)) +func (e *testEnv) unsubscribeFromChat(t *testing.T, chatId chat_v1.ChatId) { + require.NoError(t, e.data.SetChatSubscriptionStateV1(e.ctx, chatId, false)) } diff --git a/pkg/code/common/account.go b/pkg/code/common/account.go index d48d971d..417e3f4a 100644 --- a/pkg/code/common/account.go +++ b/pkg/code/common/account.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ed25519" "fmt" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -196,6 +197,27 @@ func (a *Account) ToTimelockVault(dataVersion timelock_token_v1.TimelockDataVers return timelockAccounts.Vault, nil } +func (a *Account) ToMessagingAccount(mint *Account) (*Account, error) { + return a.ToTimelockVault(timelock_token_v1.DataVersion1, mint) +} + +func (a *Account) ToChatMemberId() (chat.MemberId, error) { + messagingAccount, err := a.ToMessagingAccount(KinMintAccount) + if err != nil { + return nil, err + } + + return messagingAccount.PublicKey().ToBytes(), nil +} + +func (a *Account) MustToChatMemberId() chat.MemberId { + id, err := a.ToChatMemberId() + if err != nil { + panic(err) + } + return id +} + func (a *Account) ToAssociatedTokenAccount(mint *Account) (*Account, error) { if err := a.Validate(); err != nil { return nil, errors.Wrap(err, "error validating owner account") diff --git a/pkg/code/data/chat/memory/store.go b/pkg/code/data/chat/v1/memory/store.go similarity index 99% rename from pkg/code/data/chat/memory/store.go rename to pkg/code/data/chat/v1/memory/store.go index 1b869007..1d923071 100644 --- a/pkg/code/data/chat/memory/store.go +++ b/pkg/code/data/chat/v1/memory/store.go @@ -7,8 +7,8 @@ import ( "sync" "time" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/code/data/chat" ) type ChatsById []*chat.Chat diff --git a/pkg/code/data/chat/memory/store_test.go b/pkg/code/data/chat/v1/memory/store_test.go similarity index 74% rename from pkg/code/data/chat/memory/store_test.go rename to pkg/code/data/chat/v1/memory/store_test.go index 5d2c18a5..c27859e6 100644 --- a/pkg/code/data/chat/memory/store_test.go +++ b/pkg/code/data/chat/v1/memory/store_test.go @@ -3,7 +3,7 @@ package memory import ( "testing" - "github.com/code-payments/code-server/pkg/code/data/chat/tests" + "github.com/code-payments/code-server/pkg/code/data/chat/v1/tests" ) func TestChatMemoryStore(t *testing.T) { diff --git a/pkg/code/data/chat/model.go b/pkg/code/data/chat/v1/model.go similarity index 99% rename from pkg/code/data/chat/model.go rename to pkg/code/data/chat/v1/model.go index d8fe7432..4d156996 100644 --- a/pkg/code/data/chat/model.go +++ b/pkg/code/data/chat/v1/model.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "bytes" diff --git a/pkg/code/data/chat/model_test.go b/pkg/code/data/chat/v1/model_test.go similarity index 97% rename from pkg/code/data/chat/model_test.go rename to pkg/code/data/chat/v1/model_test.go index 7774d286..062f372b 100644 --- a/pkg/code/data/chat/model_test.go +++ b/pkg/code/data/chat/v1/model_test.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "testing" diff --git a/pkg/code/data/chat/postgres/model.go b/pkg/code/data/chat/v1/postgres/model.go similarity index 99% rename from pkg/code/data/chat/postgres/model.go rename to pkg/code/data/chat/v1/postgres/model.go index 07158095..8987df2b 100644 --- a/pkg/code/data/chat/postgres/model.go +++ b/pkg/code/data/chat/v1/postgres/model.go @@ -8,10 +8,10 @@ import ( "github.com/jmoiron/sqlx" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" pgutil "github.com/code-payments/code-server/pkg/database/postgres" q "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/chat" ) const ( diff --git a/pkg/code/data/chat/postgres/store.go b/pkg/code/data/chat/v1/postgres/store.go similarity index 98% rename from pkg/code/data/chat/postgres/store.go rename to pkg/code/data/chat/v1/postgres/store.go index 943a1935..bfb2b14f 100644 --- a/pkg/code/data/chat/postgres/store.go +++ b/pkg/code/data/chat/v1/postgres/store.go @@ -6,8 +6,8 @@ import ( "github.com/jmoiron/sqlx" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/code/data/chat" ) type store struct { diff --git a/pkg/code/data/chat/postgres/store_test.go b/pkg/code/data/chat/v1/postgres/store_test.go similarity index 95% rename from pkg/code/data/chat/postgres/store_test.go rename to pkg/code/data/chat/v1/postgres/store_test.go index 49143ad7..4d72fc4e 100644 --- a/pkg/code/data/chat/postgres/store_test.go +++ b/pkg/code/data/chat/v1/postgres/store_test.go @@ -8,8 +8,8 @@ import ( "github.com/ory/dockertest/v3" "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/code/data/chat" - "github.com/code-payments/code-server/pkg/code/data/chat/tests" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" + "github.com/code-payments/code-server/pkg/code/data/chat/v1/tests" postgrestest "github.com/code-payments/code-server/pkg/database/postgres/test" diff --git a/pkg/code/data/chat/store.go b/pkg/code/data/chat/v1/store.go similarity index 99% rename from pkg/code/data/chat/store.go rename to pkg/code/data/chat/v1/store.go index 2e79a228..c21471b5 100644 --- a/pkg/code/data/chat/store.go +++ b/pkg/code/data/chat/v1/store.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "context" diff --git a/pkg/code/data/chat/tests/tests.go b/pkg/code/data/chat/v1/tests/tests.go similarity index 99% rename from pkg/code/data/chat/tests/tests.go rename to pkg/code/data/chat/v1/tests/tests.go index f9eaaadf..1453c133 100644 --- a/pkg/code/data/chat/tests/tests.go +++ b/pkg/code/data/chat/v1/tests/tests.go @@ -10,9 +10,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/chat" ) func RunTests(t *testing.T, s chat.Store, teardown func()) { diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go new file mode 100644 index 00000000..2aeed440 --- /dev/null +++ b/pkg/code/data/chat/v2/id.go @@ -0,0 +1,277 @@ +package chat_v2 + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/mr-tron/base58" + "github.com/pkg/errors" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" +) + +type ChatId [32]byte + +// GetChatIdFromBytes gets a chat ID from a byte buffer +func GetChatIdFromBytes(buffer []byte) (ChatId, error) { + if len(buffer) != 32 { + return ChatId{}, errors.New("chat id must be 32 bytes in length") + } + + var typed ChatId + copy(typed[:], buffer[:]) + + if err := typed.Validate(); err != nil { + return ChatId{}, errors.Wrap(err, "invalid chat id") + } + + return typed, nil +} + +// GetTwoWayChatId returns the ChatId for two users. +func GetTwoWayChatId(sender, receiver []byte) ChatId { + var a, b []byte + if bytes.Compare(sender, receiver) <= 0 { + a, b = sender, receiver + } else { + a, b = receiver, sender + } + + combined := make([]byte, len(a)+len(b)) + copy(combined, a) + copy(combined[len(a):], b) + + return sha256.Sum256(combined) +} + +// GetChatIdFromProto gets a chat ID from the protobuf variant +func GetChatIdFromProto(proto *commonpb.ChatId) (ChatId, error) { + if err := proto.Validate(); err != nil { + return ChatId{}, errors.Wrap(err, "proto validation failed") + } + + return GetChatIdFromBytes(proto.Value) +} + +// ToProto converts a chat ID to its protobuf variant +func (c ChatId) ToProto() *commonpb.ChatId { + return &commonpb.ChatId{Value: c[:]} +} + +// Validate validates a chat ID +func (c ChatId) Validate() error { + return nil +} + +// Clone clones a chat ID +func (c ChatId) Clone() ChatId { + var cloned ChatId + copy(cloned[:], c[:]) + return cloned +} + +// String returns the string representation of a ChatId +func (c ChatId) String() string { + return hex.EncodeToString(c[:]) +} + +type MemberId []byte + +// GetMemberIdFromBytes gets a member ID from a byte buffer +func GetMemberIdFromBytes(buffer []byte) (MemberId, error) { + if len(buffer) != 32 { + return MemberId{}, errors.New("member id must be 32 bytes in length") + } + + typed := make(MemberId, len(buffer)) + copy(typed[:], buffer[:]) + + if err := typed.Validate(); err != nil { + return MemberId{}, errors.Wrap(err, "invalid member id") + } + + return typed, nil +} + +// GetMemberIdFromString gets a chat member ID from the string representation +func GetMemberIdFromString(value string) (MemberId, error) { + b, err := base58.Decode(value) + if err != nil { + return MemberId{}, errors.Wrap(err, "invalid member id") + } + + return GetMemberIdFromBytes(b) +} + +// GetMemberIdFromProto gets a member ID from the protobuf variant +func GetMemberIdFromProto(proto *chatpb.MemberId) (MemberId, error) { + if err := proto.Validate(); err != nil { + return MemberId{}, errors.Wrap(err, "proto validation failed") + } + + return GetMemberIdFromBytes(proto.Value) +} + +// ToProto converts a message ID to its protobuf variant +func (m MemberId) ToProto() *chatpb.MemberId { + return &chatpb.MemberId{Value: m[:]} +} + +// Validate validates a chat member ID +func (m MemberId) Validate() error { + if l := len(m); l < 0 || l > 32 { + return fmt.Errorf("member id must be in range 0-32, got: %d", l) + } + + return nil +} + +// Clone clones a chat member ID +func (m MemberId) Clone() MemberId { + cloned := make(MemberId, len(m)) + copy(cloned[:], m[:]) + return cloned +} + +// String returns the string representation of a MemberId +func (m MemberId) String() string { + return base58.Encode(m[:]) +} + +// MessageId is a time-based UUIDv7 ID for chat messages +type MessageId uuid.UUID + +// GenerateMessageId generates a UUIDv7 message ID using the current time +func GenerateMessageId() MessageId { + return GenerateMessageIdAtTime(time.Now()) +} + +// GenerateMessageIdAtTime generates a UUIDv7 message ID using the provided timestamp +func GenerateMessageIdAtTime(ts time.Time) MessageId { + // Convert timestamp to milliseconds since Unix epoch + millis := ts.UnixNano() / int64(time.Millisecond) + + // Create a byte slice to hold the UUID + var uuidBytes [16]byte + + // Populate the first 6 bytes with the timestamp (42 bits for timestamp) + uuidBytes[0] = byte((millis >> 40) & 0xff) + uuidBytes[1] = byte((millis >> 32) & 0xff) + uuidBytes[2] = byte((millis >> 24) & 0xff) + uuidBytes[3] = byte((millis >> 16) & 0xff) + uuidBytes[4] = byte((millis >> 8) & 0xff) + uuidBytes[5] = byte(millis & 0xff) + + // Set the version to 7 (UUIDv7) + uuidBytes[6] = (uuidBytes[6] & 0x0f) | (0x7 << 4) + + // Populate the remaining bytes with random values + randomUUID := uuid.New() + copy(uuidBytes[7:], randomUUID[7:]) + + return uuidBytes +} + +// GetMessageIdFromBytes gets a message ID from a byte buffer +func GetMessageIdFromBytes(buffer []byte) (MessageId, error) { + if len(buffer) != 16 { + return MessageId{}, errors.New("message id must be 16 bytes in length") + } + + var typed MessageId + copy(typed[:], buffer[:]) + + if err := typed.Validate(); err != nil { + return MessageId{}, errors.Wrap(err, "invalid message id") + } + + return typed, nil +} + +// GetMessageIdFromString gets a chat message ID from the string representation +func GetMessageIdFromString(value string) (MessageId, error) { + decoded, err := uuid.Parse(value) + if err != nil { + return MessageId{}, errors.Wrap(err, "value is not a uuid string") + } + + return GetMessageIdFromBytes(decoded[:]) +} + +// GetMessageIdFromProto gets a message ID from the protobuf variant +func GetMessageIdFromProto(proto *chatpb.MessageId) (MessageId, error) { + if err := proto.Validate(); err != nil { + return MessageId{}, errors.Wrap(err, "proto validation failed") + } + + return GetMessageIdFromBytes(proto.Value) +} + +// ToProto converts a message ID to its protobuf variant +func (m MessageId) ToProto() *chatpb.MessageId { + return &chatpb.MessageId{Value: m[:]} +} + +// GetTimestamp gets the encoded timestamp in the message ID +func (m MessageId) GetTimestamp() (time.Time, error) { + if err := m.Validate(); err != nil { + return time.Time{}, errors.Wrap(err, "invalid message id") + } + + // Extract the first 6 bytes as the timestamp + millis := (int64(m[0]) << 40) | (int64(m[1]) << 32) | (int64(m[2]) << 24) | + (int64(m[3]) << 16) | (int64(m[4]) << 8) | int64(m[5]) + + // Convert milliseconds since Unix epoch to time.Time + timestamp := time.Unix(0, millis*int64(time.Millisecond)) + + return timestamp, nil +} + +// Equals returns whether two message IDs are equal +func (m MessageId) Equals(other MessageId) bool { + return m.Compare(other) == 0 +} + +// Before returns whether the message ID is before the provided value +func (m MessageId) Before(other MessageId) bool { + return m.Compare(other) < 0 +} + +// After returns whether the message ID is after the provided value +func (m MessageId) After(other MessageId) bool { + return m.Compare(other) > 0 +} + +// Compare returns the byte comparison of the message ID +func (m MessageId) Compare(other MessageId) int { + return bytes.Compare(m[:], other[:]) +} + +// Validate validates a message ID +func (m MessageId) Validate() error { + casted := uuid.UUID(m) + + if casted.Version() != 7 { + return errors.Errorf("invalid uuid version: %s", casted.Version().String()) + } + + return nil +} + +// Clone clones a chat message ID +func (m MessageId) Clone() MessageId { + var cloned MessageId + copy(cloned[:], m[:]) + return cloned +} + +// String returns the string representation of a MessageId +func (m MessageId) String() string { + return uuid.UUID(m).String() +} diff --git a/pkg/code/data/chat/v2/id_test.go b/pkg/code/data/chat/v2/id_test.go new file mode 100644 index 00000000..3a9b710c --- /dev/null +++ b/pkg/code/data/chat/v2/id_test.go @@ -0,0 +1,47 @@ +package chat_v2 + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateMessageId_Validation(t *testing.T) { + valid := GenerateMessageId() + assert.NoError(t, valid.Validate()) + + invalid := MessageId(uuid.New()) + assert.Error(t, invalid.Validate()) +} + +func TestGenerateMessageId_TimestampExtraction(t *testing.T) { + expectedTs := time.Now() + + messageId := GenerateMessageIdAtTime(expectedTs) + actualTs, err := messageId.GetTimestamp() + require.NoError(t, err) + assert.Equal(t, expectedTs.UnixMilli(), actualTs.UnixMilli()) +} + +func TestGenerateMessageId_Ordering(t *testing.T) { + now := time.Now() + messageIds := make([]MessageId, 0) + for i := 0; i < 10; i++ { + messageId := GenerateMessageIdAtTime(now.Add(time.Duration(i * int(time.Millisecond)))) + messageIds = append(messageIds, messageId) + } + + for i := 0; i < len(messageIds)-1; i++ { + assert.True(t, messageIds[i].Equals(messageIds[i])) + assert.False(t, messageIds[i].Equals(messageIds[i+1])) + + assert.True(t, messageIds[i].Before(messageIds[i+1])) + assert.False(t, messageIds[i].After(messageIds[i+1])) + + assert.False(t, messageIds[i+1].Before(messageIds[i])) + assert.True(t, messageIds[i+1].After(messageIds[i])) + } +} diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go new file mode 100644 index 00000000..164cdc9c --- /dev/null +++ b/pkg/code/data/chat/v2/memory/store.go @@ -0,0 +1,418 @@ +package memory + +import ( + "bytes" + "context" + "slices" + "sort" + "strings" + "sync" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/database/query" +) + +type InMemoryStore struct { + mu sync.RWMutex + chats map[string]*chat.MetadataRecord + members map[string]map[string]*chat.MemberRecord + messages map[string][]*chat.MessageRecord +} + +func New() *InMemoryStore { + return &InMemoryStore{ + chats: make(map[string]*chat.MetadataRecord), + members: make(map[string]map[string]*chat.MemberRecord), + messages: make(map[string][]*chat.MessageRecord), + } +} + +// GetChatMetadata retrieves the metadata record for a specific chat +func (s *InMemoryStore) GetChatMetadata(_ context.Context, chatId chat.ChatId) (*chat.MetadataRecord, error) { + if err := chatId.Validate(); err != nil { + return nil, err + } + + s.mu.RLock() + defer s.mu.RUnlock() + + if md, exists := s.chats[string(chatId[:])]; exists { + cloned := md.Clone() + return &cloned, nil + } + + return nil, chat.ErrChatNotFound +} + +// GetChatMessageV2 retrieves a specific message from a chat +func (s *InMemoryStore) GetChatMessageV2(_ context.Context, chatId chat.ChatId, messageId chat.MessageId) (*chat.MessageRecord, error) { + if err := chatId.Validate(); err != nil { + return nil, err + } + if err := messageId.Validate(); err != nil { + return nil, err + } + + s.mu.RLock() + defer s.mu.RUnlock() + + if messages, exists := s.messages[string(chatId[:])]; exists { + for _, message := range messages { + if bytes.Equal(message.MessageId[:], messageId[:]) { + clone := message.Clone() + return &clone, nil + } + } + } + + return nil, chat.ErrMessageNotFound +} + +// GetAllChatsForUserV2 retrieves all chat IDs that a given user belongs to +func (s *InMemoryStore) GetAllChatsForUserV2(_ context.Context, user chat.MemberId, opts ...query.Option) ([]chat.ChatId, error) { + if err := user.Validate(); err != nil { + return nil, err + } + + qo := &query.QueryOptions{ + Supported: query.CanQueryByCursor | query.CanLimitResults | query.CanSortBy, + } + err := qo.Apply(opts...) + if err != nil { + return nil, err + } + + s.mu.RLock() + defer s.mu.RUnlock() + + var chatIds []chat.ChatId + for chatIdStr, members := range s.members { + if _, exists := members[user.String()]; exists { + chatId, _ := chat.GetChatIdFromBytes([]byte(chatIdStr)) + chatIds = append(chatIds, chatId) + } + } + + // Sort the chatIds + sort.Slice(chatIds, func(i, j int) bool { + if qo.SortBy == query.Descending { + return bytes.Compare(chatIds[i][:], chatIds[j][:]) > 0 + } + return bytes.Compare(chatIds[i][:], chatIds[j][:]) < 0 + }) + + // Apply cursor if provided + if qo.Cursor != nil { + cursorChatId, err := chat.GetChatIdFromBytes(qo.Cursor) + if err != nil { + return nil, err + } + var filteredChatIds []chat.ChatId + for _, chatId := range chatIds { + if qo.SortBy == query.Descending { + if bytes.Compare(chatId[:], cursorChatId[:]) < 0 { + filteredChatIds = append(filteredChatIds, chatId) + } + } else { + if bytes.Compare(chatId[:], cursorChatId[:]) > 0 { + filteredChatIds = append(filteredChatIds, chatId) + } + } + } + chatIds = filteredChatIds + } + + // Apply limit if provided + if qo.Limit > 0 && uint64(len(chatIds)) > qo.Limit { + chatIds = chatIds[:qo.Limit] + } + + return chatIds, nil +} + +// GetAllChatMessagesV2 retrieves all messages for a specific chat +func (s *InMemoryStore) GetAllChatMessagesV2(_ context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.MessageRecord, error) { + if err := chatId.Validate(); err != nil { + return nil, err + } + + qo := &query.QueryOptions{ + Supported: query.CanLimitResults | query.CanSortBy | query.CanQueryByCursor, + } + if err := qo.Apply(opts...); err != nil { + return nil, err + } + + s.mu.RLock() + defer s.mu.RUnlock() + + messages, exists := s.messages[string(chatId[:])] + if !exists { + return nil, nil + } + + var result []*chat.MessageRecord + for _, msg := range messages { + cloned := msg.Clone() + result = append(result, &cloned) + } + + // Sort the messages + sort.Slice(result, func(i, j int) bool { + if qo.SortBy == query.Descending { + return bytes.Compare(result[i].MessageId[:], result[j].MessageId[:]) > 0 + } + return bytes.Compare(result[i].MessageId[:], result[j].MessageId[:]) < 0 + }) + + // Apply cursor if provided + if len(qo.Cursor) > 0 { + cursorMessageId, err := chat.GetMessageIdFromBytes(qo.Cursor) + if err != nil { + return nil, err + } + var filteredMessages []*chat.MessageRecord + for _, msg := range result { + if qo.SortBy == query.Descending { + if bytes.Compare(msg.MessageId[:], cursorMessageId[:]) < 0 { + filteredMessages = append(filteredMessages, msg) + } + } else { + if bytes.Compare(msg.MessageId[:], cursorMessageId[:]) > 0 { + filteredMessages = append(filteredMessages, msg) + } + } + } + result = filteredMessages + } + + // Apply limit if provided + if qo.Limit > 0 && uint64(len(result)) > qo.Limit { + result = result[:qo.Limit] + } + + return result, nil +} + +// GetChatMembersV2 retrieves all members of a specific chat +func (s *InMemoryStore) GetChatMembersV2(_ context.Context, chatId chat.ChatId) ([]*chat.MemberRecord, error) { + if err := chatId.Validate(); err != nil { + return nil, err + } + + s.mu.RLock() + defer s.mu.RUnlock() + + members, exists := s.members[string(chatId[:])] + if !exists { + return nil, chat.ErrChatNotFound + } + + var result []*chat.MemberRecord + for _, member := range members { + cloned := member.Clone() + result = append(result, &cloned) + } + + slices.SortFunc(result, func(a, b *chat.MemberRecord) int { + return strings.Compare(a.MemberId, b.MemberId) + }) + + return result, nil +} + +// IsChatMember checks if a given member is part of a specific chat +func (s *InMemoryStore) IsChatMember(_ context.Context, chatId chat.ChatId, memberId chat.MemberId) (bool, error) { + if err := chatId.Validate(); err != nil { + return false, err + } + if err := memberId.Validate(); err != nil { + return false, err + } + + s.mu.RLock() + defer s.mu.RUnlock() + + if members, exists := s.members[string(chatId[:])]; exists { + _, exists = members[memberId.String()] + return exists, nil + } + + return false, nil +} + +// PutChatV2 stores or updates the metadata for a specific chat +func (s *InMemoryStore) PutChatV2(_ context.Context, record *chat.MetadataRecord) error { + if err := record.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.chats[string(record.ChatId[:])]; exists { + return chat.ErrChatExists + } + + cloned := record.Clone() + s.chats[string(record.ChatId[:])] = &cloned + + return nil +} + +// PutChatMemberV2 stores or updates a member record for a specific chat +func (s *InMemoryStore) PutChatMemberV2(_ context.Context, record *chat.MemberRecord) error { + if err := record.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + members, exists := s.members[string(record.ChatId[:])] + if !exists { + members = make(map[string]*chat.MemberRecord) + s.members[string(record.ChatId[:])] = members + } + + if _, exists = members[record.MemberId]; exists { + return chat.ErrMemberExists + } + + cloned := record.Clone() + members[record.MemberId] = &cloned + + return nil +} + +// PutChatMessageV2 stores or updates a message record in a specific chat +func (s *InMemoryStore) PutChatMessageV2(_ context.Context, record *chat.MessageRecord) error { + if err := record.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + messages := s.messages[string(record.ChatId[:])] + if messages == nil { + messages = make([]*chat.MessageRecord, 0) + s.messages[string(record.ChatId[:])] = messages + } + + i, found := sort.Find(len(messages), func(i int) int { + return bytes.Compare(record.MessageId[:], messages[i].MessageId[:]) + }) + if found { + return chat.ErrMessageExists + } + + cloned := record.Clone() + messages = slices.Insert(messages, i, &cloned) + s.messages[string(record.ChatId[:])] = messages + + return nil +} + +// SetChatMuteStateV2 sets the mute state for a specific chat member +func (s *InMemoryStore) SetChatMuteStateV2(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, isMuted bool) error { + if err := chatId.Validate(); err != nil { + return err + } + if err := memberId.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + if members, exists := s.members[string(chatId[:])]; exists { + if member, exists := members[memberId.String()]; exists { + member.IsMuted = isMuted + return nil + } + } + return chat.ErrMemberNotFound +} + +// AdvanceChatPointerV2 advances a pointer for a chat member +func (s *InMemoryStore) AdvanceChatPointerV2(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) (bool, error) { + if err := chatId.Validate(); err != nil { + return false, err + } + if err := memberId.Validate(); err != nil { + return false, err + } + if err := pointer.Validate(); err != nil { + return false, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + members, exists := s.members[string(chatId[:])] + if !exists { + return false, chat.ErrMemberNotFound + } + + member, exists := members[memberId.String()] + if !exists { + return false, chat.ErrMemberNotFound + } + + switch pointerType { + case chat.PointerTypeSent: + case chat.PointerTypeDelivered: + if member.DeliveryPointer == nil || bytes.Compare(pointer[:], member.DeliveryPointer[:]) > 0 { + newPtr := pointer.Clone() + member.DeliveryPointer = &newPtr + return true, nil + } + case chat.PointerTypeRead: + if member.ReadPointer == nil || bytes.Compare(pointer[:], member.ReadPointer[:]) > 0 { + newPtr := pointer.Clone() + member.ReadPointer = &newPtr + return true, nil + } + default: + return false, chat.ErrInvalidPointerType + } + + return false, nil +} + +// GetChatUnreadCountV2 calculates and returns the unread message count +func (s *InMemoryStore) GetChatUnreadCountV2(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, readPointer *chat.MessageId) (uint32, error) { + if err := chatId.Validate(); err != nil { + return 0, err + } + if err := memberId.Validate(); err != nil { + return 0, err + } + if readPointer != nil { + if err := readPointer.Validate(); err != nil { + return 0, err + } + } + + s.mu.RLock() + defer s.mu.RUnlock() + + unread := uint32(0) + messages := s.messages[string(chatId[:])] + for _, message := range messages { + if readPointer != nil { + if bytes.Compare(message.MessageId[:], readPointer[:]) <= 0 { + continue + } + } + + if message.Sender.String() == memberId.String() { + continue + } + + unread++ + } + + return unread, nil +} diff --git a/pkg/code/data/chat/v2/memory/store_test.go b/pkg/code/data/chat/v2/memory/store_test.go new file mode 100644 index 00000000..48f10ff1 --- /dev/null +++ b/pkg/code/data/chat/v2/memory/store_test.go @@ -0,0 +1,535 @@ +package memory + +import ( + "bytes" + "context" + "fmt" + "math/rand/v2" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/require" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/pointer" +) + +func TestInMemoryStore_GetChatMetadata(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + metadata := &chat.MetadataRecord{ + Id: 0, + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + ChatTitle: pointer.String("hello"), + } + + result, err := store.GetChatMetadata(context.Background(), chatId) + require.ErrorIs(t, err, chat.ErrChatNotFound) + require.Nil(t, result) + + require.NoError(t, store.PutChatV2(context.Background(), metadata)) + require.ErrorIs(t, store.PutChatV2(context.Background(), metadata), chat.ErrChatExists) + + result, err = store.GetChatMetadata(context.Background(), chatId) + require.NoError(t, err) + require.Equal(t, metadata.Clone(), result.Clone()) +} + +func TestInMemoryStore_GetAllChatsForUserV2(t *testing.T) { + store := New() + + memberId := chat.MemberId("user123") + + chatIds, err := store.GetAllChatsForUserV2(context.Background(), memberId) + require.NoError(t, err) + require.Empty(t, chatIds) + + var expectedChatIds []chat.ChatId + for i := 0; i < 10; i++ { + chatId := chat.ChatId(bytes.Repeat([]byte{byte(i)}, 32)) + expectedChatIds = append(expectedChatIds, chatId) + + require.NoError(t, store.PutChatV2(context.Background(), &chat.MetadataRecord{ + Id: 0, + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + })) + + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + })) + require.ErrorIs(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + }), chat.ErrMemberExists) + } + + chatIds, err = store.GetAllChatsForUserV2(context.Background(), memberId) + require.NoError(t, err) + require.Equal(t, expectedChatIds, chatIds) +} + +func TestInMemoryStore_GetAllChatsForUserV2_Pagination(t *testing.T) { + store := New() + + memberId := chat.MemberId("user123") + + // Create 10 chats + var chatIds []chat.ChatId + for i := 0; i < 10; i++ { + chatId := chat.ChatId(bytes.Repeat([]byte{byte(i)}, 32)) + chatIds = append(chatIds, chatId) + + require.NoError(t, store.PutChatV2(context.Background(), &chat.MetadataRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + })) + + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + })) + } + + reversedChatIds := slices.Clone(chatIds) + slices.Reverse(reversedChatIds) + + t.Run("Ascending Order", func(t *testing.T) { + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Ascending)) + require.NoError(t, err) + require.Equal(t, chatIds, result) + }) + + t.Run("Descending Order", func(t *testing.T) { + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Descending)) + require.NoError(t, err) + require.Equal(t, reversedChatIds, result) + }) + + t.Run("With Cursor", func(t *testing.T) { + cursor := chatIds[3][:] + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Ascending), query.WithCursor(cursor)) + require.NoError(t, err) + require.Equal(t, chatIds[4:], result) + }) + + t.Run("With Cursor (Descending)", func(t *testing.T) { + cursor := reversedChatIds[6][:] + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Descending), query.WithCursor(cursor)) + require.NoError(t, err) + require.Equal(t, reversedChatIds[7:], result) + }) + + t.Run("With Limit", func(t *testing.T) { + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithLimit(5)) + require.NoError(t, err) + require.Equal(t, chatIds[:5], result) + }) + + t.Run("With Limit (Descending)", func(t *testing.T) { + cursor := reversedChatIds[4][:] + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Descending), query.WithCursor(cursor), query.WithLimit(3)) + require.NoError(t, err) + require.Equal(t, reversedChatIds[5:8], result) + }) +} + +func TestInMemoryStore_GetChatMessageV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + messageId := chat.GenerateMessageId() + message := &chat.MessageRecord{ + ChatId: chatId, + MessageId: messageId, + Payload: []byte("payload"), + } + + err := store.PutChatMessageV2(context.Background(), message) + require.NoError(t, err) + + result, err := store.GetChatMessageV2(context.Background(), chatId, messageId) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, bytes.Equal(result.MessageId[:], messageId[:])) +} + +// TODO: Need proper pagination tests +func TestInMemoryStore_GetAllChatMessagesV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + + var expectedMessages []*chat.MessageRecord + for i := 0; i < 10; i++ { + message := &chat.MessageRecord{ + ChatId: chatId, + MessageId: chat.GenerateMessageId(), + Payload: []byte(fmt.Sprintf("payload-%d", i)), + } + expectedMessages = append(expectedMessages, message) + + // TODO: We might need a way to address this longer term. + time.Sleep(time.Millisecond) + + require.NoError(t, store.PutChatMessageV2(context.Background(), message)) + require.ErrorIs(t, store.PutChatMessageV2(context.Background(), message), chat.ErrMessageExists) + } + + isSorted := slices.IsSortedFunc(expectedMessages, func(a, b *chat.MessageRecord) int { + return bytes.Compare(a.MessageId[:], b.MessageId[:]) + }) + require.True(t, isSorted) + + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId) + require.NoError(t, err) + require.Equal(t, len(expectedMessages), len(messages)) + + for i := 0; i < len(messages); i++ { + require.Equal(t, expectedMessages[i].ChatId, messages[i].ChatId) + require.Equal(t, expectedMessages[i].MessageId, messages[i].MessageId) + require.Equal(t, expectedMessages[i].Sender, messages[i].Sender) + require.Equal(t, expectedMessages[i].Payload, messages[i].Payload) + require.Equal(t, expectedMessages[i].IsSilent, messages[i].IsSilent) + } +} + +func TestInMemoryStore_GetAllChatMessagesV2_Pagination(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + + var expectedMessages []*chat.MessageRecord + for i := 0; i < 10; i++ { + message := &chat.MessageRecord{ + ChatId: chatId, + MessageId: chat.GenerateMessageId(), + Payload: []byte(fmt.Sprintf("payload-%d", i)), + } + expectedMessages = append(expectedMessages, message) + time.Sleep(time.Millisecond) + require.NoError(t, store.PutChatMessageV2(context.Background(), message)) + } + + reversedMessages := slices.Clone(expectedMessages) + slices.Reverse(reversedMessages) + + t.Run("Ascending order", func(t *testing.T) { + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Ascending)) + require.NoError(t, err) + require.Equal(t, expectedMessages, messages) + }) + + t.Run("Descending order", func(t *testing.T) { + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Descending)) + require.NoError(t, err) + require.Equal(t, reversedMessages, messages) + }) + + t.Run("With limit", func(t *testing.T) { + limit := uint64(5) + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Ascending), query.WithLimit(limit)) + require.NoError(t, err) + require.Equal(t, expectedMessages[:limit], messages) + }) + + t.Run("With cursor", func(t *testing.T) { + cursor := expectedMessages[3].MessageId[:] + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Ascending), query.WithCursor(cursor)) + require.NoError(t, err) + require.Equal(t, expectedMessages[4:], messages) + }) + + t.Run("With cursor and limit", func(t *testing.T) { + cursor := reversedMessages[3].MessageId[:] + limit := uint64(3) + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Descending), query.WithCursor(cursor), query.WithLimit(limit)) + require.NoError(t, err) + require.Equal(t, reversedMessages[4:7], messages) + }) +} + +// TODO: Need proper pagination tests +func TestInMemoryStore_GetChatMembersV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + + var expectedMembers []*chat.MemberRecord + for i := 0; i < 10; i++ { + member := &chat.MemberRecord{ + ChatId: chatId, + MemberId: fmt.Sprintf("user%d", i), + Owner: fmt.Sprintf("owner%d", i), + Platform: chat.PlatformTwitter, + PlatformId: fmt.Sprintf("twitter%d", i), + IsMuted: true, + JoinedAt: time.Now(), + } + + dPtr := chat.GenerateMessageId() + time.Sleep(time.Millisecond) + rPtr := chat.GenerateMessageId() + + member.DeliveryPointer = &dPtr + member.ReadPointer = &rPtr + + expectedMembers = append(expectedMembers, member) + + require.NoError(t, store.PutChatMemberV2(context.Background(), member)) + require.ErrorIs(t, store.PutChatMemberV2(context.Background(), member), chat.ErrMemberExists) + } + + members, err := store.GetChatMembersV2(context.Background(), chatId) + require.NoError(t, err) + require.Equal(t, expectedMembers, members) +} + +func TestInMemoryStore_IsChatMember(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + memberId := chat.MemberId("user123") + + isMember, err := store.IsChatMember(context.Background(), chatId, memberId) + require.NoError(t, err) + require.False(t, isMember) + + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + })) + + isMember, err = store.IsChatMember(context.Background(), chatId, memberId) + require.NoError(t, err) + require.True(t, isMember) +} + +func TestInMemoryStore_PutChatV2(t *testing.T) { + store := New() + + for i, expected := range []*chat.MetadataRecord{ + { + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + }, + { + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + ChatTitle: pointer.String("hello"), + }, + } { + expected.ChatId = chat.ChatId(bytes.Repeat([]byte{byte(i)}, 32)) + + require.NoError(t, store.PutChatV2(context.Background(), expected)) + + other := expected.Clone() + other.ChatTitle = pointer.String("mutated") + require.ErrorIs(t, store.PutChatV2(context.Background(), &other), chat.ErrChatExists) + + actual, err := store.GetChatMetadata(context.Background(), expected.ChatId) + require.NoError(t, err) + require.Equal(t, expected, actual) + } + + for _, invalid := range []*chat.MetadataRecord{ + {}, + { + ChatId: chat.ChatId{1, 2, 3}, + }, + { + ChatId: chat.ChatId{1, 2, 3}, + CreatedAt: time.Now(), + }, + { + ChatId: chat.ChatId{1, 2, 3}, + ChatType: chat.ChatTypeTwoWay, + }, + } { + require.Error(t, store.PutChatV2(context.Background(), invalid)) + } +} + +func TestInMemoryStore_SetChatMuteStateV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + memberId := chat.MemberId("user123") + + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + Platform: chat.PlatformTwitter, + PlatformId: "user", + MemberId: memberId.String(), + JoinedAt: time.Now(), + })) + + members, err := store.GetChatMembersV2(context.Background(), chatId) + require.NoError(t, err) + require.False(t, members[0].IsMuted) + + require.NoError(t, store.SetChatMuteStateV2(context.Background(), chatId, memberId, true)) + + members, err = store.GetChatMembersV2(context.Background(), chatId) + require.NoError(t, err) + require.True(t, members[0].IsMuted) +} + +func TestInMemoryStore_GetChatUnreadCountV2(t *testing.T) { + store := New() + + // Create multiple chats + chats := []chat.ChatId{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + } + counts := []int{0, 0, 0} + + ourMemberId := chat.MemberId("our_user") + otherMemberId := chat.MemberId("other_user") + + for chatIdx, chatId := range chats { + // Add members to the chat + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + Platform: chat.PlatformTwitter, + PlatformId: "our_user", + MemberId: ourMemberId.String(), + JoinedAt: time.Now(), + })) + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + Platform: chat.PlatformTwitter, + PlatformId: "other_user", + MemberId: otherMemberId.String(), + JoinedAt: time.Now(), + })) + + // Generate N messages for each chat + N := 10 + for i := 0; i < N; i++ { + sender := ourMemberId + if rand.IntN(100) < 50 { // Approximately 50% chance for a message to be from the other user + sender = otherMemberId + counts[chatIdx]++ + } + + require.NoError(t, store.PutChatMessageV2(context.Background(), &chat.MessageRecord{ + ChatId: chatId, + MessageId: chat.GenerateMessageId(), + Sender: &sender, + Payload: []byte(fmt.Sprintf("Message %d for chat %v", i, chatId)), + })) + + time.Sleep(time.Millisecond) + } + } + + // Verify that each chat has a distinct unread count + for chatIdx, chatId := range chats { + ptr := chat.GenerateMessageIdAtTime(time.Now().Add(-time.Hour)) + count, err := store.GetChatUnreadCountV2(context.Background(), chatId, ourMemberId, &ptr) + require.NoError(t, err) + require.EqualValues(t, counts[chatIdx], count) + + if count == 0 { + continue + } + + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId) + require.NoError(t, err) + + var offset *chat.MessageId + for _, message := range messages { + if message.Sender != nil && !bytes.Equal(*message.Sender, ourMemberId) { + offset = &message.MessageId + break + } + } + require.NotNil(t, offset) + + newCount, err := store.GetChatUnreadCountV2(context.Background(), chatId, ourMemberId, offset) + require.NoError(t, err) + require.Equal(t, count-1, newCount) + } +} + +func TestInMemoryStore_AdvanceChatPointerV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + memberId := chat.MemberId("user123") + + // Create a chat and add a member + metadata := &chat.MetadataRecord{ + Id: 0, + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + } + require.NoError(t, store.PutChatV2(context.Background(), metadata)) + + member := &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + + DeliveryPointer: nil, + ReadPointer: nil, + } + require.NoError(t, store.PutChatMemberV2(context.Background(), member)) + + // Test advancing delivery pointer + message1 := chat.GenerateMessageId() + advanced, err := store.AdvanceChatPointerV2(context.Background(), chatId, memberId, chat.PointerTypeDelivered, message1) + require.NoError(t, err) + require.True(t, advanced) + + // Test advancing read pointer + message2 := chat.GenerateMessageId() + advanced, err = store.AdvanceChatPointerV2(context.Background(), chatId, memberId, chat.PointerTypeRead, message2) + require.NoError(t, err) + require.True(t, advanced) + + // Test advancing to an earlier message (should not advance) + advanced, err = store.AdvanceChatPointerV2(context.Background(), chatId, memberId, chat.PointerTypeDelivered, message1) + require.NoError(t, err) + require.False(t, advanced) + + // Test with invalid pointer type + _, err = store.AdvanceChatPointerV2(context.Background(), chatId, memberId, chat.PointerType(8), message2) + require.ErrorIs(t, err, chat.ErrInvalidPointerType) + + // Test with non-existent chat + nonExistentChatId := chat.ChatId{4, 5, 6} + _, err = store.AdvanceChatPointerV2(context.Background(), nonExistentChatId, memberId, chat.PointerTypeDelivered, message2) + require.ErrorIs(t, err, chat.ErrMemberNotFound) + + // Test with non-existent member + nonExistentMemberId := chat.MemberId("nonexistent") + _, err = store.AdvanceChatPointerV2(context.Background(), chatId, nonExistentMemberId, chat.PointerTypeDelivered, message2) + require.ErrorIs(t, err, chat.ErrMemberNotFound) +} diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go new file mode 100644 index 00000000..dd80ff2a --- /dev/null +++ b/pkg/code/data/chat/v2/model.go @@ -0,0 +1,417 @@ +package chat_v2 + +import ( + "fmt" + "time" + + "github.com/pkg/errors" + + "github.com/code-payments/code-server/pkg/pointer" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" +) + +type ChatType uint8 + +const ( + ChatTypeUnknown ChatType = iota + ChatTypeTwoWay + // ChatTypeGroup +) + +// GetChatTypeFromProto gets a chat type from the protobuf variant +func GetChatTypeFromProto(proto chatpb.ChatType) ChatType { + switch proto { + case chatpb.ChatType_TWO_WAY: + return ChatTypeTwoWay + default: + return ChatTypeUnknown + } +} + +// ToProto returns the proto representation of the chat type +func (c ChatType) ToProto() chatpb.ChatType { + switch c { + case ChatTypeTwoWay: + return chatpb.ChatType_TWO_WAY + default: + return chatpb.ChatType_UNKNOWN_CHAT_TYPE + } +} + +// String returns the string representation of the chat type +func (c ChatType) String() string { + switch c { + case ChatTypeTwoWay: + return "two-way" + default: + return "unknown" + } +} + +type PointerType uint8 + +const ( + PointerTypeUnknown PointerType = iota + PointerTypeSent + PointerTypeDelivered + PointerTypeRead +) + +// GetPointerTypeFromProto gets a chat ID from the protobuf variant +func GetPointerTypeFromProto(proto chatpb.PointerType) PointerType { + switch proto { + case chatpb.PointerType_SENT: + return PointerTypeSent + case chatpb.PointerType_DELIVERED: + return PointerTypeDelivered + case chatpb.PointerType_READ: + return PointerTypeRead + default: + return PointerTypeUnknown + } +} + +// ToProto returns the proto representation of the pointer type +func (p PointerType) ToProto() chatpb.PointerType { + switch p { + case PointerTypeSent: + return chatpb.PointerType_SENT + case PointerTypeDelivered: + return chatpb.PointerType_DELIVERED + case PointerTypeRead: + return chatpb.PointerType_READ + default: + return chatpb.PointerType_UNKNOWN_POINTER_TYPE + } +} + +// String returns the string representation of the pointer type +func (p PointerType) String() string { + switch p { + case PointerTypeSent: + return "sent" + case PointerTypeDelivered: + return "delivered" + case PointerTypeRead: + return "read" + default: + return "unknown" + } +} + +type Platform uint8 + +const ( + PlatformUnknown Platform = iota + PlatformTwitter +) + +// GetPlatformFromProto returns the proto representation of the platform +func GetPlatformFromProto(proto chatpb.Platform) Platform { + switch proto { + case chatpb.Platform_TWITTER: + return PlatformTwitter + default: + return PlatformUnknown + } +} + +// ToProto returns the proto representation of the platform +func (p Platform) ToProto() chatpb.Platform { + switch p { + case PlatformTwitter: + return chatpb.Platform_TWITTER + default: + return chatpb.Platform_UNKNOWN_PLATFORM + } +} + +// String returns the string representation of the platform +func (p Platform) String() string { + switch p { + case PlatformTwitter: + return "twitter" + default: + return "unknown" + } +} + +type MetadataRecord struct { + Id int64 + ChatId ChatId + ChatType ChatType + CreatedAt time.Time + + ChatTitle *string +} + +// Validate validates a chat Record +func (r *MetadataRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return errors.Wrap(err, "invalid chat id") + } + + switch r.ChatType { + case ChatTypeTwoWay: + default: + return errors.Errorf("invalid chat type: %d", r.ChatType) + } + + if r.CreatedAt.IsZero() { + return errors.New("creation timestamp is required") + } + + return nil +} + +// Clone clones a chat record +func (r *MetadataRecord) Clone() MetadataRecord { + return MetadataRecord{ + Id: r.Id, + ChatId: r.ChatId, + ChatType: r.ChatType, + CreatedAt: r.CreatedAt, + + ChatTitle: pointer.StringCopy(r.ChatTitle), + } +} + +// CopyTo copies a chat record to the provided destination +func (r *MetadataRecord) CopyTo(dst *MetadataRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + dst.ChatType = r.ChatType + dst.CreatedAt = r.CreatedAt + + dst.ChatTitle = pointer.StringCopy(r.ChatTitle) +} + +type MemberRecord struct { + Id int64 + ChatId ChatId + + // MemberId is derived from Owner (using account.ToMessagingAccount) + // + // It is stored to allow indexed lookups when only MemberId is available. + // We must also store Owner so server can lookup proper push tokens. + MemberId string + + // Owner is required to be able to send push notifications. + // + // Currently, it is _optional_, as we don't have a way to reverse lookup. + // However, we _will_ want to make it mandatory. + Owner string + + // Identity. + // + // Currently, assumes single. + Platform Platform + PlatformId string + + DeliveryPointer *MessageId + ReadPointer *MessageId + + IsMuted bool + JoinedAt time.Time +} + +// Validate validates a member Record +func (r *MemberRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return fmt.Errorf("invalid chat id: %w", err) + } + + if len(r.MemberId) == 0 { + return fmt.Errorf("missing member id") + } + + if len(r.PlatformId) == 0 { + return fmt.Errorf("missing platform id") + } + + switch r.Platform { + case PlatformTwitter: + if len(r.PlatformId) > 15 { + return errors.New("platform id must have at most 15 characters") + } + default: + return errors.Errorf("invalid plaftorm: %d", r.Platform) + } + + if r.DeliveryPointer != nil { + if err := r.DeliveryPointer.Validate(); err != nil { + return errors.Wrap(err, "invalid delivery pointer") + } + } + + if r.ReadPointer != nil { + if err := r.ReadPointer.Validate(); err != nil { + return errors.Wrap(err, "invalid read pointer") + } + } + + if r.JoinedAt.IsZero() { + return errors.New("joined timestamp is required") + } + + return nil +} + +// Clone clones a member record +func (r *MemberRecord) Clone() MemberRecord { + var deliveryPointerCopy *MessageId + if r.DeliveryPointer != nil { + cloned := r.DeliveryPointer.Clone() + deliveryPointerCopy = &cloned + } + + var readPointerCopy *MessageId + if r.ReadPointer != nil { + cloned := r.ReadPointer.Clone() + readPointerCopy = &cloned + } + + return MemberRecord{ + Id: r.Id, + ChatId: r.ChatId, + MemberId: r.MemberId, + Owner: r.Owner, + + Platform: r.Platform, + PlatformId: r.PlatformId, + + DeliveryPointer: deliveryPointerCopy, + ReadPointer: readPointerCopy, + + IsMuted: r.IsMuted, + JoinedAt: r.JoinedAt, + } +} + +// CopyTo copies a member record to the provided destination +func (r *MemberRecord) CopyTo(dst *MemberRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + dst.Owner = r.Owner + dst.MemberId = r.MemberId + + dst.Platform = r.Platform + dst.PlatformId = r.PlatformId + + if r.DeliveryPointer != nil { + cloned := r.DeliveryPointer.Clone() + dst.DeliveryPointer = &cloned + } + if r.ReadPointer != nil { + cloned := r.ReadPointer.Clone() + dst.ReadPointer = &cloned + } + + dst.IsMuted = r.IsMuted + dst.JoinedAt = r.JoinedAt +} + +type MembersById []*MemberRecord + +func (a MembersById) Len() int { return len(a) } +func (a MembersById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MembersById) Less(i, j int) bool { + return a[i].Id < a[j].Id +} + +type MessageRecord struct { + Id int64 + ChatId ChatId + MessageId MessageId + + Sender *MemberId + + Payload []byte + + IsSilent bool + + // Note: No timestamp field, since it's encoded in MessageId + // Note: Maybe a timestamp field, because it's maybe better? +} + +// Validate validates a message Record +func (r *MessageRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return errors.Wrap(err, "invalid chat id") + } + + if err := r.MessageId.Validate(); err != nil { + return errors.Wrap(err, "invalid message id") + } + + if r.Sender != nil { + if err := r.Sender.Validate(); err != nil { + return errors.Wrap(err, "invalid sender id") + } + } + + if len(r.Payload) == 0 { + return errors.New("message payload is required") + } + + return nil +} + +type MessagesByMessageId []*MessageRecord + +func (a MessagesByMessageId) Len() int { return len(a) } +func (a MessagesByMessageId) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MessagesByMessageId) Less(i, j int) bool { + return a[i].MessageId.Before(a[j].MessageId) +} + +// Clone clones a message record +func (r *MessageRecord) Clone() MessageRecord { + var senderCopy *MemberId + if r.Sender != nil { + cloned := r.Sender.Clone() + senderCopy = &cloned + } + + var payloadCopy []byte + if len(r.Payload) > 0 { + payloadCopy = make([]byte, len(r.Payload)) + copy(payloadCopy, r.Payload) + } + + return MessageRecord{ + Id: r.Id, + ChatId: r.ChatId, + MessageId: r.MessageId, + + Sender: senderCopy, + + Payload: payloadCopy, + + IsSilent: r.IsSilent, + } +} + +// CopyTo copies a message record to the provided destination +func (r *MessageRecord) CopyTo(dst *MessageRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + dst.MessageId = r.MessageId + + if r.Sender != nil { + cloned := r.Sender.Clone() + dst.Sender = &cloned + } + + payloadCopy := make([]byte, len(r.Payload)) + copy(payloadCopy, r.Payload) + dst.Payload = payloadCopy + + dst.IsSilent = r.IsSilent +} + +// GetTimestamp gets the timestamp for a message record +func (r *MessageRecord) GetTimestamp() (time.Time, error) { + return r.MessageId.GetTimestamp() +} diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go new file mode 100644 index 00000000..fa4fbb6e --- /dev/null +++ b/pkg/code/data/chat/v2/store.go @@ -0,0 +1,70 @@ +package chat_v2 + +import ( + "context" + "errors" + "github.com/code-payments/code-server/pkg/database/query" +) + +var ( + ErrChatExists = errors.New("chat already exists") + ErrChatNotFound = errors.New("chat not found") + ErrMemberExists = errors.New("chat member already exists") + ErrMemberNotFound = errors.New("chat member not found") + ErrMessageExists = errors.New("chat message already exists") + ErrMessageNotFound = errors.New("chat message not found") + ErrInvalidPointerType = errors.New("invalid pointer type") +) + +type Store interface { + // GetChatMetadata retrieves the metadata record for a specific chat, identified by chatId. + // + // It returns ErrChatNotFound if the chat doesn't exist. + GetChatMetadata(ctx context.Context, chatId ChatId) (*MetadataRecord, error) + + // GetChatMessageV2 retrieves a specific message from a chat, identified by chatId and messageId. + GetChatMessageV2(ctx context.Context, chatId ChatId, messageId MessageId) (*MessageRecord, error) + + // GetAllChatsForUserV2 retrieves all chat IDs that a given user (where user is the messaging address). + GetAllChatsForUserV2(ctx context.Context, user MemberId, opts ...query.Option) ([]ChatId, error) + + // GetAllChatMessagesV2 retrieves all messages for a specific chat, identified by chatId. + GetAllChatMessagesV2(ctx context.Context, chatId ChatId, opts ...query.Option) ([]*MessageRecord, error) + + // GetChatMembersV2 retrieves all members of a specific chat, identified by chatId. + GetChatMembersV2(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) + + // IsChatMember checks if a given member, identified by memberId, is part of a specific chat, identified by chatId. + IsChatMember(ctx context.Context, chatId ChatId, memberId MemberId) (bool, error) + + // PutChatV2 stores or updates the metadata for a specific chat. + // + // ErrChatExists is returned if the chat with the same ID already exists. + PutChatV2(ctx context.Context, record *MetadataRecord) error + + // PutChatMemberV2 stores or updates a member record for a specific chat. + // + // ErrMemberExists is returned if the member already exists. + // Updating should be done with specific DB calls. + PutChatMemberV2(ctx context.Context, record *MemberRecord) error + + // PutChatMessageV2 stores or updates a message record in a specific chat. + // + // ErrMessageExists is returned if the message already exists. + PutChatMessageV2(ctx context.Context, record *MessageRecord) error + + // SetChatMuteStateV2 sets the mute state for a specific chat member, identified by chatId and memberId. + // + // ErrMemberNotFound if the member does not exist. + SetChatMuteStateV2(ctx context.Context, chatId ChatId, memberId MemberId, isMuted bool) error + + // AdvanceChatPointerV2 advances a pointer for a chat member, identified by chatId and memberId. + // + // It returns whether the pointer was advanced. If no member exists, ErrMemberNotFound is returned. + AdvanceChatPointerV2(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) (bool, error) + + // GetChatUnreadCountV2 calculates and returns the unread message count for a specific chat member, + // + // Existence checks are not performed. + GetChatUnreadCountV2(ctx context.Context, chatId ChatId, memberId MemberId, readPointer *MessageId) (uint32, error) +} diff --git a/pkg/code/data/intent/intent.go b/pkg/code/data/intent/intent.go index 70a840cc..6c7f5206 100644 --- a/pkg/code/data/intent/intent.go +++ b/pkg/code/data/intent/intent.go @@ -109,9 +109,13 @@ type SendPrivatePaymentMetadata struct { IsRemoteSend bool IsMicroPayment bool IsTip bool + IsChat bool // Set when IsTip = true TipMetadata *TipMetadata + + // Set when IsChat = true + ChatId string } type TipMetadata struct { @@ -578,8 +582,10 @@ func (m *SendPrivatePaymentMetadata) Clone() SendPrivatePaymentMetadata { IsRemoteSend: m.IsRemoteSend, IsMicroPayment: m.IsMicroPayment, IsTip: m.IsTip, + IsChat: m.IsChat, TipMetadata: tipMetadata, + ChatId: m.ChatId, } } @@ -605,8 +611,10 @@ func (m *SendPrivatePaymentMetadata) CopyTo(dst *SendPrivatePaymentMetadata) { dst.IsRemoteSend = m.IsRemoteSend dst.IsMicroPayment = m.IsMicroPayment dst.IsTip = m.IsTip + dst.IsChat = m.IsChat dst.TipMetadata = tipMetadata + dst.ChatId = m.ChatId } func (m *SendPrivatePaymentMetadata) Validate() error { @@ -650,6 +658,14 @@ func (m *SendPrivatePaymentMetadata) Validate() error { return errors.New("tip metadata can only be set for tips") } + if m.IsChat { + if len(m.ChatId) == 0 { + return errors.New("chat_id required for chat") + } + } else if m.ChatId != "" { + return errors.New("chat_id can only be set for chats") + } + return nil } diff --git a/pkg/code/data/intent/postgres/model.go b/pkg/code/data/intent/postgres/model.go index 77b8e7a0..4aa62649 100644 --- a/pkg/code/data/intent/postgres/model.go +++ b/pkg/code/data/intent/postgres/model.go @@ -49,6 +49,8 @@ type intentModel struct { IsTip bool `db:"is_tip"` TipPlatform sql.NullInt16 `db:"tip_platform"` TippedUsername sql.NullString `db:"tipped_username"` + IsChat bool `db:"is_chat"` + ChatId sql.NullString `db:"chat_id"` RelationshipTo sql.NullString `db:"relationship_to"` InitiatorPhoneNumber sql.NullString `db:"phone_number"` // todo: rename the DB field to initiator_phone_number State uint `db:"state"` @@ -106,6 +108,7 @@ func toIntentModel(obj *intent.Record) (*intentModel, error) { m.IsRemoteSend = obj.SendPrivatePaymentMetadata.IsRemoteSend m.IsMicroPayment = obj.SendPrivatePaymentMetadata.IsMicroPayment m.IsTip = obj.SendPrivatePaymentMetadata.IsTip + m.IsChat = obj.SendPrivatePaymentMetadata.IsChat if m.IsTip { m.TipPlatform = sql.NullInt16{ @@ -117,6 +120,13 @@ func toIntentModel(obj *intent.Record) (*intentModel, error) { String: obj.SendPrivatePaymentMetadata.TipMetadata.Username, } } + + if m.IsChat { + m.ChatId = sql.NullString{ + Valid: true, + String: obj.SendPrivatePaymentMetadata.ChatId, + } + } case intent.ReceivePaymentsPrivately: m.Source = obj.ReceivePaymentsPrivatelyMetadata.Source m.Quantity = obj.ReceivePaymentsPrivatelyMetadata.Quantity @@ -224,6 +234,7 @@ func fromIntentModel(obj *intentModel) *intent.Record { IsRemoteSend: obj.IsRemoteSend, IsMicroPayment: obj.IsMicroPayment, IsTip: obj.IsTip, + IsChat: obj.IsChat, } if record.SendPrivatePaymentMetadata.IsTip { @@ -232,6 +243,11 @@ func fromIntentModel(obj *intentModel) *intent.Record { Username: obj.TippedUsername.String, } } + + if record.SendPrivatePaymentMetadata.IsChat { + record.SendPrivatePaymentMetadata.ChatId = obj.ChatId.String + } + case intent.ReceivePaymentsPrivately: record.ReceivePaymentsPrivatelyMetadata = &intent.ReceivePaymentsPrivatelyMetadata{ Source: obj.Source, @@ -300,16 +316,16 @@ func fromIntentModel(obj *intentModel) *intent.Record { func (m *intentModel) dbSave(ctx context.Context, db *sqlx.DB) error { return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { query := `INSERT INTO ` + intentTableName + ` - (intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) + (intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28) ON CONFLICT (intent_id) DO UPDATE - SET state = $25 + SET state = $27 WHERE ` + intentTableName + `.intent_id = $1 RETURNING - id, intent_id, intent_type, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at` + id, intent_id, intent_type, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at` err := tx.QueryRowxContext( ctx, @@ -334,9 +350,11 @@ func (m *intentModel) dbSave(ctx context.Context, db *sqlx.DB) error { m.IsIssuerVoidingGiftCard, m.IsMicroPayment, m.IsTip, + m.IsChat, m.RelationshipTo, m.TipPlatform, m.TippedUsername, + m.ChatId, m.InitiatorPhoneNumber, m.State, m.CreatedAt, @@ -349,7 +367,7 @@ func (m *intentModel) dbSave(ctx context.Context, db *sqlx.DB) error { func dbGetIntent(ctx context.Context, db *sqlx.DB, intentID string) (*intentModel, error) { res := &intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE intent_id = $1 LIMIT 1` @@ -364,7 +382,7 @@ func dbGetIntent(ctx context.Context, db *sqlx.DB, intentID string) (*intentMode func dbGetLatestByInitiatorAndType(ctx context.Context, db *sqlx.DB, intentType intent.Type, owner string) (*intentModel, error) { res := &intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE owner = $1 AND intent_type = $2 ORDER BY created_at DESC @@ -381,7 +399,7 @@ func dbGetLatestByInitiatorAndType(ctx context.Context, db *sqlx.DB, intentType func dbGetAllByOwner(ctx context.Context, db *sqlx.DB, owner string, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE (owner = $1 OR destination_owner = $1) AND (intent_type != $2 AND intent_type != $3) ` @@ -542,7 +560,7 @@ func dbGetNetBalanceFromPrePrivacy2022Intents(ctx context.Context, db *sqlx.DB, func dbGetLatestSaveRecentRootIntentForTreasury(ctx context.Context, db *sqlx.DB, treasury string) (*intentModel, error) { res := &intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE treasury_pool = $1 and intent_type = $2 ORDER BY id DESC @@ -559,7 +577,7 @@ func dbGetLatestSaveRecentRootIntentForTreasury(ctx context.Context, db *sqlx.DB func dbGetOriginalGiftCardIssuedIntent(ctx context.Context, db *sqlx.DB, giftCardVault string) (*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE destination = $1 and intent_type = $2 AND state != $3 AND is_remote_send IS TRUE LIMIT 2 @@ -591,7 +609,7 @@ func dbGetOriginalGiftCardIssuedIntent(ctx context.Context, db *sqlx.DB, giftCar func dbGetGiftCardClaimedIntent(ctx context.Context, db *sqlx.DB, giftCardVault string) (*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE source = $1 and intent_type = $2 AND state != $3 AND is_remote_send IS TRUE LIMIT 2 diff --git a/pkg/code/data/intent/postgres/store_test.go b/pkg/code/data/intent/postgres/store_test.go index fdaae93b..18a7d1c9 100644 --- a/pkg/code/data/intent/postgres/store_test.go +++ b/pkg/code/data/intent/postgres/store_test.go @@ -47,13 +47,15 @@ const ( is_issuer_voiding_gift_card BOOL NOT NULL, is_micro_payment BOOL NOT NULL, is_tip BOOL NOT NULL, + is_chat BOOL NOT NULL, relationship_to TEXT NULL, tip_platform INTEGER NULL, tipped_username TEXT NULL, - phone_number text NULL, + chat_id TEXT NULL, + phone_number TEXT NULL, state integer NOT NULL, diff --git a/pkg/code/data/intent/tests/tests.go b/pkg/code/data/intent/tests/tests.go index c17af4cb..2769b092 100644 --- a/pkg/code/data/intent/tests/tests.go +++ b/pkg/code/data/intent/tests/tests.go @@ -37,6 +37,7 @@ func RunTests(t *testing.T, s intent.Store, teardown func()) { testGetLatestSaveRecentRootIntentForTreasury, testGetOriginalGiftCardIssuedIntent, testGetGiftCardClaimedIntent, + testChatPayment, } { tf(t, s) teardown() @@ -1011,3 +1012,54 @@ func testGetGiftCardClaimedIntent(t *testing.T, s intent.Store) { assert.Equal(t, "i9", actual.IntentId) }) } + +func testChatPayment(t *testing.T, s intent.Store) { + t.Run("testChatPayment", func(t *testing.T) { + record := &intent.Record{ + IntentId: "i1", + IntentType: intent.SendPrivatePayment, + InitiatorOwnerAccount: "init1", + SendPrivatePaymentMetadata: &intent.SendPrivatePaymentMetadata{ + DestinationOwnerAccount: "do", + DestinationTokenAccount: "dt", + Quantity: 1, + ExchangeCurrency: "USD", + ExchangeRate: 1, + NativeAmount: 1, + UsdMarketValue: 1, + IsChat: true, + ChatId: "chatId", + }, + } + require.NoError(t, s.Save(context.Background(), record)) + + saved, err := s.Get(context.Background(), record.IntentId) + require.NoError(t, err) + require.Equal(t, record, saved) + }) + + t.Run("testChatPayment invalid", func(t *testing.T) { + base := &intent.Record{ + IntentId: "i1", + IntentType: intent.SendPrivatePayment, + InitiatorOwnerAccount: "init1", + SendPrivatePaymentMetadata: &intent.SendPrivatePaymentMetadata{ + DestinationOwnerAccount: "do", + DestinationTokenAccount: "dt", + Quantity: 1, + ExchangeCurrency: "USD", + ExchangeRate: 1, + NativeAmount: 1, + UsdMarketValue: 1, + }, + } + + r := base.Clone() + r.SendPrivatePaymentMetadata.IsChat = true + require.Error(t, s.Save(context.Background(), &r)) + + r = base.Clone() + r.SendPrivatePaymentMetadata.ChatId = "chatId" + require.Error(t, s.Save(context.Background(), &r)) + }) +} diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 49479c84..758e1b5c 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -25,7 +25,8 @@ import ( "github.com/code-payments/code-server/pkg/code/data/airdrop" "github.com/code-payments/code-server/pkg/code/data/badgecount" "github.com/code-payments/code-server/pkg/code/data/balance" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/contact" "github.com/code-payments/code-server/pkg/code/data/currency" @@ -59,7 +60,8 @@ import ( airdrop_memory_client "github.com/code-payments/code-server/pkg/code/data/airdrop/memory" badgecount_memory_client "github.com/code-payments/code-server/pkg/code/data/badgecount/memory" balance_memory_client "github.com/code-payments/code-server/pkg/code/data/balance/memory" - chat_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/memory" + chat_v1_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/v1/memory" + chat_v2_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/v2/memory" commitment_memory_client "github.com/code-payments/code-server/pkg/code/data/commitment/memory" contact_memory_client "github.com/code-payments/code-server/pkg/code/data/contact/memory" currency_memory_client "github.com/code-payments/code-server/pkg/code/data/currency/memory" @@ -69,7 +71,7 @@ import ( intent_memory_client "github.com/code-payments/code-server/pkg/code/data/intent/memory" login_memory_client "github.com/code-payments/code-server/pkg/code/data/login/memory" merkletree_memory_client "github.com/code-payments/code-server/pkg/code/data/merkletree/memory" - messaging "github.com/code-payments/code-server/pkg/code/data/messaging" + "github.com/code-payments/code-server/pkg/code/data/messaging" messaging_memory_client "github.com/code-payments/code-server/pkg/code/data/messaging/memory" nonce_memory_client "github.com/code-payments/code-server/pkg/code/data/nonce/memory" onramp_memory_client "github.com/code-payments/code-server/pkg/code/data/onramp/memory" @@ -94,7 +96,7 @@ import ( airdrop_postgres_client "github.com/code-payments/code-server/pkg/code/data/airdrop/postgres" badgecount_postgres_client "github.com/code-payments/code-server/pkg/code/data/badgecount/postgres" balance_postgres_client "github.com/code-payments/code-server/pkg/code/data/balance/postgres" - chat_postgres_client "github.com/code-payments/code-server/pkg/code/data/chat/postgres" + chat_v1_postgres_client "github.com/code-payments/code-server/pkg/code/data/chat/v1/postgres" commitment_postgres_client "github.com/code-payments/code-server/pkg/code/data/commitment/postgres" contact_postgres_client "github.com/code-payments/code-server/pkg/code/data/contact/postgres" currency_postgres_client "github.com/code-payments/code-server/pkg/code/data/currency/postgres" @@ -378,19 +380,34 @@ type DatabaseData interface { CountWebhookByState(ctx context.Context, state webhook.State) (uint64, error) GetAllPendingWebhooksReadyToSend(ctx context.Context, limit uint64) ([]*webhook.Record, error) - // Chat + // Chat V1 // -------------------------------------------------------------------------------- - PutChat(ctx context.Context, record *chat.Chat) error - GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.Chat, error) - GetAllChatsForUser(ctx context.Context, user string, opts ...query.Option) ([]*chat.Chat, error) - PutChatMessage(ctx context.Context, record *chat.Message) error - DeleteChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) error - GetChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) - GetAllChatMessages(ctx context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.Message, error) - AdvanceChatPointer(ctx context.Context, chatId chat.ChatId, pointer string) error - GetChatUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) - SetChatMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error - SetChatSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error + PutChatV1(ctx context.Context, record *chat_v1.Chat) error + GetChatByIdV1(ctx context.Context, chatId chat_v1.ChatId) (*chat_v1.Chat, error) + GetAllChatsForUserV1(ctx context.Context, user string, opts ...query.Option) ([]*chat_v1.Chat, error) + PutChatMessageV1(ctx context.Context, record *chat_v1.Message) error + DeleteChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) error + GetChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) (*chat_v1.Message, error) + GetAllChatMessagesV1(ctx context.Context, chatId chat_v1.ChatId, opts ...query.Option) ([]*chat_v1.Message, error) + AdvanceChatPointerV1(ctx context.Context, chatId chat_v1.ChatId, pointer string) error + GetChatUnreadCountV1(ctx context.Context, chatId chat_v1.ChatId) (uint32, error) + SetChatMuteStateV1(ctx context.Context, chatId chat_v1.ChatId, isMuted bool) error + SetChatSubscriptionStateV1(ctx context.Context, chatId chat_v1.ChatId, isSubscribed bool) error + + // Chat V2 + // -------------------------------------------------------------------------------- + GetChatMetadata(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.MetadataRecord, error) + GetChatMessageV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) + GetAllChatsForUserV2(ctx context.Context, user chat_v2.MemberId, opts ...query.Option) ([]chat_v2.ChatId, error) + GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) + GetChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) + IsChatMember(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (bool, error) + PutChatV2(ctx context.Context, record *chat_v2.MetadataRecord) error + PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error + PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error + SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error + AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) + GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer *chat_v2.MessageId) (uint32, error) // Badge Count // -------------------------------------------------------------------------------- @@ -470,7 +487,8 @@ type DatabaseProvider struct { paywall paywall.Store event event.Store webhook webhook.Store - chat chat.Store + chatv1 chat_v1.Store + chatv2 chat_v2.Store badgecount badgecount.Store login login.Store balance balance.Store @@ -532,7 +550,8 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { paywall: paywall_postgres_client.New(db), event: event_postgres_client.New(db), webhook: webhook_postgres_client.New(db), - chat: chat_postgres_client.New(db), + chatv1: chat_v1_postgres_client.New(db), + chatv2: chat_v2_memory_client.New(), // todo: Postgres version for production after PoC badgecount: badgecount_postgres_client.New(db), login: login_postgres_client.New(db), balance: balance_postgres_client.New(db), @@ -575,7 +594,8 @@ func NewTestDatabaseProvider() DatabaseData { paywall: paywall_memory_client.New(), event: event_memory_client.New(), webhook: webhook_memory_client.New(), - chat: chat_memory_client.New(), + chatv1: chat_v1_memory_client.New(), + chatv2: chat_v2_memory_client.New(), badgecount: badgecount_memory_client.New(), login: login_memory_client.New(), balance: balance_memory_client.New(), @@ -1399,48 +1419,87 @@ func (dp *DatabaseProvider) GetAllPendingWebhooksReadyToSend(ctx context.Context return dp.webhook.GetAllPendingReadyToSend(ctx, limit) } -// Chat +// Chat V1 // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) PutChat(ctx context.Context, record *chat.Chat) error { - return dp.chat.PutChat(ctx, record) +func (dp *DatabaseProvider) PutChatV1(ctx context.Context, record *chat_v1.Chat) error { + return dp.chatv1.PutChat(ctx, record) } -func (dp *DatabaseProvider) GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.Chat, error) { - return dp.chat.GetChatById(ctx, chatId) +func (dp *DatabaseProvider) GetChatByIdV1(ctx context.Context, chatId chat_v1.ChatId) (*chat_v1.Chat, error) { + return dp.chatv1.GetChatById(ctx, chatId) } -func (dp *DatabaseProvider) GetAllChatsForUser(ctx context.Context, user string, opts ...query.Option) ([]*chat.Chat, error) { +func (dp *DatabaseProvider) GetAllChatsForUserV1(ctx context.Context, user string, opts ...query.Option) ([]*chat_v1.Chat, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chat.GetAllChatsForUser(ctx, user, req.Cursor, req.SortBy, req.Limit) + return dp.chatv1.GetAllChatsForUser(ctx, user, req.Cursor, req.SortBy, req.Limit) } -func (dp *DatabaseProvider) PutChatMessage(ctx context.Context, record *chat.Message) error { - return dp.chat.PutMessage(ctx, record) +func (dp *DatabaseProvider) PutChatMessageV1(ctx context.Context, record *chat_v1.Message) error { + return dp.chatv1.PutMessage(ctx, record) } -func (dp *DatabaseProvider) DeleteChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) error { - return dp.chat.DeleteMessage(ctx, chatId, messageId) +func (dp *DatabaseProvider) DeleteChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) error { + return dp.chatv1.DeleteMessage(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) { - return dp.chat.GetMessageById(ctx, chatId, messageId) +func (dp *DatabaseProvider) GetChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) (*chat_v1.Message, error) { + return dp.chatv1.GetMessageById(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetAllChatMessages(ctx context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.Message, error) { +func (dp *DatabaseProvider) GetAllChatMessagesV1(ctx context.Context, chatId chat_v1.ChatId, opts ...query.Option) ([]*chat_v1.Message, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chat.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) + return dp.chatv1.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) +} +func (dp *DatabaseProvider) AdvanceChatPointerV1(ctx context.Context, chatId chat_v1.ChatId, pointer string) error { + return dp.chatv1.AdvancePointer(ctx, chatId, pointer) +} +func (dp *DatabaseProvider) GetChatUnreadCountV1(ctx context.Context, chatId chat_v1.ChatId) (uint32, error) { + return dp.chatv1.GetUnreadCount(ctx, chatId) +} +func (dp *DatabaseProvider) SetChatMuteStateV1(ctx context.Context, chatId chat_v1.ChatId, isMuted bool) error { + return dp.chatv1.SetMuteState(ctx, chatId, isMuted) +} +func (dp *DatabaseProvider) SetChatSubscriptionStateV1(ctx context.Context, chatId chat_v1.ChatId, isSubscribed bool) error { + return dp.chatv1.SetSubscriptionState(ctx, chatId, isSubscribed) +} + +// Chat V2 +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) GetChatMetadata(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.MetadataRecord, error) { + return dp.chatv2.GetChatMetadata(ctx, chatId) +} +func (dp *DatabaseProvider) GetChatMessageV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) { + return dp.chatv2.GetChatMessageV2(ctx, chatId, messageId) +} +func (dp *DatabaseProvider) GetAllChatsForUserV2(ctx context.Context, user chat_v2.MemberId, opts ...query.Option) ([]chat_v2.ChatId, error) { + return dp.chatv2.GetAllChatsForUserV2(ctx, user, opts...) +} +func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { + return dp.chatv2.GetAllChatMessagesV2(ctx, chatId, opts...) +} +func (dp *DatabaseProvider) GetChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { + return dp.chatv2.GetChatMembersV2(ctx, chatId) +} +func (dp *DatabaseProvider) IsChatMember(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (bool, error) { + return dp.chatv2.IsChatMember(ctx, chatId, memberId) +} +func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.MetadataRecord) error { + return dp.chatv2.PutChatV2(ctx, record) +} +func (dp *DatabaseProvider) PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error { + return dp.chatv2.PutChatMemberV2(ctx, record) } -func (dp *DatabaseProvider) AdvanceChatPointer(ctx context.Context, chatId chat.ChatId, pointer string) error { - return dp.chat.AdvancePointer(ctx, chatId, pointer) +func (dp *DatabaseProvider) PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error { + return dp.chatv2.PutChatMessageV2(ctx, record) } -func (dp *DatabaseProvider) GetChatUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) { - return dp.chat.GetUnreadCount(ctx, chatId) +func (dp *DatabaseProvider) SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error { + return dp.chatv2.SetChatMuteStateV2(ctx, chatId, memberId, isMuted) } -func (dp *DatabaseProvider) SetChatMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error { - return dp.chat.SetMuteState(ctx, chatId, isMuted) +func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) { + return dp.chatv2.AdvanceChatPointerV2(ctx, chatId, memberId, pointerType, pointer) } -func (dp *DatabaseProvider) SetChatSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error { - return dp.chat.SetSubscriptionState(ctx, chatId, isSubscribed) +func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer *chat_v2.MessageId) (uint32, error) { + return dp.chatv2.GetChatUnreadCountV2(ctx, chatId, memberId, readPointer) } // Badge Count diff --git a/pkg/code/localization/keys.go b/pkg/code/localization/keys.go index edd557df..926acc64 100644 --- a/pkg/code/localization/keys.go +++ b/pkg/code/localization/keys.go @@ -43,6 +43,7 @@ const ( ChatTitleKinPurchases = "title.chat.kinPurchases" ChatTitlePayments = "title.chat.payments" ChatTitleTips = "title.chat.tips" + ChatTitleTwoWay = "title.chat.twoWay" // Message Bodies diff --git a/pkg/code/push/notifications.go b/pkg/code/push/notifications.go index ccc361fb..e7f5ba33 100644 --- a/pkg/code/push/notifications.go +++ b/pkg/code/push/notifications.go @@ -9,12 +9,14 @@ import ( "google.golang.org/protobuf/proto" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" + chatv2pb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/code/thirdparty" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -59,13 +61,13 @@ func SendDepositPushNotification( // Legacy push notification still considers chat mute state // // todo: Proper migration to chat system - chatRecord, err := data.GetChatById(ctx, chat.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) + chatRecord, err := data.GetChatByIdV1(ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) switch err { case nil: if chatRecord.IsMuted { return nil } - case chat.ErrChatNotFound: + case chat_v1.ErrChatNotFound: default: log.WithError(err).Warn("failure getting chat record") return errors.Wrap(err, "error getting chat record") @@ -139,13 +141,13 @@ func SendGiftCardReturnedPushNotification( // Legacy push notification still considers chat mute state // // todo: Proper migration to chat system - chatRecord, err := data.GetChatById(ctx, chat.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) + chatRecord, err := data.GetChatByIdV1(ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) switch err { case nil: if chatRecord.IsMuted { return nil } - case chat.ErrChatNotFound: + case chat_v1.ErrChatNotFound: default: log.WithError(err).Warn("failure getting chat record") return errors.Wrap(err, "error getting chat record") @@ -320,15 +322,15 @@ func SendChatMessagePushNotification( for _, content := range chatMessage.Content { var contentToPush *chatpb.Content switch typedContent := content.Type.(type) { - case *chatpb.Content_Localized: - localizedPushBody, err := localization.Localize(locale, typedContent.Localized.KeyOrText) + case *chatpb.Content_ServerLocalized: + localizedPushBody, err := localization.Localize(locale, typedContent.ServerLocalized.KeyOrText) if err != nil { continue } contentToPush = &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localizedPushBody, }, }, @@ -358,13 +360,157 @@ func SendChatMessagePushNotification( } contentToPush = &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ + KeyOrText: localizedPushBody, + }, + }, + } + case *chatpb.Content_NaclBox, *chatpb.Content_Text: + contentToPush = content + case *chatpb.Content_ThankYou: + contentToPush = &chatpb.Content{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ + // todo: localize this + KeyOrText: "🙏 They thanked you for their tip", + }, + }, + } + } + + if contentToPush == nil { + continue + } + + marshalledContent, err := proto.Marshal(contentToPush) + if err != nil { + log.WithError(err).Warn("failure marshalling chat content") + return err + } + + kvs := map[string]string{ + "chat_title": localizedPushTitle, + "message_content": base64.StdEncoding.EncodeToString(marshalledContent), + } + + err = sendMutableNotificationToOwner( + ctx, + data, + pusher, + owner, + chatMessageDataPush, + chatTitle, + kvs, + ) + if err != nil { + anyErrorPushingContent = true + log.WithError(err).Warn("failure sending data push notification") + } + } + + if anyErrorPushingContent { + return errors.New("at least one piece of content failed to push") + } + return nil +} + +func SendChatMessagePushNotificationV2( + ctx context.Context, + data code_data.Provider, + pusher push_lib.Provider, + chatId chat_v2.ChatId, + chatTitle string, + owner *common.Account, + chatMessage *chatv2pb.Message, +) error { + log := logrus.StandardLogger().WithFields(logrus.Fields{ + "method": "SendChatMessagePushNotificationV2", + "owner": owner.PublicKey().ToBase58(), + "chat": chatTitle, + }) + + // Best-effort try to update the badge count before pushing message content + // + // Note: Only chat messages generate badge counts + err := UpdateBadgeCount(ctx, data, pusher, owner) + if err != nil { + log.WithError(err).Warn("failure updating badge count on device") + } + + locale, err := data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) + if err != nil { + log.WithError(err).Warn("failure getting user locale") + return err + } + + var localizedPushTitle string + + chatProperties, ok := chat_util.InternalChatProperties[chatTitle] + if ok { + localized, err := localization.Localize(locale, chatProperties.TitleLocalizationKey) + if err != nil { + return nil + } + localizedPushTitle = localized + } else { + domainDisplayName, err := thirdparty.GetDomainDisplayName(chatTitle) + if err == nil { + localizedPushTitle = domainDisplayName + } else { + return nil + } + } + + var anyErrorPushingContent bool + for _, content := range chatMessage.Content { + var contentToPush *chatv2pb.Content + switch typedContent := content.Type.(type) { + case *chatv2pb.Content_Localized: + localizedPushBody, err := localization.Localize(locale, typedContent.Localized.KeyOrText) + if err != nil { + continue + } + + contentToPush = &chatv2pb.Content{ + Type: &chatv2pb.Content_Localized{ + Localized: &chatv2pb.LocalizedContent{ KeyOrText: localizedPushBody, }, }, } - case *chatpb.Content_NaclBox: + case *chatv2pb.Content_ExchangeData: + var currencyCode currency_lib.Code + var nativeAmount float64 + if typedContent.ExchangeData.GetExact() != nil { + exchangeData := typedContent.ExchangeData.GetExact() + currencyCode = currency_lib.Code(exchangeData.Currency) + nativeAmount = exchangeData.NativeAmount + } else { + exchangeData := typedContent.ExchangeData.GetPartial() + currencyCode = currency_lib.Code(exchangeData.Currency) + nativeAmount = exchangeData.NativeAmount + } + + localizedPushBody, err := localization.LocalizeFiatWithVerb( + locale, + chatpb.ExchangeDataContent_Verb(typedContent.ExchangeData.Verb), + currencyCode, + nativeAmount, + true, + ) + if err != nil { + continue + } + + contentToPush = &chatv2pb.Content{ + Type: &chatv2pb.Content_Localized{ + Localized: &chatv2pb.LocalizedContent{ + KeyOrText: localizedPushBody, + }, + }, + } + case *chatv2pb.Content_NaclBox, *chatv2pb.Content_Text: contentToPush = content } @@ -380,6 +526,7 @@ func SendChatMessagePushNotification( kvs := map[string]string{ "chat_title": localizedPushTitle, + "chat_id": chatId.String(), "message_content": base64.StdEncoding.EncodeToString(marshalledContent), } @@ -401,5 +548,6 @@ func SendChatMessagePushNotification( if anyErrorPushingContent { return errors.New("at least one piece of content failed to push") } + return nil } diff --git a/pkg/code/server/grpc/chat/stream.go b/pkg/code/server/grpc/chat/stream.go new file mode 100644 index 00000000..9d79969b --- /dev/null +++ b/pkg/code/server/grpc/chat/stream.go @@ -0,0 +1,121 @@ +package chat + +import ( + "context" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" +) + +const ( + // todo: configurable + streamBufferSize = 64 + streamPingDelay = 5 * time.Second + streamKeepAliveRecvTimeout = 10 * time.Second + streamNotifyTimeout = 10 * time.Second +) + +type chatEventStream struct { + sync.Mutex + + closed bool + streamCh chan *chatpb.ChatStreamEvent +} + +func newChatEventStream(bufferSize int) *chatEventStream { + return &chatEventStream{ + streamCh: make(chan *chatpb.ChatStreamEvent, bufferSize), + } +} + +func (s *chatEventStream) notify(event *chatpb.ChatStreamEvent, timeout time.Duration) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + + s.Lock() + + if s.closed { + s.Unlock() + return errors.New("cannot notify closed stream") + } + + select { + case s.streamCh <- m: + case <-time.After(timeout): + s.Unlock() + s.close() + return errors.New("timed out sending message to streamCh") + } + + s.Unlock() + return nil +} + +func (s *chatEventStream) close() { + s.Lock() + defer s.Unlock() + + if s.closed { + return + } + + s.closed = true + close(s.streamCh) +} + +func boundedStreamChatEventsRecv( + ctx context.Context, + streamer chatpb.Chat_StreamChatEventsServer, + timeout time.Duration, +) (req *chatpb.StreamChatEventsRequest, err error) { + done := make(chan struct{}) + go func() { + req, err = streamer.Recv() + close(done) + }() + + select { + case <-done: + return req, err + case <-ctx.Done(): + return nil, status.Error(codes.Canceled, "") + case <-time.After(timeout): + return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") + } +} + +// Very naive implementation to start +func monitorChatEventStreamHealth( + ctx context.Context, + log *logrus.Entry, + ssRef string, + streamer chatpb.Chat_StreamChatEventsServer, +) <-chan struct{} { + streamHealthChan := make(chan struct{}) + go func() { + defer close(streamHealthChan) + + for { + // todo: configurable timeout + req, err := boundedStreamChatEventsRecv(ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + switch req.Type.(type) { + case *chatpb.StreamChatEventsRequest_Pong: + log.Tracef("received pong from client (stream=%s)", ssRef) + default: + // Client sent something unexpected. Terminate the stream + return + } + } + }() + return streamHealthChan +} diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/v1/server.go similarity index 60% rename from pkg/code/server/grpc/chat/server.go rename to pkg/code/server/grpc/chat/v1/server.go index 362c0fd7..5fef09f9 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/v1/server.go @@ -1,14 +1,20 @@ -package chat +package chat_v1 import ( + "bytes" "context" + "fmt" "math" + "strings" + "sync" + "time" "github.com/mr-tron/base58" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" @@ -18,30 +24,55 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" + push_util "github.com/code-payments/code-server/pkg/code/push" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/grpc/client" + push_lib "github.com/code-payments/code-server/pkg/push" + sync_util "github.com/code-payments/code-server/pkg/sync" ) const ( maxPageSize = 100 ) +var ( + mockTwoWayChat = chat.GetChatId("user1", "user2", true).ToProto() +) + +// todo: Resolve duplication of streaming logic with messaging service. The latest and greatest will live here. type server struct { - log *logrus.Entry - data code_data.Provider - auth *auth_util.RPCSignatureVerifier + log *logrus.Entry + data code_data.Provider + auth *auth_util.RPCSignatureVerifier + pusher push_lib.Provider + + streamsMu sync.RWMutex + streams map[string]*chatEventStream + + chatLocks *sync_util.StripedLock + chatEventChans *sync_util.StripedChannel chatpb.UnimplementedChatServer } -func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { - return &server{ - log: logrus.StandardLogger().WithField("type", "chat/server"), - data: data, - auth: auth, +func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier, pusher push_lib.Provider) chatpb.ChatServer { + s := &server{ + log: logrus.StandardLogger().WithField("type", "chat/v1/server"), + data: data, + auth: auth, + pusher: pusher, + streams: make(map[string]*chatEventStream), + chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters + chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters } + + for i, channel := range s.chatEventChans.GetChannels() { + go s.asyncChatEventStreamNotifier(i, channel) + } + + return s } func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { @@ -88,7 +119,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } } - chatRecords, err := s.data.GetAllChatsForUser( + chatRecords, err := s.data.GetAllChatsForUserV1( ctx, owner.PublicKey().ToBase58(), query.WithCursor(cursor), @@ -147,7 +178,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } protoMetadata.Title = &chatpb.ChatMetadata_Localized{ - Localized: &chatpb.LocalizedContent{ + Localized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.LocalizeWithFallback( locale, localization.GetLocalizationKeyForUserAgent(ctx, chatProperties.TitleLocalizationKey), @@ -189,7 +220,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch if !skipUnreadCountQuery && !chatRecord.IsMuted && !chatRecord.IsUnsubscribed { // todo: will need batching when users have a large number of chats - unreadCount, err := s.data.GetChatUnreadCount(ctx, chatRecord.ChatId) + unreadCount, err := s.data.GetChatUnreadCountV1(ctx, chatRecord.ChatId) if err != nil { log.WithError(err).Warn("failure getting unread count") } @@ -229,7 +260,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.GetMessagesResponse{ Result: chatpb.GetMessagesResponse_NOT_FOUND, @@ -265,7 +296,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest cursor = req.Cursor.Value } - messageRecords, err := s.data.GetAllChatMessages( + messageRecords, err := s.data.GetAllChatMessagesV1( ctx, chatId, query.WithCursor(cursor), @@ -298,11 +329,11 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest for _, content := range protoChatMessage.Content { switch typed := content.Type.(type) { - case *chatpb.Content_Localized: - typed.Localized.KeyOrText = localization.LocalizeWithFallback( + case *chatpb.Content_ServerLocalized: + typed.ServerLocalized.KeyOrText = localization.LocalizeWithFallback( locale, - localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), - typed.Localized.KeyOrText, + localization.GetLocalizationKeyForUserAgent(ctx, typed.ServerLocalized.KeyOrText), + typed.ServerLocalized.KeyOrText, ) } } @@ -347,26 +378,45 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId := chat.ChatIdFromProto(req.ChatId) - log = log.WithField("chat_id", chatId.String()) + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + chatId := chat.ChatIdFromProto(req.ChatId) messageId := base58.Encode(req.Pointer.Value.Value) log = log.WithFields(logrus.Fields{ + "chat_id": chatId.String(), "message_id": messageId, "pointer_type": req.Pointer.Kind, }) - if req.Pointer.Kind != chatpb.Pointer_READ { - return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") + // todo: Temporary code to simluate real-time + if req.Pointer.User != nil { + return nil, status.Error(codes.InvalidArgument, "pointer.user cannot be set by clients") } + if bytes.Equal(mockTwoWayChat.Value, req.ChatId.Value) { + req.Pointer.User = &chatpb.ChatMemberId{Value: req.Owner.Value} - signature := req.Signature - req.Signature = nil - if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { - return nil, err + event := &chatpb.ChatStreamEvent{ + Pointers: []*chatpb.Pointer{req.Pointer}, + } + + if err := s.asyncNotifyAll(chatId, owner, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_OK, + }, nil + } + + if req.Pointer.Kind != chatpb.Pointer_READ { + return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_CHAT_NOT_FOUND, @@ -380,7 +430,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.PermissionDenied, "") } - newPointerRecord, err := s.data.GetChatMessage(ctx, chatId, messageId) + newPointerRecord, err := s.data.GetChatMessageV1(ctx, chatId, messageId) if err == chat.ErrMessageNotFound { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_MESSAGE_NOT_FOUND, @@ -391,7 +441,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } if chatRecord.ReadPointer != nil { - oldPointerRecord, err := s.data.GetChatMessage(ctx, chatId, *chatRecord.ReadPointer) + oldPointerRecord, err := s.data.GetChatMessageV1(ctx, chatId, *chatRecord.ReadPointer) if err != nil { log.WithError(err).Warn("failure getting chat message record for old pointer value") return nil, status.Error(codes.Internal, "") @@ -404,7 +454,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } } - err = s.data.AdvanceChatPointer(ctx, chatId, messageId) + err = s.data.AdvanceChatPointerV1(ctx, chatId, messageId) if err != nil { log.WithError(err).Warn("failure advancing pointer") return nil, status.Error(codes.Internal, "") @@ -434,7 +484,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.SetMuteStateResponse{ Result: chatpb.SetMuteStateResponse_CHAT_NOT_FOUND, @@ -461,7 +511,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque }, nil } - err = s.data.SetChatMuteState(ctx, chatId, req.IsMuted) + err = s.data.SetChatMuteStateV1(ctx, chatId, req.IsMuted) if err != nil { log.WithError(err).Warn("failure setting mute status") return nil, status.Error(codes.Internal, "") @@ -492,7 +542,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.SetSubscriptionStateResponse{ Result: chatpb.SetSubscriptionStateResponse_CHAT_NOT_FOUND, @@ -519,7 +569,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr }, nil } - err = s.data.SetChatSubscriptionState(ctx, chatId, req.IsSubscribed) + err = s.data.SetChatSubscriptionStateV1(ctx, chatId, req.IsSubscribed) if err != nil { log.WithError(err).Warn("failure setting subcription status") return nil, status.Error(codes.Internal, "") @@ -529,3 +579,232 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr Result: chatpb.SetSubscriptionStateResponse_OK, }, nil } + +// +// Experimental PoC two-way chat APIs below +// + +func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { + ctx := streamer.Context() + + log := s.log.WithField("method", "StreamChatEvents") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedStreamChatEventsRecv(ctx, streamer, 250*time.Millisecond) + if err != nil { + return err + } + + if req.GetOpenStream() == nil { + return status.Error(codes.InvalidArgument, "open_stream is nil") + } + + if req.GetOpenStream().Signature == nil { + return status.Error(codes.InvalidArgument, "signature is nil") + } + + if !bytes.Equal(req.GetOpenStream().ChatId.Value, mockTwoWayChat.Value) { + return status.Error(codes.Unimplemented, "") + } + chatId := chat.ChatIdFromProto(req.GetOpenStream().ChatId) + log = log.WithField("chat_id", chatId.String()) + + owner, err := common.NewAccountFromProto(req.GetOpenStream().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + signature := req.GetOpenStream().Signature + req.GetOpenStream().Signature = nil + if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { + return err + } + + streamKey := fmt.Sprintf("%s:%s", chatId.String(), owner.PublicKey().ToBase58()) + + s.streamsMu.Lock() + + stream, exists := s.streams[streamKey] + if exists { + s.streamsMu.Unlock() + // There's an existing stream on this server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + stream = newChatEventStream(streamBufferSize) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = stream + + s.streamsMu.Unlock() + + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another OpenMessageStream() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == stream { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + + sendPingCh := time.After(0) + streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) + + for { + select { + case event, ok := <-stream.streamCh: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Events{ + Events: &chatpb.ChatStreamEventBatch{ + Events: []*chatpb.ChatStreamEvent{event}, + }, + }, + }) + if err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } +} + +func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { + log := s.log.WithField("method", "SendMessage") + log = client.InjectLoggingMetadata(ctx, log) + + if !bytes.Equal(req.ChatId.Value, mockTwoWayChat.Value) { + return nil, status.Error(codes.Unimplemented, "") + } + chatId := chat.ChatIdFromProto(req.ChatId) + log = log.WithField("chat_id", chatId.String()) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + switch req.Content[0].Type.(type) { + case *chatpb.Content_Text, *chatpb.Content_ThankYou: + default: + return nil, status.Error(codes.InvalidArgument, "content[0] must be Text or ThankYou") + } + + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + + // todo: Revisit message IDs + messageId, err := common.NewRandomAccount() + if err != nil { + log.WithError(err).Warn("failure generating random message id") + return nil, status.Error(codes.Internal, "") + } + + chatMessage := &chatpb.ChatMessage{ + MessageId: &chatpb.ChatMessageId{Value: messageId.ToProto().Value}, + Ts: timestamppb.Now(), + Content: req.Content, + Sender: &chatpb.ChatMemberId{Value: req.Owner.Value}, + Cursor: nil, // todo: Don't have cursor until we save it to the DB + } + + // todo: Save the message to the DB + + event := &chatpb.ChatStreamEvent{ + Messages: []*chatpb.ChatMessage{chatMessage}, + } + + if err := s.asyncNotifyAll(chatId, owner, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + s.asyncPushChatMessage(owner, chatId, chatMessage) + + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_OK, + Message: chatMessage, + }, nil +} + +// todo: doesn't respect mute/unsubscribe rules +// todo: only sends pushes to active stream listeners instead of all message recipients +func (s *server) asyncPushChatMessage(sender *common.Account, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { + ctx := context.TODO() + + go func() { + s.streamsMu.RLock() + for key := range s.streams { + if !strings.HasPrefix(key, chatId.String()) { + continue + } + + receiver, err := common.NewAccountFromPublicKeyString(strings.Split(key, ":")[1]) + if err != nil { + continue + } + + if bytes.Equal(sender.PublicKey().ToBytes(), receiver.PublicKey().ToBytes()) { + continue + } + + go push_util.SendChatMessagePushNotification( + ctx, + s.data, + s.pusher, + "TontonTwitch", + receiver, + chatMessage, + ) + } + s.streamsMu.RUnlock() + }() +} diff --git a/pkg/code/server/grpc/chat/server_test.go b/pkg/code/server/grpc/chat/v1/server_test.go similarity index 96% rename from pkg/code/server/grpc/chat/server_test.go rename to pkg/code/server/grpc/chat/v1/server_test.go index f6dd6eff..fe9eaf2b 100644 --- a/pkg/code/server/grpc/chat/server_test.go +++ b/pkg/code/server/grpc/chat/v1/server_test.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "context" @@ -22,13 +22,14 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/phone" "github.com/code-payments/code-server/pkg/code/data/preferences" "github.com/code-payments/code-server/pkg/code/data/user" "github.com/code-payments/code-server/pkg/code/data/user/storage" "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/kin" + memory_push "github.com/code-payments/code-server/pkg/push/memory" "github.com/code-payments/code-server/pkg/testutil" ) @@ -102,8 +103,8 @@ func TestGetChatsAndMessages_HappyPath(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, @@ -242,7 +243,7 @@ func TestGetChatsAndMessages_HappyPath(t *testing.T) { require.Len(t, getMessagesResp.Messages, 1) assert.Equal(t, expectedCodeTeamMessage.MessageId.Value, getMessagesResp.Messages[0].Cursor.Value) getMessagesResp.Messages[0].Cursor = nil - expectedCodeTeamMessage.Content[0].GetLocalized().KeyOrText = "localized message body content" + expectedCodeTeamMessage.Content[0].GetServerLocalized().KeyOrText = "localized message body content" assert.True(t, proto.Equal(expectedCodeTeamMessage, getMessagesResp.Messages[0])) getMessagesResp, err = env.client.GetMessages(env.ctx, getCashTransactionsMessagesReq) @@ -288,8 +289,8 @@ func TestChatHistoryReadState_HappyPath(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("msg.body.key%d", i), }, }, @@ -346,8 +347,8 @@ func TestChatHistoryReadState_NegativeProgress(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("msg.body.key%d", i), }, }, @@ -429,8 +430,8 @@ func TestChatHistoryReadState_MessageNotFound(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, @@ -743,8 +744,8 @@ func TestUnauthorizedAccess(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, @@ -880,7 +881,7 @@ func setup(t *testing.T) (env *testEnv, cleanup func()) { data: code_data.NewTestDataProvider(), } - s := NewChatServer(env.data, auth_util.NewRPCSignatureVerifier(env.data)) + s := NewChatServer(env.data, auth_util.NewRPCSignatureVerifier(env.data), memory_push.NewPushProvider()) env.server = s.(*server) serv.RegisterService(func(server *grpc.Server) { @@ -893,12 +894,12 @@ func setup(t *testing.T) (env *testEnv, cleanup func()) { } func (e *testEnv) sendExternalAppChatMessage(t *testing.T, msg *chatpb.ChatMessage, domain string, isVerified bool, recipient *common.Account) { - _, err := chat_util.SendChatMessage(e.ctx, e.data, domain, chat.ChatTypeExternalApp, isVerified, recipient, msg, false) + _, err := chat_util.SendNotificationChatMessageV1(e.ctx, e.data, domain, chat.ChatTypeExternalApp, isVerified, recipient, msg, false) require.NoError(t, err) } func (e *testEnv) sendInternalChatMessage(t *testing.T, msg *chatpb.ChatMessage, chatTitle string, recipient *common.Account) { - _, err := chat_util.SendChatMessage(e.ctx, e.data, chatTitle, chat.ChatTypeInternal, true, recipient, msg, false) + _, err := chat_util.SendNotificationChatMessageV1(e.ctx, e.data, chatTitle, chat.ChatTypeInternal, true, recipient, msg, false) require.NoError(t, err) } diff --git a/pkg/code/server/grpc/chat/v1/stream.go b/pkg/code/server/grpc/chat/v1/stream.go new file mode 100644 index 00000000..05b63235 --- /dev/null +++ b/pkg/code/server/grpc/chat/v1/stream.go @@ -0,0 +1,177 @@ +package chat_v1 + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" + "github.com/code-payments/code-server/pkg/code/common" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" +) + +const ( + // todo: configurable + streamBufferSize = 64 + streamPingDelay = 5 * time.Second + streamKeepAliveRecvTimeout = 10 * time.Second + streamNotifyTimeout = time.Second +) + +type chatEventStream struct { + sync.Mutex + + closed bool + streamCh chan *chatpb.ChatStreamEvent +} + +func newChatEventStream(bufferSize int) *chatEventStream { + return &chatEventStream{ + streamCh: make(chan *chatpb.ChatStreamEvent, bufferSize), + } +} + +func (s *chatEventStream) notify(event *chatpb.ChatStreamEvent, timeout time.Duration) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + + s.Lock() + + if s.closed { + s.Unlock() + return errors.New("cannot notify closed stream") + } + + select { + case s.streamCh <- m: + case <-time.After(timeout): + s.Unlock() + s.close() + return errors.New("timed out sending message to streamCh") + } + + s.Unlock() + return nil +} + +func (s *chatEventStream) close() { + s.Lock() + defer s.Unlock() + + if s.closed { + return + } + + s.closed = true + close(s.streamCh) +} + +func boundedStreamChatEventsRecv( + ctx context.Context, + streamer chatpb.Chat_StreamChatEventsServer, + timeout time.Duration, +) (req *chatpb.StreamChatEventsRequest, err error) { + done := make(chan struct{}) + go func() { + req, err = streamer.Recv() + close(done) + }() + + select { + case <-done: + return req, err + case <-ctx.Done(): + return nil, status.Error(codes.Canceled, "") + case <-time.After(timeout): + return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") + } +} + +type chatEventNotification struct { + chatId chat.ChatId + owner *common.Account + event *chatpb.ChatStreamEvent + ts time.Time +} + +func (s *server) asyncNotifyAll(chatId chat.ChatId, owner *common.Account, event *chatpb.ChatStreamEvent) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, owner, m, time.Now()}) + if !ok { + return errors.New("chat event channel is full") + } + return nil +} + +func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan interface{}) { + log := s.log.WithFields(logrus.Fields{ + "method": "asyncChatEventStreamNotifier", + "worker": workerId, + }) + + for value := range channel { + typedValue, ok := value.(*chatEventNotification) + if !ok { + log.Warn("channel did not receive expected struct") + continue + } + + log := log.WithField("chat_id", typedValue.chatId.String()) + + if time.Since(typedValue.ts) > time.Second { + log.Warn("channel notification latency is elevated") + } + + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, typedValue.chatId.String()) { + continue + } + + if strings.HasSuffix(key, typedValue.owner.PublicKey().ToBase58()) { + continue + } + + if err := stream.notify(typedValue.event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + } +} + +// Very naive implementation to start +func monitorChatEventStreamHealth( + ctx context.Context, + log *logrus.Entry, + ssRef string, + streamer chatpb.Chat_StreamChatEventsServer, +) <-chan struct{} { + streamHealthChan := make(chan struct{}) + go func() { + defer close(streamHealthChan) + + for { + // todo: configurable timeout + req, err := boundedStreamChatEventsRecv(ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + switch req.Type.(type) { + case *chatpb.StreamChatEventsRequest_Pong: + log.Tracef("received pong from client (stream=%s)", ssRef) + default: + // Client sent something unexpected. Terminate the stream + return + } + } + }() + return streamHealthChan +} diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go new file mode 100644 index 00000000..a339b24d --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -0,0 +1,991 @@ +package chat_v2 + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "github.com/code-payments/code-server/pkg/code/data/account" + "github.com/code-payments/code-server/pkg/pointer" + "sync" + "time" + + "github.com/mr-tron/base58" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" + "golang.org/x/text/language" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + auth_util "github.com/code-payments/code-server/pkg/code/auth" + "github.com/code-payments/code-server/pkg/code/common" + code_data "github.com/code-payments/code-server/pkg/code/data" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/code-payments/code-server/pkg/code/data/twitter" + "github.com/code-payments/code-server/pkg/code/localization" + push_util "github.com/code-payments/code-server/pkg/code/push" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/grpc/client" + "github.com/code-payments/code-server/pkg/push" + sync_util "github.com/code-payments/code-server/pkg/sync" +) + +// todo: resolve some common code for sending chat messages across RPCs + +const ( + maxGetChatsPageSize = 100 + maxGetMessagesPageSize = 100 + flushMessageCount = 100 +) + +type Server struct { + log *logrus.Entry + + data code_data.Provider + auth *auth_util.RPCSignatureVerifier + push push.Provider + + streamsMu sync.RWMutex + streams map[string]eventStream + + chatLocks *sync_util.StripedLock + chatEventChans *sync_util.StripedChannel + + chatpb.UnimplementedChatServer +} + +func NewChatServer( + data code_data.Provider, + auth *auth_util.RPCSignatureVerifier, + push push.Provider, +) *Server { + s := &Server{ + log: logrus.StandardLogger().WithField("type", "chat/v2/Server"), + + data: data, + auth: auth, + push: push, + + streams: make(map[string]eventStream), + + chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters + chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters + } + + for i, channel := range s.chatEventChans.GetChannels() { + go s.asyncChatEventStreamNotifier(i, channel) + } + + return s +} + +func (s *Server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { + // todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership + log := s.log.WithField("method", "GetChats") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + var limit uint64 + if req.PageSize > 0 { + limit = uint64(req.PageSize) + } else { + limit = maxGetChatsPageSize + } + if limit > maxGetChatsPageSize { + limit = maxGetChatsPageSize + } + + var direction query.Ordering + if req.Direction == chatpb.GetChatsRequest_ASC { + direction = query.Ascending + } else { + direction = query.Descending + } + + var cursor query.Cursor + if req.Cursor != nil { + cursor = req.Cursor.Value + } + + memberID, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("Failed to derive messaging account") + return nil, status.Error(codes.Internal, "") + } + + chats, err := s.data.GetAllChatsForUserV2( + ctx, + memberID, + query.WithCursor(cursor), + query.WithDirection(direction), + query.WithLimit(limit), + ) + + log.WithField("chats", len(chats)).Info("Retrieved chatlist for user") + metadata := make([]*chatpb.Metadata, 0, len(chats)) + for _, id := range chats { + md, err := s.getMetadata(ctx, memberID, id) + if err != nil { + return nil, nil + } + + metadata = append(metadata, md) + } + + return &chatpb.GetChatsResponse{ + Result: chatpb.GetChatsResponse_OK, + Chats: metadata, + }, nil +} + +func (s *Server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest) (*chatpb.GetMessagesResponse, error) { + log := s.log.WithField("method", "GetMessages") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + isChatMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failed to check if chat member") + return nil, status.Error(codes.Internal, "") + } + if !isChatMember { + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_DENIED, + }, nil + } + + var limit uint64 + if req.PageSize > 0 { + limit = uint64(req.PageSize) + } else { + limit = maxGetMessagesPageSize + } + if limit > maxGetMessagesPageSize { + limit = maxGetMessagesPageSize + } + + var direction query.Ordering + if req.Direction == chatpb.GetMessagesRequest_ASC { + direction = query.Ascending + } else { + direction = query.Descending + } + + var cursor query.Cursor + if req.Cursor != nil { + cursor = req.Cursor.Value + } + + protoChatMessages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithCursor(cursor), + query.WithDirection(direction), + query.WithLimit(limit), + ) + if err != nil { + log.WithError(err).Warn("failure getting chat messages") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_OK, + Messages: protoChatMessages, + }, nil +} + +func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { + log := s.log.WithField("method", "StartChat") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + creator, err := s.data.GetTwitterUserByTipAddress(ctx, memberId.String()) + if errors.Is(err, twitter.ErrUserNotFound) { + log.WithField("memberId", memberId).Info("User has no twitter account") + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_MISSING_IDENTITY}, nil + } else if err != nil { + log.WithError(err).Warn("failed to get twitter user") + return nil, status.Error(codes.Internal, "") + } + + switch typed := req.Parameters.(type) { + case *chatpb.StartChatRequest_TwoWayChat: + chatId := chat.GetTwoWayChatId(memberId, typed.TwoWayChat.OtherUser.Value) + + metadata, err := s.getMetadata(ctx, memberId, chatId) + if err == nil { + return &chatpb.StartChatResponse{ + Chat: metadata, + }, nil + + } else if err != nil && !errors.Is(err, chat.ErrChatNotFound) { + log.WithError(err).Warn("failed to get chat metadata") + return nil, status.Error(codes.Internal, "") + } + + if typed.TwoWayChat.IntentId == nil { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_INVALID_PARAMETER}, nil + } + + intentId := base58.Encode(typed.TwoWayChat.IntentId.Value) + log = log.WithField("intent", intentId) + + intentRecord, err := s.data.GetIntent(ctx, intentId) + if errors.Is(err, intent.ErrIntentNotFound) { + log.WithError(err).Info("Intent not found") + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting intent record") + return nil, status.Error(codes.Internal, "") + } + + switch intentRecord.State { + case intent.StatePending: + log.Info("Payment intent is pending") + case intent.StateConfirmed: + default: + log.WithField("state", intentRecord.State).Info("PayToChat intent did not succeed") + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } + + if intentRecord.SendPrivatePaymentMetadata == nil { + log.Info("intent missing private payment meta") + //return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } + + if !intentRecord.SendPrivatePaymentMetadata.IsChat { + log.Info("intent is not for chat") + //return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } + + expectedChatId := base58.Encode(chatId[:]) + if intentRecord.SendPrivatePaymentMetadata.ChatId != expectedChatId { + log.WithField("expected", expectedChatId).WithField("actual", intentRecord.SendPrivatePaymentMetadata.ChatId).Warn("chat id mismatch") + } + + otherMessagingAddress := base58.Encode(typed.TwoWayChat.OtherUser.Value) + + otherTwitter, err := s.data.GetTwitterUserByTipAddress(ctx, otherMessagingAddress) + if errors.Is(err, twitter.ErrUserNotFound) { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_USER_NOT_FOUND}, nil + } else if err != nil { + log.WithError(err).Warn("failure checking twitter user") + return nil, status.Error(codes.Internal, "") + } + + otherAccount, err := s.data.GetAccountInfoByTokenAddress(ctx, otherMessagingAddress) + if errors.Is(err, account.ErrAccountInfoNotFound) { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_USER_NOT_FOUND}, nil + } else if err != nil { + log.WithError(err).Warn("failure checking account info") + return nil, status.Error(codes.Internal, "") + } + + // At this point, we assume the relationship is valid, and can proceed to recover or create + // the chat record. + creationTs := time.Now() + chatRecord := &chat.MetadataRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: creationTs, + ChatTitle: nil, + } + + memberRecords := []*chat.MemberRecord{ + { + ChatId: chatId, + MemberId: memberId.String(), + Owner: owner.PublicKey().ToBase58(), + + Platform: chat.PlatformTwitter, + PlatformId: creator.Username, + + JoinedAt: creationTs, + }, + { + ChatId: chatId, + MemberId: otherMessagingAddress, + Owner: otherAccount.OwnerAccount, + + Platform: chat.PlatformTwitter, + PlatformId: otherTwitter.Username, + + JoinedAt: time.Now(), + }, + } + + // Note: this should almost _always_ succeed in the happy path, since we check + // for existence earlier! + // + // The only time we have to rollback and query is on race of creation. + err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { + existingChatRecord, err := s.data.GetChatMetadata(ctx, chatId) + if err != nil && !errors.Is(err, chat.ErrChatNotFound) { + return fmt.Errorf("failed to check existing chat: %w", err) + } + + if existingChatRecord != nil { + chatRecord = existingChatRecord + memberRecords, err = s.data.GetChatMembersV2(ctx, chatId) + if err != nil { + return fmt.Errorf("failed to check existing chat members: %w", err) + } + + return nil + } + + if err = s.data.PutChatV2(ctx, chatRecord); err != nil { + return fmt.Errorf("failed to save new chat: %w", err) + } + for _, m := range memberRecords { + if err = s.data.PutChatMemberV2(ctx, m); err != nil { + return fmt.Errorf("failed to add member to chat: %w", err) + } + } + + return nil + }) + if err != nil { + log.WithError(err).Warn("failure creating chat") + return nil, status.Error(codes.Internal, "") + } + + md, err := s.populateMetadata(ctx, chatRecord, memberRecords, memberId) + if err != nil { + log.WithError(err).Warn("failure populating metadata") + return nil, status.Error(codes.Internal, "") + } + + event := &chatEventNotification{ + chatId: chatId, + chatUpdate: md, + } + if err = s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failed to notify event stream") + } + + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_OK, + Chat: md, + }, nil + + default: + return nil, status.Error(codes.InvalidArgument, "StartChatRequest.Parameters is nil") + } +} + +func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { + log := s.log.WithField("method", "SendMessage") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + switch req.Content[0].Type.(type) { + case *chatpb.Content_Text: + default: + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_INVALID_CONTENT_TYPE, + }, nil + } + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + metadata, err := s.data.GetChatMetadata(ctx, chatId) + if errors.Is(err, chat.ErrChatNotFound) { + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_DENIED, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failure checking member record") + return nil, status.Error(codes.Internal, "") + } + if !isMember { + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_DENIED, + }, nil + } + + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + + chatMessage := newProtoChatMessage(memberId, req.Content...) + + err = s.persistChatMessage(ctx, chatId, chatMessage) + if err != nil { + log.WithError(err).Warn("failure persisting chat message") + return nil, status.Error(codes.Internal, "") + } + + s.onPersistChatMessage(log, chatId, chatMessage) + s.sendPushNotifications(chatId, pointer.StringOrEmpty(metadata.ChatTitle), memberId, chatMessage) + + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_OK, + Message: chatMessage, + }, nil +} + +func (s *Server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerRequest) (*chatpb.AdvancePointerResponse, error) { + log := s.log.WithField("method", "AdvancePointer") + log = client.InjectLoggingMetadata(ctx, log) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId) + + pointerType := chat.GetPointerTypeFromProto(req.Pointer.Type) + log = log.WithField("pointer_type", pointerType.String()) + if pointerType <= chat.PointerTypeUnknown || pointerType > chat.PointerTypeRead { + return nil, status.Error(codes.InvalidArgument, "invalid pointer type") + } + + pointerValue, err := chat.GetMessageIdFromProto(req.Pointer.Value) + if err != nil { + log.WithError(err).Warn("invalid pointer value") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("pointer_value", pointerValue.String()) + + // Force override whatever the user thought it should be. + req.Pointer.MemberId = memberId.ToProto() + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failure checking member record") + return nil, status.Error(codes.Internal, "") + } + if !isMember { + return &chatpb.AdvancePointerResponse{Result: chatpb.AdvancePointerResponse_DENIED}, nil + } + + _, err = s.data.GetChatMessageV2(ctx, chatId, pointerValue) + if errors.Is(err, chat.ErrChatNotFound) { + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_MESSAGE_NOT_FOUND, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting chat message record") + return nil, status.Error(codes.Internal, "") + } + + isAdvanced, err := s.data.AdvanceChatPointerV2(ctx, chatId, memberId, pointerType, pointerValue) + if err != nil { + log.WithError(err).Warn("failure advancing chat pointer") + return nil, status.Error(codes.Internal, "") + } + + if isAdvanced { + event := &chatEventNotification{ + chatId: chatId, + pointerUpdate: req.Pointer, + } + if err := s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + } + + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_OK, + }, nil +} + +func (s *Server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { + log := s.log.WithField("method", "SetMuteState") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failure checking member record") + return nil, status.Error(codes.Internal, "") + } + if !isMember { + return &chatpb.SetMuteStateResponse{Result: chatpb.SetMuteStateResponse_DENIED}, nil + } + + // todo: Use chat record to determine if muting is allowed + + err = s.data.SetChatMuteStateV2(ctx, chatId, memberId, req.IsMuted) + if err != nil { + log.WithError(err).Warn("failure setting mute state") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.SetMuteStateResponse{ + Result: chatpb.SetMuteStateResponse_OK, + }, nil +} + +func (s *Server) NotifyIsTyping(ctx context.Context, req *chatpb.NotifyIsTypingRequest) (*chatpb.NotifyIsTypingResponse, error) { + log := s.log.WithField("method", "NotifyIsTyping") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failure checking member record") + return nil, status.Error(codes.Internal, "") + } + if !isMember { + return &chatpb.NotifyIsTypingResponse{Result: chatpb.NotifyIsTypingResponse_DENIED}, nil + } + + // Internalize the event + // notifyAll sends to both(depending on type) + // notifyAll then determines the actual assembly + + event := &chatEventNotification{ + chatId: chatId, + isTyping: &chatpb.IsTyping{ + MemberId: memberId.ToProto(), + IsTyping: req.IsTyping, + }, + } + + if err := s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + return &chatpb.NotifyIsTypingResponse{}, nil +} + +func (s *Server) NotifyMessage(_ context.Context, _ chat.ChatId, _ *chatpb.Message) { + // TODO: Cleanup this up +} + +// todo: needs to have a 'fill' version +func (s *Server) getMetadata(ctx context.Context, asMember chat.MemberId, id chat.ChatId) (*chatpb.Metadata, error) { + mdRecord, err := s.data.GetChatMetadata(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to lookup metadata: %w", err) + } + + members, err := s.data.GetChatMembersV2(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get members: %w", err) + } + + return s.populateMetadata(ctx, mdRecord, members, asMember) +} + +func (s *Server) populateMetadata(ctx context.Context, mdRecord *chat.MetadataRecord, members []*chat.MemberRecord, asMember chat.MemberId) (*chatpb.Metadata, error) { + md := &chatpb.Metadata{ + ChatId: mdRecord.ChatId.ToProto(), + Type: mdRecord.ChatType.ToProto(), + Cursor: &chatpb.Cursor{Value: mdRecord.ChatId[:]}, + IsMuted: false, + Muteable: false, + NumUnread: 0, + } + + if mdRecord.ChatTitle != nil { + md.Title = *mdRecord.ChatTitle + } + + for _, m := range members { + memberId, err := chat.GetMemberIdFromString(m.MemberId) + if err != nil { + return nil, fmt.Errorf("invalid member id %q: %w", m.MemberId, err) + } + + member := &chatpb.Member{ + MemberId: memberId.ToProto(), + } + md.Members = append(md.Members, member) + + twitterUser, err := s.data.GetTwitterUserByTipAddress(ctx, m.MemberId) + if errors.Is(err, twitter.ErrUserNotFound) { + s.log.WithField("member", m.MemberId).Info("Twitter user not found for existing user") + } else if err != nil { + // TODO: If client have caching, we could just not do this... + return nil, fmt.Errorf("failed to get twitter user: %w", err) + } else { + member.Identity = &chatpb.MemberIdentity{ + Platform: chatpb.Platform_TWITTER, + Username: twitterUser.Username, + DisplayName: twitterUser.Name, + ProfilePicUrl: twitterUser.ProfilePicUrl, + } + } + + if m.DeliveryPointer != nil { + member.Pointers = append(member.Pointers, &chatpb.Pointer{ + Type: chatpb.PointerType_DELIVERED, + Value: m.DeliveryPointer.ToProto(), + MemberId: memberId.ToProto(), + }) + } + if m.ReadPointer != nil { + member.Pointers = append(member.Pointers, &chatpb.Pointer{ + Type: chatpb.PointerType_READ, + Value: m.ReadPointer.ToProto(), + MemberId: memberId.ToProto(), + }) + } + + md.IsMuted = m.IsMuted + + // If the member is not the requestor, then we can skip further processing + if !bytes.Equal(asMember, memberId) { + continue + } + + member.IsSelf = true + + // TODO: Do we actually want to compute this feature? It's maybe non-trivial. + // Maybe should have a safety valve at minimum. + md.NumUnread, err = s.data.GetChatUnreadCountV2(ctx, mdRecord.ChatId, memberId, m.ReadPointer) + if err != nil { + return nil, fmt.Errorf("failed to get unread count: %w", err) + } + } + + return md, nil +} +func (s *Server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.Message, error) { + messageRecords, err := s.data.GetAllChatMessagesV2(ctx, chatId, queryOptions...) + if err != nil { + return nil, fmt.Errorf("failure getting chat messages: %w", err) + } + + var userLocale *language.Tag // Loaded lazily when required + var res []*chatpb.Message + for _, messageRecord := range messageRecords { + var protoChatMessage chatpb.Message + err = proto.Unmarshal(messageRecord.Payload, &protoChatMessage) + if err != nil { + return nil, errors.Wrap(err, "error unmarshalling proto chat message") + } + + ts, err := messageRecord.GetTimestamp() + if err != nil { + return nil, errors.Wrap(err, "error getting message timestamp") + } + + for _, content := range protoChatMessage.Content { + switch typed := content.Type.(type) { + case *chatpb.Content_Localized: + if userLocale == nil { + loadedUserLocale, err := s.data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) + if err != nil { + return nil, errors.Wrap(err, "error getting user locale") + } + userLocale = &loadedUserLocale + } + + typed.Localized.KeyOrText = localization.LocalizeWithFallback( + *userLocale, + localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), + typed.Localized.KeyOrText, + ) + } + } + + protoChatMessage.MessageId = messageRecord.MessageId.ToProto() + if messageRecord.Sender != nil { + protoChatMessage.SenderId = messageRecord.Sender.ToProto() + } + protoChatMessage.Ts = timestamppb.New(ts) + protoChatMessage.Cursor = &chatpb.Cursor{Value: messageRecord.MessageId[:]} + + res = append(res, &protoChatMessage) + } + + return res, nil +} + +// todo: This belongs in the common chat utility, which currently only operates on v1 chats +func (s *Server) persistChatMessage(ctx context.Context, chatId chat.ChatId, protoChatMessage *chatpb.Message) error { + if err := protoChatMessage.Validate(); err != nil { + return errors.Wrap(err, "proto chat message failed validation") + } + + messageId, err := chat.GetMessageIdFromProto(protoChatMessage.MessageId) + if err != nil { + return errors.Wrap(err, "invalid message id") + } + + var senderId *chat.MemberId + if protoChatMessage.SenderId != nil { + convertedSenderId, err := chat.GetMemberIdFromProto(protoChatMessage.SenderId) + if err != nil { + return errors.Wrap(err, "invalid member id") + } + senderId = &convertedSenderId + } + + // Clear out extracted metadata as a space optimization + cloned := proto.Clone(protoChatMessage).(*chatpb.Message) + cloned.MessageId = nil + cloned.SenderId = nil + cloned.Ts = nil + cloned.Cursor = nil + + marshalled, err := proto.Marshal(cloned) + if err != nil { + return errors.Wrap(err, "error marshalling proto chat message") + } + + messageRecord := &chat.MessageRecord{ + ChatId: chatId, + MessageId: messageId, + Sender: senderId, + Payload: marshalled, + IsSilent: false, + } + + err = s.data.PutChatMessageV2(ctx, messageRecord) + if err != nil { + return errors.Wrap(err, "error persisting chat message") + } + return nil +} + +func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.Message) { + event := &chatEventNotification{ + chatId: chatId, + messageUpdate: chatMessage, + } + + if err := s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } +} + +func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sender chat.MemberId, message *chatpb.Message) { + log := s.log.WithFields(logrus.Fields{ + "method": "sendPushNotifications", + "sender": sender.String(), + "chat_id": chatId.String(), + }) + + // TODO: Callers might already have this loaded. + members, err := s.data.GetChatMembersV2(context.Background(), chatId) + if err != nil { + log.WithError(err).Warn("failure getting chat members") + return + } + + var eg errgroup.Group + eg.SetLimit(min(32, len(members))) + + for _, m := range members { + if m.Owner == "" || m.IsMuted { + continue + } + + owner, err := common.NewAccountFromPublicKeyString(m.Owner) + if err != nil { + log.WithError(err).WithField("member", m.MemberId).Warn("failure getting owner") + continue + } + + m := m + eg.Go(func() error { + log.WithField("member", m.MemberId).Info("sending push notification") + err = push_util.SendChatMessagePushNotificationV2( + context.Background(), + s.data, + s.push, + chatId, + chatTitle, + owner, + message, + ) + if err != nil { + log. + WithError(err). + WithField("member", m.MemberId). + Warn("failure sending push notification") + } + + return nil + }) + } + + _ = eg.Wait() +} + +func newProtoChatMessage(sender chat.MemberId, content ...*chatpb.Content) *chatpb.Message { + messageId := chat.GenerateMessageId() + ts, _ := messageId.GetTimestamp() + + return &chatpb.Message{ + MessageId: messageId.ToProto(), + SenderId: sender.ToProto(), + Content: content, + Ts: timestamppb.New(ts), + Cursor: &chatpb.Cursor{Value: messageId[:]}, + } +} diff --git a/pkg/code/server/grpc/chat/v2/server_test.go b/pkg/code/server/grpc/chat/v2/server_test.go new file mode 100644 index 00000000..edfc8311 --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/server_test.go @@ -0,0 +1,452 @@ +package chat_v2 + +import ( + "bytes" + "context" + "fmt" + "github.com/sirupsen/logrus" + "slices" + "testing" + "time" + + "github.com/mr-tron/base58" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + + auth_util "github.com/code-payments/code-server/pkg/code/auth" + "github.com/code-payments/code-server/pkg/code/common" + "github.com/code-payments/code-server/pkg/code/data" + "github.com/code-payments/code-server/pkg/code/data/account" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/code-payments/code-server/pkg/code/data/twitter" + "github.com/code-payments/code-server/pkg/currency" + pushmemory "github.com/code-payments/code-server/pkg/push/memory" + "github.com/code-payments/code-server/pkg/testutil" + "github.com/stretchr/testify/require" +) + +func TestServerHappy(t *testing.T) { + env, cleanup := setup(t) + defer cleanup() + + userA := testutil.NewRandomAccount(t) + userB := testutil.NewRandomAccount(t) + + ctx := context.Background() + for i, u := range []*common.Account{userA, userB} { + tipAddr, err := u.ToMessagingAccount(common.KinMintAccount) + require.NoError(t, err) + + userSuffix := string(rune('a' + i)) + + err = env.data.SaveTwitterUser(ctx, &twitter.Record{ + Username: fmt.Sprintf("username-%s", userSuffix), + Name: fmt.Sprintf("name-%s", userSuffix), + ProfilePicUrl: fmt.Sprintf("pp-%s", userSuffix), + TipAddress: tipAddr.PublicKey().ToBase58(), + LastUpdatedAt: time.Now(), + CreatedAt: time.Now(), + }) + require.NoError(t, err) + + err = env.data.CreateAccountInfo(ctx, &account.Record{ + OwnerAccount: u.String(), + AuthorityAccount: u.String(), + TokenAccount: base58.Encode(u.MustToChatMemberId()), + MintAccount: common.KinMintAccount.String(), + AccountType: commonpb.AccountType_PRIMARY, + CreatedAt: time.Now(), + }) + require.NoError(t, err) + } + + chatId := chat.GetTwoWayChatId(userA.MustToChatMemberId(), userB.MustToChatMemberId()) + intentId := bytes.Repeat([]byte{1}, 32) + err := env.data.SaveIntent(ctx, &intent.Record{ + IntentId: base58.Encode(intentId), + IntentType: intent.SendPrivatePayment, + InitiatorOwnerAccount: userA.String(), + SendPrivatePaymentMetadata: &intent.SendPrivatePaymentMetadata{ + DestinationTokenAccount: userB.String(), + Quantity: 10, + ExchangeCurrency: currency.USD, + ExchangeRate: 10, + UsdMarketValue: 10.0, + NativeAmount: 1, + IsChat: true, + ChatId: base58.Encode(chatId[:]), + }, + State: intent.StateConfirmed, + CreatedAt: time.Now(), + }) + require.NoError(t, err) + + t.Run("Initial State", func(t *testing.T) { + req := &chatpb.GetChatsRequest{Owner: userA.ToProto()} + req.Signature = signProtoMessage(t, req, userA, false) + + chats, err := env.client.GetChats(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.GetChatsResponse_OK, chats.Result) + require.Empty(t, chats.Chats) + }) + + eventCtx, eventCancel := context.WithTimeout(ctx, time.Minute) + defer eventCancel() + + eventClient, err := env.client.StreamChatEvents(eventCtx) + require.NoError(t, err) + + req := &chatpb.StreamChatEventsRequest_Params{ + Owner: userA.ToProto(), + } + req.Signature = signProtoMessage(t, req, userA, false) + + err = eventClient.Send(&chatpb.StreamChatEventsRequest{ + Type: &chatpb.StreamChatEventsRequest_Params_{ + Params: req, + }, + }) + eventCh := make(chan *chatpb.StreamChatEventsResponse_EventBatch, 1024) + + go func() { + defer close(eventCh) + + for { + msg, err := eventClient.Recv() + if err != nil { + env.log.WithError(err).Error("Failed to receive event stream") + return + } + + switch typed := msg.Type.(type) { + case *chatpb.StreamChatEventsResponse_Ping: + _ = eventClient.Send(&chatpb.StreamChatEventsRequest{ + Type: &chatpb.StreamChatEventsRequest_Pong{ + Pong: &commonpb.ClientPong{ + Timestamp: timestamppb.Now(), + }, + }, + }) + case *chatpb.StreamChatEventsResponse_Error: + env.log.WithError(err).WithField("code", typed.Error.Code).Warn("failed to receive update event") + case *chatpb.StreamChatEventsResponse_Events: + eventCh <- typed.Events + } + } + }() + + t.Run("StartChat", func(t *testing.T) { + req := &chatpb.StartChatRequest{ + Owner: userA.ToProto(), + Parameters: &chatpb.StartChatRequest_TwoWayChat{ + TwoWayChat: &chatpb.StartTwoWayChatParameters{ + OtherUser: &commonpb.SolanaAccountId{Value: userB.MustToChatMemberId()}, + IntentId: &commonpb.IntentId{Value: intentId}, + }, + }, + } + req.Signature = signProtoMessage(t, req, userA, false) + + resp, err := env.client.StartChat(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.StartChatResponse_OK, resp.Result) + require.NotEmpty(t, resp.GetChat().GetChatId()) + + expectedMeta := &chatpb.Metadata{ + ChatId: resp.Chat.ChatId, + Type: chatpb.ChatType_TWO_WAY, + Cursor: &chatpb.Cursor{Value: resp.Chat.ChatId.Value}, + Title: "", + Members: []*chatpb.Member{ + { + MemberId: userA.MustToChatMemberId().ToProto(), + Identity: &chatpb.MemberIdentity{ + Platform: chatpb.Platform_TWITTER, + Username: "username-a", + DisplayName: "name-a", + ProfilePicUrl: "pp-a", + }, + IsSelf: true, + }, + { + MemberId: userB.MustToChatMemberId().ToProto(), + Identity: &chatpb.MemberIdentity{ + Platform: chatpb.Platform_TWITTER, + Username: "username-b", + DisplayName: "name-b", + ProfilePicUrl: "pp-b", + }, + }, + }, + } + + slices.SortFunc(expectedMeta.Members, func(a, b *chatpb.Member) int { + return bytes.Compare(a.MemberId.Value, b.MemberId.Value) + }) + slices.SortFunc(resp.Chat.Members, func(a, b *chatpb.Member) int { + return bytes.Compare(a.MemberId.Value, b.MemberId.Value) + }) + + require.NoError(t, testutil.ProtoEqual(expectedMeta, resp.Chat)) + + for _, u := range []*common.Account{userA, userB} { + getChats := &chatpb.GetChatsRequest{Owner: u.ToProto()} + getChats.Signature = signProtoMessage(t, getChats, u, false) + + for _, member := range resp.Chat.Members { + member.IsSelf = bytes.Equal(u.MustToChatMemberId(), member.MemberId.Value) + } + + chats, err := env.client.GetChats(ctx, getChats) + require.NoError(t, err) + require.Equal(t, chatpb.GetChatsResponse_OK, chats.Result) + require.Len(t, chats.Chats, 1) + + slices.SortFunc(chats.Chats[0].Members, func(a, b *chatpb.Member) int { + return bytes.Compare(a.MemberId.Value, b.MemberId.Value) + }) + + require.NoError(t, testutil.ProtoEqual(resp.Chat, chats.Chats[0])) + } + }) + + var messages []*chatpb.Message + t.Run("Send Messages", func(t *testing.T) { + for _, u := range []*common.Account{userA, userB} { + for i := 0; i < 5; i++ { + req := &chatpb.SendMessageRequest{ + ChatId: chatId.ToProto(), + Owner: u.ToProto(), + Content: []*chatpb.Content{ + { + Type: &chatpb.Content_Text{ + Text: &chatpb.TextContent{ + Text: fmt.Sprintf("message-%d", i), + }, + }, + }, + }, + } + req.Signature = signProtoMessage(t, req, u, false) + + resp, err := env.client.SendMessage(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.SendMessageResponse_OK, resp.Result) + messages = append(messages, resp.GetMessage()) + + // TODO: Hack on message generation...again. + time.Sleep(time.Millisecond) + } + } + + for _, u := range []*common.Account{userA, userB} { + req := &chatpb.GetChatsRequest{Owner: u.ToProto()} + req.Signature = signProtoMessage(t, req, u, false) + + resp, err := env.client.GetChats(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.GetChatsResponse_OK, resp.Result) + + // 5 unread _each_ + require.EqualValues(t, 5, resp.Chats[0].NumUnread) + } + }) + + t.Run("Get Messages", func(t *testing.T) { + for _, u := range []*common.Account{userA, userB} { + req := &chatpb.GetMessagesRequest{ + ChatId: chatId.ToProto(), + Owner: u.ToProto(), + } + req.Signature = signProtoMessage(t, req, u, false) + + resp, err := env.client.GetMessages(ctx, req) + require.NoError(t, err) + require.NoError(t, testutil.ProtoSliceEqual(messages, resp.GetMessages())) + + req.Cursor = resp.Messages[1].GetCursor() + req.Signature = nil + req.Signature = signProtoMessage(t, req, u, false) + + resp, err = env.client.GetMessages(ctx, req) + require.NoError(t, err) + require.NoError(t, testutil.ProtoSliceEqual(messages[2:], resp.GetMessages())) + } + }) + + t.Run("Advance Pointer", func(t *testing.T) { + for _, tc := range []struct { + offset int + user *common.Account + }{ + {offset: 5 + 2, user: userA}, + {offset: 0 + 2, user: userB}, + } { + req := &chatpb.AdvancePointerRequest{ + ChatId: chatId.ToProto(), + Pointer: &chatpb.Pointer{ + Type: chatpb.PointerType_READ, + Value: messages[tc.offset].MessageId, + MemberId: tc.user.MustToChatMemberId().ToProto(), + }, + Owner: tc.user.ToProto(), + } + req.Signature = signProtoMessage(t, req, tc.user, false) + + resp, err := env.client.AdvancePointer(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.AdvancePointerResponse_OK, resp.Result) + + getChats := &chatpb.GetChatsRequest{Owner: tc.user.ToProto()} + getChats.Signature = signProtoMessage(t, getChats, tc.user, false) + + chats, err := env.client.GetChats(ctx, getChats) + require.NoError(t, err) + require.Equal(t, chatpb.GetChatsResponse_OK, chats.Result) + require.EqualValues(t, 2, chats.Chats[0].NumUnread) + } + }) + + eventCancel() + t.Run("Event Stream", func(t *testing.T) { + var events []*chatpb.StreamChatEventsResponse_ChatUpdate + for batch := range eventCh { + for _, e := range batch.Updates { + events = append(events, e) + } + } + + require.Equal(t, 13, len(events)) + + // Chat creation + require.NotNil(t, events[0].Metadata) + + // 10 messages + for i := 1; i < 10+1; i++ { + require.NotNil(t, events[i].LastMessage) + } + + // Pointer updates + for i := 1 + 10; i < 13; i++ { + require.NotNil(t, events[i].Pointer) + } + }) + + t.Run("Message Stream", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + client, err := env.client.StreamMessages(ctx) + require.NoError(t, err) + + req := &chatpb.StreamMessagesRequest_Params{ + ChatId: chatId.ToProto(), + Owner: userA.ToProto(), + Signature: nil, + } + req.Signature = signProtoMessage(t, req, userA, false) + + err = client.Send(&chatpb.StreamMessagesRequest{ + Type: &chatpb.StreamMessagesRequest_Params_{ + Params: req, + }, + }) + require.NoError(t, err) + + // expect some amount of flushes + var streamedMessages []*chatpb.Message + for { + resp, err := client.Recv() + require.NoError(t, err) + + switch typed := resp.Type.(type) { + case *chatpb.StreamMessagesResponse_Error: + require.FailNow(t, typed.Error.String()) + case *chatpb.StreamMessagesResponse_Ping: + _ = client.Send(&chatpb.StreamMessagesRequest{ + Type: &chatpb.StreamMessagesRequest_Pong{ + Pong: &commonpb.ClientPong{Timestamp: timestamppb.Now()}, + }, + }) + + case *chatpb.StreamMessagesResponse_Messages: + for _, m := range typed.Messages.Messages { + streamedMessages = append(streamedMessages, m) + if len(streamedMessages) == len(messages) { + break + } + } + + default: + } + + if len(streamedMessages) == len(messages) { + break + } + } + + require.True(t, slices.IsSortedFunc(streamedMessages, func(a, b *chatpb.Message) int { + return -1 * bytes.Compare(a.MessageId.Value, b.MessageId.Value) + })) + slices.Reverse(streamedMessages) + require.NoError(t, testutil.ProtoSliceEqual(messages, streamedMessages)) + }) +} + +type testEnv struct { + log *logrus.Logger + ctx context.Context + client chatpb.ChatClient + server *Server + data data.Provider +} + +func setup(t *testing.T) (env *testEnv, cleanup func()) { + conn, serv, err := testutil.NewServer() + require.NoError(t, err) + + env = &testEnv{ + log: logrus.StandardLogger(), + ctx: context.Background(), + client: chatpb.NewChatClient(conn), + data: data.NewTestDataProvider(), + } + + env.server = NewChatServer( + env.data, + auth_util.NewRPCSignatureVerifier(env.data), + pushmemory.NewPushProvider(), + ) + + serv.RegisterService(func(server *grpc.Server) { + chatpb.RegisterChatServer(server, env.server) + }) + + testutil.SetupRandomSubsidizer(t, env.data) + + cleanup, err = serv.Serve() + require.NoError(t, err) + return env, cleanup +} + +func signProtoMessage(t *testing.T, msg proto.Message, signer *common.Account, simulateInvalidSignature bool) *commonpb.Signature { + msgBytes, err := proto.Marshal(msg) + require.NoError(t, err) + + if simulateInvalidSignature { + signer = testutil.NewRandomAccount(t) + } + + signature, err := signer.Sign(msgBytes) + require.NoError(t, err) + + return &commonpb.Signature{ + Value: signature, + } +} diff --git a/pkg/code/server/grpc/chat/v2/stream.go b/pkg/code/server/grpc/chat/v2/stream.go new file mode 100644 index 00000000..ee2b859a --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/stream.go @@ -0,0 +1,135 @@ +package chat_v2 + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +const ( + // todo: configurable + streamBufferSize = 64 + streamPingDelay = 5 * time.Second + streamKeepAliveRecvTimeout = 10 * time.Second + streamNotifyTimeout = time.Second +) + +type eventStream interface { + notify(notification *chatEventNotification, timeout time.Duration) error +} + +type protoEventStream[T proto.Message] struct { + sync.Mutex + + closed bool + ch chan T + transform func(*chatEventNotification) (T, bool) +} + +func newEventStream[T proto.Message]( + bufferSize int, + selector func(notification *chatEventNotification) (T, bool), +) *protoEventStream[T] { + return &protoEventStream[T]{ + ch: make(chan T, bufferSize), + transform: selector, + } +} + +func (e *protoEventStream[T]) notify(event *chatEventNotification, timeout time.Duration) error { + msg, ok := e.transform(event) + if !ok { + return nil + } + + e.Lock() + if e.closed { + e.Unlock() + return errors.New("cannot notify closed stream") + } + + select { + case e.ch <- msg: + case <-time.After(timeout): + e.Unlock() + e.close() + return errors.New("timed out sending message to streamCh") + } + + e.Unlock() + return nil +} + +func (e *protoEventStream[T]) close() { + e.Lock() + defer e.Unlock() + + if e.closed { + return + } + + e.closed = true + close(e.ch) +} + +type ptr[T any] interface { + proto.Message + *T +} + +func boundedReceive[Req any, ReqPtr ptr[Req]]( + ctx context.Context, + stream grpc.ServerStream, + timeout time.Duration, +) (ReqPtr, error) { + var err error + var req = new(Req) + doneCh := make(chan struct{}) + + go func() { + err = stream.RecvMsg(req) + close(doneCh) + }() + + select { + case <-doneCh: + return req, err + case <-ctx.Done(): + return req, status.Error(codes.Canceled, "") + case <-time.After(timeout): + return req, status.Error(codes.DeadlineExceeded, "timeout receiving message") + } +} + +func monitorStreamHealth[Req any, ReqPtr ptr[Req]]( + ctx context.Context, + log *logrus.Entry, + ssRef string, + streamer grpc.ServerStream, + validFn func(ReqPtr) bool, +) <-chan struct{} { + healthCh := make(chan struct{}) + go func() { + defer close(healthCh) + + for { + req, err := boundedReceive[Req, ReqPtr](ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + if !validFn(req) { + return + } + log.Tracef("received pong from client (stream=%s)", ssRef) + } + }() + return healthCh +} diff --git a/pkg/code/server/grpc/chat/v2/streams.go b/pkg/code/server/grpc/chat/v2/streams.go new file mode 100644 index 00000000..5043e54b --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/streams.go @@ -0,0 +1,522 @@ +package chat_v2 + +import ( + "bytes" + "context" + "errors" + "fmt" + "go.uber.org/zap" + "time" + + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + + "github.com/code-payments/code-server/pkg/code/common" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/grpc/client" +) + +func (s *Server) StreamMessages(stream chatpb.Chat_StreamMessagesServer) error { + ctx := stream.Context() + log := s.log.WithField("method", "StreamMessages") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedReceive[chatpb.StreamMessagesRequest, *chatpb.StreamMessagesRequest]( + ctx, + stream, + 250*time.Millisecond, + ) + if err != nil { + return err + } + + if req.GetParams() == nil { + return status.Error(codes.InvalidArgument, "StreamChatEventsRequest.Type must be OpenStreamRequest") + } + + owner, err := common.NewAccountFromProto(req.GetParams().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.GetParams().ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.GetParams().Signature + req.GetParams().Signature = nil + if err := s.auth.Authenticate(stream.Context(), owner, req.GetParams(), signature); err != nil { + return err + } + + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return status.Error(codes.Internal, "") + } + if !isMember { + return stream.Send(&chatpb.StreamMessagesResponse{ + Type: &chatpb.StreamMessagesResponse_Error{ + Error: &chatpb.StreamError{Code: chatpb.StreamError_DENIED}, + }, + }) + } + + streamKey := fmt.Sprintf("%s:%s", chatId.String(), memberId.String()) + + s.streamsMu.Lock() + + if _, exists := s.streams[streamKey]; exists { + s.streamsMu.Unlock() + // There's an existing stream on this Server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this Server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + ss := newEventStream[*chatpb.StreamMessagesResponse_MessageBatch]( + streamBufferSize, + func(notification *chatEventNotification) (*chatpb.StreamMessagesResponse_MessageBatch, bool) { + if notification.messageUpdate == nil { + return nil, false + } + if notification.chatId != chatId { + return nil, false + } + + return &chatpb.StreamMessagesResponse_MessageBatch{ + Messages: []*chatpb.Message{notification.messageUpdate}, + }, true + }, + ) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = ss + + s.streamsMu.Unlock() + + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another StreamChatEvents() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == ss { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + + sendPingCh := time.After(0) + streamHealthCh := monitorStreamHealth(ctx, log, streamRef, stream, func(t *chatpb.StreamMessagesRequest) bool { + return t.GetPong() != nil + }) + + // TODO: Support pagination options (or just remove if not necessary). + go s.flushMessages(ctx, chatId, owner, ss) + go s.flushPointers(ctx, chatId, owner, ss) + + for { + select { + case batch, ok := <-ss.ch: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + resp := &chatpb.StreamMessagesResponse{ + Type: &chatpb.StreamMessagesResponse_Messages{ + Messages: batch, + }, + } + + if err = stream.Send(resp); err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := stream.Send(&chatpb.StreamMessagesResponse{ + Type: &chatpb.StreamMessagesResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } +} + +func (s *Server) StreamChatEvents(stream chatpb.Chat_StreamChatEventsServer) error { + ctx := stream.Context() + log := s.log.WithField("method", "StreamChatEvents") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedReceive[chatpb.StreamChatEventsRequest, *chatpb.StreamChatEventsRequest]( + ctx, + stream, + 250*time.Millisecond, + ) + if err != nil { + return err + } + + if req.GetParams() == nil { + return status.Error(codes.InvalidArgument, "StreamChatEventsRequest.Type must be OpenStreamRequest") + } + + owner, err := common.NewAccountFromProto(req.GetParams().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.GetParams().Signature + req.GetParams().Signature = nil + if err := s.auth.Authenticate(stream.Context(), owner, req.GetParams(), signature); err != nil { + return err + } + + // This should be safe? The user would have to provide a pub key + // that derives to a collision on another stream key (i.e. messages) + streamKey := fmt.Sprintf("%s", memberId.String()) + + s.streamsMu.Lock() + + if _, exists := s.streams[streamKey]; exists { + s.streamsMu.Unlock() + // There's an existing stream on this Server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this Server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + ss := newEventStream[*chatpb.StreamChatEventsResponse_EventBatch]( + streamBufferSize, + func(notification *chatEventNotification) (*chatpb.StreamChatEventsResponse_EventBatch, bool) { + // We need to check memberships here. + // + // TODO: This needs to be heavily cached + isMember, err := s.data.IsChatMember(ctx, notification.chatId, memberId) + if err != nil { + log.Warn("Failed to check if member for notification", zap.String("chat_id", notification.chatId.String())) + } else if !isMember { + log.Debug("Notification for chat not a member of, dropping", zap.String("chat_id", notification.chatId.String())) + return nil, false + } + + return &chatpb.StreamChatEventsResponse_EventBatch{ + Updates: []*chatpb.StreamChatEventsResponse_ChatUpdate{ + { + ChatId: notification.chatId.ToProto(), + Metadata: notification.chatUpdate, + LastMessage: notification.messageUpdate, + Pointer: notification.pointerUpdate, + IsTyping: notification.isTyping, + }, + }, + }, true + }, + ) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = ss + + s.streamsMu.Unlock() + + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another StreamChatEvents() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == ss { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + + sendPingCh := time.After(0) + streamHealthCh := monitorStreamHealth(ctx, log, streamRef, stream, func(t *chatpb.StreamMessagesRequest) bool { + return t.GetPong() != nil + }) + + go s.flushChats(ctx, owner, memberId, ss) + + for { + select { + case batch, ok := <-ss.ch: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + resp := &chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Events{ + Events: batch, + }, + } + + if err = stream.Send(resp); err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := stream.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } +} + +func (s *Server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream eventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushMessages", + "chat_id": chatId.String(), + "owner_account": owner.PublicKey().ToBase58(), + }) + + protoChatMessages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithCursor(query.EmptyCursor), + query.WithDirection(query.Descending), + query.WithLimit(flushMessageCount), + ) + if err != nil { + log.WithError(err).Warn("failure getting chat messages") + return + } + + for _, protoChatMessage := range protoChatMessages { + event := &chatEventNotification{ + chatId: chatId, + messageUpdate: protoChatMessage, + } + + if err = stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } +} + +func (s *Server) flushPointers(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream eventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushPointers", + "chat_id": chatId.String(), + }) + + callingMemberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failure computing self") + return + } + + memberRecords, err := s.data.GetChatMembersV2(ctx, chatId) + if err != nil { + log.WithError(err).Warn("failure getting chat members") + return + } + + for _, memberRecord := range memberRecords { + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + memberId, err := chat.GetMemberIdFromString(memberRecord.MemberId) + if err != nil { + log.WithError(err).Warnf("failure getting memberId from %s", memberRecord.MemberId) + return + } + + if bytes.Equal(memberId, callingMemberId) { + continue + } + + event := &chatEventNotification{ + pointerUpdate: &chatpb.Pointer{ + Type: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberId.ToProto(), + }, + } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } + } +} + +func (s *Server) flushChats(ctx context.Context, owner *common.Account, memberId chat.MemberId, stream eventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushChats", + "member_id": memberId.String(), + }) + + chats, err := s.data.GetAllChatsForUserV2(ctx, memberId) + if err != nil { + log.WithError(err).Warn("failed get chats") + return + } + + // TODO: This needs to be far safer. + for _, chatId := range chats { + go func(chatId chat.ChatId) { + md, err := s.getMetadata(ctx, memberId, chatId) + if err != nil { + log.WithError(err).Warn("failed get metadata", zap.String("chat_id", chatId.String())) + return + } + + messages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithLimit(1), + query.WithDirection(query.Descending), + ) + if err != nil { + log.WithError(err).Warn("failed get chat messages", zap.String("chat_id", chatId.String())) + } + + event := &chatEventNotification{ + chatId: chatId, + chatUpdate: md, + } + if len(messages) > 0 { + event.messageUpdate = messages[0] + } + }(chatId) + } +} + +type chatEventNotification struct { + chatId chat.ChatId + ts time.Time + + chatUpdate *chatpb.Metadata + pointerUpdate *chatpb.Pointer + messageUpdate *chatpb.Message + isTyping *chatpb.IsTyping +} + +func (s *Server) asyncNotifyAll(chatId chat.ChatId, event *chatEventNotification) error { + event.ts = time.Now() + ok := s.chatEventChans.Send(chatId[:], event) + if !ok { + return errors.New("chat event channel is full") + } + + return nil +} + +func (s *Server) asyncChatEventStreamNotifier(workerId int, channel <-chan any) { + log := s.log.WithFields(logrus.Fields{ + "method": "asyncChatEventStreamNotifier", + "worker": workerId, + }) + + for value := range channel { + typedValue, ok := value.(*chatEventNotification) + if !ok { + log.Warn("channel did not receive expected struct") + continue + } + + log = log.WithField("chat_id", typedValue.chatId.String()) + + if time.Since(typedValue.ts) > time.Second { + log.Warn("channel notification latency is elevated") + } + + s.streamsMu.RLock() + for _, stream := range s.streams { + if err := stream.notify(typedValue, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + } +} diff --git a/pkg/code/server/grpc/messaging/server.go b/pkg/code/server/grpc/messaging/server.go index c89f69ce..87a35c3e 100644 --- a/pkg/code/server/grpc/messaging/server.go +++ b/pkg/code/server/grpc/messaging/server.go @@ -17,19 +17,19 @@ import ( "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" "github.com/code-payments/code-server/pkg/cache" - "github.com/code-payments/code-server/pkg/grpc/client" - "github.com/code-payments/code-server/pkg/retry" - "github.com/code-payments/code-server/pkg/retry/backoff" - "github.com/code-payments/code-server/pkg/code/auth" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/messaging" "github.com/code-payments/code-server/pkg/code/data/rendezvous" "github.com/code-payments/code-server/pkg/code/thirdparty" + "github.com/code-payments/code-server/pkg/grpc/client" + "github.com/code-payments/code-server/pkg/retry" + "github.com/code-payments/code-server/pkg/retry/backoff" ) const ( @@ -285,7 +285,7 @@ func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_O err := streamer.Send(&messagingpb.OpenMessageStreamWithKeepAliveResponse{ ResponseOrPing: &messagingpb.OpenMessageStreamWithKeepAliveResponse_Ping{ - Ping: &messagingpb.ServerPing{ + Ping: &commonpb.ServerPing{ Timestamp: timestamppb.Now(), PingDelay: durationpb.New(messageStreamPingDelay), }, diff --git a/pkg/code/server/grpc/messaging/testutil.go b/pkg/code/server/grpc/messaging/testutil.go index f1e7cf97..4968ab2a 100644 --- a/pkg/code/server/grpc/messaging/testutil.go +++ b/pkg/code/server/grpc/messaging/testutil.go @@ -373,7 +373,7 @@ func (c *clientEnv) receiveMessagesInRealTime(t *testing.T, rendezvousKey *commo case *messagingpb.OpenMessageStreamWithKeepAliveResponse_Ping: require.NoError(t, streamer.streamWithKeepAlives.Send(&messagingpb.OpenMessageStreamWithKeepAliveRequest{ RequestOrPong: &messagingpb.OpenMessageStreamWithKeepAliveRequest_Pong{ - Pong: &messagingpb.ClientPong{ + Pong: &commonpb.ClientPong{ Timestamp: timestamppb.Now(), }, }, @@ -467,7 +467,7 @@ func (c *clientEnv) waitUntilStreamTerminationOrTimeout(t *testing.T, rendezvous if keepStreamAlive { require.NoError(t, streamer.streamWithKeepAlives.Send(&messagingpb.OpenMessageStreamWithKeepAliveRequest{ RequestOrPong: &messagingpb.OpenMessageStreamWithKeepAliveRequest_Pong{ - Pong: &messagingpb.ClientPong{ + Pong: &commonpb.ClientPong{ Timestamp: timestamppb.Now(), }, }, diff --git a/pkg/code/server/grpc/transaction/v2/history_test.go b/pkg/code/server/grpc/transaction/v2/history_test.go index f2cddee5..80ae0442 100644 --- a/pkg/code/server/grpc/transaction/v2/history_test.go +++ b/pkg/code/server/grpc/transaction/v2/history_test.go @@ -12,7 +12,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" currency_lib "github.com/code-payments/code-server/pkg/currency" "github.com/code-payments/code-server/pkg/kin" timelock_token_v1 "github.com/code-payments/code-server/pkg/solana/timelock/v1" @@ -142,7 +142,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { sendingPhone.tip456KinToCodeUser(t, receivingPhone, twitterUsername).requireSuccess(t) - chatMessageRecords, err := server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.CashTransactionsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err := server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 10) @@ -236,7 +236,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 32.1, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(321), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 3) @@ -267,7 +267,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 123.0, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(123), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.CashTransactionsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 7) @@ -334,7 +334,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 2.1, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(42), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.TipsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.TipsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) @@ -347,7 +347,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 45.6, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(456), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 5) @@ -396,7 +396,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 12_345.0, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(12_345), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), false)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), false)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) @@ -409,7 +409,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 77.69, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.EqualValues(t, 77690000, protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.TipsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.TipsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) diff --git a/pkg/code/server/grpc/transaction/v2/intent.go b/pkg/code/server/grpc/transaction/v2/intent.go index 9faffc47..5f0db81e 100644 --- a/pkg/code/server/grpc/transaction/v2/intent.go +++ b/pkg/code/server/grpc/transaction/v2/intent.go @@ -888,7 +888,7 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm return err } - tipMessagesToPush, err := chat_util.SendTipsExchangeMessage(ctx, s.data, intentRecord) + tipMessagesToPush, err := chat_util.SendTipsExchangeMessage(ctx, s.data, s.notifier, intentRecord) if err != nil { log.WithError(err).Warn("failure updating tips chat") return err diff --git a/pkg/code/server/grpc/transaction/v2/intent_handler.go b/pkg/code/server/grpc/transaction/v2/intent_handler.go index b54cb7b8..85b8d4eb 100644 --- a/pkg/code/server/grpc/transaction/v2/intent_handler.go +++ b/pkg/code/server/grpc/transaction/v2/intent_handler.go @@ -475,6 +475,7 @@ func (h *SendPrivatePaymentIntentHandler) PopulateMetadata(ctx context.Context, IsRemoteSend: typedProtoMetadata.IsRemoteSend, IsMicroPayment: isMicroPayment, IsTip: typedProtoMetadata.IsTip, + IsChat: typedProtoMetadata.IsChat, } if typedProtoMetadata.IsTip { @@ -488,6 +489,14 @@ func (h *SendPrivatePaymentIntentHandler) PopulateMetadata(ctx context.Context, } } + if typedProtoMetadata.IsChat { + if typedProtoMetadata.ChatId == nil { + return newIntentValidationError("chat id is missing") + } + + intentRecord.SendPrivatePaymentMetadata.ChatId = base58.Encode(typedProtoMetadata.ChatId.GetValue()) + } + if destinationAccountInfo != nil { intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount = destinationAccountInfo.OwnerAccount } diff --git a/pkg/code/server/grpc/transaction/v2/server.go b/pkg/code/server/grpc/transaction/v2/server.go index 3b852327..5ba0bee0 100644 --- a/pkg/code/server/grpc/transaction/v2/server.go +++ b/pkg/code/server/grpc/transaction/v2/server.go @@ -11,6 +11,7 @@ import ( "github.com/code-payments/code-server/pkg/code/antispam" auth_util "github.com/code-payments/code-server/pkg/code/auth" + "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/lawenforcement" @@ -29,7 +30,8 @@ type transactionServer struct { auth *auth_util.RPCSignatureVerifier - pusher push_lib.Provider + pusher push_lib.Provider + notifier chat.Notifier jupiterClient *jupiter.Client @@ -65,6 +67,7 @@ type transactionServer struct { func NewTransactionServer( data code_data.Provider, pusher push_lib.Provider, + notifier chat.Notifier, jupiterClient *jupiter.Client, messagingClient messaging.InternalMessageClient, maxmind *maxminddb.Reader, @@ -85,7 +88,8 @@ func NewTransactionServer( auth: auth_util.NewRPCSignatureVerifier(data), - pusher: pusher, + pusher: pusher, + notifier: notifier, jupiterClient: jupiterClient, diff --git a/pkg/code/server/grpc/transaction/v2/swap.go b/pkg/code/server/grpc/transaction/v2/swap.go index bfc76556..8bb6e230 100644 --- a/pkg/code/server/grpc/transaction/v2/swap.go +++ b/pkg/code/server/grpc/transaction/v2/swap.go @@ -22,7 +22,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" "github.com/code-payments/code-server/pkg/code/data/account" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" push_util "github.com/code-payments/code-server/pkg/code/push" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -511,7 +511,7 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con // Inspect the chat history for a USDC deposited message. If that message // doesn't exist, then avoid sending the swap in progress chat message, since // it can lead to user confusion. - chatMessageRecords, err := s.data.GetAllChatMessages(ctx, chatId, query.WithDirection(query.Descending), query.WithLimit(1)) + chatMessageRecords, err := s.data.GetAllChatMessagesV1(ctx, chatId, query.WithDirection(query.Descending), query.WithLimit(1)) switch err { case nil: var protoChatMessage chatpb.ChatMessage @@ -521,12 +521,12 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con } switch typed := protoChatMessage.Content[0].Type.(type) { - case *chatpb.Content_Localized: - if typed.Localized.KeyOrText != localization.ChatMessageUsdcDeposited { + case *chatpb.Content_ServerLocalized: + if typed.ServerLocalized.KeyOrText != localization.ChatMessageUsdcDeposited { return nil } } - case chat.ErrMessageNotFound: + case chat_v1.ErrMessageNotFound: default: return errors.Wrap(err, "error fetching chat messages") } diff --git a/pkg/code/server/grpc/transaction/v2/testutil.go b/pkg/code/server/grpc/transaction/v2/testutil.go index 38c42626..106feb6a 100644 --- a/pkg/code/server/grpc/transaction/v2/testutil.go +++ b/pkg/code/server/grpc/transaction/v2/testutil.go @@ -30,12 +30,13 @@ import ( transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" "github.com/code-payments/code-server/pkg/code/antispam" + "github.com/code-payments/code-server/pkg/code/chat" chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/action" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/currency" "github.com/code-payments/code-server/pkg/code/data/deposit" @@ -184,6 +185,7 @@ func setupTestEnv(t *testing.T, serverOverrides *testOverrides) (serverTestEnv, testService := NewTransactionServer( db, memory_push.NewPushProvider(), + chat.NewNoopNotifier(), nil, messaging.NewMessagingClient(db), nil, @@ -6173,7 +6175,7 @@ func isSubmitIntentError(resp *transactionpb.SubmitIntentResponse, err error) bo return err != nil || resp.GetError() != nil } -func getProtoChatMessage(t *testing.T, record *chat.Message) *chatpb.ChatMessage { +func getProtoChatMessage(t *testing.T, record *chat_v1.Message) *chatpb.ChatMessage { var protoMessage chatpb.ChatMessage require.NoError(t, proto.Unmarshal(record.Data, &protoMessage)) return &protoMessage diff --git a/pkg/code/server/grpc/user/server.go b/pkg/code/server/grpc/user/server.go index 1b24adf6..35469244 100644 --- a/pkg/code/server/grpc/user/server.go +++ b/pkg/code/server/grpc/user/server.go @@ -22,6 +22,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/code/data/paymentrequest" "github.com/code-payments/code-server/pkg/code/data/phone" @@ -701,6 +702,24 @@ func (s *identityServer) GetTwitterUser(ctx context.Context, req *userpb.GetTwit return nil, status.Error(codes.Internal, "") } + var friendChatId *commonpb.ChatId + if req.Requestor != nil { + // TODO: Validate the requestor + ownerAccount, err := common.NewAccountFromProto(req.Requestor) + if err != nil { + log.WithError(err).Warn("failed to get owner account") + return nil, status.Error(codes.Internal, "") + } + + ownerMessagingAccount, err := ownerAccount.ToMessagingAccount(common.KinMintAccount) + if err != nil { + log.WithError(err).Warn("failed to get owner messaging account") + return nil, status.Error(codes.Internal, "") + } + + friendChatId = chat.GetTwoWayChatId(ownerMessagingAccount.PublicKey().ToBytes(), tipAddress.PublicKey().ToBytes()).ToProto() + } + return &userpb.GetTwitterUserResponse{ Result: userpb.GetTwitterUserResponse_OK, TwitterUser: &userpb.TwitterUser{ @@ -710,6 +729,11 @@ func (s *identityServer) GetTwitterUser(ctx context.Context, req *userpb.GetTwit ProfilePicUrl: record.ProfilePicUrl, VerifiedType: record.VerifiedType, FollowerCount: record.FollowerCount, + FriendshipCost: &transactionpb.ExchangeDataWithoutRate{ + Currency: "usd", + NativeAmount: 1.0, + }, + FriendChatId: friendChatId, }, }, nil case twitter.ErrUserNotFound: @@ -720,7 +744,6 @@ func (s *identityServer) GetTwitterUser(ctx context.Context, req *userpb.GetTwit log.WithError(err).Warn("failure getting twitter user info") return nil, status.Error(codes.Internal, "") } - } func (s *identityServer) markWebhookAsPending(ctx context.Context, id string) error { diff --git a/pkg/database/query/cursor.go b/pkg/database/query/cursor.go index 533ef49f..08e29085 100644 --- a/pkg/database/query/cursor.go +++ b/pkg/database/query/cursor.go @@ -9,7 +9,7 @@ import ( type Cursor []byte var ( - EmptyCursor Cursor = Cursor([]byte{}) + EmptyCursor = Cursor([]byte{}) ) func ToCursor(val uint64) Cursor { diff --git a/pkg/pointer/pointer.go b/pkg/pointer/pointer.go index a3f8da02..e32fbf22 100644 --- a/pkg/pointer/pointer.go +++ b/pkg/pointer/pointer.go @@ -7,6 +7,15 @@ func String(value string) *string { return &value } +// StringOrEmpty returns the value of the string, if set. Otherwise, "". +func StringOrEmpty(value *string) string { + if value != nil { + return *value + } + + return "" +} + // StringOrDefault returns the pointer if not nil, otherwise the default value func StringOrDefault(value *string, defaultValue string) *string { if value != nil { @@ -32,6 +41,36 @@ func StringCopy(value *string) *string { return String(*value) } +// Uint8 returns a pointer to the provided uint8 value +func Uint8(value uint8) *uint8 { + return &value +} + +// Uint8OrDefault returns the pointer if not nil, otherwise the default value +func Uint8OrDefault(value *uint8, defaultValue uint8) *uint8 { + if value != nil { + return value + } + return &defaultValue +} + +// Uint8IfValid returns a pointer to the value if it's valid, otherwise nil +func Uint8IfValid(valid bool, value uint8) *uint8 { + if valid { + return &value + } + return nil +} + +// Uint8Copy returns a pointer that's a copy of the provided value +func Uint8Copy(value *uint8) *uint8 { + if value == nil { + return nil + } + + return Uint8(*value) +} + // Uint64 returns a pointer to the provided uint64 value func Uint64(value uint64) *uint64 { return &value diff --git a/pkg/testutil/proto.go b/pkg/testutil/proto.go new file mode 100644 index 00000000..a9ba5dce --- /dev/null +++ b/pkg/testutil/proto.go @@ -0,0 +1,33 @@ +package testutil + +import ( + "fmt" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +func ProtoEqual(a, b proto.Message) error { + if proto.Equal(a, b) { + return nil + } + + aJSON, _ := protojson.Marshal(a) + bJSON, _ := protojson.Marshal(b) + + return fmt.Errorf("expected: %v\nactual: %v", string(aJSON), string(bJSON)) +} + +func ProtoSliceEqual[T proto.Message](a, b []T) error { + if len(a) != len(b) { + return fmt.Errorf("len(%d) != len(%d)", len(a), len(b)) + } + + for i := range a { + if err := ProtoEqual(a[i], b[i]); err != nil { + return fmt.Errorf("element mismatch at %d\n%w", i, err) + } + } + + return nil +} diff --git a/pkg/twitter/client.go b/pkg/twitter/client.go index 9cf9bb08..6949d2ec 100644 --- a/pkg/twitter/client.go +++ b/pkg/twitter/client.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/dghubble/oauth1" "github.com/pkg/errors" "github.com/code-payments/code-server/pkg/metrics" @@ -28,8 +29,10 @@ const ( type Client struct { httpClient *http.Client - clientId string - clientSecret string + clientId string + clientSecret string + accessToken string + accessTokenSecret string bearerTokenMu sync.RWMutex bearerToken string @@ -37,11 +40,13 @@ type Client struct { } // NewClient returns a new Twitter client -func NewClient(clientId, clientSecret string) *Client { +func NewClient(clientId, clientSecret, accessToken, accessTokenSecret string) *Client { return &Client{ - httpClient: http.DefaultClient, - clientId: clientId, - clientSecret: clientSecret, + httpClient: http.DefaultClient, + clientId: clientId, + clientSecret: clientSecret, + accessToken: accessToken, + accessTokenSecret: accessTokenSecret, } } @@ -143,8 +148,16 @@ func (c *Client) SearchRecentTweets(ctx context.Context, searchString string, ma return tweets, nextToken, err } +// SendReply sends a reply to the provided tweet +func (c *Client) SendReply(ctx context.Context, tweetId, text string) (string, error) { + tracer := metrics.TraceMethodCall(ctx, metricsStructName, "SendReply") + defer tracer.End() + + return c.sendTweet(ctx, text, &tweetId) +} + func (c *Client) getUser(ctx context.Context, fromUrl string) (*User, error) { - bearerToken, err := c.getBearerToken(c.clientId, c.clientSecret) + bearerToken, err := c.getBearerToken() if err != nil { return nil, err } @@ -189,7 +202,7 @@ func (c *Client) getUser(ctx context.Context, fromUrl string) (*User, error) { } func (c *Client) getTweets(ctx context.Context, fromUrl string) ([]*Tweet, *string, error) { - bearerToken, err := c.getBearerToken(c.clientId, c.clientSecret) + bearerToken, err := c.getBearerToken() if err != nil { return nil, nil, err } @@ -253,7 +266,77 @@ func (c *Client) getTweets(ctx context.Context, fromUrl string) ([]*Tweet, *stri return result.Data, result.Meta.NextToken, nil } -func (c *Client) getBearerToken(clientId, clientSecret string) (string, error) { +func (c *Client) sendTweet(ctx context.Context, text string, inReplyTo *string) (string, error) { + apiUrl := baseUrl + "tweets" + + type ReplyParams struct { + InReplyToTweetId string `json:"in_reply_to_tweet_id"` + } + type Request struct { + Text string `json:"text"` + Reply *ReplyParams `json:"reply"` + } + + reqPayload := Request{ + Text: text, + } + if inReplyTo != nil { + reqPayload.Reply = &ReplyParams{ + InReplyToTweetId: *inReplyTo, + } + } + + reqJson, err := json.Marshal(reqPayload) + if err != nil { + return "", err + } + + req, err := http.NewRequest("POST", apiUrl, bytes.NewBuffer(reqJson)) + if err != nil { + return "", err + } + + req = req.WithContext(ctx) + + req.Header.Set("Content-Type", "application/json") + + config := oauth1.NewConfig(c.clientId, c.clientSecret) + token := oauth1.NewToken(c.accessToken, c.accessTokenSecret) + httpClient := config.Client(oauth1.NoContext, token) + + resp, err := httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return "", fmt.Errorf("unexpected http status code: %d", resp.StatusCode) + } + + var result struct { + Data struct { + Id *string `json:"id"` + } `json:"data"` + Errors []*twitterError `json:"errors"` + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if err := json.Unmarshal(body, &result); err != nil { + return "", err + } + + if len(result.Errors) > 0 { + return "", result.Errors[0].toError() + } + return *result.Data.Id, nil +} + +func (c *Client) getBearerToken() (string, error) { c.bearerTokenMu.RLock() if time.Since(c.lastBearerTokenRefresh) < bearerTokenMaxAge { c.bearerTokenMu.RUnlock() @@ -275,7 +358,7 @@ func (c *Client) getBearerToken(clientId, clientSecret string) (string, error) { } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth(clientId, clientSecret) + req.SetBasicAuth(c.clientId, c.clientSecret) resp, err := c.httpClient.Do(req) if err != nil {