From 0a119a9f4af4d6265d164ea90ef71cdc1cf5c807 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20V=C3=A9lez?= Date: Mon, 25 Aug 2025 17:24:58 +0200 Subject: [PATCH 1/3] migrate two components tests from enzyme to RTL (#33610) Co-authored-by: Mattermost Build --- .../setting_item_min.test.tsx.snap | 75 ----------- .../__snapshots__/external_link.test.tsx.snap | 31 ----- .../external_link/external_link.test.tsx | 33 ++--- .../src/components/setting_item_min.test.tsx | 121 ++++++++++++------ .../src/components/setting_item_min.tsx | 3 +- .../user_settings_notifications.test.tsx.snap | 56 ++++---- .../__snapshots__/index.test.tsx.snap | 2 +- .../channels/src/sass/routes/_settings.scss | 2 +- 8 files changed, 135 insertions(+), 188 deletions(-) delete mode 100644 webapp/channels/src/components/__snapshots__/setting_item_min.test.tsx.snap delete mode 100644 webapp/channels/src/components/external_link/__snapshots__/external_link.test.tsx.snap diff --git a/webapp/channels/src/components/__snapshots__/setting_item_min.test.tsx.snap b/webapp/channels/src/components/__snapshots__/setting_item_min.test.tsx.snap deleted file mode 100644 index 7c79e274d13..00000000000 --- a/webapp/channels/src/components/__snapshots__/setting_item_min.test.tsx.snap +++ /dev/null @@ -1,75 +0,0 @@ -// Jest Snapshot v1, https://goo.gl/fbAQLP - -exports[`components/SettingItemMin should match snapshot 1`] = ` -
-
-

- title -

- -
-
- describe -
-
-`; - -exports[`components/SettingItemMin should match snapshot, on disableOpen to true 1`] = ` -
-
-

- title -

- -
-
- describe -
-
-`; diff --git a/webapp/channels/src/components/external_link/__snapshots__/external_link.test.tsx.snap b/webapp/channels/src/components/external_link/__snapshots__/external_link.test.tsx.snap deleted file mode 100644 index a342da968df..00000000000 --- a/webapp/channels/src/components/external_link/__snapshots__/external_link.test.tsx.snap +++ /dev/null @@ -1,31 +0,0 @@ -// Jest Snapshot v1, https://goo.gl/fbAQLP - -exports[`components/external_link should match snapshot 1`] = ` - - - - Click Me - - - -`; diff --git a/webapp/channels/src/components/external_link/external_link.test.tsx b/webapp/channels/src/components/external_link/external_link.test.tsx index 862688651a5..77ae0f1ee6b 100644 --- a/webapp/channels/src/components/external_link/external_link.test.tsx +++ b/webapp/channels/src/components/external_link/external_link.test.tsx @@ -1,14 +1,11 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import {mount} from 'enzyme'; import React from 'react'; -import {Provider} from 'react-redux'; import type {DeepPartial} from '@mattermost/types/utilities'; import {renderWithContext, screen} from 'tests/react_testing_utils'; -import mockStore from 'tests/test_store'; import type {GlobalState} from 'types/store'; @@ -29,19 +26,23 @@ describe('components/external_link', () => { }, }; - it('should match snapshot', () => { - const store = mockStore(initialState); - const wrapper = mount( - - {'Click Me'} - , + it('should render external link with correct attributes', () => { + renderWithContext( + + {'Click Me'} + , + initialState, ); - expect(wrapper).toMatchSnapshot(); + const linkElement = screen.getByRole('link', {name: 'Click Me'}); + + expect(linkElement).toBeInTheDocument(); + expect(linkElement).toHaveAttribute('target', '_blank'); + expect(linkElement).toHaveAttribute('rel', 'noopener noreferrer'); + expect(linkElement).toHaveAttribute('href', expect.stringContaining('https://mattermost.com')); }); it('should attach parameters', () => { @@ -67,7 +68,9 @@ describe('components/external_link', () => { state, ); - expect(screen.queryByText('Click Me')).toHaveAttribute( + const linkElement = screen.getByRole('link', {name: 'Click Me'}); + + expect(linkElement).toHaveAttribute( 'href', expect.stringMatching('utm_source=mattermost&utm_medium=in-product-cloud&utm_content=test&uid=currentUserId&sid='), ); diff --git a/webapp/channels/src/components/setting_item_min.test.tsx b/webapp/channels/src/components/setting_item_min.test.tsx index 29a627dc461..32e9f79d8ff 100644 --- a/webapp/channels/src/components/setting_item_min.test.tsx +++ b/webapp/channels/src/components/setting_item_min.test.tsx @@ -1,62 +1,111 @@ // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. // See LICENSE.txt for license information. -import {shallow} from 'enzyme'; import React from 'react'; +import {renderWithContext, screen, userEvent} from 'tests/react_testing_utils'; + import SettingItemMin from './setting_item_min'; describe('components/SettingItemMin', () => { const baseProps = { - title: 'title', - disableOpen: false, - section: 'section', + title: 'Test Title', + isDisabled: false, + section: 'test-section', updateSection: jest.fn(), - describe: 'describe', - isMobileView: false, - actions: { - updateActiveSection: jest.fn(), - }, + describe: 'Test description', }; - test('should match snapshot', () => { - const wrapper = shallow( - , - ); + test('should render with default props', () => { + renderWithContext(); + + expect(screen.getByText('Test Title')).toBeInTheDocument(); + expect(screen.getByText('Test description')).toBeInTheDocument(); + expect(screen.getByRole('button', {name: 'Test Title Edit'})).toBeInTheDocument(); + }); + + test('should render without edit button when disabled', () => { + const props = {...baseProps, isDisabled: true}; + renderWithContext(); - expect(wrapper).toMatchSnapshot(); + expect(screen.getByText('Test Title')).toBeInTheDocument(); + expect(screen.getByText('Test description')).toBeInTheDocument(); + expect(screen.queryByRole('button')).not.toBeInTheDocument(); }); - test('should match snapshot, on disableOpen to true', () => { - const props = {...baseProps, disableOpen: true}; - const wrapper = shallow( - , - ); + test('should render custom disabled edit button when provided', () => { + const customEditButton = {'Custom Edit Button'}; + const props = { + ...baseProps, + isDisabled: true, + collapsedEditButtonWhenDisabled: customEditButton, + }; + renderWithContext(); - expect(wrapper).toMatchSnapshot(); + expect(screen.getByText('Custom Edit Button')).toBeInTheDocument(); + expect(screen.queryByRole('button')).not.toBeInTheDocument(); }); - test('should have called updateSection on handleClick with section', () => { + test('should call updateSection when edit button is clicked', async () => { const updateSection = jest.fn(); const props = {...baseProps, updateSection}; - const wrapper = shallow( - , - ); + renderWithContext(); + + const editButton = screen.getByRole('button', {name: 'Test Title Edit'}); + await userEvent.click(editButton); + + expect(updateSection).toHaveBeenCalledTimes(1); + expect(updateSection).toHaveBeenCalledWith('test-section'); + }); + + test('should call updateSection when container div is clicked', async () => { + const updateSection = jest.fn(); + const props = {...baseProps, updateSection}; + renderWithContext(); + + const container = screen.getByText('Test Title').closest('.section-min'); + await userEvent.click(container!); - wrapper.instance().handleClick({preventDefault: jest.fn()} as any); - expect(updateSection).toHaveBeenCalled(); - expect(updateSection).toHaveBeenCalledWith('section'); + expect(updateSection).toHaveBeenCalledTimes(1); + expect(updateSection).toHaveBeenCalledWith('test-section'); }); - test('should have called updateSection on handleClick with empty string', () => { + test('should not call updateSection when disabled and edit button area is clicked', async () => { const updateSection = jest.fn(); - const props = {...baseProps, updateSection, section: ''}; - const wrapper = shallow( - , - ); - - wrapper.instance().handleClick({preventDefault: jest.fn()} as any); - expect(updateSection).toHaveBeenCalled(); - expect(updateSection).toHaveBeenCalledWith(''); + const props = {...baseProps, updateSection, isDisabled: true}; + renderWithContext(); + + const container = screen.getByText('Test Title').closest('.section-min'); + await userEvent.click(container!); + + expect(updateSection).not.toHaveBeenCalled(); + }); + + test('should have correct accessibility attributes', () => { + renderWithContext(); + + const editButton = screen.getByRole('button', {name: 'Test Title Edit'}); + expect(editButton).toHaveAttribute('aria-expanded', 'false'); + expect(editButton).toHaveAttribute('id', 'test-sectionEdit'); + expect(editButton).toHaveAttribute('aria-labelledby', 'test-sectionTitle test-sectionEdit'); + + const title = screen.getByText('Test Title'); + expect(title).toHaveAttribute('id', 'test-sectionTitle'); + + const description = screen.getByText('Test description'); + expect(description).toHaveAttribute('id', 'test-sectionDesc'); + }); + + test('should apply disabled styling when isDisabled is true', () => { + const props = {...baseProps, isDisabled: true}; + renderWithContext(); + + const container = screen.getByText('Test Title').closest('.section-min'); + const title = screen.getByText('Test Title'); + const description = screen.getByText('Test description'); + + expect(container).toHaveClass('isDisabled'); + expect(title).toHaveClass('isDisabled'); + expect(description).toHaveClass('isDisabled'); }); }); diff --git a/webapp/channels/src/components/setting_item_min.tsx b/webapp/channels/src/components/setting_item_min.tsx index ba3af422b5e..fd3546d6ff2 100644 --- a/webapp/channels/src/components/setting_item_min.tsx +++ b/webapp/channels/src/components/setting_item_min.tsx @@ -59,6 +59,7 @@ export default class SettingItemMin extends React.PureComponent { } e.preventDefault(); + e.stopPropagation(); this.props.updateSection(this.props.section); }; @@ -96,7 +97,7 @@ export default class SettingItemMin extends React.PureComponent { onClick={this.handleClick} >

.secion-min__header { + > .section-min__header { display: flex; flex-direction: row; justify-content: space-between; From d5aa8211a2bf1e2bc2a9d438deb7f1a15af57666 Mon Sep 17 00:00:00 2001 From: Scott Bishel Date: Mon, 25 Aug 2025 10:52:25 -0600 Subject: [PATCH 2/3] MM-64978 Implement differentiated page sizes for auto loading vs scroll loading (#33607) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement differentiated page sizes for post loading Add perPage parameter to loadPosts function to optimize loading performance based on context: - User scroll: 30 posts per request (responsive UX) - Auto-loading: 200 posts per request (efficient for sparse channels) Changes: - Make perPage required parameter in LoadPostsParameters interface - Add USER_SCROLL_POSTS_PER_PAGE (30) and AUTO_LOAD_POSTS_PER_PAGE (200) constants - Create separate getPostsBeforeAutoLoad() method for high-volume auto-loading - Update canLoadMorePosts() to use auto-load method for maximum efficiency This reduces server round-trips by 6.7x when auto-loading content in channels with heavy join/leave message activity while maintaining responsive user-initiated scrolling. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude * drop pages to 20 * drop pages to 30 * lint fixes * lint fixes, add tests * more fixes --------- Co-authored-by: Claude Co-authored-by: Mattermost Build --- webapp/channels/src/actions/views/channel.ts | 8 +- .../post_view/post_list/post_list.test.tsx | 126 ++++++++++++++++-- .../post_view/post_list/post_list.tsx | 29 +++- 3 files changed, 144 insertions(+), 19 deletions(-) diff --git a/webapp/channels/src/actions/views/channel.ts b/webapp/channels/src/actions/views/channel.ts index 2974a2a53af..ce7d15f6979 100644 --- a/webapp/channels/src/actions/views/channel.ts +++ b/webapp/channels/src/actions/views/channel.ts @@ -365,17 +365,17 @@ export interface LoadPostsParameters { channelId: string; postId: string; type: CanLoadMorePosts; + perPage: number; } export function loadPosts({ channelId, postId, type, + perPage, }: LoadPostsParameters): ThunkActionFunc> { //type here can be BEFORE_ID or AFTER_ID return async (dispatch) => { - const POST_INCREASE_AMOUNT = Constants.POST_CHUNK_SIZE / 2; - dispatch({ type: ActionTypes.LOADING_POSTS, data: true, @@ -385,9 +385,9 @@ export function loadPosts({ const page = 0; let result; if (type === PostRequestTypes.BEFORE_ID) { - result = await dispatch(PostActions.getPostsBefore(channelId, postId, page, POST_INCREASE_AMOUNT)); + result = await dispatch(PostActions.getPostsBefore(channelId, postId, page, perPage)); } else { - result = await dispatch(PostActions.getPostsAfter(channelId, postId, page, POST_INCREASE_AMOUNT)); + result = await dispatch(PostActions.getPostsAfter(channelId, postId, page, perPage)); } const {data} = result; diff --git a/webapp/channels/src/components/post_view/post_list/post_list.test.tsx b/webapp/channels/src/components/post_view/post_list/post_list.test.tsx index 31c645f8ad7..d397b39add5 100644 --- a/webapp/channels/src/components/post_view/post_list/post_list.test.tsx +++ b/webapp/channels/src/components/post_view/post_list/post_list.test.tsx @@ -94,14 +94,14 @@ describe('components/post_view/post_list', () => { wrapper.find(VirtPostList).prop('actions').loadOlderPosts(); expect(wrapper.state('loadingOlderPosts')).toEqual(true); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID}); - await wrapper.instance().callLoadPosts('undefined', 'undefined', undefined); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID, perPage: 30}); + await wrapper.instance().callLoadPosts('undefined', 'undefined', undefined, 30); expect(wrapper.state('loadingOlderPosts')).toBe(false); wrapper.find(VirtPostList).prop('actions').loadNewerPosts(); expect(wrapper.state('loadingNewerPosts')).toEqual(true); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID}); - await wrapper.instance().callLoadPosts('undefined', 'undefined', undefined); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID, perPage: 30}); + await wrapper.instance().callLoadPosts('undefined', 'undefined', undefined, 30); expect(wrapper.state('loadingNewerPosts')).toBe(false); }); @@ -192,7 +192,7 @@ describe('components/post_view/post_list', () => { const wrapper = shallow(); wrapper.setProps({atOldestPost: false}); wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(undefined); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID}); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID, perPage: 200}); }); test('Should call getPostsAfter if all older posts are loaded and not newerPosts', async () => { @@ -200,14 +200,14 @@ describe('components/post_view/post_list', () => { const wrapper = shallow(); wrapper.setProps({atOldestPost: true}); wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(undefined); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID}); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID, perPage: 30}); }); test('Should call getPostsAfter canLoadMorePosts is requested with AFTER_ID', async () => { const postIds = createFakePosIds(2); const wrapper = shallow(); wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(PostRequestTypes.AFTER_ID); - expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID}); + expect(actionsProp.loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[0], type: PostRequestTypes.AFTER_ID, perPage: 30}); }); }); @@ -231,7 +231,7 @@ describe('components/post_view/post_list', () => { wrapper.find(VirtPostList).prop('actions').loadOlderPosts(); expect(wrapper.state('loadingOlderPosts')).toEqual(true); expect(loadPosts).toHaveBeenCalledTimes(1); - expect(loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID}); + expect(loadPosts).toHaveBeenCalledWith({channelId: baseProps.channelId, postId: postIds[postIds.length - 1], type: PostRequestTypes.BEFORE_ID, perPage: 30}); await loadPosts(); expect(wrapper.state('loadingOlderPosts')).toBe(false); expect(loadPosts).toHaveBeenCalledTimes(3); @@ -260,4 +260,114 @@ describe('components/post_view/post_list', () => { expect(actionsProp.markChannelAsRead).not.toHaveBeenCalled(); }); }); + + describe('Differentiated page sizes', () => { + test('Should use 30 posts for user scroll (getPostsBefore)', async () => { + const postIds = createFakePosIds(2); + const loadPosts = jest.fn().mockImplementation(() => Promise.resolve({moreToLoad: true})); + const props = { + ...baseProps, + postListIds: postIds, + actions: { + ...actionsProp, + loadPosts, + }, + }; + + const wrapper = shallow( + , + ); + + // Trigger user scroll up + wrapper.find(VirtPostList).prop('actions').loadOlderPosts(); + + expect(loadPosts).toHaveBeenCalledWith({ + channelId: baseProps.channelId, + postId: postIds[postIds.length - 1], + type: PostRequestTypes.BEFORE_ID, + perPage: 30, + }); + }); + + test('Should use 200 posts for auto-loading (getPostsBeforeAutoLoad)', async () => { + const postIds = createFakePosIds(2); + const loadPosts = jest.fn().mockImplementation(() => Promise.resolve({moreToLoad: true})); + const props = { + ...baseProps, + postListIds: postIds, + atOldestPost: false, + actions: { + ...actionsProp, + loadPosts, + }, + }; + + const wrapper = shallow( + , + ); + + // Trigger auto-loading via canLoadMorePosts + await wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(PostRequestTypes.BEFORE_ID); + + expect(loadPosts).toHaveBeenCalledWith({ + channelId: baseProps.channelId, + postId: postIds[postIds.length - 1], + type: PostRequestTypes.BEFORE_ID, + perPage: 200, // AUTO_LOAD_POSTS_PER_PAGE + }); + }); + + test('Should use 30 posts for user scroll down (getPostsAfter)', async () => { + const postIds = createFakePosIds(2); + const loadPosts = jest.fn().mockImplementation(() => Promise.resolve({moreToLoad: true})); + const props = { + ...baseProps, + postListIds: postIds, + actions: { + ...actionsProp, + loadPosts, + }, + }; + + const wrapper = shallow( + , + ); + + // Trigger user scroll down + wrapper.find(VirtPostList).prop('actions').loadNewerPosts(); + + expect(loadPosts).toHaveBeenCalledWith({ + channelId: baseProps.channelId, + postId: postIds[0], + type: PostRequestTypes.AFTER_ID, + perPage: 30, // USER_SCROLL_POSTS_PER_PAGE + }); + }); + + test('Should use 200 posts when canLoadMorePosts is triggered with BEFORE_ID', async () => { + const postIds = createFakePosIds(2); + const loadPosts = jest.fn().mockImplementation(() => Promise.resolve({moreToLoad: true})); + const props = { + ...baseProps, + postListIds: postIds, + atOldestPost: false, + actions: { + ...actionsProp, + loadPosts, + }, + }; + + const wrapper = shallow( + , + ); + + // Trigger auto-loading via canLoadMorePosts + wrapper.find(VirtPostList).prop('actions').canLoadMorePosts(PostRequestTypes.BEFORE_ID); + + // Should use auto-load page size (200 posts) + expect(loadPosts).toHaveBeenCalledWith(expect.objectContaining({ + perPage: 200, + })); + }); + }); }); diff --git a/webapp/channels/src/components/post_view/post_list/post_list.tsx b/webapp/channels/src/components/post_view/post_list/post_list.tsx index a441ca5d329..86f8bee98f9 100644 --- a/webapp/channels/src/components/post_view/post_list/post_list.tsx +++ b/webapp/channels/src/components/post_view/post_list/post_list.tsx @@ -12,12 +12,16 @@ import type {LoadPostsParameters, LoadPostsReturnValue, CanLoadMorePosts} from ' import LoadingScreen from 'components/loading_screen'; import VirtPostList from 'components/post_view/post_list_virtualized/post_list_virtualized'; -import {PostRequestTypes} from 'utils/constants'; +import {PostRequestTypes, Constants} from 'utils/constants'; import {Mark, Measure, measureAndReport} from 'utils/performance_telemetry'; import {getOldestPostId, getLatestPostId} from 'utils/post_utils'; const MAX_NUMBER_OF_AUTO_RETRIES = 3; -export const MAX_EXTRA_PAGES_LOADED = 10; +export const MAX_EXTRA_PAGES_LOADED = 30; + +// Post loading page sizes +const USER_SCROLL_POSTS_PER_PAGE = Constants.POST_CHUNK_SIZE / 2; // 30 posts for user scroll +const AUTO_LOAD_POSTS_PER_PAGE = 200; // Maximum server limit for auto-loading // Measures the time between channel or team switch started and the post list component rendering posts. // Set "fresh" to true when the posts have not been loaded before. @@ -263,11 +267,12 @@ export default class PostList extends React.PureComponent { } }; - callLoadPosts = async (channelId: string, postId: string, type: CanLoadMorePosts) => { + callLoadPosts = async (channelId: string, postId: string, type: CanLoadMorePosts, perPage: number) => { const {error} = await this.props.actions.loadPosts({ channelId, postId, type, + perPage, }); if (type === PostRequestTypes.BEFORE_ID) { @@ -279,7 +284,7 @@ export default class PostList extends React.PureComponent { if (error) { if (this.autoRetriesCount < MAX_NUMBER_OF_AUTO_RETRIES) { this.autoRetriesCount++; - await this.callLoadPosts(channelId, postId, type); + await this.callLoadPosts(channelId, postId, type, perPage); } else if (this.mounted) { this.setState({autoRetryEnable: false}); } @@ -327,7 +332,7 @@ export default class PostList extends React.PureComponent { } if (!this.props.atOldestPost && type === PostRequestTypes.BEFORE_ID) { - await this.getPostsBefore(); + await this.getPostsBeforeAutoLoad(); } else if (!this.props.atLatestPost) { // if all olderPosts are loaded load new ones await this.getPostsAfter(); @@ -348,7 +353,7 @@ export default class PostList extends React.PureComponent { const oldestPostId = this.getOldestVisiblePostId(); this.setState({loadingOlderPosts: true}); - await this.callLoadPosts(this.props.channelId, oldestPostId, PostRequestTypes.BEFORE_ID); + await this.callLoadPosts(this.props.channelId, oldestPostId, PostRequestTypes.BEFORE_ID, USER_SCROLL_POSTS_PER_PAGE); }; getPostsAfter = async () => { @@ -363,7 +368,17 @@ export default class PostList extends React.PureComponent { const latestPostId = this.getLatestVisiblePostId(); this.setState({loadingNewerPosts: true}); - await this.callLoadPosts(this.props.channelId, latestPostId, PostRequestTypes.AFTER_ID); + await this.callLoadPosts(this.props.channelId, latestPostId, PostRequestTypes.AFTER_ID, USER_SCROLL_POSTS_PER_PAGE); + }; + + getPostsBeforeAutoLoad = async () => { + if (this.state.loadingOlderPosts) { + return; + } + + const oldestPostId = this.getOldestVisiblePostId(); + this.setState({loadingOlderPosts: true}); + await this.callLoadPosts(this.props.channelId, oldestPostId, PostRequestTypes.BEFORE_ID, AUTO_LOAD_POSTS_PER_PAGE); }; render() { From 553f99612e59344ddc703d16afbc3bfd5be25861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Mon, 25 Aug 2025 19:28:19 +0200 Subject: [PATCH 3/3] MM-60441: Re-index public channels when a user joins a team (#33400) * Index all public channels when a user joins a team * Precompute team members for indexChannelsForTeam * Refactor RequestContextWithMaster to store package This way, we can import it from both the sqlstore and the searchlayer packages. The alternative for this is duplicating the code in those two packages, but that will *not* work: The context package expects custom types for the keys stored in it, so that different packages never clash with each other when trying to register a new key. See the docs for the WithValue function: https://pkg.go.dev/context#WithValue If we try to duplicate the storeContextKey type in both the sqlstore and searchlayer packages, although they *look* the same, they are not, and HasMaster will fail to get the value of the storeContextKey(useMaster) key if it's from the other package. * Use master in call to GetTeamMembersForChannel In GetTeamMembersForChannel, use the DB from the newly passed context, which will be the receiving context everywhere except in the call done from indexChannelsForTeam, to avoid the read after write issue when saving a team member. * Fix GetPublicChannelsForTeam paging We were using the page and perPage arguments as is in the call to GetPublicChannelsForTeam, but that function expects and offset and a limit as understood by SQL. Although perPage and limit are interchangeable, offset is not equal to page, but to page * perPage. * Add a synchronous bulk indexer for Opensearch * Implement Opensearch's SyncBulkIndexChannels * Add a synchronous bulk indexer for Elasticsearch * Implement Elasticsearch's SynkBulkIndexChannels * Test SyncBulkIndexChannels * make mocks * Bulk index channels on indexChannelsForTeam * Handle error from SyncBulkIndexChannels * Fix style * Revert indexChannelWithTeamMembers refactor * Remove defensive code on sync bulk processor * Revert "Add a synchronous bulk indexer for Opensearch" This reverts commit bfe4671d96bffa9ca27ed3c655fc5527b72bbafb. * Revert "Add a synchronous bulk indexer for Elasticsearch" This reverts commit 6643ae3f30c461544d0861aec7dee1f24e507c37. * Refactor bulk indexers with a common interface * Test all the different implementations Assisted by Claude * Remove debug statements * Refactor common code into _stop * Rename getUserIDsFor{,Private}Channel * Wrap error * Make perPage a const * Fix typos * Call GetTeamsForUser only if needed * Differentiate errors for sync/async processors --------- Co-authored-by: Ibrahim Serdar Acikgoz Co-authored-by: Mattermost Build --- server/channels/store/context.go | 46 ++ server/channels/store/context_test.go | 55 +++ .../channels/store/retrylayer/retrylayer.go | 4 +- .../store/searchlayer/channel_layer.go | 26 +- server/channels/store/searchlayer/layer.go | 30 ++ .../channels/store/searchlayer/team_layer.go | 26 +- .../channels/store/sqlstore/channel_store.go | 4 +- server/channels/store/sqlstore/context.go | 25 +- server/channels/store/store.go | 2 +- .../store/storetest/mocks/ChannelStore.go | 18 +- .../channels/store/timerlayer/timerlayer.go | 4 +- .../enterprise/elasticsearch/common/common.go | 9 + .../elasticsearch/common/indexing_job.go | 3 +- .../elasticsearch/common/indexing_job_test.go | 5 +- .../elasticsearch/elasticsearch/bulk.go | 152 ++---- .../elasticsearch/bulk_client_data.go | 184 ++++++++ .../elasticsearch/bulk_client_data_test.go | 355 ++++++++++++++ .../elasticsearch/bulk_client_req.go | 164 +++++++ .../elasticsearch/bulk_client_req_test.go | 263 +++++++++++ .../elasticsearch/elasticsearch/bulk_test.go | 78 +++- .../elasticsearch/elasticsearch.go | 85 +++- .../elasticsearch/elasticsearch_test.go | 73 +++ .../elasticsearch/elasticsearch/sync_bulk.go | 94 ++++ .../elasticsearch/opensearch/bulk.go | 67 ++- .../elasticsearch/opensearch/bulk_test.go | 442 +++++++++++++++++- .../elasticsearch/opensearch/opensearch.go | 74 ++- .../opensearch/opensearch_test.go | 73 +++ server/i18n/en.json | 8 + .../services/searchengine/interface.go | 1 + .../mocks/SearchEngineInterface.go | 20 + server/public/utils/page.go | 43 ++ server/public/utils/page_test.go | 69 +++ 32 files changed, 2279 insertions(+), 223 deletions(-) create mode 100644 server/channels/store/context.go create mode 100644 server/channels/store/context_test.go create mode 100644 server/enterprise/elasticsearch/elasticsearch/bulk_client_data.go create mode 100644 server/enterprise/elasticsearch/elasticsearch/bulk_client_data_test.go create mode 100644 server/enterprise/elasticsearch/elasticsearch/bulk_client_req.go create mode 100644 server/enterprise/elasticsearch/elasticsearch/bulk_client_req_test.go create mode 100644 server/enterprise/elasticsearch/elasticsearch/sync_bulk.go create mode 100644 server/public/utils/page.go create mode 100644 server/public/utils/page_test.go diff --git a/server/channels/store/context.go b/server/channels/store/context.go new file mode 100644 index 00000000000..7b68d01d3f0 --- /dev/null +++ b/server/channels/store/context.go @@ -0,0 +1,46 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package store + +import ( + "context" + + "github.com/mattermost/mattermost/server/public/shared/request" +) + +// storeContextKey is the base type for all context keys for the store. +type storeContextKey string + +// contextValue is a type to hold some pre-determined context values. +type contextValue string + +// Different possible values of contextValue. +const ( + useMaster contextValue = "useMaster" +) + +// WithMaster adds the context value that master DB should be selected for this request. +// +// Deprecated: This method is deprecated and there's ongoing change to use `request.CTX` across +// instead of `context.Context`. Please use `RequestContextWithMaster` instead. +func WithMaster(ctx context.Context) context.Context { + return context.WithValue(ctx, storeContextKey(useMaster), true) +} + +// RequestContextWithMaster adds the context value that master DB should be selected for this request. +func RequestContextWithMaster(c request.CTX) request.CTX { + ctx := WithMaster(c.Context()) + c = c.WithContext(ctx) + return c +} + +// HasMaster is a helper function to check whether master DB should be selected or not. +func HasMaster(ctx context.Context) bool { + if v := ctx.Value(storeContextKey(useMaster)); v != nil { + if res, ok := v.(bool); ok && res { + return true + } + } + return false +} diff --git a/server/channels/store/context_test.go b/server/channels/store/context_test.go new file mode 100644 index 00000000000..44c4e17257c --- /dev/null +++ b/server/channels/store/context_test.go @@ -0,0 +1,55 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package store + +import ( + "context" + "testing" + + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/stretchr/testify/assert" +) + +func TestContextMaster(t *testing.T) { + ctx := context.Background() + + m := WithMaster(ctx) + assert.True(t, HasMaster(m)) +} + +func TestRequestContextWithMaster(t *testing.T) { + t.Run("set and get", func(t *testing.T) { + var rctx request.CTX = request.TestContext(t) + + rctx = RequestContextWithMaster(rctx) + assert.True(t, HasMaster(rctx.Context())) + }) + + t.Run("values get copied from original context", func(t *testing.T) { + var rctx request.CTX = request.TestContext(t) + rctx = RequestContextWithMaster(rctx) + rctxCopy := rctx + + assert.True(t, HasMaster(rctx.Context())) + assert.True(t, HasMaster(rctxCopy.Context())) + }) + + t.Run("directly assigning does not cause the copy to alter the original context", func(t *testing.T) { + var rctx request.CTX = request.TestContext(t) + rctxCopy := rctx + rctxCopy = RequestContextWithMaster(rctxCopy) + + assert.False(t, HasMaster(rctx.Context())) + assert.True(t, HasMaster(rctxCopy.Context())) + }) + + t.Run("directly assigning does not cause the original context to alter the copy", func(t *testing.T) { + var rctx request.CTX = request.TestContext(t) + rctxCopy := rctx + rctx = RequestContextWithMaster(rctx) + + assert.True(t, HasMaster(rctx.Context())) + assert.False(t, HasMaster(rctxCopy.Context())) + }) +} diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index bb4f76fe2fe..464f1867658 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -2472,11 +2472,11 @@ func (s *RetryLayerChannelStore) GetTeamForChannel(channelID string) (*model.Tea } -func (s *RetryLayerChannelStore) GetTeamMembersForChannel(channelID string) ([]string, error) { +func (s *RetryLayerChannelStore) GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) { tries := 0 for { - result, err := s.ChannelStore.GetTeamMembersForChannel(channelID) + result, err := s.ChannelStore.GetTeamMembersForChannel(rctx, channelID) if err == nil { return result, nil } diff --git a/server/channels/store/searchlayer/channel_layer.go b/server/channels/store/searchlayer/channel_layer.go index 60db3a8d323..b06fab1e644 100644 --- a/server/channels/store/searchlayer/channel_layer.go +++ b/server/channels/store/searchlayer/channel_layer.go @@ -47,7 +47,7 @@ func (c *SearchChannelStore) indexChannel(rctx request.CTX, channel *model.Chann } } - teamMemberIDs, err = c.GetTeamMembersForChannel(channel.Id) + teamMemberIDs, err = c.GetTeamMembersForChannel(rctx, channel.Id) if err != nil { rctx.Logger().Warn("Encountered error while indexing channel", mlog.String("channel_id", channel.Id), mlog.Err(err)) return @@ -66,6 +66,30 @@ func (c *SearchChannelStore) indexChannel(rctx request.CTX, channel *model.Chann } } +func (c *SearchChannelStore) bulkIndexChannels(rctx request.CTX, channels []*model.Channel, teamMemberIDs []string) { + // Util function to get userIDs, only for private channels + getUserIDsForPrivateChannel := func(channel *model.Channel) ([]string, error) { + if channel.Type != model.ChannelTypePrivate { + return []string{}, nil + } + return c.GetAllChannelMemberIdsByChannelId(channel.Id) + } + + for _, engine := range c.rootStore.searchEngine.GetActiveEngines() { + if !engine.IsIndexingEnabled() { + continue + } + + runIndexFn(rctx, engine, func(engineCopy searchengine.SearchEngineInterface) { + appErr := engineCopy.SyncBulkIndexChannels(rctx, channels, getUserIDsForPrivateChannel, teamMemberIDs) + if appErr != nil { + rctx.Logger().Error("Failed to synchronously bulk-index channels.", mlog.String("search_engine", engineCopy.GetName()), mlog.Err(appErr)) + return + } + }) + } +} + func (c *SearchChannelStore) Save(rctx request.CTX, channel *model.Channel, maxChannels int64, channelOptions ...model.ChannelOption) (*model.Channel, error) { newChannel, err := c.ChannelStore.Save(rctx, channel, maxChannels, channelOptions...) if err == nil { diff --git a/server/channels/store/searchlayer/layer.go b/server/channels/store/searchlayer/layer.go index 4ed2a5d6cfe..f30aa84f09e 100644 --- a/server/channels/store/searchlayer/layer.go +++ b/server/channels/store/searchlayer/layer.go @@ -9,6 +9,7 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/public/utils" "github.com/mattermost/mattermost/server/v8/channels/store" "github.com/mattermost/mattermost/server/v8/platform/services/searchengine" ) @@ -111,6 +112,35 @@ func (s *SearchStore) indexUser(rctx request.CTX, user *model.User) { } } +func (s *SearchStore) indexChannelsForTeam(rctx request.CTX, teamID string) { + const perPage = 100 + var ( + channels []*model.Channel + ) + + channels, err := utils.Pager(func(page int) ([]*model.Channel, error) { + return s.channel.GetPublicChannelsForTeam(teamID, page*perPage, perPage) + }, perPage) + if err != nil { + rctx.Logger().Warn("Encountered error while retrieving public channels for indexing", mlog.String("team_id", teamID), mlog.Err(err)) + return + } + + if len(channels) == 0 { + return + } + + // Use master context to avoid replica lag issues when reading team members + masterRctx := store.RequestContextWithMaster(rctx) + teamMemberIDs, err := s.channel.GetTeamMembersForChannel(masterRctx, channels[0].Id) + if err != nil { + rctx.Logger().Warn("Encountered error while retrieving team members for channel", mlog.String("channel_id", channels[0].Id), mlog.Err(err)) + return + } + + s.channel.bulkIndexChannels(rctx, channels, teamMemberIDs) +} + // Runs an indexing function synchronously or asynchronously depending on the engine func runIndexFn(rctx request.CTX, engine searchengine.SearchEngineInterface, indexFn func(searchengine.SearchEngineInterface)) { if engine.IsIndexingSync() { diff --git a/server/channels/store/searchlayer/team_layer.go b/server/channels/store/searchlayer/team_layer.go index 46601e1dbb7..d0a2b9f1c50 100644 --- a/server/channels/store/searchlayer/team_layer.go +++ b/server/channels/store/searchlayer/team_layer.go @@ -17,7 +17,11 @@ type SearchTeamStore struct { func (s SearchTeamStore) SaveMember(rctx request.CTX, teamMember *model.TeamMember, maxUsersPerTeam int) (*model.TeamMember, error) { member, err := s.TeamStore.SaveMember(rctx, teamMember, maxUsersPerTeam) if err == nil { - s.rootStore.indexUserFromID(rctx, member.UserId) + // Nothing to do if search engine is not active + if s.rootStore.searchEngine.ActiveEngine() != "database" && s.rootStore.searchEngine.ActiveEngine() != "none" { + s.rootStore.indexUserFromID(rctx, member.UserId) + s.rootStore.indexChannelsForTeam(rctx, member.TeamId) + } } return member, err } @@ -33,15 +37,31 @@ func (s SearchTeamStore) UpdateMember(rctx request.CTX, teamMember *model.TeamMe func (s SearchTeamStore) RemoveMember(rctx request.CTX, teamId string, userId string) error { err := s.TeamStore.RemoveMember(rctx, teamId, userId) if err == nil { - s.rootStore.indexUserFromID(rctx, userId) + // Nothing to do if search engine is not active + if s.rootStore.searchEngine.ActiveEngine() != "database" && s.rootStore.searchEngine.ActiveEngine() != "none" { + s.rootStore.indexUserFromID(rctx, userId) + s.rootStore.indexChannelsForTeam(rctx, teamId) + } } return err } func (s SearchTeamStore) RemoveAllMembersByUser(rctx request.CTX, userId string) error { + if s.rootStore.searchEngine.ActiveEngine() != "database" && s.rootStore.searchEngine.ActiveEngine() != "none" { + memberships, err := s.TeamStore.GetTeamsForUser(rctx, userId, "", true) + if err != nil { + return err + } + for _, membership := range memberships { + s.rootStore.indexChannelsForTeam(rctx, membership.TeamId) + } + } + err := s.TeamStore.RemoveAllMembersByUser(rctx, userId) if err == nil { - s.rootStore.indexUserFromID(rctx, userId) + if s.rootStore.searchEngine.ActiveEngine() != "database" && s.rootStore.searchEngine.ActiveEngine() != "none" { + s.rootStore.indexUserFromID(rctx, userId) + } } return err } diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go index 1c252f2af21..59d069aa5c4 100644 --- a/server/channels/store/sqlstore/channel_store.go +++ b/server/channels/store/sqlstore/channel_store.go @@ -3046,9 +3046,9 @@ func (s SqlChannelStore) GetMembersForUserWithCursorPagination(userId string, pe return dbMembers.ToModel(), nil } -func (s SqlChannelStore) GetTeamMembersForChannel(channelID string) ([]string, error) { +func (s SqlChannelStore) GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) { teamMemberIDs := []string{} - if err := s.GetReplica().Select(&teamMemberIDs, `SELECT tm.UserId + if err := s.DBXFromContext(rctx.Context()).Select(&teamMemberIDs, `SELECT tm.UserId FROM Channels c, Teams t, TeamMembers tm WHERE c.TeamId=t.Id diff --git a/server/channels/store/sqlstore/context.go b/server/channels/store/sqlstore/context.go index d35694202ae..7509804e44b 100644 --- a/server/channels/store/sqlstore/context.go +++ b/server/channels/store/sqlstore/context.go @@ -7,17 +7,7 @@ import ( "context" "github.com/mattermost/mattermost/server/public/shared/request" -) - -// storeContextKey is the base type for all context keys for the store. -type storeContextKey string - -// contextValue is a type to hold some pre-determined context values. -type contextValue string - -// Different possible values of contextValue. -const ( - useMaster contextValue = "useMaster" + "github.com/mattermost/mattermost/server/v8/channels/store" ) // WithMaster adds the context value that master DB should be selected for this request. @@ -25,24 +15,17 @@ const ( // Deprecated: This method is deprecated and there's ongoing change to use `request.CTX` across // instead of `context.Context`. Please use `RequestContextWithMaster` instead. func WithMaster(ctx context.Context) context.Context { - return context.WithValue(ctx, storeContextKey(useMaster), true) + return store.WithMaster(ctx) } // RequestContextWithMaster adds the context value that master DB should be selected for this request. func RequestContextWithMaster(c request.CTX) request.CTX { - ctx := WithMaster(c.Context()) - c = c.WithContext(ctx) - return c + return store.RequestContextWithMaster(c) } // HasMaster is a helper function to check whether master DB should be selected or not. func HasMaster(ctx context.Context) bool { - if v := ctx.Value(storeContextKey(useMaster)); v != nil { - if res, ok := v.(bool); ok && res { - return true - } - } - return false + return store.HasMaster(ctx) } // DBXFromContext is a helper utility that returns the sqlx DB handle from a given context. diff --git a/server/channels/store/store.go b/server/channels/store/store.go index c57c4ae5d8e..7290f45c1b2 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -266,7 +266,7 @@ type ChannelStore interface { AnalyticsDeletedTypeCount(teamID string, channelType model.ChannelType) (int64, error) AnalyticsCountAll(teamID string) (map[model.ChannelType]int64, error) GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error) - GetTeamMembersForChannel(channelID string) ([]string, error) + GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) GetMembersForUserWithPagination(userID string, page, perPage int) (model.ChannelMembersWithTeamData, error) GetMembersForUserWithCursorPagination(userId string, perPage int, fromChanneID string) (model.ChannelMembersWithTeamData, error) Autocomplete(rctx request.CTX, userID, term string, includeDeleted, isGuest bool) (model.ChannelListWithTeamData, error) diff --git a/server/channels/store/storetest/mocks/ChannelStore.go b/server/channels/store/storetest/mocks/ChannelStore.go index f271d219fcd..7f970734b6c 100644 --- a/server/channels/store/storetest/mocks/ChannelStore.go +++ b/server/channels/store/storetest/mocks/ChannelStore.go @@ -2148,9 +2148,9 @@ func (_m *ChannelStore) GetTeamForChannel(channelID string) (*model.Team, error) return r0, r1 } -// GetTeamMembersForChannel provides a mock function with given fields: channelID -func (_m *ChannelStore) GetTeamMembersForChannel(channelID string) ([]string, error) { - ret := _m.Called(channelID) +// GetTeamMembersForChannel provides a mock function with given fields: rctx, channelID +func (_m *ChannelStore) GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) { + ret := _m.Called(rctx, channelID) if len(ret) == 0 { panic("no return value specified for GetTeamMembersForChannel") @@ -2158,19 +2158,19 @@ func (_m *ChannelStore) GetTeamMembersForChannel(channelID string) ([]string, er var r0 []string var r1 error - if rf, ok := ret.Get(0).(func(string) ([]string, error)); ok { - return rf(channelID) + if rf, ok := ret.Get(0).(func(request.CTX, string) ([]string, error)); ok { + return rf(rctx, channelID) } - if rf, ok := ret.Get(0).(func(string) []string); ok { - r0 = rf(channelID) + if rf, ok := ret.Get(0).(func(request.CTX, string) []string); ok { + r0 = rf(rctx, channelID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]string) } } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(channelID) + if rf, ok := ret.Get(1).(func(request.CTX, string) error); ok { + r1 = rf(rctx, channelID) } else { r1 = ret.Error(1) } diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 5130bf3ce70..6887d45059e 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -2033,10 +2033,10 @@ func (s *TimerLayerChannelStore) GetTeamForChannel(channelID string) (*model.Tea return result, err } -func (s *TimerLayerChannelStore) GetTeamMembersForChannel(channelID string) ([]string, error) { +func (s *TimerLayerChannelStore) GetTeamMembersForChannel(rctx request.CTX, channelID string) ([]string, error) { start := time.Now() - result, err := s.ChannelStore.GetTeamMembersForChannel(channelID) + result, err := s.ChannelStore.GetTeamMembersForChannel(rctx, channelID) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { diff --git a/server/enterprise/elasticsearch/common/common.go b/server/enterprise/elasticsearch/common/common.go index 4dcdcbf29c0..3070f212739 100644 --- a/server/enterprise/elasticsearch/common/common.go +++ b/server/enterprise/elasticsearch/common/common.go @@ -34,8 +34,17 @@ const ( // At the moment, this number is hardcoded. If needed, we can expose // this to the config. BulkFlushInterval = 5 * time.Second + + // Size of the largest request to be done, in bytes + BulkFlushBytes = 10 * 1024 * 1024 // 10 MiB ) +type BulkSettings struct { + FlushBytes int + FlushInterval time.Duration + FlushNumReqs int +} + var ( urlRe = regexp.MustCompile(URLRegexpRE) markdownLinkRe = regexp.MustCompile(URLMarkdownLinkRE) diff --git a/server/enterprise/elasticsearch/common/indexing_job.go b/server/enterprise/elasticsearch/common/indexing_job.go index 5482ccd81ec..983521268ce 100644 --- a/server/enterprise/elasticsearch/common/indexing_job.go +++ b/server/enterprise/elasticsearch/common/indexing_job.go @@ -569,7 +569,8 @@ func BulkIndexChannels(config *model.Config, } } - teamMemberIDs, err := store.Channel().GetTeamMembersForChannel(channel.Id) + rctx := request.EmptyContext(logger) + teamMemberIDs, err := store.Channel().GetTeamMembersForChannel(rctx, channel.Id) if err != nil { return nil, model.NewAppError("IndexerWorker.BulkIndexChannels", "ent.elasticsearch.getAllTeamMembers.error", nil, "", http.StatusInternalServerError).Wrap(err) } diff --git a/server/enterprise/elasticsearch/common/indexing_job_test.go b/server/enterprise/elasticsearch/common/indexing_job_test.go index 24c302f1914..a8c9597459e 100644 --- a/server/enterprise/elasticsearch/common/indexing_job_test.go +++ b/server/enterprise/elasticsearch/common/indexing_job_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/mattermost/mattermost/server/public/model" @@ -37,8 +38,8 @@ func TestBulkIndexChannelsWithDeletedChannels(t *testing.T) { // Since these are open channels, GetAllChannelMemberIdsByChannelId won't be called // But GetTeamMembersForChannel will be called for both channels - mockChannelStore.On("GetTeamMembersForChannel", "ch1").Return([]string{"team1"}, nil) - mockChannelStore.On("GetTeamMembersForChannel", "ch2").Return([]string{"team1"}, nil) + mockChannelStore.On("GetTeamMembersForChannel", mock.AnythingOfType("*request.Context"), "ch1").Return([]string{"team1"}, nil) + mockChannelStore.On("GetTeamMembersForChannel", mock.AnythingOfType("*request.Context"), "ch2").Return([]string{"team1"}, nil) // Track which channels were actually indexed indexedChannels := make(map[string]bool) diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk.go b/server/enterprise/elasticsearch/elasticsearch/bulk.go index 9488e68e8f1..342dd15d291 100644 --- a/server/enterprise/elasticsearch/elasticsearch/bulk.go +++ b/server/enterprise/elasticsearch/elasticsearch/bulk.go @@ -4,130 +4,58 @@ package elasticsearch import ( - "context" - "sync" + "fmt" "time" elastic "github.com/elastic/go-elasticsearch/v8" - "github.com/elastic/go-elasticsearch/v8/typedapi/core/bulk" - "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/mattermost/mattermost/server/public/model" + esTypes "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" ) -type Bulk struct { - mut sync.Mutex - - logger mlog.LoggerIFace - client *elastic.TypedClient - bulkClient *bulk.Bulk - settings model.ElasticsearchSettings - - quitFlusher chan struct{} - quitFlusherWg sync.WaitGroup - - pendingRequests int -} - -func NewBulk(settings model.ElasticsearchSettings, +type BulkClient interface { + IndexOp(op esTypes.IndexOperation, doc any) error + DeleteOp(op esTypes.DeleteOperation) error + Flush() error + Stop() error +} + +// NewBulk returns a BulkClient, with the specific implementation depending on +// the specified thresholds in bulkSettings. +// NewBulk will return an error if bulkSettings.FlushNumReqs and +// bulkSettings.FlushBytes are both non-zero: the support of these thresholds +// by the implementations of BulkClient is mutually exclusive. +func NewBulk(bulkSettings common.BulkSettings, + client *elastic.TypedClient, + reqTimeout time.Duration, logger mlog.LoggerIFace, - client *elastic.TypedClient) *Bulk { - b := &Bulk{ - settings: settings, - logger: logger, - client: client, - bulkClient: client.Bulk(), - quitFlusher: make(chan struct{}), - } - - b.quitFlusherWg.Add(1) - go b.periodicFlusher() - - return b -} - -// IndexOp is a helper function to add an IndexOperation to the current bulk request. -// doc argument can be a []byte, json.RawMessage or a struct. -func (r *Bulk) IndexOp(op types.IndexOperation, doc any) error { - r.mut.Lock() - defer r.mut.Unlock() - - if err := r.bulkClient.IndexOp(op, doc); err != nil { - return err - } - - return r.flushIfNecessary() -} - -// DeleteOp is a helper function to add a DeleteOperation to the current bulk request. -func (r *Bulk) DeleteOp(op types.DeleteOperation) error { - r.mut.Lock() - defer r.mut.Unlock() - - if err := r.bulkClient.DeleteOp(op); err != nil { - return err - } - - return r.flushIfNecessary() -} - -// flushIfNecessary flushes the pending buffer if needed. -// It MUST be called with an already acquired mutex. -func (r *Bulk) flushIfNecessary() error { - r.pendingRequests++ - - if r.pendingRequests > *r.settings.LiveIndexingBatchSize { - return r._flush() +) (BulkClient, error) { + if bulkSettings.FlushBytes == 0 && + bulkSettings.FlushInterval == 0 && + bulkSettings.FlushNumReqs == 0 { + return nil, fmt.Errorf("at least one of FlushBytes, FlushInterval or FlushNumReqs should be non-zero") } - - return nil -} - -func (r *Bulk) Stop() error { - r.mut.Lock() - defer r.mut.Unlock() - r.logger.Info("Stopping Bulk processor") - - if r.pendingRequests > 0 { - return r._flush() + if bulkSettings.FlushBytes > 0 && bulkSettings.FlushNumReqs > 0 { + return nil, fmt.Errorf( + "one of bulkSettings.FlushBytes (set to %d) or bulkSettings.FlushNumReqs (set to %d) should be zero", + bulkSettings.FlushBytes, + bulkSettings.FlushNumReqs, + ) } - close(r.quitFlusher) - r.quitFlusherWg.Wait() - - return nil -} - -func (r *Bulk) periodicFlusher() { - defer r.quitFlusherWg.Done() - - for { - select { - case <-time.After(common.BulkFlushInterval): - r.mut.Lock() - if r.pendingRequests > 0 { - if err := r._flush(); err != nil { - r.logger.Warn("Error flushing live indexing buffer", mlog.Err(err)) - } - } - r.mut.Unlock() - case <-r.quitFlusher: - return + var bulkClient BulkClient + var err error + if bulkSettings.FlushBytes > 0 { + bulkClient, err = NewDataBulkClient(bulkSettings, client, reqTimeout, logger) + if err != nil { + return nil, err + } + } else { + bulkClient, err = NewReqBulkClient(bulkSettings, client, reqTimeout, logger) + if err != nil { + return nil, err } } -} - -// _flush MUST be called with an acquired lock. -func (r *Bulk) _flush() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*r.settings.RequestTimeoutSeconds)*time.Second) - defer cancel() - - _, err := r.bulkClient.Do(ctx) - if err != nil { - return err - } - r.pendingRequests = 0 - return nil + return bulkClient, nil } diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_client_data.go b/server/enterprise/elasticsearch/elasticsearch/bulk_client_data.go new file mode 100644 index 00000000000..2da7664335c --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_client_data.go @@ -0,0 +1,184 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "sync" + "time" + + elastic "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/esutil" + esTypes "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" +) + +// DataBulkClient is an Elasticsearch bulk client based on the +// go-elasticsearch/v8/esutil.BulkIndexer type. +// It supports time- and size-based thresholds, but not a threshold on number +// of requests. +type DataBulkClient struct { + mut sync.Mutex + + indexer esutil.BulkIndexer + client *elastic.TypedClient + bulkSettings common.BulkSettings + reqTimeout time.Duration + logger mlog.LoggerIFace +} + +func NewDataBulkClient(bulkSettings common.BulkSettings, + client *elastic.TypedClient, + reqTimeout time.Duration, + logger mlog.LoggerIFace, +) (*DataBulkClient, error) { + if bulkSettings.FlushNumReqs > 0 { + return nil, fmt.Errorf("DataBulkClient does not support a threshold on number of requests") + } + + indexer, err := newIndexer(client, bulkSettings, logger) + if err != nil { + return nil, err + } + + return &DataBulkClient{ + indexer: indexer, + client: client, + bulkSettings: bulkSettings, + reqTimeout: reqTimeout, + logger: logger, + }, nil +} + +func newIndexer(client *elastic.TypedClient, bulkSettings common.BulkSettings, logger mlog.LoggerIFace) (esutil.BulkIndexer, error) { + // A zeroed FlushInterval means that there should be no time-based flush, + // but esutil.BulkIndexer defaults to 30 seconds if the interval is zero, + // so we pick a large enough interval + interval := bulkSettings.FlushInterval + if interval == 0 { + interval = 1 * time.Hour + } + + return esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + FlushBytes: bulkSettings.FlushBytes, + FlushInterval: interval, + Client: client, + OnError: func(ctx context.Context, err error) { + logger.Error("indexer error", mlog.Err(err)) + }, + OnFlushStart: func(ctx context.Context) context.Context { + logger.Debug("elasticsearch bulk indexer flush started") + return ctx + }, + OnFlushEnd: func(context.Context) { + logger.Debug("elasticsearch bulk indexer flush ended") + }, + }) +} + +func (b *DataBulkClient) onSuccess(_ context.Context, item esutil.BulkIndexerItem, _ esutil.BulkIndexerResponseItem) { + b.logger.Info("successfully added new bulk operation", + mlog.String("action", item.Action), + mlog.String("index", item.Index), + mlog.String("document_id", item.DocumentID), + ) +} + +func (b *DataBulkClient) onFailure(_ context.Context, item esutil.BulkIndexerItem, _ esutil.BulkIndexerResponseItem, err error) { + b.logger.Info("failed to add new bulk operation", + mlog.String("action", item.Action), + mlog.String("index", item.Index), + mlog.String("document_id", item.DocumentID), + mlog.Err(err), + ) +} + +func (b *DataBulkClient) IndexOp(op esTypes.IndexOperation, doc any) error { + b.mut.Lock() + defer b.mut.Unlock() + + var bodyReader io.ReadSeeker + switch v := doc.(type) { + case []byte: + bodyReader = bytes.NewReader(v) + case json.RawMessage: + bodyReader = bytes.NewReader(v) + default: + body, err := json.Marshal(doc) + if err != nil { + return err + } + bodyReader = bytes.NewReader(body) + } + + ctx, cancel := context.WithTimeout(context.Background(), b.reqTimeout) + defer cancel() + + return b.indexer.Add(ctx, esutil.BulkIndexerItem{ + Index: *op.Index_, + Action: "index", + DocumentID: *op.Id_, + Body: bodyReader, + OnSuccess: b.onSuccess, + OnFailure: b.onFailure, + }) +} +func (b *DataBulkClient) DeleteOp(op esTypes.DeleteOperation) error { + b.mut.Lock() + defer b.mut.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), b.reqTimeout) + defer cancel() + + return b.indexer.Add(ctx, esutil.BulkIndexerItem{ + Index: *op.Index_, + Action: "delete", + DocumentID: *op.Id_, + Body: nil, + OnSuccess: b.onSuccess, + OnFailure: b.onFailure, + }) +} + +func (b *DataBulkClient) _stop() error { + ctx, cancel := context.WithTimeout(context.Background(), b.reqTimeout) + defer cancel() + + return b.indexer.Close(ctx) +} + +func (b *DataBulkClient) Flush() error { + b.mut.Lock() + defer b.mut.Unlock() + + // The esutil.BulkIndexer cannot be manually flushed, but it can be closed, + // which does flush all the contents. + if err := b._stop(); err != nil { + return fmt.Errorf("failed to close the BulkIndexer: %w", err) + } + + // But calling Close essentially kills all the running processes, so we have + // to create a new one in order to restart it + indexer, err := newIndexer(b.client, b.bulkSettings, b.logger) + if err != nil { + return fmt.Errorf("failed to restart the BulkIndexer: %w", err) + } + b.indexer = indexer + + return nil +} + +func (b *DataBulkClient) Stop() error { + b.mut.Lock() + defer b.mut.Unlock() + + b.logger.Info("Stopping Bulk processor") + + return b._stop() +} diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_client_data_test.go b/server/enterprise/elasticsearch/elasticsearch/bulk_client_data_test.go new file mode 100644 index 00000000000..1ad6af56a40 --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_client_data_test.go @@ -0,0 +1,355 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/elastic/go-elasticsearch/v8/esutil" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/api4" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" + "github.com/stretchr/testify/require" +) + +// setupDataBulkClient creates a test data bulk client with common setup +func setupDataBulkClient(t *testing.T, flushBytes int, flushInterval time.Duration) (*DataBulkClient, *api4.TestHelper) { + th := api4.SetupEnterprise(t) + + client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) + bulkClient, err := NewDataBulkClient( + common.BulkSettings{ + FlushBytes: flushBytes, + FlushInterval: flushInterval, + FlushNumReqs: 0, // DataBulkClient doesn't support FlushNumReqs + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + require.NoError(t, err) + + return bulkClient, th +} + +func flushAndGetStats(t *testing.T, b *DataBulkClient) esutil.BulkIndexerStats { + t.Helper() + + // Close the indexer to flush + err := b.indexer.Close(context.Background()) + require.NoError(t, err) + + // Get the stats + stats := b.indexer.Stats() + + // Restart the indexer + newIdxr, err := newIndexer(b.client, b.bulkSettings, b.logger) + require.NoError(t, err) + b.indexer = newIdxr + + return stats +} + +func TestDataIndexOp(t *testing.T) { + t.Run("single index operation", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + // Check that the request got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 1, int(stats.NumAdded)) + require.Equal(t, 0, int(stats.NumIndexed)) + + // Flush, and check that the document was indexed + stats = flushAndGetStats(t, bulkClient) + require.Equal(t, 1, int(stats.NumIndexed)) + }) + + t.Run("multiple index operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + for range 5 { + post := createTestPost(t, "test message") + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + } + + // Check that the requests got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 5, int(stats.NumAdded)) + + // Flush, and check that the documents were indexed + stats = flushAndGetStats(t, bulkClient) + require.Equal(t, 5, int(stats.NumIndexed)) + }) + + t.Run("index operation with json.RawMessage", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + docId := model.NewId() + jsonData := []byte(`{"message": "test raw message"}`) + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }, jsonData) + require.NoError(t, err) + + // Check that the request got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 1, int(stats.NumAdded)) + + // Flush, and check that the document was indexed + stats = flushAndGetStats(t, bulkClient) + require.Equal(t, 1, int(stats.NumIndexed)) + }) + + t.Run("index operation with byte slice", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + docId := model.NewId() + data := []byte(`{"message": "test byte slice"}`) + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }, data) + require.NoError(t, err) + + // Check that the request got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 1, int(stats.NumAdded)) + + // Flush, and check that the document was indexed + stats = flushAndGetStats(t, bulkClient) + require.Equal(t, 1, int(stats.NumIndexed)) + }) +} + +func TestDataDeleteOp(t *testing.T) { + t.Run("single delete operation", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + // Index a new post and flush + post := createTestPost(t, "test message") + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + require.NoError(t, bulkClient.Flush()) + + err = bulkClient.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }) + require.NoError(t, err) + + // Check that the request got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 1, int(stats.NumAdded)) + require.Equal(t, 0, int(stats.NumDeleted)) + + // Flush, and check that the document was deleted + stats = flushAndGetStats(t, bulkClient) + fmt.Println(stats) + require.Equal(t, 1, int(stats.NumDeleted)) + }) + + t.Run("multiple delete operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + posts := make([]string, 3) + + // Index three new posts and flush + for i := range 3 { + post := createTestPost(t, "test message") + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + posts[i] = post.Id + } + require.NoError(t, bulkClient.Flush()) + + for _, id := range posts { + err := bulkClient.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(id), + }) + require.NoError(t, err) + } + + // Check that the requests got added + stats := bulkClient.indexer.Stats() + require.Equal(t, 3, int(stats.NumAdded)) + require.Equal(t, 0, int(stats.NumDeleted)) + + // Flush, and check that the documents were deleted + stats = flushAndGetStats(t, bulkClient) + fmt.Println(stats) + require.Equal(t, 3, int(stats.NumDeleted)) + }) +} + +func TestDataFlush(t *testing.T) { + t.Run("flush with pending operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + err = bulkClient.Flush() + require.NoError(t, err) + }) + + t.Run("flush with no pending operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + t.Cleanup(func() { + err := bulkClient.Stop() + require.NoError(t, err) + }) + + err := bulkClient.Flush() + require.NoError(t, err) + }) +} + +func TestDataStop(t *testing.T) { + t.Run("stop with pending operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + err = bulkClient.Stop() + require.NoError(t, err) + }) + + t.Run("stop with no pending operations", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 0) // 1MB flush threshold + defer th.TearDown() + + err := bulkClient.Stop() + require.NoError(t, err) + }) + + t.Run("stop with periodic flusher", func(t *testing.T) { + bulkClient, th := setupDataBulkClient(t, 1024*1024, 100*time.Millisecond) + defer th.TearDown() + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + // Stop should flush pending operations and stop the periodic flusher + err = bulkClient.Stop() + require.NoError(t, err) + }) +} + +func TestDataNewDataBulkClient(t *testing.T) { + th := api4.SetupEnterprise(t) + defer th.TearDown() + + client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) + + t.Run("valid configuration", func(t *testing.T) { + bulkClient, err := NewDataBulkClient( + common.BulkSettings{ + FlushBytes: 1024, + FlushInterval: 100 * time.Millisecond, + FlushNumReqs: 0, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + require.NoError(t, err) + require.NotNil(t, bulkClient) + + err = bulkClient.Stop() + require.NoError(t, err) + }) + + t.Run("invalid configuration with FlushNumReqs", func(t *testing.T) { + bulkClient, err := NewDataBulkClient( + common.BulkSettings{ + FlushBytes: 1024, + FlushInterval: 100 * time.Millisecond, + FlushNumReqs: 10, // This should cause an error + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + require.Error(t, err) + require.Nil(t, bulkClient) + require.Contains(t, err.Error(), "DataBulkClient does not support a threshold on number of requests") + }) +} diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_client_req.go b/server/enterprise/elasticsearch/elasticsearch/bulk_client_req.go new file mode 100644 index 00000000000..80bf5df94fd --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_client_req.go @@ -0,0 +1,164 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "context" + "fmt" + "sync" + "time" + + elastic "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/typedapi/core/bulk" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" +) + +// ReqBulkClient is an Elasticsearch bulk client based on the +// go-elasticsearch/v8/typedapi/code/bulk.Bulk type. +// It supports time- and number-of-requests-based thresholds, but not a +// threshold on the size of the request. +type ReqBulkClient struct { + mut sync.Mutex + + indexer *bulk.Bulk + client *elastic.TypedClient + bulkSettings common.BulkSettings + reqTimeout time.Duration + logger mlog.LoggerIFace + + quitFlusher chan struct{} + quitFlusherWg sync.WaitGroup + pendingRequests int +} + +func NewReqBulkClient(bulkSettings common.BulkSettings, + client *elastic.TypedClient, + reqTimeout time.Duration, + logger mlog.LoggerIFace, +) (*ReqBulkClient, error) { + if bulkSettings.FlushBytes > 0 { + return nil, fmt.Errorf("BulkClientBasic does not support a threshold on bytes") + } + + b := &ReqBulkClient{ + indexer: client.Bulk(), + client: client, + bulkSettings: bulkSettings, + reqTimeout: reqTimeout, + logger: logger, + + quitFlusher: make(chan struct{}), + } + + if bulkSettings.FlushInterval > 0 { + b.quitFlusherWg.Add(1) + go b.periodicFlusher() + } + + return b, nil +} + +// IndexOp is a helper function to add an IndexOperation to the current bulk request. +// doc argument can be a []byte, json.RawMessage or a struct. +func (r *ReqBulkClient) IndexOp(op types.IndexOperation, doc any) error { + r.mut.Lock() + defer r.mut.Unlock() + + if err := r.indexer.IndexOp(op, doc); err != nil { + return err + } + + return r.flushIfNecessary() +} + +// DeleteOp is a helper function to add a DeleteOperation to the current bulk request. +func (r *ReqBulkClient) DeleteOp(op types.DeleteOperation) error { + r.mut.Lock() + defer r.mut.Unlock() + + if err := r.indexer.DeleteOp(op); err != nil { + return err + } + + return r.flushIfNecessary() +} + +// flushIfNecessary flushes the pending buffer if needed. +// It MUST be called with an already acquired mutex. +func (r *ReqBulkClient) flushIfNecessary() error { + r.pendingRequests++ + + // Check number of requests threshold, only if specified + if r.bulkSettings.FlushNumReqs > 0 { + if r.pendingRequests > r.bulkSettings.FlushNumReqs { + return r._flush() + } + } + + return nil +} + +func (r *ReqBulkClient) Stop() error { + r.mut.Lock() + defer r.mut.Unlock() + + r.logger.Info("Stopping Bulk processor") + + if r.pendingRequests > 0 { + return r._flush() + } + + if r.bulkSettings.FlushInterval > 0 { + close(r.quitFlusher) + r.quitFlusherWg.Wait() + } + + return nil +} + +func (r *ReqBulkClient) periodicFlusher() { + defer r.quitFlusherWg.Done() + + for { + select { + case <-time.After(r.bulkSettings.FlushInterval): + r.mut.Lock() + if r.pendingRequests > 0 { + if err := r._flush(); err != nil { + r.logger.Warn("Error flushing live indexing buffer", mlog.Err(err)) + } + } + r.mut.Unlock() + case <-r.quitFlusher: + return + } + } +} + +// _flush MUST be called with an acquired lock. +func (r *ReqBulkClient) _flush() error { + if r.pendingRequests == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), r.reqTimeout) + defer cancel() + + _, err := r.indexer.Do(ctx) + if err != nil { + return err + } + r.pendingRequests = 0 + + return nil +} + +func (r *ReqBulkClient) Flush() error { + r.mut.Lock() + defer r.mut.Unlock() + + return r._flush() +} diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_client_req_test.go b/server/enterprise/elasticsearch/elasticsearch/bulk_client_req_test.go new file mode 100644 index 00000000000..d2efe807121 --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_client_req_test.go @@ -0,0 +1,263 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "testing" + "time" + + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/v8/channels/api4" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" + "github.com/stretchr/testify/require" +) + +// setupBulkClient creates a test bulk client with common setup +func setupBulkClient(t *testing.T, flushNumReqs int, flushInterval time.Duration) (*ReqBulkClient, *api4.TestHelper) { + th := api4.SetupEnterprise(t) + + client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) + bulkClient, err := NewReqBulkClient( + common.BulkSettings{ + FlushBytes: 0, + FlushInterval: flushInterval, + FlushNumReqs: flushNumReqs, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + require.NoError(t, err) + + return bulkClient, th +} + +// createTestPost creates a test post for indexing +func createTestPost(t *testing.T, message string) *common.ESPost { + post, err := common.ESPostFromPost(&model.Post{ + Id: model.NewId(), + Message: message, + }, "myteam") + require.NoError(t, err) + return post +} + +func TestBulkProcessor(t *testing.T) { + th := api4.SetupEnterprise(t) + defer th.TearDown() + + bulkClient, _ := setupBulkClient(t, *th.App.Config().ElasticsearchSettings.LiveIndexingBatchSize, 0) + + post := createTestPost(t, "hello world") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("myindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + + require.Equal(t, 1, bulkClient.pendingRequests) + + err = bulkClient.Stop() + require.NoError(t, err) + + require.Equal(t, 0, bulkClient.pendingRequests) +} + +func TestIndexOp(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + t.Run("single index operation", func(t *testing.T) { + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + }) + + t.Run("multiple index operations", func(t *testing.T) { + initialRequests := bulkClient.pendingRequests + + for range 5 { + post := createTestPost(t, "test message") + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + } + + require.Equal(t, initialRequests+5, bulkClient.pendingRequests) + }) + + t.Run("auto flush on threshold", func(t *testing.T) { + // Create a new client with low flush threshold + bulkClient2, th2 := setupBulkClient(t, 2, 0) + defer th2.TearDown() + + post1 := createTestPost(t, "first message") + err := bulkClient2.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post1.Id), + }, post1) + require.NoError(t, err) + require.Equal(t, 1, bulkClient2.pendingRequests) + + post2 := createTestPost(t, "second message") + err = bulkClient2.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post2.Id), + }, post2) + require.NoError(t, err) + require.Equal(t, 2, bulkClient2.pendingRequests) + + // Third operation should trigger flush + post3 := createTestPost(t, "third message") + err = bulkClient2.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post3.Id), + }, post3) + require.NoError(t, err) + require.Equal(t, 0, bulkClient2.pendingRequests) + }) +} + +func TestDeleteOp(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + t.Run("single delete operation", func(t *testing.T) { + docId := model.NewId() + + err := bulkClient.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + }) + + t.Run("multiple delete operations", func(t *testing.T) { + initialRequests := bulkClient.pendingRequests + + for range 3 { + docId := model.NewId() + err := bulkClient.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + } + + require.Equal(t, initialRequests+3, bulkClient.pendingRequests) + }) + + t.Run("auto flush on threshold", func(t *testing.T) { + // Create a new client with low flush threshold + bulkClient2, th2 := setupBulkClient(t, 2, 0) + defer th2.TearDown() + + // Add two delete operations + for range 2 { + docId := model.NewId() + err := bulkClient2.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + } + require.Equal(t, 2, bulkClient2.pendingRequests) + + // Third operation should trigger flush + docId := model.NewId() + err := bulkClient2.DeleteOp(types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + require.Equal(t, 0, bulkClient2.pendingRequests) + }) +} + +func TestFlush(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + t.Run("flush with pending requests", func(t *testing.T) { + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + + err = bulkClient.Flush() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) + + t.Run("flush with no pending requests", func(t *testing.T) { + require.Equal(t, 0, bulkClient.pendingRequests) + + err := bulkClient.Flush() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) +} + +func TestStop(t *testing.T) { + t.Run("stop with pending requests", func(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + + err = bulkClient.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) + + t.Run("stop with no pending requests", func(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 0) + defer th.TearDown() + + require.Equal(t, 0, bulkClient.pendingRequests) + + err := bulkClient.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) + + t.Run("stop with periodic flusher", func(t *testing.T) { + bulkClient, th := setupBulkClient(t, 10, 100*time.Millisecond) + defer th.TearDown() + + post := createTestPost(t, "test message") + + err := bulkClient.IndexOp(types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulkClient.pendingRequests) + + // Stop should flush pending requests and stop the periodic flusher + err = bulkClient.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulkClient.pendingRequests) + }) +} diff --git a/server/enterprise/elasticsearch/elasticsearch/bulk_test.go b/server/enterprise/elasticsearch/elasticsearch/bulk_test.go index 0970dbddb29..663227de8b0 100644 --- a/server/enterprise/elasticsearch/elasticsearch/bulk_test.go +++ b/server/enterprise/elasticsearch/elasticsearch/bulk_test.go @@ -5,39 +5,75 @@ package elasticsearch import ( "testing" + "time" - "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/v8/channels/api4" "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" "github.com/stretchr/testify/require" ) -func TestBulkProcessor(t *testing.T) { +func TestNewBulk(t *testing.T) { th := api4.SetupEnterprise(t) defer th.TearDown() - client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) - bulk := NewBulk(th.App.Config().ElasticsearchSettings, - th.Server.Platform().Log(), - client) - post, err := common.ESPostFromPost(&model.Post{ - Id: model.NewId(), - Message: "hello world", - }, "myteam") - require.NoError(t, err) + t.Run("zeroed bulksettings", func(t *testing.T) { + _, err := NewBulk( + common.BulkSettings{}, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log(), + ) + + require.Error(t, err) + }) + + t.Run("incompatible bulkSettings", func(t *testing.T) { + _, err := NewBulk( + common.BulkSettings{ + FlushBytes: 100, + FlushInterval: 5 * time.Second, + FlushNumReqs: 10, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log(), + ) + + require.Error(t, err) + }) - err = bulk.IndexOp(types.IndexOperation{ - Index_: model.NewPointer("myindex"), - Id_: model.NewPointer(post.Id), - }, post) - require.NoError(t, err) + t.Run("data-based bulk client", func(t *testing.T) { + client, err := NewBulk( + common.BulkSettings{ + FlushBytes: 100, + FlushInterval: 5 * time.Second, + FlushNumReqs: 0, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log(), + ) + require.NoError(t, err) - require.Equal(t, 1, bulk.pendingRequests) + _, ok := client.(*DataBulkClient) + require.True(t, ok) + }) - err = bulk.Stop() - require.NoError(t, err) + t.Run("requests-based bulk client", func(t *testing.T) { + client, err := NewBulk( + common.BulkSettings{ + FlushBytes: 0, + FlushInterval: 5 * time.Second, + FlushNumReqs: 100, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log(), + ) + require.NoError(t, err) - require.Equal(t, 0, bulk.pendingRequests) + _, ok := client.(*ReqBulkClient) + require.True(t, ok) + }) } diff --git a/server/enterprise/elasticsearch/elasticsearch/elasticsearch.go b/server/enterprise/elasticsearch/elasticsearch/elasticsearch.go index e4443808d70..6524d1b4015 100644 --- a/server/enterprise/elasticsearch/elasticsearch/elasticsearch.go +++ b/server/enterprise/elasticsearch/elasticsearch/elasticsearch.go @@ -43,8 +43,9 @@ type ElasticsearchInterfaceImpl struct { fullVersion string plugins []string - bulkProcessor *Bulk - Platform *platform.PlatformService + bulkProcessor BulkClient + syncBulkProcessor BulkClient + Platform *platform.PlatformService } func getJSONOrErrorStr(obj any) string { @@ -128,10 +129,39 @@ func (es *ElasticsearchInterfaceImpl) Start() *model.AppError { ctx := context.Background() - if *es.Platform.Config().ElasticsearchSettings.LiveIndexingBatchSize > 1 { - es.bulkProcessor = NewBulk(es.Platform.Config().ElasticsearchSettings, - es.Platform.Log(), - es.client) + esSettings := es.Platform.Config().ElasticsearchSettings + if *esSettings.LiveIndexingBatchSize > 1 { + es.bulkProcessor, err = NewBulk( + common.BulkSettings{ + FlushBytes: 0, + FlushInterval: common.BulkFlushInterval, + FlushNumReqs: *esSettings.LiveIndexingBatchSize, + }, + es.client, + time.Duration(*esSettings.RequestTimeoutSeconds)*time.Second, + es.Platform.Log()) + if err != nil { + return model.NewAppError("elasticsearch.start", + "ent.elasticsearch.create_processor.bulk_processor_create_failed", + nil, "", + http.StatusInternalServerError).Wrap(err) + } + } + + es.syncBulkProcessor, err = NewBulk( + common.BulkSettings{ + FlushBytes: common.BulkFlushBytes, + FlushInterval: 0, + FlushNumReqs: 0, + }, + es.client, + time.Duration(*esSettings.RequestTimeoutSeconds)*time.Second, + es.Platform.Log()) + if err != nil { + return model.NewAppError("elasticsearch.start", + "ent.elasticsearch.create_processor.sync_bulk_processor_create_failed", + nil, "", + http.StatusInternalServerError).Wrap(err) } // Set up posts index template. @@ -749,6 +779,49 @@ func (es *ElasticsearchInterfaceImpl) IndexChannel(rctx request.CTX, channel *mo return nil } +func (es *ElasticsearchInterfaceImpl) SyncBulkIndexChannels(rctx request.CTX, channels []*model.Channel, getUserIDsForChannel func(channel *model.Channel) ([]string, error), teamMemberIDs []string) *model.AppError { + if len(channels) == 0 { + return nil + } + + es.mutex.RLock() + defer es.mutex.RUnlock() + + if atomic.LoadInt32(&es.ready) == 0 { + return model.NewAppError("Elasticsearch.SyncBulkIndexChannels", "ent.elasticsearch.not_started.error", map[string]any{"Backend": model.ElasticsearchSettingsESBackend}, "", http.StatusInternalServerError) + } + + indexName := *es.Platform.Config().ElasticsearchSettings.IndexPrefix + common.IndexBaseChannels + metrics := es.Platform.Metrics() + + for _, channel := range channels { + userIDs, err := getUserIDsForChannel(channel) + if err != nil { + return model.NewAppError("Elasticsearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + searchChannel := common.ESChannelFromChannel(channel, userIDs, teamMemberIDs) + + err = es.syncBulkProcessor.IndexOp(types.IndexOperation{ + Index_: model.NewPointer(indexName), + Id_: model.NewPointer(searchChannel.Id), + }, searchChannel) + if err != nil { + return model.NewAppError("Elasticsearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + if metrics != nil { + metrics.IncrementChannelIndexCounter() + } + } + + if err := es.syncBulkProcessor.Flush(); err != nil { + return model.NewAppError("Elasticsearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + return nil +} + func (es *ElasticsearchInterfaceImpl) SearchChannels(teamId, userID string, term string, isGuest, includeDeleted bool) ([]string, *model.AppError) { es.mutex.RLock() defer es.mutex.RUnlock() diff --git a/server/enterprise/elasticsearch/elasticsearch/elasticsearch_test.go b/server/enterprise/elasticsearch/elasticsearch/elasticsearch_test.go index 62edffeea7c..21baa62a1b0 100644 --- a/server/enterprise/elasticsearch/elasticsearch/elasticsearch_test.go +++ b/server/enterprise/elasticsearch/elasticsearch/elasticsearch_test.go @@ -98,3 +98,76 @@ func (s *ElasticsearchInterfaceTestSuite) SetupTest() { s.Nil(s.CommonTestSuite.ESImpl.PurgeIndexes(s.th.Context)) } + +func (s *ElasticsearchInterfaceTestSuite) TestSyncBulkIndexChannels() { + s.Run("Should index multiple channels successfully", func() { + // Create test channels + channel1 := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypeOpen, + Name: "test-channel-1", + DisplayName: "Test Channel 1", + } + channel1.PreSave() + + channel2 := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypePrivate, + Name: "test-channel-2", + DisplayName: "Test Channel 2", + } + channel2.PreSave() + + channels := []*model.Channel{channel1, channel2} + + // Mock getUserIDsForChannel function + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return []string{s.th.BasicUser.Id, s.th.BasicUser2.Id}, nil + } + + teamMemberIDs := []string{s.th.BasicUser.Id, s.th.BasicUser2.Id} + + // Test the bulk indexing + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, channels, getUserIDsForChannel, teamMemberIDs) + s.Require().Nil(appErr) + + // Refresh the index to ensure data is searchable + s.Require().NoError(s.CommonTestSuite.RefreshIndexFn()) + + // Verify both channels are indexed + found, _, err := s.CommonTestSuite.GetDocumentFn("channels", channel1.Id) + s.Require().NoError(err) + s.Require().True(found) + + found, _, err = s.CommonTestSuite.GetDocumentFn("channels", channel2.Id) + s.Require().NoError(err) + s.Require().True(found) + }) + + s.Run("Should handle empty channels list", func() { + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return []string{}, nil + } + + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, []*model.Channel{}, getUserIDsForChannel, []string{}) + s.Require().Nil(appErr) + }) + + s.Run("Should handle getUserIDsForChannel error", func() { + channel := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypeOpen, + Name: "test-channel-error", + DisplayName: "Test Channel Error", + } + channel.PreSave() + + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return nil, model.NewAppError("TestError", "test.error", nil, "", 500) + } + + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, []*model.Channel{channel}, getUserIDsForChannel, []string{}) + s.Require().NotNil(appErr) + s.Require().Contains(appErr.Error(), "test.error") + }) +} diff --git a/server/enterprise/elasticsearch/elasticsearch/sync_bulk.go b/server/enterprise/elasticsearch/elasticsearch/sync_bulk.go new file mode 100644 index 00000000000..298d97cfddd --- /dev/null +++ b/server/enterprise/elasticsearch/elasticsearch/sync_bulk.go @@ -0,0 +1,94 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.enterprise for license information. + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "time" + + elastic "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/esutil" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" +) + +type SyncBulk struct { + client *elastic.TypedClient + bulkIndexer esutil.BulkIndexer +} + +func NewSyncBulk(client *elastic.TypedClient) (*SyncBulk, error) { + bulkIndexer, err := newBulkIndexer(client) + if err != nil { + return nil, err + } + + return &SyncBulk{client, bulkIndexer}, nil +} + +func newBulkIndexer(client *elastic.TypedClient) (esutil.BulkIndexer, error) { + return esutil.NewBulkIndexer(esutil.BulkIndexerConfig{ + Client: client, // The Elasticsearch client + FlushBytes: common.BulkFlushBytes, // The flush threshold in bytes + FlushInterval: 30 * time.Second, // The periodic flush interval + }) +} + +// IndexOp is a helper function to add an IndexOperation to the current bulk request. +// doc argument can be a []byte, json.RawMessage or a struct. +func (r *SyncBulk) IndexOp(op types.IndexOperation, doc any) error { + var body io.ReadSeeker + switch v := doc.(type) { + case []byte: + body = bytes.NewReader(v) + case json.RawMessage: + body = bytes.NewReader(v) + default: + data, err := json.Marshal(doc) + if err != nil { + return err + } + body = bytes.NewReader(data) + } + + return r.bulkIndexer.Add(context.Background(), esutil.BulkIndexerItem{ + Index: *op.Index_, + Action: "index", + DocumentID: *op.Id_, + Body: body, + }) +} + +// DeleteOp is a helper function to add a DeleteOperation to the current bulk request. +func (r *SyncBulk) DeleteOp(op types.DeleteOperation) error { + return r.bulkIndexer.Add(context.Background(), esutil.BulkIndexerItem{ + Index: *op.Index_, + Action: "delete", + DocumentID: *op.Id_, + }) +} + +func (r *SyncBulk) Stop() error { + return r.bulkIndexer.Close(context.Background()) +} + +func (r *SyncBulk) Flush() error { + // Flush by closing the indexer: there is no manual Flush method + if err := r.bulkIndexer.Close(context.Background()); err != nil { + return err + } + + // Restart the indexer so that we can keep using it + bulkIndexer, err := newBulkIndexer(r.client) + if err != nil { + return fmt.Errorf("unable to restart bulk indexer: %w", err) + } + + r.bulkIndexer = bulkIndexer + return nil +} diff --git a/server/enterprise/elasticsearch/opensearch/bulk.go b/server/enterprise/elasticsearch/opensearch/bulk.go index ca40793b95b..dbb26d80029 100644 --- a/server/enterprise/elasticsearch/opensearch/bulk.go +++ b/server/enterprise/elasticsearch/opensearch/bulk.go @@ -11,7 +11,6 @@ import ( "time" "github.com/elastic/go-elasticsearch/v8/typedapi/types" - "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/enterprise/elasticsearch/common" "github.com/opensearch-project/opensearch-go/v4/opensearchapi" @@ -21,9 +20,10 @@ type Bulk struct { mut sync.Mutex buf *bytes.Buffer - logger mlog.LoggerIFace - client *opensearchapi.Client - settings model.ElasticsearchSettings + client *opensearchapi.Client + bulkSettings common.BulkSettings + reqTimeout time.Duration + logger mlog.LoggerIFace quitFlusher chan struct{} quitFlusherWg sync.WaitGroup @@ -31,19 +31,25 @@ type Bulk struct { pendingRequests int } -func NewBulk(settings model.ElasticsearchSettings, +func NewBulk(bulkSettings common.BulkSettings, + client *opensearchapi.Client, + reqTimeout time.Duration, logger mlog.LoggerIFace, - client *opensearchapi.Client) *Bulk { +) *Bulk { b := &Bulk{ - settings: settings, - logger: logger, - client: client, - quitFlusher: make(chan struct{}), - buf: &bytes.Buffer{}, + bulkSettings: bulkSettings, + reqTimeout: reqTimeout, + logger: logger, + client: client, + quitFlusher: make(chan struct{}), + buf: &bytes.Buffer{}, } - b.quitFlusherWg.Add(1) - go b.periodicFlusher() + // Start the timer only if a flush interval was specified + if bulkSettings.FlushInterval > 0 { + b.quitFlusherWg.Add(1) + go b.periodicFlusher() + } return b } @@ -101,10 +107,20 @@ func (r *Bulk) DeleteOp(op *types.DeleteOperation) error { // flushIfNecessary flushes the pending buffer if needed. // It MUST be called with an already acquired mutex. func (r *Bulk) flushIfNecessary() error { + // Check data threshold, only if specified + if r.bulkSettings.FlushBytes > 0 { + if r.buf.Len() >= r.bulkSettings.FlushBytes { + return r._flush() + } + } + r.pendingRequests++ - if r.pendingRequests > *r.settings.LiveIndexingBatchSize { - return r._flush() + // Check number of requests threshold, only if specified + if r.bulkSettings.FlushNumReqs > 0 { + if r.pendingRequests > r.bulkSettings.FlushNumReqs { + return r._flush() + } } return nil @@ -119,8 +135,11 @@ func (r *Bulk) Stop() error { return r._flush() } - close(r.quitFlusher) - r.quitFlusherWg.Wait() + // Cleanup the timer if the flush interval was specified + if r.bulkSettings.FlushInterval > 0 { + close(r.quitFlusher) + r.quitFlusherWg.Wait() + } return nil } @@ -130,7 +149,7 @@ func (r *Bulk) periodicFlusher() { for { select { - case <-time.After(common.BulkFlushInterval): + case <-time.After(r.bulkSettings.FlushInterval): r.mut.Lock() if r.pendingRequests > 0 { if err := r._flush(); err != nil { @@ -146,7 +165,11 @@ func (r *Bulk) periodicFlusher() { // _flush MUST be called with an acquired lock. func (r *Bulk) _flush() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*r.settings.RequestTimeoutSeconds)*time.Second) + if r.pendingRequests == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), r.reqTimeout) defer cancel() _, err := r.client.Bulk(ctx, opensearchapi.BulkReq{ @@ -160,3 +183,9 @@ func (r *Bulk) _flush() error { return nil } + +func (r *Bulk) Flush() error { + r.mut.Lock() + defer r.mut.Unlock() + return r._flush() +} diff --git a/server/enterprise/elasticsearch/opensearch/bulk_test.go b/server/enterprise/elasticsearch/opensearch/bulk_test.go index 57466e9beb1..a7590444ae0 100644 --- a/server/enterprise/elasticsearch/opensearch/bulk_test.go +++ b/server/enterprise/elasticsearch/opensearch/bulk_test.go @@ -4,8 +4,10 @@ package opensearch import ( + "fmt" "os" "testing" + "time" "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/mattermost/mattermost/server/public/model" @@ -14,22 +16,15 @@ import ( "github.com/stretchr/testify/require" ) -func TestBulkProcessor(t *testing.T) { +// setupBulkClient creates a test bulk client with common setup +func setupBulkClient(t *testing.T, flushBytes int, flushNumReqs int, flushInterval time.Duration) (*Bulk, *api4.TestHelper) { th := api4.SetupEnterprise(t) - defer th.TearDown() if os.Getenv("IS_CI") == "true" { os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://opensearch:9201") os.Setenv("MM_ELASTICSEARCHSETTINGS_BACKEND", "opensearch") } - defer func() { - if os.Getenv("IS_CI") == "true" { - os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") - os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") - } - }() - th.App.UpdateConfig(func(cfg *model.Config) { if os.Getenv("IS_CI") == "true" { *cfg.ElasticsearchSettings.ConnectionURL = "http://opensearch:9201" @@ -43,17 +38,42 @@ func TestBulkProcessor(t *testing.T) { }) client := createTestClient(t, th.Context, th.App.Config(), th.App.FileBackend()) - bulk := NewBulk(th.App.Config().ElasticsearchSettings, - th.Server.Platform().Log(), - client) + bulk := NewBulk( + common.BulkSettings{ + FlushBytes: flushBytes, + FlushInterval: flushInterval, + FlushNumReqs: flushNumReqs, + }, + client, + time.Duration(*th.App.Config().ElasticsearchSettings.RequestTimeoutSeconds)*time.Second, + th.Server.Platform().Log()) + return bulk, th +} + +// createTestPost creates a test post for indexing +func createTestPost(t *testing.T, message string) *common.ESPost { post, err := common.ESPostFromPost(&model.Post{ Id: model.NewId(), - Message: "hello world", + Message: message, }, "myteam") require.NoError(t, err) + return post +} + +func TestBulkProcessor(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() - err = bulk.IndexOp(&types.IndexOperation{ + post := createTestPost(t, "hello world") + + err := bulk.IndexOp(&types.IndexOperation{ Index_: model.NewPointer("myindex"), Id_: model.NewPointer(post.Id), }, post) @@ -66,3 +86,397 @@ func TestBulkProcessor(t *testing.T) { require.Equal(t, 0, bulk.pendingRequests) } + +func TestNewBulk(t *testing.T) { + bulk, th := setupBulkClient(t, 1024, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + t.Run("creates bulk client without periodic flusher", func(t *testing.T) { + require.NotNil(t, bulk) + require.NotNil(t, bulk.client) + require.NotNil(t, bulk.logger) + require.NotNil(t, bulk.buf) + require.Equal(t, 0, bulk.pendingRequests) + require.Equal(t, 1024, bulk.bulkSettings.FlushBytes) + require.Equal(t, 10, bulk.bulkSettings.FlushNumReqs) + }) + + t.Run("creates bulk client with periodic flusher", func(t *testing.T) { + bulkWithTimer, th2 := setupBulkClient(t, 1024, 10, 100*time.Millisecond) + defer th2.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + require.NotNil(t, bulkWithTimer) + require.Equal(t, 100*time.Millisecond, bulkWithTimer.bulkSettings.FlushInterval) + + err := bulkWithTimer.Stop() + require.NoError(t, err) + }) + + err := bulk.Stop() + require.NoError(t, err) +} + +func TestIndexOp(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + t.Run("single index operation with struct", func(t *testing.T) { + post := createTestPost(t, "test message") + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + // Verify buffer has content + require.Greater(t, bulk.buf.Len(), 0) + }) + + t.Run("index operation with []byte", func(t *testing.T) { + initialRequests := bulk.pendingRequests + docId := model.NewId() + data := []byte(`{"message": "test byte slice"}`) + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }, data) + require.NoError(t, err) + require.Equal(t, initialRequests+1, bulk.pendingRequests) + }) + + t.Run("index operation with json.RawMessage", func(t *testing.T) { + initialRequests := bulk.pendingRequests + docId := model.NewId() + jsonData := []byte(`{"message": "test raw message"}`) + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }, jsonData) + require.NoError(t, err) + require.Equal(t, initialRequests+1, bulk.pendingRequests) + }) + + t.Run("multiple index operations", func(t *testing.T) { + initialRequests := bulk.pendingRequests + + for i := range 5 { + post := createTestPost(t, fmt.Sprintf("test message %d", i)) + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + } + + require.Equal(t, initialRequests+5, bulk.pendingRequests) + }) + + t.Run("auto flush on request threshold", func(t *testing.T) { + // Create a new client with low flush threshold + bulk2, th2 := setupBulkClient(t, 0, 2, 0) + defer th2.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + post1 := createTestPost(t, "first message") + err := bulk2.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post1.Id), + }, post1) + require.NoError(t, err) + require.Equal(t, 1, bulk2.pendingRequests) + + post2 := createTestPost(t, "second message") + err = bulk2.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post2.Id), + }, post2) + require.NoError(t, err) + require.Equal(t, 2, bulk2.pendingRequests) + + // Third operation should trigger flush + post3 := createTestPost(t, "third message") + err = bulk2.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post3.Id), + }, post3) + require.NoError(t, err) + require.Equal(t, 0, bulk2.pendingRequests) + + err = bulk2.Stop() + require.NoError(t, err) + }) + + err := bulk.Stop() + require.NoError(t, err) +} + +func TestDeleteOp(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + t.Run("single delete operation", func(t *testing.T) { + docId := model.NewId() + + err := bulk.DeleteOp(&types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + // Verify buffer has content + require.Greater(t, bulk.buf.Len(), 0) + }) + + t.Run("multiple delete operations", func(t *testing.T) { + initialRequests := bulk.pendingRequests + + for range 3 { + docId := model.NewId() + err := bulk.DeleteOp(&types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + } + + require.Equal(t, initialRequests+3, bulk.pendingRequests) + }) + + t.Run("auto flush on request threshold", func(t *testing.T) { + // Create a new client with low flush threshold + bulk2, th2 := setupBulkClient(t, 0, 2, 0) + defer th2.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + // Add two delete operations + for range 2 { + docId := model.NewId() + err := bulk2.DeleteOp(&types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + } + require.Equal(t, 2, bulk2.pendingRequests) + + // Third operation should trigger flush + docId := model.NewId() + err := bulk2.DeleteOp(&types.DeleteOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(docId), + }) + require.NoError(t, err) + require.Equal(t, 0, bulk2.pendingRequests) + + err = bulk2.Stop() + require.NoError(t, err) + }) + + err := bulk.Stop() + require.NoError(t, err) +} + +func TestFlush(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + t.Run("flush with pending operations", func(t *testing.T) { + post := createTestPost(t, "test message") + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + err = bulk.Flush() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + + // Verify buffer is empty after flush + require.Equal(t, 0, bulk.buf.Len()) + }) + + t.Run("flush with no pending operations", func(t *testing.T) { + require.Equal(t, 0, bulk.pendingRequests) + + err := bulk.Flush() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + }) + + err := bulk.Stop() + require.NoError(t, err) +} + +func TestStop(t *testing.T) { + t.Run("stop with pending operations", func(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + post := createTestPost(t, "test message") + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + err = bulk.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + }) + + t.Run("stop with no pending operations", func(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 0) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + require.Equal(t, 0, bulk.pendingRequests) + + err := bulk.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + }) + + t.Run("stop with periodic flusher", func(t *testing.T) { + bulk, th := setupBulkClient(t, 0, 10, 100*time.Millisecond) + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + post := createTestPost(t, "test message") + + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + require.Equal(t, 1, bulk.pendingRequests) + + // Stop should flush pending operations and stop the periodic flusher + err = bulk.Stop() + require.NoError(t, err) + require.Equal(t, 0, bulk.pendingRequests) + }) +} + +func TestFlushThresholds(t *testing.T) { + t.Run("flush on bytes threshold", func(t *testing.T) { + // Create a client with very small byte threshold + bulk, th := setupBulkClient(t, 100, 0, 0) // 100 bytes threshold + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + // Add operations that should exceed the byte threshold + for range 5 { + post := createTestPost(t, "This is a long message that should help us exceed the byte threshold for testing") + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + } + + // Should have been flushed due to byte threshold + require.Equal(t, 0, bulk.pendingRequests) + + err := bulk.Stop() + require.NoError(t, err) + }) + + t.Run("no flush when thresholds not met", func(t *testing.T) { + bulk, th := setupBulkClient(t, 100000, 10, 0) // High thresholds + defer th.TearDown() + defer func() { + if os.Getenv("IS_CI") == "true" { + os.Setenv("MM_ELASTICSEARCHSETTINGS_CONNECTIONURL", "http://elasticsearch:9201") + os.Unsetenv("MM_ELASTICSEARCHSETTINGS_BACKEND") + } + }() + + // Add a few operations that shouldn't trigger flush + for range 3 { + post := createTestPost(t, "short") + err := bulk.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer("testindex"), + Id_: model.NewPointer(post.Id), + }, post) + require.NoError(t, err) + fmt.Println("PENDING REQS:", bulk.pendingRequests) + } + + // Should not have been flushed + require.Equal(t, 3, bulk.pendingRequests) + + err := bulk.Stop() + require.NoError(t, err) + }) +} diff --git a/server/enterprise/elasticsearch/opensearch/opensearch.go b/server/enterprise/elasticsearch/opensearch/opensearch.go index 69bb30b8758..1a542fa450d 100644 --- a/server/enterprise/elasticsearch/opensearch/opensearch.go +++ b/server/enterprise/elasticsearch/opensearch/opensearch.go @@ -45,8 +45,9 @@ type OpensearchInterfaceImpl struct { fullVersion string plugins []string - bulkProcessor *Bulk - Platform *platform.PlatformService + bulkProcessor *Bulk + syncBulkProcessor *Bulk + Platform *platform.PlatformService } func getJSONOrErrorStr(obj any) string { @@ -130,11 +131,27 @@ func (os *OpensearchInterfaceImpl) Start() *model.AppError { ctx := context.Background() - if *os.Platform.Config().ElasticsearchSettings.LiveIndexingBatchSize > 1 { - os.bulkProcessor = NewBulk(os.Platform.Config().ElasticsearchSettings, - os.Platform.Log(), - os.client) - } + esSettings := os.Platform.Config().ElasticsearchSettings + if *esSettings.LiveIndexingBatchSize > 1 { + os.bulkProcessor = NewBulk( + common.BulkSettings{ + FlushBytes: 0, + FlushInterval: common.BulkFlushInterval, + FlushNumReqs: *esSettings.LiveIndexingBatchSize, + }, + os.client, + time.Duration(*esSettings.RequestTimeoutSeconds)*time.Second, + os.Platform.Log()) + } + os.syncBulkProcessor = NewBulk( + common.BulkSettings{ + FlushBytes: common.BulkFlushBytes, + FlushInterval: 0, + FlushNumReqs: 0, + }, + os.client, + time.Duration(*esSettings.RequestTimeoutSeconds)*time.Second, + os.Platform.Log()) // Set up posts index template. templateBuf, err := json.Marshal(common.GetPostTemplate(os.Platform.Config())) @@ -831,6 +848,49 @@ func (os *OpensearchInterfaceImpl) IndexChannel(rctx request.CTX, channel *model return nil } +func (os *OpensearchInterfaceImpl) SyncBulkIndexChannels(rctx request.CTX, channels []*model.Channel, getUserIDsForChannel func(channel *model.Channel) ([]string, error), teamMemberIDs []string) *model.AppError { + if len(channels) == 0 { + return nil + } + + os.mutex.RLock() + defer os.mutex.RUnlock() + + if atomic.LoadInt32(&os.ready) == 0 { + return model.NewAppError("Opensearch.SyncBulkIndexChannels", "ent.elasticsearch.not_started.error", map[string]any{"Backend": model.ElasticsearchSettingsOSBackend}, "", http.StatusInternalServerError) + } + + indexName := *os.Platform.Config().ElasticsearchSettings.IndexPrefix + common.IndexBaseChannels + metrics := os.Platform.Metrics() + + for _, channel := range channels { + userIDs, err := getUserIDsForChannel(channel) + if err != nil { + return model.NewAppError("Opensearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + searchChannel := common.ESChannelFromChannel(channel, userIDs, teamMemberIDs) + + err = os.syncBulkProcessor.IndexOp(&types.IndexOperation{ + Index_: model.NewPointer(indexName), + Id_: model.NewPointer(searchChannel.Id), + }, searchChannel) + if err != nil { + return model.NewAppError("Opensearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + if metrics != nil { + metrics.IncrementChannelIndexCounter() + } + } + + if err := os.syncBulkProcessor.Flush(); err != nil { + return model.NewAppError("Opensearch.SyncBulkIndexChannels", model.NoTranslation, nil, "", http.StatusInternalServerError).Wrap(err) + } + + return nil +} + func (os *OpensearchInterfaceImpl) SearchChannels(teamId, userID string, term string, isGuest, includeDeleted bool) ([]string, *model.AppError) { os.mutex.RLock() defer os.mutex.RUnlock() diff --git a/server/enterprise/elasticsearch/opensearch/opensearch_test.go b/server/enterprise/elasticsearch/opensearch/opensearch_test.go index 5cfada6612b..60b2f82710c 100644 --- a/server/enterprise/elasticsearch/opensearch/opensearch_test.go +++ b/server/enterprise/elasticsearch/opensearch/opensearch_test.go @@ -124,3 +124,76 @@ func (s *OpensearchInterfaceTestSuite) SetupTest() { s.Nil(s.CommonTestSuite.ESImpl.PurgeIndexes(s.th.Context)) } + +func (s *OpensearchInterfaceTestSuite) TestSyncBulkIndexChannels() { + s.Run("Should index multiple channels successfully", func() { + // Create test channels + channel1 := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypeOpen, + Name: "test-channel-1", + DisplayName: "Test Channel 1", + } + channel1.PreSave() + + channel2 := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypePrivate, + Name: "test-channel-2", + DisplayName: "Test Channel 2", + } + channel2.PreSave() + + channels := []*model.Channel{channel1, channel2} + + // Mock getUserIDsForChannel function + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return []string{s.th.BasicUser.Id, s.th.BasicUser2.Id}, nil + } + + teamMemberIDs := []string{s.th.BasicUser.Id, s.th.BasicUser2.Id} + + // Test the bulk indexing + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, channels, getUserIDsForChannel, teamMemberIDs) + s.Require().Nil(appErr) + + // Refresh the index to ensure data is searchable + s.Require().NoError(s.CommonTestSuite.RefreshIndexFn()) + + // Verify both channels are indexed + found, _, err := s.CommonTestSuite.GetDocumentFn("channels", channel1.Id) + s.Require().NoError(err) + s.Require().True(found) + + found, _, err = s.CommonTestSuite.GetDocumentFn("channels", channel2.Id) + s.Require().NoError(err) + s.Require().True(found) + }) + + s.Run("Should handle empty channels list", func() { + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return []string{}, nil + } + + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, []*model.Channel{}, getUserIDsForChannel, []string{}) + s.Require().Nil(appErr) + }) + + s.Run("Should handle getUserIDsForChannel error", func() { + channel := &model.Channel{ + TeamId: s.th.BasicTeam.Id, + Type: model.ChannelTypeOpen, + Name: "test-channel-error", + DisplayName: "Test Channel Error", + } + channel.PreSave() + + getUserIDsForChannel := func(channel *model.Channel) ([]string, error) { + return nil, model.NewAppError("TestError", "test.error", nil, "", 500) + } + + appErr := s.CommonTestSuite.ESImpl.SyncBulkIndexChannels(s.th.Context, []*model.Channel{channel}, getUserIDsForChannel, []string{}) + s.Require().NotNil(appErr) + s.Require().Contains(appErr.Error(), "test.error") + }) +} diff --git a/server/i18n/en.json b/server/i18n/en.json index b79de1db8ce..8007ab01782 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -7912,6 +7912,14 @@ "id": "ent.elasticsearch.create_client.connect_failed", "translation": "Setting up {{.Backend}} Client Failed" }, + { + "id": "ent.elasticsearch.create_processor.bulk_processor_create_failed", + "translation": "Failed to create Elasticsearch bulk processor" + }, + { + "id": "ent.elasticsearch.create_processor.sync_bulk_processor_create_failed", + "translation": "Failed to create Elasticsearch sync bulk processor" + }, { "id": "ent.elasticsearch.create_template_channels_if_not_exists.template_create_failed", "translation": "Failed to create {{.Backend}} template for channels" diff --git a/server/platform/services/searchengine/interface.go b/server/platform/services/searchengine/interface.go index 69740171d55..40f24196efb 100644 --- a/server/platform/services/searchengine/interface.go +++ b/server/platform/services/searchengine/interface.go @@ -33,6 +33,7 @@ type SearchEngineInterface interface { // IndexChannel indexes a given channel. The userIDs are only populated // for private channels. IndexChannel(rctx request.CTX, channel *model.Channel, userIDs, teamMemberIDs []string) *model.AppError + SyncBulkIndexChannels(rctx request.CTX, channels []*model.Channel, getUserIDsForChannel func(channel *model.Channel) ([]string, error), teamMemberIDs []string) *model.AppError SearchChannels(teamId, userID, term string, isGuest, includeDeleted bool) ([]string, *model.AppError) DeleteChannel(channel *model.Channel) *model.AppError IndexUser(rctx request.CTX, user *model.User, teamsIds, channelsIds []string) *model.AppError diff --git a/server/platform/services/searchengine/mocks/SearchEngineInterface.go b/server/platform/services/searchengine/mocks/SearchEngineInterface.go index 7aa7ab9f8ec..1c49890a27a 100644 --- a/server/platform/services/searchengine/mocks/SearchEngineInterface.go +++ b/server/platform/services/searchengine/mocks/SearchEngineInterface.go @@ -757,6 +757,26 @@ func (_m *SearchEngineInterface) Stop() *model.AppError { return r0 } +// SyncBulkIndexChannels provides a mock function with given fields: rctx, channels, getUserIDsForChannel, teamMemberIDs +func (_m *SearchEngineInterface) SyncBulkIndexChannels(rctx request.CTX, channels []*model.Channel, getUserIDsForChannel func(*model.Channel) ([]string, error), teamMemberIDs []string) *model.AppError { + ret := _m.Called(rctx, channels, getUserIDsForChannel, teamMemberIDs) + + if len(ret) == 0 { + panic("no return value specified for SyncBulkIndexChannels") + } + + var r0 *model.AppError + if rf, ok := ret.Get(0).(func(request.CTX, []*model.Channel, func(*model.Channel) ([]string, error), []string) *model.AppError); ok { + r0 = rf(rctx, channels, getUserIDsForChannel, teamMemberIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.AppError) + } + } + + return r0 +} + // TestConfig provides a mock function with given fields: rctx, cfg func (_m *SearchEngineInterface) TestConfig(rctx request.CTX, cfg *model.Config) *model.AppError { ret := _m.Called(rctx, cfg) diff --git a/server/public/utils/page.go b/server/public/utils/page.go new file mode 100644 index 00000000000..c9e000036e8 --- /dev/null +++ b/server/public/utils/page.go @@ -0,0 +1,43 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package utils + +// Pager fetches all items from a paginated API. +// Pager is a generic function that fetches and aggregates paginated data. +// It takes a fetch function and a perPage parameter as arguments. +// +// The fetch function is responsible for retrieving a slice of items of type T +// for a given page number. It returns the fetched items and an error, if any. +// Ideally a developer may want to use a closure to create a fetch function. +// +// The perPage parameter specifies the number of items to fetch per page. +// +// Example usage: +// +// items, err := Pager(fetchFunc, 10) +// if err != nil { +// // handle error +// } +// // process items +func Pager[T any](fetch func(page int) ([]T, error), perPage int) ([]T, error) { + var list []T + var page int + + for { + fetched, err := fetch(page) + if err != nil { + return list, err + } + + list = append(list, fetched...) + + if len(fetched) < perPage { + break + } + + page++ + } + + return list, nil +} diff --git a/server/public/utils/page_test.go b/server/public/utils/page_test.go new file mode 100644 index 00000000000..e5a8adbe48d --- /dev/null +++ b/server/public/utils/page_test.go @@ -0,0 +1,69 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package utils + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPager(t *testing.T) { + tests := []struct { + name string + fetch func(page int) ([]int, error) + perPage int + expected []int + expectErr bool + }{ + { + name: "successful fetch", + fetch: func(page int) ([]int, error) { + if page > 2 { + return nil, nil + } + return []int{page*10 + 1, page*10 + 2, page*10 + 3}, nil + }, + perPage: 3, + expected: []int{1, 2, 3, 11, 12, 13, 21, 22, 23}, + }, + { + name: "fetch with error", + fetch: func(page int) ([]int, error) { + if page == 1 { + return nil, errors.New("fetch error") + } + return []int{page*10 + 1, page*10 + 2, page*10 + 3}, nil + }, + perPage: 3, + expected: []int{1, 2, 3}, + expectErr: true, + }, + { + name: "fetch with fewer items than perPage", + fetch: func(page int) ([]int, error) { + if page > 0 { + return []int{11, 12}, nil + } + return []int{page*10 + 1, page*10 + 2, page*10 + 3}, nil + }, + perPage: 3, + expected: []int{1, 2, 3, 11, 12}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Pager(tt.fetch, tt.perPage) + if tt.expectErr { + assert.Error(t, err) + assert.Equal(t, tt.expected, result) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +}