diff --git a/server/channels/store/context.go b/server/channels/store/context.go new file mode 100644 index 00000000000..7b68d01d3f0 --- /dev/null +++ b/server/channels/store/context.go @@ -0,0 +1,46 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package store + +import ( + "context" + + "github.com/mattermost/mattermost/server/public/shared/request" +) + +// storeContextKey is the base type for all context keys for the store. +type storeContextKey string + +// contextValue is a type to hold some pre-determined context values. +type contextValue string + +// Different possible values of contextValue. +const ( + useMaster contextValue = "useMaster" +) + +// WithMaster adds the context value that master DB should be selected for this request. +// +// Deprecated: This method is deprecated and there's ongoing change to use `request.CTX` across +// instead of `context.Context`. Please use `RequestContextWithMaster` instead. +func WithMaster(ctx context.Context) context.Context { + return context.WithValue(ctx, storeContextKey(useMaster), true) +} + +// RequestContextWithMaster adds the context value that master DB should be selected for this request. +func RequestContextWithMaster(c request.CTX) request.CTX { + ctx := WithMaster(c.Context()) + c = c.WithContext(ctx) + return c +} + +// HasMaster is a helper function to check whether master DB should be selected or not. +func HasMaster(ctx context.Context) bool { + if v := ctx.Value(storeContextKey(useMaster)); v != nil { + if res, ok := v.(bool); ok && res { + return true + } + } + return false +} diff --git a/server/channels/store/context_test.go b/server/channels/store/context_test.go new file mode 100644 index 00000000000..44c4e17257c --- /dev/null +++ b/server/channels/store/context_test.go @@ -0,0 +1,55 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package store + +import ( + "context" + "testing" + + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/stretchr/testify/assert" +) + +func TestContextMaster(t *testing.T) { + ctx := context.Background() + + m := WithMaster(ctx) + assert.True(t, HasMaster(m)) +} + +func TestRequestContextWithMaster(t *testing.T) { + t.Run("set and get", func(t *testing.T) { + var rctx request.CTX = request.TestContext(t) + + rctx = RequestContextWithMaster(rctx) + assert.True(t, HasMaster(rctx.Context())) + }) + + t.Run("values get copied from original context", func(t *testing.T) { + var rctx request.CTX = request.TestContext(t) + rctx = RequestContextWithMaster(rctx) + rctxCopy := rctx + + assert.True(t, HasMaster(rctx.Context())) + assert.True(t, HasMaster(rctxCopy.Context())) + }) + + t.Run("directly assigning does not cause the copy to alter the original context", func(t *testing.T) { + var rctx request.CTX = request.TestContext(t) + rctxCopy := rctx + rctxCopy = RequestContextWithMaster(rctxCopy) + + assert.False(t, HasMaster(rctx.Context())) + assert.True(t, HasMaster(rctxCopy.Context())) + }) + + t.Run("directly assigning does not cause the original context to alter the copy", func(t *testing.T) { + var rctx request.CTX = request.TestContext(t) + rctxCopy := rctx + rctx = RequestContextWithMaster(rctx) + + assert.True(t, HasMaster(rctx.Context())) + assert.False(t, HasMaster(rctxCopy.Context())) + }) +} diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index bb4f76fe2fe..464f1867658 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -2472,11 +2472,11 @@ func (s *RetryLayerChannelStore) GetTeamForChannel(channelID string) (*model.Tea } -func (s *RetryLayerChannelStore) GetTeamMembersForChannel(channelID string) ([]string, error) { +func (s *RetryLayerChannelStore) GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) { tries := 0 for { - result, err := s.ChannelStore.GetTeamMembersForChannel(channelID) + result, err := s.ChannelStore.GetTeamMembersForChannel(rctx, channelID) if err == nil { return result, nil } diff --git a/server/channels/store/searchlayer/channel_layer.go b/server/channels/store/searchlayer/channel_layer.go index 60db3a8d323..b06fab1e644 100644 --- a/server/channels/store/searchlayer/channel_layer.go +++ b/server/channels/store/searchlayer/channel_layer.go @@ -47,7 +47,7 @@ func (c *SearchChannelStore) indexChannel(rctx request.CTX, channel *model.Chann } } - teamMemberIDs, err = c.GetTeamMembersForChannel(channel.Id) + teamMemberIDs, err = c.GetTeamMembersForChannel(rctx, channel.Id) if err != nil { rctx.Logger().Warn("Encountered error while indexing channel", mlog.String("channel_id", channel.Id), mlog.Err(err)) return @@ -66,6 +66,30 @@ func (c *SearchChannelStore) indexChannel(rctx request.CTX, channel *model.Chann } } +func (c *SearchChannelStore) bulkIndexChannels(rctx request.CTX, channels []*model.Channel, teamMemberIDs []string) { + // Util function to get userIDs, only for private channels + getUserIDsForPrivateChannel := func(channel *model.Channel) ([]string, error) { + if channel.Type != model.ChannelTypePrivate { + return []string{}, nil + } + return c.GetAllChannelMemberIdsByChannelId(channel.Id) + } + + for _, engine := range c.rootStore.searchEngine.GetActiveEngines() { + if !engine.IsIndexingEnabled() { + continue + } + + runIndexFn(rctx, engine, func(engineCopy searchengine.SearchEngineInterface) { + appErr := engineCopy.SyncBulkIndexChannels(rctx, channels, getUserIDsForPrivateChannel, teamMemberIDs) + if appErr != nil { + rctx.Logger().Error("Failed to synchronously bulk-index channels.", mlog.String("search_engine", engineCopy.GetName()), mlog.Err(appErr)) + return + } + }) + } +} + func (c *SearchChannelStore) Save(rctx request.CTX, channel *model.Channel, maxChannels int64, channelOptions ...model.ChannelOption) (*model.Channel, error) { newChannel, err := c.ChannelStore.Save(rctx, channel, maxChannels, channelOptions...) if err == nil { diff --git a/server/channels/store/searchlayer/layer.go b/server/channels/store/searchlayer/layer.go index 4ed2a5d6cfe..f30aa84f09e 100644 --- a/server/channels/store/searchlayer/layer.go +++ b/server/channels/store/searchlayer/layer.go @@ -9,6 +9,7 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/public/utils" "github.com/mattermost/mattermost/server/v8/channels/store" "github.com/mattermost/mattermost/server/v8/platform/services/searchengine" ) @@ -111,6 +112,35 @@ func (s *SearchStore) indexUser(rctx request.CTX, user *model.User) { } } +func (s *SearchStore) indexChannelsForTeam(rctx request.CTX, teamID string) { + const perPage = 100 + var ( + channels []*model.Channel + ) + + channels, err := utils.Pager(func(page int) ([]*model.Channel, error) { + return s.channel.GetPublicChannelsForTeam(teamID, page*perPage, perPage) + }, perPage) + if err != nil { + rctx.Logger().Warn("Encountered error while retrieving public channels for indexing", mlog.String("team_id", teamID), mlog.Err(err)) + return + } + + if len(channels) == 0 { + return + } + + // Use master context to avoid replica lag issues when reading team members + masterRctx := store.RequestContextWithMaster(rctx) + teamMemberIDs, err := s.channel.GetTeamMembersForChannel(masterRctx, channels[0].Id) + if err != nil { + rctx.Logger().Warn("Encountered error while retrieving team members for channel", mlog.String("channel_id", channels[0].Id), mlog.Err(err)) + return + } + + s.channel.bulkIndexChannels(rctx, channels, teamMemberIDs) +} + // Runs an indexing function synchronously or asynchronously depending on the engine func runIndexFn(rctx request.CTX, engine searchengine.SearchEngineInterface, indexFn func(searchengine.SearchEngineInterface)) { if engine.IsIndexingSync() { diff --git a/server/channels/store/searchlayer/team_layer.go b/server/channels/store/searchlayer/team_layer.go index 46601e1dbb7..d0a2b9f1c50 100644 --- a/server/channels/store/searchlayer/team_layer.go +++ b/server/channels/store/searchlayer/team_layer.go @@ -17,7 +17,11 @@ type SearchTeamStore struct { func (s SearchTeamStore) SaveMember(rctx request.CTX, teamMember *model.TeamMember, maxUsersPerTeam int) (*model.TeamMember, error) { member, err := s.TeamStore.SaveMember(rctx, teamMember, maxUsersPerTeam) if err == nil { - s.rootStore.indexUserFromID(rctx, member.UserId) + // Nothing to do if search engine is not active + if s.rootStore.searchEngine.ActiveEngine() != "database" && s.rootStore.searchEngine.ActiveEngine() != "none" { + s.rootStore.indexUserFromID(rctx, member.UserId) + s.rootStore.indexChannelsForTeam(rctx, member.TeamId) + } } return member, err } @@ -33,15 +37,31 @@ func (s SearchTeamStore) UpdateMember(rctx request.CTX, teamMember *model.TeamMe func (s SearchTeamStore) RemoveMember(rctx request.CTX, teamId string, userId string) error { err := s.TeamStore.RemoveMember(rctx, teamId, userId) if err == nil { - s.rootStore.indexUserFromID(rctx, userId) + // Nothing to do if search engine is not active + if s.rootStore.searchEngine.ActiveEngine() != "database" && s.rootStore.searchEngine.ActiveEngine() != "none" { + s.rootStore.indexUserFromID(rctx, userId) + s.rootStore.indexChannelsForTeam(rctx, teamId) + } } return err } func (s SearchTeamStore) RemoveAllMembersByUser(rctx request.CTX, userId string) error { + if s.rootStore.searchEngine.ActiveEngine() != "database" && s.rootStore.searchEngine.ActiveEngine() != "none" { + memberships, err := s.TeamStore.GetTeamsForUser(rctx, userId, "", true) + if err != nil { + return err + } + for _, membership := range memberships { + s.rootStore.indexChannelsForTeam(rctx, membership.TeamId) + } + } + err := s.TeamStore.RemoveAllMembersByUser(rctx, userId) if err == nil { - s.rootStore.indexUserFromID(rctx, userId) + if s.rootStore.searchEngine.ActiveEngine() != "database" && s.rootStore.searchEngine.ActiveEngine() != "none" { + s.rootStore.indexUserFromID(rctx, userId) + } } return err } diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go index 1c252f2af21..59d069aa5c4 100644 --- a/server/channels/store/sqlstore/channel_store.go +++ b/server/channels/store/sqlstore/channel_store.go @@ -3046,9 +3046,9 @@ func (s SqlChannelStore) GetMembersForUserWithCursorPagination(userId string, pe return dbMembers.ToModel(), nil } -func (s SqlChannelStore) GetTeamMembersForChannel(channelID string) ([]string, error) { +func (s SqlChannelStore) GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) { teamMemberIDs := []string{} - if err := s.GetReplica().Select(&teamMemberIDs, `SELECT tm.UserId + if err := s.DBXFromContext(rctx.Context()).Select(&teamMemberIDs, `SELECT tm.UserId FROM Channels c, Teams t, TeamMembers tm WHERE c.TeamId=t.Id diff --git a/server/channels/store/sqlstore/context.go b/server/channels/store/sqlstore/context.go index d35694202ae..7509804e44b 100644 --- a/server/channels/store/sqlstore/context.go +++ b/server/channels/store/sqlstore/context.go @@ -7,17 +7,7 @@ import ( "context" "github.com/mattermost/mattermost/server/public/shared/request" -) - -// storeContextKey is the base type for all context keys for the store. -type storeContextKey string - -// contextValue is a type to hold some pre-determined context values. -type contextValue string - -// Different possible values of contextValue. -const ( - useMaster contextValue = "useMaster" + "github.com/mattermost/mattermost/server/v8/channels/store" ) // WithMaster adds the context value that master DB should be selected for this request. @@ -25,24 +15,17 @@ const ( // Deprecated: This method is deprecated and there's ongoing change to use `request.CTX` across // instead of `context.Context`. Please use `RequestContextWithMaster` instead. func WithMaster(ctx context.Context) context.Context { - return context.WithValue(ctx, storeContextKey(useMaster), true) + return store.WithMaster(ctx) } // RequestContextWithMaster adds the context value that master DB should be selected for this request. func RequestContextWithMaster(c request.CTX) request.CTX { - ctx := WithMaster(c.Context()) - c = c.WithContext(ctx) - return c + return store.RequestContextWithMaster(c) } // HasMaster is a helper function to check whether master DB should be selected or not. func HasMaster(ctx context.Context) bool { - if v := ctx.Value(storeContextKey(useMaster)); v != nil { - if res, ok := v.(bool); ok && res { - return true - } - } - return false + return store.HasMaster(ctx) } // DBXFromContext is a helper utility that returns the sqlx DB handle from a given context. diff --git a/server/channels/store/store.go b/server/channels/store/store.go index c57c4ae5d8e..7290f45c1b2 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -266,7 +266,7 @@ type ChannelStore interface { AnalyticsDeletedTypeCount(teamID string, channelType model.ChannelType) (int64, error) AnalyticsCountAll(teamID string) (map[model.ChannelType]int64, error) GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error) - GetTeamMembersForChannel(channelID string) ([]string, error) + GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) GetMembersForUserWithPagination(userID string, page, perPage int) (model.ChannelMembersWithTeamData, error) GetMembersForUserWithCursorPagination(userId string, perPage int, fromChanneID string) (model.ChannelMembersWithTeamData, error) Autocomplete(rctx request.CTX, userID, term string, includeDeleted, isGuest bool) (model.ChannelListWithTeamData, error) diff --git a/server/channels/store/storetest/mocks/ChannelStore.go b/server/channels/store/storetest/mocks/ChannelStore.go index f271d219fcd..7f970734b6c 100644 --- a/server/channels/store/storetest/mocks/ChannelStore.go +++ b/server/channels/store/storetest/mocks/ChannelStore.go @@ -2148,9 +2148,9 @@ func (_m *ChannelStore) GetTeamForChannel(channelID string) (*model.Team, error) return r0, r1 } -// GetTeamMembersForChannel provides a mock function with given fields: channelID -func (_m *ChannelStore) GetTeamMembersForChannel(channelID string) ([]string, error) { - ret := _m.Called(channelID) +// GetTeamMembersForChannel provides a mock function with given fields: rctx, channelID +func (_m *ChannelStore) GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) { + ret := _m.Called(rctx, channelID) if len(ret) == 0 { panic("no return value specified for GetTeamMembersForChannel") @@ -2158,19 +2158,19 @@ func (_m *ChannelStore) GetTeamMembersForChannel(channelID string) ([]string, er var r0 []string var r1 error - if rf, ok := ret.Get(0).(func(string) ([]string, error)); ok { - return rf(channelID) + if rf, ok := ret.Get(0).(func(request.CTX, string) ([]string, error)); ok { + return rf(rctx, channelID) } - if rf, ok := ret.Get(0).(func(string) []string); ok { - r0 = rf(channelID) + if rf, ok := ret.Get(0).(func(request.CTX, string) []string); ok { + r0 = rf(rctx, channelID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]string) } } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(channelID) + if rf, ok := ret.Get(1).(func(request.CTX, string) error); ok { + r1 = rf(rctx, channelID) } else { r1 = ret.Error(1) } diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 5130bf3ce70..6887d45059e 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -2033,10 +2033,10 @@ func (s *TimerLayerChannelStore) GetTeamForChannel(channelID string) (*model.Tea return result, err } -func (s *TimerLayerChannelStore) GetTeamMembersForChannel(channelID string) ([]string, error) { +func (s *TimerLayerChannelStore) GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) { start := time.Now() - result, err := s.ChannelStore.GetTeamMembersForChannel(channelID) + result, err := s.ChannelStore.GetTeamMembersForChannel(rctx, channelID) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { diff --git a/server/enterprise/elasticsearch/common/common.go b/server/enterprise/elasticsearch/common/common.go index 4dcdcbf29c0..3070f212739 100644 --- a/server/enterprise/elasticsearch/common/common.go +++ b/server/enterprise/elasticsearch/common/common.go @@ -34,8 +34,17 @@ const ( // At the moment, this number is hardcoded. If needed, we can expose // this to the config. BulkFlushInterval = 5 * time.Second + + // Size of the largest request to be done, in bytes + BulkFlushBytes = 10 * 1024 * 1024 // 10 MiB ) +type BulkSettings struct { + FlushBytes int + FlushInterval time.Duration + FlushNumReqs int +} + var ( urlRe = regexp.MustCompile(URLRegexpRE) markdownLinkRe = regexp.MustCompile(URLMarkdownLinkRE) diff --git a/server/enterprise/elasticsearch/common/indexing_job.go b/server/enterprise/elasticsearch/common/indexing_job.go index 5482ccd81ec..983521268ce 100644 --- a/server/enterprise/elasticsearch/common/indexing_job.go +++ b/server/enterprise/elasticsearch/common/indexing_job.go @@ -569,7 +569,8 @@ func BulkIndexChannels(config *model.Config, } } - teamMemberIDs, err := store.Channel().GetTeamMembersForChannel(channel.Id) + rctx := request.EmptyContext(logger) + teamMemberIDs, err := store.Channel().GetTeamMembersForChannel(rctx, channel.Id) if err != nil { return nil, model.NewAppError("IndexerWorker.BulkIndexChannels", "ent.elasticsearch.getAllTeamMembers.error", nil, "", http.StatusInternalServerError).Wrap(err) } diff --git a/server/enterprise/elasticsearch/common/indexing_job_test.go b/server/enterprise/elasticsearch/common/indexing_job_test.go index 24c302f1914..a8c9597459e 100644 --- a/server/enterprise/elasticsearch/common/indexing_job_test.go +++ b/server/enterprise/elasticsearch/common/indexing_job_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/mattermost/mattermost/server/public/model" @@ -37,8 +38,8 @@ func TestBulkIndexChannelsWithDeletedChannels(t *testing.T) { // Since these are open channels, GetAllChannelMemberIdsByChannelId won't be called // But GetTeamMembersForChannel will be called for both channels - mockChannelStore.On("GetTeamMembersForChannel", "ch1").Return([]string{"team1"}, nil) - mockChannelStore.On("GetTeamMembersForChannel", "ch2").Return([]string{"team1"}, nil) + mockChannelStore.On("GetTeamMembersForChannel", mock.AnythingOfType("*request.Context"), "ch1").Return([]string{"team1"}, nil) + mockChannelStore.On("GetTeamMembersForChannel", mock.AnythingOfType("*request.Context"), "ch2").Return([]string{"team1"}, nil) // Track which channels were actually indexed indexedChannels := make(map[string]bool) diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk.go b/server/enterprise/elasticsearch/elasticsearch/bulk.go index 9488e68e8f1..342dd15d291 100644 --- a/server/enterprise/elasticsearch/elasticsearch/bulk.go +++ b/server/enterprise/elasticsearch/elasticsearch/bulk.go @@ -4,130 +4,58 @@ package elasticsearch import ( - "context" - "sync" + "fmt" "time" elastic "github.com/elastic/go-elasticsearch/v8" - "github.com/elastic/go-elasticsearch/v8/typedapi/core/bulk" - "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/mattermost/mattermost/server/public/model" + esTypes "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" ) -type Bulk struct { - mut sync.Mutex - - logger mlog.LoggerIFace - client *elastic.TypedClient - bulkClient *bulk.Bulk - settings model.ElasticsearchSettings - - quitFlusher chan struct{} - quitFlusherWg sync.WaitGroup - - pendingRequests int -} - -func NewBulk(settings model.ElasticsearchSettings, +type BulkClient interface { + IndexOp(op esTypes.IndexOperation, doc any) error + DeleteOp(op esTypes.DeleteOperation) error + Flush() error + Stop() error +} + +// NewBulk returns a BulkClient, with the specific implementation depending on +// the specified thresholds in bulkSettings. +// NewBulk will return an error if bulkSettings.FlushNumReqs and +// bulkSettings.FlushBytes are both non-zero: the support of these thresholds +// by the implementations of BulkClient is mutually exclusive. +func NewBulk(bulkSettings common.BulkSettings, + client *elastic.TypedClient, + reqTimeout time.Duration, logger mlog.LoggerIFace, - client *elastic.TypedClient) *Bulk { - b := &Bulk{ - settings: settings, - logger: logger, - client: client, - bulkClient: client.Bulk(), - quitFlusher: make(chan struct{}), - } - - b.quitFlusherWg.Add(1) - go b.periodicFlusher() - - return b -} - -// IndexOp is a helper function to add an IndexOperation to the current bulk request. -// doc argument can be a []byte, json.RawMessage or a struct. -func (r *Bulk) IndexOp(op types.IndexOperation, doc any) error { - r.mut.Lock() - defer r.mut.Unlock() - - if err := r.bulkClient.IndexOp(op, doc); err != nil { - return err - } - - return r.flushIfNecessary() -} - -// DeleteOp is a helper function to add a DeleteOperation to the current bulk request. -func (r *Bulk) DeleteOp(op types.DeleteOperation) error { - r.mut.Lock() - defer r.mut.Unlock() - - if err := r.bulkClient.DeleteOp(op); err != nil { - return err - } - - return r.flushIfNecessary() -} - -// flushIfNecessary flushes the pending buffer if needed. -// It MUST be called with an already acquired mutex. -func (r *Bulk) flushIfNecessary() error { - r.pendingRequests++ - - if r.pendingRequests > *r.settings.LiveIndexingBatchSize { - return r._flush() +) (BulkClient, error) { + if bulkSettings.FlushBytes == 0 && + bulkSettings.FlushInterval == 0 && + bulkSettings.FlushNumReqs == 0 { + return nil, fmt.Errorf("at least one of FlushBytes, FlushInterval or FlushNumReqs should be non-zero") } - - return nil -} - -func (r *Bulk) Stop() error { - r.mut.Lock() - defer r.mut.Unlock() - r.logger.Info("Stopping Bulk processor") - - if r.pendingRequests > 0 { - return r._flush() + if bulkSettings.FlushBytes > 0 && bulkSettings.FlushNumReqs > 0 { + return nil, fmt.Errorf( + "one of bulkSettings.FlushBytes (set to %d) or bulkSettings.FlushNumReqs (set to %d) should be zero", + bulkSettings.FlushBytes, + bulkSettings.FlushNumReqs, + ) } - close(r.quitFlusher) - r.quitFlusherWg.Wait() - - return nil -} - -func (r *Bulk) periodicFlusher() { - defer r.quitFlusherWg.Done() - - for { - select { - case <-time.After(common.BulkFlushInterval): - r.mut.Lock() - if r.pendingRequests > 0 { - if err := r._flush(); err != nil { - r.logger.Warn("Error flushing live indexing buffer", mlog.Err(err)) - } - } - r.mut.Unlock() - case <-r.quitFlusher: - return + var bulkClient BulkClient + var err error + if bulkSettings.FlushBytes > 0 { + bulkClient, err = NewDataBulkClient(bulkSettings, client, reqTimeout, logger) + if err != nil { + return nil, err + } + } else { + bulkClient, err = NewReqBulkClient(bulkSettings, client, reqTimeout, logger) + if err != nil { + return nil, err } } -} - -// _flush MUST be called with an acquired lock. -func (r *Bulk) _flush() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*r.settings.RequestTimeoutSeconds)*time.Second) - defer cancel() - - _, err := r.bulkClient.Do(ctx) - if err != nil { - return err - } - r.pendingRequests = 0 - return nil + return bulkClient, nil } diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_client_data.go b/server/enterprise/elasticsearch/elasticsearch/bulk_client_data.go new file mode 100644 index 00000000000..2da7664335c --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_client_data.go @@ -0,0 +1,184 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "sync" + "time" + + elastic "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/esutil" + esTypes "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" +) + +// DataBulkClient is an Elasticsearch bulk client based on the +// go-elasticsearch/v8/esutil.BulkIndexer type. +// It supports time- and size-based thresholds, but not a threshold on number +// of requests. +type DataBulkClient struct { + mut sync.Mutex + + indexer esutil.BulkIndexer + client *elastic.TypedClient + bulkSettings common.BulkSettings + reqTimeout time.Duration + logger mlog.LoggerIFace +} + +func NewDataBulkClient(bulkSettings common.BulkSettings, + client *elastic.TypedClient, + reqTimeout time.Duration, + logger mlog.LoggerIFace, +) (*DataBulkClient, error) { + if bulkSettings.FlushNumReqs > 0 { + return nil, fmt.Errorf("DataBulkClient does not support a threshold on number of requests") + } + + indexer, err := newIndexer(client, bulkSettings, logger) + if err != nil { + return nil, err + } + + return &DataBulkClient{ + indexer: indexer, + client: client, + bulkSettings: bulkSettings, + reqTimeout: reqTimeout, + logger: logger, + }, nil +} + +func newIndexer(client *elastic.TypedClient, bulkSettings common.BulkSettings, logger mlog.LoggerIFace) (esutil.BulkIndexer, error) { + // A zeroed FlushInterval means that there should be no time-based flush, + // but esutil.BulkIndexer defaults to 30 seconds if the interval is zero, + // so we pick a large enough interval + interval := bulkSettings.FlushInterval + if interval == 0 { + interval = 1 * time.Hour + } + + return esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + FlushBytes: bulkSettings.FlushBytes, + FlushInterval: interval, + Client: client, + OnError: func(ctx context.Context, err error) { + logger.Error("indexer error", mlog.Err(err)) + }, + OnFlushStart: func(ctx context.Context) context.Context { + logger.Debug("elasticsearch bulk indexer flush started") + return ctx + }, + OnFlushEnd: func(context.Context) { + logger.Debug("elasticsearch bulk indexer flush ended") + }, + }) +} + +func (b *DataBulkClient) onSuccess(_ context.Context, item esutil.BulkIndexerItem, _ esutil.BulkIndexerResponseItem) { + b.logger.Info("successfully added new bulk operation", + mlog.String("action", item.Action), + mlog.String("index", item.Index), + mlog.String("document_id", item.DocumentID), + ) +} + +func (b *DataBulkClient) onFailure(_ context.Context, item esutil.BulkIndexerItem, _ esutil.BulkIndexerResponseItem, err error) { + b.logger.Info("failed to add new bulk operation", + mlog.String("action", item.Action), + mlog.String("index", item.Index), + mlog.String("document_id", item.DocumentID), + mlog.Err(err), + ) +} + +func (b *DataBulkClient) IndexOp(op esTypes.IndexOperation, doc any) error { + b.mut.Lock() + defer b.mut.Unlock() + + var bodyReader io.ReadSeeker + switch v := doc.(type) { + case []byte: + bodyReader = bytes.NewReader(v) + case json.RawMessage: + bodyReader = bytes.NewReader(v) + default: + body, err := json.Marshal(doc) + if err != nil { + return err + } + bodyReader = bytes.NewReader(body) + } + + ctx, cancel := context.WithTimeout(context.Background(), b.reqTimeout) + defer cancel() + + return b.indexer.Add(ctx, esutil.BulkIndexerItem{ + Index: *op.Index_, + Action: "index", + DocumentID: *op.Id_, + Body: bodyReader, + OnSuccess: b.onSuccess, + OnFailure: b.onFailure, + }) +} +func (b *DataBulkClient) DeleteOp(op esTypes.DeleteOperation) error { + b.mut.Lock() + defer b.mut.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), b.reqTimeout) + defer cancel() + + return b.indexer.Add(ctx, esutil.BulkIndexerItem{ + Index: *op.Index_, + Action: "delete", + DocumentID: *op.Id_, + Body: nil, + OnSuccess: b.onSuccess, + OnFailure: b.onFailure, + }) +} + +func (b *DataBulkClient) _stop() error { + ctx, cancel := context.WithTimeout(context.Background(), b.reqTimeout) + defer cancel() + + return b.indexer.Close(ctx) +} + +func (b *DataBulkClient) Flush() error { + b.mut.Lock() + defer b.mut.Unlock() + + // The esutil.BulkIndexer cannot be manually flushed, but it can be closed, + // which does flush all the contents. + if err := b._stop(); err != nil { + return fmt.Errorf("failed to close the BulkIndexer: %w", err) + } + + // But calling Close essentially kills all the running processes, so we have + // to create a new one in order to restart it + indexer, err := newIndexer(b.client, b.bulkSettings, b.logger) + if err != nil { + return fmt.Errorf("failed to restart the BulkIndexer: %w", err) + } + b.indexer = indexer + + return nil +} + +func (b *DataBulkClient) Stop() error { + b.mut.Lock() + defer b.mut.Unlock() + + b.logger.Info("Stopping Bulk processor") + + return b._stop() +} diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_client_data_test.go b/server/enterprise/elasticsearch/elasticsearch/bulk_client_data_test.go new file mode 100644 index 00000000000..1ad6af56a40 --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_client_data_test.go @@ -0,0 +1,355 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/elastic/go-elasticsearch/v8/esutil" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/api4" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" + "github.com/stretchr/testify/require" +) + +// setupDataBulkClient creates a test data bulk client with common setup +func setupDataBulkClient(t *testing.T, flushBytes int, flushInterval time.Duration) (*DataBulkClient, *api4.TestHelper) { + th := api4.SetupEnterprise(t) + + client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) + bulkClient, err := NewDataBulkClient( + common.BulkSettings{ + FlushBytes: flushBytes, + FlushInterval: flushInterval, + FlushNumReqs: 0, // DataBulkClient doesn't support FlushNumReqs + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + require.NoError(t, err) + + return bulkClient, th +} + +func flushAndGetStats(t *testing.T, b *DataBulkClient) esutil.BulkIndexerStats { + t.Helper() + + // Close the indexer to flush + err := b.indexer.Close(context.Background()) + require.NoError(t, err) + + // Get the stats + stats := b.indexer.Stats() + + // Restart the indexer + newIdxr, err := newIndexer(b.client, b.bulkSettings, b.logger) + require.NoError(t, err) + b.indexer = newIdxr + + return stats +} + +func TestDataIndexOp(t *testing.T) { + t.Run("single index operation", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + // Check that the request got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 1, int(stats.NumAdded)) + require.Equal(t, 0, int(stats.NumIndexed)) + + // Flush, and check that the document was indexed + stats = flushAndGetStats(t, bulkClient) + require.Equal(t, 1, int(stats.NumIndexed)) + }) + + t.Run("multiple index operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + for range 5 { + post := createTestPost(t, "test message") + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + } + + // Check that the requests got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 5, int(stats.NumAdded)) + + // Flush, and check that the documents were indexed + stats = flushAndGetStats(t, bulkClient) + require.Equal(t, 5, int(stats.NumIndexed)) + }) + + t.Run("index operation with json.RawMessage", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + docId := model.NewId() + jsonData := []byte(`{"message": "test raw message"}`) + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }, jsonData) + require.NoError(t, err) + + // Check that the request got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 1, int(stats.NumAdded)) + + // Flush, and check that the document was indexed + stats = flushAndGetStats(t, bulkClient) + require.Equal(t, 1, int(stats.NumIndexed)) + }) + + t.Run("index operation with byte slice", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + docId := model.NewId() + data := []byte(`{"message": "test byte slice"}`) + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }, data) + require.NoError(t, err) + + // Check that the request got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 1, int(stats.NumAdded)) + + // Flush, and check that the document was indexed + stats = flushAndGetStats(t, bulkClient) + require.Equal(t, 1, int(stats.NumIndexed)) + }) +} + +func TestDataDeleteOp(t *testing.T) { + t.Run("single delete operation", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + // Index a new post and flush + post := createTestPost(t, "test message") + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + require.NoError(t, bulkClient.Flush()) + + err = bulkClient.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }) + require.NoError(t, err) + + // Check that the request got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 1, int(stats.NumAdded)) + require.Equal(t, 0, int(stats.NumDeleted)) + + // Flush, and check that the document was deleted + stats = flushAndGetStats(t, bulkClient) + fmt.Println(stats) + require.Equal(t, 1, int(stats.NumDeleted)) + }) + + t.Run("multiple delete operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + posts := make([]string, 3) + + // Index three new posts and flush + for i := range 3 { + post := createTestPost(t, "test message") + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + posts[i] = post.Id + } + require.NoError(t, bulkClient.Flush()) + + for _, id := range posts { + err := bulkClient.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(id), + }) + require.NoError(t, err) + } + + // Check that the requests got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 3, int(stats.NumAdded)) + require.Equal(t, 0, int(stats.NumDeleted)) + + // Flush, and check that the documents were deleted + stats = flushAndGetStats(t, bulkClient) + fmt.Println(stats) + require.Equal(t, 3, int(stats.NumDeleted)) + }) +} + +func TestDataFlush(t *testing.T) { + t.Run("flush with pending operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + err = bulkClient.Flush() + require.NoError(t, err) + }) + + t.Run("flush with no pending operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + err := bulkClient.Flush() + require.NoError(t, err) + }) +} + +func TestDataStop(t *testing.T) { + t.Run("stop with pending operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + err = bulkClient.Stop() + require.NoError(t, err) + }) + + t.Run("stop with no pending operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + + err := bulkClient.Stop() + require.NoError(t, err) + }) + + t.Run("stop with periodic flusher", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 100*time.Millisecond) + defer th.TearDown() + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + // Stop should flush pending operations and stop the periodic flusher + err = bulkClient.Stop() + require.NoError(t, err) + }) +} + +func TestDataNewDataBulkClient(t *testing.T) { + th := api4.SetupEnterprise(t) + defer th.TearDown() + + client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) + + t.Run("valid configuration", func(t *testing.T) { + bulkClient, err := NewDataBulkClient( + common.BulkSettings{ + FlushBytes: 1024, + FlushInterval: 100 * time.Millisecond, + FlushNumReqs: 0, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + require.NoError(t, err) + require.NotNil(t, bulkClient) + + err = bulkClient.Stop() + require.NoError(t, err) + }) + + t.Run("invalid configuration with FlushNumReqs", func(t *testing.T) { + bulkClient, err := NewDataBulkClient( + common.BulkSettings{ + FlushBytes: 1024, + FlushInterval: 100 * time.Millisecond, + FlushNumReqs: 10, // This should cause an error + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + require.Error(t, err) + require.Nil(t, bulkClient) + require.Contains(t, err.Error(), "DataBulkClient does not support a threshold on number of requests") + }) +} diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_client_req.go b/server/enterprise/elasticsearch/elasticsearch/bulk_client_req.go new file mode 100644 index 00000000000..80bf5df94fd --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_client_req.go @@ -0,0 +1,164 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "context" + "fmt" + "sync" + "time" + + elastic "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/typedapi/core/bulk" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" +) + +// ReqBulkClient is an Elasticsearch bulk client based on the +// go-elasticsearch/v8/typedapi/code/bulk.Bulk type. +// It supports time- and number-of-requests-based thresholds, but not a +// threshold on the size of the request. +type ReqBulkClient struct { + mut sync.Mutex + + indexer *bulk.Bulk + client *elastic.TypedClient + bulkSettings common.BulkSettings + reqTimeout time.Duration + logger mlog.LoggerIFace + + quitFlusher chan struct{} + quitFlusherWg sync.WaitGroup + pendingRequests int +} + +func NewReqBulkClient(bulkSettings common.BulkSettings, + client *elastic.TypedClient, + reqTimeout time.Duration, + logger mlog.LoggerIFace, +) (*ReqBulkClient, error) { + if bulkSettings.FlushBytes > 0 { + return nil, fmt.Errorf("BulkClientBasic does not support a threshold on bytes") + } + + b := &ReqBulkClient{ + indexer: client.Bulk(), + client: client, + bulkSettings: bulkSettings, + reqTimeout: reqTimeout, + logger: logger, + + quitFlusher: make(chan struct{}), + } + + if bulkSettings.FlushInterval > 0 { + b.quitFlusherWg.Add(1) + go b.periodicFlusher() + } + + return b, nil +} + +// IndexOp is a helper function to add an IndexOperation to the current bulk request. +// doc argument can be a []byte, json.RawMessage or a struct. +func (r *ReqBulkClient) IndexOp(op types.IndexOperation, doc any) error { + r.mut.Lock() + defer r.mut.Unlock() + + if err := r.indexer.IndexOp(op, doc); err != nil { + return err + } + + return r.flushIfNecessary() +} + +// DeleteOp is a helper function to add a DeleteOperation to the current bulk request. +func (r *ReqBulkClient) DeleteOp(op types.DeleteOperation) error { + r.mut.Lock() + defer r.mut.Unlock() + + if err := r.indexer.DeleteOp(op); err != nil { + return err + } + + return r.flushIfNecessary() +} + +// flushIfNecessary flushes the pending buffer if needed. +// It MUST be called with an already acquired mutex. +func (r *ReqBulkClient) flushIfNecessary() error { + r.pendingRequests++ + + // Check number of requests threshold, only if specified + if r.bulkSettings.FlushNumReqs > 0 { + if r.pendingRequests > r.bulkSettings.FlushNumReqs { + return r._flush() + } + } + + return nil +} + +func (r *ReqBulkClient) Stop() error { + r.mut.Lock() + defer r.mut.Unlock() + + r.logger.Info("Stopping Bulk processor") + + if r.pendingRequests > 0 { + return r._flush() + } + + if r.bulkSettings.FlushInterval > 0 { + close(r.quitFlusher) + r.quitFlusherWg.Wait() + } + + return nil +} + +func (r *ReqBulkClient) periodicFlusher() { + defer r.quitFlusherWg.Done() + + for { + select { + case <-time.After(r.bulkSettings.FlushInterval): + r.mut.Lock() + if r.pendingRequests > 0 { + if err := r._flush(); err != nil { + r.logger.Warn("Error flushing live indexing buffer", mlog.Err(err)) + } + } + r.mut.Unlock() + case <-r.quitFlusher: + return + } + } +} + +// _flush MUST be called with an acquired lock. +func (r *ReqBulkClient) _flush() error { + if r.pendingRequests == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), r.reqTimeout) + defer cancel() + + _, err := r.indexer.Do(ctx) + if err != nil { + return err + } + r.pendingRequests = 0 + + return nil +} + +func (r *ReqBulkClient) Flush() error { + r.mut.Lock() + defer r.mut.Unlock() + + return r._flush() +} diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_client_req_test.go b/server/enterprise/elasticsearch/elasticsearch/bulk_client_req_test.go new file mode 100644 index 00000000000..d2efe807121 --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_client_req_test.go @@ -0,0 +1,263 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "testing" + "time" + + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/api4" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" + "github.com/stretchr/testify/require" +) + +// setupBulkClient creates a test bulk client with common setup +func setupBulkClient(t *testing.T, flushNumReqs int, flushInterval time.Duration) (*ReqBulkClient, *api4.TestHelper) { + th := api4.SetupEnterprise(t) + + client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) + bulkClient, err := NewReqBulkClient( + common.BulkSettings{ + FlushBytes: 0, + FlushInterval: flushInterval, + FlushNumReqs: flushNumReqs, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + require.NoError(t, err) + + return bulkClient, th +} + +// createTestPost creates a test post for indexing +func createTestPost(t *testing.T, message string) *common.ESPost { + post, err := common.ESPostFromPost(&model.Post{ + Id: model.NewId(), + Message: message, + }, "myteam") + require.NoError(t, err) + return post +} + +func TestBulkProcessor(t *testing.T) { + th := api4.SetupEnterprise(t) + defer th.TearDown() + + bulkClient, _ := setupBulkClient(t, *th.App.Config().ElasticsearchSettings.LiveIndexingBatchSize, 0) + + post := createTestPost(t, "hello world") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("myindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + require.Equal(t, 1, bulkClient.pendingRequests) + + err = bulkClient.Stop() + require.NoError(t, err) + + require.Equal(t, 0, bulkClient.pendingRequests) +} + +func TestIndexOp(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + t.Run("single index operation", func(t *testing.T) { + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + }) + + t.Run("multiple index operations", func(t *testing.T) { + initialRequests := bulkClient.pendingRequests + + for range 5 { + post := createTestPost(t, "test message") + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + } + + require.Equal(t, initialRequests+5, bulkClient.pendingRequests) + }) + + t.Run("auto flush on threshold", func(t *testing.T) { + // Create a new client with low flush threshold + bulkClient2, th2 := setupBulkClient(t, 2, 0) + defer th2.TearDown() + + post1 := createTestPost(t, "first message") + err := bulkClient2.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post1.Id), + }, post1) + require.NoError(t, err) + require.Equal(t, 1, bulkClient2.pendingRequests) + + post2 := createTestPost(t, "second message") + err = bulkClient2.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post2.Id), + }, post2) + require.NoError(t, err) + require.Equal(t, 2, bulkClient2.pendingRequests) + + // Third operation should trigger flush + post3 := createTestPost(t, "third message") + err = bulkClient2.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post3.Id), + }, post3) + require.NoError(t, err) + require.Equal(t, 0, bulkClient2.pendingRequests) + }) +} + +func TestDeleteOp(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + t.Run("single delete operation", func(t *testing.T) { + docId := model.NewId() + + err := bulkClient.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + }) + + t.Run("multiple delete operations", func(t *testing.T) { + initialRequests := bulkClient.pendingRequests + + for range 3 { + docId := model.NewId() + err := bulkClient.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + } + + require.Equal(t, initialRequests+3, bulkClient.pendingRequests) + }) + + t.Run("auto flush on threshold", func(t *testing.T) { + // Create a new client with low flush threshold + bulkClient2, th2 := setupBulkClient(t, 2, 0) + defer th2.TearDown() + + // Add two delete operations + for range 2 { + docId := model.NewId() + err := bulkClient2.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + } + require.Equal(t, 2, bulkClient2.pendingRequests) + + // Third operation should trigger flush + docId := model.NewId() + err := bulkClient2.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + require.Equal(t, 0, bulkClient2.pendingRequests) + }) +} + +func TestFlush(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + t.Run("flush with pending requests", func(t *testing.T) { + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + + err = bulkClient.Flush() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) + + t.Run("flush with no pending requests", func(t *testing.T) { + require.Equal(t, 0, bulkClient.pendingRequests) + + err := bulkClient.Flush() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) +} + +func TestStop(t *testing.T) { + t.Run("stop with pending requests", func(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + + err = bulkClient.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) + + t.Run("stop with no pending requests", func(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + require.Equal(t, 0, bulkClient.pendingRequests) + + err := bulkClient.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) + + t.Run("stop with periodic flusher", func(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 100*time.Millisecond) + defer th.TearDown() + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + + // Stop should flush pending requests and stop the periodic flusher + err = bulkClient.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) +} diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_test.go b/server/enterprise/elasticsearch/elasticsearch/bulk_test.go index 0970dbddb29..663227de8b0 100644 --- a/server/enterprise/elasticsearch/elasticsearch/bulk_test.go +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_test.go @@ -5,39 +5,75 @@ package elasticsearch import ( "testing" + "time" - "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/v8/channels/api4" "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" "github.com/stretchr/testify/require" ) -func TestBulkProcessor(t *testing.T) { +func TestNewBulk(t *testing.T) { th := api4.SetupEnterprise(t) defer th.TearDown() - client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) - bulk := NewBulk(th.App.Config().ElasticsearchSettings, - th.Server.Platform().Log(), - client) - post, err := common.ESPostFromPost(&model.Post{ - Id: model.NewId(), - Message: "hello world", - }, "myteam") - require.NoError(t, err) + t.Run("zeroed bulksettings", func(t *testing.T) { + _, err := NewBulk( + common.BulkSettings{}, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log(), + ) + + require.Error(t, err) + }) + + t.Run("incompatible bulkSettings", func(t *testing.T) { + _, err := NewBulk( + common.BulkSettings{ + FlushBytes: 100, + FlushInterval: 5 * time.Second, + FlushNumReqs: 10, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log(), + ) + + require.Error(t, err) + }) - err = bulk.IndexOp(types.IndexOperation{ - Index_: model.NewPointer("myindex"), - Id_: model.NewPointer(post.Id), - }, post) - require.NoError(t, err) + t.Run("data-based bulk client", func(t *testing.T) { + client, err := NewBulk( + common.BulkSettings{ + FlushBytes: 100, + FlushInterval: 5 * time.Second, + FlushNumReqs: 0, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log(), + ) + require.NoError(t, err) - require.Equal(t, 1, bulk.pendingRequests) + _, ok := client.(*DataBulkClient) + require.True(t, ok) + }) - err = bulk.Stop() - require.NoError(t, err) + t.Run("requests-based bulk client", func(t *testing.T) { + client, err := NewBulk( + common.BulkSettings{ + FlushBytes: 0, + FlushInterval: 5 * time.Second, + FlushNumReqs: 100, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log(), + ) + require.NoError(t, err) - require.Equal(t, 0, bulk.pendingRequests) + _, ok := client.(*ReqBulkClient) + require.True(t, ok) + }) } diff --git a/server/enterprise/elasticsearch/elasticsearch/elasticsearch.go b/server/enterprise/elasticsearch/elasticsearch/elasticsearch.go index e4443808d70..6524d1b4015 100644 --- a/server/enterprise/elasticsearch/elasticsearch/elasticsearch.go +++ b/server/enterprise/elasticsearch/elasticsearch/elasticsearch.go @@ -43,8 +43,9 @@ type ElasticsearchInterfaceImpl struct { fullVersion string plugins []string - bulkProcessor *Bulk - Platform *platform.PlatformService + bulkProcessor BulkClient + syncBulkProcessor BulkClient + Platform *platform.PlatformService } func getJSONOrErrorStr(obj any) string { @@ -128,10 +129,39 @@ func (es *ElasticsearchInterfaceImpl) Start() *model.AppError { ctx := context.Background() - if *es.Platform.Config().ElasticsearchSettings.LiveIndexingBatchSize > 1 { - es.bulkProcessor = NewBulk(es.Platform.Config().ElasticsearchSettings, - es.Platform.Log(), - es.client) + esSettings := es.Platform.Config().ElasticsearchSettings + if *esSettings.LiveIndexingBatchSize > 1 { + es.bulkProcessor, err = NewBulk( + common.BulkSettings{ + FlushBytes: 0, + FlushInterval: common.BulkFlushInterval, + FlushNumReqs: *esSettings.LiveIndexingBatchSize, + }, + es.client, + time.Duration(*esSettings.RequestTimeoutSeconds)*time.Second, + es.Platform.Log()) + if err != nil { + return model.NewAppError("elasticsearch.start", + "ent.elasticsearch.create_processor.bulk_processor_create_failed", + nil, "", + http.StatusInternalServerError).Wrap(err) + } + } + + es.syncBulkProcessor, err = NewBulk( + common.BulkSettings{ + FlushBytes: common.BulkFlushBytes, + FlushInterval: 0, + FlushNumReqs: 0, + }, + es.client, + time.Duration(*esSettings.RequestTimeoutSeconds)*time.Second, + es.Platform.Log()) + if err != nil { + return model.NewAppError("elasticsearch.start", + "ent.elasticsearch.create_processor.sync_bulk_processor_create_failed", + nil, "", + http.StatusInternalServerError).Wrap(err) } // Set up posts index template. @@ -749,6 +779,49 @@ func (es *ElasticsearchInterfaceImpl) IndexChannel(rctx request.CTX, channel *mo return nil } +func (es *ElasticsearchInterfaceImpl) SyncBulkIndexChannels(rctx request.CTX, channels []*model.Channel, getUserIDsForChannel func(channel *model.Channel) ([]string, error), teamMemberIDs []string) *model.AppError { + if len(channels) == 0 { + return nil + } + + es.mutex.RLock() + defer es.mutex.RUnlock() + + if atomic.LoadInt32(&es.ready) == 0 { + return model.NewAppError("Elasticsearch.SyncBulkIndexChannels", "ent.elasticsearch.not_started.error", map[string]any{"Backend": model.ElasticsearchSettingsESBackend}, "", http.StatusInternalServerError) + } + + indexName := *es.Platform.Config().ElasticsearchSettings.IndexPrefix + common.IndexBaseChannels + metrics := es.Platform.Metrics() + + for _, channel := range channels { + userIDs, err := getUserIDsForChannel(channel) + if err != nil { + return model.NewAppError("Elasticsearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + searchChannel := common.ESChannelFromChannel(channel, userIDs, teamMemberIDs) + + err = es.syncBulkProcessor.IndexOp(types.IndexOperation{ + Index_: model.NewPointer(indexName), + Id_: model.NewPointer(searchChannel.Id), + }, searchChannel) + if err != nil { + return model.NewAppError("Elasticsearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + if metrics != nil { + metrics.IncrementChannelIndexCounter() + } + } + + if err := es.syncBulkProcessor.Flush(); err != nil { + return model.NewAppError("Elasticsearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + return nil +} + func (es *ElasticsearchInterfaceImpl) SearchChannels(teamId, userID string, term string, isGuest, includeDeleted bool) ([]string, *model.AppError) { es.mutex.RLock() defer es.mutex.RUnlock() diff --git a/server/enterprise/elasticsearch/elasticsearch/elasticsearch_test.go b/server/enterprise/elasticsearch/elasticsearch/elasticsearch_test.go index 62edffeea7c..21baa62a1b0 100644 --- a/server/enterprise/elasticsearch/elasticsearch/elasticsearch_test.go +++ b/server/enterprise/elasticsearch/elasticsearch/elasticsearch_test.go @@ -98,3 +98,76 @@ func (s *ElasticsearchInterfaceTestSuite) SetupTest() { s.Nil(s.CommonTestSuite.ESImpl.PurgeIndexes(s.th.Context)) } + +func (s *ElasticsearchInterfaceTestSuite) TestSyncBulkIndexChannels() { + s.Run("Should index multiple channels successfully", func() { + // Create test channels + channel1 := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypeOpen, + Name: "test-channel-1", + DisplayName: "Test Channel 1", + } + channel1.PreSave() + + channel2 := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypePrivate, + Name: "test-channel-2", + DisplayName: "Test Channel 2", + } + channel2.PreSave() + + channels := []*model.Channel{channel1, channel2} + + // Mock getUserIDsForChannel function + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return []string{s.th.BasicUser.Id, s.th.BasicUser2.Id}, nil + } + + teamMemberIDs := []string{s.th.BasicUser.Id, s.th.BasicUser2.Id} + + // Test the bulk indexing + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, channels, getUserIDsForChannel, teamMemberIDs) + s.Require().Nil(appErr) + + // Refresh the index to ensure data is searchable + s.Require().NoError(s.CommonTestSuite.RefreshIndexFn()) + + // Verify both channels are indexed + found, _, err := s.CommonTestSuite.GetDocumentFn("channels", channel1.Id) + s.Require().NoError(err) + s.Require().True(found) + + found, _, err = s.CommonTestSuite.GetDocumentFn("channels", channel2.Id) + s.Require().NoError(err) + s.Require().True(found) + }) + + s.Run("Should handle empty channels list", func() { + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return []string{}, nil + } + + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, []*model.Channel{}, getUserIDsForChannel, []string{}) + s.Require().Nil(appErr) + }) + + s.Run("Should handle getUserIDsForChannel error", func() { + channel := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypeOpen, + Name: "test-channel-error", + DisplayName: "Test Channel Error", + } + channel.PreSave() + + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return nil, model.NewAppError("TestError", "test.error", nil, "", 500) + } + + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, []*model.Channel{channel}, getUserIDsForChannel, []string{}) + s.Require().NotNil(appErr) + s.Require().Contains(appErr.Error(), "test.error") + }) +} diff --git a/server/enterprise/elasticsearch/elasticsearch/sync_bulk.go b/server/enterprise/elasticsearch/elasticsearch/sync_bulk.go new file mode 100644 index 00000000000..298d97cfddd --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/sync_bulk.go @@ -0,0 +1,94 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "time" + + elastic "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/esutil" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" +) + +type SyncBulk struct { + client *elastic.TypedClient + bulkIndexer esutil.BulkIndexer +} + +func NewSyncBulk(client *elastic.TypedClient) (*SyncBulk, error) { + bulkIndexer, err := newBulkIndexer(client) + if err != nil { + return nil, err + } + + return &SyncBulk{client, bulkIndexer}, nil +} + +func newBulkIndexer(client *elastic.TypedClient) (esutil.BulkIndexer, error) { + return esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + Client: client, // The Elasticsearch client + FlushBytes: common.BulkFlushBytes, // The flush threshold in bytes + FlushInterval: 30 * time.Second, // The periodic flush interval + }) +} + +// IndexOp is a helper function to add an IndexOperation to the current bulk request. +// doc argument can be a []byte, json.RawMessage or a struct. +func (r *SyncBulk) IndexOp(op types.IndexOperation, doc any) error { + var body io.ReadSeeker + switch v := doc.(type) { + case []byte: + body = bytes.NewReader(v) + case json.RawMessage: + body = bytes.NewReader(v) + default: + data, err := json.Marshal(doc) + if err != nil { + return err + } + body = bytes.NewReader(data) + } + + return r.bulkIndexer.Add(context.Background(), esutil.BulkIndexerItem{ + Index: *op.Index_, + Action: "index", + DocumentID: *op.Id_, + Body: body, + }) +} + +// DeleteOp is a helper function to add a DeleteOperation to the current bulk request. +func (r *SyncBulk) DeleteOp(op types.DeleteOperation) error { + return r.bulkIndexer.Add(context.Background(), esutil.BulkIndexerItem{ + Index: *op.Index_, + Action: "delete", + DocumentID: *op.Id_, + }) +} + +func (r *SyncBulk) Stop() error { + return r.bulkIndexer.Close(context.Background()) +} + +func (r *SyncBulk) Flush() error { + // Flush by closing the indexer: there is no manual Flush method + if err := r.bulkIndexer.Close(context.Background()); err != nil { + return err + } + + // Restart the indexer so that we can keep using it + bulkIndexer, err := newBulkIndexer(r.client) + if err != nil { + return fmt.Errorf("unable to restart bulk indexer: %w", err) + } + + r.bulkIndexer = bulkIndexer + return nil +} diff --git a/server/enterprise/elasticsearch/opensearch/bulk.go b/server/enterprise/elasticsearch/opensearch/bulk.go index ca40793b95b..dbb26d80029 100644 --- a/server/enterprise/elasticsearch/opensearch/bulk.go +++ b/server/enterprise/elasticsearch/opensearch/bulk.go @@ -11,7 +11,6 @@ import ( "time" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" "github.com/opensearch-project/opensearch-go/v4/opensearchapi" @@ -21,9 +20,10 @@ type Bulk struct { mut sync.Mutex buf *bytes.Buffer - logger mlog.LoggerIFace - client *opensearchapi.Client - settings model.ElasticsearchSettings + client *opensearchapi.Client + bulkSettings common.BulkSettings + reqTimeout time.Duration + logger mlog.LoggerIFace quitFlusher chan struct{} quitFlusherWg sync.WaitGroup @@ -31,19 +31,25 @@ type Bulk struct { pendingRequests int } -func NewBulk(settings model.ElasticsearchSettings, +func NewBulk(bulkSettings common.BulkSettings, + client *opensearchapi.Client, + reqTimeout time.Duration, logger mlog.LoggerIFace, - client *opensearchapi.Client) *Bulk { +) *Bulk { b := &Bulk{ - settings: settings, - logger: logger, - client: client, - quitFlusher: make(chan struct{}), - buf: &bytes.Buffer{}, + bulkSettings: bulkSettings, + reqTimeout: reqTimeout, + logger: logger, + client: client, + quitFlusher: make(chan struct{}), + buf: &bytes.Buffer{}, } - b.quitFlusherWg.Add(1) - go b.periodicFlusher() + // Start the timer only if a flush interval was specified + if bulkSettings.FlushInterval > 0 { + b.quitFlusherWg.Add(1) + go b.periodicFlusher() + } return b } @@ -101,10 +107,20 @@ func (r *Bulk) DeleteOp(op *types.DeleteOperation) error { // flushIfNecessary flushes the pending buffer if needed. // It MUST be called with an already acquired mutex. func (r *Bulk) flushIfNecessary() error { + // Check data threshold, only if specified + if r.bulkSettings.FlushBytes > 0 { + if r.buf.Len() >= r.bulkSettings.FlushBytes { + return r._flush() + } + } + r.pendingRequests++ - if r.pendingRequests > *r.settings.LiveIndexingBatchSize { - return r._flush() + // Check number of requests threshold, only if specified + if r.bulkSettings.FlushNumReqs > 0 { + if r.pendingRequests > r.bulkSettings.FlushNumReqs { + return r._flush() + } } return nil @@ -119,8 +135,11 @@ func (r *Bulk) Stop() error { return r._flush() } - close(r.quitFlusher) - r.quitFlusherWg.Wait() + // Cleanup the timer if the flush interval was specified + if r.bulkSettings.FlushInterval > 0 { + close(r.quitFlusher) + r.quitFlusherWg.Wait() + } return nil } @@ -130,7 +149,7 @@ func (r *Bulk) periodicFlusher() { for { select { - case <-time.After(common.BulkFlushInterval): + case <-time.After(r.bulkSettings.FlushInterval): r.mut.Lock() if r.pendingRequests > 0 { if err := r._flush(); err != nil { @@ -146,7 +165,11 @@ func (r *Bulk) periodicFlusher() { // _flush MUST be called with an acquired lock. func (r *Bulk) _flush() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*r.settings.RequestTimeoutSeconds)*time.Second) + if r.pendingRequests == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), r.reqTimeout) defer cancel() _, err := r.client.Bulk(ctx, opensearchapi.BulkReq{ @@ -160,3 +183,9 @@ func (r *Bulk) _flush() error { return nil } + +func (r *Bulk) Flush() error { + r.mut.Lock() + defer r.mut.Unlock() + return r._flush() +} diff --git a/server/enterprise/elasticsearch/opensearch/bulk_test.go b/server/enterprise/elasticsearch/opensearch/bulk_test.go index 57466e9beb1..a7590444ae0 100644 --- a/server/enterprise/elasticsearch/opensearch/bulk_test.go +++ b/server/enterprise/elasticsearch/opensearch/bulk_test.go @@ -4,8 +4,10 @@ package opensearch import ( + "fmt" "os" "testing" + "time" "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/mattermost/mattermost/server/public/model" @@ -14,22 +16,15 @@ import ( "github.com/stretchr/testify/require" ) -func TestBulkProcessor(t *testing.T) { +// setupBulkClient creates a test bulk client with common setup +func setupBulkClient(t *testing.T, flushBytes int, flushNumReqs int, flushInterval time.Duration) (*Bulk, *api4.TestHelper) { th := api4.SetupEnterprise(t) - defer th.TearDown() if os.Getenv("IS_CI") == "true" { os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://opensearch:9201") os.Setenv("MM_ELASTICSEARCHSETTINGS_BACKEND", "opensearch") } - defer func() { - if os.Getenv("IS_CI") == "true" { - os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") - os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") - } - }() - th.App.UpdateConfig(func(cfg *model.Config) { if os.Getenv("IS_CI") == "true" { *cfg.ElasticsearchSettings.ConnectionURL = "http://opensearch:9201" @@ -43,17 +38,42 @@ func TestBulkProcessor(t *testing.T) { }) client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) - bulk := NewBulk(th.App.Config().ElasticsearchSettings, - th.Server.Platform().Log(), - client) + bulk := NewBulk( + common.BulkSettings{ + FlushBytes: flushBytes, + FlushInterval: flushInterval, + FlushNumReqs: flushNumReqs, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + return bulk, th +} + +// createTestPost creates a test post for indexing +func createTestPost(t *testing.T, message string) *common.ESPost { post, err := common.ESPostFromPost(&model.Post{ Id: model.NewId(), - Message: "hello world", + Message: message, }, "myteam") require.NoError(t, err) + return post +} + +func TestBulkProcessor(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() - err = bulk.IndexOp(&types.IndexOperation{ + post := createTestPost(t, "hello world") + + err := bulk.IndexOp(&types.IndexOperation{ Index_: model.NewPointer("myindex"), Id_: model.NewPointer(post.Id), }, post) @@ -66,3 +86,397 @@ func TestBulkProcessor(t *testing.T) { require.Equal(t, 0, bulk.pendingRequests) } + +func TestNewBulk(t *testing.T) { + bulk, th := setupBulkClient(t, 1024, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + t.Run("creates bulk client without periodic flusher", func(t *testing.T) { + require.NotNil(t, bulk) + require.NotNil(t, bulk.client) + require.NotNil(t, bulk.logger) + require.NotNil(t, bulk.buf) + require.Equal(t, 0, bulk.pendingRequests) + require.Equal(t, 1024, bulk.bulkSettings.FlushBytes) + require.Equal(t, 10, bulk.bulkSettings.FlushNumReqs) + }) + + t.Run("creates bulk client with periodic flusher", func(t *testing.T) { + bulkWithTimer, th2 := setupBulkClient(t, 1024, 10, 100*time.Millisecond) + defer th2.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + require.NotNil(t, bulkWithTimer) + require.Equal(t, 100*time.Millisecond, bulkWithTimer.bulkSettings.FlushInterval) + + err := bulkWithTimer.Stop() + require.NoError(t, err) + }) + + err := bulk.Stop() + require.NoError(t, err) +} + +func TestIndexOp(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + t.Run("single index operation with struct", func(t *testing.T) { + post := createTestPost(t, "test message") + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + // Verify buffer has content + require.Greater(t, bulk.buf.Len(), 0) + }) + + t.Run("index operation with []byte", func(t *testing.T) { + initialRequests := bulk.pendingRequests + docId := model.NewId() + data := []byte(`{"message": "test byte slice"}`) + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }, data) + require.NoError(t, err) + require.Equal(t, initialRequests+1, bulk.pendingRequests) + }) + + t.Run("index operation with json.RawMessage", func(t *testing.T) { + initialRequests := bulk.pendingRequests + docId := model.NewId() + jsonData := []byte(`{"message": "test raw message"}`) + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }, jsonData) + require.NoError(t, err) + require.Equal(t, initialRequests+1, bulk.pendingRequests) + }) + + t.Run("multiple index operations", func(t *testing.T) { + initialRequests := bulk.pendingRequests + + for i := range 5 { + post := createTestPost(t, fmt.Sprintf("test message %d", i)) + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + } + + require.Equal(t, initialRequests+5, bulk.pendingRequests) + }) + + t.Run("auto flush on request threshold", func(t *testing.T) { + // Create a new client with low flush threshold + bulk2, th2 := setupBulkClient(t, 0, 2, 0) + defer th2.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + post1 := createTestPost(t, "first message") + err := bulk2.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post1.Id), + }, post1) + require.NoError(t, err) + require.Equal(t, 1, bulk2.pendingRequests) + + post2 := createTestPost(t, "second message") + err = bulk2.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post2.Id), + }, post2) + require.NoError(t, err) + require.Equal(t, 2, bulk2.pendingRequests) + + // Third operation should trigger flush + post3 := createTestPost(t, "third message") + err = bulk2.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post3.Id), + }, post3) + require.NoError(t, err) + require.Equal(t, 0, bulk2.pendingRequests) + + err = bulk2.Stop() + require.NoError(t, err) + }) + + err := bulk.Stop() + require.NoError(t, err) +} + +func TestDeleteOp(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + t.Run("single delete operation", func(t *testing.T) { + docId := model.NewId() + + err := bulk.DeleteOp(&types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + // Verify buffer has content + require.Greater(t, bulk.buf.Len(), 0) + }) + + t.Run("multiple delete operations", func(t *testing.T) { + initialRequests := bulk.pendingRequests + + for range 3 { + docId := model.NewId() + err := bulk.DeleteOp(&types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + } + + require.Equal(t, initialRequests+3, bulk.pendingRequests) + }) + + t.Run("auto flush on request threshold", func(t *testing.T) { + // Create a new client with low flush threshold + bulk2, th2 := setupBulkClient(t, 0, 2, 0) + defer th2.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + // Add two delete operations + for range 2 { + docId := model.NewId() + err := bulk2.DeleteOp(&types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + } + require.Equal(t, 2, bulk2.pendingRequests) + + // Third operation should trigger flush + docId := model.NewId() + err := bulk2.DeleteOp(&types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + require.Equal(t, 0, bulk2.pendingRequests) + + err = bulk2.Stop() + require.NoError(t, err) + }) + + err := bulk.Stop() + require.NoError(t, err) +} + +func TestFlush(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + t.Run("flush with pending operations", func(t *testing.T) { + post := createTestPost(t, "test message") + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + err = bulk.Flush() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + + // Verify buffer is empty after flush + require.Equal(t, 0, bulk.buf.Len()) + }) + + t.Run("flush with no pending operations", func(t *testing.T) { + require.Equal(t, 0, bulk.pendingRequests) + + err := bulk.Flush() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + }) + + err := bulk.Stop() + require.NoError(t, err) +} + +func TestStop(t *testing.T) { + t.Run("stop with pending operations", func(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + post := createTestPost(t, "test message") + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + err = bulk.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + }) + + t.Run("stop with no pending operations", func(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + require.Equal(t, 0, bulk.pendingRequests) + + err := bulk.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + }) + + t.Run("stop with periodic flusher", func(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 100*time.Millisecond) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + post := createTestPost(t, "test message") + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + // Stop should flush pending operations and stop the periodic flusher + err = bulk.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + }) +} + +func TestFlushThresholds(t *testing.T) { + t.Run("flush on bytes threshold", func(t *testing.T) { + // Create a client with very small byte threshold + bulk, th := setupBulkClient(t, 100, 0, 0) // 100 bytes threshold + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + // Add operations that should exceed the byte threshold + for range 5 { + post := createTestPost(t, "This is a long message that should help us exceed the byte threshold for testing") + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + } + + // Should have been flushed due to byte threshold + require.Equal(t, 0, bulk.pendingRequests) + + err := bulk.Stop() + require.NoError(t, err) + }) + + t.Run("no flush when thresholds not met", func(t *testing.T) { + bulk, th := setupBulkClient(t, 100000, 10, 0) // High thresholds + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + // Add a few operations that shouldn't trigger flush + for range 3 { + post := createTestPost(t, "short") + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + fmt.Println("PENDING REQS:", bulk.pendingRequests) + } + + // Should not have been flushed + require.Equal(t, 3, bulk.pendingRequests) + + err := bulk.Stop() + require.NoError(t, err) + }) +} diff --git a/server/enterprise/elasticsearch/opensearch/opensearch.go b/server/enterprise/elasticsearch/opensearch/opensearch.go index 69bb30b8758..1a542fa450d 100644 --- a/server/enterprise/elasticsearch/opensearch/opensearch.go +++ b/server/enterprise/elasticsearch/opensearch/opensearch.go @@ -45,8 +45,9 @@ type OpensearchInterfaceImpl struct { fullVersion string plugins []string - bulkProcessor *Bulk - Platform *platform.PlatformService + bulkProcessor *Bulk + syncBulkProcessor *Bulk + Platform *platform.PlatformService } func getJSONOrErrorStr(obj any) string { @@ -130,11 +131,27 @@ func (os *OpensearchInterfaceImpl) Start() *model.AppError { ctx := context.Background() - if *os.Platform.Config().ElasticsearchSettings.LiveIndexingBatchSize > 1 { - os.bulkProcessor = NewBulk(os.Platform.Config().ElasticsearchSettings, - os.Platform.Log(), - os.client) - } + esSettings := os.Platform.Config().ElasticsearchSettings + if *esSettings.LiveIndexingBatchSize > 1 { + os.bulkProcessor = NewBulk( + common.BulkSettings{ + FlushBytes: 0, + FlushInterval: common.BulkFlushInterval, + FlushNumReqs: *esSettings.LiveIndexingBatchSize, + }, + os.client, + time.Duration(*esSettings.RequestTimeoutSeconds)*time.Second, + os.Platform.Log()) + } + os.syncBulkProcessor = NewBulk( + common.BulkSettings{ + FlushBytes: common.BulkFlushBytes, + FlushInterval: 0, + FlushNumReqs: 0, + }, + os.client, + time.Duration(*esSettings.RequestTimeoutSeconds)*time.Second, + os.Platform.Log()) // Set up posts index template. templateBuf, err := json.Marshal(common.GetPostTemplate(os.Platform.Config())) @@ -831,6 +848,49 @@ func (os *OpensearchInterfaceImpl) IndexChannel(rctx request.CTX, channel *model return nil } +func (os *OpensearchInterfaceImpl) SyncBulkIndexChannels(rctx request.CTX, channels []*model.Channel, getUserIDsForChannel func(channel *model.Channel) ([]string, error), teamMemberIDs []string) *model.AppError { + if len(channels) == 0 { + return nil + } + + os.mutex.RLock() + defer os.mutex.RUnlock() + + if atomic.LoadInt32(&os.ready) == 0 { + return model.NewAppError("Opensearch.SyncBulkIndexChannels", "ent.elasticsearch.not_started.error", map[string]any{"Backend": model.ElasticsearchSettingsOSBackend}, "", http.StatusInternalServerError) + } + + indexName := *os.Platform.Config().ElasticsearchSettings.IndexPrefix + common.IndexBaseChannels + metrics := os.Platform.Metrics() + + for _, channel := range channels { + userIDs, err := getUserIDsForChannel(channel) + if err != nil { + return model.NewAppError("Opensearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + searchChannel := common.ESChannelFromChannel(channel, userIDs, teamMemberIDs) + + err = os.syncBulkProcessor.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer(indexName), + Id_: model.NewPointer(searchChannel.Id), + }, searchChannel) + if err != nil { + return model.NewAppError("Opensearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + if metrics != nil { + metrics.IncrementChannelIndexCounter() + } + } + + if err := os.syncBulkProcessor.Flush(); err != nil { + return model.NewAppError("Opensearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + return nil +} + func (os *OpensearchInterfaceImpl) SearchChannels(teamId, userID string, term string, isGuest, includeDeleted bool) ([]string, *model.AppError) { os.mutex.RLock() defer os.mutex.RUnlock() diff --git a/server/enterprise/elasticsearch/opensearch/opensearch_test.go b/server/enterprise/elasticsearch/opensearch/opensearch_test.go index 5cfada6612b..60b2f82710c 100644 --- a/server/enterprise/elasticsearch/opensearch/opensearch_test.go +++ b/server/enterprise/elasticsearch/opensearch/opensearch_test.go @@ -124,3 +124,76 @@ func (s *OpensearchInterfaceTestSuite) SetupTest() { s.Nil(s.CommonTestSuite.ESImpl.PurgeIndexes(s.th.Context)) } + +func (s *OpensearchInterfaceTestSuite) TestSyncBulkIndexChannels() { + s.Run("Should index multiple channels successfully", func() { + // Create test channels + channel1 := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypeOpen, + Name: "test-channel-1", + DisplayName: "Test Channel 1", + } + channel1.PreSave() + + channel2 := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypePrivate, + Name: "test-channel-2", + DisplayName: "Test Channel 2", + } + channel2.PreSave() + + channels := []*model.Channel{channel1, channel2} + + // Mock getUserIDsForChannel function + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return []string{s.th.BasicUser.Id, s.th.BasicUser2.Id}, nil + } + + teamMemberIDs := []string{s.th.BasicUser.Id, s.th.BasicUser2.Id} + + // Test the bulk indexing + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, channels, getUserIDsForChannel, teamMemberIDs) + s.Require().Nil(appErr) + + // Refresh the index to ensure data is searchable + s.Require().NoError(s.CommonTestSuite.RefreshIndexFn()) + + // Verify both channels are indexed + found, _, err := s.CommonTestSuite.GetDocumentFn("channels", channel1.Id) + s.Require().NoError(err) + s.Require().True(found) + + found, _, err = s.CommonTestSuite.GetDocumentFn("channels", channel2.Id) + s.Require().NoError(err) + s.Require().True(found) + }) + + s.Run("Should handle empty channels list", func() { + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return []string{}, nil + } + + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, []*model.Channel{}, getUserIDsForChannel, []string{}) + s.Require().Nil(appErr) + }) + + s.Run("Should handle getUserIDsForChannel error", func() { + channel := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypeOpen, + Name: "test-channel-error", + DisplayName: "Test Channel Error", + } + channel.PreSave() + + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return nil, model.NewAppError("TestError", "test.error", nil, "", 500) + } + + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, []*model.Channel{channel}, getUserIDsForChannel, []string{}) + s.Require().NotNil(appErr) + s.Require().Contains(appErr.Error(), "test.error") + }) +} diff --git a/server/i18n/en.json b/server/i18n/en.json index b79de1db8ce..8007ab01782 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -7912,6 +7912,14 @@ "id": "ent.elasticsearch.create_client.connect_failed", "translation": "Setting up {{.Backend}} Client Failed" }, + { + "id": "ent.elasticsearch.create_processor.bulk_processor_create_failed", + "translation": "Failed to create Elasticsearch bulk processor" + }, + { + "id": "ent.elasticsearch.create_processor.sync_bulk_processor_create_failed", + "translation": "Failed to create Elasticsearch sync bulk processor" + }, { "id": "ent.elasticsearch.create_template_channels_if_not_exists.template_create_failed", "translation": "Failed to create {{.Backend}} template for channels" diff --git a/server/platform/services/searchengine/interface.go b/server/platform/services/searchengine/interface.go index 69740171d55..40f24196efb 100644 --- a/server/platform/services/searchengine/interface.go +++ b/server/platform/services/searchengine/interface.go @@ -33,6 +33,7 @@ type SearchEngineInterface interface { // IndexChannel indexes a given channel. The userIDs are only populated // for private channels. IndexChannel(rctx request.CTX, channel *model.Channel, userIDs, teamMemberIDs []string) *model.AppError + SyncBulkIndexChannels(rctx request.CTX, channels []*model.Channel, getUserIDsForChannel func(channel *model.Channel) ([]string, error), teamMemberIDs []string) *model.AppError SearchChannels(teamId, userID, term string, isGuest, includeDeleted bool) ([]string, *model.AppError) DeleteChannel(channel *model.Channel) *model.AppError IndexUser(rctx request.CTX, user *model.User, teamsIds, channelsIds []string) *model.AppError diff --git a/server/platform/services/searchengine/mocks/SearchEngineInterface.go b/server/platform/services/searchengine/mocks/SearchEngineInterface.go index 7aa7ab9f8ec..1c49890a27a 100644 --- a/server/platform/services/searchengine/mocks/SearchEngineInterface.go +++ b/server/platform/services/searchengine/mocks/SearchEngineInterface.go @@ -757,6 +757,26 @@ func (_m *SearchEngineInterface) Stop() *model.AppError { return r0 } +// SyncBulkIndexChannels provides a mock function with given fields: rctx, channels, getUserIDsForChannel, teamMemberIDs +func (_m *SearchEngineInterface) SyncBulkIndexChannels(rctx request.CTX, channels []*model.Channel, getUserIDsForChannel func(*model.Channel) ([]string, error), teamMemberIDs []string) *model.AppError { + ret := _m.Called(rctx, channels, getUserIDsForChannel, teamMemberIDs) + + if len(ret) == 0 { + panic("no return value specified for SyncBulkIndexChannels") + } + + var r0 *model.AppError + if rf, ok := ret.Get(0).(func(request.CTX, []*model.Channel, func(*model.Channel) ([]string, error), []string) *model.AppError); ok { + r0 = rf(rctx, channels, getUserIDsForChannel, teamMemberIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.AppError) + } + } + + return r0 +} + // TestConfig provides a mock function with given fields: rctx, cfg func (_m *SearchEngineInterface) TestConfig(rctx request.CTX, cfg *model.Config) *model.AppError { ret := _m.Called(rctx, cfg) diff --git a/server/public/utils/page.go b/server/public/utils/page.go new file mode 100644 index 00000000000..c9e000036e8 --- /dev/null +++ b/server/public/utils/page.go @@ -0,0 +1,43 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package utils + +// Pager fetches all items from a paginated API. +// Pager is a generic function that fetches and aggregates paginated data. +// It takes a fetch function and a perPage parameter as arguments. +// +// The fetch function is responsible for retrieving a slice of items of type T +// for a given page number. It returns the fetched items and an error, if any. +// Ideally a developer may want to use a closure to create a fetch function. +// +// The perPage parameter specifies the number of items to fetch per page. +// +// Example usage: +// +// items, err := Pager(fetchFunc, 10) +// if err != nil { +// // handle error +// } +// // process items +func Pager[T any](fetch func(page int) ([]T, error), perPage int) ([]T, error) { + var list []T + var page int + + for { + fetched, err := fetch(page) + if err != nil { + return list, err + } + + list = append(list, fetched...) + + if len(fetched) < perPage { + break + } + + page++ + } + + return list, nil +} diff --git a/server/public/utils/page_test.go b/server/public/utils/page_test.go new file mode 100644 index 00000000000..e5a8adbe48d --- /dev/null +++ b/server/public/utils/page_test.go @@ -0,0 +1,69 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package utils + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPager(t *testing.T) { + tests := []struct { + name string + fetch func(page int) ([]int, error) + perPage int + expected []int + expectErr bool + }{ + { + name: "successful fetch", + fetch: func(page int) ([]int, error) { + if page > 2 { + return nil, nil + } + return []int{page*10 + 1, page*10 + 2, page*10 + 3}, nil + }, + perPage: 3, + expected: []int{1, 2, 3, 11, 12, 13, 21, 22, 23}, + }, + { + name: "fetch with error", + fetch: func(page int) ([]int, error) { + if page == 1 { + return nil, errors.New("fetch error") + } + return []int{page*10 + 1, page*10 + 2, page*10 + 3}, nil + }, + perPage: 3, + expected: []int{1, 2, 3}, + expectErr: true, + }, + { + name: "fetch with fewer items than perPage", + fetch: func(page int) ([]int, error) { + if page > 0 { + return []int{11, 12}, nil + } + return []int{page*10 + 1, page*10 + 2, page*10 + 3}, nil + }, + perPage: 3, + expected: []int{1, 2, 3, 11, 12}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Pager(tt.fetch, tt.perPage) + if tt.expectErr { + assert.Error(t, err) + assert.Equal(t, tt.expected, result) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} diff --git a/webapp/channels/src/actions/views/channel.ts b/webapp/channels/src/actions/views/channel.ts index 2974a2a53af..ce7d15f6979 100644 --- a/webapp/channels/src/actions/views/channel.ts +++ b/webapp/channels/src/actions/views/channel.ts @@ -365,17 +365,17 @@ export interface LoadPostsParameters { channelId: string; postId: string; type: CanLoadMorePosts; + perPage: number; } export function loadPosts({ channelId, postId, type, + perPage, }: LoadPostsParameters): ThunkActionFunc> { //type here can be BEFORE_ID or AFTER_ID return async (dispatch) => { - const POST_INCREASE_AMOUNT = Constants.POST_CHUNK_SIZE / 2; - dispatch({ type: ActionTypes.LOADING_POSTS, data: true, @@ -385,9 +385,9 @@ export function loadPosts({ const page = 0; let result; if (type === PostRequestTypes.BEFORE_ID) { - result = await dispatch(PostActions.getPostsBefore(channelId, postId, page, POST_INCREASE_AMOUNT)); + result = await dispatch(PostActions.getPostsBefore(channelId, postId, page, perPage)); } else { - result = await dispatch(PostActions.getPostsAfter(channelId, postId, page, POST_INCREASE_AMOUNT)); + result = await dispatch(PostActions.getPostsAfter(channelId, postId, page, perPage)); } const {data} = result; diff --git a/webapp/channels/src/components/__snapshots__/setting_item_min.test.tsx.snap b/webapp/channels/src/components/__snapshots__/setting_item_min.test.tsx.snap deleted file mode 100644 index 7c79e274d13..00000000000 --- a/webapp/channels/src/components/__snapshots__/setting_item_min.test.tsx.snap +++ /dev/null @@ -1,75 +0,0 @@ -// Jest Snapshot v1, https://goo.gl/fbAQLP - -exports[`components/SettingItemMin should match snapshot 1`] = ` -
-
-

- title -

- -
-
- describe -
-
-`; - -exports[`components/SettingItemMin should match snapshot, on disableOpen to true 1`] = ` -
-
-

- title -

- -
-
- describe -
-
-`; diff --git a/webapp/channels/src/components/external_link/__snapshots__/external_link.test.tsx.snap b/webapp/channels/src/components/external_link/__snapshots__/external_link.test.tsx.snap deleted file mode 100644 index a342da968df..00000000000 --- a/webapp/channels/src/components/external_link/__snapshots__/external_link.test.tsx.snap +++ /dev/null @@ -1,31 +0,0 @@ -// Jest Snapshot v1, https://goo.gl/fbAQLP - -exports[`components/external_link should match snapshot 1`] = ` - - - - Click Me - - - -`; diff --git a/webapp/channels/src/components/external_link/external_link.test.tsx b/webapp/channels/src/components/external_link/external_link.test.tsx index 862688651a5..77ae0f1ee6b 100644 --- a/webapp/channels/src/components/external_link/external_link.test.tsx +++ b/webapp/channels/src/components/external_link/external_link.test.tsx @@ -1,14 +1,11 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import {mount} from 'enzyme'; import React from 'react'; -import {Provider} from 'react-redux'; import type {DeepPartial} from '@mattermost/types/utilities'; import {renderWithContext, screen} from 'tests/react_testing_utils'; -import mockStore from 'tests/test_store'; import type {GlobalState} from 'types/store'; @@ -29,19 +26,23 @@ describe('components/external_link', () => { }, }; - it('should match snapshot', () => { - const store = mockStore(initialState); - const wrapper = mount( - - {'Click Me'} - , + it('should render external link with correct attributes', () => { + renderWithContext( + + {'Click Me'} + , + initialState, ); - expect(wrapper).toMatchSnapshot(); + const linkElement = screen.getByRole('link', {name: 'Click Me'}); + + expect(linkElement).toBeInTheDocument(); + expect(linkElement).toHaveAttribute('target', '_blank'); + expect(linkElement).toHaveAttribute('rel', 'noopener noreferrer'); + expect(linkElement).toHaveAttribute('href', expect.stringContaining('https://mattermost.com')); }); it('should attach parameters', () => { @@ -67,7 +68,9 @@ describe('components/external_link', () => { state, ); - expect(screen.queryByText('Click Me')).toHaveAttribute( + const linkElement = screen.getByRole('link', {name: 'Click Me'}); + + expect(linkElement).toHaveAttribute( 'href', expect.stringMatching('utm_source=mattermost&utm_medium=in-product-cloud&utm_content=test&uid=currentUserId&sid='), ); diff --git a/webapp/channels/src/components/post_view/post_list/post_list.test.tsx b/webapp/channels/src/components/post_view/post_list/post_list.test.tsx index 31c645f8ad7..d397b39add5 100644 --- a/webapp/channels/src/components/post_view/post_list/post_list.test.tsx +++ b/webapp/channels/src/components/post_view/post_list/post_list.test.tsx @@ -94,14 +94,14 @@ describe('components/post_view/post_list', () => { wrapper.find(VirtPostList).prop('actions').loadOlderPosts(); expect(wrapper.state('loadingOlderPosts')).toEqual(true); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID}); - await wrapper.instance().callLoadPosts('undefined', 'undefined', undefined); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID, perPage: 30}); + await wrapper.instance().callLoadPosts('undefined', 'undefined', undefined, 30); expect(wrapper.state('loadingOlderPosts')).toBe(false); wrapper.find(VirtPostList).prop('actions').loadNewerPosts(); expect(wrapper.state('loadingNewerPosts')).toEqual(true); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID}); - await wrapper.instance().callLoadPosts('undefined', 'undefined', undefined); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID, perPage: 30}); + await wrapper.instance().callLoadPosts('undefined', 'undefined', undefined, 30); expect(wrapper.state('loadingNewerPosts')).toBe(false); }); @@ -192,7 +192,7 @@ describe('components/post_view/post_list', () => { const wrapper = shallow(); wrapper.setProps({atOldestPost: false}); wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(undefined); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID}); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID, perPage: 200}); }); test('Should call getPostsAfter if all older posts are loaded and not newerPosts', async () => { @@ -200,14 +200,14 @@ describe('components/post_view/post_list', () => { const wrapper = shallow(); wrapper.setProps({atOldestPost: true}); wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(undefined); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID}); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID, perPage: 30}); }); test('Should call getPostsAfter canLoadMorePosts is requested with AFTER_ID', async () => { const postIds = createFakePosIds(2); const wrapper = shallow(); wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(PostRequestTypes.AFTER_ID); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID}); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID, perPage: 30}); }); }); @@ -231,7 +231,7 @@ describe('components/post_view/post_list', () => { wrapper.find(VirtPostList).prop('actions').loadOlderPosts(); expect(wrapper.state('loadingOlderPosts')).toEqual(true); expect(loadPosts).toHaveBeenCalledTimes(1); - expect(loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID}); + expect(loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID, perPage: 30}); await loadPosts(); expect(wrapper.state('loadingOlderPosts')).toBe(false); expect(loadPosts).toHaveBeenCalledTimes(3); @@ -260,4 +260,114 @@ describe('components/post_view/post_list', () => { expect(actionsProp.markChannelAsRead).not.toHaveBeenCalled(); }); }); + + describe('Differentiated page sizes', () => { + test('Should use 30 posts for user scroll (getPostsBefore)', async () => { + const postIds = createFakePosIds(2); + const loadPosts = jest.fn().mockImplementation(() => Promise.resolve({moreToLoad: true})); + const props = { + ...baseProps, + postListIds: postIds, + actions: { + ...actionsProp, + loadPosts, + }, + }; + + const wrapper = shallow( + , + ); + + // Trigger user scroll up + wrapper.find(VirtPostList).prop('actions').loadOlderPosts(); + + expect(loadPosts).toHaveBeenCalledWith({ + channelId: baseProps.channelId, + postId: postIds[postIds.length - 1], + type: PostRequestTypes.BEFORE_ID, + perPage: 30, + }); + }); + + test('Should use 200 posts for auto-loading (getPostsBeforeAutoLoad)', async () => { + const postIds = createFakePosIds(2); + const loadPosts = jest.fn().mockImplementation(() => Promise.resolve({moreToLoad: true})); + const props = { + ...baseProps, + postListIds: postIds, + atOldestPost: false, + actions: { + ...actionsProp, + loadPosts, + }, + }; + + const wrapper = shallow( + , + ); + + // Trigger auto-loading via canLoadMorePosts + await wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(PostRequestTypes.BEFORE_ID); + + expect(loadPosts).toHaveBeenCalledWith({ + channelId: baseProps.channelId, + postId: postIds[postIds.length - 1], + type: PostRequestTypes.BEFORE_ID, + perPage: 200, // AUTO_LOAD_POSTS_PER_PAGE + }); + }); + + test('Should use 30 posts for user scroll down (getPostsAfter)', async () => { + const postIds = createFakePosIds(2); + const loadPosts = jest.fn().mockImplementation(() => Promise.resolve({moreToLoad: true})); + const props = { + ...baseProps, + postListIds: postIds, + actions: { + ...actionsProp, + loadPosts, + }, + }; + + const wrapper = shallow( + , + ); + + // Trigger user scroll down + wrapper.find(VirtPostList).prop('actions').loadNewerPosts(); + + expect(loadPosts).toHaveBeenCalledWith({ + channelId: baseProps.channelId, + postId: postIds[0], + type: PostRequestTypes.AFTER_ID, + perPage: 30, // USER_SCROLL_POSTS_PER_PAGE + }); + }); + + test('Should use 200 posts when canLoadMorePosts is triggered with BEFORE_ID', async () => { + const postIds = createFakePosIds(2); + const loadPosts = jest.fn().mockImplementation(() => Promise.resolve({moreToLoad: true})); + const props = { + ...baseProps, + postListIds: postIds, + atOldestPost: false, + actions: { + ...actionsProp, + loadPosts, + }, + }; + + const wrapper = shallow( + , + ); + + // Trigger auto-loading via canLoadMorePosts + wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(PostRequestTypes.BEFORE_ID); + + // Should use auto-load page size (200 posts) + expect(loadPosts).toHaveBeenCalledWith(expect.objectContaining({ + perPage: 200, + })); + }); + }); }); diff --git a/webapp/channels/src/components/post_view/post_list/post_list.tsx b/webapp/channels/src/components/post_view/post_list/post_list.tsx index a441ca5d329..86f8bee98f9 100644 --- a/webapp/channels/src/components/post_view/post_list/post_list.tsx +++ b/webapp/channels/src/components/post_view/post_list/post_list.tsx @@ -12,12 +12,16 @@ import type {LoadPostsParameters, LoadPostsReturnValue, CanLoadMorePosts} from ' import LoadingScreen from 'components/loading_screen'; import VirtPostList from 'components/post_view/post_list_virtualized/post_list_virtualized'; -import {PostRequestTypes} from 'utils/constants'; +import {PostRequestTypes, Constants} from 'utils/constants'; import {Mark, Measure, measureAndReport} from 'utils/performance_telemetry'; import {getOldestPostId, getLatestPostId} from 'utils/post_utils'; const MAX_NUMBER_OF_AUTO_RETRIES = 3; -export const MAX_EXTRA_PAGES_LOADED = 10; +export const MAX_EXTRA_PAGES_LOADED = 30; + +// Post loading page sizes +const USER_SCROLL_POSTS_PER_PAGE = Constants.POST_CHUNK_SIZE / 2; // 30 posts for user scroll +const AUTO_LOAD_POSTS_PER_PAGE = 200; // Maximum server limit for auto-loading // Measures the time between channel or team switch started and the post list component rendering posts. // Set "fresh" to true when the posts have not been loaded before. @@ -263,11 +267,12 @@ export default class PostList extends React.PureComponent { } }; - callLoadPosts = async (channelId: string, postId: string, type: CanLoadMorePosts) => { + callLoadPosts = async (channelId: string, postId: string, type: CanLoadMorePosts, perPage: number) => { const {error} = await this.props.actions.loadPosts({ channelId, postId, type, + perPage, }); if (type === PostRequestTypes.BEFORE_ID) { @@ -279,7 +284,7 @@ export default class PostList extends React.PureComponent { if (error) { if (this.autoRetriesCount < MAX_NUMBER_OF_AUTO_RETRIES) { this.autoRetriesCount++; - await this.callLoadPosts(channelId, postId, type); + await this.callLoadPosts(channelId, postId, type, perPage); } else if (this.mounted) { this.setState({autoRetryEnable: false}); } @@ -327,7 +332,7 @@ export default class PostList extends React.PureComponent { } if (!this.props.atOldestPost && type === PostRequestTypes.BEFORE_ID) { - await this.getPostsBefore(); + await this.getPostsBeforeAutoLoad(); } else if (!this.props.atLatestPost) { // if all olderPosts are loaded load new ones await this.getPostsAfter(); @@ -348,7 +353,7 @@ export default class PostList extends React.PureComponent { const oldestPostId = this.getOldestVisiblePostId(); this.setState({loadingOlderPosts: true}); - await this.callLoadPosts(this.props.channelId, oldestPostId, PostRequestTypes.BEFORE_ID); + await this.callLoadPosts(this.props.channelId, oldestPostId, PostRequestTypes.BEFORE_ID, USER_SCROLL_POSTS_PER_PAGE); }; getPostsAfter = async () => { @@ -363,7 +368,17 @@ export default class PostList extends React.PureComponent { const latestPostId = this.getLatestVisiblePostId(); this.setState({loadingNewerPosts: true}); - await this.callLoadPosts(this.props.channelId, latestPostId, PostRequestTypes.AFTER_ID); + await this.callLoadPosts(this.props.channelId, latestPostId, PostRequestTypes.AFTER_ID, USER_SCROLL_POSTS_PER_PAGE); + }; + + getPostsBeforeAutoLoad = async () => { + if (this.state.loadingOlderPosts) { + return; + } + + const oldestPostId = this.getOldestVisiblePostId(); + this.setState({loadingOlderPosts: true}); + await this.callLoadPosts(this.props.channelId, oldestPostId, PostRequestTypes.BEFORE_ID, AUTO_LOAD_POSTS_PER_PAGE); }; render() { diff --git a/webapp/channels/src/components/setting_item_min.test.tsx b/webapp/channels/src/components/setting_item_min.test.tsx index 29a627dc461..32e9f79d8ff 100644 --- a/webapp/channels/src/components/setting_item_min.test.tsx +++ b/webapp/channels/src/components/setting_item_min.test.tsx @@ -1,62 +1,111 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import {shallow} from 'enzyme'; import React from 'react'; +import {renderWithContext, screen, userEvent} from 'tests/react_testing_utils'; + import SettingItemMin from './setting_item_min'; describe('components/SettingItemMin', () => { const baseProps = { - title: 'title', - disableOpen: false, - section: 'section', + title: 'Test Title', + isDisabled: false, + section: 'test-section', updateSection: jest.fn(), - describe: 'describe', - isMobileView: false, - actions: { - updateActiveSection: jest.fn(), - }, + describe: 'Test description', }; - test('should match snapshot', () => { - const wrapper = shallow( - , - ); + test('should render with default props', () => { + renderWithContext(); + + expect(screen.getByText('Test Title')).toBeInTheDocument(); + expect(screen.getByText('Test description')).toBeInTheDocument(); + expect(screen.getByRole('button', {name: 'Test Title Edit'})).toBeInTheDocument(); + }); + + test('should render without edit button when disabled', () => { + const props = {...baseProps, isDisabled: true}; + renderWithContext(); - expect(wrapper).toMatchSnapshot(); + expect(screen.getByText('Test Title')).toBeInTheDocument(); + expect(screen.getByText('Test description')).toBeInTheDocument(); + expect(screen.queryByRole('button')).not.toBeInTheDocument(); }); - test('should match snapshot, on disableOpen to true', () => { - const props = {...baseProps, disableOpen: true}; - const wrapper = shallow( - , - ); + test('should render custom disabled edit button when provided', () => { + const customEditButton = {'Custom Edit Button'}; + const props = { + ...baseProps, + isDisabled: true, + collapsedEditButtonWhenDisabled: customEditButton, + }; + renderWithContext(); - expect(wrapper).toMatchSnapshot(); + expect(screen.getByText('Custom Edit Button')).toBeInTheDocument(); + expect(screen.queryByRole('button')).not.toBeInTheDocument(); }); - test('should have called updateSection on handleClick with section', () => { + test('should call updateSection when edit button is clicked', async () => { const updateSection = jest.fn(); const props = {...baseProps, updateSection}; - const wrapper = shallow( - , - ); + renderWithContext(); + + const editButton = screen.getByRole('button', {name: 'Test Title Edit'}); + await userEvent.click(editButton); + + expect(updateSection).toHaveBeenCalledTimes(1); + expect(updateSection).toHaveBeenCalledWith('test-section'); + }); + + test('should call updateSection when container div is clicked', async () => { + const updateSection = jest.fn(); + const props = {...baseProps, updateSection}; + renderWithContext(); + + const container = screen.getByText('Test Title').closest('.section-min'); + await userEvent.click(container!); - wrapper.instance().handleClick({preventDefault: jest.fn()} as any); - expect(updateSection).toHaveBeenCalled(); - expect(updateSection).toHaveBeenCalledWith('section'); + expect(updateSection).toHaveBeenCalledTimes(1); + expect(updateSection).toHaveBeenCalledWith('test-section'); }); - test('should have called updateSection on handleClick with empty string', () => { + test('should not call updateSection when disabled and edit button area is clicked', async () => { const updateSection = jest.fn(); - const props = {...baseProps, updateSection, section: ''}; - const wrapper = shallow( - , - ); - - wrapper.instance().handleClick({preventDefault: jest.fn()} as any); - expect(updateSection).toHaveBeenCalled(); - expect(updateSection).toHaveBeenCalledWith(''); + const props = {...baseProps, updateSection, isDisabled: true}; + renderWithContext(); + + const container = screen.getByText('Test Title').closest('.section-min'); + await userEvent.click(container!); + + expect(updateSection).not.toHaveBeenCalled(); + }); + + test('should have correct accessibility attributes', () => { + renderWithContext(); + + const editButton = screen.getByRole('button', {name: 'Test Title Edit'}); + expect(editButton).toHaveAttribute('aria-expanded', 'false'); + expect(editButton).toHaveAttribute('id', 'test-sectionEdit'); + expect(editButton).toHaveAttribute('aria-labelledby', 'test-sectionTitle test-sectionEdit'); + + const title = screen.getByText('Test Title'); + expect(title).toHaveAttribute('id', 'test-sectionTitle'); + + const description = screen.getByText('Test description'); + expect(description).toHaveAttribute('id', 'test-sectionDesc'); + }); + + test('should apply disabled styling when isDisabled is true', () => { + const props = {...baseProps, isDisabled: true}; + renderWithContext(); + + const container = screen.getByText('Test Title').closest('.section-min'); + const title = screen.getByText('Test Title'); + const description = screen.getByText('Test description'); + + expect(container).toHaveClass('isDisabled'); + expect(title).toHaveClass('isDisabled'); + expect(description).toHaveClass('isDisabled'); }); }); diff --git a/webapp/channels/src/components/setting_item_min.tsx b/webapp/channels/src/components/setting_item_min.tsx index ba3af422b5e..fd3546d6ff2 100644 --- a/webapp/channels/src/components/setting_item_min.tsx +++ b/webapp/channels/src/components/setting_item_min.tsx @@ -59,6 +59,7 @@ export default class SettingItemMin extends React.PureComponent { } e.preventDefault(); + e.stopPropagation(); this.props.updateSection(this.props.section); }; @@ -96,7 +97,7 @@ export default class SettingItemMin extends React.PureComponent { onClick={this.handleClick} >

.secion-min__header { + > .section-min__header { display: flex; flex-direction: row; justify-content: space-between;