diff --git a/go.mod b/go.mod index 8701d46c..4df1f59d 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( firebase.google.com/go/v4 v4.8.0 github.com/aws/aws-sdk-go-v2 v0.17.0 github.com/bits-and-blooms/bloom/v3 v3.1.0 - github.com/code-payments/code-protobuf-api v1.16.6 + github.com/code-payments/code-protobuf-api v1.17.0 github.com/emirpasic/gods v1.12.0 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/golang-jwt/jwt/v5 v5.0.0 diff --git a/go.sum b/go.sum index ca244823..90a0e4b7 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,8 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/code-payments/code-protobuf-api v1.16.6 h1:QCot0U+4Ar5SdSX4v955FORMsd3Qcf0ZgkoqlGJZzu0= -github.com/code-payments/code-protobuf-api v1.16.6/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +github.com/code-payments/code-protobuf-api v1.17.0 h1:zqLrhm54purzsKYb+5CZ3fJZCIVzmuZ31yeARUkuyWE= +github.com/code-payments/code-protobuf-api v1.17.0/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6 h1:NmTXa/uVnDyp0TY5MKi197+3HWcnYWfnHGyaFthlnGw= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= diff --git a/pkg/code/async/geyser/external_deposit.go b/pkg/code/async/geyser/external_deposit.go index 1b5939ac..48685818 100644 --- a/pkg/code/async/geyser/external_deposit.go +++ b/pkg/code/async/geyser/external_deposit.go @@ -20,7 +20,7 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/balance" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/deposit" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/intent" @@ -299,7 +299,7 @@ func processPotentialExternalDeposit(ctx context.Context, conf *conf, data code_ chatMessage, ) } - case chat.ErrMessageAlreadyExists: + case chat_v1.ErrMessageAlreadyExists: default: return errors.Wrap(err, "error sending chat message") } @@ -772,7 +772,7 @@ func delayedUsdcDepositProcessing( chatMessage, ) } - case chat.ErrMessageAlreadyExists: + case chat_v1.ErrMessageAlreadyExists: default: return } diff --git a/pkg/code/async/geyser/messenger.go b/pkg/code/async/geyser/messenger.go index 5c94153c..afb8e999 100644 --- a/pkg/code/async/geyser/messenger.go +++ b/pkg/code/async/geyser/messenger.go @@ -15,7 +15,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/push" "github.com/code-payments/code-server/pkg/code/thirdparty" "github.com/code-payments/code-server/pkg/database/query" @@ -169,13 +169,13 @@ func processPotentialBlockchainMessage(ctx context.Context, data code_data.Provi ctx, data, asciiBaseDomain, - chat.ChatTypeExternalApp, + chat_v1.ChatTypeExternalApp, true, recipientOwner, chatMessage, false, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return errors.Wrap(err, "error persisting chat message") } diff --git a/pkg/code/chat/message_cash_transactions.go b/pkg/code/chat/message_cash_transactions.go index 1976d9d5..8a94361f 100644 --- a/pkg/code/chat/message_cash_transactions.go +++ b/pkg/code/chat/message_cash_transactions.go @@ -11,7 +11,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" ) @@ -93,9 +93,9 @@ func SendCashTransactionsExchangeMessage(ctx context.Context, data code_data.Pro return errors.Wrap(err, "error getting original gift card issued intent") } - chatId := chat.GetChatId(CashTransactionsName, giftCardIssuedIntentRecord.InitiatorOwnerAccount, true) + chatId := chat_v1.GetChatId(CashTransactionsName, giftCardIssuedIntentRecord.InitiatorOwnerAccount, true) - err = data.DeleteChatMessage(ctx, chatId, giftCardIssuedIntentRecord.IntentId) + err = data.DeleteChatMessageV1(ctx, chatId, giftCardIssuedIntentRecord.IntentId) if err != nil { return errors.Wrap(err, "error deleting chat message") } @@ -152,13 +152,13 @@ func SendCashTransactionsExchangeMessage(ctx context.Context, data code_data.Pro ctx, data, CashTransactionsName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, protoMessage, true, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return errors.Wrap(err, "error persisting chat message") } } diff --git a/pkg/code/chat/message_code_team.go b/pkg/code/chat/message_code_team.go index fe8a7049..d24e2305 100644 --- a/pkg/code/chat/message_code_team.go +++ b/pkg/code/chat/message_code_team.go @@ -9,7 +9,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/code/localization" ) @@ -20,7 +20,7 @@ func SendCodeTeamMessage(ctx context.Context, data code_data.Provider, receiver ctx, data, CodeTeamName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, chatMessage, @@ -48,8 +48,8 @@ func newIncentiveMessage(localizedTextKey string, intentRecord *intent.Record) ( content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localizedTextKey, }, }, diff --git a/pkg/code/chat/message_kin_purchases.go b/pkg/code/chat/message_kin_purchases.go index 1ec10b68..4377c247 100644 --- a/pkg/code/chat/message_kin_purchases.go +++ b/pkg/code/chat/message_kin_purchases.go @@ -11,14 +11,14 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" ) // GetKinPurchasesChatId returns the chat ID for the Kin Purchases chat for a // given owner account -func GetKinPurchasesChatId(owner *common.Account) chat.ChatId { - return chat.GetChatId(KinPurchasesName, owner.PublicKey().ToBase58(), true) +func GetKinPurchasesChatId(owner *common.Account) chat_v1.ChatId { + return chat_v1.GetChatId(KinPurchasesName, owner.PublicKey().ToBase58(), true) } // SendKinPurchasesMessage sends a message to the Kin Purchases chat. @@ -27,7 +27,7 @@ func SendKinPurchasesMessage(ctx context.Context, data code_data.Provider, recei ctx, data, KinPurchasesName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, chatMessage, @@ -40,8 +40,8 @@ func SendKinPurchasesMessage(ctx context.Context, data code_data.Provider, recei func ToUsdcDepositedMessage(signature string, ts time.Time) (*chatpb.ChatMessage, error) { content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageUsdcDeposited, }, }, @@ -60,8 +60,8 @@ func NewUsdcBeingConvertedMessage(ts time.Time) (*chatpb.ChatMessage, error) { content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageUsdcBeingConverted, }, }, @@ -79,8 +79,8 @@ func ToKinAvailableForUseMessage(signature string, ts time.Time, purchases ...*t content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageKinAvailableForUse, }, }, diff --git a/pkg/code/chat/message_merchant.go b/pkg/code/chat/message_merchant.go index b4504c39..a340f8bd 100644 --- a/pkg/code/chat/message_merchant.go +++ b/pkg/code/chat/message_merchant.go @@ -13,7 +13,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/action" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" ) @@ -36,7 +36,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // the merchant. Representation in the UI may differ (ie. 2 and 3 are grouped), // but this is the most flexible solution with the chat model. chatTitle := PaymentsName - chatType := chat.ChatTypeInternal + chatType := chat_v1.ChatTypeInternal isVerifiedChat := false exchangeData, ok := getExchangeDataFromIntent(intentRecord) @@ -59,7 +59,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i if paymentRequestRecord.Domain != nil { chatTitle = *paymentRequestRecord.Domain - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = paymentRequestRecord.IsVerified } @@ -87,7 +87,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -107,7 +107,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -126,7 +126,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.ExternalDepositMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -171,7 +171,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i protoMessage, verbAndExchangeData.verb != chatpb.ExchangeDataContent_RECEIVED || !isVerifiedChat, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return nil, errors.Wrap(err, "error persisting chat message") } diff --git a/pkg/code/chat/message_tips.go b/pkg/code/chat/message_tips.go index b9984a9b..5752205a 100644 --- a/pkg/code/chat/message_tips.go +++ b/pkg/code/chat/message_tips.go @@ -9,7 +9,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" ) @@ -70,13 +70,13 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, inten ctx, data, TipsName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, protoMessage, verb != chatpb.ExchangeDataContent_RECEIVED_TIP, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return nil, errors.Wrap(err, "error persisting chat message") } diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index 41da0902..412e656b 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -12,7 +12,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" ) // SendChatMessage sends a chat message to a receiving owner account. @@ -24,13 +24,13 @@ func SendChatMessage( ctx context.Context, data code_data.Provider, chatTitle string, - chatType chat.ChatType, + chatType chat_v1.ChatType, isVerifiedChat bool, receiver *common.Account, protoMessage *chatpb.ChatMessage, isSilentMessage bool, ) (canPushMessage bool, err error) { - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) if protoMessage.Cursor != nil { // Let the utilities and GetMessages RPC handle cursors @@ -58,13 +58,13 @@ func SendChatMessage( canPersistMessage := true canPushMessage = !isSilentMessage - existingChatRecord, err := data.GetChatById(ctx, chatId) + existingChatRecord, err := data.GetChatByIdV1(ctx, chatId) switch err { case nil: canPersistMessage = !existingChatRecord.IsUnsubscribed canPushMessage = canPushMessage && canPersistMessage && !existingChatRecord.IsMuted - case chat.ErrChatNotFound: - chatRecord := &chat.Chat{ + case chat_v1.ErrChatNotFound: + chatRecord := &chat_v1.Chat{ ChatId: chatId, ChatType: chatType, ChatTitle: chatTitle, @@ -79,8 +79,8 @@ func SendChatMessage( CreatedAt: time.Now(), } - err = data.PutChat(ctx, chatRecord) - if err != nil && err != chat.ErrChatAlreadyExists { + err = data.PutChatV1(ctx, chatRecord) + if err != nil && err != chat_v1.ErrChatAlreadyExists { return false, err } default: @@ -88,7 +88,7 @@ func SendChatMessage( } if canPersistMessage { - messageRecord := &chat.Message{ + messageRecord := &chat_v1.Message{ ChatId: chatId, MessageId: base58.Encode(messageId), @@ -100,7 +100,7 @@ func SendChatMessage( Timestamp: ts.AsTime(), } - err = data.PutChatMessage(ctx, messageRecord) + err = data.PutChatMessageV1(ctx, messageRecord) if err != nil { return false, err } diff --git a/pkg/code/chat/sender_test.go b/pkg/code/chat/sender_test.go index 7625a1f3..ac767fd9 100644 --- a/pkg/code/chat/sender_test.go +++ b/pkg/code/chat/sender_test.go @@ -17,7 +17,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/badgecount" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/testutil" ) @@ -26,14 +26,14 @@ func TestSendChatMessage_HappyPath(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) var expectedBadgeCount int for i := 0; i < 10; i++ { chatMessage := newRandomChatMessage(t, i+1) expectedBadgeCount += 1 - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.True(t, canPush) @@ -56,7 +56,7 @@ func TestSendChatMessage_VerifiedChat(t *testing.T) { for _, isVerified := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - _, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, isVerified, receiver, chatMessage, true) + _, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, isVerified, receiver, chatMessage, true) require.NoError(t, err) env.assertChatRecordSaved(t, chatTitle, receiver, isVerified) } @@ -67,11 +67,11 @@ func TestSendChatMessage_SilentMessage(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for i, isSilent := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, isSilent) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, isSilent) require.NoError(t, err) assert.Equal(t, !isSilent, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, isSilent) @@ -84,7 +84,7 @@ func TestSendChatMessage_MuteState(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for _, isMuted := range []bool{false, true} { if isMuted { @@ -92,7 +92,7 @@ func TestSendChatMessage_MuteState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isMuted, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, false) @@ -105,7 +105,7 @@ func TestSendChatMessage_SubscriptionState(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for _, isUnsubscribed := range []bool{false, true} { if isUnsubscribed { @@ -113,7 +113,7 @@ func TestSendChatMessage_SubscriptionState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isUnsubscribed, canPush) if isUnsubscribed { @@ -130,12 +130,12 @@ func TestSendChatMessage_InvalidProtoMessage(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) chatMessage := newRandomChatMessage(t, 1) chatMessage.Content = nil - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) assert.Error(t, err) assert.False(t, canPush) env.assertChatRecordNotSaved(t, chatId) @@ -159,8 +159,8 @@ func newRandomChatMessage(t *testing.T, contentLength int) *chatpb.ChatMessage { var content []*chatpb.Content for i := 0; i < contentLength; i++ { content = append(content, &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("key%d", rand.Uint32()), }, }, @@ -173,13 +173,13 @@ func newRandomChatMessage(t *testing.T, contentLength int) *chatpb.ChatMessage { } func (e *testEnv) assertChatRecordSaved(t *testing.T, chatTitle string, receiver *common.Account, isVerified bool) { - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerified) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerified) - chatRecord, err := e.data.GetChatById(e.ctx, chatId) + chatRecord, err := e.data.GetChatByIdV1(e.ctx, chatId) require.NoError(t, err) assert.Equal(t, chatId[:], chatRecord.ChatId[:]) - assert.Equal(t, chat.ChatTypeInternal, chatRecord.ChatType) + assert.Equal(t, chat_v1.ChatTypeInternal, chatRecord.ChatType) assert.Equal(t, chatTitle, chatRecord.ChatTitle) assert.Equal(t, isVerified, chatRecord.IsVerified) assert.Equal(t, receiver.PublicKey().ToBase58(), chatRecord.CodeUser) @@ -188,8 +188,8 @@ func (e *testEnv) assertChatRecordSaved(t *testing.T, chatTitle string, receiver assert.False(t, chatRecord.IsUnsubscribed) } -func (e *testEnv) assertChatMessageRecordSaved(t *testing.T, chatId chat.ChatId, protoMessage *chatpb.ChatMessage, isSilent bool) { - messageRecord, err := e.data.GetChatMessage(e.ctx, chatId, base58.Encode(protoMessage.GetMessageId().Value)) +func (e *testEnv) assertChatMessageRecordSaved(t *testing.T, chatId chat_v1.ChatId, protoMessage *chatpb.ChatMessage, isSilent bool) { + messageRecord, err := e.data.GetChatMessageV1(e.ctx, chatId, base58.Encode(protoMessage.GetMessageId().Value)) require.NoError(t, err) cloned := proto.Clone(protoMessage).(*chatpb.ChatMessage) @@ -218,22 +218,22 @@ func (e *testEnv) assertBadgeCount(t *testing.T, owner *common.Account, expected assert.EqualValues(t, expected, badgeCountRecord.BadgeCount) } -func (e *testEnv) assertChatRecordNotSaved(t *testing.T, chatId chat.ChatId) { - _, err := e.data.GetChatById(e.ctx, chatId) - assert.Equal(t, chat.ErrChatNotFound, err) +func (e *testEnv) assertChatRecordNotSaved(t *testing.T, chatId chat_v1.ChatId) { + _, err := e.data.GetChatByIdV1(e.ctx, chatId) + assert.Equal(t, chat_v1.ErrChatNotFound, err) } -func (e *testEnv) assertChatMessageRecordNotSaved(t *testing.T, chatId chat.ChatId, messageId *chatpb.ChatMessageId) { - _, err := e.data.GetChatMessage(e.ctx, chatId, base58.Encode(messageId.Value)) - assert.Equal(t, chat.ErrMessageNotFound, err) +func (e *testEnv) assertChatMessageRecordNotSaved(t *testing.T, chatId chat_v1.ChatId, messageId *chatpb.ChatMessageId) { + _, err := e.data.GetChatMessageV1(e.ctx, chatId, base58.Encode(messageId.Value)) + assert.Equal(t, chat_v1.ErrMessageNotFound, err) } -func (e *testEnv) muteChat(t *testing.T, chatId chat.ChatId) { - require.NoError(t, e.data.SetChatMuteState(e.ctx, chatId, true)) +func (e *testEnv) muteChat(t *testing.T, chatId chat_v1.ChatId) { + require.NoError(t, e.data.SetChatMuteStateV1(e.ctx, chatId, true)) } -func (e *testEnv) unsubscribeFromChat(t *testing.T, chatId chat.ChatId) { - require.NoError(t, e.data.SetChatSubscriptionState(e.ctx, chatId, false)) +func (e *testEnv) unsubscribeFromChat(t *testing.T, chatId chat_v1.ChatId) { + require.NoError(t, e.data.SetChatSubscriptionStateV1(e.ctx, chatId, false)) } diff --git a/pkg/code/data/chat/memory/store.go b/pkg/code/data/chat/v1/memory/store.go similarity index 99% rename from pkg/code/data/chat/memory/store.go rename to pkg/code/data/chat/v1/memory/store.go index 1b869007..1d923071 100644 --- a/pkg/code/data/chat/memory/store.go +++ b/pkg/code/data/chat/v1/memory/store.go @@ -7,8 +7,8 @@ import ( "sync" "time" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/code/data/chat" ) type ChatsById []*chat.Chat diff --git a/pkg/code/data/chat/memory/store_test.go b/pkg/code/data/chat/v1/memory/store_test.go similarity index 74% rename from pkg/code/data/chat/memory/store_test.go rename to pkg/code/data/chat/v1/memory/store_test.go index 5d2c18a5..c27859e6 100644 --- a/pkg/code/data/chat/memory/store_test.go +++ b/pkg/code/data/chat/v1/memory/store_test.go @@ -3,7 +3,7 @@ package memory import ( "testing" - "github.com/code-payments/code-server/pkg/code/data/chat/tests" + "github.com/code-payments/code-server/pkg/code/data/chat/v1/tests" ) func TestChatMemoryStore(t *testing.T) { diff --git a/pkg/code/data/chat/model.go b/pkg/code/data/chat/v1/model.go similarity index 99% rename from pkg/code/data/chat/model.go rename to pkg/code/data/chat/v1/model.go index d8fe7432..4d156996 100644 --- a/pkg/code/data/chat/model.go +++ b/pkg/code/data/chat/v1/model.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "bytes" diff --git a/pkg/code/data/chat/model_test.go b/pkg/code/data/chat/v1/model_test.go similarity index 97% rename from pkg/code/data/chat/model_test.go rename to pkg/code/data/chat/v1/model_test.go index 7774d286..062f372b 100644 --- a/pkg/code/data/chat/model_test.go +++ b/pkg/code/data/chat/v1/model_test.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "testing" diff --git a/pkg/code/data/chat/postgres/model.go b/pkg/code/data/chat/v1/postgres/model.go similarity index 99% rename from pkg/code/data/chat/postgres/model.go rename to pkg/code/data/chat/v1/postgres/model.go index 07158095..8987df2b 100644 --- a/pkg/code/data/chat/postgres/model.go +++ b/pkg/code/data/chat/v1/postgres/model.go @@ -8,10 +8,10 @@ import ( "github.com/jmoiron/sqlx" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" pgutil "github.com/code-payments/code-server/pkg/database/postgres" q "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/chat" ) const ( diff --git a/pkg/code/data/chat/postgres/store.go b/pkg/code/data/chat/v1/postgres/store.go similarity index 98% rename from pkg/code/data/chat/postgres/store.go rename to pkg/code/data/chat/v1/postgres/store.go index 943a1935..bfb2b14f 100644 --- a/pkg/code/data/chat/postgres/store.go +++ b/pkg/code/data/chat/v1/postgres/store.go @@ -6,8 +6,8 @@ import ( "github.com/jmoiron/sqlx" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/code/data/chat" ) type store struct { diff --git a/pkg/code/data/chat/postgres/store_test.go b/pkg/code/data/chat/v1/postgres/store_test.go similarity index 95% rename from pkg/code/data/chat/postgres/store_test.go rename to pkg/code/data/chat/v1/postgres/store_test.go index 49143ad7..4d72fc4e 100644 --- a/pkg/code/data/chat/postgres/store_test.go +++ b/pkg/code/data/chat/v1/postgres/store_test.go @@ -8,8 +8,8 @@ import ( "github.com/ory/dockertest/v3" "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/code/data/chat" - "github.com/code-payments/code-server/pkg/code/data/chat/tests" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" + "github.com/code-payments/code-server/pkg/code/data/chat/v1/tests" postgrestest "github.com/code-payments/code-server/pkg/database/postgres/test" diff --git a/pkg/code/data/chat/store.go b/pkg/code/data/chat/v1/store.go similarity index 99% rename from pkg/code/data/chat/store.go rename to pkg/code/data/chat/v1/store.go index 2e79a228..c21471b5 100644 --- a/pkg/code/data/chat/store.go +++ b/pkg/code/data/chat/v1/store.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "context" diff --git a/pkg/code/data/chat/tests/tests.go b/pkg/code/data/chat/v1/tests/tests.go similarity index 99% rename from pkg/code/data/chat/tests/tests.go rename to pkg/code/data/chat/v1/tests/tests.go index f9eaaadf..1453c133 100644 --- a/pkg/code/data/chat/tests/tests.go +++ b/pkg/code/data/chat/v1/tests/tests.go @@ -10,9 +10,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/chat" ) func RunTests(t *testing.T, s chat.Store, teardown func()) { diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go new file mode 100644 index 00000000..de912e48 --- /dev/null +++ b/pkg/code/data/chat/v2/id.go @@ -0,0 +1,275 @@ +package chat_v2 + +import ( + "bytes" + "encoding/hex" + "time" + + "github.com/google/uuid" + "github.com/pkg/errors" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" +) + +type ChatId [32]byte + +// GetChatIdFromBytes gets a chat ID from a byte buffer +func GetChatIdFromBytes(buffer []byte) (ChatId, error) { + if len(buffer) != 32 { + return ChatId{}, errors.New("chat id must be 32 bytes in length") + } + + var typed ChatId + copy(typed[:], buffer[:]) + + if err := typed.Validate(); err != nil { + return ChatId{}, errors.Wrap(err, "invalid chat id") + } + + return typed, nil +} + +// GetChatIdFromBytes gets a chat ID from the string representation +func GetChatIdFromString(value string) (ChatId, error) { + decoded, err := hex.DecodeString(value) + if err != nil { + return ChatId{}, errors.Wrap(err, "value is not a hexadecimal string") + } + + return GetChatIdFromBytes(decoded) +} + +// GetChatIdFromProto gets a chat ID from the protobuf variant +func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { + if err := proto.Validate(); err != nil { + return ChatId{}, errors.Wrap(err, "proto validation failed") + } + + return GetChatIdFromBytes(proto.Value) +} + +// ToProto converts a chat ID to its protobuf variant +func (c ChatId) ToProto() *chatpb.ChatId { + return &chatpb.ChatId{Value: c[:]} +} + +// Validate validates a chat ID +func (c ChatId) Validate() error { + return nil +} + +// Clone clones a chat ID +func (c ChatId) Clone() ChatId { + var cloned ChatId + copy(cloned[:], c[:]) + return cloned +} + +// String returns the string representation of a ChatId +func (c ChatId) String() string { + return hex.EncodeToString(c[:]) +} + +// Random UUIDv4 ID for chat members +type MemberId uuid.UUID + +// GenerateMemberId generates a new random chat member ID +func GenerateMemberId() MemberId { + return MemberId(uuid.New()) +} + +// GetMemberIdFromBytes gets a member ID from a byte buffer +func GetMemberIdFromBytes(buffer []byte) (MemberId, error) { + if len(buffer) != 16 { + return MemberId{}, errors.New("member id must be 16 bytes in length") + } + + var typed MemberId + copy(typed[:], buffer[:]) + + if err := typed.Validate(); err != nil { + return MemberId{}, errors.Wrap(err, "invalid member id") + } + + return typed, nil +} + +// GetMemberIdFromString gets a chat member ID from the string representation +func GetMemberIdFromString(value string) (MemberId, error) { + decoded, err := uuid.Parse(value) + if err != nil { + return MemberId{}, errors.Wrap(err, "value is not a uuid string") + } + + return GetMemberIdFromBytes(decoded[:]) +} + +// GetMemberIdFromProto gets a member ID from the protobuf variant +func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { + if err := proto.Validate(); err != nil { + return MemberId{}, errors.Wrap(err, "proto validation failed") + } + + return GetMemberIdFromBytes(proto.Value) +} + +// ToProto converts a message ID to its protobuf variant +func (m MemberId) ToProto() *chatpb.ChatMemberId { + return &chatpb.ChatMemberId{Value: m[:]} +} + +// Validate validates a chat member ID +func (m MemberId) Validate() error { + casted := uuid.UUID(m) + + if casted.Version() != 4 { + return errors.Errorf("invalid uuid version: %s", casted.Version().String()) + } + + return nil +} + +// Clone clones a chat member ID +func (m MemberId) Clone() MemberId { + var cloned MemberId + copy(cloned[:], m[:]) + return cloned +} + +// String returns the string representation of a MemberId +func (m MemberId) String() string { + return uuid.UUID(m).String() +} + +// Time-based UUIDv7 ID for chat messages +type MessageId uuid.UUID + +// GenerateMessageId generates a UUIDv7 message ID using the current time +func GenerateMessageId() MessageId { + return GenerateMessageIdAtTime(time.Now()) +} + +// GenerateMessageIdAtTime generates a UUIDv7 message ID using the provided timestamp +func GenerateMessageIdAtTime(ts time.Time) MessageId { + // Convert timestamp to milliseconds since Unix epoch + millis := ts.UnixNano() / int64(time.Millisecond) + + // Create a byte slice to hold the UUID + var uuidBytes [16]byte + + // Populate the first 6 bytes with the timestamp (42 bits for timestamp) + uuidBytes[0] = byte((millis >> 40) & 0xff) + uuidBytes[1] = byte((millis >> 32) & 0xff) + uuidBytes[2] = byte((millis >> 24) & 0xff) + uuidBytes[3] = byte((millis >> 16) & 0xff) + uuidBytes[4] = byte((millis >> 8) & 0xff) + uuidBytes[5] = byte(millis & 0xff) + + // Set the version to 7 (UUIDv7) + uuidBytes[6] = (uuidBytes[6] & 0x0f) | (0x7 << 4) + + // Populate the remaining bytes with random values + randomUUID := uuid.New() + copy(uuidBytes[7:], randomUUID[7:]) + + return MessageId(uuidBytes) +} + +// GetMessageIdFromBytes gets a message ID from a byte buffer +func GetMessageIdFromBytes(buffer []byte) (MessageId, error) { + if len(buffer) != 16 { + return MessageId{}, errors.New("message id must be 16 bytes in length") + } + + var typed MessageId + copy(typed[:], buffer[:]) + + if err := typed.Validate(); err != nil { + return MessageId{}, errors.Wrap(err, "invalid message id") + } + + return typed, nil +} + +// GetMessageIdFromString gets a chat message ID from the string representation +func GetMessageIdFromString(value string) (MessageId, error) { + decoded, err := uuid.Parse(value) + if err != nil { + return MessageId{}, errors.Wrap(err, "value is not a uuid string") + } + + return GetMessageIdFromBytes(decoded[:]) +} + +// GetMessageIdFromProto gets a message ID from the protobuf variant +func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { + if err := proto.Validate(); err != nil { + return MessageId{}, errors.Wrap(err, "proto validation failed") + } + + return GetMessageIdFromBytes(proto.Value) +} + +// ToProto converts a message ID to its protobuf variant +func (m MessageId) ToProto() *chatpb.ChatMessageId { + return &chatpb.ChatMessageId{Value: m[:]} +} + +// GetTimestamp gets the encoded timestamp in the message ID +func (m MessageId) GetTimestamp() (time.Time, error) { + if err := m.Validate(); err != nil { + return time.Time{}, errors.Wrap(err, "invalid message id") + } + + // Extract the first 6 bytes as the timestamp + millis := (int64(m[0]) << 40) | (int64(m[1]) << 32) | (int64(m[2]) << 24) | + (int64(m[3]) << 16) | (int64(m[4]) << 8) | int64(m[5]) + + // Convert milliseconds since Unix epoch to time.Time + timestamp := time.Unix(0, millis*int64(time.Millisecond)) + + return timestamp, nil +} + +// Equals returns whether two message IDs are equal +func (m MessageId) Equals(other MessageId) bool { + return m.Compare(other) == 0 +} + +// Before returns whether the message ID is before the provided value +func (m MessageId) Before(other MessageId) bool { + return m.Compare(other) < 0 +} + +// Before returns whether the message ID is after the provided value +func (m MessageId) After(other MessageId) bool { + return m.Compare(other) > 0 +} + +// Compare returns the byte comparison of the message ID +func (m MessageId) Compare(other MessageId) int { + return bytes.Compare(m[:], other[:]) +} + +// Validate validates a message ID +func (m MessageId) Validate() error { + casted := uuid.UUID(m) + + if casted.Version() != 7 { + return errors.Errorf("invalid uuid version: %s", casted.Version().String()) + } + + return nil +} + +// Clone clones a chat message ID +func (m MessageId) Clone() MessageId { + var cloned MessageId + copy(cloned[:], m[:]) + return cloned +} + +// String returns the string representation of a MessageId +func (m MessageId) String() string { + return uuid.UUID(m).String() +} diff --git a/pkg/code/data/chat/v2/id_test.go b/pkg/code/data/chat/v2/id_test.go new file mode 100644 index 00000000..27f1f359 --- /dev/null +++ b/pkg/code/data/chat/v2/id_test.go @@ -0,0 +1,55 @@ +package chat_v2 + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateMemberId_Validation(t *testing.T) { + valid := GenerateMemberId() + assert.NoError(t, valid.Validate()) + + invalid := MemberId(GenerateMessageId()) + assert.Error(t, invalid.Validate()) +} + +func TestGenerateMessageId_Validation(t *testing.T) { + valid := GenerateMessageId() + assert.NoError(t, valid.Validate()) + + invalid := MessageId(uuid.New()) + assert.Error(t, invalid.Validate()) +} + +func TestGenerateMessageId_TimestampExtraction(t *testing.T) { + expectedTs := time.Now() + + messageId := GenerateMessageIdAtTime(expectedTs) + actualTs, err := messageId.GetTimestamp() + require.NoError(t, err) + assert.Equal(t, expectedTs.UnixMilli(), actualTs.UnixMilli()) +} + +func TestGenerateMessageId_Ordering(t *testing.T) { + now := time.Now() + messageIds := make([]MessageId, 0) + for i := 0; i < 10; i++ { + messageId := GenerateMessageIdAtTime(now.Add(time.Duration(i * int(time.Millisecond)))) + messageIds = append(messageIds, messageId) + } + + for i := 0; i < len(messageIds)-1; i++ { + assert.True(t, messageIds[i].Equals(messageIds[i])) + assert.False(t, messageIds[i].Equals(messageIds[i+1])) + + assert.True(t, messageIds[i].Before(messageIds[i+1])) + assert.False(t, messageIds[i].After(messageIds[i+1])) + + assert.False(t, messageIds[i+1].Before(messageIds[i])) + assert.True(t, messageIds[i+1].After(messageIds[i])) + } +} diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go new file mode 100644 index 00000000..ff54f9ac --- /dev/null +++ b/pkg/code/data/chat/v2/memory/store.go @@ -0,0 +1,541 @@ +package memory + +import ( + "bytes" + "context" + "sort" + "sync" + "time" + + "github.com/pkg/errors" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/database/query" +) + +// todo: finish implementing me +type store struct { + mu sync.Mutex + + chatRecords []*chat.ChatRecord + memberRecords []*chat.MemberRecord + messageRecords []*chat.MessageRecord + + lastChatId int64 + lastMemberId int64 + lastMessageId int64 +} + +// New returns a new in memory chat.Store +func New() chat.Store { + return &store{} +} + +// GetChatById implements chat.Store.GetChatById +func (s *store) GetChatById(_ context.Context, chatId chat.ChatId) (*chat.ChatRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findChatById(chatId) + if item == nil { + return nil, chat.ErrChatNotFound + } + + cloned := item.Clone() + return &cloned, nil +} + +// GetMemberById implements chat.Store.GetMemberById +func (s *store) GetMemberById(_ context.Context, chatId chat.ChatId, memberId chat.MemberId) (*chat.MemberRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return nil, chat.ErrMemberNotFound + } + + cloned := item.Clone() + return &cloned, nil +} + +// GetMessageById implements chat.Store.GetMessageById +func (s *store) GetMessageById(_ context.Context, chatId chat.ChatId, messageId chat.MessageId) (*chat.MessageRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMessageById(chatId, messageId) + if item == nil { + return nil, chat.ErrMessageNotFound + } + + cloned := item.Clone() + return &cloned, nil +} + +// GetAllMembersByChatId implements chat.Store.GetAllMembersByChatId +func (s *store) GetAllMembersByChatId(_ context.Context, chatId chat.ChatId) ([]*chat.MemberRecord, error) { + items := s.findMembersByChatId(chatId) + if len(items) == 0 { + return nil, chat.ErrMemberNotFound + } + return cloneMemberRecords(items), nil +} + +// GetAllMembersByPlatformIds implements chat.store.GetAllMembersByPlatformIds +func (s *store) GetAllMembersByPlatformIds(_ context.Context, idByPlatform map[chat.Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + items := s.findMembersByPlatformIds(idByPlatform) + items, err := s.getMemberRecordPage(items, cursor, direction, limit) + if err != nil { + return nil, err + } + + if len(items) == 0 { + return nil, chat.ErrMemberNotFound + } + return cloneMemberRecords(items), nil +} + +// GetUnreadCount implements chat.store.GetUnreadCount +func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, readPointer chat.MessageId) (uint32, error) { + s.mu.Lock() + defer s.mu.Unlock() + + items := s.findMessagesByChatId(chatId) + items = s.filterMessagesAfter(items, readPointer) + items = s.filterMessagesNotSentBy(items, memberId) + items = s.filterNotifiedMessages(items) + return uint32(len(items)), nil +} + +// GetAllMessagesByChatId implements chat.Store.GetAllMessagesByChatId +func (s *store) GetAllMessagesByChatId(_ context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + items := s.findMessagesByChatId(chatId) + items, err := s.getMessageRecordPage(items, cursor, direction, limit) + if err != nil { + return nil, err + } + + if len(items) == 0 { + return nil, chat.ErrMessageNotFound + } + return cloneMessageRecords(items), nil +} + +// PutChat creates a new chat +func (s *store) PutChat(_ context.Context, record *chat.ChatRecord) error { + if err := record.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.lastChatId++ + + if item := s.findChat(record); item != nil { + return chat.ErrChatExists + } + + record.Id = s.lastChatId + if record.CreatedAt.IsZero() { + record.CreatedAt = time.Now() + } + + cloned := record.Clone() + s.chatRecords = append(s.chatRecords, &cloned) + + return nil +} + +// PutMember creates a new chat member +func (s *store) PutMember(_ context.Context, record *chat.MemberRecord) error { + if err := record.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.lastMemberId++ + + if item := s.findMember(record); item != nil { + return chat.ErrMemberExists + } + + record.Id = s.lastMemberId + if record.JoinedAt.IsZero() { + record.JoinedAt = time.Now() + } + + cloned := record.Clone() + s.memberRecords = append(s.memberRecords, &cloned) + + return nil +} + +// PutMessage implements chat.Store.PutMessage +func (s *store) PutMessage(_ context.Context, record *chat.MessageRecord) error { + if err := record.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.lastMessageId++ + + if item := s.findMessage(record); item != nil { + return chat.ErrMessageExsits + } + + record.Id = s.lastMessageId + + cloned := record.Clone() + s.messageRecords = append(s.messageRecords, &cloned) + + return nil +} + +// AdvancePointer implements chat.Store.AdvancePointer +func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) (bool, error) { + switch pointerType { + case chat.PointerTypeDelivered, chat.PointerTypeRead: + default: + return false, chat.ErrInvalidPointerType + } + + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return false, chat.ErrMemberNotFound + } + + var currentPointer *chat.MessageId + switch pointerType { + case chat.PointerTypeDelivered: + currentPointer = item.DeliveryPointer + case chat.PointerTypeRead: + currentPointer = item.ReadPointer + } + + if currentPointer == nil || currentPointer.Before(pointer) { + switch pointerType { + case chat.PointerTypeDelivered: + cloned := pointer.Clone() + item.DeliveryPointer = &cloned + case chat.PointerTypeRead: + cloned := pointer.Clone() + item.ReadPointer = &cloned + } + + return true, nil + } + return false, nil +} + +// UpgradeIdentity implements chat.Store.UpgradeIdentity +func (s *store) UpgradeIdentity(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, platform chat.Platform, platformId string) error { + switch platform { + case chat.PlatformTwitter: + default: + return errors.Errorf("platform not supported for identity upgrades: %s", platform.String()) + } + + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return chat.ErrMemberNotFound + } + if item.Platform != chat.PlatformCode { + return chat.ErrMemberIdentityAlreadyUpgraded + } + + item.Platform = platform + item.PlatformId = platformId + + return nil +} + +// SetMuteState implements chat.Store.SetMuteState +func (s *store) SetMuteState(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, isMuted bool) error { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return chat.ErrMemberNotFound + } + + item.IsMuted = isMuted + + return nil +} + +// SetSubscriptionState implements chat.Store.SetSubscriptionState +func (s *store) SetSubscriptionState(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, isSubscribed bool) error { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return chat.ErrMemberNotFound + } + + item.IsUnsubscribed = !isSubscribed + + return nil +} + +func (s *store) findChat(data *chat.ChatRecord) *chat.ChatRecord { + for _, item := range s.chatRecords { + if data.Id == item.Id { + return item + } + + if bytes.Equal(data.ChatId[:], item.ChatId[:]) { + return item + } + } + return nil +} + +func (s *store) findChatById(chatId chat.ChatId) *chat.ChatRecord { + for _, item := range s.chatRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) { + return item + } + } + return nil +} + +func (s *store) findMember(data *chat.MemberRecord) *chat.MemberRecord { + for _, item := range s.memberRecords { + if data.Id == item.Id { + return item + } + + if bytes.Equal(data.ChatId[:], item.ChatId[:]) && bytes.Equal(data.MemberId[:], item.MemberId[:]) { + return item + } + } + return nil +} + +func (s *store) findMemberById(chatId chat.ChatId, memberId chat.MemberId) *chat.MemberRecord { + for _, item := range s.memberRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) && bytes.Equal(memberId[:], item.MemberId[:]) { + return item + } + } + return nil +} + +func (s *store) findMembersByChatId(chatId chat.ChatId) []*chat.MemberRecord { + var res []*chat.MemberRecord + for _, item := range s.memberRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) { + res = append(res, item) + } + } + return res +} + +func (s *store) findMembersByPlatformIds(idByPlatform map[chat.Platform]string) []*chat.MemberRecord { + var res []*chat.MemberRecord + for _, item := range s.memberRecords { + platformId, ok := idByPlatform[item.Platform] + if !ok { + continue + } + + if platformId == item.PlatformId { + res = append(res, item) + } + } + return res +} + +func (s *store) getMemberRecordPage(items []*chat.MemberRecord, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { + if len(items) == 0 { + return nil, nil + } + + var memberIdCursor *uint64 + if len(cursor) > 0 { + cursorValue := query.FromCursor(cursor) + memberIdCursor = &cursorValue + } + + var res []*chat.MemberRecord + if memberIdCursor == nil { + res = items + } else { + for _, item := range items { + if item.Id > int64(*memberIdCursor) && direction == query.Ascending { + res = append(res, item) + } + + if item.Id < int64(*memberIdCursor) && direction == query.Descending { + res = append(res, item) + } + } + } + + if direction == query.Ascending { + sort.Sort(chat.MembersById(res)) + } else { + sort.Sort(sort.Reverse(chat.MembersById(res))) + } + + if len(res) >= int(limit) { + return res[:limit], nil + } + + return res, nil +} + +func (s *store) findMessage(data *chat.MessageRecord) *chat.MessageRecord { + for _, item := range s.messageRecords { + if data.Id == item.Id { + return item + } + + if bytes.Equal(data.ChatId[:], item.ChatId[:]) && bytes.Equal(data.MessageId[:], item.MessageId[:]) { + return item + } + } + return nil +} + +func (s *store) findMessageById(chatId chat.ChatId, messageId chat.MessageId) *chat.MessageRecord { + for _, item := range s.messageRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) && bytes.Equal(messageId[:], item.MessageId[:]) { + return item + } + } + return nil +} + +func (s *store) findMessagesByChatId(chatId chat.ChatId) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range s.messageRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) { + res = append(res, item) + } + } + return res +} + +func (s *store) filterMessagesAfter(items []*chat.MessageRecord, pointer chat.MessageId) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range items { + if item.MessageId.After(pointer) { + res = append(res, item) + } + } + return res +} + +func (s *store) filterMessagesNotSentBy(items []*chat.MessageRecord, sender chat.MemberId) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range items { + if item.Sender == nil || !bytes.Equal(item.Sender[:], sender[:]) { + res = append(res, item) + } + } + return res +} + +func (s *store) filterNotifiedMessages(items []*chat.MessageRecord) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range items { + if !item.IsSilent { + res = append(res, item) + } + } + return res +} + +func (s *store) getMessageRecordPage(items []*chat.MessageRecord, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { + if len(items) == 0 { + return nil, nil + } + + var messageIdCursor *chat.MessageId + if len(cursor) > 0 { + messageId, err := chat.GetMessageIdFromBytes(cursor) + if err != nil { + return nil, err + } + messageIdCursor = &messageId + } + + var res []*chat.MessageRecord + if messageIdCursor == nil { + res = items + } else { + for _, item := range items { + if item.MessageId.After(*messageIdCursor) && direction == query.Ascending { + res = append(res, item) + } + + if item.MessageId.Before(*messageIdCursor) && direction == query.Descending { + res = append(res, item) + } + } + } + + if direction == query.Ascending { + sort.Sort(chat.MessagesByMessageId(res)) + } else { + sort.Sort(sort.Reverse(chat.MessagesByMessageId(res))) + } + + if len(res) >= int(limit) { + return res[:limit], nil + } + + return res, nil +} + +func (s *store) reset() { + s.mu.Lock() + defer s.mu.Unlock() + + s.chatRecords = nil + s.memberRecords = nil + s.messageRecords = nil + + s.lastChatId = 0 + s.lastMemberId = 0 + s.lastMessageId = 0 +} + +func cloneMemberRecords(items []*chat.MemberRecord) []*chat.MemberRecord { + res := make([]*chat.MemberRecord, len(items)) + for i, item := range items { + cloned := item.Clone() + res[i] = &cloned + } + return res +} + +func cloneMessageRecords(items []*chat.MessageRecord) []*chat.MessageRecord { + res := make([]*chat.MessageRecord, len(items)) + for i, item := range items { + cloned := item.Clone() + res[i] = &cloned + } + return res +} diff --git a/pkg/code/data/chat/v2/memory/store_test.go b/pkg/code/data/chat/v2/memory/store_test.go new file mode 100644 index 00000000..cd61dfa4 --- /dev/null +++ b/pkg/code/data/chat/v2/memory/store_test.go @@ -0,0 +1,15 @@ +package memory + +import ( + "testing" + + "github.com/code-payments/code-server/pkg/code/data/chat/v2/tests" +) + +func TestChatMemoryStore(t *testing.T) { + testStore := New() + teardown := func() { + testStore.(*store).reset() + } + tests.RunTests(t, testStore, teardown) +} diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go new file mode 100644 index 00000000..ef3c7071 --- /dev/null +++ b/pkg/code/data/chat/v2/model.go @@ -0,0 +1,505 @@ +package chat_v2 + +import ( + "time" + + "github.com/mr-tron/base58" + "github.com/pkg/errors" + + "github.com/code-payments/code-server/pkg/pointer" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" +) + +type ChatType uint8 + +const ( + ChatTypeUnknown ChatType = iota + ChatTypeNotification + ChatTypeTwoWay + // ChatTypeGroup +) + +type ReferenceType uint8 + +const ( + ReferenceTypeUnknown ReferenceType = iota + ReferenceTypeIntent + ReferenceTypeSignature +) + +type PointerType uint8 + +const ( + PointerTypeUnknown PointerType = iota + PointerTypeSent + PointerTypeDelivered + PointerTypeRead +) + +type Platform uint8 + +const ( + PlatformUnknown Platform = iota + PlatformCode + PlatformTwitter +) + +type ChatRecord struct { + Id int64 + ChatId ChatId + + ChatType ChatType + + // Presence determined by ChatType: + // * Notification: Present, and may be a localization key + // * Two Way: Not present and generated dynamically based on chat members + // * Group: Present, and will not be a localization key + ChatTitle *string + + IsVerified bool + + CreatedAt time.Time +} + +type MemberRecord struct { + Id int64 + ChatId ChatId + MemberId MemberId + + Platform Platform + PlatformId string + + DeliveryPointer *MessageId + ReadPointer *MessageId + + IsMuted bool + IsUnsubscribed bool + + JoinedAt time.Time +} + +type MessageRecord struct { + Id int64 + ChatId ChatId + MessageId MessageId + + // Not present for notification-style chats + Sender *MemberId + + Data []byte + + ReferenceType *ReferenceType + Reference *string + + IsSilent bool + + // Note: No timestamp field, since it's encoded in MessageId +} + +type MembersById []*MemberRecord + +func (a MembersById) Len() int { return len(a) } +func (a MembersById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MembersById) Less(i, j int) bool { + return a[i].Id < a[j].Id +} + +type MessagesByMessageId []*MessageRecord + +func (a MessagesByMessageId) Len() int { return len(a) } +func (a MessagesByMessageId) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MessagesByMessageId) Less(i, j int) bool { + return a[i].MessageId.Before(a[j].MessageId) +} + +// GetChatTypeFromProto gets a chat type from the protobuf variant +func GetChatTypeFromProto(proto chatpb.ChatType) ChatType { + switch proto { + case chatpb.ChatType_NOTIFICATION: + return ChatTypeNotification + case chatpb.ChatType_TWO_WAY: + return ChatTypeTwoWay + default: + return ChatTypeUnknown + } +} + +// ToProto returns the proto representation of the chat type +func (c ChatType) ToProto() chatpb.ChatType { + switch c { + case ChatTypeNotification: + return chatpb.ChatType_NOTIFICATION + case ChatTypeTwoWay: + return chatpb.ChatType_TWO_WAY + default: + return chatpb.ChatType_UNKNOWN_CHAT_TYPE + } +} + +// String returns the string representation of the chat type +func (c ChatType) String() string { + switch c { + case ChatTypeNotification: + return "notification" + case ChatTypeTwoWay: + return "two-way" + default: + return "unknown" + } +} + +// GetPointerTypeFromProto gets a chat ID from the protobuf variant +func GetPointerTypeFromProto(proto chatpb.PointerType) PointerType { + switch proto { + case chatpb.PointerType_SENT: + return PointerTypeSent + case chatpb.PointerType_DELIVERED: + return PointerTypeDelivered + case chatpb.PointerType_READ: + return PointerTypeRead + default: + return PointerTypeUnknown + } +} + +// ToProto returns the proto representation of the pointer type +func (p PointerType) ToProto() chatpb.PointerType { + switch p { + case PointerTypeSent: + return chatpb.PointerType_SENT + case PointerTypeDelivered: + return chatpb.PointerType_DELIVERED + case PointerTypeRead: + return chatpb.PointerType_READ + default: + return chatpb.PointerType_UNKNOWN_POINTER_TYPE + } +} + +// String returns the string representation of the pointer type +func (p PointerType) String() string { + switch p { + case PointerTypeSent: + return "sent" + case PointerTypeDelivered: + return "delivered" + case PointerTypeRead: + return "read" + default: + return "unknown" + } +} + +// ToProto returns the proto representation of the platform +func GetPlatformFromProto(proto chatpb.Platform) Platform { + switch proto { + case chatpb.Platform_TWITTER: + return PlatformTwitter + default: + return PlatformUnknown + } +} + +// ToProto returns the proto representation of the platform +func (p Platform) ToProto() chatpb.Platform { + switch p { + case PlatformTwitter: + return chatpb.Platform_TWITTER + default: + return chatpb.Platform_UNKNOWN_PLATFORM + } +} + +// String returns the string representation of the platform +func (p Platform) String() string { + switch p { + case PlatformCode: + return "code" + case PlatformTwitter: + return "twitter" + default: + return "unknown" + } +} + +// Validate validates a chat Record +func (r *ChatRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return errors.Wrap(err, "invalid chat id") + } + + switch r.ChatType { + case ChatTypeNotification: + if r.ChatTitle == nil || len(*r.ChatTitle) == 0 { + return errors.New("chat title is required for notification chats") + } + case ChatTypeTwoWay: + if r.ChatTitle != nil { + return errors.New("chat title cannot be set for two way chats") + } + default: + return errors.Errorf("invalid chat type: %d", r.ChatType) + } + + if r.CreatedAt.IsZero() { + return errors.New("creation timestamp is required") + } + + return nil +} + +// Clone clones a chat record +func (r *ChatRecord) Clone() ChatRecord { + return ChatRecord{ + Id: r.Id, + ChatId: r.ChatId, + + ChatType: r.ChatType, + + ChatTitle: pointer.StringCopy(r.ChatTitle), + + IsVerified: r.IsVerified, + + CreatedAt: r.CreatedAt, + } +} + +// CopyTo copies a chat record to the provided destination +func (r *ChatRecord) CopyTo(dst *ChatRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + + dst.ChatType = r.ChatType + + dst.ChatTitle = pointer.StringCopy(r.ChatTitle) + + dst.IsVerified = r.IsVerified + + dst.CreatedAt = r.CreatedAt +} + +// Validate validates a member Record +func (r *MemberRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return errors.Wrap(err, "invalid chat id") + } + + if err := r.MemberId.Validate(); err != nil { + return errors.Wrap(err, "invalid member id") + } + + if len(r.PlatformId) == 0 { + return errors.New("platform id is required") + } + + switch r.Platform { + case PlatformCode: + decoded, err := base58.Decode(r.PlatformId) + if err != nil { + return errors.Wrap(err, "invalid base58 plaftorm id") + } + + if len(decoded) != 32 { + return errors.Wrap(err, "platform id is not a 32 byte buffer") + } + case PlatformTwitter: + if len(r.PlatformId) > 15 { + return errors.New("platform id must have at most 15 characters") + } + default: + return errors.Errorf("invalid plaftorm: %d", r.Platform) + } + + if r.DeliveryPointer != nil { + if err := r.DeliveryPointer.Validate(); err != nil { + return errors.Wrap(err, "invalid delivery pointer") + } + } + + if r.ReadPointer != nil { + if err := r.ReadPointer.Validate(); err != nil { + return errors.Wrap(err, "invalid read pointer") + } + } + + if r.JoinedAt.IsZero() { + return errors.New("joined timestamp is required") + } + + return nil +} + +// Clone clones a member record +func (r *MemberRecord) Clone() MemberRecord { + var deliveryPointerCopy *MessageId + if r.DeliveryPointer != nil { + cloned := r.DeliveryPointer.Clone() + deliveryPointerCopy = &cloned + } + + var readPointerCopy *MessageId + if r.ReadPointer != nil { + cloned := r.ReadPointer.Clone() + readPointerCopy = &cloned + } + + return MemberRecord{ + Id: r.Id, + ChatId: r.ChatId, + MemberId: r.MemberId, + + Platform: r.Platform, + PlatformId: r.PlatformId, + + DeliveryPointer: deliveryPointerCopy, + ReadPointer: readPointerCopy, + + IsMuted: r.IsMuted, + IsUnsubscribed: r.IsUnsubscribed, + + JoinedAt: r.JoinedAt, + } +} + +// CopyTo copies a member record to the provided destination +func (r *MemberRecord) CopyTo(dst *MemberRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + dst.MemberId = r.MemberId + + dst.Platform = r.Platform + dst.PlatformId = r.PlatformId + + if r.DeliveryPointer != nil { + cloned := r.DeliveryPointer.Clone() + dst.DeliveryPointer = &cloned + } + if r.ReadPointer != nil { + cloned := r.ReadPointer.Clone() + dst.ReadPointer = &cloned + } + + dst.IsMuted = r.IsMuted + dst.IsUnsubscribed = r.IsUnsubscribed + + dst.JoinedAt = r.JoinedAt +} + +// Validate validates a message Record +func (r *MessageRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return errors.Wrap(err, "invalid chat id") + } + + if err := r.MessageId.Validate(); err != nil { + return errors.Wrap(err, "invalid message id") + } + + if r.Sender != nil { + if err := r.Sender.Validate(); err != nil { + return errors.Wrap(err, "invalid sender id") + } + } + + if len(r.Data) == 0 { + return errors.New("message data is required") + } + + if r.Reference == nil && r.ReferenceType != nil { + return errors.New("reference is required when reference type is provided") + } + + if r.Reference != nil && r.ReferenceType == nil { + return errors.New("reference cannot be set when reference type is missing") + } + + if r.ReferenceType != nil { + switch *r.ReferenceType { + case ReferenceTypeIntent: + decoded, err := base58.Decode(*r.Reference) + if err != nil { + return errors.Wrap(err, "invalid base58 intent id reference") + } + + if len(decoded) != 32 { + return errors.Wrap(err, "reference is not a 32 byte buffer") + } + case ReferenceTypeSignature: + decoded, err := base58.Decode(*r.Reference) + if err != nil { + return errors.Wrap(err, "invalid base58 signature reference") + } + + if len(decoded) != 64 { + return errors.Wrap(err, "reference is not a 64 byte buffer") + } + default: + return errors.Errorf("invalid reference type: %d", *r.ReferenceType) + } + } + + return nil +} + +// Clone clones a message record +func (r *MessageRecord) Clone() MessageRecord { + var senderCopy *MemberId + if r.Sender != nil { + cloned := r.Sender.Clone() + senderCopy = &cloned + } + + dataCopy := make([]byte, len(r.Data)) + copy(dataCopy, r.Data) + + var referenceTypeCopy *ReferenceType + if r.ReferenceType != nil { + cloned := *r.ReferenceType + referenceTypeCopy = &cloned + } + + return MessageRecord{ + Id: r.Id, + ChatId: r.ChatId, + MessageId: r.MessageId, + + Sender: senderCopy, + + Data: dataCopy, + + ReferenceType: referenceTypeCopy, + Reference: pointer.StringCopy(r.Reference), + + IsSilent: r.IsSilent, + } +} + +// CopyTo copies a message record to the provided destination +func (r *MessageRecord) CopyTo(dst *MessageRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + dst.MessageId = r.MessageId + + if r.Sender != nil { + cloned := r.Sender.Clone() + dst.Sender = &cloned + } + + dataCopy := make([]byte, len(r.Data)) + copy(dataCopy, r.Data) + dst.Data = dataCopy + + if r.ReferenceType != nil { + cloned := *r.ReferenceType + dst.ReferenceType = &cloned + } + dst.Reference = pointer.StringCopy(r.Reference) + + dst.IsSilent = r.IsSilent +} + +// GetTimestamp gets the timestamp for a message record +func (r *MessageRecord) GetTimestamp() (time.Time, error) { + return r.MessageId.GetTimestamp() +} diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go new file mode 100644 index 00000000..a3fc4b43 --- /dev/null +++ b/pkg/code/data/chat/v2/store.go @@ -0,0 +1,68 @@ +package chat_v2 + +import ( + "context" + "errors" + + "github.com/code-payments/code-server/pkg/database/query" +) + +var ( + ErrChatExists = errors.New("chat already exists") + ErrChatNotFound = errors.New("chat not found") + ErrMemberExists = errors.New("chat member already exists") + ErrMemberNotFound = errors.New("chat member not found") + ErrMemberIdentityAlreadyUpgraded = errors.New("chat member identity already upgraded") + ErrMessageExsits = errors.New("chat message already exists") + ErrMessageNotFound = errors.New("chat message not found") + ErrInvalidPointerType = errors.New("invalid pointer type") +) + +// todo: Define interface methods +type Store interface { + // GetChatById gets a chat by its chat ID + GetChatById(ctx context.Context, chatId ChatId) (*ChatRecord, error) + + // GetMemberById gets a chat member by the chat and member IDs + GetMemberById(ctx context.Context, chatId ChatId, memberId MemberId) (*MemberRecord, error) + + // GetMessageById gets a chat message by the chat and message IDs + GetMessageById(ctx context.Context, chatId ChatId, messageId MessageId) (*MessageRecord, error) + + // GetAllMembersByChatId gets all members for a given chat + // + // todo: Add paging when we introduce group chats + GetAllMembersByChatId(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) + + // GetAllMembersByPlatformIds gets all members for platform users across all chats + GetAllMembersByPlatformIds(ctx context.Context, idByPlatform map[Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) + + // GetAllMessagesByChatId gets all messages for a given chat + // + // Note: Cursor is a message ID + GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) + + // GetUnreadCount gets the unread message count for a chat ID at a read pointer for a given chat member + GetUnreadCount(ctx context.Context, chatId ChatId, memberId MemberId, readPointer MessageId) (uint32, error) + + // PutChat creates a new chat + PutChat(ctx context.Context, record *ChatRecord) error + + // PutMember creates a new chat member + PutMember(ctx context.Context, record *MemberRecord) error + + // PutMessage creates a new chat message + PutMessage(ctx context.Context, record *MessageRecord) error + + // AdvancePointer advances a chat pointer for a chat member + AdvancePointer(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) (bool, error) + + // UpgradeIdentity upgrades a chat member's identity from an anonymous state + UpgradeIdentity(ctx context.Context, chatId ChatId, memberId MemberId, platform Platform, platformId string) error + + // SetMuteState updates the mute state for a chat member + SetMuteState(ctx context.Context, chatId ChatId, memberId MemberId, isMuted bool) error + + // SetSubscriptionState updates the subscription state for a chat member + SetSubscriptionState(ctx context.Context, chatId ChatId, memberId MemberId, isSubscribed bool) error +} diff --git a/pkg/code/data/chat/v2/tests/tests.go b/pkg/code/data/chat/v2/tests/tests.go new file mode 100644 index 00000000..94c85a89 --- /dev/null +++ b/pkg/code/data/chat/v2/tests/tests.go @@ -0,0 +1,14 @@ +package tests + +import ( + "testing" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" +) + +func RunTests(t *testing.T, s chat.Store, teardown func()) { + for _, tf := range []func(t *testing.T, s chat.Store){} { + tf(t, s) + teardown() + } +} diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 49479c84..c043a8f2 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -25,7 +25,8 @@ import ( "github.com/code-payments/code-server/pkg/code/data/airdrop" "github.com/code-payments/code-server/pkg/code/data/badgecount" "github.com/code-payments/code-server/pkg/code/data/balance" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/contact" "github.com/code-payments/code-server/pkg/code/data/currency" @@ -59,7 +60,8 @@ import ( airdrop_memory_client "github.com/code-payments/code-server/pkg/code/data/airdrop/memory" badgecount_memory_client "github.com/code-payments/code-server/pkg/code/data/badgecount/memory" balance_memory_client "github.com/code-payments/code-server/pkg/code/data/balance/memory" - chat_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/memory" + chat_v1_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/v1/memory" + chat_v2_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/v2/memory" commitment_memory_client "github.com/code-payments/code-server/pkg/code/data/commitment/memory" contact_memory_client "github.com/code-payments/code-server/pkg/code/data/contact/memory" currency_memory_client "github.com/code-payments/code-server/pkg/code/data/currency/memory" @@ -94,7 +96,7 @@ import ( airdrop_postgres_client "github.com/code-payments/code-server/pkg/code/data/airdrop/postgres" badgecount_postgres_client "github.com/code-payments/code-server/pkg/code/data/badgecount/postgres" balance_postgres_client "github.com/code-payments/code-server/pkg/code/data/balance/postgres" - chat_postgres_client "github.com/code-payments/code-server/pkg/code/data/chat/postgres" + chat_v1_postgres_client "github.com/code-payments/code-server/pkg/code/data/chat/v1/postgres" commitment_postgres_client "github.com/code-payments/code-server/pkg/code/data/commitment/postgres" contact_postgres_client "github.com/code-payments/code-server/pkg/code/data/contact/postgres" currency_postgres_client "github.com/code-payments/code-server/pkg/code/data/currency/postgres" @@ -378,19 +380,36 @@ type DatabaseData interface { CountWebhookByState(ctx context.Context, state webhook.State) (uint64, error) GetAllPendingWebhooksReadyToSend(ctx context.Context, limit uint64) ([]*webhook.Record, error) - // Chat + // Chat V1 // -------------------------------------------------------------------------------- - PutChat(ctx context.Context, record *chat.Chat) error - GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.Chat, error) - GetAllChatsForUser(ctx context.Context, user string, opts ...query.Option) ([]*chat.Chat, error) - PutChatMessage(ctx context.Context, record *chat.Message) error - DeleteChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) error - GetChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) - GetAllChatMessages(ctx context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.Message, error) - AdvanceChatPointer(ctx context.Context, chatId chat.ChatId, pointer string) error - GetChatUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) - SetChatMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error - SetChatSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error + PutChatV1(ctx context.Context, record *chat_v1.Chat) error + GetChatByIdV1(ctx context.Context, chatId chat_v1.ChatId) (*chat_v1.Chat, error) + GetAllChatsForUserV1(ctx context.Context, user string, opts ...query.Option) ([]*chat_v1.Chat, error) + PutChatMessageV1(ctx context.Context, record *chat_v1.Message) error + DeleteChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) error + GetChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) (*chat_v1.Message, error) + GetAllChatMessagesV1(ctx context.Context, chatId chat_v1.ChatId, opts ...query.Option) ([]*chat_v1.Message, error) + AdvanceChatPointerV1(ctx context.Context, chatId chat_v1.ChatId, pointer string) error + GetChatUnreadCountV1(ctx context.Context, chatId chat_v1.ChatId) (uint32, error) + SetChatMuteStateV1(ctx context.Context, chatId chat_v1.ChatId, isMuted bool) error + SetChatSubscriptionStateV1(ctx context.Context, chatId chat_v1.ChatId, isSubscribed bool) error + + // Chat V2 + // -------------------------------------------------------------------------------- + GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error) + GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) + GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) + GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) + GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) + GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) + GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) + PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error + PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error + PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error + AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) + UpgradeChatMemberIdentityV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, platform chat_v2.Platform, platformId string) error + SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error + SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error // Badge Count // -------------------------------------------------------------------------------- @@ -470,7 +489,8 @@ type DatabaseProvider struct { paywall paywall.Store event event.Store webhook webhook.Store - chat chat.Store + chatv1 chat_v1.Store + chatv2 chat_v2.Store badgecount badgecount.Store login login.Store balance balance.Store @@ -532,7 +552,8 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { paywall: paywall_postgres_client.New(db), event: event_postgres_client.New(db), webhook: webhook_postgres_client.New(db), - chat: chat_postgres_client.New(db), + chatv1: chat_v1_postgres_client.New(db), + chatv2: chat_v2_memory_client.New(), // todo: Postgres version for production after PoC badgecount: badgecount_postgres_client.New(db), login: login_postgres_client.New(db), balance: balance_postgres_client.New(db), @@ -575,7 +596,8 @@ func NewTestDatabaseProvider() DatabaseData { paywall: paywall_memory_client.New(), event: event_memory_client.New(), webhook: webhook_memory_client.New(), - chat: chat_memory_client.New(), + chatv1: chat_v1_memory_client.New(), + chatv2: chat_v2_memory_client.New(), badgecount: badgecount_memory_client.New(), login: login_memory_client.New(), balance: balance_memory_client.New(), @@ -1399,48 +1421,101 @@ func (dp *DatabaseProvider) GetAllPendingWebhooksReadyToSend(ctx context.Context return dp.webhook.GetAllPendingReadyToSend(ctx, limit) } -// Chat +// Chat V1 // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) PutChat(ctx context.Context, record *chat.Chat) error { - return dp.chat.PutChat(ctx, record) +func (dp *DatabaseProvider) PutChatV1(ctx context.Context, record *chat_v1.Chat) error { + return dp.chatv1.PutChat(ctx, record) } -func (dp *DatabaseProvider) GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.Chat, error) { - return dp.chat.GetChatById(ctx, chatId) +func (dp *DatabaseProvider) GetChatByIdV1(ctx context.Context, chatId chat_v1.ChatId) (*chat_v1.Chat, error) { + return dp.chatv1.GetChatById(ctx, chatId) } -func (dp *DatabaseProvider) GetAllChatsForUser(ctx context.Context, user string, opts ...query.Option) ([]*chat.Chat, error) { +func (dp *DatabaseProvider) GetAllChatsForUserV1(ctx context.Context, user string, opts ...query.Option) ([]*chat_v1.Chat, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chat.GetAllChatsForUser(ctx, user, req.Cursor, req.SortBy, req.Limit) + return dp.chatv1.GetAllChatsForUser(ctx, user, req.Cursor, req.SortBy, req.Limit) } -func (dp *DatabaseProvider) PutChatMessage(ctx context.Context, record *chat.Message) error { - return dp.chat.PutMessage(ctx, record) +func (dp *DatabaseProvider) PutChatMessageV1(ctx context.Context, record *chat_v1.Message) error { + return dp.chatv1.PutMessage(ctx, record) } -func (dp *DatabaseProvider) DeleteChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) error { - return dp.chat.DeleteMessage(ctx, chatId, messageId) +func (dp *DatabaseProvider) DeleteChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) error { + return dp.chatv1.DeleteMessage(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) { - return dp.chat.GetMessageById(ctx, chatId, messageId) +func (dp *DatabaseProvider) GetChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) (*chat_v1.Message, error) { + return dp.chatv1.GetMessageById(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetAllChatMessages(ctx context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.Message, error) { +func (dp *DatabaseProvider) GetAllChatMessagesV1(ctx context.Context, chatId chat_v1.ChatId, opts ...query.Option) ([]*chat_v1.Message, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chat.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) + return dp.chatv1.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) +} +func (dp *DatabaseProvider) AdvanceChatPointerV1(ctx context.Context, chatId chat_v1.ChatId, pointer string) error { + return dp.chatv1.AdvancePointer(ctx, chatId, pointer) +} +func (dp *DatabaseProvider) GetChatUnreadCountV1(ctx context.Context, chatId chat_v1.ChatId) (uint32, error) { + return dp.chatv1.GetUnreadCount(ctx, chatId) +} +func (dp *DatabaseProvider) SetChatMuteStateV1(ctx context.Context, chatId chat_v1.ChatId, isMuted bool) error { + return dp.chatv1.SetMuteState(ctx, chatId, isMuted) +} +func (dp *DatabaseProvider) SetChatSubscriptionStateV1(ctx context.Context, chatId chat_v1.ChatId, isSubscribed bool) error { + return dp.chatv1.SetSubscriptionState(ctx, chatId, isSubscribed) +} + +// Chat V2 +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error) { + return dp.chatv2.GetChatById(ctx, chatId) +} +func (dp *DatabaseProvider) GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) { + return dp.chatv2.GetMemberById(ctx, chatId, memberId) +} +func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) { + return dp.chatv2.GetMessageById(ctx, chatId, messageId) +} +func (dp *DatabaseProvider) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { + return dp.chatv2.GetAllMembersByChatId(ctx, chatId) +} +func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { + req, err := query.DefaultPaginationHandler(opts...) + if err != nil { + return nil, err + } + return dp.chatv2.GetAllMembersByPlatformIds(ctx, idByPlatform, req.Cursor, req.SortBy, req.Limit) +} +func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { + req, err := query.DefaultPaginationHandler(opts...) + if err != nil { + return nil, err + } + return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit) +} +func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) { + return dp.chatv2.GetUnreadCount(ctx, chatId, memberId, readPointer) +} +func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error { + return dp.chatv2.PutChat(ctx, record) +} +func (dp *DatabaseProvider) PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error { + return dp.chatv2.PutMember(ctx, record) +} +func (dp *DatabaseProvider) PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error { + return dp.chatv2.PutMessage(ctx, record) } -func (dp *DatabaseProvider) AdvanceChatPointer(ctx context.Context, chatId chat.ChatId, pointer string) error { - return dp.chat.AdvancePointer(ctx, chatId, pointer) +func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) { + return dp.chatv2.AdvancePointer(ctx, chatId, memberId, pointerType, pointer) } -func (dp *DatabaseProvider) GetChatUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) { - return dp.chat.GetUnreadCount(ctx, chatId) +func (dp *DatabaseProvider) UpgradeChatMemberIdentityV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, platform chat_v2.Platform, platformId string) error { + return dp.chatv2.UpgradeIdentity(ctx, chatId, memberId, platform, platformId) } -func (dp *DatabaseProvider) SetChatMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error { - return dp.chat.SetMuteState(ctx, chatId, isMuted) +func (dp *DatabaseProvider) SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error { + return dp.chatv2.SetMuteState(ctx, chatId, memberId, isMuted) } -func (dp *DatabaseProvider) SetChatSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error { - return dp.chat.SetSubscriptionState(ctx, chatId, isSubscribed) +func (dp *DatabaseProvider) SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error { + return dp.chatv2.SetSubscriptionState(ctx, chatId, memberId, isSubscribed) } // Badge Count diff --git a/pkg/code/push/notifications.go b/pkg/code/push/notifications.go index ccc361fb..bfe4db7c 100644 --- a/pkg/code/push/notifications.go +++ b/pkg/code/push/notifications.go @@ -14,7 +14,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/code/thirdparty" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -59,13 +59,13 @@ func SendDepositPushNotification( // Legacy push notification still considers chat mute state // // todo: Proper migration to chat system - chatRecord, err := data.GetChatById(ctx, chat.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) + chatRecord, err := data.GetChatByIdV1(ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) switch err { case nil: if chatRecord.IsMuted { return nil } - case chat.ErrChatNotFound: + case chat_v1.ErrChatNotFound: default: log.WithError(err).Warn("failure getting chat record") return errors.Wrap(err, "error getting chat record") @@ -139,13 +139,13 @@ func SendGiftCardReturnedPushNotification( // Legacy push notification still considers chat mute state // // todo: Proper migration to chat system - chatRecord, err := data.GetChatById(ctx, chat.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) + chatRecord, err := data.GetChatByIdV1(ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) switch err { case nil: if chatRecord.IsMuted { return nil } - case chat.ErrChatNotFound: + case chat_v1.ErrChatNotFound: default: log.WithError(err).Warn("failure getting chat record") return errors.Wrap(err, "error getting chat record") @@ -320,15 +320,15 @@ func SendChatMessagePushNotification( for _, content := range chatMessage.Content { var contentToPush *chatpb.Content switch typedContent := content.Type.(type) { - case *chatpb.Content_Localized: - localizedPushBody, err := localization.Localize(locale, typedContent.Localized.KeyOrText) + case *chatpb.Content_ServerLocalized: + localizedPushBody, err := localization.Localize(locale, typedContent.ServerLocalized.KeyOrText) if err != nil { continue } contentToPush = &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localizedPushBody, }, }, @@ -358,14 +358,23 @@ func SendChatMessagePushNotification( } contentToPush = &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localizedPushBody, }, }, } - case *chatpb.Content_NaclBox: + case *chatpb.Content_NaclBox, *chatpb.Content_Text: contentToPush = content + case *chatpb.Content_ThankYou: + contentToPush = &chatpb.Content{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ + // todo: localize this + KeyOrText: "🙏 They thanked you for their tip", + }, + }, + } } if contentToPush == nil { diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/v1/server.go similarity index 60% rename from pkg/code/server/grpc/chat/server.go rename to pkg/code/server/grpc/chat/v1/server.go index 362c0fd7..5fef09f9 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/v1/server.go @@ -1,14 +1,20 @@ -package chat +package chat_v1 import ( + "bytes" "context" + "fmt" "math" + "strings" + "sync" + "time" "github.com/mr-tron/base58" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" @@ -18,30 +24,55 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" + push_util "github.com/code-payments/code-server/pkg/code/push" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/grpc/client" + push_lib "github.com/code-payments/code-server/pkg/push" + sync_util "github.com/code-payments/code-server/pkg/sync" ) const ( maxPageSize = 100 ) +var ( + mockTwoWayChat = chat.GetChatId("user1", "user2", true).ToProto() +) + +// todo: Resolve duplication of streaming logic with messaging service. The latest and greatest will live here. type server struct { - log *logrus.Entry - data code_data.Provider - auth *auth_util.RPCSignatureVerifier + log *logrus.Entry + data code_data.Provider + auth *auth_util.RPCSignatureVerifier + pusher push_lib.Provider + + streamsMu sync.RWMutex + streams map[string]*chatEventStream + + chatLocks *sync_util.StripedLock + chatEventChans *sync_util.StripedChannel chatpb.UnimplementedChatServer } -func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { - return &server{ - log: logrus.StandardLogger().WithField("type", "chat/server"), - data: data, - auth: auth, +func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier, pusher push_lib.Provider) chatpb.ChatServer { + s := &server{ + log: logrus.StandardLogger().WithField("type", "chat/v1/server"), + data: data, + auth: auth, + pusher: pusher, + streams: make(map[string]*chatEventStream), + chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters + chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters } + + for i, channel := range s.chatEventChans.GetChannels() { + go s.asyncChatEventStreamNotifier(i, channel) + } + + return s } func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { @@ -88,7 +119,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } } - chatRecords, err := s.data.GetAllChatsForUser( + chatRecords, err := s.data.GetAllChatsForUserV1( ctx, owner.PublicKey().ToBase58(), query.WithCursor(cursor), @@ -147,7 +178,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } protoMetadata.Title = &chatpb.ChatMetadata_Localized{ - Localized: &chatpb.LocalizedContent{ + Localized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.LocalizeWithFallback( locale, localization.GetLocalizationKeyForUserAgent(ctx, chatProperties.TitleLocalizationKey), @@ -189,7 +220,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch if !skipUnreadCountQuery && !chatRecord.IsMuted && !chatRecord.IsUnsubscribed { // todo: will need batching when users have a large number of chats - unreadCount, err := s.data.GetChatUnreadCount(ctx, chatRecord.ChatId) + unreadCount, err := s.data.GetChatUnreadCountV1(ctx, chatRecord.ChatId) if err != nil { log.WithError(err).Warn("failure getting unread count") } @@ -229,7 +260,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.GetMessagesResponse{ Result: chatpb.GetMessagesResponse_NOT_FOUND, @@ -265,7 +296,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest cursor = req.Cursor.Value } - messageRecords, err := s.data.GetAllChatMessages( + messageRecords, err := s.data.GetAllChatMessagesV1( ctx, chatId, query.WithCursor(cursor), @@ -298,11 +329,11 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest for _, content := range protoChatMessage.Content { switch typed := content.Type.(type) { - case *chatpb.Content_Localized: - typed.Localized.KeyOrText = localization.LocalizeWithFallback( + case *chatpb.Content_ServerLocalized: + typed.ServerLocalized.KeyOrText = localization.LocalizeWithFallback( locale, - localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), - typed.Localized.KeyOrText, + localization.GetLocalizationKeyForUserAgent(ctx, typed.ServerLocalized.KeyOrText), + typed.ServerLocalized.KeyOrText, ) } } @@ -347,26 +378,45 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId := chat.ChatIdFromProto(req.ChatId) - log = log.WithField("chat_id", chatId.String()) + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + chatId := chat.ChatIdFromProto(req.ChatId) messageId := base58.Encode(req.Pointer.Value.Value) log = log.WithFields(logrus.Fields{ + "chat_id": chatId.String(), "message_id": messageId, "pointer_type": req.Pointer.Kind, }) - if req.Pointer.Kind != chatpb.Pointer_READ { - return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") + // todo: Temporary code to simluate real-time + if req.Pointer.User != nil { + return nil, status.Error(codes.InvalidArgument, "pointer.user cannot be set by clients") } + if bytes.Equal(mockTwoWayChat.Value, req.ChatId.Value) { + req.Pointer.User = &chatpb.ChatMemberId{Value: req.Owner.Value} - signature := req.Signature - req.Signature = nil - if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { - return nil, err + event := &chatpb.ChatStreamEvent{ + Pointers: []*chatpb.Pointer{req.Pointer}, + } + + if err := s.asyncNotifyAll(chatId, owner, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_OK, + }, nil + } + + if req.Pointer.Kind != chatpb.Pointer_READ { + return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_CHAT_NOT_FOUND, @@ -380,7 +430,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.PermissionDenied, "") } - newPointerRecord, err := s.data.GetChatMessage(ctx, chatId, messageId) + newPointerRecord, err := s.data.GetChatMessageV1(ctx, chatId, messageId) if err == chat.ErrMessageNotFound { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_MESSAGE_NOT_FOUND, @@ -391,7 +441,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } if chatRecord.ReadPointer != nil { - oldPointerRecord, err := s.data.GetChatMessage(ctx, chatId, *chatRecord.ReadPointer) + oldPointerRecord, err := s.data.GetChatMessageV1(ctx, chatId, *chatRecord.ReadPointer) if err != nil { log.WithError(err).Warn("failure getting chat message record for old pointer value") return nil, status.Error(codes.Internal, "") @@ -404,7 +454,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } } - err = s.data.AdvanceChatPointer(ctx, chatId, messageId) + err = s.data.AdvanceChatPointerV1(ctx, chatId, messageId) if err != nil { log.WithError(err).Warn("failure advancing pointer") return nil, status.Error(codes.Internal, "") @@ -434,7 +484,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.SetMuteStateResponse{ Result: chatpb.SetMuteStateResponse_CHAT_NOT_FOUND, @@ -461,7 +511,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque }, nil } - err = s.data.SetChatMuteState(ctx, chatId, req.IsMuted) + err = s.data.SetChatMuteStateV1(ctx, chatId, req.IsMuted) if err != nil { log.WithError(err).Warn("failure setting mute status") return nil, status.Error(codes.Internal, "") @@ -492,7 +542,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.SetSubscriptionStateResponse{ Result: chatpb.SetSubscriptionStateResponse_CHAT_NOT_FOUND, @@ -519,7 +569,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr }, nil } - err = s.data.SetChatSubscriptionState(ctx, chatId, req.IsSubscribed) + err = s.data.SetChatSubscriptionStateV1(ctx, chatId, req.IsSubscribed) if err != nil { log.WithError(err).Warn("failure setting subcription status") return nil, status.Error(codes.Internal, "") @@ -529,3 +579,232 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr Result: chatpb.SetSubscriptionStateResponse_OK, }, nil } + +// +// Experimental PoC two-way chat APIs below +// + +func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { + ctx := streamer.Context() + + log := s.log.WithField("method", "StreamChatEvents") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedStreamChatEventsRecv(ctx, streamer, 250*time.Millisecond) + if err != nil { + return err + } + + if req.GetOpenStream() == nil { + return status.Error(codes.InvalidArgument, "open_stream is nil") + } + + if req.GetOpenStream().Signature == nil { + return status.Error(codes.InvalidArgument, "signature is nil") + } + + if !bytes.Equal(req.GetOpenStream().ChatId.Value, mockTwoWayChat.Value) { + return status.Error(codes.Unimplemented, "") + } + chatId := chat.ChatIdFromProto(req.GetOpenStream().ChatId) + log = log.WithField("chat_id", chatId.String()) + + owner, err := common.NewAccountFromProto(req.GetOpenStream().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + signature := req.GetOpenStream().Signature + req.GetOpenStream().Signature = nil + if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { + return err + } + + streamKey := fmt.Sprintf("%s:%s", chatId.String(), owner.PublicKey().ToBase58()) + + s.streamsMu.Lock() + + stream, exists := s.streams[streamKey] + if exists { + s.streamsMu.Unlock() + // There's an existing stream on this server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + stream = newChatEventStream(streamBufferSize) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = stream + + s.streamsMu.Unlock() + + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another OpenMessageStream() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == stream { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + + sendPingCh := time.After(0) + streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) + + for { + select { + case event, ok := <-stream.streamCh: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Events{ + Events: &chatpb.ChatStreamEventBatch{ + Events: []*chatpb.ChatStreamEvent{event}, + }, + }, + }) + if err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } +} + +func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { + log := s.log.WithField("method", "SendMessage") + log = client.InjectLoggingMetadata(ctx, log) + + if !bytes.Equal(req.ChatId.Value, mockTwoWayChat.Value) { + return nil, status.Error(codes.Unimplemented, "") + } + chatId := chat.ChatIdFromProto(req.ChatId) + log = log.WithField("chat_id", chatId.String()) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + switch req.Content[0].Type.(type) { + case *chatpb.Content_Text, *chatpb.Content_ThankYou: + default: + return nil, status.Error(codes.InvalidArgument, "content[0] must be Text or ThankYou") + } + + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + + // todo: Revisit message IDs + messageId, err := common.NewRandomAccount() + if err != nil { + log.WithError(err).Warn("failure generating random message id") + return nil, status.Error(codes.Internal, "") + } + + chatMessage := &chatpb.ChatMessage{ + MessageId: &chatpb.ChatMessageId{Value: messageId.ToProto().Value}, + Ts: timestamppb.Now(), + Content: req.Content, + Sender: &chatpb.ChatMemberId{Value: req.Owner.Value}, + Cursor: nil, // todo: Don't have cursor until we save it to the DB + } + + // todo: Save the message to the DB + + event := &chatpb.ChatStreamEvent{ + Messages: []*chatpb.ChatMessage{chatMessage}, + } + + if err := s.asyncNotifyAll(chatId, owner, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + s.asyncPushChatMessage(owner, chatId, chatMessage) + + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_OK, + Message: chatMessage, + }, nil +} + +// todo: doesn't respect mute/unsubscribe rules +// todo: only sends pushes to active stream listeners instead of all message recipients +func (s *server) asyncPushChatMessage(sender *common.Account, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { + ctx := context.TODO() + + go func() { + s.streamsMu.RLock() + for key := range s.streams { + if !strings.HasPrefix(key, chatId.String()) { + continue + } + + receiver, err := common.NewAccountFromPublicKeyString(strings.Split(key, ":")[1]) + if err != nil { + continue + } + + if bytes.Equal(sender.PublicKey().ToBytes(), receiver.PublicKey().ToBytes()) { + continue + } + + go push_util.SendChatMessagePushNotification( + ctx, + s.data, + s.pusher, + "TontonTwitch", + receiver, + chatMessage, + ) + } + s.streamsMu.RUnlock() + }() +} diff --git a/pkg/code/server/grpc/chat/server_test.go b/pkg/code/server/grpc/chat/v1/server_test.go similarity index 97% rename from pkg/code/server/grpc/chat/server_test.go rename to pkg/code/server/grpc/chat/v1/server_test.go index f6dd6eff..0a9abad4 100644 --- a/pkg/code/server/grpc/chat/server_test.go +++ b/pkg/code/server/grpc/chat/v1/server_test.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "context" @@ -22,13 +22,14 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/phone" "github.com/code-payments/code-server/pkg/code/data/preferences" "github.com/code-payments/code-server/pkg/code/data/user" "github.com/code-payments/code-server/pkg/code/data/user/storage" "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/kin" + memory_push "github.com/code-payments/code-server/pkg/push/memory" "github.com/code-payments/code-server/pkg/testutil" ) @@ -102,8 +103,8 @@ func TestGetChatsAndMessages_HappyPath(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, @@ -242,7 +243,7 @@ func TestGetChatsAndMessages_HappyPath(t *testing.T) { require.Len(t, getMessagesResp.Messages, 1) assert.Equal(t, expectedCodeTeamMessage.MessageId.Value, getMessagesResp.Messages[0].Cursor.Value) getMessagesResp.Messages[0].Cursor = nil - expectedCodeTeamMessage.Content[0].GetLocalized().KeyOrText = "localized message body content" + expectedCodeTeamMessage.Content[0].GetServerLocalized().KeyOrText = "localized message body content" assert.True(t, proto.Equal(expectedCodeTeamMessage, getMessagesResp.Messages[0])) getMessagesResp, err = env.client.GetMessages(env.ctx, getCashTransactionsMessagesReq) @@ -288,8 +289,8 @@ func TestChatHistoryReadState_HappyPath(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("msg.body.key%d", i), }, }, @@ -346,8 +347,8 @@ func TestChatHistoryReadState_NegativeProgress(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("msg.body.key%d", i), }, }, @@ -429,8 +430,8 @@ func TestChatHistoryReadState_MessageNotFound(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, @@ -743,8 +744,8 @@ func TestUnauthorizedAccess(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, @@ -880,7 +881,7 @@ func setup(t *testing.T) (env *testEnv, cleanup func()) { data: code_data.NewTestDataProvider(), } - s := NewChatServer(env.data, auth_util.NewRPCSignatureVerifier(env.data)) + s := NewChatServer(env.data, auth_util.NewRPCSignatureVerifier(env.data), memory_push.NewPushProvider()) env.server = s.(*server) serv.RegisterService(func(server *grpc.Server) { diff --git a/pkg/code/server/grpc/chat/v1/stream.go b/pkg/code/server/grpc/chat/v1/stream.go new file mode 100644 index 00000000..05b63235 --- /dev/null +++ b/pkg/code/server/grpc/chat/v1/stream.go @@ -0,0 +1,177 @@ +package chat_v1 + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" + "github.com/code-payments/code-server/pkg/code/common" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" +) + +const ( + // todo: configurable + streamBufferSize = 64 + streamPingDelay = 5 * time.Second + streamKeepAliveRecvTimeout = 10 * time.Second + streamNotifyTimeout = time.Second +) + +type chatEventStream struct { + sync.Mutex + + closed bool + streamCh chan *chatpb.ChatStreamEvent +} + +func newChatEventStream(bufferSize int) *chatEventStream { + return &chatEventStream{ + streamCh: make(chan *chatpb.ChatStreamEvent, bufferSize), + } +} + +func (s *chatEventStream) notify(event *chatpb.ChatStreamEvent, timeout time.Duration) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + + s.Lock() + + if s.closed { + s.Unlock() + return errors.New("cannot notify closed stream") + } + + select { + case s.streamCh <- m: + case <-time.After(timeout): + s.Unlock() + s.close() + return errors.New("timed out sending message to streamCh") + } + + s.Unlock() + return nil +} + +func (s *chatEventStream) close() { + s.Lock() + defer s.Unlock() + + if s.closed { + return + } + + s.closed = true + close(s.streamCh) +} + +func boundedStreamChatEventsRecv( + ctx context.Context, + streamer chatpb.Chat_StreamChatEventsServer, + timeout time.Duration, +) (req *chatpb.StreamChatEventsRequest, err error) { + done := make(chan struct{}) + go func() { + req, err = streamer.Recv() + close(done) + }() + + select { + case <-done: + return req, err + case <-ctx.Done(): + return nil, status.Error(codes.Canceled, "") + case <-time.After(timeout): + return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") + } +} + +type chatEventNotification struct { + chatId chat.ChatId + owner *common.Account + event *chatpb.ChatStreamEvent + ts time.Time +} + +func (s *server) asyncNotifyAll(chatId chat.ChatId, owner *common.Account, event *chatpb.ChatStreamEvent) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, owner, m, time.Now()}) + if !ok { + return errors.New("chat event channel is full") + } + return nil +} + +func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan interface{}) { + log := s.log.WithFields(logrus.Fields{ + "method": "asyncChatEventStreamNotifier", + "worker": workerId, + }) + + for value := range channel { + typedValue, ok := value.(*chatEventNotification) + if !ok { + log.Warn("channel did not receive expected struct") + continue + } + + log := log.WithField("chat_id", typedValue.chatId.String()) + + if time.Since(typedValue.ts) > time.Second { + log.Warn("channel notification latency is elevated") + } + + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, typedValue.chatId.String()) { + continue + } + + if strings.HasSuffix(key, typedValue.owner.PublicKey().ToBase58()) { + continue + } + + if err := stream.notify(typedValue.event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + } +} + +// Very naive implementation to start +func monitorChatEventStreamHealth( + ctx context.Context, + log *logrus.Entry, + ssRef string, + streamer chatpb.Chat_StreamChatEventsServer, +) <-chan struct{} { + streamHealthChan := make(chan struct{}) + go func() { + defer close(streamHealthChan) + + for { + // todo: configurable timeout + req, err := boundedStreamChatEventsRecv(ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + switch req.Type.(type) { + case *chatpb.StreamChatEventsRequest_Pong: + log.Tracef("received pong from client (stream=%s)", ssRef) + default: + // Client sent something unexpected. Terminate the stream + return + } + } + }() + return streamHealthChan +} diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go new file mode 100644 index 00000000..bc5351a2 --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -0,0 +1,1466 @@ +package chat_v2 + +import ( + "context" + "crypto/rand" + "database/sql" + "fmt" + "math" + "sync" + "time" + + "github.com/mr-tron/base58" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/text/language" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" + + auth_util "github.com/code-payments/code-server/pkg/code/auth" + "github.com/code-payments/code-server/pkg/code/common" + code_data "github.com/code-payments/code-server/pkg/code/data" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/code-payments/code-server/pkg/code/data/twitter" + "github.com/code-payments/code-server/pkg/code/localization" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/grpc/client" + timelock_token "github.com/code-payments/code-server/pkg/solana/timelock/v1" + sync_util "github.com/code-payments/code-server/pkg/sync" +) + +// todo: resolve some common code for sending chat messages across RPCs + +const ( + maxGetChatsPageSize = 100 + maxGetMessagesPageSize = 100 + flushMessageCount = 100 +) + +type server struct { + log *logrus.Entry + + data code_data.Provider + auth *auth_util.RPCSignatureVerifier + + streamsMu sync.RWMutex + streams map[string]*chatEventStream + + chatLocks *sync_util.StripedLock + chatEventChans *sync_util.StripedChannel + + chatpb.UnimplementedChatServer +} + +func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { + s := &server{ + log: logrus.StandardLogger().WithField("type", "chat/v2/server"), + + data: data, + auth: auth, + + streams: make(map[string]*chatEventStream), + + chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters + chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters + } + + for i, channel := range s.chatEventChans.GetChannels() { + go s.asyncChatEventStreamNotifier(i, channel) + } + + return s +} + +// todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership +func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { + log := s.log.WithField("method", "GetChats") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + var limit uint64 + if req.PageSize > 0 { + limit = uint64(req.PageSize) + } else { + limit = maxGetChatsPageSize + } + if limit > maxGetChatsPageSize { + limit = maxGetChatsPageSize + } + + var direction query.Ordering + if req.Direction == chatpb.GetChatsRequest_ASC { + direction = query.Ascending + } else { + direction = query.Descending + } + + var cursor query.Cursor + if req.Cursor != nil { + cursor = req.Cursor.Value + } else { + cursor = query.ToCursor(0) + if direction == query.Descending { + cursor = query.ToCursor(math.MaxInt64 - 1) + } + } + + myIdentities, err := s.getAllIdentities(ctx, owner) + if err != nil { + log.WithError(err).Warn("failure getting identities for owner account") + return nil, status.Error(codes.Internal, "") + } + + // todo: Use a better query that returns chat IDs. This will result in duplicate + // chat results if the user is in the chat multiple times across many identities. + patformUserMemberRecords, err := s.data.GetPlatformUserChatMembershipV2( + ctx, + myIdentities, + query.WithCursor(cursor), + query.WithDirection(direction), + query.WithLimit(limit), + ) + if err == chat.ErrMemberNotFound { + return &chatpb.GetChatsResponse{ + Result: chatpb.GetChatsResponse_NOT_FOUND, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting chat members for platform user") + return nil, status.Error(codes.Internal, "") + } + + var protoChats []*chatpb.ChatMetadata + for _, platformUserMemberRecord := range patformUserMemberRecords { + log := log.WithField("chat_id", platformUserMemberRecord.ChatId.String()) + + chatRecord, err := s.data.GetChatByIdV2(ctx, platformUserMemberRecord.ChatId) + if err != nil { + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) + if err != nil { + log.WithError(err).Warn("failure getting chat members") + return nil, status.Error(codes.Internal, "") + } + + protoChat, err := s.toProtoChat(ctx, chatRecord, memberRecords, myIdentities) + if err != nil { + log.WithError(err).Warn("failure constructing proto chat message") + return nil, status.Error(codes.Internal, "") + } + protoChat.Cursor = &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))} + + protoChats = append(protoChats, protoChat) + } + + return &chatpb.GetChatsResponse{ + Result: chatpb.GetChatsResponse_OK, + Chats: protoChats, + }, nil +} + +func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest) (*chatpb.GetMessagesResponse, error) { + log := s.log.WithField("method", "GetMessages") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_DENIED, + }, nil + } + + var limit uint64 + if req.PageSize > 0 { + limit = uint64(req.PageSize) + } else { + limit = maxGetMessagesPageSize + } + if limit > maxGetMessagesPageSize { + limit = maxGetMessagesPageSize + } + + var direction query.Ordering + if req.Direction == chatpb.GetMessagesRequest_ASC { + direction = query.Ascending + } else { + direction = query.Descending + } + + var cursor query.Cursor + if req.Cursor != nil { + cursor = req.Cursor.Value + } + + protoChatMessages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithCursor(cursor), + query.WithDirection(direction), + query.WithLimit(limit), + ) + if err == chat.ErrMessageNotFound { + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting chat messages") + return nil, status.Error(codes.Internal, "") + } + + if len(protoChatMessages) == 0 { + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, + }, nil + } + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_OK, + Messages: protoChatMessages, + }, nil +} + +func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { + ctx := streamer.Context() + + log := s.log.WithField("method", "StreamChatEvents") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedStreamChatEventsRecv(ctx, streamer, 250*time.Millisecond) + if err != nil { + return err + } + + if req.GetOpenStream() == nil { + return status.Error(codes.InvalidArgument, "StreamChatEventsRequest.Type must be OpenStreamRequest") + } + + owner, err := common.NewAccountFromProto(req.GetOpenStream().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.GetOpenStream().ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.GetOpenStream().MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.GetOpenStream().Signature + req.GetOpenStream().Signature = nil + if err := s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { + return err + } + + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Error{ + Error: &chatpb.ChatStreamEventError{Code: chatpb.ChatStreamEventError_CHAT_NOT_FOUND}, + }, + }) + default: + log.WithError(err).Warn("failure getting chat record") + return status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return status.Error(codes.Internal, "") + } else if !ownsChatMember { + return streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Error{ + Error: &chatpb.ChatStreamEventError{Code: chatpb.ChatStreamEventError_DENIED}, + }, + }) + } + + streamKey := fmt.Sprintf("%s:%s", chatId.String(), memberId.String()) + + s.streamsMu.Lock() + + stream, exists := s.streams[streamKey] + if exists { + s.streamsMu.Unlock() + // There's an existing stream on this server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + stream = newChatEventStream(streamBufferSize) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = stream + + s.streamsMu.Unlock() + + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another StreamChatEvents() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == stream { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + + sendPingCh := time.After(0) + streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) + + go s.flushMessages(ctx, chatId, owner, stream) + go s.flushPointers(ctx, chatId, stream) + + for { + select { + case event, ok := <-stream.streamCh: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Events{ + Events: &chatpb.ChatStreamEventBatch{ + Events: []*chatpb.ChatStreamEvent{event}, + }, + }, + }) + if err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } +} + +func (s *server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushMessages", + "chat_id": chatId.String(), + "owner_account": owner.PublicKey().ToBase58(), + }) + + protoChatMessages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithCursor(query.EmptyCursor), + query.WithDirection(query.Descending), + query.WithLimit(flushMessageCount), + ) + if err == chat.ErrMessageNotFound { + return + } else if err != nil { + log.WithError(err).Warn("failure getting chat messages") + return + } + + for _, protoChatMessage := range protoChatMessages { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: protoChatMessage, + }, + } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } +} + +func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream *chatEventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushPointers", + "chat_id": chatId.String(), + }) + + memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatId) + if err == chat.ErrMemberNotFound { + return + } else if err != nil { + log.WithError(err).Warn("failure getting chat members") + return + } + + for _, memberRecord := range memberRecords { + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Pointer{ + Pointer: &chatpb.Pointer{ + Type: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberRecord.MemberId.ToProto(), + }, + }, + } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } + } +} + +func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { + log := s.log.WithField("method", "StartChat") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + switch typed := req.Parameters.(type) { + case *chatpb.StartChatRequest_TipChat: + intentId := base58.Encode(typed.TipChat.IntentId.Value) + log = log.WithField("intent", intentId) + + intentRecord, err := s.data.GetIntent(ctx, intentId) + if err == intent.ErrIntentNotFound { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting intent record") + return nil, status.Error(codes.Internal, "") + } + + // The intent was not for a tip. + if intentRecord.SendPrivatePaymentMetadata == nil || !intentRecord.SendPrivatePaymentMetadata.IsTip { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } + + tipper, err := common.NewAccountFromPublicKeyString(intentRecord.InitiatorOwnerAccount) + if err != nil { + log.WithError(err).Warn("invalid tipper owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("tipper", tipper.PublicKey().ToBase58()) + + tippee, err := common.NewAccountFromPublicKeyString(intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount) + if err != nil { + log.WithError(err).Warn("invalid tippee owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("tippee", tippee.PublicKey().ToBase58()) + + // For now, don't allow chats where you tipped yourself. + // + // todo: How do we want to handle this case? + if owner.PublicKey().ToBase58() == tipper.PublicKey().ToBase58() { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } + + // Only the owner of the platform user at the time of tipping can initiate the chat. + if owner.PublicKey().ToBase58() != tippee.PublicKey().ToBase58() { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_DENIED, + Chat: nil, + }, nil + } + + // todo: This will require a refactor when we allow creation of other types of chats + switch intentRecord.SendPrivatePaymentMetadata.TipMetadata.Platform { + case transactionpb.TippedUser_TWITTER: + twitterUsername := intentRecord.SendPrivatePaymentMetadata.TipMetadata.Username + + // The owner must still own the Twitter username + ownsUsername, err := s.ownsTwitterUsername(ctx, owner, twitterUsername) + if err != nil { + log.WithError(err).Warn("failure determing twitter username ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsUsername { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_DENIED, + }, nil + } + + // todo: try to find an existing chat, but for now always create a new completely random one + var chatId chat.ChatId + rand.Read(chatId[:]) + + creationTs := time.Now() + + chatRecord := &chat.ChatRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + + IsVerified: true, + + CreatedAt: creationTs, + } + + memberRecords := []*chat.MemberRecord{ + { + ChatId: chatId, + MemberId: chat.GenerateMemberId(), + + Platform: chat.PlatformTwitter, + PlatformId: twitterUsername, + + JoinedAt: creationTs, + }, + { + ChatId: chatId, + MemberId: chat.GenerateMemberId(), + + Platform: chat.PlatformCode, + PlatformId: tipper.PublicKey().ToBase58(), + + JoinedAt: creationTs, + }, + } + + err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { + err := s.data.PutChatV2(ctx, chatRecord) + if err != nil { + return errors.Wrap(err, "error creating chat record") + } + + for _, memberRecord := range memberRecords { + err := s.data.PutChatMemberV2(ctx, memberRecord) + if err != nil { + return errors.Wrap(err, "error creating member record") + } + } + + return nil + }) + if err != nil { + log.WithError(err).Warn("failure creating new chat") + return nil, status.Error(codes.Internal, "") + } + + protoChat, err := s.toProtoChat( + ctx, + chatRecord, + memberRecords, + map[chat.Platform]string{ + chat.PlatformCode: owner.PublicKey().ToBase58(), + chat.PlatformTwitter: twitterUsername, + }, + ) + if err != nil { + log.WithError(err).Warn("failure constructing proto chat message") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_OK, + Chat: protoChat, + }, nil + default: + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } + + default: + return nil, status.Error(codes.InvalidArgument, "StartChatRequest.Parameters is nil") + } +} + +func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { + log := s.log.WithField("method", "SendMessage") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + switch req.Content[0].Type.(type) { + case *chatpb.Content_Text, *chatpb.Content_ThankYou: + default: + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_INVALID_CONTENT_TYPE, + }, nil + } + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + chatRecord, err := s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + switch chatRecord.ChatType { + case chat.ChatTypeTwoWay: + default: + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_INVALID_CHAT_TYPE, + }, nil + } + + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_DENIED, + }, nil + } + + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + + chatMessage := newProtoChatMessage(memberId, req.Content...) + + err = s.persistChatMessage(ctx, chatId, chatMessage) + if err != nil { + log.WithError(err).Warn("failure persisting chat message") + return nil, status.Error(codes.Internal, "") + } + + s.onPersistChatMessage(log, chatId, chatMessage) + + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_OK, + Message: chatMessage, + }, nil +} + +// todo: This belongs in the common chat utility, which currently only operates on v1 chats +func (s *server) persistChatMessage(ctx context.Context, chatId chat.ChatId, protoChatMessage *chatpb.ChatMessage) error { + if err := protoChatMessage.Validate(); err != nil { + return errors.Wrap(err, "proto chat message failed validation") + } + + messageId, err := chat.GetMessageIdFromProto(protoChatMessage.MessageId) + if err != nil { + return errors.Wrap(err, "invalid message id") + } + + var senderId *chat.MemberId + if protoChatMessage.SenderId != nil { + convertedSenderId, err := chat.GetMemberIdFromProto(protoChatMessage.SenderId) + if err != nil { + return errors.Wrap(err, "invalid member id") + } + senderId = &convertedSenderId + } + + // Clear out extracted metadata as a space optimization + cloned := proto.Clone(protoChatMessage).(*chatpb.ChatMessage) + cloned.MessageId = nil + cloned.SenderId = nil + cloned.Ts = nil + cloned.Cursor = nil + + marshalled, err := proto.Marshal(cloned) + if err != nil { + return errors.Wrap(err, "error marshalling proto chat message") + } + + // todo: Doesn't incoroporate reference. We might want to promote the proto a level above the content. + messageRecord := &chat.MessageRecord{ + ChatId: chatId, + MessageId: messageId, + + Sender: senderId, + + Data: marshalled, + + IsSilent: false, + } + + err = s.data.PutChatMessageV2(ctx, messageRecord) + if err != nil { + return errors.Wrap(err, "error persiting chat message") + } + return nil +} + +func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerRequest) (*chatpb.AdvancePointerResponse, error) { + log := s.log.WithField("method", "AdvancePointer") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.Pointer.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + pointerType := chat.GetPointerTypeFromProto(req.Pointer.Type) + log = log.WithField("pointer_type", pointerType.String()) + switch pointerType { + case chat.PointerTypeDelivered, chat.PointerTypeRead: + default: + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_INVALID_POINTER_TYPE, + }, nil + } + + pointerValue, err := chat.GetMessageIdFromProto(req.Pointer.Value) + if err != nil { + log.WithError(err).Warn("invalid pointer value") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("pointer_value", pointerValue.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_DENIED, + }, nil + } + + _, err = s.data.GetChatMessageByIdV2(ctx, chatId, pointerValue) + switch err { + case nil: + case chat.ErrMessageNotFound: + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_MESSAGE_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat message record") + return nil, status.Error(codes.Internal, "") + } + + isAdvanced, err := s.data.AdvanceChatPointerV2(ctx, chatId, memberId, pointerType, pointerValue) + if err != nil { + log.WithError(err).Warn("failure advancing chat pointer") + return nil, status.Error(codes.Internal, "") + } + + if isAdvanced { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Pointer{ + Pointer: req.Pointer, + }, + } + if err := s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + } + + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_OK, + }, nil +} + +func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityRequest) (*chatpb.RevealIdentityResponse, error) { + log := s.log.WithField("method", "RevealIdentity") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + platform := chat.GetPlatformFromProto(req.Identity.Platform) + + log = log.WithFields(logrus.Fields{ + "platform": platform.String(), + "username": req.Identity.Username, + }) + + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) + switch err { + case nil: + case chat.ErrMemberNotFound: + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DENIED, + }, nil + default: + log.WithError(err).Warn("failure getting member record") + return nil, status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMemberWithRecord(ctx, chatId, memberRecord, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DENIED, + }, nil + } + + switch platform { + case chat.PlatformTwitter: + ownsUsername, err := s.ownsTwitterUsername(ctx, owner, req.Identity.Username) + if err != nil { + log.WithError(err).Warn("failure determing twitter username ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsUsername { + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DENIED, + }, nil + } + default: + return nil, status.Error(codes.InvalidArgument, "RevealIdentityRequest.Identity.Platform must be TWITTER") + } + + // Idempotent RPC call using the same platform and username + if memberRecord.Platform == platform && memberRecord.PlatformId == req.Identity.Username { + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_OK, + }, nil + } + + // Identity was already revealed, and it isn't the specified platform and username + if memberRecord.Platform != chat.PlatformCode { + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DIFFERENT_IDENTITY_REVEALED, + }, nil + } + + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + + chatMessage := newProtoChatMessage( + memberId, + &chatpb.Content{ + Type: &chatpb.Content_IdentityRevealed{ + IdentityRevealed: &chatpb.IdentityRevealedContent{ + MemberId: req.MemberId, + Identity: req.Identity, + }, + }, + }, + ) + + err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { + err = s.data.UpgradeChatMemberIdentityV2(ctx, chatId, memberId, platform, req.Identity.Username) + switch err { + case nil: + case chat.ErrMemberIdentityAlreadyUpgraded: + return err + default: + return errors.Wrap(err, "error updating chat member identity") + } + + err := s.persistChatMessage(ctx, chatId, chatMessage) + if err != nil { + return errors.Wrap(err, "error persisting chat message") + } + return nil + }) + + if err == nil { + s.onPersistChatMessage(log, chatId, chatMessage) + } + + switch err { + case nil: + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_OK, + Message: chatMessage, + }, nil + case chat.ErrMemberIdentityAlreadyUpgraded: + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DIFFERENT_IDENTITY_REVEALED, + }, nil + default: + log.WithError(err).Warn("failure upgrading chat member identity") + return nil, status.Error(codes.Internal, "") + } +} + +func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { + log := s.log.WithField("method", "SetMuteState") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + // todo: Use chat record to determine if muting is allowed + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.SetMuteStateResponse{ + Result: chatpb.SetMuteStateResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return &chatpb.SetMuteStateResponse{ + Result: chatpb.SetMuteStateResponse_DENIED, + }, nil + } + + err = s.data.SetChatMuteStateV2(ctx, chatId, memberId, req.IsMuted) + if err != nil { + log.WithError(err).Warn("failure setting mute state") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.SetMuteStateResponse{ + Result: chatpb.SetMuteStateResponse_OK, + }, nil +} + +func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscriptionStateRequest) (*chatpb.SetSubscriptionStateResponse, error) { + log := s.log.WithField("method", "SetSubscriptionState") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + // todo: Use chat record to determine if unsubscribing is allowed + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.SetSubscriptionStateResponse{ + Result: chatpb.SetSubscriptionStateResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return &chatpb.SetSubscriptionStateResponse{ + Result: chatpb.SetSubscriptionStateResponse_DENIED, + }, nil + } + + err = s.data.SetChatSubscriptionStateV2(ctx, chatId, memberId, req.IsSubscribed) + if err != nil { + log.WithError(err).Warn("failure setting mute state") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.SetSubscriptionStateResponse{ + Result: chatpb.SetSubscriptionStateResponse_OK, + }, nil +} + +func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { + protoChat := &chatpb.ChatMetadata{ + ChatId: chatRecord.ChatId.ToProto(), + Type: chatRecord.ChatType.ToProto(), + Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(chatRecord.Id))}, + } + + switch chatRecord.ChatType { + case chat.ChatTypeTwoWay: + protoChat.Title = "Tip Chat" // todo: proper title with localization + + protoChat.CanMute = true + protoChat.CanUnsubscribe = true + default: + return nil, errors.Errorf("unsupported chat type: %s", chatRecord.ChatType.String()) + } + + for _, memberRecord := range memberRecords { + var isSelf bool + var identity *chatpb.ChatMemberIdentity + switch memberRecord.Platform { + case chat.PlatformCode: + myPublicKey, ok := myIdentitiesByPlatform[chat.PlatformCode] + isSelf = ok && myPublicKey == memberRecord.PlatformId + case chat.PlatformTwitter: + myTwitterUsername, ok := myIdentitiesByPlatform[chat.PlatformTwitter] + isSelf = ok && myTwitterUsername == memberRecord.PlatformId + + identity = &chatpb.ChatMemberIdentity{ + Platform: memberRecord.Platform.ToProto(), + Username: memberRecord.PlatformId, + } + default: + return nil, errors.Errorf("unsupported platform type: %s", memberRecord.Platform.String()) + } + + var pointers []*chatpb.Pointer + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + pointers = append(pointers, &chatpb.Pointer{ + Type: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberRecord.MemberId.ToProto(), + }) + } + + protoMember := &chatpb.ChatMember{ + MemberId: memberRecord.MemberId.ToProto(), + IsSelf: isSelf, + Identity: identity, + Pointers: pointers, + } + if protoMember.IsSelf { + protoMember.IsMuted = memberRecord.IsMuted + protoMember.IsSubscribed = !memberRecord.IsUnsubscribed + + if !memberRecord.IsUnsubscribed { + readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) + if memberRecord.ReadPointer != nil { + readPointer = *memberRecord.ReadPointer + } + unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, memberRecord.MemberId, readPointer) + if err != nil { + return nil, errors.Wrap(err, "error calculating unread count") + } + protoMember.NumUnread = unreadCount + } + } + + protoChat.Members = append(protoChat.Members, protoMember) + } + + return protoChat, nil +} + +func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.ChatMessage, error) { + messageRecords, err := s.data.GetAllChatMessagesV2( + ctx, + chatId, + queryOptions..., + ) + if err == chat.ErrMessageNotFound { + return nil, err + } + + var userLocale *language.Tag // Loaded lazily when required + var res []*chatpb.ChatMessage + for _, messageRecord := range messageRecords { + var protoChatMessage chatpb.ChatMessage + err = proto.Unmarshal(messageRecord.Data, &protoChatMessage) + if err != nil { + return nil, errors.Wrap(err, "error unmarshalling proto chat message") + } + + ts, err := messageRecord.GetTimestamp() + if err != nil { + return nil, errors.Wrap(err, "error getting message timestamp") + } + + for _, content := range protoChatMessage.Content { + switch typed := content.Type.(type) { + case *chatpb.Content_Localized: + if userLocale == nil { + loadedUserLocale, err := s.data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) + if err != nil { + return nil, errors.Wrap(err, "error getting user locale") + } + userLocale = &loadedUserLocale + } + + typed.Localized.KeyOrText = localization.LocalizeWithFallback( + *userLocale, + localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), + typed.Localized.KeyOrText, + ) + } + } + + protoChatMessage.MessageId = messageRecord.MessageId.ToProto() + if messageRecord.Sender != nil { + protoChatMessage.SenderId = messageRecord.Sender.ToProto() + } + protoChatMessage.Ts = timestamppb.New(ts) + protoChatMessage.Cursor = &chatpb.Cursor{Value: messageRecord.MessageId[:]} + + res = append(res, &protoChatMessage) + } + + return res, nil +} + +func (s *server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: chatMessage, + }, + } + if err := s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + // todo: send the push +} + +func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { + identities := map[chat.Platform]string{ + chat.PlatformCode: owner.PublicKey().ToBase58(), + } + + twitterUserame, ok, err := s.getOwnedTwitterUsername(ctx, owner) + if err != nil { + return nil, err + } + if ok { + identities[chat.PlatformTwitter] = twitterUserame + } + + return identities, nil +} + +func (s *server) ownsChatMemberWithoutRecord(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { + memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) + switch err { + case nil: + case chat.ErrMemberNotFound: + return false, nil + default: + return false, errors.Wrap(err, "error getting member record") + } + + return s.ownsChatMemberWithRecord(ctx, chatId, memberRecord, owner) +} + +func (s *server) ownsChatMemberWithRecord(ctx context.Context, chatId chat.ChatId, memberRecord *chat.MemberRecord, owner *common.Account) (bool, error) { + switch memberRecord.Platform { + case chat.PlatformCode: + return memberRecord.PlatformId == owner.PublicKey().ToBase58(), nil + case chat.PlatformTwitter: + return s.ownsTwitterUsername(ctx, owner, memberRecord.PlatformId) + default: + return false, nil + } +} + +// todo: This logic should live elsewhere in somewhere more common +func (s *server) ownsTwitterUsername(ctx context.Context, owner *common.Account, username string) (bool, error) { + ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) + if err != nil { + return false, errors.Wrap(err, "error deriving twitter tip address") + } + + twitterRecord, err := s.data.GetTwitterUserByUsername(ctx, username) + switch err { + case nil: + case twitter.ErrUserNotFound: + return false, nil + default: + return false, errors.Wrap(err, "error getting twitter user") + } + + return twitterRecord.TipAddress == ownerTipAccount.PublicKey().ToBase58(), nil +} + +// todo: This logic should live elsewhere in somewhere more common +func (s *server) getOwnedTwitterUsername(ctx context.Context, owner *common.Account) (string, bool, error) { + ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) + if err != nil { + return "", false, errors.Wrap(err, "error deriving twitter tip address") + } + + twitterRecord, err := s.data.GetTwitterUserByTipAddress(ctx, ownerTipAccount.PublicKey().ToBase58()) + switch err { + case nil: + return twitterRecord.Username, true, nil + case twitter.ErrUserNotFound: + return "", false, nil + default: + return "", false, errors.Wrap(err, "error getting twitter user") + } +} + +func newProtoChatMessage(sender chat.MemberId, content ...*chatpb.Content) *chatpb.ChatMessage { + messageId := chat.GenerateMessageId() + ts, _ := messageId.GetTimestamp() + + return &chatpb.ChatMessage{ + MessageId: messageId.ToProto(), + SenderId: sender.ToProto(), + Content: content, + Ts: timestamppb.New(ts), + Cursor: &chatpb.Cursor{Value: messageId[:]}, + } +} diff --git a/pkg/code/server/grpc/chat/v2/server_test.go b/pkg/code/server/grpc/chat/v2/server_test.go new file mode 100644 index 00000000..aacc4f95 --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/server_test.go @@ -0,0 +1 @@ +package chat_v2 diff --git a/pkg/code/server/grpc/chat/v2/stream.go b/pkg/code/server/grpc/chat/v2/stream.go new file mode 100644 index 00000000..3d39428d --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/stream.go @@ -0,0 +1,171 @@ +package chat_v2 + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" +) + +const ( + // todo: configurable + streamBufferSize = 64 + streamPingDelay = 5 * time.Second + streamKeepAliveRecvTimeout = 10 * time.Second + streamNotifyTimeout = time.Second +) + +type chatEventStream struct { + sync.Mutex + + closed bool + streamCh chan *chatpb.ChatStreamEvent +} + +func newChatEventStream(bufferSize int) *chatEventStream { + return &chatEventStream{ + streamCh: make(chan *chatpb.ChatStreamEvent, bufferSize), + } +} + +func (s *chatEventStream) notify(event *chatpb.ChatStreamEvent, timeout time.Duration) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + + s.Lock() + + if s.closed { + s.Unlock() + return errors.New("cannot notify closed stream") + } + + select { + case s.streamCh <- m: + case <-time.After(timeout): + s.Unlock() + s.close() + return errors.New("timed out sending message to streamCh") + } + + s.Unlock() + return nil +} + +func (s *chatEventStream) close() { + s.Lock() + defer s.Unlock() + + if s.closed { + return + } + + s.closed = true + close(s.streamCh) +} + +func boundedStreamChatEventsRecv( + ctx context.Context, + streamer chatpb.Chat_StreamChatEventsServer, + timeout time.Duration, +) (req *chatpb.StreamChatEventsRequest, err error) { + done := make(chan struct{}) + go func() { + req, err = streamer.Recv() + close(done) + }() + + select { + case <-done: + return req, err + case <-ctx.Done(): + return nil, status.Error(codes.Canceled, "") + case <-time.After(timeout): + return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") + } +} + +type chatEventNotification struct { + chatId chat.ChatId + event *chatpb.ChatStreamEvent + ts time.Time +} + +func (s *server) asyncNotifyAll(chatId chat.ChatId, event *chatpb.ChatStreamEvent) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, m, time.Now()}) + if !ok { + return errors.New("chat event channel is full") + } + return nil +} + +func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan interface{}) { + log := s.log.WithFields(logrus.Fields{ + "method": "asyncChatEventStreamNotifier", + "worker": workerId, + }) + + for value := range channel { + typedValue, ok := value.(*chatEventNotification) + if !ok { + log.Warn("channel did not receive expected struct") + continue + } + + log := log.WithField("chat_id", typedValue.chatId.String()) + + if time.Since(typedValue.ts) > time.Second { + log.Warn("channel notification latency is elevated") + } + + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, typedValue.chatId.String()) { + continue + } + + if err := stream.notify(typedValue.event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + } +} + +// Very naive implementation to start +func monitorChatEventStreamHealth( + ctx context.Context, + log *logrus.Entry, + ssRef string, + streamer chatpb.Chat_StreamChatEventsServer, +) <-chan struct{} { + streamHealthChan := make(chan struct{}) + go func() { + defer close(streamHealthChan) + + for { + // todo: configurable timeout + req, err := boundedStreamChatEventsRecv(ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + switch req.Type.(type) { + case *chatpb.StreamChatEventsRequest_Pong: + log.Tracef("received pong from client (stream=%s)", ssRef) + default: + // Client sent something unexpected. Terminate the stream + return + } + } + }() + return streamHealthChan +} diff --git a/pkg/code/server/grpc/messaging/server.go b/pkg/code/server/grpc/messaging/server.go index c89f69ce..87a35c3e 100644 --- a/pkg/code/server/grpc/messaging/server.go +++ b/pkg/code/server/grpc/messaging/server.go @@ -17,19 +17,19 @@ import ( "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" "github.com/code-payments/code-server/pkg/cache" - "github.com/code-payments/code-server/pkg/grpc/client" - "github.com/code-payments/code-server/pkg/retry" - "github.com/code-payments/code-server/pkg/retry/backoff" - "github.com/code-payments/code-server/pkg/code/auth" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/messaging" "github.com/code-payments/code-server/pkg/code/data/rendezvous" "github.com/code-payments/code-server/pkg/code/thirdparty" + "github.com/code-payments/code-server/pkg/grpc/client" + "github.com/code-payments/code-server/pkg/retry" + "github.com/code-payments/code-server/pkg/retry/backoff" ) const ( @@ -285,7 +285,7 @@ func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_O err := streamer.Send(&messagingpb.OpenMessageStreamWithKeepAliveResponse{ ResponseOrPing: &messagingpb.OpenMessageStreamWithKeepAliveResponse_Ping{ - Ping: &messagingpb.ServerPing{ + Ping: &commonpb.ServerPing{ Timestamp: timestamppb.Now(), PingDelay: durationpb.New(messageStreamPingDelay), }, diff --git a/pkg/code/server/grpc/messaging/testutil.go b/pkg/code/server/grpc/messaging/testutil.go index f1e7cf97..4968ab2a 100644 --- a/pkg/code/server/grpc/messaging/testutil.go +++ b/pkg/code/server/grpc/messaging/testutil.go @@ -373,7 +373,7 @@ func (c *clientEnv) receiveMessagesInRealTime(t *testing.T, rendezvousKey *commo case *messagingpb.OpenMessageStreamWithKeepAliveResponse_Ping: require.NoError(t, streamer.streamWithKeepAlives.Send(&messagingpb.OpenMessageStreamWithKeepAliveRequest{ RequestOrPong: &messagingpb.OpenMessageStreamWithKeepAliveRequest_Pong{ - Pong: &messagingpb.ClientPong{ + Pong: &commonpb.ClientPong{ Timestamp: timestamppb.Now(), }, }, @@ -467,7 +467,7 @@ func (c *clientEnv) waitUntilStreamTerminationOrTimeout(t *testing.T, rendezvous if keepStreamAlive { require.NoError(t, streamer.streamWithKeepAlives.Send(&messagingpb.OpenMessageStreamWithKeepAliveRequest{ RequestOrPong: &messagingpb.OpenMessageStreamWithKeepAliveRequest_Pong{ - Pong: &messagingpb.ClientPong{ + Pong: &commonpb.ClientPong{ Timestamp: timestamppb.Now(), }, }, diff --git a/pkg/code/server/grpc/transaction/v2/history_test.go b/pkg/code/server/grpc/transaction/v2/history_test.go index f2cddee5..80ae0442 100644 --- a/pkg/code/server/grpc/transaction/v2/history_test.go +++ b/pkg/code/server/grpc/transaction/v2/history_test.go @@ -12,7 +12,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" currency_lib "github.com/code-payments/code-server/pkg/currency" "github.com/code-payments/code-server/pkg/kin" timelock_token_v1 "github.com/code-payments/code-server/pkg/solana/timelock/v1" @@ -142,7 +142,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { sendingPhone.tip456KinToCodeUser(t, receivingPhone, twitterUsername).requireSuccess(t) - chatMessageRecords, err := server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.CashTransactionsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err := server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 10) @@ -236,7 +236,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 32.1, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(321), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 3) @@ -267,7 +267,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 123.0, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(123), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.CashTransactionsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 7) @@ -334,7 +334,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 2.1, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(42), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.TipsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.TipsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) @@ -347,7 +347,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 45.6, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(456), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 5) @@ -396,7 +396,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 12_345.0, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(12_345), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), false)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), false)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) @@ -409,7 +409,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 77.69, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.EqualValues(t, 77690000, protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.TipsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.TipsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) diff --git a/pkg/code/server/grpc/transaction/v2/swap.go b/pkg/code/server/grpc/transaction/v2/swap.go index bfc76556..8bb6e230 100644 --- a/pkg/code/server/grpc/transaction/v2/swap.go +++ b/pkg/code/server/grpc/transaction/v2/swap.go @@ -22,7 +22,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" "github.com/code-payments/code-server/pkg/code/data/account" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" push_util "github.com/code-payments/code-server/pkg/code/push" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -511,7 +511,7 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con // Inspect the chat history for a USDC deposited message. If that message // doesn't exist, then avoid sending the swap in progress chat message, since // it can lead to user confusion. - chatMessageRecords, err := s.data.GetAllChatMessages(ctx, chatId, query.WithDirection(query.Descending), query.WithLimit(1)) + chatMessageRecords, err := s.data.GetAllChatMessagesV1(ctx, chatId, query.WithDirection(query.Descending), query.WithLimit(1)) switch err { case nil: var protoChatMessage chatpb.ChatMessage @@ -521,12 +521,12 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con } switch typed := protoChatMessage.Content[0].Type.(type) { - case *chatpb.Content_Localized: - if typed.Localized.KeyOrText != localization.ChatMessageUsdcDeposited { + case *chatpb.Content_ServerLocalized: + if typed.ServerLocalized.KeyOrText != localization.ChatMessageUsdcDeposited { return nil } } - case chat.ErrMessageNotFound: + case chat_v1.ErrMessageNotFound: default: return errors.Wrap(err, "error fetching chat messages") } diff --git a/pkg/code/server/grpc/transaction/v2/testutil.go b/pkg/code/server/grpc/transaction/v2/testutil.go index 38c42626..1da54d65 100644 --- a/pkg/code/server/grpc/transaction/v2/testutil.go +++ b/pkg/code/server/grpc/transaction/v2/testutil.go @@ -35,7 +35,7 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/action" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/currency" "github.com/code-payments/code-server/pkg/code/data/deposit" @@ -6173,7 +6173,7 @@ func isSubmitIntentError(resp *transactionpb.SubmitIntentResponse, err error) bo return err != nil || resp.GetError() != nil } -func getProtoChatMessage(t *testing.T, record *chat.Message) *chatpb.ChatMessage { +func getProtoChatMessage(t *testing.T, record *chat_v1.Message) *chatpb.ChatMessage { var protoMessage chatpb.ChatMessage require.NoError(t, proto.Unmarshal(record.Data, &protoMessage)) return &protoMessage diff --git a/pkg/pointer/pointer.go b/pkg/pointer/pointer.go index a3f8da02..a353d347 100644 --- a/pkg/pointer/pointer.go +++ b/pkg/pointer/pointer.go @@ -32,6 +32,36 @@ func StringCopy(value *string) *string { return String(*value) } +// Uint8 returns a pointer to the provided uint8 value +func Uint8(value uint8) *uint8 { + return &value +} + +// Uint8OrDefault returns the pointer if not nil, otherwise the default value +func Uint8OrDefault(value *uint8, defaultValue uint8) *uint8 { + if value != nil { + return value + } + return &defaultValue +} + +// Uint8IfValid returns a pointer to the value if it's valid, otherwise nil +func Uint8IfValid(valid bool, value uint8) *uint8 { + if valid { + return &value + } + return nil +} + +// Uint8Copy returns a pointer that's a copy of the provided value +func Uint8Copy(value *uint8) *uint8 { + if value == nil { + return nil + } + + return Uint8(*value) +} + // Uint64 returns a pointer to the provided uint64 value func Uint64(value uint64) *uint64 { return &value