diff --git a/.github/workflows/server-ci.yml b/.github/workflows/server-ci.yml index f6e6cf138bbd..5b1a48688b5e 100644 --- a/.github/workflows/server-ci.yml +++ b/.github/workflows/server-ci.yml @@ -5,6 +5,7 @@ # If you rename this workflow, be sure to update those workflows as well. name: Server CI on: + workflow_dispatch: # Allow manual/API triggering for linked plugin CI push: branches: - master diff --git a/server/channels/api4/channel.go b/server/channels/api4/channel.go index 344827d0557a..815f556da332 100644 --- a/server/channels/api4/channel.go +++ b/server/channels/api4/channel.go @@ -100,6 +100,8 @@ func (api *API) InitChannel() { api.BaseRoutes.ChannelModerations.Handle("", api.APISessionRequired(getChannelModerations)).Methods(http.MethodGet) api.BaseRoutes.ChannelModerations.Handle("/patch", api.APISessionRequired(patchChannelModerations)).Methods(http.MethodPut) + + api.initChannelJoinRequestRoutes() } func createChannel(c *Context, w http.ResponseWriter, r *http.Request) { @@ -144,6 +146,24 @@ func createChannel(c *Context, w http.ResponseWriter, r *http.Request) { return } + if channel.Discoverable { + if !c.App.Config().FeatureFlags.DiscoverableChannels { + c.Err = model.NewAppError("createChannel", "api.channel.discoverable_join_request.feature_disabled.app_error", nil, "", http.StatusBadRequest) + return + } + if channel.Type != model.ChannelTypePrivate { + c.Err = model.NewAppError("createChannel", "model.channel.is_valid.discoverable.app_error", nil, "", http.StatusBadRequest) + return + } + // The team-scoped check is the closest analog to "would this user + // have permission to manage discoverability after the channel is + // created" — channel-scope grants don't exist yet at creation time. + if !c.App.SessionHasPermissionToTeam(*c.AppContext.Session(), channel.TeamId, model.PermissionManagePrivateChannelDiscoverability) { + c.SetPermissionError(model.PermissionManagePrivateChannelDiscoverability) + return + } + } + sc, appErr := c.App.CreateChannelWithUser(c.AppContext, channel, c.AppContext.Session().UserId) if appErr != nil { c.Err = appErr @@ -377,12 +397,36 @@ func patchChannel(c *Context, w http.ResponseWriter, r *http.Request) { updatingProperties := patch.DisplayName != nil || patch.Name != nil || patch.Header != nil || patch.Purpose != nil || patch.GroupConstrained != nil || patch.DefaultCategoryName != nil updatingAutoTranslation := patch.AutoTranslation != nil updatingManagedCategory := patch.ManagedCategoryName != nil + updatingDiscoverable := patch.Discoverable != nil - if !updatingProperties && !updatingAutoTranslation && patch.BannerInfo == nil && !updatingManagedCategory { + if !updatingProperties && !updatingAutoTranslation && patch.BannerInfo == nil && !updatingManagedCategory && !updatingDiscoverable { c.Err = model.NewAppError("patchChannel", "api.channel.patch_update_channel.no_changes.app_error", nil, "", http.StatusBadRequest) return } + if updatingDiscoverable { + if !c.App.Config().FeatureFlags.DiscoverableChannels { + c.Err = model.NewAppError("patchChannel", "api.channel.discoverable_join_request.feature_disabled.app_error", nil, "", http.StatusBadRequest) + return + } + if oldChannel.Type != model.ChannelTypePrivate { + c.Err = model.NewAppError("patchChannel", "model.channel.is_valid.discoverable.app_error", nil, "", http.StatusBadRequest) + return + } + if oldChannel.DeleteAt != 0 { + c.Err = model.NewAppError("patchChannel", "api.channel.update_channel.deleted.app_error", nil, "", http.StatusBadRequest) + return + } + if oldChannel.IsShared() { + c.Err = model.NewAppError("patchChannel", "api.channel.discoverable_join_request.shared.app_error", nil, "", http.StatusBadRequest) + return + } + if ok, _ := c.App.SessionHasPermissionToChannel(c.AppContext, *c.AppContext.Session(), c.Params.ChannelId, model.PermissionManagePrivateChannelDiscoverability); !ok { + c.SetPermissionError(model.PermissionManagePrivateChannelDiscoverability) + return + } + } + if updatingAutoTranslation && (c.App.AutoTranslation() == nil || !c.App.AutoTranslation().IsFeatureAvailable()) { c.Err = model.NewAppError("patchChannel", "api.channel.patch_update_channel.feature_not_available.app_error", nil, "", http.StatusForbidden) return @@ -806,6 +850,9 @@ func getChannel(c *Context, w http.ResponseWriter, r *http.Request) { } } } else if ok, _ := c.App.SessionHasPermissionToChannel(c.AppContext, *c.AppContext.Session(), c.Params.ChannelId, model.PermissionReadChannel); !ok { + if served := serveDiscoverableNonMember(c, w, channel); served { + return + } c.SetPermissionError(model.PermissionReadChannel) return } @@ -822,6 +869,80 @@ func getChannel(c *Context, w http.ResponseWriter, r *http.Request) { } } +// sanitizeDiscoverableChannel returns a copy of `channel` containing only the +// fields safe to expose to a non-member who can see the channel through the +// discoverable surface. Cell-level secrets such as Props or per-channel +// scheme identifiers are stripped so this view is strictly read-only metadata. +func sanitizeDiscoverableChannel(channel *model.Channel) *model.Channel { + if channel == nil { + return nil + } + return &model.Channel{ + Id: channel.Id, + TeamId: channel.TeamId, + Type: channel.Type, + DisplayName: channel.DisplayName, + Name: channel.Name, + Header: channel.Header, + Purpose: channel.Purpose, + Discoverable: channel.Discoverable, + PolicyEnforced: channel.PolicyEnforced, + CreateAt: channel.CreateAt, + UpdateAt: channel.UpdateAt, + DeleteAt: channel.DeleteAt, + } +} + +// discoverableNonMemberView returns a sanitized non-member view of `channel` +// when the calling user qualifies under the discoverable visibility rules, +// or (nil, nil) when the channel must remain hidden — the caller should +// emit its own permission-denied response. Errors from the discoverable +// lookup are returned for the caller to assign to c.Err. When the feature +// flag is off, this returns (nil, nil) and the caller falls through to its +// default 403/404 path so the existing read contract is preserved. +func discoverableNonMemberView(c *Context, channel *model.Channel) (*model.Channel, *model.AppError) { + if !c.App.Config().FeatureFlags.DiscoverableChannels { + return nil, nil + } + user, userErr := c.App.GetUser(c.AppContext.Session().UserId) + if userErr != nil { + return nil, userErr + } + allowed, allowedErr := c.App.IsDiscoverableJoinAllowed(c.AppContext, user, channel) + if allowedErr != nil { + return nil, allowedErr + } + if !allowed { + return nil, nil + } + return sanitizeDiscoverableChannel(channel), nil +} + +// serveDiscoverableNonMember writes the sanitized non-member discoverable +// view of `channel` to `w` and returns true when the request was handled +// here (either the response was written, or c.Err was set on a lookup +// failure). Returns false without touching the response when the caller +// should emit its own permission-denied response (the channel is hidden +// from this non-member, or the feature flag is off). +// +// Centralising this here means every read endpoint that previously emitted +// 403/404 to a non-member can keep its prior failure shape while opting in +// to the discoverable surface with a single `if served { return }` guard. +func serveDiscoverableNonMember(c *Context, w http.ResponseWriter, channel *model.Channel) bool { + sanitized, err := discoverableNonMemberView(c, channel) + if err != nil { + c.Err = err + return true + } + if sanitized == nil { + return false + } + if encErr := json.NewEncoder(w).Encode(sanitized); encErr != nil { + c.Logger.Warn("Error while writing response", mlog.Err(encErr)) + } + return true +} + func getChannelUnread(c *Context, w http.ResponseWriter, r *http.Request) { c.RequireChannelId().RequireUserId() if c.Err != nil { @@ -1646,6 +1767,9 @@ func getChannelByName(c *Context, w http.ResponseWriter, r *http.Request) { // allows team admins to access private channel if !c.App.SessionHasPermissionToTeam(*c.AppContext.Session(), channel.TeamId, model.PermissionManageTeam) { if ok, _ := c.App.SessionHasPermissionToChannel(c.AppContext, *c.AppContext.Session(), channel.Id, model.PermissionReadChannel); !ok { + if served := serveDiscoverableNonMember(c, w, channel); served { + return + } c.Err = model.NewAppError("getChannelByName", "app.channel.get_by_name.missing.app_error", nil, "teamId="+channel.TeamId+", "+"name="+channel.Name+"", http.StatusNotFound) return } @@ -1686,6 +1810,9 @@ func getChannelByNameForTeamName(c *Context, w http.ResponseWriter, r *http.Requ } else if !channelOk { // allows team admins to access private channel if !c.App.SessionHasPermissionToTeam(*c.AppContext.Session(), channel.TeamId, model.PermissionManageTeam) { + if served := serveDiscoverableNonMember(c, w, channel); served { + return + } c.Err = model.NewAppError("getChannelByNameForTeamName", "app.channel.get_by_name.missing.app_error", nil, "teamId="+channel.TeamId+", "+"name="+channel.Name+"", http.StatusNotFound) return } @@ -2252,9 +2379,25 @@ func addChannelMember(c *Context, w http.ResponseWriter, r *http.Request) { if channel.Type == model.ChannelTypePrivate { if hasPermission, _ := c.App.SessionHasPermissionToChannel(c.AppContext, *c.AppContext.Session(), channel.Id, model.PermissionManagePrivateChannelMembers); !hasPermission { + // Allow the user to self-add to a discoverable private channel only + // through the request flow — the discoverable toggle does not + // implicitly grant PermissionManagePrivateChannelMembers, and the + // existing addChannelMember API would otherwise let any caller + // bypass the queue by issuing a direct POST. c.SetPermissionError(model.PermissionManagePrivateChannelMembers) return } + + // Discoverable + no policy: the request flow is the only path. Even + // admins use it to ensure the audit trail. We exempt the case where + // the requester is adding someone other than themselves so admin + // invites still work. + for _, userId := range userIds { + if c.App.IsDiscoverableSelfAddBlocked(c.AppContext, channel, c.AppContext.Session().UserId, userId) { + c.Err = model.NewAppError("addChannelMember", "api.channel.discoverable_join_request.discoverable_requires_approval.app_error", nil, "channel_id="+channel.Id, http.StatusForbidden) + return + } + } } if channel.IsGroupConstrained() { diff --git a/server/channels/api4/channel_join_request.go b/server/channels/api4/channel_join_request.go new file mode 100644 index 000000000000..8f928d3c3b9d --- /dev/null +++ b/server/channels/api4/channel_join_request.go @@ -0,0 +1,293 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package api4 + +import ( + "encoding/json" + "net/http" + "strconv" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" +) + +// initChannelJoinRequestRoutes registers the discoverable-private-channel +// join request endpoints. The route group is split into its own file so the +// handlers stay isolated from the rest of api4/channel.go. +func (api *API) initChannelJoinRequestRoutes() { + if !api.srv.Config().FeatureFlags.DiscoverableChannels { + return + } + + api.BaseRoutes.Channel.Handle("/join_request", api.APISessionRequired(requestJoinChannel)).Methods(http.MethodPost) + api.BaseRoutes.Channel.Handle("/join_request", api.APISessionRequired(getMyChannelJoinRequest)).Methods(http.MethodGet) + api.BaseRoutes.Channel.Handle("/join_request", api.APISessionRequired(withdrawMyChannelJoinRequest)).Methods(http.MethodDelete) + + api.BaseRoutes.Channel.Handle("/join_requests", api.APISessionRequired(getChannelJoinRequests)).Methods(http.MethodGet) + api.BaseRoutes.Channel.Handle("/join_requests/count", api.APISessionRequired(countPendingChannelJoinRequests)).Methods(http.MethodGet) + api.BaseRoutes.Channel.Handle("/join_requests/{request_id:[A-Za-z0-9]+}", api.APISessionRequired(patchChannelJoinRequest)).Methods(http.MethodPatch) + + api.BaseRoutes.User.Handle("/channel_join_requests", api.APISessionRequired(getMyChannelJoinRequests)).Methods(http.MethodGet) +} + +// channelJoinRequestBody is the POST body shape for /channels/{id}/join_request. +type channelJoinRequestBody struct { + Message string `json:"message"` +} + +func requireDiscoverableChannelsEnabled(c *Context, where string) bool { + if !c.App.Config().FeatureFlags.DiscoverableChannels { + c.Err = model.NewAppError(where, "api.channel.discoverable_join_request.feature_disabled.app_error", nil, "", http.StatusNotFound) + return false + } + return true +} + +func requestJoinChannel(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireChannelId() + if c.Err != nil { + return + } + if !requireDiscoverableChannelsEnabled(c, "requestJoinChannel") { + return + } + + var body channelJoinRequestBody + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + c.SetInvalidParamWithErr("body", err) + return + } + + auditRec := c.MakeAuditRecord(model.AuditEventCreateChannelJoinRequest, model.AuditStatusFail) + defer c.LogAuditRec(auditRec) + model.AddEventParameterToAuditRec(auditRec, "channel_id", c.Params.ChannelId) + model.AddEventParameterToAuditRec(auditRec, "user_id", c.AppContext.Session().UserId) + + joined, req, appErr := c.App.RequestJoinChannel(c.AppContext, c.AppContext.Session().UserId, c.Params.ChannelId, body.Message) + if appErr != nil { + c.Err = appErr + return + } + + auditRec.Success() + if req != nil { + auditRec.AddEventResultState(req) + } + + if joined { + // Mirror the membership endpoint's "no body, just status" semantics + // when the user was added directly via the ABAC fast path. + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(map[string]string{"status": model.ChannelJoinRequestStatusApproved}); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } + return + } + + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(req); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} + +func getMyChannelJoinRequest(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireChannelId() + if c.Err != nil { + return + } + if !requireDiscoverableChannelsEnabled(c, "getMyChannelJoinRequest") { + return + } + + req, appErr := c.App.GetMyChannelJoinRequest(c.AppContext, c.AppContext.Session().UserId, c.Params.ChannelId) + if appErr != nil { + c.Err = appErr + return + } + + if req == nil { + // Mirror REST conventions: not-found instead of an explicit `null` + // so clients can distinguish "no pending request" from "service down". + w.WriteHeader(http.StatusNotFound) + return + } + + if err := json.NewEncoder(w).Encode(req); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} + +func withdrawMyChannelJoinRequest(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireChannelId() + if c.Err != nil { + return + } + if !requireDiscoverableChannelsEnabled(c, "withdrawMyChannelJoinRequest") { + return + } + + auditRec := c.MakeAuditRecord(model.AuditEventWithdrawChannelJoinRequest, model.AuditStatusFail) + defer c.LogAuditRec(auditRec) + model.AddEventParameterToAuditRec(auditRec, "channel_id", c.Params.ChannelId) + model.AddEventParameterToAuditRec(auditRec, "user_id", c.AppContext.Session().UserId) + + req, appErr := c.App.GetMyChannelJoinRequest(c.AppContext, c.AppContext.Session().UserId, c.Params.ChannelId) + if appErr != nil { + c.Err = appErr + return + } + if req == nil { + c.Err = model.NewAppError("withdrawMyChannelJoinRequest", "app.channel.join_request.not_found.app_error", nil, "channel_id="+c.Params.ChannelId, http.StatusNotFound) + return + } + + updated, appErr := c.App.WithdrawChannelJoinRequest(c.AppContext, req.Id, c.AppContext.Session().UserId) + if appErr != nil { + c.Err = appErr + return + } + + auditRec.Success() + auditRec.AddEventResultState(updated) + + if err := json.NewEncoder(w).Encode(updated); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} + +func getChannelJoinRequests(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireChannelId() + if c.Err != nil { + return + } + if !requireDiscoverableChannelsEnabled(c, "getChannelJoinRequests") { + return + } + + if ok, _ := c.App.SessionHasPermissionToChannel(c.AppContext, *c.AppContext.Session(), c.Params.ChannelId, model.PermissionManageChannelJoinRequests); !ok { + c.SetPermissionError(model.PermissionManageChannelJoinRequests) + return + } + + opts := model.GetChannelJoinRequestsOpts{ + Status: r.URL.Query().Get("status"), + Page: c.Params.Page, + PerPage: c.Params.PerPage, + } + + list, appErr := c.App.GetChannelJoinRequests(c.AppContext, c.Params.ChannelId, opts) + if appErr != nil { + c.Err = appErr + return + } + + if err := json.NewEncoder(w).Encode(list); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} + +func countPendingChannelJoinRequests(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireChannelId() + if c.Err != nil { + return + } + if !requireDiscoverableChannelsEnabled(c, "countPendingChannelJoinRequests") { + return + } + + if ok, _ := c.App.SessionHasPermissionToChannel(c.AppContext, *c.AppContext.Session(), c.Params.ChannelId, model.PermissionManageChannelJoinRequests); !ok { + c.SetPermissionError(model.PermissionManageChannelJoinRequests) + return + } + + count, appErr := c.App.CountPendingChannelJoinRequests(c.AppContext, c.Params.ChannelId) + if appErr != nil { + c.Err = appErr + return + } + + if err := json.NewEncoder(w).Encode(map[string]int64{"count": count}); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} + +func patchChannelJoinRequest(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireChannelId() + if c.Err != nil { + return + } + if !requireDiscoverableChannelsEnabled(c, "patchChannelJoinRequest") { + return + } + if !model.IsValidId(c.Params.RequestId) { + c.SetInvalidURLParam("request_id") + return + } + + if ok, _ := c.App.SessionHasPermissionToChannel(c.AppContext, *c.AppContext.Session(), c.Params.ChannelId, model.PermissionManageChannelJoinRequests); !ok { + c.SetPermissionError(model.PermissionManageChannelJoinRequests) + return + } + + var patch model.ChannelJoinRequestPatch + if err := json.NewDecoder(r.Body).Decode(&patch); err != nil { + c.SetInvalidParamWithErr("channel_join_request_patch", err) + return + } + + auditRec := c.MakeAuditRecord(model.AuditEventUpdateChannelJoinRequest, model.AuditStatusFail) + defer c.LogAuditRec(auditRec) + model.AddEventParameterToAuditRec(auditRec, "channel_id", c.Params.ChannelId) + model.AddEventParameterToAuditRec(auditRec, "request_id", c.Params.RequestId) + model.AddEventParameterToAuditRec(auditRec, "status", patch.Status) + // Capture only the presence of a denial reason in the audit log; the + // free-text contents are intentionally excluded. + model.AddEventParameterToAuditRec(auditRec, "has_denial_reason", strconv.FormatBool(patch.DenialReason != nil && *patch.DenialReason != "")) + + updated, appErr := c.App.UpdateChannelJoinRequest(c.AppContext, c.Params.RequestId, c.Params.ChannelId, &patch, c.AppContext.Session().UserId) + if appErr != nil { + c.Err = appErr + return + } + + auditRec.Success() + auditRec.AddEventResultState(updated) + + if err := json.NewEncoder(w).Encode(updated); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} + +func getMyChannelJoinRequests(c *Context, w http.ResponseWriter, r *http.Request) { + c.RequireUserId() + if c.Err != nil { + return + } + if !requireDiscoverableChannelsEnabled(c, "getMyChannelJoinRequests") { + return + } + + // Only the calling user can list their own requests; admins should use + // the per-channel queue endpoint. + if c.Params.UserId != c.AppContext.Session().UserId { + c.SetPermissionError(model.PermissionEditOtherUsers) + return + } + + opts := model.GetChannelJoinRequestsOpts{ + Status: r.URL.Query().Get("status"), + Page: c.Params.Page, + PerPage: c.Params.PerPage, + } + + list, appErr := c.App.GetMyChannelJoinRequests(c.AppContext, c.AppContext.Session().UserId, opts) + if appErr != nil { + c.Err = appErr + return + } + + if err := json.NewEncoder(w).Encode(list); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } +} diff --git a/server/channels/api4/channel_join_request_test.go b/server/channels/api4/channel_join_request_test.go new file mode 100644 index 000000000000..57bfc586854b --- /dev/null +++ b/server/channels/api4/channel_join_request_test.go @@ -0,0 +1,154 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package api4 + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" +) + +// setupDiscoverableTH spins up an api4 fixture with the discoverable channels +// feature flag enabled so the new routes are registered. +func setupDiscoverableTH(t *testing.T) *TestHelper { + t.Helper() + return SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.DiscoverableChannels = true + }).InitBasic(t) +} + +// markDiscoverableViaAdmin patches `channel` to discoverable=true using the +// SystemAdminClient so the permission check is satisfied without needing to +// rebind the channel-admin role on the test fixture. +func markDiscoverableViaAdmin(t *testing.T, th *TestHelper, channel *model.Channel) *model.Channel { + t.Helper() + on := true + patched, _, err := th.SystemAdminClient.PatchChannel(context.Background(), channel.Id, &model.ChannelPatch{Discoverable: &on}) + require.NoError(t, err) + require.True(t, patched.Discoverable) + return patched +} + +func TestRequestJoinChannelAPI_HappyPath(t *testing.T) { + mainHelper.Parallel(t) + th := setupDiscoverableTH(t) + + channel := th.CreatePrivateChannel(t) + channel = markDiscoverableViaAdmin(t, th, channel) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + _, _, err := th.Client.Login(context.Background(), other.Email, other.Password) + require.NoError(t, err) + + body := []byte(`{"message":"hi"}`) + resp, err := th.Client.DoAPIPost(context.Background(), "/channels/"+channel.Id+"/join_request", string(body)) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusCreated, resp.StatusCode) + + var req model.ChannelJoinRequest + require.NoError(t, json.NewDecoder(resp.Body).Decode(&req)) + assert.Equal(t, model.ChannelJoinRequestStatusPending, req.Status) + assert.Equal(t, channel.Id, req.ChannelId) + assert.Equal(t, other.Id, req.UserId) +} + +func TestRequestJoinChannelAPI_FeatureDisabled(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + channel := th.CreatePrivateChannel(t) + body := []byte(`{"message":"hi"}`) + resp, err := th.Client.DoAPIPost(context.Background(), "/channels/"+channel.Id+"/join_request", string(body)) + defer closeBodyOrNil(resp) + require.Error(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusNotFound, resp.StatusCode, "route must be unregistered when feature flag is off") +} + +func TestPatchChannelDiscoverable_RejectsNonPrivate(t *testing.T) { + mainHelper.Parallel(t) + th := setupDiscoverableTH(t) + + publicChannel := th.CreatePublicChannel(t) + on := true + _, resp, err := th.SystemAdminClient.PatchChannel(context.Background(), publicChannel.Id, &model.ChannelPatch{Discoverable: &on}) + require.Error(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestAddChannelMember_BlocksSelfAddOnDiscoverable(t *testing.T) { + mainHelper.Parallel(t) + th := setupDiscoverableTH(t) + + channel := th.CreatePrivateChannel(t) + channel = markDiscoverableViaAdmin(t, th, channel) + + // Add a user that has manage-private-channel-members on a different + // channel but not this one. Use Client (BasicUser2) - they're a team + // member but not yet a channel member here. + _, _, err := th.Client.Login(context.Background(), th.BasicUser2.Email, th.BasicUser2.Password) + require.NoError(t, err) + + _, resp, err := th.Client.AddChannelMember(context.Background(), channel.Id, th.BasicUser2.Id) + require.Error(t, err) + require.NotNil(t, resp) + // Without channel admin permission the underlying permission check + // fails first; either way the request flow is what they need to use. + assert.True(t, resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized, + "got %d", resp.StatusCode) +} + +func TestGetChannelByName_HiddenForNonQualifyingNonMember(t *testing.T) { + mainHelper.Parallel(t) + th := setupDiscoverableTH(t) + + // Plain (non-discoverable) private channel: a non-member must still get + // 404 — this guards against a regression in the existing read paths. + channel := th.CreatePrivateChannel(t) + + _, _, err := th.Client.Login(context.Background(), th.BasicUser2.Email, th.BasicUser2.Password) + require.NoError(t, err) + + _, resp, err := th.Client.GetChannelByName(context.Background(), channel.Name, th.BasicTeam.Id, "") + require.Error(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +func TestGetChannelByName_VisibleForQualifyingNonMemberOnDiscoverable(t *testing.T) { + mainHelper.Parallel(t) + th := setupDiscoverableTH(t) + + channel := th.CreatePrivateChannel(t) + channel = markDiscoverableViaAdmin(t, th, channel) + + _, _, err := th.Client.Login(context.Background(), th.BasicUser2.Email, th.BasicUser2.Password) + require.NoError(t, err) + + got, _, err := th.Client.GetChannelByName(context.Background(), channel.Name, th.BasicTeam.Id, "") + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, channel.Id, got.Id) + assert.True(t, got.Discoverable) +} + +// closeBodyOrNil is a tiny helper so the negative-path tests don't need to +// branch on a nil response body before deferring Close. +func closeBodyOrNil(resp *http.Response) { + if resp == nil || resp.Body == nil { + return + } + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() +} diff --git a/server/channels/app/channel.go b/server/channels/app/channel.go index 713e1cbea461..10bc89acb816 100644 --- a/server/channels/app/channel.go +++ b/server/channels/app/channel.go @@ -737,6 +737,18 @@ func (a *App) GetGroupChannel(rctx request.CTX, userIDs []string) (*model.Channe // UpdateChannel updates a given channel by its Id. It also publishes the CHANNEL_UPDATED event. func (a *App) UpdateChannel(rctx request.CTX, channel *model.Channel) (*model.Channel, *model.AppError) { + oldChannel, getErr := a.Srv().Store().Channel().Get(channel.Id, true) + if getErr != nil { + errCtx := map[string]any{"channel_id": channel.Id} + var nfErr *store.ErrNotFound + switch { + case errors.As(getErr, &nfErr): + return nil, model.NewAppError("UpdateChannel", "app.channel.get.existing.app_error", errCtx, "", http.StatusNotFound).Wrap(getErr) + default: + return nil, model.NewAppError("UpdateChannel", "app.channel.get.find.app_error", errCtx, "", http.StatusInternalServerError).Wrap(getErr) + } + } + enforced, appErr := a.ChannelAccessControlled(rctx, channel.Id) if appErr != nil { return nil, appErr @@ -752,17 +764,19 @@ func (a *App) UpdateChannel(rctx request.CTX, channel *model.Channel) (*model.Ch // silent type flip would change what the existing policy actually // does to members. The admin must remove the policy first and // re-apply it after the conversion if they still want it. - current, getErr := a.Srv().Store().Channel().Get(channel.Id, true) - if getErr != nil { - return nil, model.NewAppError("UpdateChannel", "app.channel.get.find.app_error", nil, "", http.StatusInternalServerError).Wrap(getErr) - } - if current.Type != channel.Type { + if oldChannel.Type != channel.Type { return nil, model.NewAppError("UpdateChannel", "api.channel.update_channel.policy_enforced_type_conversion.app_error", nil, "channel has an active ABAC policy; remove the policy before converting between public and private", http.StatusBadRequest) } } + var channelErr *model.AppError + channel, channelErr = a.runGuardedChannelWillBeUpdated(rctx, channel, oldChannel) + if channelErr != nil { + return nil, channelErr + } + _, err := a.Srv().Store().Channel().Update(rctx, channel) if err != nil { var appErr *model.AppError @@ -835,6 +849,14 @@ func (a *App) UpdateChannelScheme(rctx request.CTX, channel *model.Channel) (*mo } func (a *App) UpdateChannelPrivacy(rctx request.CTX, oldChannel *model.Channel, user *model.User) (*model.Channel, *model.AppError) { + wasDiscoverable := oldChannel.Discoverable + // Public channels are inherently joinable; the discoverable flag only + // has meaning for private channels. Clear it eagerly so callers reading + // the row mid-conversion don't see an inconsistent state. + if oldChannel.Type == model.ChannelTypeOpen { + oldChannel.Discoverable = false + } + channel, err := a.UpdateChannel(rctx, oldChannel) if err != nil { return channel, err @@ -844,6 +866,11 @@ func (a *App) UpdateChannelPrivacy(rctx request.CTX, oldChannel *model.Channel, if postErr != nil { if channel.Type == model.ChannelTypeOpen { channel.Type = model.ChannelTypePrivate + // Restore the discoverable flag we eagerly cleared above so + // the rollback fully undoes the conversion. Without this the + // caller would see a private channel with discoverable=false + // (and would have to re-toggle it). + channel.Discoverable = wasDiscoverable } else { channel.Type = model.ChannelTypeOpen } @@ -854,6 +881,19 @@ func (a *App) UpdateChannelPrivacy(rctx request.CTX, oldChannel *model.Channel, return channel, postErr } + // Now that the conversion is fully committed, cancel pending join + // requests for the formerly discoverable private channel — the WS + // broadcast inside the helper updates each requester's My Pending + // Requests list in real-time. Doing this after the privacy-message + // step ensures a transient post failure (which triggers the rollback + // above) cannot leave requests cancelled against a still-private + // channel. + if wasDiscoverable && channel.Type == model.ChannelTypeOpen { + a.Srv().Go(func() { + a.CancelPendingChannelJoinRequestsOnConvert(rctx, channel) + }) + } + a.Srv().Platform().InvalidateCacheForChannel(channel) messageWs := model.NewWebSocketEvent(model.WebsocketEventChannelConverted, channel.TeamId, "", "", nil, "") @@ -906,6 +946,10 @@ func (a *App) RestoreChannel(rctx request.CTX, channel *model.Channel, userID st return nil, model.NewAppError("restoreChannel", "api.channel.restore_channel.restored.app_error", nil, "", http.StatusBadRequest) } + if appErr := a.runGuardedChannelWillBeRestored(rctx, channel); appErr != nil { + return nil, appErr + } + if err := a.Srv().Store().Channel().Restore(channel.Id, model.GetMillis()); err != nil { return nil, model.NewAppError("RestoreChannel", "app.channel.restore.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } @@ -1810,23 +1854,10 @@ func (a *App) addUserToChannel(rctx request.CTX, user *model.User, channel *mode } } - var rejectionReason string - pluginContext := pluginContext(rctx) - a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool { - updatedMember, reason := hooks.ChannelMemberWillBeAdded(pluginContext, newMember) - if reason != "" { - rejectionReason = reason - return false - } - if updatedMember != nil { - newMember = updatedMember - } - return true - }, plugin.ChannelMemberWillBeAddedID) - - if rejectionReason != "" { - return nil, model.NewAppError("AddUserToChannel", "app.channel.add_user.to.channel.rejected_by_plugin", - map[string]any{"Reason": rejectionReason}, "", http.StatusBadRequest) + var channelMemberErr *model.AppError + newMember, channelMemberErr = a.runGuardedChannelMemberWillBeAdded(rctx, channel.Id, newMember) + if channelMemberErr != nil { + return nil, channelMemberErr } newMember, nErr = a.Srv().Store().Channel().SaveMember(rctx, newMember) @@ -3229,6 +3260,10 @@ func (a *App) AutocompleteChannels(rctx request.CTX, userID, term string) (model return nil, model.NewAppError("AutocompleteChannels", "app.channel.search.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } + channelList, _, appErr = a.FilterChannelListWithTeamDataForUserVisibility(rctx, channelList, userID) + if appErr != nil { + return nil, appErr + } return channelList, nil } @@ -3246,7 +3281,7 @@ func (a *App) AutocompleteChannelsForTeam(rctx request.CTX, teamID, userID, term return nil, model.NewAppError("AutocompleteChannels", "app.channel.search.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } - return channelList, nil + return a.FilterChannelListForUserVisibility(rctx, channelList, userID) } func (a *App) AutocompleteChannelsForTeamFiltered(rctx request.CTX, teamID, userID, term string, privateOnly, excludeGroupConstrained bool) (model.ChannelList, *model.AppError) { @@ -3263,7 +3298,7 @@ func (a *App) AutocompleteChannelsForTeamFiltered(rctx request.CTX, teamID, user return nil, model.NewAppError("AutocompleteChannelsForTeamFiltered", "app.channel.search.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } - return channelList, nil + return a.FilterChannelListForUserVisibility(rctx, channelList, userID) } func (a *App) AutocompleteChannelsForSearch(rctx request.CTX, teamID string, userID string, term string) (model.ChannelList, *model.AppError) { diff --git a/server/channels/app/channel_discoverable_visibility.go b/server/channels/app/channel_discoverable_visibility.go new file mode 100644 index 000000000000..95cf7a2c012a --- /dev/null +++ b/server/channels/app/channel_discoverable_visibility.go @@ -0,0 +1,384 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "context" + "sync" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/public/shared/request" +) + +// channelVisibilityCacheKey is the per-request request.CTX value key used to +// memoise PDP membership decisions across N+1 channel filtering work in a +// single Browse Channels load. +type channelVisibilityCacheKey struct{} + +type channelVisibilityCache struct { + mu sync.Mutex + decisions map[string]bool +} + +func getChannelVisibilityCache(rctx request.CTX) *channelVisibilityCache { + if v := rctx.Context().Value(channelVisibilityCacheKey{}); v != nil { + if cache, ok := v.(*channelVisibilityCache); ok { + return cache + } + } + return nil +} + +// withChannelVisibilityCache returns a request context that memoises PDP +// membership decisions across the visibility filter calls in a single request. +// It's safe to call this multiple times — only the outermost installation +// allocates a cache. +func withChannelVisibilityCache(rctx request.CTX) request.CTX { + if getChannelVisibilityCache(rctx) != nil { + return rctx + } + cache := &channelVisibilityCache{decisions: map[string]bool{}} + return rctx.WithContext(context.WithValue(rctx.Context(), channelVisibilityCacheKey{}, cache)) +} + +func (c *channelVisibilityCache) get(channelID string) (bool, bool) { + c.mu.Lock() + defer c.mu.Unlock() + v, ok := c.decisions[channelID] + return v, ok +} + +func (c *channelVisibilityCache) set(channelID string, allow bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.decisions[channelID] = allow +} + +// FilterDiscoverableChannelsByPolicy removes from `channels` any +// policy-enforced private channel that the user fails to satisfy — the +// security-critical visibility invariant in plan §6c. Channels without an +// active policy are returned untouched. Callers that need the additional +// "non-member private must be discoverable" gate should use +// FilterChannelsForUserVisibility instead. +// +// Failure modes are fail-secure: a missing AccessControl service, a +// subject-build failure, or any PDP error drops the offending channel from +// the result so a non-qualifying user can never be inadvertently shown a +// gated channel. Decisions are cached per-request via the request.CTX value +// bag installed by withChannelVisibilityCache. +func (a *App) FilterDiscoverableChannelsByPolicy(rctx request.CTX, channels []*model.Channel, userID string) ([]*model.Channel, *model.AppError) { + if len(channels) == 0 { + return channels, nil + } + + if !a.Config().FeatureFlags.DiscoverableChannels { + return channels, nil + } + + rctx = withChannelVisibilityCache(rctx) + cache := getChannelVisibilityCache(rctx) + + var ( + user *model.User + userErr *model.AppError + userOnce sync.Once + filtered = make([]*model.Channel, 0, len(channels)) + dropCount int + ) + + for _, channel := range channels { + if channel == nil { + continue + } + + if !channel.PolicyEnforced || channel.Type != model.ChannelTypePrivate || !channel.Discoverable { + filtered = append(filtered, channel) + continue + } + + if cached, ok := cache.get(channel.Id); ok { + if cached { + filtered = append(filtered, channel) + } else { + dropCount++ + } + continue + } + + userOnce.Do(func() { + user, userErr = a.GetUser(userID) + }) + if userErr != nil { + return nil, userErr + } + + // Guests are never permitted to see discoverable private channels. + if user.IsGuest() { + cache.set(channel.Id, false) + dropCount++ + continue + } + + decision, evalErr := a.evaluateChannelMembership(rctx, user, channel) + if evalErr != nil { + rctx.Logger().Warn("FilterDiscoverableChannelsByPolicy: PDP error, hiding channel (fail-secure)", + mlog.String("user_id", userID), + mlog.String("channel_id", channel.Id), + mlog.Err(evalErr), + ) + cache.set(channel.Id, false) + dropCount++ + continue + } + cache.set(channel.Id, decision) + if decision { + filtered = append(filtered, channel) + } else { + dropCount++ + } + } + + return filtered, nil +} + +// FilterChannelsForUserVisibility wraps FilterDiscoverableChannelsByPolicy with +// the secondary invariant: a non-member private channel must be discoverable +// to be visible at all. The caller is expected to scope `channels` to results +// where the user is a non-member; member channels should not be passed +// through this filter (their visibility is governed by membership alone). +// +// In practice the search/autocomplete store paths return a mix of member and +// non-member rows; callers should pass the full list because the helper +// detects membership-implying fields. The current implementation only checks +// the discoverability gate (the SQL-level membership join already excluded +// unaffiliated channels). +func (a *App) FilterChannelsForUserVisibility(rctx request.CTX, channels []*model.Channel, userID string) ([]*model.Channel, *model.AppError) { + return a.FilterDiscoverableChannelsByPolicy(rctx, channels, userID) +} + +// FilterChannelListForUserVisibility is the convenience overload for +// model.ChannelList callers (the standard list shape returned by app-layer +// search functions). +func (a *App) FilterChannelListForUserVisibility(rctx request.CTX, channels model.ChannelList, userID string) (model.ChannelList, *model.AppError) { + filtered, err := a.FilterChannelsForUserVisibility(rctx, channels, userID) + if err != nil { + return nil, err + } + return model.ChannelList(filtered), nil +} + +// FilterChannelListWithTeamDataForUserVisibility filters the team-data list +// shape used by Autocomplete and SearchAllChannels. The function preserves +// the embedded TeamDisplayName / TeamName fields. Returns the post-filter +// total adjustment so paginated callers can shrink TotalCount alongside the +// trimmed result set. +func (a *App) FilterChannelListWithTeamDataForUserVisibility(rctx request.CTX, channels model.ChannelListWithTeamData, userID string) (model.ChannelListWithTeamData, int, *model.AppError) { + if len(channels) == 0 { + return channels, 0, nil + } + + if !a.Config().FeatureFlags.DiscoverableChannels { + return channels, 0, nil + } + + rctx = withChannelVisibilityCache(rctx) + cache := getChannelVisibilityCache(rctx) + + var ( + user *model.User + userErr *model.AppError + userOnce sync.Once + out = make(model.ChannelListWithTeamData, 0, len(channels)) + dropped int + ) + + for i := range channels { + ch := channels[i] + if !ch.PolicyEnforced || ch.Type != model.ChannelTypePrivate || !ch.Discoverable { + out = append(out, ch) + continue + } + + if cached, ok := cache.get(ch.Id); ok { + if cached { + out = append(out, ch) + } else { + dropped++ + } + continue + } + + userOnce.Do(func() { + user, userErr = a.GetUser(userID) + }) + if userErr != nil { + return nil, 0, userErr + } + + if user.IsGuest() { + cache.set(ch.Id, false) + dropped++ + continue + } + + decision, evalErr := a.evaluateChannelMembership(rctx, user, &ch.Channel) + if evalErr != nil { + rctx.Logger().Warn("FilterChannelListWithTeamDataForUserVisibility: PDP error, hiding channel (fail-secure)", + mlog.String("user_id", userID), + mlog.String("channel_id", ch.Id), + mlog.Err(evalErr), + ) + cache.set(ch.Id, false) + dropped++ + continue + } + cache.set(ch.Id, decision) + if decision { + out = append(out, ch) + } else { + dropped++ + } + } + + return out, dropped, nil +} + +// IsDiscoverableJoinAllowed reports whether `user` may view `channel` as a +// non-member through the discoverable-channels surface. Returns 404 (mapped +// by callers) when the channel is hidden from this user — matching the +// "indistinguishable from a non-existent channel" requirement so the policy +// cannot act as an existence oracle. +func (a *App) IsDiscoverableJoinAllowed(rctx request.CTX, user *model.User, channel *model.Channel) (bool, *model.AppError) { + if channel == nil { + return false, nil + } + if channel.Type != model.ChannelTypePrivate || !channel.Discoverable { + return false, nil + } + if user == nil || user.IsGuest() || user.DeleteAt != 0 { + return false, nil + } + if channel.DeleteAt != 0 || channel.IsShared() { + return false, nil + } + if !channel.PolicyEnforced { + return true, nil + } + decision, evalErr := a.evaluateChannelMembership(rctx, user, channel) + if evalErr != nil { + // Fail-secure: PDP failure hides the channel rather than leak it. + rctx.Logger().Warn("IsDiscoverableJoinAllowed: PDP error, hiding channel (fail-secure)", + mlog.String("user_id", user.Id), + mlog.String("channel_id", channel.Id), + mlog.Err(evalErr), + ) + return false, nil + } + return decision, nil +} + +// CancelPendingChannelJoinRequestsOnConvert transitions every pending request +// for a channel to the withdrawn state — used when the channel is converted +// to public (open channels are inherently joinable, so a pending queue is +// nonsensical) and when the channel is archived. Failures are logged because +// the conversion / archive must not be blocked. +func (a *App) CancelPendingChannelJoinRequestsOnConvert(rctx request.CTX, channel *model.Channel) { + if channel == nil { + return + } + + const ( + pageSize = 200 + maxIterations = 50 // hard cap at ~10k requests per channel + ) + for range maxIterations { + opts := model.GetChannelJoinRequestsOpts{ + Status: model.ChannelJoinRequestStatusPending, + Page: 0, + PerPage: pageSize, + } + rows, _, err := a.Srv().Store().ChannelJoinRequest().GetForChannel(channel.Id, opts) + if err != nil { + rctx.Logger().Warn("CancelPendingChannelJoinRequestsOnConvert: failed to list pending requests", + mlog.String("channel_id", channel.Id), + mlog.Err(err), + ) + return + } + if len(rows) == 0 { + return + } + failed := 0 + for _, row := range rows { + row.Status = model.ChannelJoinRequestStatusWithdrawn + row.Message = "" + updated, updateErr := a.Srv().Store().ChannelJoinRequest().Update(row) + if updateErr != nil { + failed++ + rctx.Logger().Warn("CancelPendingChannelJoinRequestsOnConvert: failed to withdraw pending request", + mlog.String("channel_id", channel.Id), + mlog.String("request_id", row.Id), + mlog.Err(updateErr), + ) + continue + } + a.broadcastChannelJoinRequestUpdated(rctx, channel, updated) + } + // If every row in the batch failed to update, the next iteration + // would re-fetch the same rows and loop forever. Break out and + // surface the situation in the log — the operator can re-run the + // cleanup manually after addressing the underlying store error. + if failed == len(rows) { + rctx.Logger().Warn("CancelPendingChannelJoinRequestsOnConvert: every row in batch failed to update, aborting to avoid infinite loop", + mlog.String("channel_id", channel.Id), + mlog.Int("failed", failed), + ) + return + } + // Standard exit when the last page is partial: every remaining + // pending row was successfully withdrawn (or logged as failed). + if len(rows) < pageSize { + return + } + } + // maxIterations safety net — this should be effectively unreachable + // because the per-batch all-failed check above already aborts on + // systemic update failures. Fire a higher-severity log if we hit it. + rctx.Logger().Error("CancelPendingChannelJoinRequestsOnConvert: hit maxIterations, aborting", + mlog.String("channel_id", channel.Id), + mlog.Int("max_iterations", maxIterations), + ) +} + +// IsDiscoverableSelfAddBlocked reports whether a user trying to self-add to +// `channel` via POST /channels/{id}/members must instead go through the +// request flow. The block applies only when: +// - the channel is private, +// - it is discoverable but does NOT have an active ABAC policy +// (channels with a policy use the existing PDP gate inside +// addUserToChannel — admins can still add others by policy), +// - the user is not yet a member, +// - and the requester is the user themselves. +// +// Other paths (admin invites, API by reviewer ID) are unaffected: the request +// flow exists to give admins a queue, not to block invites. +func (a *App) IsDiscoverableSelfAddBlocked(rctx request.CTX, channel *model.Channel, requesterUserID, targetUserID string) bool { + if channel == nil || channel.Type != model.ChannelTypePrivate { + return false + } + if !channel.Discoverable { + return false + } + if channel.PolicyEnforced { + return false + } + if requesterUserID != targetUserID { + return false + } + if !a.Config().FeatureFlags.DiscoverableChannels { + return false + } + return true +} diff --git a/server/channels/app/channel_discoverable_visibility_test.go b/server/channels/app/channel_discoverable_visibility_test.go new file mode 100644 index 000000000000..ada36a94d7fc --- /dev/null +++ b/server/channels/app/channel_discoverable_visibility_test.go @@ -0,0 +1,85 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestDiscoverableVisibilityInvariant_NonGuestSeesNoPolicy verifies that a +// discoverable + no-policy private channel is returned through the +// non-member autocomplete path for a non-guest user. +// +// The complementary policy-enforced + non-qualifying user case is covered +// by TestFilterDiscoverableChannelsByPolicy_PolicyEnforcedFailSecure (which +// checks the fail-secure path) and the dedicated guest case is in +// TestFilterDiscoverableChannelsByPolicy_GuestHidden. +func TestDiscoverableVisibilityInvariant_NonGuestSeesNoPolicy(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + + // BasicUser2 is a member of the team but NOT of `channel`. The + // autocomplete query must still surface the channel because of the + // discoverable OR-branch (post-query ABAC filter is a no-op since the + // channel has no policy). + results, appErr := th.App.AutocompleteChannelsForTeam(th.Context, th.BasicTeam.Id, th.BasicUser2.Id, channel.Name) + require.Nil(t, appErr) + + found := false + for _, c := range results { + if c.Id == channel.Id { + found = true + break + } + } + assert.True(t, found, "discoverable + no-policy private channel must appear in autocomplete for a non-member non-guest") +} + +// TestDiscoverableVisibilityInvariant_NonDiscoverableHidden ensures that the +// store-level OR-branch we added does not inadvertently leak private +// channels with discoverable=false to non-members. The new OR clause must be +// gated on `Discoverable=true`. +func TestDiscoverableVisibilityInvariant_NonDiscoverableHidden(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + plain := th.CreatePrivateChannel(t, th.BasicTeam) + + results, appErr := th.App.AutocompleteChannelsForTeam(th.Context, th.BasicTeam.Id, th.BasicUser2.Id, plain.Name) + require.Nil(t, appErr) + + for _, c := range results { + assert.NotEqual(t, plain.Id, c.Id, "non-discoverable private channel must remain hidden from non-members") + } +} + +// TestDiscoverableVisibilityInvariant_GuestHidden re-verifies the guest path +// at the autocomplete level (the unit-level guest case lives in +// TestFilterDiscoverableChannelsByPolicy_GuestHidden, but this test exercises +// the full app+store integration so we don't accidentally rely on the +// in-memory filter alone). +func TestDiscoverableVisibilityInvariant_GuestHidden(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + + guest := th.CreateGuest(t) + th.LinkUserToTeam(t, guest, th.BasicTeam) + + results, appErr := th.App.AutocompleteChannelsForTeam(th.Context, th.BasicTeam.Id, guest.Id, channel.Name) + require.Nil(t, appErr) + + for _, c := range results { + assert.NotEqual(t, channel.Id, c.Id, "guests must never see discoverable private channels in autocomplete") + } +} diff --git a/server/channels/app/channel_guards.go b/server/channels/app/channel_guards.go new file mode 100644 index 000000000000..0bb8e4dc0ef7 --- /dev/null +++ b/server/channels/app/channel_guards.go @@ -0,0 +1,216 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "net/http" + "sync" + "time" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" +) + +// Backoff bounds for the guard-cache reload retry. Package vars (not consts) so tests can shrink +// them via t.Cleanup-restored override. +var ( + guardCacheRetryInitialDelay = 1 * time.Second + guardCacheRetryMaxDelay = 5 * time.Minute +) + +const clusterEventInvalidateChannelGuardCache = model.ClusterEvent("inv_channel_guards") + +// reloadGuardCache scans the ChannelGuards table and atomically replaces the in-memory cache with +// the result. Used both at startup (from NewChannels) and from the cluster invalidation handler. +// Forces a master read because all callers (post-write reload, cluster invalidation) can race with +// replica lag. +func (ch *Channels) reloadGuardCache(rctx request.CTX, s store.Store) error { + guards, err := s.ChannelGuard().GetAll(store.RequestContextWithMaster(rctx)) + if err != nil { + return err + } + + fresh := &sync.Map{} + grouped := map[string][]*store.ChannelGuard{} + for _, g := range guards { + grouped[g.ChannelId] = append(grouped[g.ChannelId], g) + } + for channelID, slice := range grouped { + fresh.Store(channelID, slice) + } + + ch.guardCache.Store(fresh) + return nil +} + +// getGuardsForChannel returns the cached guard slice for a channel, or nil if none. +func (ch *Channels) getGuardsForChannel(channelID string) []*store.ChannelGuard { + m := ch.guardCache.Load() + if m == nil { + return nil + } + v, ok := m.Load(channelID) + if !ok { + return nil + } + guards, _ := v.([]*store.ChannelGuard) + return guards +} + +// clusterInvalidateGuardCacheHandler is registered as the receive-side handler for +// clusterEventInvalidateChannelGuardCache. The handler refetches the entire table. +func (ch *Channels) clusterInvalidateGuardCacheHandler(msg *model.ClusterMessage) { + rctx := request.EmptyContext(ch.srv.Log()) + if err := ch.reloadGuardCache(rctx, ch.srv.Store()); err != nil { + ch.srv.Log().Warn( + "Failed to reload channel guard cache after cluster invalidation; retry scheduled", + mlog.String("event", string(msg.Event)), + mlog.Err(err), + ) + ch.scheduleGuardCacheReloadRetry() + } +} + +// broadcastChannelGuardInvalidation tells the rest of the cluster to refetch their guard caches. +// The payload is intentionally empty. +func (ch *Channels) broadcastChannelGuardInvalidation() { + cluster := ch.srv.platform.Cluster() + if cluster == nil { + return + } + + msg := &model.ClusterMessage{ + Event: clusterEventInvalidateChannelGuardCache, + SendType: model.ClusterSendReliable, + WaitForAllToSend: true, + } + cluster.SendClusterMessage(msg) +} + +// RegisterChannelGuard records that pluginID claims channelID. The caller's pluginID is expected to +// be lowercased. +func (a *App) RegisterChannelGuard(rctx request.CTX, channelID, pluginID string) *model.AppError { + if channelID == "" { + return model.NewAppError("RegisterChannelGuard", "app.channel_guard.register.empty_channel.app_error", nil, "", http.StatusBadRequest) + } + if !model.IsValidId(channelID) { + return model.NewAppError("RegisterChannelGuard", "app.channel_guard.invalid_channel.app_error", nil, "", http.StatusBadRequest) + } + + guard := &store.ChannelGuard{ + ChannelId: channelID, + PluginId: pluginID, + CreatedAt: model.GetMillis(), + } + if err := a.Srv().Store().ChannelGuard().Save(rctx, guard); err != nil { + return model.NewAppError("RegisterChannelGuard", "app.channel_guard.register.app_error", nil, err.Error(), http.StatusInternalServerError).Wrap(err) + } + + ch := a.Channels() + if err := ch.reloadGuardCache(rctx, a.Srv().Store()); err != nil { + a.Srv().Log().Warn( + "Failed to reload channel guard cache after Register; retry scheduled", + mlog.String("channel_id", channelID), + mlog.String("plugin_id", pluginID), + mlog.Err(err), + ) + ch.scheduleGuardCacheReloadRetry() + } + ch.broadcastChannelGuardInvalidation() + return nil +} + +// UnregisterChannelGuard removes pluginID's claim on channelID. If pluginID has no claim on the +// channel, this is a no-op (returns nil). The store-level DELETE matches by both ChannelId and +// PluginId, so other plugins' claims on the same channel are left untouched. +func (a *App) UnregisterChannelGuard(rctx request.CTX, channelID, pluginID string) *model.AppError { + if channelID == "" { + return model.NewAppError("UnregisterChannelGuard", "app.channel_guard.unregister.empty_channel.app_error", nil, "", http.StatusBadRequest) + } + if !model.IsValidId(channelID) { + return model.NewAppError("UnregisterChannelGuard", "app.channel_guard.invalid_channel.app_error", nil, "", http.StatusBadRequest) + } + + rowsAffected, err := a.Srv().Store().ChannelGuard().Delete(rctx, channelID, pluginID) + if err != nil { + return model.NewAppError("UnregisterChannelGuard", "app.channel_guard.unregister.app_error", nil, err.Error(), http.StatusInternalServerError).Wrap(err) + } + if rowsAffected == 0 { + a.Srv().Log().Warn( + "UnregisterChannelGuard removed no rows; pluginID does not match any guard for this channel", + mlog.String("error_id", "unregister_no_matching_guard"), + mlog.String("channel_id", channelID), + mlog.String("plugin_id", pluginID), + ) + } + + ch := a.Channels() + if err := ch.reloadGuardCache(rctx, a.Srv().Store()); err != nil { + a.Srv().Log().Warn( + "Failed to reload channel guard cache after Unregister; retry scheduled", + mlog.String("channel_id", channelID), + mlog.String("plugin_id", pluginID), + mlog.Err(err), + ) + ch.scheduleGuardCacheReloadRetry() + } + ch.broadcastChannelGuardInvalidation() + return nil +} + +// scheduleGuardCacheReloadRetry kicks off a single in-flight retry goroutine that calls +// reloadGuardCache with exponential backoff until success or until the server is shutting down. +// Multiple concurrent calls collapse to a single retry — useful when Register, Unregister, the +// cluster handler, and the startup loader can all see the same DB outage simultaneously. +// +// Returns true if a new retry goroutine was scheduled, false if one was already in flight. Call +// sites can ignore the return value; tests use it to assert single-flight semantics. +func (ch *Channels) scheduleGuardCacheReloadRetry() bool { + if !ch.guardCacheRetryInFlight.CompareAndSwap(false, true) { + return false + } + go ch.runGuardCacheReloadRetry() + return true +} + +func (ch *Channels) runGuardCacheReloadRetry() { + defer ch.guardCacheRetryInFlight.Store(false) + rctx := request.EmptyContext(ch.srv.Log()) + + delay := guardCacheRetryInitialDelay + for attempt := 1; ; attempt++ { + timer := time.NewTimer(delay) + select { + case <-ch.interruptQuitChan: + timer.Stop() + ch.srv.Log().Info( + "Channel guard cache reload retry cancelled by shutdown", + mlog.Int("attempt", attempt), + ) + return + case <-timer.C: + } + + if err := ch.reloadGuardCache(rctx, ch.srv.Store()); err != nil { + ch.srv.Log().Info( + "Channel guard cache reload retry attempt failed; will retry", + mlog.Int("attempt", attempt), + mlog.Err(err), + ) + delay *= 2 + if delay > guardCacheRetryMaxDelay { + delay = guardCacheRetryMaxDelay + } + continue + } + + ch.srv.Log().Info( + "Channel guard cache reload retry succeeded", + mlog.Int("attempt", attempt), + ) + return + } +} diff --git a/server/channels/app/channel_guards_test.go b/server/channels/app/channel_guards_test.go new file mode 100644 index 000000000000..cee183000e9c --- /dev/null +++ b/server/channels/app/channel_guards_test.go @@ -0,0 +1,381 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" + "github.com/mattermost/mattermost/server/v8/einterfaces" +) + +// captureClusterMock records every SendClusterMessage call made during a test +// so the test can assert what was broadcast. +type captureClusterMock struct { + mu sync.Mutex + captured []*model.ClusterMessage +} + +func (c *captureClusterMock) SendClusterMessage(msg *model.ClusterMessage) { + c.mu.Lock() + defer c.mu.Unlock() + c.captured = append(c.captured, msg) +} + +func (c *captureClusterMock) SendClusterMessageToNode(nodeID string, msg *model.ClusterMessage) error { + return nil +} + +func (c *captureClusterMock) snapshot() []*model.ClusterMessage { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]*model.ClusterMessage, len(c.captured)) + copy(out, c.captured) + return out +} + +// reset drops everything captured so far. Call this after TestHelper setup +// completes so the test only sees messages produced by the code under test +// (TestHelper init produces ~1000 unrelated cluster messages). +func (c *captureClusterMock) reset() { + c.mu.Lock() + defer c.mu.Unlock() + c.captured = nil +} + +func (c *captureClusterMock) StartInterNodeCommunication() {} +func (c *captureClusterMock) StopInterNodeCommunication() {} +func (c *captureClusterMock) RegisterClusterMessageHandler(event model.ClusterEvent, crm einterfaces.ClusterMessageHandler) { +} +func (c *captureClusterMock) GetClusterId() string { return "capture_cluster_mock" } +func (c *captureClusterMock) IsLeader() bool { return false } +func (c *captureClusterMock) GetMyClusterInfo() *model.ClusterInfo { return nil } +func (c *captureClusterMock) GetClusterInfos() ([]*model.ClusterInfo, error) { return nil, nil } +func (c *captureClusterMock) NotifyMsg(buf []byte) {} +func (c *captureClusterMock) GetClusterStats(rctx request.CTX) ([]*model.ClusterStats, *model.AppError) { + return nil, nil +} +func (c *captureClusterMock) GetLogs(rctx request.CTX, page, perPage int) ([]string, *model.AppError) { + return nil, nil +} +func (c *captureClusterMock) QueryLogs(rctx request.CTX, page, perPage int) (map[string][]string, *model.AppError) { + return nil, nil +} +func (c *captureClusterMock) GenerateSupportPacket(rctx request.CTX, options *model.SupportPacketOptions) (map[string][]model.FileData, error) { + return nil, nil +} +func (c *captureClusterMock) GetPluginStatuses() (model.PluginStatuses, *model.AppError) { + return nil, nil +} +func (c *captureClusterMock) ConfigChanged(previousConfig *model.Config, newConfig *model.Config, sendToOtherServer bool) *model.AppError { + return nil +} +func (c *captureClusterMock) HealthScore() int { return 0 } +func (c *captureClusterMock) WebConnCountForUser(userID string) (int, *model.AppError) { + return 0, nil +} +func (c *captureClusterMock) GetWSQueues(userID, connectionID string, seqNum int64) (map[string]*model.WSQueues, error) { + return nil, nil +} + +func TestChannelGuardCacheBroadcastShape(t *testing.T) { + mainHelper.Parallel(t) + cluster := &captureClusterMock{} + th := SetupWithClusterMock(t, cluster) + cluster.reset() // drop init-time noise; only inspect messages from code under test + + th.App.Channels().broadcastChannelGuardInvalidation() + + captured := cluster.snapshot() + require.Len(t, captured, 1) + msg := captured[0] + assert.Equal(t, clusterEventInvalidateChannelGuardCache, msg.Event) + assert.Equal(t, model.ClusterSendReliable, msg.SendType) + assert.Empty(t, msg.Data, "broadcast payload should be empty (D9: receiver does a full reload)") + assert.True(t, msg.WaitForAllToSend, "guard invalidation must wait for cluster ack (matches access_control precedent)") +} + +func TestChannelGuardRegisterTriggersBroadcast(t *testing.T) { + mainHelper.Parallel(t) + cluster := &captureClusterMock{} + th := SetupWithClusterMock(t, cluster) + cluster.reset() // drop init-time noise; only inspect messages from code under test + + channelID := model.NewId() + pluginID := "com.example.register-broadcast" + rctx := request.EmptyContext(th.App.Srv().Log()) + require.Nil(t, th.App.RegisterChannelGuard(rctx, channelID, pluginID)) + + guardEvents := filterGuardCacheEvents(cluster.snapshot()) + require.Len(t, guardEvents, 1, "Register must produce exactly one guard-cache invalidation") +} + +func filterGuardCacheEvents(msgs []*model.ClusterMessage) []*model.ClusterMessage { + out := []*model.ClusterMessage{} + for _, m := range msgs { + if m.Event == clusterEventInvalidateChannelGuardCache { + out = append(out, m) + } + } + return out +} + +func TestChannelGuardUnregisterTriggersBroadcast(t *testing.T) { + mainHelper.Parallel(t) + cluster := &captureClusterMock{} + th := SetupWithClusterMock(t, cluster) + + channelID := model.NewId() + pluginID := "com.example.unregister-broadcast" + rctx := request.EmptyContext(th.App.Srv().Log()) + // Register first (this also broadcasts), then drop captured noise so we + // only see the Unregister-side broadcast. + require.Nil(t, th.App.RegisterChannelGuard(rctx, channelID, pluginID)) + cluster.reset() + + require.Nil(t, th.App.UnregisterChannelGuard(rctx, channelID, pluginID)) + + guardEvents := filterGuardCacheEvents(cluster.snapshot()) + require.Len(t, guardEvents, 1, "Unregister must produce exactly one guard-cache invalidation") +} + +func TestChannelGuardCacheMultiChannelRefetch(t *testing.T) { + mainHelper.Parallel(t) + cluster := &captureClusterMock{} + th := SetupWithClusterMock(t, cluster) + + channelA := model.NewId() + channelB := model.NewId() + pluginA := "com.example.multi-a" + pluginB := "com.example.multi-b" + + rctx := request.EmptyContext(th.App.Srv().Log()) + require.NoError(t, th.App.Srv().Store().ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelA, PluginId: pluginA, CreatedAt: 1})) + require.NoError(t, th.App.Srv().Store().ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelA, PluginId: pluginB, CreatedAt: 2})) + require.NoError(t, th.App.Srv().Store().ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelB, PluginId: pluginA, CreatedAt: 3})) + + // Force the cache to be empty (simulate a node that just started or had its cache cleared). + th.App.Channels().guardCache.Store(&sync.Map{}) + + th.App.Channels().clusterInvalidateGuardCacheHandler(&model.ClusterMessage{ + Event: clusterEventInvalidateChannelGuardCache, + }) + + gotA := th.App.Channels().getGuardsForChannel(channelA) + gotB := th.App.Channels().getGuardsForChannel(channelB) + assert.Len(t, gotA, 2, "channel A should have two claims after refetch") + assert.Len(t, gotB, 1, "channel B should have one claim after refetch") +} + +// TestChannelGuardRegisterUnregisterNilClusterIsSafe verifies that the +// App-level Register/Unregister methods don't panic when Cluster() is nil. +// They reach broadcastChannelGuardInvalidation, so this also covers the nil +// guard inside that helper. +func TestChannelGuardRegisterUnregisterNilClusterIsSafe(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + require.Nil(t, th.App.Srv().platform.Cluster(), "expected nil cluster in a single-node test setup") + + channelID := th.BasicChannel.Id + pluginID := "com.example.nil-cluster-rt" + + rctx := request.EmptyContext(th.App.Srv().Log()) + require.Nil(t, th.App.RegisterChannelGuard(rctx, channelID, pluginID)) + got := th.App.Channels().getGuardsForChannel(channelID) + require.Len(t, got, 1) + assert.Equal(t, pluginID, got[0].PluginId) + + require.Nil(t, th.App.UnregisterChannelGuard(rctx, channelID, pluginID)) + assert.Empty(t, th.App.Channels().getGuardsForChannel(channelID)) +} + +func TestChannelGuardLowercaseNormalization(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + channelID := th.BasicChannel.Id + mixedCaseID := "MixedCase.Plugin.ID" + expectedID := "mixedcase.plugin.id" + + // Build a PluginAPI directly with a mixed-case manifest. This bypasses the + // real plugin activation path (which we don't need for the lowercasing + // check) and exercises only the api.id -> App.RegisterChannelGuard handoff. + rctx := request.EmptyContext(th.App.Srv().Log()) + api := &PluginAPI{ + id: mixedCaseID, + app: th.App, + ctx: rctx, + } + + require.Nil(t, api.RegisterChannelGuard(channelID)) + guards, err := th.App.Srv().Store().ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, guards, 1) + assert.Equal(t, expectedID, guards[0].PluginId, "PluginId must be normalized to lowercase before reaching the store") + + require.Nil(t, api.UnregisterChannelGuard(channelID)) + guards, err = th.App.Srv().Store().ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + assert.Empty(t, guards, "Unregister with the same mixed-case id must hit the lowercased row") +} + +func TestChannelGuardEmptyChannelIDRejected(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t) + + rctx := request.EmptyContext(th.App.Srv().Log()) + appErr := th.App.RegisterChannelGuard(rctx, "", "com.example.plugin") + require.NotNil(t, appErr) + assert.Equal(t, "app.channel_guard.register.empty_channel.app_error", appErr.Id) + assert.Equal(t, 400, appErr.StatusCode) + + appErr = th.App.UnregisterChannelGuard(rctx, "", "com.example.plugin") + require.NotNil(t, appErr) + assert.Equal(t, "app.channel_guard.unregister.empty_channel.app_error", appErr.Id) + assert.Equal(t, 400, appErr.StatusCode) +} + +// TestUnregisterChannelGuardWarnsOnNoMatchingRow verifies that calling UnregisterChannelGuard with +// a pluginID that has no claim on the channel returns nil (no error) and leaves the existing guard +// row untouched. The Warn log emitted when rowsAffected==0 is operator-facing and is not asserted +// here; the behavioral contract (nil return + row unchanged) is the check. +func TestUnregisterChannelGuardWarnsOnNoMatchingRow(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + channelID := th.BasicChannel.Id + pluginA := "com.example.plugin-a" + pluginB := "com.example.plugin-b" + + rctx := request.EmptyContext(th.App.Srv().Log()) + + // Register pluginA's guard on the channel. + require.Nil(t, th.App.RegisterChannelGuard(rctx, channelID, pluginA)) + + // Unregister with a different pluginID — must return nil (no-op). + appErr := th.App.UnregisterChannelGuard(rctx, channelID, pluginB) + require.Nil(t, appErr, "cross-plugin Unregister must return nil") + + // pluginA's guard row must be untouched. + guards, err := th.App.Srv().Store().ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, guards, 1, "pluginA guard row must remain after cross-plugin Unregister") + assert.Equal(t, pluginA, guards[0].PluginId) +} + +// failingGuardStore wraps a real ChannelGuardStore but forces GetAll to error, +// so tests can exercise reload-failure branches deterministically. +type failingGuardStore struct { + store.ChannelGuardStore + err error +} + +func (f *failingGuardStore) GetAll(rctx request.CTX) ([]*store.ChannelGuard, error) { + return nil, f.err +} + +// guardFailingStoreWrapper decorates a real Store, swapping ChannelGuard() for +// a failing implementation. All other store calls pass through to the embedded +// Store so the rest of the app stays functional. +type guardFailingStoreWrapper struct { + store.Store + failing *failingGuardStore +} + +func (w *guardFailingStoreWrapper) ChannelGuard() store.ChannelGuardStore { + return w.failing +} + +func TestChannelGuardCacheClusterInvalidationHandlesStoreFailure(t *testing.T) { + // No t.Parallel(): mutates package-level guardCacheRetryInitialDelay. + originalInitial := guardCacheRetryInitialDelay + guardCacheRetryInitialDelay = 30 * time.Second + t.Cleanup(func() { guardCacheRetryInitialDelay = originalInitial }) + + th := Setup(t) + ch := th.App.Channels() + + // Pre-populate the cache with a known row by writing through the real store + // then doing a successful reload. + channelID := model.NewId() + pluginID := "com.example.cluster-fail-test" + rctx := request.EmptyContext(th.App.Srv().Log()) + require.NoError(t, th.App.Srv().Store().ChannelGuard().Save(rctx, &store.ChannelGuard{ + ChannelId: channelID, + PluginId: pluginID, + CreatedAt: 1, + })) + require.NoError(t, ch.reloadGuardCache(rctx, th.App.Srv().Store())) + require.Len(t, ch.getGuardsForChannel(channelID), 1, "precondition: cache should hold the seeded row") + + // Swap in a wrapped store that fails on GetAll. + originalStore := th.App.Srv().Store() + wrapped := &guardFailingStoreWrapper{ + Store: originalStore, + failing: &failingGuardStore{ChannelGuardStore: originalStore.ChannelGuard(), err: assert.AnError}, + } + th.App.Srv().SetStore(wrapped) + t.Cleanup(func() { th.App.Srv().SetStore(originalStore) }) + + // Sanity: confirm the wrapped store actually fails, otherwise the test is meaningless. + _, err := th.App.Srv().Store().ChannelGuard().GetAll(rctx) + require.Error(t, err, "test wrapper must surface GetAll failure") + + // Calling the handler with a failing store must: + // - not panic + // - leave the existing cache untouched + // - schedule a retry (atomic.Bool flips to true) + require.NotPanics(t, func() { + ch.clusterInvalidateGuardCacheHandler(&model.ClusterMessage{ + Event: clusterEventInvalidateChannelGuardCache, + }) + }) + + assert.Len(t, ch.getGuardsForChannel(channelID), 1, "cache must be unchanged when reload fails") + assert.True(t, ch.guardCacheRetryInFlight.Load(), "failed reload from cluster handler must schedule a retry") +} + +// TestScheduleGuardCacheReloadRetrySingleFlight verifies that concurrent calls to +// scheduleGuardCacheReloadRetry collapse to a single in-flight retry goroutine. The retry goroutine +// is parked in its initial timer wait by shrinking nothing — instead we override the initial delay +// to a very long value so the test window stays inside the timer wait, then verify the second call +// returns false (no new goroutine scheduled). Test cleanup tears down the server which closes +// interruptQuitChan and lets the parked goroutine exit cleanly. No t.Parallel() because it mutates +// a package-level var. +func TestScheduleGuardCacheReloadRetrySingleFlight(t *testing.T) { + originalInitial := guardCacheRetryInitialDelay + guardCacheRetryInitialDelay = 30 * time.Second + t.Cleanup(func() { guardCacheRetryInitialDelay = originalInitial }) + + th := Setup(t) + + ch := th.App.Channels() + require.True(t, ch.scheduleGuardCacheReloadRetry(), "first call should schedule a retry") + require.False(t, ch.scheduleGuardCacheReloadRetry(), "second call should be a no-op while one is in flight") + require.False(t, ch.scheduleGuardCacheReloadRetry(), "additional concurrent calls should also be no-ops") +} + +func TestChannelGuardInvalidChannelIDRejected(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t) + + rctx := request.EmptyContext(th.App.Srv().Log()) + appErr := th.App.RegisterChannelGuard(rctx, "not-a-real-id", "com.example.plugin") + require.NotNil(t, appErr) + assert.Equal(t, "app.channel_guard.invalid_channel.app_error", appErr.Id) + assert.Equal(t, 400, appErr.StatusCode) + + appErr = th.App.UnregisterChannelGuard(rctx, "not-a-real-id", "com.example.plugin") + require.NotNil(t, appErr) + assert.Equal(t, "app.channel_guard.invalid_channel.app_error", appErr.Id) + assert.Equal(t, 400, appErr.StatusCode) +} diff --git a/server/channels/app/channel_join_request.go b/server/channels/app/channel_join_request.go new file mode 100644 index 000000000000..05b8c98e5034 --- /dev/null +++ b/server/channels/app/channel_join_request.go @@ -0,0 +1,447 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" +) + +// channelJoinRequestPaginationDefaultPerPage matches the public /api/v4 default +// for paginated endpoints. +const channelJoinRequestPaginationDefaultPerPage = 60 + +// channelJoinRequestPaginationMaxPerPage caps a single page's size; mirrors the +// 200 cap shared by other public list endpoints. +const channelJoinRequestPaginationMaxPerPage = 200 + +// requestJoinChannelGuard validates that a user is allowed to express interest +// in joining `channel` and returns a sanitized result for `channel`. Callers +// are expected to look up `channel` via the store before calling this helper. +func (a *App) requestJoinChannelGuard(rctx request.CTX, user *model.User, channel *model.Channel) *model.AppError { + if channel == nil { + return model.NewAppError("RequestJoinChannel", "app.channel.get.existing.app_error", nil, "", http.StatusNotFound) + } + + if channel.DeleteAt != 0 { + return model.NewAppError("RequestJoinChannel", "api.channel.discoverable_join_request.archived.app_error", nil, "channel_id="+channel.Id, http.StatusBadRequest) + } + + if channel.Type != model.ChannelTypePrivate { + return model.NewAppError("RequestJoinChannel", "api.channel.discoverable_join_request.not_private.app_error", nil, "channel_id="+channel.Id, http.StatusBadRequest) + } + + if !channel.Discoverable { + return model.NewAppError("RequestJoinChannel", "api.channel.discoverable_join_request.not_discoverable.app_error", nil, "channel_id="+channel.Id, http.StatusForbidden) + } + + // Shared channels join through their own remote-cluster sync mechanism. + if channel.IsShared() { + return model.NewAppError("RequestJoinChannel", "api.channel.discoverable_join_request.shared.app_error", nil, "channel_id="+channel.Id, http.StatusBadRequest) + } + + if user.IsGuest() { + return model.NewAppError("RequestJoinChannel", "api.channel.discoverable_join_request.guest.app_error", nil, "user_id="+user.Id, http.StatusForbidden) + } + + if user.DeleteAt != 0 { + return model.NewAppError("RequestJoinChannel", "app.channel.add_member.deleted_user.app_error", nil, "", http.StatusForbidden) + } + + return nil +} + +// RequestJoinChannel decides between an immediate ABAC-gated auto-join and an +// asynchronous request-to-join row. +// +// Returns the persisted ChannelJoinRequest when the user must wait for an +// admin review, or nil when the user was added directly to the channel (the +// caller can detect this via the `joined` return value). +func (a *App) RequestJoinChannel(rctx request.CTX, userID, channelID, message string) (joined bool, req *model.ChannelJoinRequest, appErr *model.AppError) { + user, appErr := a.GetUser(userID) + if appErr != nil { + return false, nil, appErr + } + + channel, appErr := a.GetChannel(rctx, channelID) + if appErr != nil { + return false, nil, appErr + } + + if guardErr := a.requestJoinChannelGuard(rctx, user, channel); guardErr != nil { + return false, nil, guardErr + } + + _, memberErr := a.Srv().Store().Channel().GetMember(rctx, channel.Id, user.Id) + if memberErr == nil { + return false, nil, model.NewAppError("RequestJoinChannel", "api.channel.discoverable_join_request.already_member.app_error", nil, "channel_id="+channel.Id, http.StatusBadRequest) + } + var nfErr *store.ErrNotFound + if !errors.As(memberErr, &nfErr) { + return false, nil, model.NewAppError("RequestJoinChannel", "app.channel.get_member.app_error", nil, "", http.StatusInternalServerError).Wrap(memberErr) + } + + enforced, appErr := a.ChannelAccessControlled(rctx, channel.Id) + if appErr != nil { + return false, nil, appErr + } + + // ABAC gate: when an active policy is attached and the user qualifies, add + // the member directly. AddChannelMember re-runs the PDP gate inside + // addUserToChannel, so a denial here is authoritative; a non-allow result + // falls through to the request-row path below ONLY when there is no policy. + if enforced { + decision, evalErr := a.evaluateChannelMembership(rctx, user, channel) + if evalErr != nil { + return false, nil, evalErr + } + if !decision { + return false, nil, model.NewAppError("RequestJoinChannel", "api.channel.discoverable_join_request.policy_denied.app_error", nil, "channel_id="+channel.Id, http.StatusForbidden) + } + + if _, err := a.AddChannelMember(rctx, user.Id, channel, ChannelMemberOpts{UserRequestorID: user.Id}); err != nil { + return false, nil, err + } + return true, nil, nil + } + + pending := &model.ChannelJoinRequest{ + ChannelId: channel.Id, + UserId: user.Id, + Message: message, + } + + saved, err := a.Srv().Store().ChannelJoinRequest().Save(pending) + if err != nil { + var conflict *store.ErrConflict + if errors.As(err, &conflict) { + existing, getErr := a.Srv().Store().ChannelJoinRequest().GetPendingForChannelAndUser(channel.Id, user.Id) + if getErr == nil { + return false, existing, nil + } + return false, nil, model.NewAppError("RequestJoinChannel", "api.channel.discoverable_join_request.duplicate.app_error", nil, "channel_id="+channel.Id, http.StatusConflict) + } + if appErr, ok := err.(*model.AppError); ok { + return false, nil, appErr + } + return false, nil, model.NewAppError("RequestJoinChannel", "app.channel.join_request.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + a.broadcastChannelJoinRequestCreated(rctx, channel, saved) + return false, saved, nil +} + +// WithdrawChannelJoinRequest flips a pending request the calling user owns to +// the withdrawn state. Non-owners receive a 404 (no oracle on existence) and +// already-terminal rows return 409. +func (a *App) WithdrawChannelJoinRequest(rctx request.CTX, requestID, userID string) (*model.ChannelJoinRequest, *model.AppError) { + current, err := a.Srv().Store().ChannelJoinRequest().Get(requestID) + if err != nil { + var nfErr *store.ErrNotFound + if errors.As(err, &nfErr) { + return nil, model.NewAppError("WithdrawChannelJoinRequest", "app.channel.join_request.not_found.app_error", nil, "request_id="+requestID, http.StatusNotFound) + } + return nil, model.NewAppError("WithdrawChannelJoinRequest", "app.channel.join_request.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + if current.UserId != userID { + // Hide the row from non-owners by returning the same not-found + // response. The reviewer flow uses different endpoints. + return nil, model.NewAppError("WithdrawChannelJoinRequest", "app.channel.join_request.not_found.app_error", nil, "request_id="+requestID, http.StatusNotFound) + } + + if current.Status != model.ChannelJoinRequestStatusPending { + return nil, model.NewAppError("WithdrawChannelJoinRequest", "api.channel.discoverable_join_request.not_pending.app_error", nil, "request_id="+requestID, http.StatusConflict) + } + + current.Status = model.ChannelJoinRequestStatusWithdrawn + current.Message = "" + + updated, err := a.Srv().Store().ChannelJoinRequest().Update(current) + if err != nil { + if appErr, ok := err.(*model.AppError); ok { + return nil, appErr + } + return nil, model.NewAppError("WithdrawChannelJoinRequest", "app.channel.join_request.update.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + channel, channelErr := a.GetChannel(rctx, updated.ChannelId) + if channelErr != nil { + // Channel went away mid-flight — still report the update; we just + // can't broadcast to the admin queue. + rctx.Logger().Warn("WithdrawChannelJoinRequest: failed to load channel for broadcast", mlog.String("channel_id", updated.ChannelId), mlog.Err(channelErr)) + return updated, nil + } + a.broadcastChannelJoinRequestUpdated(rctx, channel, updated) + return updated, nil +} + +// GetMyChannelJoinRequest returns the calling user's active pending request for +// `channelID`, or nil if none exists. It never returns an error for a missing +// row — that's the non-pending state and is expected. +func (a *App) GetMyChannelJoinRequest(rctx request.CTX, userID, channelID string) (*model.ChannelJoinRequest, *model.AppError) { + req, err := a.Srv().Store().ChannelJoinRequest().GetPendingForChannelAndUser(channelID, userID) + if err != nil { + var nfErr *store.ErrNotFound + if errors.As(err, &nfErr) { + return nil, nil + } + return nil, model.NewAppError("GetMyChannelJoinRequest", "app.channel.join_request.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + return req, nil +} + +// GetMyChannelJoinRequests lists the calling user's join requests across all +// channels. The "My Pending Requests" tab filters by `Status="pending"` (the +// default when opts.Status is empty). +func (a *App) GetMyChannelJoinRequests(rctx request.CTX, userID string, opts model.GetChannelJoinRequestsOpts) (*model.ChannelJoinRequestList, *model.AppError) { + opts = sanitizeJoinRequestListOpts(opts) + rows, total, err := a.Srv().Store().ChannelJoinRequest().GetForUser(userID, opts) + if err != nil { + return nil, model.NewAppError("GetMyChannelJoinRequests", "app.channel.join_request.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + return &model.ChannelJoinRequestList{Requests: rows, TotalCount: total}, nil +} + +// GetChannelJoinRequests lists the join requests targeting `channelID` for the +// admin queue UI. The visibility check is performed by the API layer via the +// PermissionManageChannelJoinRequests permission. +func (a *App) GetChannelJoinRequests(rctx request.CTX, channelID string, opts model.GetChannelJoinRequestsOpts) (*model.ChannelJoinRequestList, *model.AppError) { + opts = sanitizeJoinRequestListOpts(opts) + rows, total, err := a.Srv().Store().ChannelJoinRequest().GetForChannel(channelID, opts) + if err != nil { + return nil, model.NewAppError("GetChannelJoinRequests", "app.channel.join_request.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + return &model.ChannelJoinRequestList{Requests: rows, TotalCount: total}, nil +} + +// CountPendingChannelJoinRequests returns the number of pending join requests +// for `channelID`, used by the channel-header badge. +func (a *App) CountPendingChannelJoinRequests(rctx request.CTX, channelID string) (int64, *model.AppError) { + count, err := a.Srv().Store().ChannelJoinRequest().CountPending(channelID) + if err != nil { + return 0, model.NewAppError("CountPendingChannelJoinRequests", "app.channel.join_request.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + return count, nil +} + +// UpdateChannelJoinRequest applies an admin review (approve / deny) to a +// pending request. When approving, the user is added via AddChannelMember so +// the existing PDP gate inside addUserToChannel re-runs — admins cannot bypass +// an active ABAC policy. The store row is only updated after a successful add +// to keep the audit trail consistent. +func (a *App) UpdateChannelJoinRequest(rctx request.CTX, requestID, channelID string, patch *model.ChannelJoinRequestPatch, reviewerID string) (*model.ChannelJoinRequest, *model.AppError) { + if patch == nil { + return nil, model.NewAppError("UpdateChannelJoinRequest", "api.channel.discoverable_join_request.invalid_patch.app_error", nil, "", http.StatusBadRequest) + } + + switch patch.Status { + case model.ChannelJoinRequestStatusApproved, model.ChannelJoinRequestStatusDenied: + default: + return nil, model.NewAppError("UpdateChannelJoinRequest", "api.channel.discoverable_join_request.invalid_patch.app_error", nil, "status="+patch.Status, http.StatusBadRequest) + } + + current, err := a.Srv().Store().ChannelJoinRequest().Get(requestID) + if err != nil { + var nfErr *store.ErrNotFound + if errors.As(err, &nfErr) { + return nil, model.NewAppError("UpdateChannelJoinRequest", "app.channel.join_request.not_found.app_error", nil, "request_id="+requestID, http.StatusNotFound) + } + return nil, model.NewAppError("UpdateChannelJoinRequest", "app.channel.join_request.get.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + // Defense in depth: refuse cross-channel updates so a forged request id + // can't be reviewed against a channel the admin happens to own. + if current.ChannelId != channelID { + return nil, model.NewAppError("UpdateChannelJoinRequest", "app.channel.join_request.not_found.app_error", nil, "request_id="+requestID, http.StatusNotFound) + } + + if current.Status != model.ChannelJoinRequestStatusPending { + return nil, model.NewAppError("UpdateChannelJoinRequest", "api.channel.discoverable_join_request.not_pending.app_error", nil, "request_id="+requestID, http.StatusConflict) + } + + channel, appErr := a.GetChannel(rctx, current.ChannelId) + if appErr != nil { + return nil, appErr + } + + if patch.Status == model.ChannelJoinRequestStatusApproved { + if _, err := a.AddChannelMember(rctx, current.UserId, channel, ChannelMemberOpts{UserRequestorID: reviewerID}); err != nil { + return nil, err + } + } + + current.Status = patch.Status + current.ReviewedBy = reviewerID + current.ReviewedAt = model.GetMillis() + current.DenialReason = "" + if patch.Status == model.ChannelJoinRequestStatusDenied && patch.DenialReason != nil { + current.DenialReason = *patch.DenialReason + } + // Drop the original message from the response; it served its purpose + // during review and keeping it would leak free-text into the audit trail. + current.Message = "" + + updated, err := a.Srv().Store().ChannelJoinRequest().Update(current) + if err != nil { + if appErr, ok := err.(*model.AppError); ok { + return nil, appErr + } + return nil, model.NewAppError("UpdateChannelJoinRequest", "app.channel.join_request.update.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + a.broadcastChannelJoinRequestUpdated(rctx, channel, updated) + return updated, nil +} + +// sanitizeJoinRequestListOpts clamps user-provided pagination + status options +// so the store sees a normalized request. +func sanitizeJoinRequestListOpts(opts model.GetChannelJoinRequestsOpts) model.GetChannelJoinRequestsOpts { + if opts.Status == "" { + opts.Status = model.ChannelJoinRequestStatusPending + } else if !model.IsValidChannelJoinRequestStatus(opts.Status) { + opts.Status = model.ChannelJoinRequestStatusPending + } + if opts.Page < 0 { + opts.Page = 0 + } + if opts.PerPage <= 0 { + opts.PerPage = channelJoinRequestPaginationDefaultPerPage + } else if opts.PerPage > channelJoinRequestPaginationMaxPerPage { + opts.PerPage = channelJoinRequestPaginationMaxPerPage + } + return opts +} + +// evaluateChannelMembership runs the access-control PDP for `user` against the +// `membership` action on `channel`, returning the boolean decision. Errors +// from the PDP are returned to callers so they can choose between the +// "channel is invisible" (visibility filter) or "channel cannot be joined" +// (request flow) fail-secure semantics. Callers must have already verified +// that `channel.PolicyEnforced` is true before invoking the PDP. +func (a *App) evaluateChannelMembership(rctx request.CTX, user *model.User, channel *model.Channel) (bool, *model.AppError) { + acs := a.Srv().Channels().AccessControl + if acs == nil { + // No ABAC service → fail-secure. The channel acts as if the user did + // not satisfy the policy. + return false, nil + } + + subject, appErr := a.BuildAccessControlSubject(rctx, user.Id, user.Roles, channel.Id) + if appErr != nil { + return false, appErr + } + + decision, evalErr := acs.AccessEvaluation(rctx, model.AccessRequest{ + Subject: *subject, + Resource: model.Resource{ + Type: model.AccessControlPolicyTypeChannel, + ID: channel.Id, + }, + Action: "membership", + }) + if evalErr != nil { + return false, evalErr + } + return decision.Decision, nil +} + +// channelAdminUserIDs returns the user ids of channel members with the +// scheme-admin role on `channelID`. Used to scope WS broadcasts of join-request +// events to the queue audience. Failures bubble up because broadcasting to no +// one would silently break the admin UI. +func (a *App) channelAdminUserIDs(rctx request.CTX, channelID string) ([]string, *model.AppError) { + const channelMembersPageSize = 200 + + admins := []string{} + page := 0 + for { + members, err := a.Srv().Store().Channel().GetMembers(model.ChannelMembersGetOptions{ + ChannelID: channelID, + Offset: page * channelMembersPageSize, + Limit: channelMembersPageSize, + }) + if err != nil { + return nil, model.NewAppError("channelAdminUserIDs", "app.channel.get_members.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + for _, m := range members { + if m.SchemeAdmin { + admins = append(admins, m.UserId) + } + } + if len(members) < channelMembersPageSize { + break + } + page++ + } + return admins, nil +} + +// broadcastChannelJoinRequestCreated fires a channel_join_request_created event +// scoped to the channel admin set, using the OnlyChannelAdmins broadcast hook +// to filter out non-admin members the channel-id broadcast would otherwise +// reach. +func (a *App) broadcastChannelJoinRequestCreated(rctx request.CTX, channel *model.Channel, req *model.ChannelJoinRequest) { + a.publishChannelJoinRequestEvent(rctx, channel, req, model.WebsocketEventChannelJoinRequestCreated, true /* adminsOnly */) +} + +// broadcastChannelJoinRequestUpdated fires a channel_join_request_updated event +// to the channel admin set + the requesting user (so their My Pending Requests +// list reacts in real-time). +func (a *App) broadcastChannelJoinRequestUpdated(rctx request.CTX, channel *model.Channel, req *model.ChannelJoinRequest) { + // Send a dedicated copy to the requester so an offline-but-then-reconnected + // requester gets their own row update even when they are not a channel + // member yet (the channel-id broadcast wouldn't reach them otherwise). + if req.UserId != "" { + userMessage := model.NewWebSocketEvent(model.WebsocketEventChannelJoinRequestUpdated, "", "", req.UserId, nil, "") + userMessage.Add("request", marshalChannelJoinRequest(rctx, req)) + userMessage.Add("channel_id", channel.Id) + a.Publish(userMessage) + } + a.publishChannelJoinRequestEvent(rctx, channel, req, model.WebsocketEventChannelJoinRequestUpdated, true /* adminsOnly */) +} + +func (a *App) publishChannelJoinRequestEvent(rctx request.CTX, channel *model.Channel, req *model.ChannelJoinRequest, event model.WebsocketEventType, adminsOnly bool) { + message := model.NewWebSocketEvent(event, "", channel.Id, "", nil, "") + message.Add("request", marshalChannelJoinRequest(rctx, req)) + message.Add("channel_id", channel.Id) + + if adminsOnly { + admins, appErr := a.channelAdminUserIDs(rctx, channel.Id) + if appErr != nil { + rctx.Logger().Warn("Failed to compute channel admin set for join request broadcast", + mlog.String("channel_id", channel.Id), + mlog.Err(appErr), + ) + return + } + useOnlyChannelAdminsHook(message, admins) + } + a.Publish(message) +} + +// marshalChannelJoinRequest returns the request as a JSON string for the WS +// payload. JSON encoding errors are logged and the payload is delivered as an +// empty string so the event still arrives (clients can tolerate a missing +// request body and refetch). +func marshalChannelJoinRequest(rctx request.CTX, req *model.ChannelJoinRequest) string { + if req == nil { + return "" + } + buf, err := json.Marshal(req) + if err != nil { + rctx.Logger().Warn("Failed to marshal ChannelJoinRequest for WS broadcast", + mlog.String("request_id", req.Id), + mlog.Err(err), + ) + return "" + } + return string(buf) +} diff --git a/server/channels/app/channel_join_request_test.go b/server/channels/app/channel_join_request_test.go new file mode 100644 index 000000000000..6cda6adb46e6 --- /dev/null +++ b/server/channels/app/channel_join_request_test.go @@ -0,0 +1,379 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" +) + +// withDiscoverableChannelsFlag toggles the FeatureFlag for the duration of a +// test and restores it on cleanup. Feature flags are read-only by default in +// the test config store; flipping SetReadOnlyFF lets the UpdateConfig call +// land. We deliberately do NOT restore SetReadOnlyFF(true) afterward — the +// underlying store is per-test and disposed on cleanup. +func withDiscoverableChannelsFlag(t *testing.T, th *TestHelper, on bool) { + t.Helper() + th.ConfigStore.SetReadOnlyFF(false) + previous := th.App.Config().FeatureFlags.DiscoverableChannels + th.App.UpdateConfig(func(cfg *model.Config) { cfg.FeatureFlags.DiscoverableChannels = on }) + t.Cleanup(func() { + th.App.UpdateConfig(func(cfg *model.Config) { cfg.FeatureFlags.DiscoverableChannels = previous }) + }) +} + +// markDiscoverable flips the channel's discoverable flag in the store via +// PatchChannel so the model invariants run alongside the test scenario. +func markDiscoverable(t *testing.T, th *TestHelper, channel *model.Channel) *model.Channel { + t.Helper() + on := true + patched, err := th.App.PatchChannel(th.Context, channel, &model.ChannelPatch{Discoverable: &on}, th.BasicUser.Id) + require.Nil(t, err) + require.True(t, patched.Discoverable) + return patched +} + +func TestRequestJoinChannel_RejectsNonDiscoverable(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := th.CreatePrivateChannel(t, th.BasicTeam) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + + joined, req, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channel.Id, "please") + require.NotNil(t, appErr) + assert.Equal(t, http.StatusForbidden, appErr.StatusCode) + assert.Equal(t, "api.channel.discoverable_join_request.not_discoverable.app_error", appErr.Id) + assert.False(t, joined) + assert.Nil(t, req) +} + +func TestRequestJoinChannel_RejectsExistingMember(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := th.CreatePrivateChannel(t, th.BasicTeam) + channel = markDiscoverable(t, th, channel) + + // BasicUser is the channel creator → already a member. + _, _, appErr := th.App.RequestJoinChannel(th.Context, th.BasicUser.Id, channel.Id, "") + require.NotNil(t, appErr) + assert.Equal(t, http.StatusBadRequest, appErr.StatusCode) + assert.Equal(t, "api.channel.discoverable_join_request.already_member.app_error", appErr.Id) +} + +func TestRequestJoinChannel_PendingHappyPath(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := th.CreatePrivateChannel(t, th.BasicTeam) + channel = markDiscoverable(t, th, channel) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + + joined, req, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channel.Id, "let me in") + require.Nil(t, appErr) + assert.False(t, joined, "should not auto-join when no policy is enforced") + require.NotNil(t, req) + assert.Equal(t, model.ChannelJoinRequestStatusPending, req.Status) + assert.Equal(t, channel.Id, req.ChannelId) + assert.Equal(t, other.Id, req.UserId) + assert.Equal(t, "let me in", req.Message) + + // Submitting again returns the existing pending row (idempotent on + // partial-unique conflict). + joined, req2, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channel.Id, "again") + require.Nil(t, appErr) + assert.False(t, joined) + require.NotNil(t, req2) + assert.Equal(t, req.Id, req2.Id) +} + +func TestRequestJoinChannel_RejectsGuest(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := th.CreatePrivateChannel(t, th.BasicTeam) + channel = markDiscoverable(t, th, channel) + + guest := th.CreateGuest(t) + th.LinkUserToTeam(t, guest, th.BasicTeam) + + _, _, appErr := th.App.RequestJoinChannel(th.Context, guest.Id, channel.Id, "") + require.NotNil(t, appErr) + assert.Equal(t, http.StatusForbidden, appErr.StatusCode) + assert.Equal(t, "api.channel.discoverable_join_request.guest.app_error", appErr.Id) +} + +func TestUpdateChannelJoinRequest_ApproveAddsMember(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := th.CreatePrivateChannel(t, th.BasicTeam) + channel = markDiscoverable(t, th, channel) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + + _, req, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channel.Id, "") + require.Nil(t, appErr) + require.NotNil(t, req) + + patch := &model.ChannelJoinRequestPatch{Status: model.ChannelJoinRequestStatusApproved} + updated, appErr := th.App.UpdateChannelJoinRequest(th.Context, req.Id, channel.Id, patch, th.BasicUser.Id) + require.Nil(t, appErr) + assert.Equal(t, model.ChannelJoinRequestStatusApproved, updated.Status) + assert.Equal(t, th.BasicUser.Id, updated.ReviewedBy) + assert.NotZero(t, updated.ReviewedAt) + assert.Empty(t, updated.Message, "message should be redacted from the response after review") + + member, mErr := th.App.GetChannelMember(th.Context, channel.Id, other.Id) + require.Nil(t, mErr) + require.NotNil(t, member) + assert.Equal(t, other.Id, member.UserId) +} + +func TestUpdateChannelJoinRequest_DenyKeepsReason(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := th.CreatePrivateChannel(t, th.BasicTeam) + channel = markDiscoverable(t, th, channel) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + _, req, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channel.Id, "please") + require.Nil(t, appErr) + require.NotNil(t, req) + + reason := "team-internal channel" + patch := &model.ChannelJoinRequestPatch{ + Status: model.ChannelJoinRequestStatusDenied, + DenialReason: &reason, + } + updated, appErr := th.App.UpdateChannelJoinRequest(th.Context, req.Id, channel.Id, patch, th.BasicUser.Id) + require.Nil(t, appErr) + assert.Equal(t, model.ChannelJoinRequestStatusDenied, updated.Status) + assert.Equal(t, reason, updated.DenialReason) + + // Member must NOT have been added. + _, mErr := th.App.GetChannelMember(th.Context, channel.Id, other.Id) + require.NotNil(t, mErr) + assert.Equal(t, MissingChannelMemberError, mErr.Id) +} + +func TestUpdateChannelJoinRequest_RejectsCrossChannel(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channelA := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + channelB := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + _, req, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channelA.Id, "") + require.Nil(t, appErr) + require.NotNil(t, req) + + patch := &model.ChannelJoinRequestPatch{Status: model.ChannelJoinRequestStatusApproved} + _, appErr = th.App.UpdateChannelJoinRequest(th.Context, req.Id, channelB.Id, patch, th.BasicUser.Id) + require.NotNil(t, appErr) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) +} + +func TestWithdrawChannelJoinRequest_OwnerOnly(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + _, req, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channel.Id, "") + require.Nil(t, appErr) + require.NotNil(t, req) + + stranger := th.CreateUser(t) + _, appErr = th.App.WithdrawChannelJoinRequest(th.Context, req.Id, stranger.Id) + require.NotNil(t, appErr) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + + updated, appErr := th.App.WithdrawChannelJoinRequest(th.Context, req.Id, other.Id) + require.Nil(t, appErr) + assert.Equal(t, model.ChannelJoinRequestStatusWithdrawn, updated.Status) + + // A second withdrawal is rejected with 409. + _, appErr = th.App.WithdrawChannelJoinRequest(th.Context, req.Id, other.Id) + require.NotNil(t, appErr) + assert.Equal(t, http.StatusConflict, appErr.StatusCode) +} + +func TestGetMyChannelJoinRequests(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channelA := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + channelB := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + _, _, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channelA.Id, "") + require.Nil(t, appErr) + _, _, appErr = th.App.RequestJoinChannel(th.Context, other.Id, channelB.Id, "") + require.Nil(t, appErr) + + list, appErr := th.App.GetMyChannelJoinRequests(th.Context, other.Id, model.GetChannelJoinRequestsOpts{}) + require.Nil(t, appErr) + require.NotNil(t, list) + assert.EqualValues(t, 2, list.TotalCount) + assert.Len(t, list.Requests, 2) +} + +func TestCountPendingChannelJoinRequests(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + _, _, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channel.Id, "") + require.Nil(t, appErr) + + count, appErr := th.App.CountPendingChannelJoinRequests(th.Context, channel.Id) + require.Nil(t, appErr) + assert.EqualValues(t, 1, count) +} + +func TestUpdateChannelPrivacy_CancelsPendingRequestsOnConvertToPublic(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + + other := th.CreateUser(t) + th.LinkUserToTeam(t, other, th.BasicTeam) + _, req, appErr := th.App.RequestJoinChannel(th.Context, other.Id, channel.Id, "") + require.Nil(t, appErr) + require.NotNil(t, req) + + channel.Type = model.ChannelTypeOpen + converted, appErr := th.App.UpdateChannelPrivacy(th.Context, channel, th.BasicUser) + require.Nil(t, appErr) + + // Discoverable must be reset on convert-to-public — the model invariant + // (Channel.IsValid) rejects (type=O, discoverable=true), so leaving it + // true would also break the next channel save. + assert.False(t, converted.Discoverable, "Discoverable must be reset to false after convert-to-public") + persisted, getErr := th.App.GetChannel(th.Context, channel.Id) + require.Nil(t, getErr) + assert.False(t, persisted.Discoverable, "Discoverable must be persisted as false after convert-to-public") + + // The cancellation side-effect is dispatched on a goroutine; poll for + // the withdrawn state instead of sleeping. + require.Eventually(t, func() bool { + row, err := th.App.Srv().Store().ChannelJoinRequest().Get(req.Id) + if err != nil { + return false + } + return row.Status == model.ChannelJoinRequestStatusWithdrawn + }, 2*time.Second, 50*time.Millisecond) +} + +func TestIsDiscoverableSelfAddBlocked(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := markDiscoverable(t, th, th.CreatePrivateChannel(t, th.BasicTeam)) + + other := th.CreateUser(t) + assert.True(t, th.App.IsDiscoverableSelfAddBlocked(th.Context, channel, other.Id, other.Id), "self-add to discoverable + no-policy private must be blocked") + assert.False(t, th.App.IsDiscoverableSelfAddBlocked(th.Context, channel, th.BasicUser.Id, other.Id), "admin invite must not be blocked") + + // Toggle off the flag → guard is inert. + withDiscoverableChannelsFlag(t, th, false) + assert.False(t, th.App.IsDiscoverableSelfAddBlocked(th.Context, channel, other.Id, other.Id)) +} + +func TestFilterDiscoverableChannelsByPolicy_FlagOff(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + // Flag off → filter is a no-op even when channels look discoverable. + + channel := markDiscoverableInMemory(t, th.CreatePrivateChannel(t, th.BasicTeam)) + channel.PolicyEnforced = true + out, appErr := th.App.FilterDiscoverableChannelsByPolicy(th.Context, []*model.Channel{channel}, th.BasicUser2.Id) + require.Nil(t, appErr) + require.Len(t, out, 1) +} + +func TestFilterDiscoverableChannelsByPolicy_NoPolicyPasses(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := markDiscoverableInMemory(t, th.CreatePrivateChannel(t, th.BasicTeam)) + out, appErr := th.App.FilterDiscoverableChannelsByPolicy(th.Context, []*model.Channel{channel}, th.BasicUser2.Id) + require.Nil(t, appErr) + require.Len(t, out, 1, "no-policy discoverable channels are visible without ABAC evaluation") +} + +func TestFilterDiscoverableChannelsByPolicy_PolicyEnforcedFailSecure(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + // PolicyEnforced + Discoverable + no AccessControl service wired ⇒ hidden. + channel := markDiscoverableInMemory(t, th.CreatePrivateChannel(t, th.BasicTeam)) + channel.PolicyEnforced = true + + require.Nil(t, th.App.Srv().Channels().AccessControl, "test fixture must not have ABAC wired") + + out, appErr := th.App.FilterDiscoverableChannelsByPolicy(th.Context, []*model.Channel{channel}, th.BasicUser2.Id) + require.Nil(t, appErr) + assert.Len(t, out, 0, "fail-secure must hide policy-enforced channels when ABAC is unavailable") +} + +func TestFilterDiscoverableChannelsByPolicy_GuestHidden(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + withDiscoverableChannelsFlag(t, th, true) + + channel := markDiscoverableInMemory(t, th.CreatePrivateChannel(t, th.BasicTeam)) + channel.PolicyEnforced = true + + guest := th.CreateGuest(t) + out, appErr := th.App.FilterDiscoverableChannelsByPolicy(th.Context, []*model.Channel{channel}, guest.Id) + require.Nil(t, appErr) + assert.Empty(t, out, "guests must never see discoverable + policy-enforced channels") +} + +// markDiscoverableInMemory is a no-DB helper for visibility filter tests that +// don't care about persistence — they only exercise the in-memory list filter. +func markDiscoverableInMemory(t *testing.T, channel *model.Channel) *model.Channel { + t.Helper() + channel.Discoverable = true + return channel +} diff --git a/server/channels/app/channel_test.go b/server/channels/app/channel_test.go index d4a9e1f6a658..dce2f6e298c1 100644 --- a/server/channels/app/channel_test.go +++ b/server/channels/app/channel_test.go @@ -3641,6 +3641,12 @@ func TestCheckIfChannelIsRestrictedDM(t *testing.T) { func TestUpdateChannel(t *testing.T) { th := Setup(t).InitBasic(t) + t.Run("returns 404 for non-existent channel id", func(t *testing.T) { + _, appErr := th.App.UpdateChannel(th.Context, &model.Channel{Id: model.NewId()}) + require.NotNil(t, appErr) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + t.Run("should be able to update banner info", func(t *testing.T) { channel := th.createChannel(t, th.BasicTeam, model.ChannelTypeOpen) diff --git a/server/channels/app/channels.go b/server/channels/app/channels.go index eaa21d4ccd3f..07754530bd4c 100644 --- a/server/channels/app/channels.go +++ b/server/channels/app/channels.go @@ -10,6 +10,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -50,6 +51,13 @@ type Channels struct { pluginConfigListenerID string pluginClusterLeaderListenerID string + // guardCache caches ChannelGuards rows by ChannelId -> []*store.ChannelGuard. + guardCache atomic.Pointer[sync.Map] + + // guardCacheRetryInFlight collapses concurrent reload-failure retries to a single goroutine. + // See scheduleGuardCacheReloadRetry. + guardCacheRetryInFlight atomic.Bool + imageProxy *imageproxy.ImageProxy agentsBridge AgentsBridge @@ -107,6 +115,7 @@ func NewChannels(s *Server) (*Channels, error) { cfgSvc: s.Platform(), interruptQuitChan: make(chan struct{}), } + ch.guardCache.Store(&sync.Map{}) if s.agentsBridgeOverride != nil { ch.agentsBridge = s.agentsBridgeOverride @@ -231,6 +240,15 @@ func NewChannels(s *Server) (*Channels, error) { pluginsRoute.HandleFunc("/public/{public_file:.*}", ch.ServePluginPublicRequest) pluginsRoute.HandleFunc("/{anything:.*}", ch.ServePluginRequest) + if err := ch.reloadGuardCache(request.EmptyContext(s.Log()), s.Store()); err != nil { + s.Log().Warn( + "Failed to load channel guard cache at startup; retry scheduled", + mlog.Bool("clustered", s.platform.Cluster() != nil), + mlog.Err(err), + ) + ch.scheduleGuardCacheReloadRetry() + } + return ch, nil } @@ -325,6 +343,14 @@ func (ch *Channels) RunMultiHook(hookRunnerFunc func(hooks plugin.Hooks, manifes } } +// RunMultiHookExcluding is like RunMultiHook but skips plugins whose IDs appear in excludePluginIDs. +// Fail-open semantics are preserved. +func (ch *Channels) RunMultiHookExcluding(excludePluginIDs []string, hookRunnerFunc func(plugin.Hooks, *model.Manifest) bool, hookId int) { + if env := ch.GetPluginsEnvironment(); env != nil { + env.RunMultiPluginHookExcluding(excludePluginIDs, hookRunnerFunc, hookId) + } +} + // RunMultiHookWithRPCErr dispatches a hook closure across active plugins, surfacing RPC transport // errors. Returns nil in two cases that callers must distinguish themselves: (a) the plugin // environment is unavailable (plugins disabled, or not yet initialized), so the closure was never @@ -351,3 +377,13 @@ func (ch *Channels) HooksForPlugin(id string) (plugin.Hooks, error) { return hooks, nil } + +// HooksForPluginWithRPCErr returns the full *WithRPCErr hook surface for the named plugin. +// Returns an error if the plugin environment is unavailable, the plugin is not found, or not active. +func (ch *Channels) HooksForPluginWithRPCErr(id string) (plugin.HooksWithRPCErr, error) { + env := ch.GetPluginsEnvironment() + if env == nil { + return nil, errors.New("plugin environment not available") + } + return env.HooksForPluginWithRPCErr(id) +} diff --git a/server/channels/app/cluster_handlers.go b/server/channels/app/cluster_handlers.go index 720ce9ea0782..5a4a17faeea5 100644 --- a/server/channels/app/cluster_handlers.go +++ b/server/channels/app/cluster_handlers.go @@ -62,6 +62,7 @@ func (s *Server) registerClusterHandlers() { s.platform.RegisterClusterMessageHandler(model.ClusterEventInstallPlugin, s.clusterInstallPluginHandler) s.platform.RegisterClusterMessageHandler(model.ClusterEventRemovePlugin, s.clusterRemovePluginHandler) s.platform.RegisterClusterMessageHandler(model.ClusterEventPluginEvent, s.clusterPluginEventHandler) + s.platform.RegisterClusterMessageHandler(clusterEventInvalidateChannelGuardCache, s.Channels().clusterInvalidateGuardCacheHandler) s.platform.RegisterClusterHandlers() } diff --git a/server/channels/app/draft.go b/server/channels/app/draft.go index 275a718f544a..05d76184efdf 100644 --- a/server/channels/app/draft.go +++ b/server/channels/app/draft.go @@ -10,6 +10,7 @@ import ( "net/http" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/channels/store" @@ -72,9 +73,27 @@ func (a *App) UpsertDraft(rctx request.CTX, draft *model.Draft, connectionID str if deleteErr != nil { return nil, model.NewAppError("CreateDraft", "app.draft.save.app_error", nil, "", http.StatusInternalServerError).Wrap(deleteErr) } + rctx.Logger().Debug("Draft deleted via empty-message upsert", mlog.String("user_id", draft.UserId), mlog.String("channel_id", draft.ChannelId), mlog.String("root_id", draft.RootId)) return nil, nil } + var rejectionReason string + pluginContext := pluginContext(rctx) + a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool { + replacement, reason := hooks.DraftWillBeUpserted(pluginContext, draft) + if reason != "" { + rejectionReason = reason + return false + } + if replacement != nil { + draft = replacement + } + return true + }, plugin.DraftWillBeUpsertedID) + if rejectionReason != "" { + return nil, model.NewAppError("UpsertDraft", "app.draft.upsert.rejected_by_plugin", map[string]any{"Reason": rejectionReason}, "", http.StatusBadRequest) + } + dt, nErr := a.Srv().Store().Draft().Upsert(draft) if nErr != nil { return nil, model.NewAppError("CreateDraft", "app.draft.save.app_error", nil, "", http.StatusInternalServerError).Wrap(nErr) diff --git a/server/channels/app/email/mocks/ServiceInterface.go b/server/channels/app/email/mocks/ServiceInterface.go index 14517938d319..0e790ce309c1 100644 --- a/server/channels/app/email/mocks/ServiceInterface.go +++ b/server/channels/app/email/mocks/ServiceInterface.go @@ -7,17 +7,12 @@ package mocks import ( io "io" - i18n "github.com/mattermost/mattermost/server/public/shared/i18n" - - mock "github.com/stretchr/testify/mock" - model "github.com/mattermost/mattermost/server/public/model" - + i18n "github.com/mattermost/mattermost/server/public/shared/i18n" request "github.com/mattermost/mattermost/server/public/shared/request" - store "github.com/mattermost/mattermost/server/v8/channels/store" - templates "github.com/mattermost/mattermost/server/v8/platform/shared/templates" + mock "github.com/stretchr/testify/mock" ) // ServiceInterface is an autogenerated mock type for the ServiceInterface type diff --git a/server/channels/app/guarded_hooks.go b/server/channels/app/guarded_hooks.go new file mode 100644 index 000000000000..05b51a2f9e76 --- /dev/null +++ b/server/channels/app/guarded_hooks.go @@ -0,0 +1,411 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +// Channel-guard dispatch helpers. +// +// Each runGuarded helper implements two-phase plugin dispatch: Phase A fans out to non-guard +// plugins via RunMultiHookExcluding (fail-open, preserving RunMultiHook semantics — when guards is +// empty the exclude list is empty and the iteration is identical to plain RunMultiHook); Phase B +// calls each guard claimant in PluginId-sorted order via the *WithRPCErr companion, and fail-closed +// on transport errors. Phase B's for-range is a no-op when there are no guards, so unguarded +// channels traverse the same single linear flow with zero extra work beyond the Phase A dispatch. +// +// Allow-by-default for non-implementing claimants: a plugin may register a channel guard without +// implementing every guarded hook. When Phase B reaches such a claimant, the *WithRPCErr +// companion's g.implemented[] gate skips the RPC call entirely and returns zero values with +// a nil error. The helper's three guard branches all skip in that case, so the claimant contributes +// nothing, basically: "this plugin had no opinion on this hook." Iteration continues to the next +// claimant. +package app + +import ( + "net/http" + "sort" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" +) + +// resolveGuards returns the (sorted-by-PluginId) guard slice for channelID along with a +// non-nil rejectErr when the request must fail-close (plugin system disabled, or a specific +// claimant is inactive). The helper picks the right operator-facing log message internally. +// (nil, nil) means the channel is unguarded — Phase A still runs (with no exclusions) and +// Phase B's loop becomes a no-op. (guards, nil) means proceed with two-phase dispatch. +func (a *App) resolveGuards(rctx request.CTX, channelID, callerName string) (guards []*store.ChannelGuard, rejectErr *model.AppError) { + ch := a.Channels() + raw := ch.getGuardsForChannel(channelID) + if len(raw) == 0 { + return nil, nil + } + sorted := append([]*store.ChannelGuard(nil), raw...) + sort.Slice(sorted, func(i, j int) bool { return sorted[i].PluginId < sorted[j].PluginId }) + env := ch.GetPluginsEnvironment() + if env == nil { + // Plugin system disabled in config or not yet initialized, but guards exist for this + // channel. Operator action: flip PluginSettings.Enable on, or remove the guards. + return sorted, logAndErrPluginsDisabled(rctx, channelID, callerName) + } + var inactive []string + for _, g := range sorted { + if !env.IsActive(g.PluginId) { + inactive = append(inactive, g.PluginId) + } + } + if len(inactive) > 0 { + return sorted, logAndErrPluginInactive(rctx, channelID, inactive, callerName) + } + return sorted, nil +} + +// logAndErrPluginInactive emits an operator-facing Error log identifying the specific guard +// plugins that are currently inactive, then returns a generic 503 AppError. A guard plugin +// being down is an operational failure: the request must be rejected, but internal plugin IDs +// do not belong in the user-facing response. Operators read the log to diagnose which plugin +// to recover. +func logAndErrPluginInactive(rctx request.CTX, channelID string, pluginIDs []string, callerName string) *model.AppError { + rctx.Logger().Error("Channel guard rejected operation: claiming plugin is not active", + mlog.String("error_id", "guard_plugin_inactive"), + mlog.String("channel_id", channelID), + mlog.Array("plugin_ids", pluginIDs), + mlog.String("caller", callerName), + ) + return model.NewAppError(callerName, "app.plugin.inactive_guard.app_error", nil, "", http.StatusServiceUnavailable) +} + +// logAndErrPluginsDisabled emits an operator-facing Error log when the plugin system is off +// (PluginSettings.Enable == false or not yet initialized) but guards are still cached for the +// channel. Distinct from logAndErrPluginInactive: the cause is the global plugin switch, not +// a specific plugin failure. Returns the same generic 503 to the user. +func logAndErrPluginsDisabled(rctx request.CTX, channelID, callerName string) *model.AppError { + rctx.Logger().Error("Channel guard rejected operation: plugin system is disabled but guards exist for this channel", + mlog.String("error_id", "plugins_disabled_with_guards"), + mlog.String("channel_id", channelID), + mlog.String("caller", callerName), + ) + return model.NewAppError(callerName, "app.plugin.inactive_guard.app_error", nil, "", http.StatusServiceUnavailable) +} + +func appErrHookFailed(pluginID, callerName string, err error) *model.AppError { + appErr := model.NewAppError(callerName, "app.plugin.guard_hook_failed.app_error", + map[string]any{"PluginID": pluginID}, "", http.StatusServiceUnavailable) + if err != nil { + return appErr.Wrap(err) + } + return appErr +} + +func pluginIDsOf(guards []*store.ChannelGuard) []string { + ids := make([]string, len(guards)) + for i, g := range guards { + ids[i] = g.PluginId + } + return ids +} + +// runGuardedMessageWillBePosted dispatches MessageWillBePosted. Returns the (possibly +// replaced) post, or an AppError on rejection or RPC failure. +func (a *App) runGuardedMessageWillBePosted(rctx request.CTX, post *model.Post) (*model.Post, *model.AppError) { + guards, rejectErr := a.resolveGuards(rctx, post.ChannelId, "createPost") + + // Guard plugin is unavailable — fail-closed (logged with attribution). + if rejectErr != nil { + return nil, rejectErr + } + + var metadata *model.PostMetadata + if post.Metadata != nil { + metadata = post.Metadata.Copy() + } + + // Phase A: fan out to non-guard plugins, fail-open. With empty guards the exclude list is + // empty and behavior is identical to plain RunMultiHook. + var rejectionError *model.AppError + pCtx := pluginContext(rctx) + a.ch.RunMultiHookExcluding(pluginIDsOf(guards), func(hooks plugin.Hooks, _ *model.Manifest) bool { + replacementPost, rejectionReason := hooks.MessageWillBePosted(pCtx, post.ForPlugin()) + if rejectionReason != "" { + id := "Post rejected by plugin. " + rejectionReason + if rejectionReason == plugin.DismissPostError { + id = plugin.DismissPostError + } + rejectionError = model.NewAppError("createPost", id, nil, "", http.StatusBadRequest) + return false + } + if replacementPost != nil { + post = replacementPost + if post.Metadata != nil && metadata != nil { + post.Metadata.Priority = metadata.Priority + } else { + post.Metadata = metadata + } + } + return true + }, plugin.MessageWillBePostedID) + if rejectionError != nil { + return nil, rejectionError + } + + // Phase B: call each guard claimant in PluginId-sorted order, fail-closed. + for _, g := range guards { + hooks, err := a.Channels().HooksForPluginWithRPCErr(g.PluginId) + if err != nil { + // Active→inactive race: plugin deactivated between resolveGuards and now. + return nil, logAndErrPluginInactive(rctx, post.ChannelId, []string{g.PluginId}, "CreatePost") + } + replacement, reason, rpcErr := hooks.MessageWillBePostedWithRPCErr(pCtx, post.ForPlugin()) + if rpcErr != nil { + return nil, appErrHookFailed(g.PluginId, "CreatePost", rpcErr) + } + if reason != "" { + id := "Post rejected by plugin. " + reason + if reason == plugin.DismissPostError { + id = plugin.DismissPostError + } + return nil, model.NewAppError("createPost", id, nil, "", http.StatusBadRequest) + } + if replacement != nil { + post = replacement + if post.Metadata != nil && metadata != nil { + post.Metadata.Priority = metadata.Priority + } else { + post.Metadata = metadata + } + } + } + + return post, nil +} + +// runGuardedMessageWillBeUpdated dispatches MessageWillBeUpdated. In the non-guarded +// hook variant, either newPost == nil OR rejectionReason != "" signals rejection. +func (a *App) runGuardedMessageWillBeUpdated(rctx request.CTX, newPost, oldPost *model.Post) (*model.Post, *model.AppError) { + guards, rejectErr := a.resolveGuards(rctx, oldPost.ChannelId, "UpdatePost") + + // Guard plugin is unavailable — fail-closed (logged with attribution). + if rejectErr != nil { + return nil, rejectErr + } + + // buildUpdateRejectionErr mirrors the legacy error shape at post.go UpdatePost. + buildUpdateRejectionErr := func(reason string) *model.AppError { + id := "Post rejected by plugin. " + reason + if reason == plugin.DismissPostError { + id = plugin.DismissPostError + } + return model.NewAppError("UpdatePost", id, nil, "", http.StatusBadRequest) + } + + // Phase A: fan out to non-guard plugins, fail-open. With empty guards the exclude list is + // empty and behavior is identical to plain RunMultiHook. + var rejectionReason string + pCtx := pluginContext(rctx) + a.ch.RunMultiHookExcluding(pluginIDsOf(guards), func(hooks plugin.Hooks, _ *model.Manifest) bool { + newPost, rejectionReason = hooks.MessageWillBeUpdated(pCtx, newPost.ForPlugin(), oldPost.ForPlugin()) + return newPost != nil + }, plugin.MessageWillBeUpdatedID) + if newPost == nil { + return nil, buildUpdateRejectionErr(rejectionReason) + } + + // Phase B: call each guard claimant in PluginId-sorted order, fail-closed. + for _, g := range guards { + hooks, err := a.Channels().HooksForPluginWithRPCErr(g.PluginId) + if err != nil { + // Active→inactive race: plugin deactivated between resolveGuards and now. + return nil, logAndErrPluginInactive(rctx, oldPost.ChannelId, []string{g.PluginId}, "UpdatePost") + } + replacement, reason, rpcErr := hooks.MessageWillBeUpdatedWithRPCErr(pCtx, newPost.ForPlugin(), oldPost.ForPlugin()) + if rpcErr != nil { + return nil, appErrHookFailed(g.PluginId, "UpdatePost", rpcErr) + } + if reason != "" { + return nil, buildUpdateRejectionErr(reason) + } + // If replacement == nil && reason == "" && rpcErr == nil, the claimant had no opinion + // (did not implement the hook). Do not treat as rejection — continue iterating. + if replacement != nil { + newPost = replacement + } + } + + return newPost, nil +} + +// runGuardedChannelMemberWillBeAdded dispatches ChannelMemberWillBeAdded. Returns the (possibly +// replaced) member, or an AppError on rejection or RPC failure. +func (a *App) runGuardedChannelMemberWillBeAdded(rctx request.CTX, channelID string, member *model.ChannelMember) (*model.ChannelMember, *model.AppError) { + guards, rejectErr := a.resolveGuards(rctx, channelID, "AddUserToChannel") + + // Guard plugin is unavailable — fail-closed (logged with attribution). + if rejectErr != nil { + return nil, rejectErr + } + + buildMemberRejectionErr := func(reason string) *model.AppError { + return model.NewAppError("AddUserToChannel", "app.channel.add_user.to.channel.rejected_by_plugin", + map[string]any{"Reason": reason}, "", http.StatusBadRequest) + } + + // Phase A: fan out to non-guard plugins, fail-open. With empty guards the exclude list is + // empty and behavior is identical to plain RunMultiHook. + var rejectionError *model.AppError + pCtx := pluginContext(rctx) + a.ch.RunMultiHookExcluding(pluginIDsOf(guards), func(hooks plugin.Hooks, _ *model.Manifest) bool { + updatedMember, reason := hooks.ChannelMemberWillBeAdded(pCtx, member) + if reason != "" { + rejectionError = buildMemberRejectionErr(reason) + return false + } + if updatedMember != nil { + member = updatedMember + } + return true + }, plugin.ChannelMemberWillBeAddedID) + if rejectionError != nil { + return nil, rejectionError + } + + // Phase B: call each guard claimant in PluginId-sorted order, fail-closed. + for _, g := range guards { + hooks, err := a.Channels().HooksForPluginWithRPCErr(g.PluginId) + if err != nil { + // Active→inactive race: plugin deactivated between resolveGuards and now. + return nil, logAndErrPluginInactive(rctx, channelID, []string{g.PluginId}, "addUserToChannel") + } + replacement, reason, rpcErr := hooks.ChannelMemberWillBeAddedWithRPCErr(pCtx, member) + if rpcErr != nil { + return nil, appErrHookFailed(g.PluginId, "addUserToChannel", rpcErr) + } + if reason != "" { + return nil, buildMemberRejectionErr(reason) + } + // If replacement == nil && reason == "" && rpcErr == nil, the claimant had no opinion + // (did not implement the hook). Do not treat as rejection — continue iterating. + if replacement != nil { + member = replacement + } + } + + return member, nil +} + +// runGuardedChannelWillBeUpdated dispatches ChannelWillBeUpdated. Guard plugins may not mutate +// Channel.Type — type changes must go through dedicated paths (e.g., UpdateChannelPrivacy). The +// check applies only to guarded channels; unguarded callers retain RunMultiHook's permissive behavior. +func (a *App) runGuardedChannelWillBeUpdated(rctx request.CTX, newChannel, oldChannel *model.Channel) (*model.Channel, *model.AppError) { + guards, rejectErr := a.resolveGuards(rctx, newChannel.Id, "UpdateChannel") + + // Guard plugin is unavailable — fail-closed (logged with attribution). + if rejectErr != nil { + return nil, rejectErr + } + + buildUpdateRejectionErr := func(reason string) *model.AppError { + return model.NewAppError("UpdateChannel", "app.channel.update_channel.rejected_by_plugin", + map[string]any{"Reason": reason}, "", http.StatusBadRequest) + } + + buildTypeMutationErr := func(offendingPluginID string) *model.AppError { + return model.NewAppError("UpdateChannel", "app.channel.update_channel.plugin_type_mutation.app_error", + map[string]any{"PluginID": offendingPluginID}, "", http.StatusBadRequest) + } + + // Phase A: fan out to non-guard plugins, fail-open. With empty guards the exclude list is + // empty and behavior is identical to plain RunMultiHook. + // Track the last replacing plugin ID for type-mutation attribution (used only when guarded). + var rejectionReason string + var lastReplacingPluginID string + pCtx := pluginContext(rctx) + a.ch.RunMultiHookExcluding(pluginIDsOf(guards), func(hooks plugin.Hooks, manifest *model.Manifest) bool { + replacement, reason := hooks.ChannelWillBeUpdated(pCtx, newChannel, oldChannel) + if reason != "" { + rejectionReason = reason + return false + } + if replacement != nil { + newChannel = replacement + lastReplacingPluginID = manifest.Id + } + return true + }, plugin.ChannelWillBeUpdatedID) + if rejectionReason != "" { + return nil, buildUpdateRejectionErr(rejectionReason) + } + // Type-mutation check applies only to guarded channels; unguarded callers retain + // RunMultiHook's permissive semantics. + if len(guards) > 0 && lastReplacingPluginID != "" && newChannel.Type != oldChannel.Type { + return nil, buildTypeMutationErr(lastReplacingPluginID) + } + + // Phase B: call each guard claimant in PluginId-sorted order, fail-closed. + for _, g := range guards { + hooks, err := a.Channels().HooksForPluginWithRPCErr(g.PluginId) + if err != nil { + // Active→inactive race: plugin deactivated between resolveGuards and now. + return nil, logAndErrPluginInactive(rctx, newChannel.Id, []string{g.PluginId}, "UpdateChannel") + } + replacement, reason, rpcErr := hooks.ChannelWillBeUpdatedWithRPCErr(pCtx, newChannel, oldChannel) + if rpcErr != nil { + return nil, appErrHookFailed(g.PluginId, "UpdateChannel", rpcErr) + } + if reason != "" { + return nil, buildUpdateRejectionErr(reason) + } + // If replacement == nil && reason == "" && rpcErr == nil, the claimant had no opinion + // (did not implement the hook). Do not treat as rejection — continue iterating. + if replacement != nil { + newChannel = replacement + // Check immediately after each Phase B replacement. + if newChannel.Type != oldChannel.Type { + return nil, buildTypeMutationErr(g.PluginId) + } + } + } + + return newChannel, nil +} + +// runGuardedChannelWillBeRestored dispatches ChannelWillBeRestored. Reject-only — no replacement. +func (a *App) runGuardedChannelWillBeRestored(rctx request.CTX, channel *model.Channel) *model.AppError { + guards, rejectErr := a.resolveGuards(rctx, channel.Id, "RestoreChannel") + + // Guard plugin is unavailable — fail-closed (logged with attribution). + if rejectErr != nil { + return rejectErr + } + + // Phase A: fan out to non-guard plugins, fail-open. With empty guards the exclude list is + // empty and behavior is identical to plain RunMultiHook. + var rejectionReason string + pCtx := pluginContext(rctx) + a.ch.RunMultiHookExcluding(pluginIDsOf(guards), func(hooks plugin.Hooks, _ *model.Manifest) bool { + rejectionReason = hooks.ChannelWillBeRestored(pCtx, channel) + return rejectionReason == "" + }, plugin.ChannelWillBeRestoredID) + if rejectionReason != "" { + return model.NewAppError("RestoreChannel", "app.channel.restore_channel.rejected_by_plugin", + map[string]any{"Reason": rejectionReason}, "", http.StatusBadRequest) + } + + // Phase B: call each guard claimant in PluginId-sorted order, fail-closed. + for _, g := range guards { + hooks, err := a.Channels().HooksForPluginWithRPCErr(g.PluginId) + if err != nil { + // Active→inactive race: plugin deactivated between resolveGuards and now. + return logAndErrPluginInactive(rctx, channel.Id, []string{g.PluginId}, "RestoreChannel") + } + reason, rpcErr := hooks.ChannelWillBeRestoredWithRPCErr(pCtx, channel) + if rpcErr != nil { + return appErrHookFailed(g.PluginId, "RestoreChannel", rpcErr) + } + if reason != "" { + return model.NewAppError("RestoreChannel", "app.channel.restore_channel.rejected_by_plugin", + map[string]any{"Reason": reason}, "", http.StatusBadRequest) + } + } + + return nil +} diff --git a/server/channels/app/guarded_hooks_test.go b/server/channels/app/guarded_hooks_test.go new file mode 100644 index 000000000000..0e5a84189cba --- /dev/null +++ b/server/channels/app/guarded_hooks_test.go @@ -0,0 +1,224 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package app + +import ( + "errors" + "net/http" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" +) + +// seedGuardCache directly populates the Channels guard cache for unit tests that +// need guards without going through the full DB round-trip. +func seedGuardCache(th *TestHelper, channelID string, guards []*store.ChannelGuard) { + m := &sync.Map{} + if len(guards) > 0 { + m.Store(channelID, guards) + } + th.App.Channels().guardCache.Store(m) +} + +func TestResolveGuards(t *testing.T) { + t.Run("no guards returns nil nil", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Empty cache — channel has no guard rows. + seedGuardCache(th, th.BasicChannel.Id, nil) + + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, rejectErr := th.App.resolveGuards(rctx, th.BasicChannel.Id, "test") + require.Nil(t, rejectErr) + require.Nil(t, guards) + }) + + t.Run("cache uninitialized returns nil nil", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Store a nil *sync.Map — models the brief window before the first reload. + th.App.Channels().guardCache.Store((*sync.Map)(nil)) + + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, rejectErr := th.App.resolveGuards(rctx, th.BasicChannel.Id, "test") + require.Nil(t, rejectErr) + require.Nil(t, guards) + }) + + t.Run("guards are sorted by PluginId", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Insert guards in reverse alphabetical order; resolveGuards must return them sorted. + unsorted := []*store.ChannelGuard{ + {ChannelId: th.BasicChannel.Id, PluginId: "zzz.plugin"}, + {ChannelId: th.BasicChannel.Id, PluginId: "aaa.plugin"}, + {ChannelId: th.BasicChannel.Id, PluginId: "mmm.plugin"}, + } + seedGuardCache(th, th.BasicChannel.Id, unsorted) + + // All plugin IDs are unknown to the environment → IsActive returns false for each. + // Disable plugins so resolveGuards hits the env==nil branch instead. + // We only want to test sort order, so use a trick: temporarily disable plugins to + // get through the env==nil fast-path and confirm the sorted slice is built before + // the env check. Actually env==nil returns early with the sorted slice — that's + // correct behaviour to assert sort order. + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.PluginSettings.Enable = false }) + + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, rejectErr := th.App.resolveGuards(rctx, th.BasicChannel.Id, "test") + // env==nil → reject is non-nil, but guards slice must still be sorted. + require.NotNil(t, rejectErr, "plugins disabled + guards exist → expect reject error") + require.Len(t, guards, 3) + assert.Equal(t, "aaa.plugin", guards[0].PluginId) + assert.Equal(t, "mmm.plugin", guards[1].PluginId) + assert.Equal(t, "zzz.plugin", guards[2].PluginId) + }) + + t.Run("single inactive plugin returns reject error", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Seed one guard with a plugin ID that is not active in the environment. + fakePlugin := "com.example.inactive-single" + seedGuardCache(th, th.BasicChannel.Id, []*store.ChannelGuard{ + {ChannelId: th.BasicChannel.Id, PluginId: fakePlugin}, + }) + + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, rejectErr := th.App.resolveGuards(rctx, th.BasicChannel.Id, "callerA") + require.NotNil(t, rejectErr) + assert.Equal(t, "app.plugin.inactive_guard.app_error", rejectErr.Id) + assert.Equal(t, http.StatusServiceUnavailable, rejectErr.StatusCode) + // Guards slice is returned even on reject so callers can log the full context. + require.Len(t, guards, 1) + assert.Equal(t, fakePlugin, guards[0].PluginId) + }) + + t.Run("multiple inactive plugins returns reject error", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Two inactive guards — exercises the mlog.Array path in logAndErrPluginInactive. + seedGuardCache(th, th.BasicChannel.Id, []*store.ChannelGuard{ + {ChannelId: th.BasicChannel.Id, PluginId: "com.example.inactive-a"}, + {ChannelId: th.BasicChannel.Id, PluginId: "com.example.inactive-b"}, + }) + + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, rejectErr := th.App.resolveGuards(rctx, th.BasicChannel.Id, "callerB") + require.NotNil(t, rejectErr) + assert.Equal(t, "app.plugin.inactive_guard.app_error", rejectErr.Id) + assert.Equal(t, http.StatusServiceUnavailable, rejectErr.StatusCode) + require.Len(t, guards, 2) + }) + + t.Run("env nil branch returns reject error", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + fakePlugin := "com.example.env-nil" + seedGuardCache(th, th.BasicChannel.Id, []*store.ChannelGuard{ + {ChannelId: th.BasicChannel.Id, PluginId: fakePlugin}, + }) + + // Disable the plugin system so GetPluginsEnvironment returns nil. + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.PluginSettings.Enable = false }) + + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, rejectErr := th.App.resolveGuards(rctx, th.BasicChannel.Id, "callerC") + require.NotNil(t, rejectErr) + assert.Equal(t, "app.plugin.inactive_guard.app_error", rejectErr.Id) + assert.Equal(t, http.StatusServiceUnavailable, rejectErr.StatusCode) + // Guards slice is still populated with the sorted rows. + require.Len(t, guards, 1) + }) +} + +func TestPluginIDsOf(t *testing.T) { + t.Run("nil input returns empty slice", func(t *testing.T) { + ids := pluginIDsOf(nil) + assert.Empty(t, ids) + }) + + t.Run("empty input returns empty slice", func(t *testing.T) { + ids := pluginIDsOf([]*store.ChannelGuard{}) + assert.Empty(t, ids) + }) + + t.Run("multiple guards returns IDs in input order", func(t *testing.T) { + guards := []*store.ChannelGuard{ + {PluginId: "aaa"}, + {PluginId: "bbb"}, + {PluginId: "ccc"}, + } + ids := pluginIDsOf(guards) + require.Equal(t, []string{"aaa", "bbb", "ccc"}, ids) + }) +} + +func TestAppErrHookFailed(t *testing.T) { + t.Run("without error sets correct fields", func(t *testing.T) { + appErr := appErrHookFailed("com.example.plugin", "CreatePost", nil) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.guard_hook_failed.app_error", appErr.Id) + assert.Equal(t, http.StatusServiceUnavailable, appErr.StatusCode) + // err==nil branch: no Wrap, so Unwrap returns nil. + assert.NoError(t, appErr.Unwrap()) + }) + + t.Run("with error wraps it", func(t *testing.T) { + cause := errors.New("rpc transport failure") + appErr := appErrHookFailed("com.example.plugin", "UpdatePost", cause) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.guard_hook_failed.app_error", appErr.Id) + assert.Equal(t, http.StatusServiceUnavailable, appErr.StatusCode) + // err!=nil branch: Wrap stores it; errors.Is traverses via Unwrap. + assert.ErrorIs(t, appErr, cause) + }) +} + +func TestLogAndErrPluginInactive(t *testing.T) { + t.Run("single plugin ID returns correct AppError", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + rctx := request.EmptyContext(th.App.Srv().Log()) + + appErr := logAndErrPluginInactive(rctx, "ch-id-1", []string{"com.example.only"}, "callerX") + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, http.StatusServiceUnavailable, appErr.StatusCode) + }) + + t.Run("multiple plugin IDs returns correct AppError", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + rctx := request.EmptyContext(th.App.Srv().Log()) + + appErr := logAndErrPluginInactive(rctx, "ch-id-2", []string{"com.a", "com.b", "com.c"}, "callerY") + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, http.StatusServiceUnavailable, appErr.StatusCode) + }) +} + +func TestLogAndErrPluginsDisabled(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + rctx := request.EmptyContext(th.App.Srv().Log()) + + appErr := logAndErrPluginsDisabled(rctx, "ch-id-3", "callerZ") + require.NotNil(t, appErr) + // Same user-visible error ID as inactive_guard (internal cause differs). + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, http.StatusServiceUnavailable, appErr.StatusCode) +} diff --git a/server/channels/app/plugin_api.go b/server/channels/app/plugin_api.go index b57b3a2ef2fe..e3689b57bb27 100644 --- a/server/channels/app/plugin_api.go +++ b/server/channels/app/plugin_api.go @@ -537,6 +537,14 @@ func (api *PluginAPI) UpdateChannel(channel *model.Channel) (*model.Channel, *mo return api.app.UpdateChannel(api.ctx, channel) } +func (api *PluginAPI) RegisterChannelGuard(channelID string) *model.AppError { + return api.app.RegisterChannelGuard(api.ctx, channelID, strings.ToLower(api.id)) +} + +func (api *PluginAPI) UnregisterChannelGuard(channelID string) *model.AppError { + return api.app.UnregisterChannelGuard(api.ctx, channelID, strings.ToLower(api.id)) +} + func (api *PluginAPI) SearchChannels(teamID string, term string) ([]*model.Channel, *model.AppError) { channels, err := api.app.SearchChannels(api.ctx, teamID, term) if err != nil { diff --git a/server/channels/app/plugin_hooks_test.go b/server/channels/app/plugin_hooks_test.go index 0f05fd9c8da7..df483b555e1a 100644 --- a/server/channels/app/plugin_hooks_test.go +++ b/server/channels/app/plugin_hooks_test.go @@ -6,12 +6,15 @@ package app import ( "bytes" _ "embed" + "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "os" "path/filepath" + "sort" + "strconv" "strings" "sync" "testing" @@ -1846,6 +1849,171 @@ func TestHookMessagesWillBeConsumed(t *testing.T) { }) } +func TestUpdatePostFiresConsumeHook(t *testing.T) { + mainHelper.Parallel(t) + + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.ConsumePostHook = true + }).InitBasic(t) + + var mockAPI plugintest.API + mockAPI.On("LoadPluginConfiguration", mock.Anything).Return(nil) + mockAPI.On("LogDebug", mock.Anything).Return(nil) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{` + package main + + import ( + "strings" + + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) MessagesWillBeConsumed(posts []*model.Post) []*model.Post { + for _, post := range posts { + post.Message = strings.ToUpper(post.Message) + } + return posts + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `}, th.App, func(*model.Manifest) plugin.API { return &mockAPI }) + t.Cleanup(tearDown) + + wsMessages, closeWS := connectFakeWebSocket(t, th, th.BasicUser.Id, "", []model.WebsocketEventType{ + model.WebsocketEventPosted, + model.WebsocketEventPostEdited, + }) + defer closeWS() + + basePost, _, err := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original body", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: false}) + require.Nil(t, err) + + drainTimeout := time.After(500 * time.Millisecond) +drainLoop: + for { + select { + case <-wsMessages: + case <-drainTimeout: + break drainLoop + } + } + + editedMessage := "edited body" + patchedPost, _, err := th.App.PatchPost(th.Context, basePost.Id, &model.PostPatch{ + Message: &editedMessage, + }, nil) + require.Nil(t, err) + + require.Equal(t, "EDITED BODY", patchedPost.Message) + + timeout := time.After(5 * time.Second) + for { + select { + case ev := <-wsMessages: + if ev.EventType() != model.WebsocketEventPostEdited { + continue + } + postJSON, ok := ev.GetData()["post"].(string) + require.True(t, ok, "post field in websocket event should be a JSON string") + var wsPost model.Post + require.NoError(t, json.Unmarshal([]byte(postJSON), &wsPost)) + assert.Equal(t, "EDITED BODY", wsPost.Message) + return + case <-timeout: + require.Fail(t, "timed out waiting for post_edited websocket event") + } + } +} + +func TestUpdatePostNoConsumeHookWhenFlagDisabled(t *testing.T) { + mainHelper.Parallel(t) + + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.ConsumePostHook = false + }).InitBasic(t) + + var mockAPI plugintest.API + mockAPI.On("LoadPluginConfiguration", mock.Anything).Return(nil) + mockAPI.On("LogDebug", mock.Anything).Return(nil) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{` + package main + + import ( + "strings" + + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) MessagesWillBeConsumed(posts []*model.Post) []*model.Post { + for _, post := range posts { + post.Message = strings.ToUpper(post.Message) + } + return posts + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `}, th.App, func(*model.Manifest) plugin.API { return &mockAPI }) + t.Cleanup(tearDown) + + basePost, _, err := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original body", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: false}) + require.Nil(t, err) + + editedMessage := "edited body" + patchedPost, _, err := th.App.PatchPost(th.Context, basePost.Id, &model.PostPatch{ + Message: &editedMessage, + }, nil) + require.Nil(t, err) + + assert.Equal(t, "edited body", patchedPost.Message) +} + +func TestUpdatePostNoOpWhenNoPlugin(t *testing.T) { + mainHelper.Parallel(t) + + th := SetupConfig(t, func(cfg *model.Config) { + cfg.FeatureFlags.ConsumePostHook = true + }).InitBasic(t) + + basePost, _, err := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original body", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: false}) + require.Nil(t, err) + + editedMessage := "edited body" + patchedPost, _, err := th.App.PatchPost(th.Context, basePost.Id, &model.PostPatch{ + Message: &editedMessage, + }, nil) + require.Nil(t, err) + + assert.Equal(t, "edited body", patchedPost.Message) +} + func TestHookPreferencesHaveChanged(t *testing.T) { mainHelper.Parallel(t) t.Run("should be called when preferences are changed by non-plugin code", func(t *testing.T) { @@ -3230,3 +3398,3003 @@ func TestHookChannelWillBeArchived(t *testing.T) { assert.NotEqual(t, int64(0), ch.DeleteAt) }) } + +func TestHookRPCChannelWillBeUpdated(t *testing.T) { + mainHelper.Parallel(t) + + t.Run("rejected", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) { + return nil, "rpc test rejected" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + hooks, err := th.App.GetPluginsEnvironment().HooksForPlugin(pluginIDs[0]) + require.NoError(t, err) + + newCh := &model.Channel{Id: model.NewId(), TeamId: th.BasicTeam.Id, Type: model.ChannelTypePrivate, DisplayName: "new"} + oldCh := &model.Channel{Id: newCh.Id, TeamId: th.BasicTeam.Id, Type: model.ChannelTypeOpen, DisplayName: "old"} + replacement, reason := hooks.ChannelWillBeUpdated(&plugin.Context{}, newCh, oldCh) + require.Equal(t, "rpc test rejected", reason) + require.Nil(t, replacement) + }) + + t.Run("modify", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) { + newChannel.DisplayName = "modified-by-plugin" + return newChannel, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + hooks, err := th.App.GetPluginsEnvironment().HooksForPlugin(pluginIDs[0]) + require.NoError(t, err) + + newCh := &model.Channel{Id: model.NewId(), TeamId: th.BasicTeam.Id, Type: model.ChannelTypePrivate, DisplayName: "new"} + oldCh := &model.Channel{Id: newCh.Id, TeamId: th.BasicTeam.Id, Type: model.ChannelTypeOpen, DisplayName: "old"} + replacement, reason := hooks.ChannelWillBeUpdated(&plugin.Context{}, newCh, oldCh) + require.Equal(t, "", reason) + require.NotNil(t, replacement) + require.Equal(t, "modified-by-plugin", replacement.DisplayName) + }) +} + +func TestHookRPCChannelWillBeRestored(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeRestored(c *plugin.Context, channel *model.Channel) string { + return "rpc test rejected" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + hooks, err := th.App.GetPluginsEnvironment().HooksForPlugin(pluginIDs[0]) + require.NoError(t, err) + + ch := &model.Channel{Id: model.NewId(), TeamId: th.BasicTeam.Id, Type: model.ChannelTypePrivate, DisplayName: "restore"} + reason := hooks.ChannelWillBeRestored(&plugin.Context{}, ch) + require.Equal(t, "rpc test rejected", reason) +} + +func TestHookRPCScheduledPostWillBeCreated(t *testing.T) { + mainHelper.Parallel(t) + + t.Run("rejected", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ScheduledPostWillBeCreated(c *plugin.Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) { + return nil, "rpc test rejected" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + hooks, err := th.App.GetPluginsEnvironment().HooksForPlugin(pluginIDs[0]) + require.NoError(t, err) + + sp := &model.ScheduledPost{ + Draft: model.Draft{ + UserId: model.NewId(), + ChannelId: model.NewId(), + Message: "scheduled hi", + }, + Id: model.NewId(), + ScheduledAt: 1234567890, + } + replacement, reason := hooks.ScheduledPostWillBeCreated(&plugin.Context{}, sp) + require.Equal(t, "rpc test rejected", reason) + require.Nil(t, replacement) + }) + + t.Run("modify", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ScheduledPostWillBeCreated(c *plugin.Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) { + scheduledPost.Message = "modified-by-plugin" + return scheduledPost, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + hooks, err := th.App.GetPluginsEnvironment().HooksForPlugin(pluginIDs[0]) + require.NoError(t, err) + + sp := &model.ScheduledPost{ + Draft: model.Draft{ + UserId: model.NewId(), + ChannelId: model.NewId(), + Message: "original", + }, + Id: model.NewId(), + ScheduledAt: 1234567890, + } + replacement, reason := hooks.ScheduledPostWillBeCreated(&plugin.Context{}, sp) + require.Equal(t, "", reason) + require.NotNil(t, replacement) + require.Equal(t, "modified-by-plugin", replacement.Message) + }) +} + +func TestHookRPCDraftWillBeUpserted(t *testing.T) { + mainHelper.Parallel(t) + + t.Run("rejected", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) DraftWillBeUpserted(c *plugin.Context, draft *model.Draft) (*model.Draft, string) { + return nil, "rpc test rejected" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + hooks, err := th.App.GetPluginsEnvironment().HooksForPlugin(pluginIDs[0]) + require.NoError(t, err) + + draft := &model.Draft{ + UserId: model.NewId(), + ChannelId: model.NewId(), + Message: "draft hi", + } + replacement, reason := hooks.DraftWillBeUpserted(&plugin.Context{}, draft) + require.Equal(t, "rpc test rejected", reason) + require.Nil(t, replacement) + }) + + t.Run("modify", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) DraftWillBeUpserted(c *plugin.Context, draft *model.Draft) (*model.Draft, string) { + draft.Message = "modified-by-plugin" + return draft, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + hooks, err := th.App.GetPluginsEnvironment().HooksForPlugin(pluginIDs[0]) + require.NoError(t, err) + + draft := &model.Draft{ + UserId: model.NewId(), + ChannelId: model.NewId(), + Message: "original", + } + replacement, reason := hooks.DraftWillBeUpserted(&plugin.Context{}, draft) + require.Equal(t, "", reason) + require.NotNil(t, replacement) + require.Equal(t, "modified-by-plugin", replacement.Message) + }) +} + +func TestRegisterChannelGuardIdempotent(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + channelID := th.BasicChannel.Id + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) OnActivate() error { + channelID := "` + channelID + `" + if appErr := p.API.RegisterChannelGuard(channelID); appErr != nil { + return appErr + } + // Second call must be idempotent. + return p.API.RegisterChannelGuard(channelID) + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, err := th.App.Srv().Store().ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, guards, 1, "second Register call must be a no-op (DO NOTHING)") + + cached := th.App.Channels().getGuardsForChannel(channelID) + require.Len(t, cached, 1, "cache should match the store") + assert.Equal(t, strings.ToLower(pluginIDs[0]), cached[0].PluginId) +} + +func TestRegisterChannelGuardMultiClaim(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + channelID := th.BasicChannel.Id + + pluginCode := func() string { + return ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) OnActivate() error { + return p.API.RegisterChannelGuard("` + channelID + `") + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + ` + } + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + pluginCode(), + pluginCode(), + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 2) + + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, err := th.App.Srv().Store().ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, guards, 2, "two distinct plugins must produce two rows") + + pluginAID := strings.ToLower(pluginIDs[0]) + pluginBID := strings.ToLower(pluginIDs[1]) + + cached := th.App.Channels().getGuardsForChannel(channelID) + require.Len(t, cached, 2) + cachedIDs := []string{cached[0].PluginId, cached[1].PluginId} + assert.Contains(t, cachedIDs, pluginAID) + assert.Contains(t, cachedIDs, pluginBID) + + // Unregister plugin A's claim via the App-level method; B's claim must remain. + require.Nil(t, th.App.UnregisterChannelGuard(rctx, channelID, pluginAID)) + + guards, err = th.App.Srv().Store().ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, guards, 1) + assert.Equal(t, pluginBID, guards[0].PluginId) + + cached = th.App.Channels().getGuardsForChannel(channelID) + require.Len(t, cached, 1) + assert.Equal(t, pluginBID, cached[0].PluginId) +} + +func TestChannelGuardSurvivesArchive(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + channelID := th.BasicChannel.Id + + tearDown, pluginIDs, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) OnActivate() error { + return p.API.RegisterChannelGuard("` + channelID + `") + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, pluginIDs, 1) + + // Archive the channel. + require.Nil(t, th.App.DeleteChannel(th.Context, th.BasicChannel, th.BasicUser.Id)) + + // Guard row must persist (no FK, no cascade). + rctx := request.EmptyContext(th.App.Srv().Log()) + guards, err := th.App.Srv().Store().ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, guards, 1) + assert.Equal(t, strings.ToLower(pluginIDs[0]), guards[0].PluginId) + + cached := th.App.Channels().getGuardsForChannel(channelID) + require.Len(t, cached, 1) +} + +func TestHookChannelWillBeUpdated(t *testing.T) { + mainHelper.Parallel(t) + + t.Run("rejected", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) { + return nil, "update not permitted" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + original := th.BasicChannel.DisplayName + updated := th.BasicChannel.DeepCopy() + updated.DisplayName = "Should Not Persist" + + _, appErr := th.App.UpdateChannel(th.Context, updated) + require.NotNil(t, appErr) + assert.Contains(t, appErr.Id, "rejected_by_plugin") + + fetched, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + assert.Equal(t, original, fetched.DisplayName) + }) + + t.Run("modified", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "strings" + + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) { + newChannel.DisplayName = strings.ToUpper(newChannel.DisplayName) + return newChannel, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + updated := th.BasicChannel.DeepCopy() + updated.DisplayName = "lowercase name" + + _, appErr := th.App.UpdateChannel(th.Context, updated) + require.Nil(t, appErr) + + fetched, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + assert.Equal(t, "LOWERCASE NAME", fetched.DisplayName) + }) + + t.Run("old vs new diff", func(t *testing.T) { + // Plugin rejects only when the DisplayName changed — proving that oldChannel carries the + // stored value, not a copy of newChannel. + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) { + if oldChannel.DisplayName != newChannel.DisplayName { + return nil, "display name changed" + } + return nil, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + // Call with a changed DisplayName — plugin sees old != new and rejects. + changed := th.BasicChannel.DeepCopy() + changed.DisplayName = "Renamed Channel" + _, appErr := th.App.UpdateChannel(th.Context, changed) + require.NotNil(t, appErr) + assert.Contains(t, appErr.Id, "rejected_by_plugin") + + // Call with the same DisplayName — plugin sees old == new and allows. + same := th.BasicChannel.DeepCopy() + _, appErr = th.App.UpdateChannel(th.Context, same) + require.Nil(t, appErr) + }) + + t.Run("idempotent across repeat calls", func(t *testing.T) { + // UpdateChannelPrivacy may invoke UpdateChannel twice on the postChannelPrivacyMessage + // failure path (forward + revert). This test approximates that double-fire by calling + // UpdateChannel twice with the same plugin loaded — the hook must tolerate repeat invocations. + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) { + return nil, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + first := th.BasicChannel.DeepCopy() + first.DisplayName = "First" + _, appErr := th.App.UpdateChannel(th.Context, first) + require.Nil(t, appErr) + + second := first.DeepCopy() + second.DisplayName = "Second" + _, appErr = th.App.UpdateChannel(th.Context, second) + require.Nil(t, appErr) + + fetched, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + assert.Equal(t, "Second", fetched.DisplayName) + }) +} + +func TestHookChannelWillBeRestored(t *testing.T) { + mainHelper.Parallel(t) + + t.Run("rejected", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // First archive the channel so RestoreChannel has something to do. + require.Nil(t, th.App.DeleteChannel(th.Context, th.BasicChannel, th.BasicUser.Id)) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeRestored(c *plugin.Context, channel *model.Channel) string { + return "restore not permitted" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + archived, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + + _, appErr := th.App.RestoreChannel(th.Context, archived, th.BasicUser.Id) + require.NotNil(t, appErr) + assert.Contains(t, appErr.Id, "rejected_by_plugin") + + fetched, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + assert.NotEqual(t, int64(0), fetched.DeleteAt) + }) + + t.Run("allowed", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + require.Nil(t, th.App.DeleteChannel(th.Context, th.BasicChannel, th.BasicUser.Id)) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeRestored(c *plugin.Context, channel *model.Channel) string { + return "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + archived, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + + _, appErr := th.App.RestoreChannel(th.Context, archived, th.BasicUser.Id) + require.Nil(t, appErr) + + fetched, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + assert.Equal(t, int64(0), fetched.DeleteAt) + }) +} + +func TestHookScheduledPostWillBeCreated(t *testing.T) { + mainHelper.Parallel(t) + + t.Run("save rejected", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ScheduledPostWillBeCreated(c *plugin.Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) { + return nil, "scheduled post not permitted" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + sp := &model.ScheduledPost{ + Draft: model.Draft{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "scheduled hi", + }, + ScheduledAt: model.GetMillis() + 60_000, + } + _, appErr := th.App.SaveScheduledPost(th.Context, sp, "") + require.NotNil(t, appErr) + assert.Contains(t, appErr.Id, "rejected_by_plugin") + }) + + t.Run("save modified", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ScheduledPostWillBeCreated(c *plugin.Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) { + scheduledPost.Message = "modified-by-plugin" + return scheduledPost, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + sp := &model.ScheduledPost{ + Draft: model.Draft{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original", + }, + ScheduledAt: model.GetMillis() + 60_000, + } + saved, appErr := th.App.SaveScheduledPost(th.Context, sp, "") + require.Nil(t, appErr) + require.NotNil(t, saved) + assert.Equal(t, "modified-by-plugin", saved.Message) + }) + + t.Run("update rejected", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // First save (no plugin loaded yet so the hook is a no-op). + sp := &model.ScheduledPost{ + Draft: model.Draft{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original", + }, + ScheduledAt: model.GetMillis() + 60_000, + } + saved, appErr := th.App.SaveScheduledPost(th.Context, sp, "") + require.Nil(t, appErr) + require.NotNil(t, saved) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ScheduledPostWillBeCreated(c *plugin.Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) { + return nil, "update not permitted" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + saved.Message = "edited" + _, appErr = th.App.UpdateScheduledPost(th.Context, th.BasicUser.Id, saved, "") + require.NotNil(t, appErr) + assert.Contains(t, appErr.Id, "rejected_by_plugin") + }) +} + +func TestHookDraftWillBeUpserted(t *testing.T) { + mainHelper.Parallel(t) + + t.Run("rejected", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.Server.platform.SetConfigReadOnlyFF(false) + defer th.Server.platform.SetConfigReadOnlyFF(true) + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = true }) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) DraftWillBeUpserted(c *plugin.Context, draft *model.Draft) (*model.Draft, string) { + return nil, "draft not permitted" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + draft := &model.Draft{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "draft hi", + } + _, appErr := th.App.UpsertDraft(th.Context, draft, "") + require.NotNil(t, appErr) + assert.Contains(t, appErr.Id, "rejected_by_plugin") + + drafts, getErr := th.App.GetDraftsForUser(th.Context, th.BasicUser.Id, th.BasicTeam.Id) + require.Nil(t, getErr) + assert.Empty(t, drafts) + }) + + t.Run("modified", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.Server.platform.SetConfigReadOnlyFF(false) + defer th.Server.platform.SetConfigReadOnlyFF(true) + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = true }) + + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) DraftWillBeUpserted(c *plugin.Context, draft *model.Draft) (*model.Draft, string) { + draft.Message = "modified-by-plugin" + return draft, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + draft := &model.Draft{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original", + } + saved, appErr := th.App.UpsertDraft(th.Context, draft, "") + require.Nil(t, appErr) + require.NotNil(t, saved) + assert.Equal(t, "modified-by-plugin", saved.Message) + }) + + t.Run("delete-empty does not fire hook", func(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + th.Server.platform.SetConfigReadOnlyFF(false) + defer th.Server.platform.SetConfigReadOnlyFF(true) + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = true }) + + // Plugin rejects everything; if it fires on the delete path we will see an AppError. + tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) DraftWillBeUpserted(c *plugin.Context, draft *model.Draft) (*model.Draft, string) { + return nil, "should not be called for empty-message delete" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + empty := &model.Draft{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "", + } + _, appErr := th.App.UpsertDraft(th.Context, empty, "") + require.Nil(t, appErr) + }) +} + +func TestHooksNoOpWhenNoPlugin(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // No plugin loaded — all hooks must be no-ops and the affected app calls must succeed + // (or fail for unrelated reasons). This guards against accidentally turning a no-op + // RunMultiHook into a hard requirement. + + updated := th.BasicChannel.DeepCopy() + updated.DisplayName = "renamed" + _, appErr := th.App.UpdateChannel(th.Context, updated) + require.Nil(t, appErr) + + require.Nil(t, th.App.DeleteChannel(th.Context, th.BasicChannel, th.BasicUser.Id)) + archived, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + _, appErr = th.App.RestoreChannel(th.Context, archived, th.BasicUser.Id) + require.Nil(t, appErr) + + // UpsertDraft exercises the DraftWillBeUpserted hook path with no plugin loaded. + th.Server.platform.SetConfigReadOnlyFF(false) + defer th.Server.platform.SetConfigReadOnlyFF(true) + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = true }) + draft := &model.Draft{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "no-op draft", + } + _, appErr = th.App.UpsertDraft(th.Context, draft, "") + require.Nil(t, appErr) +} + +func TestChannelGuardBlocksPostWhenPluginInactive(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Compile and activate a plugin that implements MessageWillBePosted (allow all posts). + // The guard row is registered directly from the test using App.RegisterChannelGuard so + // the test is not coupled to a particular OnActivate implementation. + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + return nil, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + require.Len(t, pluginIDs, 1) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // Register a channel guard for BasicChannel under this plugin's ID. + appErr := th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID) + require.Nil(t, appErr, "RegisterChannelGuard must succeed") + + // Subtest (a): plugin active — CreatePost must succeed. + t.Run("plugin active allows post", func(t *testing.T) { + post := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "should be allowed", + } + createdPost, _, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + require.NotNil(t, createdPost) + }) + + // Subtest (b): plugin deactivated — CreatePost must return 503 inactive_guard error + // and the post must not be persisted. + t.Run("plugin inactive rejects post", func(t *testing.T) { + require.True(t, th.App.GetPluginsEnvironment().Deactivate(pluginID)) + require.False(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + post := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "should be rejected", + } + createdPost, _, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.NotNil(t, appErr, "expected error when guard plugin is inactive") + require.Nil(t, createdPost) + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Verify the post was not persisted by fetching recent posts for the channel. + postList, storeErr := th.App.Srv().Store().Post().GetPosts(th.Context, model.GetPostsOptions{ + ChannelId: th.BasicChannel.Id, + Page: 0, + PerPage: 10, + }, false, nil) + require.NoError(t, storeErr) + for _, p := range postList.Posts { + assert.NotEqual(t, "should be rejected", p.Message, "rejected post must not be in the store") + } + }) +} + +func TestChannelGuardBlocksPostUpdateWhenPluginInactive(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Compile and activate a plugin that implements MessageWillBeUpdated (allow all updates). + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) MessageWillBeUpdated(c *plugin.Context, newPost *model.Post, oldPost *model.Post) (*model.Post, string) { + return newPost, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + require.Len(t, pluginIDs, 1) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // Register a channel guard for BasicChannel under this plugin's ID. + appErr := th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID) + require.Nil(t, appErr, "RegisterChannelGuard must succeed") + + // Create the initial post that will be updated. + post := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original message", + } + createdPost, _, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + require.NotNil(t, createdPost) + + // Subtest (a): plugin active — UpdatePost must succeed. + t.Run("plugin active allows update", func(t *testing.T) { + updatedPost := createdPost.Clone() + updatedPost.Message = "updated message allowed" + result, _, appErr := th.App.UpdatePost(th.Context, updatedPost, nil) + require.Nil(t, appErr) + require.NotNil(t, result) + }) + + // Subtest (b): plugin deactivated — UpdatePost must return 503 inactive_guard error + // and the post must remain unchanged in the store. + t.Run("plugin inactive rejects update", func(t *testing.T) { + require.True(t, th.App.GetPluginsEnvironment().Deactivate(pluginID)) + require.False(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + updatedPost := createdPost.Clone() + updatedPost.Message = "should be rejected" + result, _, appErr := th.App.UpdatePost(th.Context, updatedPost, nil) + require.NotNil(t, appErr, "expected error when guard plugin is inactive") + require.Nil(t, result) + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Verify the post was not updated by fetching it from the store. + fetchedPost, storeErr := th.App.GetSinglePost(th.Context, createdPost.Id, false) + require.Nil(t, storeErr) + assert.NotEqual(t, "should be rejected", fetchedPost.Message, "rejected update must not be persisted") + }) +} + +// TestChannelGuardPostUpdateRejectionReasonPreserved locks in the legacy rejection-reason +// shape for UpdatePost. A plugin returning (nil, "blocked-by-policy") must surface as +// AppError with Id "Post rejected by plugin. blocked-by-policy". The unguarded path +// exercises the legacy AppError shape that existing tooling may grep for. +func TestChannelGuardPostUpdateRejectionReasonPreserved(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) MessageWillBeUpdated(c *plugin.Context, newPost *model.Post, oldPost *model.Post) (*model.Post, string) { + return nil, "blocked-by-policy" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + require.Len(t, pluginIDs, 1) + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginIDs[0])) + + // Create the initial post that will be updated. + post := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original message", + } + createdPost, _, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + require.NotNil(t, createdPost) + + // Unguarded path — no guard registered. The plugin returns (nil, "blocked-by-policy") and the + // rejection error must include the reason verbatim. + updatedPost := createdPost.Clone() + updatedPost.Message = "unguarded rejection" + result, _, appErr := th.App.UpdatePost(th.Context, updatedPost, nil) + require.NotNil(t, appErr, "expected rejection from plugin") + require.Nil(t, result) + assert.Equal(t, "Post rejected by plugin. blocked-by-policy", appErr.Id) +} + +func TestChannelGuardBlocksMemberAddWhenPluginInactive(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Compile and activate a plugin that implements ChannelMemberWillBeAdded (allow all). + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelMemberWillBeAdded(c *plugin.Context, member *model.ChannelMember) (*model.ChannelMember, string) { + return member, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + require.Len(t, pluginIDs, 1) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // Create a private channel to test member addition. + privateChannel := th.CreatePrivateChannel(t, th.BasicTeam) + + // Register a channel guard for this channel under this plugin's ID. + appErr := th.App.RegisterChannelGuard(th.Context, privateChannel.Id, pluginID) + require.Nil(t, appErr, "RegisterChannelGuard must succeed") + + // Subtest (a): plugin active — AddUserToChannel must succeed. + t.Run("plugin active allows member add", func(t *testing.T) { + _, appErr := th.App.AddUserToChannel(th.Context, th.BasicUser2, privateChannel, false) + // May already be a member from setup; either success or "already a member" is OK. + if appErr != nil { + assert.NotEqual(t, "app.plugin.inactive_guard.app_error", appErr.Id, "must not be a guard error when plugin is active") + } + }) + + // Subtest (b): plugin deactivated — AddUserToChannel must return 503 inactive_guard error. + t.Run("plugin inactive rejects member add", func(t *testing.T) { + require.True(t, th.App.GetPluginsEnvironment().Deactivate(pluginID)) + require.False(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // Use a new user who is definitely not yet a member; add them to the team first. + newUser := th.CreateUser(t) + _, _, teamErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, newUser.Id, "") + require.Nil(t, teamErr) + _, appErr := th.App.AddUserToChannel(th.Context, newUser, privateChannel, false) + require.NotNil(t, appErr, "expected error when guard plugin is inactive") + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Verify the user was not added. + _, memberErr := th.App.GetChannelMember(th.Context, privateChannel.Id, newUser.Id) + require.NotNil(t, memberErr, "user must not be a member of the channel") + }) +} + +func TestChannelGuardBlocksChannelUpdateWhenPluginInactive(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Compile and activate a plugin that implements ChannelWillBeUpdated (allow all updates). + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + return newChannel, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + require.Len(t, pluginIDs, 1) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // Register a channel guard for BasicChannel under this plugin's ID. + appErr := th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID) + require.Nil(t, appErr, "RegisterChannelGuard must succeed") + + // Subtest (a): plugin active — UpdateChannel must succeed. + t.Run("plugin active allows update", func(t *testing.T) { + channelToUpdate := th.BasicChannel.DeepCopy() + channelToUpdate.DisplayName = "Updated Name Allowed" + result, appErr := th.App.UpdateChannel(th.Context, channelToUpdate) + require.Nil(t, appErr) + require.NotNil(t, result) + }) + + // Subtest (b): plugin deactivated — UpdateChannel must return 503 inactive_guard error + // and the channel must remain unchanged in the store. + t.Run("plugin inactive rejects update", func(t *testing.T) { + require.True(t, th.App.GetPluginsEnvironment().Deactivate(pluginID)) + require.False(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + channelToUpdate := th.BasicChannel.DeepCopy() + channelToUpdate.DisplayName = "Should Be Rejected" + result, appErr := th.App.UpdateChannel(th.Context, channelToUpdate) + require.NotNil(t, appErr, "expected error when guard plugin is inactive") + require.Nil(t, result) + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Verify the channel was not updated. + fetched, storeErr := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, storeErr) + assert.NotEqual(t, "Should Be Rejected", fetched.DisplayName, "rejected update must not be persisted") + }) +} + +func TestChannelGuardRejectsTypeMutationFromPlugin(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Compile and activate a plugin that flips the channel Type in its replacement. + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + mutated := newChannel.DeepCopy() + // Flip Open <-> Private. + if mutated.Type == model.ChannelTypeOpen { + mutated.Type = model.ChannelTypePrivate + } else { + mutated.Type = model.ChannelTypeOpen + } + return mutated, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + require.Len(t, pluginIDs, 1) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // Register a channel guard so this goes through the guarded path. + appErr := th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID) + require.Nil(t, appErr, "RegisterChannelGuard must succeed") + + originalType := th.BasicChannel.Type + + channelToUpdate := th.BasicChannel.DeepCopy() + channelToUpdate.DisplayName = "Type Mutation Attempt" + result, appErr := th.App.UpdateChannel(th.Context, channelToUpdate) + require.NotNil(t, appErr, "expected type-mutation error") + require.Nil(t, result) + assert.Equal(t, "app.channel.update_channel.plugin_type_mutation.app_error", appErr.Id) + assert.Equal(t, 400, appErr.StatusCode) + // The error string must include the offending plugin ID (from the i18n template). + assert.Contains(t, appErr.Error(), pluginID) + + // Verify the channel type was not changed. + fetched, storeErr := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, storeErr) + assert.Equal(t, originalType, fetched.Type, "type must not be mutated by plugin replacement") +} + +func TestChannelGuardAllowsNonTypeMutationFromPlugin(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Compile and activate a plugin that modifies DisplayName but not Type. + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + modified := newChannel.DeepCopy() + modified.DisplayName = "plugin-modified-name" + return modified, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + require.Len(t, pluginIDs, 1) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // Register a channel guard so this goes through the guarded path. + appErr := th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID) + require.Nil(t, appErr, "RegisterChannelGuard must succeed") + + channelToUpdate := th.BasicChannel.DeepCopy() + channelToUpdate.DisplayName = "Original Caller Name" + result, appErr := th.App.UpdateChannel(th.Context, channelToUpdate) + require.Nil(t, appErr, "non-type-mutation replacement must succeed") + require.NotNil(t, result) + + // Verify the DB has the plugin-modified DisplayName. + fetched, storeErr := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, storeErr) + assert.Equal(t, "plugin-modified-name", fetched.DisplayName, "plugin DisplayName replacement must be persisted") +} + +// Guard blocks RestoreChannel when the guard plugin is inactive. +func TestChannelGuardBlocksRestoreWhenPluginInactive(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Compile and activate a plugin that implements ChannelWillBeRestored (allow all). + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) ChannelWillBeRestored(c *plugin.Context, channel *model.Channel) string { + return "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + require.Len(t, pluginIDs, 1) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // Archive BasicChannel so RestoreChannel has something to do. + require.Nil(t, th.App.DeleteChannel(th.Context, th.BasicChannel, th.BasicUser.Id)) + + // Register a channel guard for this channel under this plugin's ID. + appErr := th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID) + require.Nil(t, appErr, "RegisterChannelGuard must succeed") + + // Subtest (a): plugin active — RestoreChannel must succeed. + t.Run("plugin active allows restore", func(t *testing.T) { + archived, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + require.NotEqual(t, int64(0), archived.DeleteAt, "channel must be archived before restore") + + _, appErr := th.App.RestoreChannel(th.Context, archived, th.BasicUser.Id) + require.Nil(t, appErr, "expected no error when guard plugin is active") + + // Re-archive for the next subtest. + require.Nil(t, th.App.DeleteChannel(th.Context, th.BasicChannel, th.BasicUser.Id)) + }) + + // Subtest (b): plugin deactivated — RestoreChannel must return 503 inactive_guard error + // and the channel must remain archived. + t.Run("plugin inactive rejects restore", func(t *testing.T) { + require.True(t, th.App.GetPluginsEnvironment().Deactivate(pluginID)) + require.False(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + archived, err := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, err) + require.NotEqual(t, int64(0), archived.DeleteAt, "channel must be archived for this subtest") + + result, appErr := th.App.RestoreChannel(th.Context, archived, th.BasicUser.Id) + require.NotNil(t, appErr, "expected error when guard plugin is inactive") + require.Nil(t, result) + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Verify the channel was not restored (still archived). + fetched, storeErr := th.App.GetChannel(th.Context, th.BasicChannel.Id) + require.Nil(t, storeErr) + assert.NotEqual(t, int64(0), fetched.DeleteAt, "rejected restore must not change DeleteAt") + }) +} + +// --------------------------------------------------------------------------- +// Cross-cutting e2e tests for channel-guard dispatch +// --------------------------------------------------------------------------- + +// TestChannelGuardWrapperRejectsOnHookRPCError verifies that when a guard plugin's hook +// implementation panics (which net/rpc recovers and returns as a non-nil error from +// client.Call), the guarded site returns 503 app.plugin.guard_hook_failed.app_error. +// +// The first sub-test is a panic-discovery smoke test that proves the mechanism works before +// relying on it for all five sites. The remaining sub-tests cover each guarded site. +// +// Each sub-test also verifies that an unguarded channel with the same panicking plugin still +// succeeds (existing fail-open RunMultiHook swallows RPC errors per long-standing contract). +func TestChannelGuardWrapperRejectsOnHookRPCError(t *testing.T) { + mainHelper.Parallel(t) + + // panicAllPlugin is a single compiled plugin that panics in all five guarded hooks. + // One plugin, one compile — reused across every sub-test. + const panicAllPlugin = ` +package main + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type PanicPlugin struct { + plugin.MattermostPlugin +} + +func (p *PanicPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + panic("forced RPC error") +} + +func (p *PanicPlugin) MessageWillBeUpdated(c *plugin.Context, newPost *model.Post, oldPost *model.Post) (*model.Post, string) { + panic("forced RPC error") +} + +func (p *PanicPlugin) ChannelMemberWillBeAdded(c *plugin.Context, member *model.ChannelMember) (*model.ChannelMember, string) { + panic("forced RPC error") +} + +func (p *PanicPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + panic("forced RPC error") +} + +func (p *PanicPlugin) ChannelWillBeRestored(c *plugin.Context, channel *model.Channel) string { + panic("forced RPC error") +} + +func main() { + plugin.ClientMain(&PanicPlugin{}) +} +` + + // One sub-test per guarded site. Each registers the panicking guard plugin on a + // channel and asserts the guard wrapper returns 503 (Phase B fail-closed). Each also + // verifies the unguarded path with the same plugin returns no error (fail-open + // preservation for non-guarded callers). + + t.Run("MessageWillBePosted", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{panicAllPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + + guardedCh := th.CreateChannel(t, th.BasicTeam) + appErr := th.App.RegisterChannelGuard(th.Context, guardedCh.Id, pluginID) + require.Nil(t, appErr) + + _, _, appErr = th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: guardedCh.Id, + Message: "msg", + }, guardedCh, model.CreatePostFlags{SetOnline: true}) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.guard_hook_failed.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Unguarded: fail-open. + unguardedCh := th.CreateChannel(t, th.BasicTeam) + _, _, appErr2 := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: unguardedCh.Id, + Message: "unguarded", + }, unguardedCh, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr2) + }) + + t.Run("MessageWillBeUpdated", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{panicAllPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + + guardedCh := th.CreateChannel(t, th.BasicTeam) + appErr := th.App.RegisterChannelGuard(th.Context, guardedCh.Id, pluginID) + require.Nil(t, appErr) + + // Create a post to update (without the panicking plugin active on this channel yet). + // Create the initial post on BasicChannel (no guard) to avoid the guard. + initialPost := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: guardedCh.Id, + Message: "original", + } + // To create the initial post we need to temporarily bypass the guard. + // Remove guard, create post, re-add guard. + require.Nil(t, th.App.UnregisterChannelGuard(th.Context, guardedCh.Id, pluginID)) + created, _, err := th.App.CreatePost(th.Context, initialPost, guardedCh, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, err) + require.Nil(t, th.App.RegisterChannelGuard(th.Context, guardedCh.Id, pluginID)) + + updated := created.Clone() + updated.Message = "updated" + _, _, appErr = th.App.UpdatePost(th.Context, updated, nil) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.guard_hook_failed.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Unguarded: fail-open. + unguardedCh := th.CreateChannel(t, th.BasicTeam) + initial2 := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: unguardedCh.Id, + Message: "initial2", + } + created2, _, err2 := th.App.CreatePost(th.Context, initial2, unguardedCh, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, err2) + updated2 := created2.Clone() + updated2.Message = "updated2" + _, _, appErr2 := th.App.UpdatePost(th.Context, updated2, nil) + require.Nil(t, appErr2) + }) + + t.Run("ChannelMemberWillBeAdded", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{panicAllPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + + guardedCh := th.CreatePrivateChannel(t, th.BasicTeam) + appErr := th.App.RegisterChannelGuard(th.Context, guardedCh.Id, pluginID) + require.Nil(t, appErr) + + newUser := th.CreateUser(t) + _, _, teamErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, newUser.Id, "") + require.Nil(t, teamErr) + + _, appErr = th.App.AddUserToChannel(th.Context, newUser, guardedCh, false) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.guard_hook_failed.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Unguarded: fail-open. + unguardedCh := th.CreatePrivateChannel(t, th.BasicTeam) + newUser2 := th.CreateUser(t) + _, _, teamErr2 := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, newUser2.Id, "") + require.Nil(t, teamErr2) + _, appErr2 := th.App.AddUserToChannel(th.Context, newUser2, unguardedCh, false) + require.Nil(t, appErr2) + }) + + t.Run("ChannelWillBeUpdated", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{panicAllPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + + guardedCh := th.CreateChannel(t, th.BasicTeam) + appErr := th.App.RegisterChannelGuard(th.Context, guardedCh.Id, pluginID) + require.Nil(t, appErr) + + ch := guardedCh.DeepCopy() + ch.DisplayName = "Panic Test" + _, appErr = th.App.UpdateChannel(th.Context, ch) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.guard_hook_failed.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Unguarded: fail-open. + unguardedCh := th.CreateChannel(t, th.BasicTeam) + ch2 := unguardedCh.DeepCopy() + ch2.DisplayName = "Unguarded Update" + _, appErr2 := th.App.UpdateChannel(th.Context, ch2) + require.Nil(t, appErr2) + }) + + t.Run("ChannelWillBeRestored", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{panicAllPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + + guardedCh := th.CreateChannel(t, th.BasicTeam) + require.Nil(t, th.App.DeleteChannel(th.Context, guardedCh, th.BasicUser.Id)) + appErr := th.App.RegisterChannelGuard(th.Context, guardedCh.Id, pluginID) + require.Nil(t, appErr) + + archived, err := th.App.GetChannel(th.Context, guardedCh.Id) + require.Nil(t, err) + _, appErr = th.App.RestoreChannel(th.Context, archived, th.BasicUser.Id) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.guard_hook_failed.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Unguarded: fail-open. + unguardedCh := th.CreateChannel(t, th.BasicTeam) + require.Nil(t, th.App.DeleteChannel(th.Context, unguardedCh, th.BasicUser.Id)) + archived2, err2 := th.App.GetChannel(th.Context, unguardedCh.Id) + require.Nil(t, err2) + _, appErr2 := th.App.RestoreChannel(th.Context, archived2, th.BasicUser.Id) + require.Nil(t, appErr2) + }) +} + +// TestChannelGuardAllowsAllOpsWhenPluginActiveNoRejection registers a guard whose plugin +// allows every hook and exercises all five guarded sites to confirm no regression. +func TestChannelGuardAllowsAllOpsWhenPluginActiveNoRejection(t *testing.T) { + mainHelper.Parallel(t) + + th := Setup(t).InitBasic(t) + + const allowAllPlugin = ` +package main + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type AllowPlugin struct { + plugin.MattermostPlugin +} + +func (p *AllowPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + return nil, "" +} + +func (p *AllowPlugin) MessageWillBeUpdated(c *plugin.Context, newPost *model.Post, oldPost *model.Post) (*model.Post, string) { + return newPost, "" +} + +func (p *AllowPlugin) ChannelMemberWillBeAdded(c *plugin.Context, member *model.ChannelMember) (*model.ChannelMember, string) { + return member, "" +} + +func (p *AllowPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + return newChannel, "" +} + +func (p *AllowPlugin) ChannelWillBeRestored(c *plugin.Context, channel *model.Channel) string { + return "" +} + +func main() { + plugin.ClientMain(&AllowPlugin{}) +} +` + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{allowAllPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + // All five sites share the same channel so one guard covers all. + ch := th.BasicChannel + require.Nil(t, th.App.RegisterChannelGuard(th.Context, ch.Id, pluginID)) + + // Site 1: MessageWillBePosted (CreatePost). + t.Run("MessageWillBePosted", func(t *testing.T) { + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: ch.Id, + Message: "allow all test", + }, ch, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + }) + + // Site 2: MessageWillBeUpdated (UpdatePost). Create a post first on BasicChannel (no guard + // conflict — guard already registered, plugin allows). + var createdPost *model.Post + t.Run("MessageWillBeUpdated_setup", func(t *testing.T) { + var appErr *model.AppError + createdPost, _, appErr = th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: ch.Id, + Message: "original for update", + }, ch, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + }) + + t.Run("MessageWillBeUpdated", func(t *testing.T) { + require.NotNil(t, createdPost) + up := createdPost.Clone() + up.Message = "updated by allow-all guard" + _, _, appErr := th.App.UpdatePost(th.Context, up, nil) + require.Nil(t, appErr) + }) + + // Site 3: ChannelMemberWillBeAdded. Use a fresh user to guarantee AddUserToChannel + // reaches the hook (existing-membership early-return would silently skip it). + t.Run("ChannelMemberWillBeAdded", func(t *testing.T) { + newUser := th.CreateUser(t) + _, _, teamErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, newUser.Id, "") + require.Nil(t, teamErr) + _, appErr := th.App.AddUserToChannel(th.Context, newUser, ch, false) + require.Nil(t, appErr) + }) + + // Site 4: ChannelWillBeUpdated. + t.Run("ChannelWillBeUpdated", func(t *testing.T) { + update := ch.DeepCopy() + update.DisplayName = "Allow-All Guard Test" + result, appErr := th.App.UpdateChannel(th.Context, update) + require.Nil(t, appErr) + require.NotNil(t, result) + }) + + // Site 5: ChannelWillBeRestored. Archive then restore. + t.Run("ChannelWillBeRestored", func(t *testing.T) { + restoreCh := th.CreateChannel(t, th.BasicTeam) + require.Nil(t, th.App.RegisterChannelGuard(th.Context, restoreCh.Id, pluginID)) + require.Nil(t, th.App.DeleteChannel(th.Context, restoreCh, th.BasicUser.Id)) + archived, err := th.App.GetChannel(th.Context, restoreCh.Id) + require.Nil(t, err) + _, appErr := th.App.RestoreChannel(th.Context, archived, th.BasicUser.Id) + require.Nil(t, appErr) + }) +} + +// TestChannelGuardFiresHookWhenPluginActive confirms that for each of the five guarded sites, +// when a guard plugin's hook returns a rejection, the rejection comes from the hook (not from +// the guard inactive pre-check). The error reason matches the plugin-returned string. +func TestChannelGuardFiresHookWhenPluginActive(t *testing.T) { + mainHelper.Parallel(t) + + const rejectPlugin = ` +package main + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type RejectPlugin struct { + plugin.MattermostPlugin +} + +func (p *RejectPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + return nil, "guard-rejected-post" +} + +func (p *RejectPlugin) MessageWillBeUpdated(c *plugin.Context, newPost *model.Post, oldPost *model.Post) (*model.Post, string) { + return nil, "guard-rejected-update" +} + +func (p *RejectPlugin) ChannelMemberWillBeAdded(c *plugin.Context, member *model.ChannelMember) (*model.ChannelMember, string) { + return nil, "guard-rejected-member" +} + +func (p *RejectPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + return nil, "guard-rejected-channel-update" +} + +func (p *RejectPlugin) ChannelWillBeRestored(c *plugin.Context, channel *model.Channel) string { + return "guard-rejected-restore" +} + +func main() { + plugin.ClientMain(&RejectPlugin{}) +} +` + + t.Run("MessageWillBePosted", func(t *testing.T) { + th := Setup(t).InitBasic(t) + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{rejectPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + ch := th.BasicChannel + require.Nil(t, th.App.RegisterChannelGuard(th.Context, ch.Id, pluginID)) + + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: ch.Id, + Message: "msg", + }, ch, model.CreatePostFlags{SetOnline: true}) + require.NotNil(t, appErr, "plugin rejection must return error") + // The error comes from the hook (plugin active) — Id must contain the rejection reason. + assert.NotEqual(t, "app.plugin.inactive_guard.app_error", appErr.Id, "must not be inactive-guard error") + assert.Contains(t, appErr.Id, "guard-rejected-post") + }) + + t.Run("MessageWillBeUpdated", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + // Create a post BEFORE activating the reject plugin (the plugin also rejects + // MessageWillBePosted, so CreatePost would fail if the plugin were active). + initialPost, _, err := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, err) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{rejectPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID)) + + updated := initialPost.Clone() + updated.Message = "attempt" + _, _, appErr := th.App.UpdatePost(th.Context, updated, nil) + require.NotNil(t, appErr) + assert.NotEqual(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Contains(t, appErr.Id, "guard-rejected-update") + }) + + t.Run("ChannelMemberWillBeAdded", func(t *testing.T) { + th := Setup(t).InitBasic(t) + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{rejectPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + ch := th.CreatePrivateChannel(t, th.BasicTeam) + require.Nil(t, th.App.RegisterChannelGuard(th.Context, ch.Id, pluginID)) + + newUser := th.CreateUser(t) + _, _, teamErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, newUser.Id, "") + require.Nil(t, teamErr) + + _, appErr := th.App.AddUserToChannel(th.Context, newUser, ch, false) + require.NotNil(t, appErr) + assert.NotEqual(t, "app.plugin.inactive_guard.app_error", appErr.Id) + // ChannelMemberWillBeAdded rejection wraps the reason via app.channel.add_user.to.channel.rejected_by_plugin + assert.Equal(t, "app.channel.add_user.to.channel.rejected_by_plugin", appErr.Id) + }) + + t.Run("ChannelWillBeUpdated", func(t *testing.T) { + th := Setup(t).InitBasic(t) + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{rejectPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + ch := th.BasicChannel + require.Nil(t, th.App.RegisterChannelGuard(th.Context, ch.Id, pluginID)) + + update := ch.DeepCopy() + update.DisplayName = "Rejected" + _, appErr := th.App.UpdateChannel(th.Context, update) + require.NotNil(t, appErr) + assert.NotEqual(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, "app.channel.update_channel.rejected_by_plugin", appErr.Id) + }) + + t.Run("ChannelWillBeRestored", func(t *testing.T) { + th := Setup(t).InitBasic(t) + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{rejectPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + require.True(t, th.App.GetPluginsEnvironment().IsActive(pluginID)) + + ch := th.CreateChannel(t, th.BasicTeam) + require.Nil(t, th.App.DeleteChannel(th.Context, ch, th.BasicUser.Id)) + require.Nil(t, th.App.RegisterChannelGuard(th.Context, ch.Id, pluginID)) + + archived, err := th.App.GetChannel(th.Context, ch.Id) + require.Nil(t, err) + _, appErr := th.App.RestoreChannel(th.Context, archived, th.BasicUser.Id) + require.NotNil(t, appErr) + assert.NotEqual(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, "app.channel.restore_channel.rejected_by_plugin", appErr.Id) + }) +} + +// TestChannelGuardTwoPhaseDispatchOrdering installs two plugins: a guard plugin G and a +// non-guard plugin N. N uppercases the message in Phase A; G sees the uppercased message in +// Phase B. When N rejects, Phase B is not invoked. +func TestChannelGuardTwoPhaseDispatchOrdering(t *testing.T) { + mainHelper.Parallel(t) + + // Guard plugin G: allow everything; records the message it received. + // The destination file path is baked into the source at compile time so the + // plugin doesn't need to read it from the environment — process-global env + // mutation is incompatible with t.Parallel(). + makeGuardSrc := func(receivedFile string) string { + return fmt.Sprintf(` +package main + +import ( + "os" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type GuardPlugin struct { + plugin.MattermostPlugin +} + +func (p *GuardPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + _ = os.WriteFile(%q, []byte(post.Message), 0644) + return nil, "" +} + +func main() { + plugin.ClientMain(&GuardPlugin{}) +} +`, receivedFile) + } + + // Non-guard plugin N: uppercases the message. + const srcN = ` +package main + +import ( + "strings" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type NPlugin struct { + plugin.MattermostPlugin +} + +func (p *NPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + modified := post.Clone() + modified.Message = strings.ToUpper(post.Message) + return modified, "" +} + +func main() { + plugin.ClientMain(&NPlugin{}) +} +` + + // Non-guard plugin N_reject: rejects all posts. + const srcNReject = ` +package main + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type NRejectPlugin struct { + plugin.MattermostPlugin +} + +func (p *NRejectPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + return nil, "n-rejected" +} + +func main() { + plugin.ClientMain(&NRejectPlugin{}) +} +` + + // Sub-test (a): N uppercases, G receives the uppercased message. + t.Run("Phase_A_composes_into_Phase_B_input", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + // Temp file for the guard plugin to write the received message. + receivedFile, err := os.CreateTemp("", "guard_received_*.txt") + require.NoError(t, err) + receivedFile.Close() + defer os.Remove(receivedFile.Name()) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{makeGuardSrc(receivedFile.Name()), srcN}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 2) + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + + // Determine which ID belongs to G vs N based on position. + gID := pluginIDs[0] + nID := pluginIDs[1] + _ = nID // N is not registered as a guard. + + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, gID)) + + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "hello", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + + // Read the message that the guard plugin received; it must be uppercased. + received, readErr := os.ReadFile(receivedFile.Name()) + require.NoError(t, readErr) + assert.Equal(t, "HELLO", string(received), "Phase B guard must see Phase A's output (uppercased)") + }) + + // Sub-test (b): N rejects → Phase B (guard) is not invoked. + t.Run("Phase_A_rejection_skips_Phase_B", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + receivedFile, err := os.CreateTemp("", "guard_received_*.txt") + require.NoError(t, err) + receivedFile.Close() + defer os.Remove(receivedFile.Name()) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{makeGuardSrc(receivedFile.Name()), srcNReject}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 2) + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + + gID := pluginIDs[0] + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, gID)) + + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "msg", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.NotNil(t, appErr, "N_reject must reject") + assert.Contains(t, appErr.Id, "n-rejected") + + // Guard plugin must NOT have been called (file stays empty). + received, readErr := os.ReadFile(receivedFile.Name()) + require.NoError(t, readErr) + assert.Empty(t, string(received), "Phase B guard must not be invoked when Phase A rejects") + }) +} + +// TestChannelGuardMultiClaimAllMustBeActive installs two guard plugins G1 and G2 on the +// same channel. Both active → CreatePost succeeds. Deactivate either → 503. Re-activate → +// success. The plugin ID is logged server-side (operator attribution) but intentionally +// omitted from the user-facing AppError, so this test only asserts the generic 503 shape. +func TestChannelGuardMultiClaimAllMustBeActive(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + const allowPlugin = ` +package main + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type AllowPlugin struct { + plugin.MattermostPlugin +} + +func (p *AllowPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + return nil, "" +} + +func main() { + plugin.ClientMain(&AllowPlugin{}) +} +` + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{allowPlugin, allowPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 2) + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + + g1ID := pluginIDs[0] + g2ID := pluginIDs[1] + + ch := th.BasicChannel + require.Nil(t, th.App.RegisterChannelGuard(th.Context, ch.Id, g1ID)) + require.Nil(t, th.App.RegisterChannelGuard(th.Context, ch.Id, g2ID)) + + // Both active: must succeed. + t.Run("both_active_succeeds", func(t *testing.T) { + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: ch.Id, + Message: "both active", + }, ch, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + }) + + // Deactivate G1: must get the generic 503 (plugin ID is in the server log, not the AppError). + t.Run("g1_inactive_returns_503", func(t *testing.T) { + require.True(t, th.App.GetPluginsEnvironment().Deactivate(g1ID)) + + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: ch.Id, + Message: "g1 inactive", + }, ch, model.CreatePostFlags{SetOnline: true}) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Re-activate G1. + _, _, activateErr := th.App.GetPluginsEnvironment().Activate(g1ID) + require.NoError(t, activateErr) + }) + + // Deactivate G2: must get 503. + t.Run("g2_inactive_returns_503", func(t *testing.T) { + require.True(t, th.App.GetPluginsEnvironment().Deactivate(g2ID)) + + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: ch.Id, + Message: "g2 inactive", + }, ch, model.CreatePostFlags{SetOnline: true}) + require.NotNil(t, appErr) + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) + + // Re-activate G2. + _, _, activateErr := th.App.GetPluginsEnvironment().Activate(g2ID) + require.NoError(t, activateErr) + }) + + // Both re-activated: must succeed again. + t.Run("both_reactivated_succeeds", func(t *testing.T) { + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: ch.Id, + Message: "both reactivated", + }, ch, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + }) +} + +// TestChannelGuardMultiClaimPhaseBSequence verifies Phase B composition and sequencing with two +// guard plugins G1 and G2. Plugin IDs are random UUIDs at test time, so the test does not pin which +// guard sorts first; it asserts properties that hold regardless of order. +// +// a) Both allow: each prepends its tag to the message → final message contains both tags in +// PluginId-sorted-call order, proving Phase B composes left-to-right. +// +// b) Whichever guard runs first rejects → the second guard is NOT invoked (test reads +// either possible counter file and asserts at least one is empty, allowing 0 or 1 +// invocations of the second to satisfy the short-circuit contract). +// +// c) Phase A's RunMultiHookExcluding skips both guards: a third non-guard plugin N runs +// exactly once per CreatePost, while G1/G2's counters do not increment during Phase A. +func TestChannelGuardMultiClaimPhaseBSequence(t *testing.T) { + mainHelper.Parallel(t) + + // Each plugin source is built per-subtest with its counter file path baked + // in as a Go literal. Reading the path from the environment instead would + // require t.Setenv, which panics under t.Parallel. + + // G1: prepends "G1:" to the message; writes its call count to a file. + makeG1PrependSrc := func(countFile string) string { + return fmt.Sprintf(` +package main + +import ( + "fmt" + "os" + "strconv" + "strings" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type G1Plugin struct { + plugin.MattermostPlugin +} + +func (p *G1Plugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + countFile := %q + count := 0 + if data, err := os.ReadFile(countFile); err == nil { + count, _ = strconv.Atoi(strings.TrimSpace(string(data))) + } + count++ + _ = os.WriteFile(countFile, []byte(fmt.Sprintf("%%d", count)), 0644) + + modified := post.Clone() + modified.Message = "G1:" + post.Message + return modified, "" +} + +func main() { + plugin.ClientMain(&G1Plugin{}) +} +`, countFile) + } + + // G1 that rejects. + makeG1RejectSrc := func(countFile string) string { + return fmt.Sprintf(` +package main + +import ( + "fmt" + "os" + "strconv" + "strings" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type G1RejectPlugin struct { + plugin.MattermostPlugin +} + +func (p *G1RejectPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + countFile := %q + count := 0 + if data, err := os.ReadFile(countFile); err == nil { + count, _ = strconv.Atoi(strings.TrimSpace(string(data))) + } + count++ + _ = os.WriteFile(countFile, []byte(fmt.Sprintf("%%d", count)), 0644) + return nil, "g1-rejected" +} + +func main() { + plugin.ClientMain(&G1RejectPlugin{}) +} +`, countFile) + } + + // G2: prepends "G2:" to the message; writes its call count to a file. + makeG2Src := func(countFile string) string { + return fmt.Sprintf(` +package main + +import ( + "fmt" + "os" + "strconv" + "strings" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type G2Plugin struct { + plugin.MattermostPlugin +} + +func (p *G2Plugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + countFile := %q + count := 0 + if data, err := os.ReadFile(countFile); err == nil { + count, _ = strconv.Atoi(strings.TrimSpace(string(data))) + } + count++ + _ = os.WriteFile(countFile, []byte(fmt.Sprintf("%%d", count)), 0644) + + modified := post.Clone() + modified.Message = "G2:" + post.Message + return modified, "" +} + +func main() { + plugin.ClientMain(&G2Plugin{}) +} +`, countFile) + } + + // G3: counts in a temp file but never rejects (used as the third guard in phase-b tests). + makeG3Src := func(countFile string) string { + return fmt.Sprintf(` +package main + +import ( + "fmt" + "os" + "strconv" + "strings" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type G3Plugin struct { + plugin.MattermostPlugin +} + +func (p *G3Plugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + countFile := %q + count := 0 + if data, err := os.ReadFile(countFile); err == nil { + count, _ = strconv.Atoi(strings.TrimSpace(string(data))) + } + count++ + _ = os.WriteFile(countFile, []byte(fmt.Sprintf("%%d", count)), 0644) + return nil, "" +} + +func main() { + plugin.ClientMain(&G3Plugin{}) +} +`, countFile) + } + + // Non-guard plugin N: writes its call count to a file. + makeNSrc := func(countFile string) string { + return fmt.Sprintf(` +package main + +import ( + "fmt" + "os" + "strconv" + "strings" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type NPlugin struct { + plugin.MattermostPlugin +} + +func (p *NPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + countFile := %q + count := 0 + if data, err := os.ReadFile(countFile); err == nil { + count, _ = strconv.Atoi(strings.TrimSpace(string(data))) + } + count++ + _ = os.WriteFile(countFile, []byte(fmt.Sprintf("%%d", count)), 0644) + return nil, "" +} + +func main() { + plugin.ClientMain(&NPlugin{}) +} +`, countFile) + } + + // Helper to read a counter file. + readCount := func(t *testing.T, path string) int { + t.Helper() + data, err := os.ReadFile(path) + require.NoError(t, err) + s := strings.TrimSpace(string(data)) + if s == "" { + return 0 + } + n, err := strconv.Atoi(s) + require.NoError(t, err) + return n + } + + // Sub-test (a): both allow, modifications compose left-to-right. + // G1 prepends "G1:", G2 prepends "G2:" → "G2:G1:". + // Phase B order is determined by PluginId alphabetical order (resolveGuards sorts). + t.Run("composition_left_to_right", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + g1CountFile, _ := os.CreateTemp("", "g1_count_*.txt") + g1CountFile.Close() + defer os.Remove(g1CountFile.Name()) + g2CountFile, _ := os.CreateTemp("", "g2_count_*.txt") + g2CountFile.Close() + defer os.Remove(g2CountFile.Name()) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{makeG1PrependSrc(g1CountFile.Name()), makeG2Src(g2CountFile.Name())}, th.App, th.NewPluginAPI) + defer tearDown() + require.Len(t, errs, 2) + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + + // pluginIDs[0] → G1Prepend (prepends "G1:"), pluginIDs[1] → G2 (prepends "G2:"). + // resolveGuards fires Phase B in PluginId alphabetical order. Walk the sorted IDs to + // predict the expected final message and assert exact equality. + id0, id1 := pluginIDs[0], pluginIDs[1] + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, id0)) + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, id1)) + + sortedIDs := []string{id0, id1} + sort.Strings(sortedIDs) + // Each plugin prepends its tag to whatever message it receives. Walking in + // sorted order: the first plugin sees "original" and produces "G?:original"; + // the second plugin sees that and prepends its own tag. Build the expected + // result by walking backwards through the sorted list (each plugin wraps the prior). + pluginTag := map[string]string{id0: "G1:", id1: "G2:"} + expected := "original" + for _, id := range sortedIDs { + expected = pluginTag[id] + expected + } + + created, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "original", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + require.NotNil(t, created) + + // Exact equality: confirms both that both plugins ran AND that they ran in + // PluginId-sorted order. Contains would accept the wrong order. + require.Equal(t, expected, created.Message) + }) + + // Sub-test (b): a guard's rejection propagates and stops Phase B iteration. + // + // Three guard plugins are used so that the rejecter can be in the middle of the + // sorted order (two plugins cannot detect a missing short-circuit: the loop ends + // naturally after two iterations regardless). The rejecter is G1Reject (pluginIDs[0]); + // G2 (pluginIDs[1]) and G3 (pluginIDs[2]) are plain counters. After sorting the + // three plugin IDs, any plugin whose sorted position is after the rejecter MUST have a + // count of 0 (Phase B short-circuited). Any plugin before the rejecter must have count 1. + // The rejecter itself must have count 1. + t.Run("guard_rejection_stops_phase_b", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + g1CountFile, _ := os.CreateTemp("", "g1_count_*.txt") + g1CountFile.Close() + defer os.Remove(g1CountFile.Name()) + g2CountFile, _ := os.CreateTemp("", "g2_count_*.txt") + g2CountFile.Close() + defer os.Remove(g2CountFile.Name()) + g3CountFile, _ := os.CreateTemp("", "g3_count_*.txt") + g3CountFile.Close() + defer os.Remove(g3CountFile.Name()) + + // pluginIDs[0] → G1Reject (rejecter, writes to g1CountFile) + // pluginIDs[1] → G2 (counter, writes to g2CountFile) + // pluginIDs[2] → G3 (counter, writes to g3CountFile) + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{makeG1RejectSrc(g1CountFile.Name()), makeG2Src(g2CountFile.Name()), makeG3Src(g3CountFile.Name())}, th.App, th.NewPluginAPI) + defer tearDown() + require.Len(t, errs, 3) + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + require.NoError(t, errs[2]) + + rejecterID := pluginIDs[0] + for _, id := range pluginIDs { + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, id)) + } + + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "msg", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.NotNil(t, appErr, "rejection from a guard in Phase B must propagate") + assert.Contains(t, appErr.Id, "g1-rejected", "the reject plugin must be the source of the error") + + // Map each plugin ID to its count file so we can check by sorted position. + countFile := map[string]string{ + pluginIDs[0]: g1CountFile.Name(), + pluginIDs[1]: g2CountFile.Name(), + pluginIDs[2]: g3CountFile.Name(), + } + sortedIDs := []string{pluginIDs[0], pluginIDs[1], pluginIDs[2]} + sort.Strings(sortedIDs) + + // Find rejecter's index in the sorted order. + rejecterIdx := -1 + for i, id := range sortedIDs { + if id == rejecterID { + rejecterIdx = i + break + } + } + require.NotEqual(t, -1, rejecterIdx) + + // Rejecter must have run exactly once. + rejecterCount := readCount(t, countFile[rejecterID]) + assert.Equal(t, 1, rejecterCount, "rejecter plugin must have been invoked exactly once") + + // Plugins sorted before the rejecter: each must have run exactly once. + for _, id := range sortedIDs[:rejecterIdx] { + c := readCount(t, countFile[id]) + assert.Equal(t, 1, c, "plugin sorted before rejecter must have run once") + } + + // Plugins sorted after the rejecter: Phase B must have short-circuited; count must be 0. + for _, id := range sortedIDs[rejecterIdx+1:] { + c := readCount(t, countFile[id]) + assert.Equal(t, 0, c, "plugin sorted after rejecter must not have been invoked (short-circuit)") + } + }) + + // Sub-test (c): Phase A's RunMultiHookExcluding skips guards; non-guard N runs once. + t.Run("phase_a_excludes_guards", func(t *testing.T) { + th := Setup(t).InitBasic(t) + + g1CountFile, _ := os.CreateTemp("", "g1_count_*.txt") + g1CountFile.Close() + defer os.Remove(g1CountFile.Name()) + g2CountFile, _ := os.CreateTemp("", "g2_count_*.txt") + g2CountFile.Close() + defer os.Remove(g2CountFile.Name()) + nCountFile, _ := os.CreateTemp("", "n_count_*.txt") + nCountFile.Close() + defer os.Remove(nCountFile.Name()) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{makeG1PrependSrc(g1CountFile.Name()), makeG2Src(g2CountFile.Name()), makeNSrc(nCountFile.Name())}, th.App, th.NewPluginAPI) + defer tearDown() + require.Len(t, errs, 3) + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + require.NoError(t, errs[2]) + + g1RegID := pluginIDs[0] + g2RegID := pluginIDs[1] + // pluginIDs[2] is N — not registered as a guard. + + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, g1RegID)) + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, g2RegID)) + + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "phase-a-test", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr) + + // N (non-guard) runs exactly once during Phase A. + nCount := readCount(t, nCountFile.Name()) + assert.Equal(t, 1, nCount, "non-guard plugin N must run once in Phase A") + + // G1 and G2 each run exactly once during Phase B (not in Phase A). + g1Count := readCount(t, g1CountFile.Name()) + g2Count := readCount(t, g2CountFile.Name()) + assert.Equal(t, 1, g1Count, "G1 must run once in Phase B only") + assert.Equal(t, 1, g2Count, "G2 must run once in Phase B only") + }) +} + +// TestChannelGuardNoCheckWhenNoRow confirms that channels with no guard registered +// proceed normally and no guard-related error IDs fire. +func TestChannelGuardNoCheckWhenNoRow(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // No plugin installed. Channels have no guard rows. + // CreatePost must succeed without any guard-related error. + post := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "no guard test", + } + created, _, appErr := th.App.CreatePost(th.Context, post, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr, "CreatePost on unguarded channel must succeed") + require.NotNil(t, created) + assert.NotEqual(t, "", created.Id, "created post must have an ID") + + // UpdatePost must also succeed. + updated := created.Clone() + updated.Message = "updated no guard" + result, _, appErr2 := th.App.UpdatePost(th.Context, updated, nil) + require.Nil(t, appErr2, "UpdatePost on unguarded channel must succeed") + require.NotNil(t, result) + + // UpdateChannel must succeed. + ch := th.BasicChannel.DeepCopy() + ch.DisplayName = "No Guard Channel Update" + updatedCh, appErr3 := th.App.UpdateChannel(th.Context, ch) + require.Nil(t, appErr3, "UpdateChannel on unguarded channel must succeed") + require.NotNil(t, updatedCh) +} + +// TestChannelGuardFailsClosedWhenPluginsDisabled covers the resolveGuards branch where the +// plugin system is off (PluginSettings.Enable == false) but a guard row still exists for the +// channel. The user-facing AppError shape is the same generic 503 used for inactive guards +// (the distinguishing operator-facing error_id lives in the server log via +// logAndErrPluginsDisabled), so this test verifies fail-closed enforcement, not log content. +func TestChannelGuardFailsClosedWhenPluginsDisabled(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{ + ` + package main + + import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" + ) + + type MyPlugin struct { + plugin.MattermostPlugin + } + + func (p *MyPlugin) MessageWillBePosted(c *plugin.Context, post *model.Post) (*model.Post, string) { + return nil, "" + } + + func main() { + plugin.ClientMain(&MyPlugin{}) + } + `, + }, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID)) + + // Disable the plugin system globally. resolveGuards now sees env == nil while + // guards remain in the cache, taking the logAndErrPluginsDisabled branch. + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.PluginSettings.Enable = false + }) + + _, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "plugins disabled", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.NotNil(t, appErr, "guarded channel must fail-closed when plugin system is disabled") + assert.Equal(t, "app.plugin.inactive_guard.app_error", appErr.Id) + assert.Equal(t, 503, appErr.StatusCode) +} + +// TestChannelGuardAllowByDefaultForUnimplementedHook covers the contract documented in +// guarded_hooks.go: a plugin may register a channel guard without implementing every +// guarded hook. When Phase B reaches such a claimant, the *WithRPCErr companion's +// g.implemented[] gate skips the RPC entirely and returns zero values with a nil +// error — which the helper treats as "no opinion" rather than rejection. The op succeeds. +func TestChannelGuardAllowByDefaultForUnimplementedHook(t *testing.T) { + mainHelper.Parallel(t) + + // partialPlugin implements ChannelMemberWillBeAdded only; all other guarded-hook + // companions return "not implemented" (zero values, nil error), which the helpers + // treat as allow-by-default. + const partialPlugin = ` +package main + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type PartialPlugin struct { + plugin.MattermostPlugin +} + +func (p *PartialPlugin) ChannelMemberWillBeAdded(c *plugin.Context, member *model.ChannelMember) (*model.ChannelMember, string) { + return nil, "" +} + +func main() { + plugin.ClientMain(&PartialPlugin{}) +} +` + + th := Setup(t).InitBasic(t) + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{partialPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 1) + require.NoError(t, errs[0]) + pluginID := pluginIDs[0] + + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, pluginID)) + + // CreatePost: plugin does not implement MessageWillBePosted → allow-by-default. + created, _, appErr := th.App.CreatePost(th.Context, &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: "allow by default", + }, th.BasicChannel, model.CreatePostFlags{SetOnline: true}) + require.Nil(t, appErr, "CreatePost must succeed when guard plugin doesn't implement MessageWillBePosted") + require.NotNil(t, created) + + // UpdatePost: plugin does not implement MessageWillBeUpdated → allow-by-default. + updated := created.Clone() + updated.Message = "allow by default updated" + result, _, appErr2 := th.App.UpdatePost(th.Context, updated, nil) + require.Nil(t, appErr2, "UpdatePost must succeed when guard plugin doesn't implement MessageWillBeUpdated") + require.NotNil(t, result) + + // UpdateChannel: plugin does not implement ChannelWillBeUpdated → allow-by-default. + chCopy := th.BasicChannel.DeepCopy() + chCopy.DisplayName = "Allow by Default Update" + updatedCh, appErr3 := th.App.UpdateChannel(th.Context, chCopy) + require.Nil(t, appErr3, "UpdateChannel must succeed when guard plugin doesn't implement ChannelWillBeUpdated") + require.NotNil(t, updatedCh) + + // RestoreChannel: plugin does not implement ChannelWillBeRestored → allow-by-default. + t.Run("RestoreChannel", func(t *testing.T) { + th2 := Setup(t).InitBasic(t) + tearDown2, pluginIDs2, errs2 := SetAppEnvironmentWithPlugins(t, []string{partialPlugin}, th2.App, th2.NewPluginAPI) + defer tearDown2() + require.Len(t, errs2, 1) + require.NoError(t, errs2[0]) + pluginID2 := pluginIDs2[0] + + restoreCh := th2.CreateChannel(t, th2.BasicTeam) + require.Nil(t, th2.App.RegisterChannelGuard(th2.Context, restoreCh.Id, pluginID2)) + require.Nil(t, th2.App.DeleteChannel(th2.Context, restoreCh, th2.BasicUser.Id)) + + archived, err := th2.App.GetChannel(th2.Context, restoreCh.Id) + require.Nil(t, err) + _, appErr := th2.App.RestoreChannel(th2.Context, archived, th2.BasicUser.Id) + require.Nil(t, appErr, "RestoreChannel must succeed when guard plugin doesn't implement ChannelWillBeRestored") + }) +} + +// TestChannelGuardRejectsTypeMutationFromPhaseAPlugin covers the type-mutation guard at +// guarded_hooks.go line ~339: when one or more guards exist for a channel, a non-guard +// (Phase A) plugin that mutates Channel.Type must be rejected. This is the Phase A branch +// of the type-mutation check, distinct from the Phase B branch covered by +// TestChannelGuardRejectsTypeMutationFromPlugin. +func TestChannelGuardRejectsTypeMutationFromPhaseAPlugin(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic(t) + + // Plugin G: passive guard (allows everything). Phase B has nothing to do. + const guardPlugin = ` +package main + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type GuardPlugin struct { + plugin.MattermostPlugin +} + +func (p *GuardPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + return nil, "" +} + +func main() { + plugin.ClientMain(&GuardPlugin{}) +} +` + + // Plugin N: non-guard plugin that mutates Channel.Type in ChannelWillBeUpdated. + // On a guarded channel, this must be rejected with the type-mutation AppError. + const mutatorPlugin = ` +package main + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/model" +) + +type MutatorPlugin struct { + plugin.MattermostPlugin +} + +func (p *MutatorPlugin) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + mutated := newChannel + mutated.Type = model.ChannelTypePrivate + return mutated, "" +} + +func main() { + plugin.ClientMain(&MutatorPlugin{}) +} +` + + tearDown, pluginIDs, errs := SetAppEnvironmentWithPlugins(t, []string{guardPlugin, mutatorPlugin}, th.App, th.NewPluginAPI) + defer tearDown() + + require.Len(t, errs, 2) + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + guardID := pluginIDs[0] + // Mutator plugin (pluginIDs[1]) is intentionally NOT registered as a guard. + + // Use a public channel so type mutation Public → Private is observable. + require.Equal(t, model.ChannelTypeOpen, th.BasicChannel.Type) + require.Nil(t, th.App.RegisterChannelGuard(th.Context, th.BasicChannel.Id, guardID)) + + chCopy := th.BasicChannel.DeepCopy() + chCopy.DisplayName = "Phase A type mutation" + _, appErr := th.App.UpdateChannel(th.Context, chCopy) + require.NotNil(t, appErr, "Phase A plugin mutating Channel.Type on a guarded channel must be rejected") + assert.Equal(t, "app.channel.update_channel.plugin_type_mutation.app_error", appErr.Id) + assert.Equal(t, 400, appErr.StatusCode) +} diff --git a/server/channels/app/post.go b/server/channels/app/post.go index 71ede54b87ad..2b8e05506f76 100644 --- a/server/channels/app/post.go +++ b/server/channels/app/post.go @@ -331,39 +331,14 @@ func (a *App) CreatePost(rctx request.CTX, post *model.Post, channel *model.Chan } } - var metadata *model.PostMetadata - if post.Metadata != nil { - metadata = post.Metadata.Copy() - } - var rejectionError *model.AppError pluginContext := pluginContext(rctx) if post.Type != model.PostTypeBurnOnRead { - a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool { - replacementPost, rejectionReason := hooks.MessageWillBePosted(pluginContext, post.ForPlugin()) - if rejectionReason != "" { - id := "Post rejected by plugin. " + rejectionReason - if rejectionReason == plugin.DismissPostError { - id = plugin.DismissPostError - } - rejectionError = model.NewAppError("createPost", id, nil, "", http.StatusBadRequest) - return false - } - if replacementPost != nil { - post = replacementPost - if post.Metadata != nil && metadata != nil { - post.Metadata.Priority = metadata.Priority - } else { - post.Metadata = metadata - } - } - - return true - }, plugin.MessageWillBePostedID) - - if rejectionError != nil { - return nil, false, rejectionError + newPost, guardErr := a.runGuardedMessageWillBePosted(rctx, post) + if guardErr != nil { + return nil, false, guardErr } + post = newPost } // Pre-fill the CreateAt field for link previews to get the correct timestamp. @@ -930,15 +905,11 @@ func (a *App) UpdatePost(rctx request.CTX, receivedUpdatedPost *model.Post, upda oldPost.RemoteId = new(*receivedUpdatedPost.RemoteId) } - var rejectionReason string - pluginContext := pluginContext(rctx) if newPost.Type != model.PostTypeBurnOnRead { - a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool { - newPost, rejectionReason = hooks.MessageWillBeUpdated(pluginContext, newPost.ForPlugin(), oldPost.ForPlugin()) - return newPost != nil - }, plugin.MessageWillBeUpdatedID) - if newPost == nil { - return nil, false, model.NewAppError("UpdatePost", "Post rejected by plugin. "+rejectionReason, nil, "", http.StatusBadRequest) + var appErr2 *model.AppError + newPost, appErr2 = a.runGuardedMessageWillBeUpdated(rctx, newPost, oldPost) + if appErr2 != nil { + return nil, false, appErr2 } } @@ -963,12 +934,13 @@ func (a *App) UpdatePost(rctx request.CTX, receivedUpdatedPost *model.Post, upda } } + pCtx := pluginContext(rctx) pluginOldPost := oldPost.ForPlugin() pluginNewPost := newPost.ForPlugin() if newPost.Type != model.PostTypeBurnOnRead { a.Srv().Go(func() { a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool { - hooks.MessageHasBeenUpdated(pluginContext, pluginNewPost, pluginOldPost) + hooks.MessageHasBeenUpdated(pCtx, pluginNewPost, pluginOldPost) return true }, plugin.MessageHasBeenUpdatedID) }) @@ -1011,6 +983,8 @@ func (a *App) UpdatePost(rctx request.CTX, receivedUpdatedPost *model.Post, upda } } + a.applyPostWillBeConsumedHook(&rpost) + message := model.NewWebSocketEvent(model.WebsocketEventPostEdited, "", rpost.ChannelId, "", nil, "") appErr = a.publishWebsocketEventForPost(rctx, rpost, message) diff --git a/server/channels/app/scheduled_post.go b/server/channels/app/scheduled_post.go index 3e23c43780e9..872c6c7229a1 100644 --- a/server/channels/app/scheduled_post.go +++ b/server/channels/app/scheduled_post.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" ) @@ -39,6 +40,23 @@ func (a *App) SaveScheduledPost(rctx request.CTX, scheduledPost *model.Scheduled return nil, model.NewAppError("App.scheduledPostPreSaveChecks", "app.save_scheduled_post.channel_deleted.app_error", map[string]any{"user_id": scheduledPost.UserId, "channel_id": scheduledPost.ChannelId}, "", http.StatusBadRequest) } + var rejectionReason string + pluginContext := pluginContext(rctx) + a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool { + replacement, reason := hooks.ScheduledPostWillBeCreated(pluginContext, scheduledPost) + if reason != "" { + rejectionReason = reason + return false + } + if replacement != nil { + scheduledPost = replacement + } + return true + }, plugin.ScheduledPostWillBeCreatedID) + if rejectionReason != "" { + return nil, model.NewAppError("SaveScheduledPost", "app.scheduled_post.save.rejected_by_plugin", map[string]any{"Reason": rejectionReason}, "", http.StatusBadRequest) + } + savedScheduledPost, err := a.Srv().Store().ScheduledPost().CreateScheduledPost(scheduledPost) if err != nil { return nil, model.NewAppError("App.ScheduledPost", "app.save_scheduled_post.save.app_error", map[string]any{"user_id": scheduledPost.UserId, "channel_id": scheduledPost.ChannelId}, "", http.StatusBadRequest).Wrap(err) @@ -86,6 +104,23 @@ func (a *App) UpdateScheduledPost(rctx request.CTX, userId string, scheduledPost // updated scheduled post. It's better to do this before calling update than after. scheduledPost.RestoreNonUpdatableFields(existingScheduledPost) + var rejectionReason string + pluginContext := pluginContext(rctx) + a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool { + replacement, reason := hooks.ScheduledPostWillBeCreated(pluginContext, scheduledPost) + if reason != "" { + rejectionReason = reason + return false + } + if replacement != nil { + scheduledPost = replacement + } + return true + }, plugin.ScheduledPostWillBeCreatedID) + if rejectionReason != "" { + return nil, model.NewAppError("UpdateScheduledPost", "app.scheduled_post.update.rejected_by_plugin", map[string]any{"Reason": rejectionReason}, "", http.StatusBadRequest) + } + if err := a.Srv().Store().ScheduledPost().UpdatedScheduledPost(scheduledPost); err != nil { return nil, model.NewAppError("app.UpdateScheduledPost", "app.update_scheduled_post.update.error", map[string]any{"user_id": userId, "scheduled_post_id": scheduledPost.Id}, "", http.StatusInternalServerError).Wrap(err) } diff --git a/server/channels/app/web_broadcast_hooks.go b/server/channels/app/web_broadcast_hooks.go index 3db2fa7fcbde..f8ba6713a487 100644 --- a/server/channels/app/web_broadcast_hooks.go +++ b/server/channels/app/web_broadcast_hooks.go @@ -25,6 +25,7 @@ const ( broadcastBurnOnRead = "burn_on_read" broadcastBurnOnReadReaction = "burn_on_read_reaction" broadcastAbacFiles = "abac_files" + broadcastOnlyChannelAdmins = "only_channel_admins" ) func (s *Server) makeBroadcastHooks() map[string]platform.BroadcastHook { @@ -37,6 +38,7 @@ func (s *Server) makeBroadcastHooks() map[string]platform.BroadcastHook { broadcastBurnOnRead: &burnOnReadBroadcastHook{}, broadcastBurnOnReadReaction: &burnOnReadReactionBroadcastHook{}, broadcastAbacFiles: &abacFilesBroadcastHook{}, + broadcastOnlyChannelAdmins: &onlyChannelAdminsBroadcastHook{}, } } @@ -505,6 +507,34 @@ func (h *abacFilesBroadcastHook) stripFilesFromMessage(msg *platform.HookedWebSo return nil } +// onlyChannelAdminsBroadcastHook narrows a channel-scoped broadcast to the +// channel-admin subset of the channel's members. The hook arg +// `channel_admin_user_ids` is the precomputed list of admin user ids at publish +// time; recipients not in that set have the event rejected. +// +// Pair with `Broadcast{ChannelId: channelId}` so the platform's existing +// channel-member fan-out is the outer bound and this hook simply filters +// non-admin members out. +type onlyChannelAdminsBroadcastHook struct{} + +func useOnlyChannelAdminsHook(message *model.WebSocketEvent, channelAdminUserIds []string) { + message.GetBroadcast().AddHook(broadcastOnlyChannelAdmins, map[string]any{ + "channel_admin_user_ids": model.StringArray(channelAdminUserIds), + }) +} + +func (h *onlyChannelAdminsBroadcastHook) Process(msg *platform.HookedWebSocketEvent, webConn *platform.WebConn, args map[string]any) error { + adminUserIDs, err := getTypedArg[model.StringArray](args, "channel_admin_user_ids") + if err != nil { + return errors.Wrap(err, "Invalid channel_admin_user_ids value passed to onlyChannelAdminsBroadcastHook") + } + + if !slices.Contains(adminUserIDs, webConn.UserId) { + msg.Event().Reject() + } + return nil +} + func incrementWebsocketCounter(wc *platform.WebConn) { if wc.Platform.Metrics() == nil { return diff --git a/server/channels/db/migrations/migrations.list b/server/channels/db/migrations/migrations.list index 73a7ed10fd13..d85e7155fcd7 100644 --- a/server/channels/db/migrations/migrations.list +++ b/server/channels/db/migrations/migrations.list @@ -365,3 +365,7 @@ channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_ channels/db/migrations/postgres/000183_create_channel_join_requests_user_status_index.up.sql channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.down.sql channels/db/migrations/postgres/000184_add_lastused_to_incoming_webhooks.up.sql +channels/db/migrations/postgres/000185_create_channel_guards.down.sql +channels/db/migrations/postgres/000185_create_channel_guards.up.sql +channels/db/migrations/postgres/000186_create_channel_guards_plugin_id_index.down.sql +channels/db/migrations/postgres/000186_create_channel_guards_plugin_id_index.up.sql diff --git a/server/channels/db/migrations/postgres/000185_create_channel_guards.down.sql b/server/channels/db/migrations/postgres/000185_create_channel_guards.down.sql new file mode 100644 index 000000000000..0a327247b93a --- /dev/null +++ b/server/channels/db/migrations/postgres/000185_create_channel_guards.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ChannelGuards; diff --git a/server/channels/db/migrations/postgres/000185_create_channel_guards.up.sql b/server/channels/db/migrations/postgres/000185_create_channel_guards.up.sql new file mode 100644 index 000000000000..141f1d3f5bc3 --- /dev/null +++ b/server/channels/db/migrations/postgres/000185_create_channel_guards.up.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS ChannelGuards ( + ChannelId varchar(26) NOT NULL, + PluginId varchar(190) NOT NULL, + CreatedAt bigint NOT NULL, + PRIMARY KEY (ChannelId, PluginId) +); diff --git a/server/channels/db/migrations/postgres/000186_create_channel_guards_plugin_id_index.down.sql b/server/channels/db/migrations/postgres/000186_create_channel_guards_plugin_id_index.down.sql new file mode 100644 index 000000000000..e523ed0283ae --- /dev/null +++ b/server/channels/db/migrations/postgres/000186_create_channel_guards_plugin_id_index.down.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +DROP INDEX CONCURRENTLY IF EXISTS idx_channel_guards_plugin_id; diff --git a/server/channels/db/migrations/postgres/000186_create_channel_guards_plugin_id_index.up.sql b/server/channels/db/migrations/postgres/000186_create_channel_guards_plugin_id_index.up.sql new file mode 100644 index 000000000000..2b196e9e1ebf --- /dev/null +++ b/server/channels/db/migrations/postgres/000186_create_channel_guards_plugin_id_index.up.sql @@ -0,0 +1,2 @@ +-- morph:nontransactional +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_channel_guards_plugin_id ON ChannelGuards(PluginId); diff --git a/server/channels/store/layer_generators/main.go b/server/channels/store/layer_generators/main.go index 9c244cef93e9..796b4152891b 100644 --- a/server/channels/store/layer_generators/main.go +++ b/server/channels/store/layer_generators/main.go @@ -184,6 +184,8 @@ func generateLayer(name, templateFile string) ([]byte, error) { switch result { case "*PostReminderMetadata": returns = append(returns, fmt.Sprintf("*store.%s", strings.TrimPrefix(result, "*"))) + case "[]*ChannelGuard": + returns = append(returns, fmt.Sprintf("[]*store.%s", strings.TrimPrefix(result, "[]*"))) default: returns = append(returns, result) } @@ -243,7 +245,7 @@ func generateLayer(name, templateFile string) ([]byte, error) { switch param.Type { case "ChannelSearchOpts", "UserGetByIdsOpts", "ThreadMembershipOpts", "GetPolicyOptions": paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type)) - case "*UserGetByIdsOpts", "*SidebarCategorySearchOpts": + case "*UserGetByIdsOpts", "*SidebarCategorySearchOpts", "*ChannelGuard": paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.%s", param.Name, strings.TrimPrefix(param.Type, "*"))) default: paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type)) @@ -257,7 +259,7 @@ func generateLayer(name, templateFile string) ([]byte, error) { switch param.Type { case "ChannelSearchOpts", "UserGetByIdsOpts", "ThreadMembershipOpts", "GetPolicyOptions": paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type)) - case "*UserGetByIdsOpts", "*SidebarCategorySearchOpts": + case "*UserGetByIdsOpts", "*SidebarCategorySearchOpts", "*ChannelGuard": paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.%s", param.Name, strings.TrimPrefix(param.Type, "*"))) default: paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type)) diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index d097e33f9300..3267d96082cd 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -27,6 +27,7 @@ type RetryLayer struct { BotStore store.BotStore ChannelStore store.ChannelStore ChannelBookmarkStore store.ChannelBookmarkStore + ChannelGuardStore store.ChannelGuardStore ChannelJoinRequestStore store.ChannelJoinRequestStore ChannelMemberHistoryStore store.ChannelMemberHistoryStore ClusterDiscoveryStore store.ClusterDiscoveryStore @@ -108,6 +109,10 @@ func (s *RetryLayer) ChannelBookmark() store.ChannelBookmarkStore { return s.ChannelBookmarkStore } +func (s *RetryLayer) ChannelGuard() store.ChannelGuardStore { + return s.ChannelGuardStore +} + func (s *RetryLayer) ChannelJoinRequest() store.ChannelJoinRequestStore { return s.ChannelJoinRequestStore } @@ -347,6 +352,11 @@ type RetryLayerChannelBookmarkStore struct { Root *RetryLayer } +type RetryLayerChannelGuardStore struct { + store.ChannelGuardStore + Root *RetryLayer +} + type RetryLayerChannelJoinRequestStore struct { store.ChannelJoinRequestStore Root *RetryLayer @@ -3952,6 +3962,90 @@ func (s *RetryLayerChannelBookmarkStore) UpdateSortOrder(bookmarkID string, chan } +func (s *RetryLayerChannelGuardStore) Delete(rctx request.CTX, channelID string, pluginID string) (int64, error) { + + tries := 0 + for { + result, err := s.ChannelGuardStore.Delete(rctx, channelID, pluginID) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelGuardStore) GetAll(rctx request.CTX) ([]*store.ChannelGuard, error) { + + tries := 0 + for { + result, err := s.ChannelGuardStore.GetAll(rctx) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelGuardStore) GetForChannel(rctx request.CTX, channelID string) ([]*store.ChannelGuard, error) { + + tries := 0 + for { + result, err := s.ChannelGuardStore.GetForChannel(rctx, channelID) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + +func (s *RetryLayerChannelGuardStore) Save(rctx request.CTX, guard *store.ChannelGuard) error { + + tries := 0 + for { + err := s.ChannelGuardStore.Save(rctx, guard) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerChannelJoinRequestStore) CountPending(channelId string) (int64, error) { tries := 0 @@ -16201,6 +16295,27 @@ func (s *RetryLayerUserStore) DeactivateMagicLinkGuests() ([]string, error) { } +func (s *RetryLayerUserStore) DecrementFailedPasswordAttempts(userID string) error { + + tries := 0 + for { + err := s.UserStore.DecrementFailedPasswordAttempts(userID) + if err == nil { + return nil + } + if !isRepeatableError(err) { + return err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerUserStore) DemoteUserToGuest(userID string) (*model.User, error) { tries := 0 @@ -17455,11 +17570,11 @@ func (s *RetryLayerUserStore) StoreMfaUsedTimestamps(userID string, ts []int) er } -func (s *RetryLayerUserStore) Update(rctx request.CTX, user *model.User, allowRoleUpdate bool) (*model.UserUpdate, error) { +func (s *RetryLayerUserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { tries := 0 for { - result, err := s.UserStore.Update(rctx, user, allowRoleUpdate) + result, err := s.UserStore.TryIncrementFailedPasswordAttempts(userID, maxAttempts) if err == nil { return result, nil } @@ -17476,11 +17591,11 @@ func (s *RetryLayerUserStore) Update(rctx request.CTX, user *model.User, allowRo } -func (s *RetryLayerUserStore) UpdateAuthData(userID string, service string, authData *string, email string, resetMfa bool) (string, error) { +func (s *RetryLayerUserStore) Update(rctx request.CTX, user *model.User, allowRoleUpdate bool) (*model.UserUpdate, error) { tries := 0 for { - result, err := s.UserStore.UpdateAuthData(userID, service, authData, email, resetMfa) + result, err := s.UserStore.Update(rctx, user, allowRoleUpdate) if err == nil { return result, nil } @@ -17497,32 +17612,11 @@ func (s *RetryLayerUserStore) UpdateAuthData(userID string, service string, auth } -func (s *RetryLayerUserStore) UpdateFailedPasswordAttempts(userID string, attempts int) error { - - tries := 0 - for { - err := s.UserStore.UpdateFailedPasswordAttempts(userID, attempts) - if err == nil { - return nil - } - if !isRepeatableError(err) { - return err - } - tries++ - if tries >= 3 { - err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return err - } - timepkg.Sleep(100 * timepkg.Millisecond) - } - -} - -func (s *RetryLayerUserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { +func (s *RetryLayerUserStore) UpdateAuthData(userID string, service string, authData *string, email string, resetMfa bool) (string, error) { tries := 0 for { - result, err := s.UserStore.TryIncrementFailedPasswordAttempts(userID, maxAttempts) + result, err := s.UserStore.UpdateAuthData(userID, service, authData, email, resetMfa) if err == nil { return result, nil } @@ -17539,11 +17633,11 @@ func (s *RetryLayerUserStore) TryIncrementFailedPasswordAttempts(userID string, } -func (s *RetryLayerUserStore) DecrementFailedPasswordAttempts(userID string) error { +func (s *RetryLayerUserStore) UpdateFailedPasswordAttempts(userID string, attempts int) error { tries := 0 for { - err := s.UserStore.DecrementFailedPasswordAttempts(userID) + err := s.UserStore.UpdateFailedPasswordAttempts(userID, attempts) if err == nil { return nil } @@ -18733,6 +18827,7 @@ func New(childStore store.Store) *RetryLayer { newStore.BotStore = &RetryLayerBotStore{BotStore: childStore.Bot(), Root: &newStore} newStore.ChannelStore = &RetryLayerChannelStore{ChannelStore: childStore.Channel(), Root: &newStore} newStore.ChannelBookmarkStore = &RetryLayerChannelBookmarkStore{ChannelBookmarkStore: childStore.ChannelBookmark(), Root: &newStore} + newStore.ChannelGuardStore = &RetryLayerChannelGuardStore{ChannelGuardStore: childStore.ChannelGuard(), Root: &newStore} newStore.ChannelJoinRequestStore = &RetryLayerChannelJoinRequestStore{ChannelJoinRequestStore: childStore.ChannelJoinRequest(), Root: &newStore} newStore.ChannelMemberHistoryStore = &RetryLayerChannelMemberHistoryStore{ChannelMemberHistoryStore: childStore.ChannelMemberHistory(), Root: &newStore} newStore.ClusterDiscoveryStore = &RetryLayerClusterDiscoveryStore{ClusterDiscoveryStore: childStore.ClusterDiscovery(), Root: &newStore} diff --git a/server/channels/store/retrylayer/retrylayer_test.go b/server/channels/store/retrylayer/retrylayer_test.go index 7cb965e53b38..0010f1d3a2b0 100644 --- a/server/channels/store/retrylayer/retrylayer_test.go +++ b/server/channels/store/retrylayer/retrylayer_test.go @@ -19,6 +19,7 @@ func genStore() *mocks.Store { mock.On("Audit").Return(&mocks.AuditStore{}) mock.On("Bot").Return(&mocks.BotStore{}) mock.On("Channel").Return(&mocks.ChannelStore{}) + mock.On("ChannelGuard").Return(&mocks.ChannelGuardStore{}) mock.On("ChannelMemberHistory").Return(&mocks.ChannelMemberHistoryStore{}) mock.On("ChannelBookmark").Return(&mocks.ChannelBookmarkStore{}) mock.On("ClusterDiscovery").Return(&mocks.ClusterDiscoveryStore{}) diff --git a/server/channels/store/sqlstore/channel_guard_store.go b/server/channels/store/sqlstore/channel_guard_store.go new file mode 100644 index 000000000000..abbc9fbdea55 --- /dev/null +++ b/server/channels/store/sqlstore/channel_guard_store.go @@ -0,0 +1,83 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + sq "github.com/mattermost/squirrel" + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" +) + +type SqlChannelGuardStore struct { + *SqlStore + + channelGuardSelectQuery sq.SelectBuilder +} + +func newSqlChannelGuardStore(sqlStore *SqlStore) store.ChannelGuardStore { + s := &SqlChannelGuardStore{SqlStore: sqlStore} + + s.channelGuardSelectQuery = s.getQueryBuilder(). + Select("ChannelId", "PluginId", "CreatedAt"). + From("ChannelGuards") + + return s +} + +func (s *SqlChannelGuardStore) Save(rctx request.CTX, guard *store.ChannelGuard) error { + builder := s.getQueryBuilder(). + Insert("ChannelGuards"). + Columns("ChannelId", "PluginId", "CreatedAt"). + Values(guard.ChannelId, guard.PluginId, guard.CreatedAt). + SuffixExpr(sq.Expr("ON CONFLICT (ChannelId, PluginId) DO NOTHING")) + + if _, err := s.GetMaster().ExecBuilder(builder); err != nil { + return errors.Wrapf(err, "failed to save channel guard for channel=%s plugin=%s", guard.ChannelId, guard.PluginId) + } + + return nil +} + +func (s *SqlChannelGuardStore) Delete(rctx request.CTX, channelID, pluginID string) (int64, error) { + builder := s.getQueryBuilder(). + Delete("ChannelGuards"). + Where(sq.Eq{ + "ChannelId": channelID, + "PluginId": pluginID, + }) + + result, err := s.GetMaster().ExecBuilder(builder) + if err != nil { + return 0, errors.Wrapf(err, "failed to delete channel guard for channel=%s plugin=%s", channelID, pluginID) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, errors.Wrapf(err, "failed to get rows affected for channel guard delete channel=%s plugin=%s", channelID, pluginID) + } + + return rowsAffected, nil +} + +func (s *SqlChannelGuardStore) GetForChannel(rctx request.CTX, channelID string) ([]*store.ChannelGuard, error) { + query := s.channelGuardSelectQuery.Where(sq.Eq{"ChannelId": channelID}) + + guards := []*store.ChannelGuard{} + if err := s.DBXFromContext(rctx.Context()).SelectBuilder(&guards, query); err != nil { + return nil, errors.Wrapf(err, "failed to get channel guards for channel=%s", channelID) + } + + return guards, nil +} + +func (s *SqlChannelGuardStore) GetAll(rctx request.CTX) ([]*store.ChannelGuard, error) { + guards := []*store.ChannelGuard{} + if err := s.DBXFromContext(rctx.Context()).SelectBuilder(&guards, s.channelGuardSelectQuery); err != nil { + return nil, errors.Wrap(err, "failed to get all channel guards") + } + + return guards, nil +} diff --git a/server/channels/store/sqlstore/channel_guard_store_test.go b/server/channels/store/sqlstore/channel_guard_store_test.go new file mode 100644 index 000000000000..253771560501 --- /dev/null +++ b/server/channels/store/sqlstore/channel_guard_store_test.go @@ -0,0 +1,14 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package sqlstore + +import ( + "testing" + + "github.com/mattermost/mattermost/server/v8/channels/store/storetest" +) + +func TestChannelGuardStore(t *testing.T) { + StoreTest(t, storetest.TestChannelGuardStore) +} diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go index d98f215dacbe..9b4cbf7805c0 100644 --- a/server/channels/store/sqlstore/channel_store.go +++ b/server/channels/store/sqlstore/channel_store.go @@ -3260,6 +3260,9 @@ func (s SqlChannelStore) Autocomplete(rctx request.CTX, userID, term string, inc From("ChannelMembers"). Where(sq.Eq{"UserId": userID}))) } else { + // Non-guests see public channels, private channels they're a member of, and + // discoverable private channels (subject to a post-query ABAC visibility filter + // applied at the app layer for policy-enforced channels). query = query.Where(sq.Or{ sq.NotEq{"c.Type": model.ChannelTypePrivate}, sq.And{ @@ -3268,6 +3271,10 @@ func (s SqlChannelStore) Autocomplete(rctx request.CTX, userID, term string, inc From("ChannelMembers"). Where(sq.Eq{"UserId": userID})), }, + sq.And{ + sq.Eq{"c.Type": model.ChannelTypePrivate}, + sq.Eq{"c.Discoverable": true}, + }, }) } @@ -3311,12 +3318,19 @@ func (s SqlChannelStore) buildAutocompleteInTeamQuery(teamID, userID, term strin if isGuest { query = query.Where(sq.Expr("c.Id IN (?)", memberSubQuery)) } else { + // Non-guests see public channels, private channels they're a member of, and + // discoverable private channels (subject to a post-query ABAC visibility filter + // applied at the app layer for policy-enforced channels). query = query.Where(sq.Or{ sq.NotEq{"c.Type": model.ChannelTypePrivate}, sq.And{ sq.Eq{"c.Type": model.ChannelTypePrivate}, sq.Expr("c.Id IN (?)", memberSubQuery), }, + sq.And{ + sq.Eq{"c.Type": model.ChannelTypePrivate}, + sq.Eq{"c.Discoverable": true}, + }, }) } diff --git a/server/channels/store/sqlstore/migration_000172_test.go b/server/channels/store/sqlstore/migration_000185_test.go similarity index 98% rename from server/channels/store/sqlstore/migration_000172_test.go rename to server/channels/store/sqlstore/migration_000185_test.go index 7dba1969f714..bb6a37fa9904 100644 --- a/server/channels/store/sqlstore/migration_000172_test.go +++ b/server/channels/store/sqlstore/migration_000185_test.go @@ -23,7 +23,7 @@ func readMigrationSQL(t *testing.T, filename string) string { return string(data) } -func TestMigration000172(t *testing.T) { +func TestMigration000185(t *testing.T) { logger := mlog.CreateTestLogger(t) settings, err := makeSqlSettings(model.DatabaseDriverPostgres) @@ -206,7 +206,7 @@ func TestMigration000172(t *testing.T) { assert.Equal(t, groupID, val.GroupID, "value GroupID should remain unchanged after down") } -func TestMigration000172DownPreservesNonUserFields(t *testing.T) { +func TestMigration000185DownPreservesNonUserFields(t *testing.T) { logger := mlog.CreateTestLogger(t) settings, err := makeSqlSettings(model.DatabaseDriverPostgres) @@ -299,7 +299,7 @@ func TestMigration000172DownPreservesNonUserFields(t *testing.T) { assert.Equal(t, "sysadmin", channelField.PermissionOptions.String) } -func TestMigration000172NoOpOnFreshDB(t *testing.T) { +func TestMigration000185NoOpOnFreshDB(t *testing.T) { logger := mlog.CreateTestLogger(t) settings, err := makeSqlSettings(model.DatabaseDriverPostgres) diff --git a/server/channels/store/sqlstore/store.go b/server/channels/store/sqlstore/store.go index 571435c5c115..f252c26ff93a 100644 --- a/server/channels/store/sqlstore/store.go +++ b/server/channels/store/sqlstore/store.go @@ -105,6 +105,7 @@ type SqlStoreStores struct { postPersistentNotification store.PostPersistentNotificationStore desktopTokens store.DesktopTokensStore channelBookmarks store.ChannelBookmarkStore + channelGuard store.ChannelGuardStore scheduledPost store.ScheduledPostStore view store.ViewStore propertyGroup store.PropertyGroupStore @@ -292,6 +293,7 @@ func New(settings model.SqlSettings, logger mlog.LoggerIFace, metrics einterface store.stores.postPersistentNotification = newSqlPostPersistentNotificationStore(store) store.stores.desktopTokens = newSqlDesktopTokensStore(store, metrics) store.stores.channelBookmarks = newSqlChannelBookmarkStore(store) + store.stores.channelGuard = newSqlChannelGuardStore(store) store.stores.scheduledPost = newScheduledPostStore(store) store.stores.view = newSqlViewStore(store) store.stores.propertyGroup = newPropertyGroupStore(store) @@ -913,6 +915,10 @@ func (ss *SqlStore) ChannelBookmark() store.ChannelBookmarkStore { return ss.stores.channelBookmarks } +func (ss *SqlStore) ChannelGuard() store.ChannelGuardStore { + return ss.stores.channelGuard +} + func (ss *SqlStore) View() store.ViewStore { return ss.stores.view } diff --git a/server/channels/store/store.go b/server/channels/store/store.go index b6798551d78e..816d69e330c8 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -62,6 +62,7 @@ type Store interface { LinkMetadata() LinkMetadataStore SharedChannel() SharedChannelStore Draft() DraftStore + ChannelGuard() ChannelGuardStore MarkSystemRanUnitTests() Close() LockToMaster() @@ -1076,6 +1077,21 @@ type PostPriorityStore interface { Delete(postID string) error } +// ChannelGuard is a single claim row asserting that a plugin has registered as a guard for a given +// channel. Plugins may co-claim a channel; one row per (ChannelId, PluginId) pair. +type ChannelGuard struct { + ChannelId string + PluginId string + CreatedAt int64 +} + +type ChannelGuardStore interface { + Save(rctx request.CTX, guard *ChannelGuard) error + Delete(rctx request.CTX, channelID, pluginID string) (rowsAffected int64, err error) + GetForChannel(rctx request.CTX, channelID string) ([]*ChannelGuard, error) + GetAll(rctx request.CTX) ([]*ChannelGuard, error) +} + type DraftStore interface { Upsert(d *model.Draft) (*model.Draft, error) Get(userID, channelID, rootID string, includeDeleted bool) (*model.Draft, error) diff --git a/server/channels/store/storetest/channel_guard_store.go b/server/channels/store/storetest/channel_guard_store.go new file mode 100644 index 000000000000..5b961eb4da06 --- /dev/null +++ b/server/channels/store/storetest/channel_guard_store.go @@ -0,0 +1,141 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package storetest + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestChannelGuardStore(t *testing.T, rctx request.CTX, ss store.Store) { + t.Run("SaveAndGetForChannel", func(t *testing.T) { testChannelGuardSaveAndGetForChannel(t, rctx, ss) }) + t.Run("SaveIdempotentSamePlugin", func(t *testing.T) { testChannelGuardSaveIdempotentSamePlugin(t, rctx, ss) }) + t.Run("SaveTwoPluginsSameChannel", func(t *testing.T) { testChannelGuardSaveTwoPluginsSameChannel(t, rctx, ss) }) + t.Run("Delete", func(t *testing.T) { testChannelGuardDelete(t, rctx, ss) }) + t.Run("DeleteRowsAffected", func(t *testing.T) { testChannelGuardDeleteRowsAffected(t, rctx, ss) }) + t.Run("GetAll", func(t *testing.T) { testChannelGuardGetAll(t, rctx, ss) }) +} + +func testChannelGuardSaveAndGetForChannel(t *testing.T, rctx request.CTX, ss store.Store) { + channelID := model.NewId() + pluginID := "com.example.plugin-a" + + guard := &store.ChannelGuard{ + ChannelId: channelID, + PluginId: pluginID, + CreatedAt: 1000, + } + + err := ss.ChannelGuard().Save(rctx, guard) + require.NoError(t, err) + + got, err := ss.ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, channelID, got[0].ChannelId) + assert.Equal(t, pluginID, got[0].PluginId) + assert.Equal(t, int64(1000), got[0].CreatedAt) +} + +func testChannelGuardSaveIdempotentSamePlugin(t *testing.T, rctx request.CTX, ss store.Store) { + channelID := model.NewId() + pluginID := "com.example.plugin-a" + + first := &store.ChannelGuard{ChannelId: channelID, PluginId: pluginID, CreatedAt: 1000} + require.NoError(t, ss.ChannelGuard().Save(rctx, first)) + + second := &store.ChannelGuard{ChannelId: channelID, PluginId: pluginID, CreatedAt: 2000} + require.NoError(t, ss.ChannelGuard().Save(rctx, second)) + + got, err := ss.ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, got, 1, "second save should be a no-op (DO NOTHING)") + assert.Equal(t, int64(1000), got[0].CreatedAt, "original CreatedAt should be preserved") +} + +func testChannelGuardSaveTwoPluginsSameChannel(t *testing.T, rctx request.CTX, ss store.Store) { + channelID := model.NewId() + pluginA := "com.example.plugin-a" + pluginB := "com.example.plugin-b" + + require.NoError(t, ss.ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelID, PluginId: pluginA, CreatedAt: 1000})) + require.NoError(t, ss.ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelID, PluginId: pluginB, CreatedAt: 2000})) + + got, err := ss.ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, got, 2) + + pluginIDs := []string{got[0].PluginId, got[1].PluginId} + assert.Contains(t, pluginIDs, pluginA) + assert.Contains(t, pluginIDs, pluginB) +} + +func testChannelGuardDelete(t *testing.T, rctx request.CTX, ss store.Store) { + channelID := model.NewId() + pluginA := "com.example.plugin-a" + pluginB := "com.example.plugin-b" + + require.NoError(t, ss.ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelID, PluginId: pluginA, CreatedAt: 1000})) + require.NoError(t, ss.ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelID, PluginId: pluginB, CreatedAt: 2000})) + + n, err := ss.ChannelGuard().Delete(rctx, channelID, pluginA) + require.NoError(t, err) + assert.Equal(t, int64(1), n, "expected 1 row deleted") + + got, err := ss.ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, got, 1, "only plugin-A's row should be deleted") + assert.Equal(t, pluginB, got[0].PluginId) + + // Deleting an already-removed (channel, plugin) pair is a no-op, not an error. + n, err = ss.ChannelGuard().Delete(rctx, channelID, pluginA) + require.NoError(t, err) + assert.Equal(t, int64(0), n, "expected 0 rows deleted for already-removed row") +} + +func testChannelGuardDeleteRowsAffected(t *testing.T, rctx request.CTX, ss store.Store) { + channelID := model.NewId() + pluginA := "com.example.plugin-a" + pluginB := "com.example.plugin-b" + + require.NoError(t, ss.ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelID, PluginId: pluginA, CreatedAt: 1000})) + + // Cross-plugin delete: pluginB has no claim on the channel; returns (0, nil). + n, err := ss.ChannelGuard().Delete(rctx, channelID, pluginB) + require.NoError(t, err) + assert.Equal(t, int64(0), n, "cross-plugin delete must return 0 rows affected") + + // pluginA's row must be untouched. + got, err := ss.ChannelGuard().GetForChannel(rctx, channelID) + require.NoError(t, err) + require.Len(t, got, 1, "pluginA row must remain after cross-plugin delete") + assert.Equal(t, pluginA, got[0].PluginId) +} + +func testChannelGuardGetAll(t *testing.T, rctx request.CTX, ss store.Store) { + channelA := model.NewId() + channelB := model.NewId() + pluginA := "com.example.plugin-a-" + model.NewId() + pluginB := "com.example.plugin-b-" + model.NewId() + + require.NoError(t, ss.ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelA, PluginId: pluginA, CreatedAt: 1000})) + require.NoError(t, ss.ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelA, PluginId: pluginB, CreatedAt: 1100})) + require.NoError(t, ss.ChannelGuard().Save(rctx, &store.ChannelGuard{ChannelId: channelB, PluginId: pluginA, CreatedAt: 1200})) + + all, err := ss.ChannelGuard().GetAll(rctx) + require.NoError(t, err) + + count := 0 + for _, g := range all { + if g.PluginId == pluginA || g.PluginId == pluginB { + count++ + } + } + assert.Equal(t, 3, count, "expected 3 rows from this test fixture") +} diff --git a/server/channels/store/storetest/mocks/ChannelGuardStore.go b/server/channels/store/storetest/mocks/ChannelGuardStore.go new file mode 100644 index 000000000000..7606ebaf885b --- /dev/null +++ b/server/channels/store/storetest/mocks/ChannelGuardStore.go @@ -0,0 +1,136 @@ +// Code generated by mockery v2.53.4. DO NOT EDIT. + +// Regenerate this file using `make store-mocks`. + +package mocks + +import ( + request "github.com/mattermost/mattermost/server/public/shared/request" + store "github.com/mattermost/mattermost/server/v8/channels/store" + mock "github.com/stretchr/testify/mock" +) + +// ChannelGuardStore is an autogenerated mock type for the ChannelGuardStore type +type ChannelGuardStore struct { + mock.Mock +} + +// Delete provides a mock function with given fields: rctx, channelID, pluginID +func (_m *ChannelGuardStore) Delete(rctx request.CTX, channelID string, pluginID string) (int64, error) { + ret := _m.Called(rctx, channelID, pluginID) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(request.CTX, string, string) (int64, error)); ok { + return rf(rctx, channelID, pluginID) + } + if rf, ok := ret.Get(0).(func(request.CTX, string, string) int64); ok { + r0 = rf(rctx, channelID, pluginID) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(request.CTX, string, string) error); ok { + r1 = rf(rctx, channelID, pluginID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAll provides a mock function with given fields: rctx +func (_m *ChannelGuardStore) GetAll(rctx request.CTX) ([]*store.ChannelGuard, error) { + ret := _m.Called(rctx) + + if len(ret) == 0 { + panic("no return value specified for GetAll") + } + + var r0 []*store.ChannelGuard + var r1 error + if rf, ok := ret.Get(0).(func(request.CTX) ([]*store.ChannelGuard, error)); ok { + return rf(rctx) + } + if rf, ok := ret.Get(0).(func(request.CTX) []*store.ChannelGuard); ok { + r0 = rf(rctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*store.ChannelGuard) + } + } + + if rf, ok := ret.Get(1).(func(request.CTX) error); ok { + r1 = rf(rctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetForChannel provides a mock function with given fields: rctx, channelID +func (_m *ChannelGuardStore) GetForChannel(rctx request.CTX, channelID string) ([]*store.ChannelGuard, error) { + ret := _m.Called(rctx, channelID) + + if len(ret) == 0 { + panic("no return value specified for GetForChannel") + } + + var r0 []*store.ChannelGuard + var r1 error + if rf, ok := ret.Get(0).(func(request.CTX, string) ([]*store.ChannelGuard, error)); ok { + return rf(rctx, channelID) + } + if rf, ok := ret.Get(0).(func(request.CTX, string) []*store.ChannelGuard); ok { + r0 = rf(rctx, channelID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*store.ChannelGuard) + } + } + + if rf, ok := ret.Get(1).(func(request.CTX, string) error); ok { + r1 = rf(rctx, channelID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Save provides a mock function with given fields: rctx, guard +func (_m *ChannelGuardStore) Save(rctx request.CTX, guard *store.ChannelGuard) error { + ret := _m.Called(rctx, guard) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(request.CTX, *store.ChannelGuard) error); ok { + r0 = rf(rctx, guard) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewChannelGuardStore creates a new instance of ChannelGuardStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewChannelGuardStore(t interface { + mock.TestingT + Cleanup(func()) +}) *ChannelGuardStore { + mock := &ChannelGuardStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/server/channels/store/storetest/mocks/Store.go b/server/channels/store/storetest/mocks/Store.go index 62ef539c03e8..16ed34d912f6 100644 --- a/server/channels/store/storetest/mocks/Store.go +++ b/server/channels/store/storetest/mocks/Store.go @@ -160,6 +160,26 @@ func (_m *Store) ChannelBookmark() store.ChannelBookmarkStore { return r0 } +// ChannelGuard provides a mock function with no fields +func (_m *Store) ChannelGuard() store.ChannelGuardStore { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ChannelGuard") + } + + var r0 store.ChannelGuardStore + if rf, ok := ret.Get(0).(func() store.ChannelGuardStore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.ChannelGuardStore) + } + } + + return r0 +} + // ChannelJoinRequest provides a mock function with no fields func (_m *Store) ChannelJoinRequest() store.ChannelJoinRequestStore { ret := _m.Called() diff --git a/server/channels/store/storetest/mocks/UserStore.go b/server/channels/store/storetest/mocks/UserStore.go index 3beef15a46fc..312e0a9b88ba 100644 --- a/server/channels/store/storetest/mocks/UserStore.go +++ b/server/channels/store/storetest/mocks/UserStore.go @@ -355,6 +355,24 @@ func (_m *UserStore) DeactivateMagicLinkGuests() ([]string, error) { return r0, r1 } +// DecrementFailedPasswordAttempts provides a mock function with given fields: userID +func (_m *UserStore) DecrementFailedPasswordAttempts(userID string) error { + ret := _m.Called(userID) + + if len(ret) == 0 { + panic("no return value specified for DecrementFailedPasswordAttempts") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(userID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DemoteUserToGuest provides a mock function with given fields: userID func (_m *UserStore) DemoteUserToGuest(userID string) (*model.User, error) { ret := _m.Called(userID) @@ -2078,6 +2096,34 @@ func (_m *UserStore) StoreMfaUsedTimestamps(userID string, ts []int) error { return r0 } +// TryIncrementFailedPasswordAttempts provides a mock function with given fields: userID, maxAttempts +func (_m *UserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { + ret := _m.Called(userID, maxAttempts) + + if len(ret) == 0 { + panic("no return value specified for TryIncrementFailedPasswordAttempts") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string, int) (bool, error)); ok { + return rf(userID, maxAttempts) + } + if rf, ok := ret.Get(0).(func(string, int) bool); ok { + r0 = rf(userID, maxAttempts) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string, int) error); ok { + r1 = rf(userID, maxAttempts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Update provides a mock function with given fields: rctx, user, allowRoleUpdate func (_m *UserStore) Update(rctx request.CTX, user *model.User, allowRoleUpdate bool) (*model.UserUpdate, error) { ret := _m.Called(rctx, user, allowRoleUpdate) @@ -2154,52 +2200,6 @@ func (_m *UserStore) UpdateFailedPasswordAttempts(userID string, attempts int) e return r0 } -// TryIncrementFailedPasswordAttempts provides a mock function with given fields: userID, maxAttempts -func (_m *UserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { - ret := _m.Called(userID, maxAttempts) - - if len(ret) == 0 { - panic("no return value specified for TryIncrementFailedPasswordAttempts") - } - - var r0 bool - var r1 error - if rf, ok := ret.Get(0).(func(string, int) (bool, error)); ok { - return rf(userID, maxAttempts) - } - if rf, ok := ret.Get(0).(func(string, int) bool); ok { - r0 = rf(userID, maxAttempts) - } else { - r0 = ret.Get(0).(bool) - } - - if rf, ok := ret.Get(1).(func(string, int) error); ok { - r1 = rf(userID, maxAttempts) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// DecrementFailedPasswordAttempts provides a mock function with given fields: userID -func (_m *UserStore) DecrementFailedPasswordAttempts(userID string) error { - ret := _m.Called(userID) - - if len(ret) == 0 { - panic("no return value specified for DecrementFailedPasswordAttempts") - } - - var r0 error - if rf, ok := ret.Get(0).(func(string) error); ok { - r0 = rf(userID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // UpdateLastLogin provides a mock function with given fields: userID, lastLogin func (_m *UserStore) UpdateLastLogin(userID string, lastLogin int64) error { ret := _m.Called(userID, lastLogin) diff --git a/server/channels/store/storetest/store.go b/server/channels/store/storetest/store.go index bfc113a828cd..999754b46b8b 100644 --- a/server/channels/store/storetest/store.go +++ b/server/channels/store/storetest/store.go @@ -64,6 +64,7 @@ type Store struct { PostPersistentNotificationStore mocks.PostPersistentNotificationStore DesktopTokensStore mocks.DesktopTokensStore ChannelBookmarkStore mocks.ChannelBookmarkStore + ChannelGuardStore mocks.ChannelGuardStore ScheduledPostStore mocks.ScheduledPostStore PropertyGroupStore mocks.PropertyGroupStore PropertyFieldStore mocks.PropertyFieldStore @@ -121,6 +122,7 @@ func (s *Store) ChannelMemberHistory() store.ChannelMemberHistoryStore { return &s.ChannelMemberHistoryStore } func (s *Store) ChannelBookmark() store.ChannelBookmarkStore { return &s.ChannelBookmarkStore } +func (s *Store) ChannelGuard() store.ChannelGuardStore { return &s.ChannelGuardStore } func (s *Store) DesktopTokens() store.DesktopTokensStore { return &s.DesktopTokensStore } func (s *Store) NotifyAdmin() store.NotifyAdminStore { return &s.NotifyAdminStore } func (s *Store) Group() store.GroupStore { return &s.GroupStore } @@ -239,6 +241,7 @@ func (s *Store) AssertExpectations(t mock.TestingT) bool { &s.PostPersistentNotificationStore, &s.DesktopTokensStore, &s.ChannelBookmarkStore, + &s.ChannelGuardStore, &s.ScheduledPostStore, &s.AccessControlPolicyStore, &s.AttributesStore, diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 29c520c06e64..db5c7249039b 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -26,6 +26,7 @@ type TimerLayer struct { BotStore store.BotStore ChannelStore store.ChannelStore ChannelBookmarkStore store.ChannelBookmarkStore + ChannelGuardStore store.ChannelGuardStore ChannelJoinRequestStore store.ChannelJoinRequestStore ChannelMemberHistoryStore store.ChannelMemberHistoryStore ClusterDiscoveryStore store.ClusterDiscoveryStore @@ -107,6 +108,10 @@ func (s *TimerLayer) ChannelBookmark() store.ChannelBookmarkStore { return s.ChannelBookmarkStore } +func (s *TimerLayer) ChannelGuard() store.ChannelGuardStore { + return s.ChannelGuardStore +} + func (s *TimerLayer) ChannelJoinRequest() store.ChannelJoinRequestStore { return s.ChannelJoinRequestStore } @@ -346,6 +351,11 @@ type TimerLayerChannelBookmarkStore struct { Root *TimerLayer } +type TimerLayerChannelGuardStore struct { + store.ChannelGuardStore + Root *TimerLayer +} + type TimerLayerChannelJoinRequestStore struct { store.ChannelJoinRequestStore Root *TimerLayer @@ -3292,6 +3302,70 @@ func (s *TimerLayerChannelBookmarkStore) UpdateSortOrder(bookmarkID string, chan return result, err } +func (s *TimerLayerChannelGuardStore) Delete(rctx request.CTX, channelID string, pluginID string) (int64, error) { + start := time.Now() + + result, err := s.ChannelGuardStore.Delete(rctx, channelID, pluginID) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelGuardStore.Delete", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelGuardStore) GetAll(rctx request.CTX) ([]*store.ChannelGuard, error) { + start := time.Now() + + result, err := s.ChannelGuardStore.GetAll(rctx) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelGuardStore.GetAll", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelGuardStore) GetForChannel(rctx request.CTX, channelID string) ([]*store.ChannelGuard, error) { + start := time.Now() + + result, err := s.ChannelGuardStore.GetForChannel(rctx, channelID) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelGuardStore.GetForChannel", success, elapsed) + } + return result, err +} + +func (s *TimerLayerChannelGuardStore) Save(rctx request.CTX, guard *store.ChannelGuard) error { + start := time.Now() + + err := s.ChannelGuardStore.Save(rctx, guard) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("ChannelGuardStore.Save", success, elapsed) + } + return err +} + func (s *TimerLayerChannelJoinRequestStore) CountPending(channelId string) (int64, error) { start := time.Now() @@ -12784,6 +12858,22 @@ func (s *TimerLayerUserStore) DeactivateMagicLinkGuests() ([]string, error) { return result, err } +func (s *TimerLayerUserStore) DecrementFailedPasswordAttempts(userID string) error { + start := time.Now() + + err := s.UserStore.DecrementFailedPasswordAttempts(userID) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.DecrementFailedPasswordAttempts", success, elapsed) + } + return err +} + func (s *TimerLayerUserStore) DemoteUserToGuest(userID string) (*model.User, error) { start := time.Now() @@ -13805,10 +13895,10 @@ func (s *TimerLayerUserStore) StoreMfaUsedTimestamps(userID string, ts []int) er return err } -func (s *TimerLayerUserStore) Update(rctx request.CTX, user *model.User, allowRoleUpdate bool) (*model.UserUpdate, error) { +func (s *TimerLayerUserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { start := time.Now() - result, err := s.UserStore.Update(rctx, user, allowRoleUpdate) + result, err := s.UserStore.TryIncrementFailedPasswordAttempts(userID, maxAttempts) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -13816,15 +13906,15 @@ func (s *TimerLayerUserStore) Update(rctx request.CTX, user *model.User, allowRo if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("UserStore.Update", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.TryIncrementFailedPasswordAttempts", success, elapsed) } return result, err } -func (s *TimerLayerUserStore) UpdateAuthData(userID string, service string, authData *string, email string, resetMfa bool) (string, error) { +func (s *TimerLayerUserStore) Update(rctx request.CTX, user *model.User, allowRoleUpdate bool) (*model.UserUpdate, error) { start := time.Now() - result, err := s.UserStore.UpdateAuthData(userID, service, authData, email, resetMfa) + result, err := s.UserStore.Update(rctx, user, allowRoleUpdate) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -13832,31 +13922,15 @@ func (s *TimerLayerUserStore) UpdateAuthData(userID string, service string, auth if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("UserStore.UpdateAuthData", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.Update", success, elapsed) } return result, err } -func (s *TimerLayerUserStore) UpdateFailedPasswordAttempts(userID string, attempts int) error { - start := time.Now() - - err := s.UserStore.UpdateFailedPasswordAttempts(userID, attempts) - - elapsed := float64(time.Since(start)) / float64(time.Second) - if s.Root.Metrics != nil { - success := "false" - if err == nil { - success = "true" - } - s.Root.Metrics.ObserveStoreMethodDuration("UserStore.UpdateFailedPasswordAttempts", success, elapsed) - } - return err -} - -func (s *TimerLayerUserStore) TryIncrementFailedPasswordAttempts(userID string, maxAttempts int) (bool, error) { +func (s *TimerLayerUserStore) UpdateAuthData(userID string, service string, authData *string, email string, resetMfa bool) (string, error) { start := time.Now() - result, err := s.UserStore.TryIncrementFailedPasswordAttempts(userID, maxAttempts) + result, err := s.UserStore.UpdateAuthData(userID, service, authData, email, resetMfa) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -13864,15 +13938,15 @@ func (s *TimerLayerUserStore) TryIncrementFailedPasswordAttempts(userID string, if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("UserStore.TryIncrementFailedPasswordAttempts", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.UpdateAuthData", success, elapsed) } return result, err } -func (s *TimerLayerUserStore) DecrementFailedPasswordAttempts(userID string) error { +func (s *TimerLayerUserStore) UpdateFailedPasswordAttempts(userID string, attempts int) error { start := time.Now() - err := s.UserStore.DecrementFailedPasswordAttempts(userID) + err := s.UserStore.UpdateFailedPasswordAttempts(userID, attempts) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -13880,7 +13954,7 @@ func (s *TimerLayerUserStore) DecrementFailedPasswordAttempts(userID string) err if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("UserStore.DecrementFailedPasswordAttempts", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("UserStore.UpdateFailedPasswordAttempts", success, elapsed) } return err } @@ -14812,6 +14886,7 @@ func New(childStore store.Store, metrics einterfaces.MetricsInterface) *TimerLay newStore.BotStore = &TimerLayerBotStore{BotStore: childStore.Bot(), Root: &newStore} newStore.ChannelStore = &TimerLayerChannelStore{ChannelStore: childStore.Channel(), Root: &newStore} newStore.ChannelBookmarkStore = &TimerLayerChannelBookmarkStore{ChannelBookmarkStore: childStore.ChannelBookmark(), Root: &newStore} + newStore.ChannelGuardStore = &TimerLayerChannelGuardStore{ChannelGuardStore: childStore.ChannelGuard(), Root: &newStore} newStore.ChannelJoinRequestStore = &TimerLayerChannelJoinRequestStore{ChannelJoinRequestStore: childStore.ChannelJoinRequest(), Root: &newStore} newStore.ChannelMemberHistoryStore = &TimerLayerChannelMemberHistoryStore{ChannelMemberHistoryStore: childStore.ChannelMemberHistory(), Root: &newStore} newStore.ClusterDiscoveryStore = &TimerLayerClusterDiscoveryStore{ClusterDiscoveryStore: childStore.ClusterDiscovery(), Root: &newStore} diff --git a/server/channels/testlib/store.go b/server/channels/testlib/store.go index 5d01e25c6ab0..c8fe7f149572 100644 --- a/server/channels/testlib/store.go +++ b/server/channels/testlib/store.go @@ -145,6 +145,9 @@ func GetMockStoreForSetupFunctions() *mocks.Store { propertyFieldStore := mocks.PropertyFieldStore{} propertyValueStore := mocks.PropertyValueStore{} + channelGuardStore := mocks.ChannelGuardStore{} + channelGuardStore.On("GetAll", mock.Anything).Return([]*store.ChannelGuard{}, nil) + groupsByName := map[string]*model.PropertyGroup{} accessControlGroup := &model.PropertyGroup{ID: model.NewId(), Name: model.AccessControlPropertyGroupName, Version: model.PropertyGroupVersionV2} @@ -218,6 +221,7 @@ func GetMockStoreForSetupFunctions() *mocks.Store { mockStore.On("PropertyGroup").Return(&propertyGroupStore) mockStore.On("PropertyField").Return(&propertyFieldStore) mockStore.On("PropertyValue").Return(&propertyValueStore) + mockStore.On("ChannelGuard").Return(&channelGuardStore) return &mockStore } diff --git a/server/channels/web/params.go b/server/channels/web/params.go index b57cbf6a2a3a..1bd741bde3dd 100644 --- a/server/channels/web/params.go +++ b/server/channels/web/params.go @@ -129,6 +129,9 @@ type Params struct { GroupName string ObjectType string TargetId string + + // Channel join requests + RequestId string } var getChannelMembersForUserRegex = regexp.MustCompile("/api/v4/users/[A-Za-z0-9]{26}/channel_members") @@ -205,6 +208,7 @@ func ParamsFromRequest(r *http.Request) *Params { params.GroupName = props["group_name"] params.ObjectType = props["object_type"] params.TargetId = props["target_id"] + params.RequestId = props["request_id"] params.Scope = query.Get("scope") if val, err := strconv.Atoi(query.Get("page")); err != nil || (val < 0 && params.UserId == "" && !getChannelMembersForUserRegex.MatchString(r.URL.Path)) { diff --git a/server/cmd/mattermost/commands/db_ping.go b/server/cmd/mattermost/commands/db_ping.go new file mode 100644 index 000000000000..80729a24e407 --- /dev/null +++ b/server/cmd/mattermost/commands/db_ping.go @@ -0,0 +1,181 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package commands + +import ( + "context" + dbsql "database/sql" + stdErrors "errors" + "net/url" + "time" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" + "github.com/mattermost/mattermost/server/v8/config" +) + +const ( + dbPingDefaultTimeout = 5 * time.Minute + dbPingDefaultRetryInterval = 2 * time.Second + // dbPingAttemptTimeout caps a single PingContext call so a hung connection + // doesn't block the whole timeout budget on one attempt. + dbPingAttemptTimeout = 10 * time.Second +) + +var DBPingCmd = &cobra.Command{ + Use: "ping", + Short: "Wait for the database to become reachable", + Long: `Pings the configured Mattermost database, retrying until --timeout expires. +Exits 0 once the database accepts a ping. Exits non-zero on timeout or fatal error. + +Intended for use as a readiness probe (e.g. a Kubernetes init container). +Resolves the DSN exactly like 'mattermost db migrate' / 'mattermost db init': +the --config flag, then MM_CONFIG, then config.json (which is then loaded as +a config store and SqlSettings.DataSource is used).`, + Example: ` # Database DSN passed via --config (preferred for readiness probes) + $ mattermost db ping --config postgres://mmuser:mostest@localhost/mattermost --timeout 2m + + # Or via MM_CONFIG + $ MM_CONFIG=postgres://localhost/mattermost mattermost db ping`, + Args: cobra.NoArgs, + RunE: dbPingCmdF, +} + +func init() { + DBPingCmd.Flags().Duration("timeout", dbPingDefaultTimeout, + "Maximum total time to wait for the DB to become reachable.") + DBPingCmd.Flags().Duration("retry-interval", dbPingDefaultRetryInterval, + "Sleep between ping attempts.") + DbCmd.AddCommand(DBPingCmd) +} + +func dbPingCmdF(command *cobra.Command, _ []string) error { + logger := mlog.CreateConsoleLogger() + defer func() { + _ = logger.Shutdown() + }() + + timeout, _ := command.Flags().GetDuration("timeout") + retryInterval, _ := command.Flags().GetDuration("retry-interval") + if timeout <= 0 { + return errors.New("--timeout must be > 0") + } + if retryInterval <= 0 { + return errors.New("--retry-interval must be > 0") + } + + dsn, err := resolvePingDataSource(command) + if err != nil { + return err + } + + sanitized, err := sanitizePingDataSource(dsn) + if err != nil { + return err + } + + db, err := dbsql.Open(model.DatabaseDriverPostgres, dsn) + if err != nil { + return errors.Wrap(err, "failed to open SQL connection") + } + defer db.Close() + + // Minimal pool — this is a one-shot readiness probe. + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + ctx, cancel := context.WithTimeout(command.Context(), timeout) + defer cancel() + + return pingWithRetry(ctx, db, retryInterval, logger.With( + mlog.String("dataSource", sanitized), + )) +} + +func sanitizePingDataSource(dsn string) (string, error) { + sanitized, err := model.SanitizeDataSource(model.DatabaseDriverPostgres, dsn) + if err != nil { + return "", safeDataSourceSanitizationError(err) + } + + return sanitized, nil +} + +func safeDataSourceSanitizationError(err error) error { + var urlErr *url.Error + if stdErrors.As(err, &urlErr) { + if urlErr.Err != nil { + return errors.Errorf("invalid database DSN: %v", urlErr.Err) + } + return errors.New("invalid database DSN") + } + + return errors.New("invalid database DSN") +} + +// resolvePingDataSource returns a postgres DSN to ping. +// +// If the configured DSN is a postgres:// / postgresql:// URL it is returned as-is +// (fast path: no config store load required). Otherwise it is treated as a file +// path: a config.Store is loaded read-only (createFileIfNotExist=false so the +// command never has a side effect of creating a config file) and +// SqlSettings.DataSource is returned. +func resolvePingDataSource(command *cobra.Command) (string, error) { + cfgDSN := getConfigDSN(command, config.GetEnvironment()) + + if config.IsDatabaseDSN(cfgDSN) { + return cfgDSN, nil + } + + cfgStore, err := config.NewStoreFromDSN(cfgDSN, true /*readOnly*/, nil /*customDefaults*/, false /*createFileIfNotExist*/) + if err != nil { + return "", errors.Wrapf(err, "failed to load configuration from %q", cfgDSN) + } + defer cfgStore.Close() + + sqlSettings := cfgStore.Get().SqlSettings + if sqlSettings.DataSource == nil || *sqlSettings.DataSource == "" { + return "", errors.New("no database DSN configured: set --config or MM_CONFIG to a postgres:// URL, or ensure SqlSettings.DataSource is set in your configuration") + } + if !config.IsDatabaseDSN(*sqlSettings.DataSource) { + // Defensive: the loaded config has a non-postgres DataSource. Mattermost is postgres-only. + return "", errors.New("configured SqlSettings.DataSource is not a postgres DSN") + } + return *sqlSettings.DataSource, nil +} + +// pingWithRetry pings db every retryInterval until it succeeds or ctx is done. +// Each individual PingContext call is capped at dbPingAttemptTimeout so a hung +// network connection cannot consume the entire timeout budget on a single try. +func pingWithRetry(ctx context.Context, db *dbsql.DB, retryInterval time.Duration, logger mlog.LoggerIFace) error { + attempt := 0 + for { + attempt++ + attemptCtx, cancel := context.WithTimeout(ctx, dbPingAttemptTimeout) + err := db.PingContext(attemptCtx) + cancel() + if err == nil { + logger.Info("Database is reachable", mlog.Int("attempt", attempt)) + return nil + } + + // Surface progress on every attempt so operators can see the probe is alive. + // Intentionally omit the raw error: lib/pq error strings can echo DSN fragments. + logger.Info("Waiting for database", + mlog.Int("attempt", attempt), + mlog.Duration("retry_interval", retryInterval), + mlog.String("status", "ping_failed"), + ) + + // Wait retryInterval, but bail early if ctx is done. + select { + case <-ctx.Done(): + return errors.Wrapf(ctx.Err(), "timed out waiting for database after %d attempts", attempt) + case <-time.After(retryInterval): + } + } +} diff --git a/server/cmd/mattermost/commands/db_ping_test.go b/server/cmd/mattermost/commands/db_ping_test.go new file mode 100644 index 000000000000..005dfc343d22 --- /dev/null +++ b/server/cmd/mattermost/commands/db_ping_test.go @@ -0,0 +1,322 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package commands + +import ( + "context" + dbsql "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" +) + +// dsnFromHelper builds a postgres:// DSN from the test main helper's SqlSettings. +// The helper itself stores the live test postgres DSN. +func dsnFromHelper(t *testing.T) string { + t.Helper() + require.NotNil(t, mainHelper, "mainHelper must be initialized; do not run with -short") + settings := mainHelper.GetSQLSettings() + require.NotNil(t, settings.DataSource) + require.NotEmpty(t, *settings.DataSource) + return *settings.DataSource +} + +// --- subprocess (CLI integration) tests --- + +func TestDBPingHappyPath(t *testing.T) { + if testing.Short() { + t.Skip("requires live test database") + } + + th := SetupWithStoreMock(t) + output := th.CheckCommand(t, "db", "ping", "--timeout", "30s") + require.Contains(t, output, "Database is reachable", + "expected success log line in command output, got: %s", output) +} + +func TestDBPingDirectDSN(t *testing.T) { + if testing.Short() { + t.Skip("requires live test database") + } + + th := SetupWithStoreMock(t) + th.SetAutoConfig(false) + + dsn := dsnFromHelper(t) + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", "--timeout", "30s") + require.NoError(t, err, "command should succeed when DSN is direct postgres URL; output: %s", output) + require.Contains(t, output, "Database is reachable") +} + +func TestDBPingTimeoutOnUnreachableDB(t *testing.T) { + th := SetupWithStoreMock(t) + th.SetAutoConfig(false) + + // Loopback to a port nothing listens on; connect_timeout=1 keeps each + // attempt short. We allow 2s total with 500ms between attempts so we get + // multiple "Waiting for database" lines. + dsn := "postgres://nobody@127.0.0.1:1/mattermost?sslmode=disable&connect_timeout=1" + + start := time.Now() + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", + "--timeout", "2s", "--retry-interval", "500ms") + elapsed := time.Since(start) + + require.Error(t, err, "command should fail on unreachable DB; output: %s", output) + require.Contains(t, output, "timed out waiting for database", + "expected timeout message in output, got: %s", output) + require.LessOrEqual(t, elapsed, 30*time.Second, + "command should not exceed a generous upper bound; took %s", elapsed) +} + +func TestDBPingInvalidDSN(t *testing.T) { + th := SetupWithStoreMock(t) + th.SetAutoConfig(false) + + // Passes IsDatabaseDSN (postgres:// prefix) so it takes the direct path. + dsn := "postgres://leakyuser:supersecret@[invalid" + + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", + "--timeout", "2s", "--retry-interval", "500ms") + require.Error(t, err, "command should fail on malformed DSN; output: %s", output) + require.Contains(t, output, "invalid database DSN", + "expected sanitized DSN parse error; got: %s", output) + require.Contains(t, output, "missing ']' in host", + "expected malformed DSN reason; got: %s", output) + require.NotContains(t, output, "supersecret", + "malformed DSN errors must not leak credentials; got: %s", output) + require.NotContains(t, output, "leakyuser", + "malformed DSN errors must not leak credentials; got: %s", output) +} + +func TestDBPingMissingConfigFile(t *testing.T) { + th := SetupWithStoreMock(t) + th.SetAutoConfig(false) + + // Point --config at a path that does not exist; createFileIfNotExist=false + // inside resolvePingDataSource means NewStoreFromDSN will return an error. + missing := th.TemporaryDirectory() + "/does-not-exist.json" + + output, err := th.RunCommandWithOutput(t, "--config", missing, "db", "ping", + "--timeout", "2s", "--retry-interval", "500ms") + require.Error(t, err, "command should fail when --config file does not exist; output: %s", output) + require.Contains(t, output, "failed to load configuration", + "expected config-load error message; got: %s", output) +} + +func TestDBPingFlagValidation(t *testing.T) { + th := SetupWithStoreMock(t) + th.SetAutoConfig(false) + dsn := "postgres://localhost:1/mattermost?sslmode=disable&connect_timeout=1" + + t.Run("zero timeout", func(t *testing.T) { + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", + "--timeout", "0s") + require.Error(t, err) + require.Contains(t, output, "--timeout must be > 0", + "expected timeout validation error; got: %s", output) + }) + + t.Run("zero retry interval", func(t *testing.T) { + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", + "--timeout", "1s", "--retry-interval", "0s") + require.Error(t, err) + require.Contains(t, output, "--retry-interval must be > 0", + "expected retry-interval validation error; got: %s", output) + }) + + t.Run("negative timeout", func(t *testing.T) { + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", + "--timeout", "-1s") + require.Error(t, err) + require.Contains(t, output, "--timeout must be > 0", + "expected timeout validation error for negative value; got: %s", output) + }) + + t.Run("garbage timeout value", func(t *testing.T) { + // cobra refuses to parse "garbage" as a duration; subcommand never runs. + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", + "--timeout", "garbage") + require.Error(t, err) + // Don't pin to exact text — cobra owns the error string here. Just + // confirm the subcommand's success log is absent. + require.NotContains(t, output, "Database is reachable", + "command should not have run successfully; got: %s", output) + }) +} + +func TestDBPingRetryIntervalHonored(t *testing.T) { + th := SetupWithStoreMock(t) + th.SetAutoConfig(false) + + dsn := "postgres://nobody@127.0.0.1:1/mattermost?sslmode=disable&connect_timeout=1" + + start := time.Now() + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", + "--timeout", "2s", "--retry-interval", "500ms") + elapsed := time.Since(start) + + require.Error(t, err) + // Loose lower bound: at 500ms intervals we expect at least 2 retries + // (~1s wall clock minimum) before the 2s timeout strikes. + require.GreaterOrEqual(t, elapsed, 1*time.Second, + "expected retries to span at least 1s; got %s", elapsed) + // Loose upper bound: don't exceed several multiples of the configured + // timeout — accommodates CI variance. + require.LessOrEqual(t, elapsed, 30*time.Second, + "expected command to bail close to --timeout; took %s", elapsed) + + waitingCount := strings.Count(output, "Waiting for database") + require.GreaterOrEqual(t, waitingCount, 2, + "expected at least 2 'Waiting for database' lines; got %d in output:\n%s", + waitingCount, output) +} + +// TestDBPingShortRetryIntervalProducesMoreAttempts verifies that the +// --retry-interval flag actually controls the cadence (not just the timeout). +func TestDBPingShortRetryIntervalProducesMoreAttempts(t *testing.T) { + th := SetupWithStoreMock(t) + th.SetAutoConfig(false) + + dsn := "postgres://nobody@127.0.0.1:1/mattermost?sslmode=disable&connect_timeout=1" + + output, err := th.RunCommandWithOutput(t, "--config", dsn, "db", "ping", + "--timeout", "3s", "--retry-interval", "200ms") + require.Error(t, err) + + waitingCount := strings.Count(output, "Waiting for database") + // At 200ms intervals over 3s we expect well more than 3 attempts even + // accounting for per-attempt connection overhead. + require.GreaterOrEqual(t, waitingCount, 3, + "expected several retries with short interval; got %d in output:\n%s", + waitingCount, output) +} + +// TestDBPingCmdRegistered confirms the new subcommand is wired into the +// existing DbCmd group, so users actually get `mattermost db ping`. +func TestDBPingCmdRegistered(t *testing.T) { + require.Contains(t, DbCmd.Commands(), DBPingCmd, + "DBPingCmd should be registered as a subcommand of DbCmd") + require.Equal(t, "ping", DBPingCmd.Use) + + // Flags exist with sensible defaults. + timeoutFlag := DBPingCmd.Flags().Lookup("timeout") + require.NotNil(t, timeoutFlag) + require.Equal(t, dbPingDefaultTimeout.String(), timeoutFlag.DefValue) + + intervalFlag := DBPingCmd.Flags().Lookup("retry-interval") + require.NotNil(t, intervalFlag) + require.Equal(t, dbPingDefaultRetryInterval.String(), intervalFlag.DefValue) +} + +// --- in-process tests of pingWithRetry / resolvePingDataSource --- + +func TestPingWithRetry_SuccessOnFirstAttempt(t *testing.T) { + if testing.Short() { + t.Skip("requires live test database") + } + + dsn := dsnFromHelper(t) + db, err := dbsql.Open(model.DatabaseDriverPostgres, dsn) + require.NoError(t, err) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + logger := mlog.CreateConsoleTestLogger(t) + err = pingWithRetry(ctx, db, 100*time.Millisecond, logger) + require.NoError(t, err) +} + +func TestPingWithRetry_TimeoutAgainstUnreachable(t *testing.T) { + dsn := "postgres://nobody@127.0.0.1:1/mattermost?sslmode=disable&connect_timeout=1" + db, err := dbsql.Open(model.DatabaseDriverPostgres, dsn) + require.NoError(t, err) + defer db.Close() + + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond) + defer cancel() + + logger := mlog.CreateConsoleTestLogger(t) + err = pingWithRetry(ctx, db, 200*time.Millisecond, logger) + elapsed := time.Since(start) + + require.Error(t, err) + require.True(t, + errors.Is(err, context.DeadlineExceeded) || + strings.Contains(err.Error(), "timed out waiting for database"), + "expected deadline-exceeded or timeout error; got %v", err) + // Must have honored the timeout, not just returned immediately. + require.LessOrEqual(t, elapsed, 30*time.Second, + "expected reasonable upper bound; took %s", elapsed) +} + +func TestPingWithRetry_ContextCancelImmediately(t *testing.T) { + dsn := "postgres://nobody@127.0.0.1:1/mattermost?sslmode=disable&connect_timeout=1" + db, err := dbsql.Open(model.DatabaseDriverPostgres, dsn) + require.NoError(t, err) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel before we even start + + logger := mlog.CreateConsoleTestLogger(t) + err = pingWithRetry(ctx, db, 1*time.Second, logger) + require.Error(t, err, "cancelled context should produce an error") +} + +// In-process tests of resolvePingDataSource. We drive DSN selection via the +// MM_CONFIG environment variable rather than the --config persistent flag +// because the persistent flag is only merged into a subcommand's local +// flagset during cobra's Execute() pipeline; calling resolvePingDataSource +// directly outside Execute means the flag would not be visible. +// MM_CONFIG is consumed by getConfigDSN as the second-precedence source. + +func TestResolvePingDataSource_DirectDSN(t *testing.T) { + wanted := "postgres://user:pw@example.invalid:5432/mm?sslmode=disable" + t.Setenv("MM_CONFIG", wanted) + + got, err := resolvePingDataSource(DBPingCmd) + require.NoError(t, err) + require.Equal(t, wanted, got) +} + +func TestResolvePingDataSource_DirectDSN_PostgresqlScheme(t *testing.T) { + wanted := "postgresql://user:pw@example.invalid:5432/mm?sslmode=disable" + t.Setenv("MM_CONFIG", wanted) + + got, err := resolvePingDataSource(DBPingCmd) + require.NoError(t, err) + require.Equal(t, wanted, got) +} + +func TestResolvePingDataSource_MissingFile(t *testing.T) { + missing := t.TempDir() + "/no-such-config.json" + t.Setenv("MM_CONFIG", missing) + + _, err := resolvePingDataSource(DBPingCmd) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to load configuration") +} + +// TestResolvePingDataSource_PointsAtDirectory verifies that pointing --config +// at a directory (not a JSON file) surfaces a clear, wrapped error. Catches +// regressions where we silently fall through instead of returning the load error. +func TestResolvePingDataSource_PointsAtDirectory(t *testing.T) { + dir := t.TempDir() + t.Setenv("MM_CONFIG", dir) + + _, err := resolvePingDataSource(DBPingCmd) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to load configuration", + "expected wrapped error; got %v", err) +} diff --git a/server/einterfaces/mocks/AccessControlSyncJobInterface.go b/server/einterfaces/mocks/AccessControlSyncJobInterface.go index e4b82fc3f56f..799df2fe44c6 100644 --- a/server/einterfaces/mocks/AccessControlSyncJobInterface.go +++ b/server/einterfaces/mocks/AccessControlSyncJobInterface.go @@ -5,10 +5,9 @@ package mocks import ( + model "github.com/mattermost/mattermost/server/public/model" jobs "github.com/mattermost/mattermost/server/v8/einterfaces/jobs" mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" ) // AccessControlSyncJobInterface is an autogenerated mock type for the AccessControlSyncJobInterface type diff --git a/server/einterfaces/mocks/AutoTranslationInterface.go b/server/einterfaces/mocks/AutoTranslationInterface.go index 7cedb0885612..e1c7087dabfc 100644 --- a/server/einterfaces/mocks/AutoTranslationInterface.go +++ b/server/einterfaces/mocks/AutoTranslationInterface.go @@ -7,11 +7,9 @@ package mocks import ( context "context" + model "github.com/mattermost/mattermost/server/public/model" jobs "github.com/mattermost/mattermost/server/v8/einterfaces/jobs" - mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" ) // AutoTranslationInterface is an autogenerated mock type for the AutoTranslationInterface type diff --git a/server/einterfaces/mocks/CloudJobInterface.go b/server/einterfaces/mocks/CloudJobInterface.go index ebe5a5efc7e6..d0c35f15abac 100644 --- a/server/einterfaces/mocks/CloudJobInterface.go +++ b/server/einterfaces/mocks/CloudJobInterface.go @@ -5,10 +5,9 @@ package mocks import ( + model "github.com/mattermost/mattermost/server/public/model" jobs "github.com/mattermost/mattermost/server/v8/einterfaces/jobs" mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" ) // CloudJobInterface is an autogenerated mock type for the CloudJobInterface type diff --git a/server/einterfaces/mocks/ClusterInterface.go b/server/einterfaces/mocks/ClusterInterface.go index 3a0ef3df5f21..f8aa96d40a19 100644 --- a/server/einterfaces/mocks/ClusterInterface.go +++ b/server/einterfaces/mocks/ClusterInterface.go @@ -5,12 +5,10 @@ package mocks import ( - einterfaces "github.com/mattermost/mattermost/server/v8/einterfaces" - mock "github.com/stretchr/testify/mock" - model "github.com/mattermost/mattermost/server/public/model" - request "github.com/mattermost/mattermost/server/public/shared/request" + einterfaces "github.com/mattermost/mattermost/server/v8/einterfaces" + mock "github.com/stretchr/testify/mock" ) // ClusterInterface is an autogenerated mock type for the ClusterInterface type diff --git a/server/einterfaces/mocks/DataRetentionJobInterface.go b/server/einterfaces/mocks/DataRetentionJobInterface.go index 0876690eaa45..962a7bfbb09b 100644 --- a/server/einterfaces/mocks/DataRetentionJobInterface.go +++ b/server/einterfaces/mocks/DataRetentionJobInterface.go @@ -5,10 +5,9 @@ package mocks import ( + model "github.com/mattermost/mattermost/server/public/model" jobs "github.com/mattermost/mattermost/server/v8/einterfaces/jobs" mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" ) // DataRetentionJobInterface is an autogenerated mock type for the DataRetentionJobInterface type diff --git a/server/einterfaces/mocks/ElasticsearchAggregatorInterface.go b/server/einterfaces/mocks/ElasticsearchAggregatorInterface.go index b33cef03908d..dd8774b6e856 100644 --- a/server/einterfaces/mocks/ElasticsearchAggregatorInterface.go +++ b/server/einterfaces/mocks/ElasticsearchAggregatorInterface.go @@ -5,10 +5,9 @@ package mocks import ( + model "github.com/mattermost/mattermost/server/public/model" jobs "github.com/mattermost/mattermost/server/v8/einterfaces/jobs" mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" ) // ElasticsearchAggregatorInterface is an autogenerated mock type for the ElasticsearchAggregatorInterface type diff --git a/server/einterfaces/mocks/LdapSyncInterface.go b/server/einterfaces/mocks/LdapSyncInterface.go index 7582f6050db8..ad2537af01d8 100644 --- a/server/einterfaces/mocks/LdapSyncInterface.go +++ b/server/einterfaces/mocks/LdapSyncInterface.go @@ -5,10 +5,9 @@ package mocks import ( + model "github.com/mattermost/mattermost/server/public/model" jobs "github.com/mattermost/mattermost/server/v8/einterfaces/jobs" mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" ) // LdapSyncInterface is an autogenerated mock type for the LdapSyncInterface type diff --git a/server/einterfaces/mocks/MessageExportJobInterface.go b/server/einterfaces/mocks/MessageExportJobInterface.go index 7c52631fca2a..d12b733292c6 100644 --- a/server/einterfaces/mocks/MessageExportJobInterface.go +++ b/server/einterfaces/mocks/MessageExportJobInterface.go @@ -5,10 +5,9 @@ package mocks import ( + model "github.com/mattermost/mattermost/server/public/model" jobs "github.com/mattermost/mattermost/server/v8/einterfaces/jobs" mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" ) // MessageExportJobInterface is an autogenerated mock type for the MessageExportJobInterface type diff --git a/server/einterfaces/mocks/MetricsInterface.go b/server/einterfaces/mocks/MetricsInterface.go index 610883b11615..fb6bbda40806 100644 --- a/server/einterfaces/mocks/MetricsInterface.go +++ b/server/einterfaces/mocks/MetricsInterface.go @@ -5,12 +5,11 @@ package mocks import ( - logr "github.com/mattermost/logr/v2" - mock "github.com/stretchr/testify/mock" + sql "database/sql" + logr "github.com/mattermost/logr/v2" model "github.com/mattermost/mattermost/server/public/model" - - sql "database/sql" + mock "github.com/stretchr/testify/mock" ) // MetricsInterface is an autogenerated mock type for the MetricsInterface type diff --git a/server/einterfaces/mocks/OAuthProvider.go b/server/einterfaces/mocks/OAuthProvider.go index 869a703de895..4c6766adac2e 100644 --- a/server/einterfaces/mocks/OAuthProvider.go +++ b/server/einterfaces/mocks/OAuthProvider.go @@ -8,9 +8,8 @@ import ( io "io" model "github.com/mattermost/mattermost/server/public/model" - mock "github.com/stretchr/testify/mock" - request "github.com/mattermost/mattermost/server/public/shared/request" + mock "github.com/stretchr/testify/mock" ) // OAuthProvider is an autogenerated mock type for the OAuthProvider type diff --git a/server/einterfaces/mocks/PushProxyInterface.go b/server/einterfaces/mocks/PushProxyInterface.go index d35fe5403131..09660c51ac9c 100644 --- a/server/einterfaces/mocks/PushProxyInterface.go +++ b/server/einterfaces/mocks/PushProxyInterface.go @@ -5,10 +5,9 @@ package mocks import ( + model "github.com/mattermost/mattermost/server/public/model" jobs "github.com/mattermost/mattermost/server/v8/einterfaces/jobs" mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" ) // PushProxyInterface is an autogenerated mock type for the PushProxyInterface type diff --git a/server/einterfaces/mocks/SamlInterface.go b/server/einterfaces/mocks/SamlInterface.go index 6066bc3cc6fb..bedec7f46ce5 100644 --- a/server/einterfaces/mocks/SamlInterface.go +++ b/server/einterfaces/mocks/SamlInterface.go @@ -5,11 +5,10 @@ package mocks import ( + saml2 "github.com/mattermost/gosaml2" model "github.com/mattermost/mattermost/server/public/model" request "github.com/mattermost/mattermost/server/public/shared/request" mock "github.com/stretchr/testify/mock" - - saml2 "github.com/mattermost/gosaml2" ) // SamlInterface is an autogenerated mock type for the SamlInterface type diff --git a/server/einterfaces/mocks/Scheduler.go b/server/einterfaces/mocks/Scheduler.go index 88024464dae8..f459db138ca1 100644 --- a/server/einterfaces/mocks/Scheduler.go +++ b/server/einterfaces/mocks/Scheduler.go @@ -5,11 +5,11 @@ package mocks import ( + time "time" + model "github.com/mattermost/mattermost/server/public/model" request "github.com/mattermost/mattermost/server/public/shared/request" mock "github.com/stretchr/testify/mock" - - time "time" ) // Scheduler is an autogenerated mock type for the Scheduler type diff --git a/server/i18n/en.json b/server/i18n/en.json index 043fa6f30a2c..dfe59e867ca6 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -427,6 +427,54 @@ "id": "api.channel.delete_channel.type.invalid", "translation": "Unable to delete direct or group message channels" }, + { + "id": "api.channel.discoverable_join_request.already_member.app_error", + "translation": "You are already a member of this channel." + }, + { + "id": "api.channel.discoverable_join_request.archived.app_error", + "translation": "Cannot request to join an archived channel." + }, + { + "id": "api.channel.discoverable_join_request.discoverable_requires_approval.app_error", + "translation": "This channel requires admin approval to join. Please send a request from the Browse Channels modal." + }, + { + "id": "api.channel.discoverable_join_request.duplicate.app_error", + "translation": "You already have a pending request to join this channel." + }, + { + "id": "api.channel.discoverable_join_request.feature_disabled.app_error", + "translation": "Discoverable channels are not enabled on this server." + }, + { + "id": "api.channel.discoverable_join_request.guest.app_error", + "translation": "Guests cannot request to join discoverable private channels." + }, + { + "id": "api.channel.discoverable_join_request.invalid_patch.app_error", + "translation": "Invalid update for the channel join request." + }, + { + "id": "api.channel.discoverable_join_request.not_discoverable.app_error", + "translation": "This channel is not discoverable." + }, + { + "id": "api.channel.discoverable_join_request.not_pending.app_error", + "translation": "The join request is no longer pending." + }, + { + "id": "api.channel.discoverable_join_request.not_private.app_error", + "translation": "Only private channels accept join requests." + }, + { + "id": "api.channel.discoverable_join_request.policy_denied.app_error", + "translation": "You do not satisfy the access rules required to join this channel." + }, + { + "id": "api.channel.discoverable_join_request.shared.app_error", + "translation": "Shared channels do not accept discoverable join requests." + }, { "id": "api.channel.get_channel.flagged_post_mismatch.app_error", "translation": "Channel ID does not match the channel ID of the flagged post." @@ -5634,6 +5682,22 @@ "id": "app.channel.group_message_conversion.post_message.error", "translation": "Failed to create group message to channel conversion post" }, + { + "id": "app.channel.join_request.get.app_error", + "translation": "Failed to load channel join request." + }, + { + "id": "app.channel.join_request.not_found.app_error", + "translation": "Channel join request not found." + }, + { + "id": "app.channel.join_request.save.app_error", + "translation": "Failed to save channel join request." + }, + { + "id": "app.channel.join_request.update.app_error", + "translation": "Failed to update channel join request." + }, { "id": "app.channel.migrate_channel_members.select.app_error", "translation": "Failed to select the batch of channel members." @@ -5698,6 +5762,10 @@ "id": "app.channel.restore.app_error", "translation": "Unable to restore the channel." }, + { + "id": "app.channel.restore_channel.rejected_by_plugin", + "translation": "Channel restore rejected by plugin: {{.Reason}}" + }, { "id": "app.channel.save_member.app_error", "translation": "Unable to save channel member." @@ -5742,6 +5810,14 @@ "id": "app.channel.update_channel.internal_error", "translation": "Unable to update channel." }, + { + "id": "app.channel.update_channel.plugin_type_mutation.app_error", + "translation": "Plugin {{.PluginID}} attempted to mutate channel type via ChannelWillBeUpdated; type changes must go through the dedicated type-change path" + }, + { + "id": "app.channel.update_channel.rejected_by_plugin", + "translation": "Channel update rejected by plugin: {{.Reason}}" + }, { "id": "app.channel.update_last_viewed_at.app_error", "translation": "Unable to update the last viewed at time." @@ -5762,6 +5838,26 @@ "id": "app.channel.user_belongs_to_channels.app_error", "translation": "Unable to determine if the user belongs to a list of channels." }, + { + "id": "app.channel_guard.invalid_channel.app_error", + "translation": "Channel ID is not a valid channel identifier." + }, + { + "id": "app.channel_guard.register.app_error", + "translation": "Unable to register the channel guard." + }, + { + "id": "app.channel_guard.register.empty_channel.app_error", + "translation": "Channel ID is required to register a channel guard." + }, + { + "id": "app.channel_guard.unregister.app_error", + "translation": "Unable to unregister the channel guard." + }, + { + "id": "app.channel_guard.unregister.empty_channel.app_error", + "translation": "Channel ID is required to unregister a channel guard." + }, { "id": "app.channel_member_history.log_join_event.internal_error", "translation": "Failed to record channel member history." @@ -6312,6 +6408,10 @@ "id": "app.draft.save.app_error", "translation": "Unable to save the Draft." }, + { + "id": "app.draft.upsert.rejected_by_plugin", + "translation": "Draft rejected by plugin: {{.Reason}}" + }, { "id": "app.drafts.permanent_delete_by_user.app_error", "translation": "Unable to delete drafts for user." @@ -7736,6 +7836,14 @@ "id": "app.plugin.get_statuses.app_error", "translation": "Unable to get plugin statuses." }, + { + "id": "app.plugin.guard_hook_failed.app_error", + "translation": "Operation rejected: claiming plugin {{.PluginID}} hook call failed" + }, + { + "id": "app.plugin.inactive_guard.app_error", + "translation": "Operation rejected: a required plugin is not active" + }, { "id": "app.plugin.install.app_error", "translation": "Unable to install plugin." @@ -8702,10 +8810,18 @@ "id": "app.scheduled_post.private_channel", "translation": "Private channel" }, + { + "id": "app.scheduled_post.save.rejected_by_plugin", + "translation": "Scheduled post rejected by plugin: {{.Reason}}" + }, { "id": "app.scheduled_post.unknown_channel", "translation": "Unknown Channel" }, + { + "id": "app.scheduled_post.update.rejected_by_plugin", + "translation": "Scheduled post update rejected by plugin: {{.Reason}}" + }, { "id": "app.scheme.delete.app_error", "translation": "Unable to delete this scheme." diff --git a/server/platform/services/cache/mocks/Provider.go b/server/platform/services/cache/mocks/Provider.go index 3c53f2562e32..596754e049db 100644 --- a/server/platform/services/cache/mocks/Provider.go +++ b/server/platform/services/cache/mocks/Provider.go @@ -7,7 +7,6 @@ package mocks import ( einterfaces "github.com/mattermost/mattermost/server/v8/einterfaces" cache "github.com/mattermost/mattermost/server/v8/platform/services/cache" - mock "github.com/stretchr/testify/mock" ) diff --git a/server/platform/services/searchengine/mocks/SearchEngineInterface.go b/server/platform/services/searchengine/mocks/SearchEngineInterface.go index beba45ebfa1a..d8d047b18f2f 100644 --- a/server/platform/services/searchengine/mocks/SearchEngineInterface.go +++ b/server/platform/services/searchengine/mocks/SearchEngineInterface.go @@ -8,9 +8,8 @@ import ( context "context" model "github.com/mattermost/mattermost/server/public/model" - mock "github.com/stretchr/testify/mock" - request "github.com/mattermost/mattermost/server/public/shared/request" + mock "github.com/stretchr/testify/mock" time "time" ) diff --git a/server/platform/services/sharedchannel/mock_AppIface_test.go b/server/platform/services/sharedchannel/mock_AppIface_test.go index 250d5808a7d1..763d9c9ec35f 100644 --- a/server/platform/services/sharedchannel/mock_AppIface_test.go +++ b/server/platform/services/sharedchannel/mock_AppIface_test.go @@ -5,12 +5,10 @@ package sharedchannel import ( - filestore "github.com/mattermost/mattermost/server/v8/platform/shared/filestore" - mock "github.com/stretchr/testify/mock" - model "github.com/mattermost/mattermost/server/public/model" - request "github.com/mattermost/mattermost/server/public/shared/request" + filestore "github.com/mattermost/mattermost/server/v8/platform/shared/filestore" + mock "github.com/stretchr/testify/mock" ) // MockAppIface is an autogenerated mock type for the AppIface type diff --git a/server/platform/services/sharedchannel/mock_ServerIface_test.go b/server/platform/services/sharedchannel/mock_ServerIface_test.go index d223f1df0763..5930e0795922 100644 --- a/server/platform/services/sharedchannel/mock_ServerIface_test.go +++ b/server/platform/services/sharedchannel/mock_ServerIface_test.go @@ -5,16 +5,12 @@ package sharedchannel import ( + model "github.com/mattermost/mattermost/server/public/model" mlog "github.com/mattermost/mattermost/server/public/shared/mlog" + store "github.com/mattermost/mattermost/server/v8/channels/store" einterfaces "github.com/mattermost/mattermost/server/v8/einterfaces" - - mock "github.com/stretchr/testify/mock" - - model "github.com/mattermost/mattermost/server/public/model" - remotecluster "github.com/mattermost/mattermost/server/v8/platform/services/remotecluster" - - store "github.com/mattermost/mattermost/server/v8/channels/store" + mock "github.com/stretchr/testify/mock" ) // MockServerIface is an autogenerated mock type for the ServerIface type diff --git a/server/platform/shared/filestore/mocks/FileBackend.go b/server/platform/shared/filestore/mocks/FileBackend.go index 46bfefae0501..823ac7d26b78 100644 --- a/server/platform/shared/filestore/mocks/FileBackend.go +++ b/server/platform/shared/filestore/mocks/FileBackend.go @@ -6,12 +6,10 @@ package mocks import ( io "io" + time "time" filestore "github.com/mattermost/mattermost/server/v8/platform/shared/filestore" - mock "github.com/stretchr/testify/mock" - - time "time" ) // FileBackend is an autogenerated mock type for the FileBackend type diff --git a/server/public/model/audit_events.go b/server/public/model/audit_events.go index a82ec6b7d834..489a4e1a3dd5 100644 --- a/server/public/model/audit_events.go +++ b/server/public/model/audit_events.go @@ -84,6 +84,9 @@ const ( AuditEventAddChannelMember = "addChannelMember" // add member to channel AuditEventConvertGroupMessageToChannel = "convertGroupMessageToChannel" // convert group message to private channel AuditEventCreateChannel = "createChannel" // create public or private channel + AuditEventCreateChannelJoinRequest = "createChannelJoinRequest" // request to join a discoverable private channel + AuditEventUpdateChannelJoinRequest = "updateChannelJoinRequest" // approve or deny a channel join request + AuditEventWithdrawChannelJoinRequest = "withdrawChannelJoinRequest" // requester cancels their channel join request AuditEventCreateDirectChannel = "createDirectChannel" // create direct message channel between two users AuditEventCreateGroupChannel = "createGroupChannel" // create group message channel with multiple users AuditEventDeleteChannel = "deleteChannel" // delete channel diff --git a/server/public/plugin/api.go b/server/public/plugin/api.go index 97b06dfd84a6..23eabd0bed85 100644 --- a/server/public/plugin/api.go +++ b/server/public/plugin/api.go @@ -510,6 +510,26 @@ type API interface { // Minimum server version: 5.2 UpdateChannel(channel *model.Channel) (*model.Channel, *model.AppError) + // RegisterChannelGuard claims the channel for this plugin, signaling to the server that the + // channel has plugin-managed semantics and that the server's default behaviors are unsafe + // without plugin involvement. + // + // The calling plugin's ID is implicit. Multiple plugins may co-guard the same channel; each + // claim is an independent row. Subsequent calls from the same plugin are idempotent; calls from + // a different plugin add a new claim. + // + // @tag Channel + // Minimum server version: 11.8 + RegisterChannelGuard(channelID string) *model.AppError + + // UnregisterChannelGuard releases this plugin's claim on the channel. Only the registering + // plugin can unregister its own claim; other plugins' claims on the same channel are + // unaffected. + // + // @tag Channel + // Minimum server version: 11.8 + UnregisterChannelGuard(channelID string) *model.AppError + // SearchChannels returns the channels on a team matching the provided search term. // // @tag Channel diff --git a/server/public/plugin/api_timer_layer_generated.go b/server/public/plugin/api_timer_layer_generated.go index 35818b5f6aca..c4301202d4e0 100644 --- a/server/public/plugin/api_timer_layer_generated.go +++ b/server/public/plugin/api_timer_layer_generated.go @@ -560,6 +560,20 @@ func (api *apiTimerLayer) UpdateChannel(channel *model.Channel) (*model.Channel, return _returnsA, _returnsB } +func (api *apiTimerLayer) RegisterChannelGuard(channelID string) *model.AppError { + startTime := timePkg.Now() + _returnsA := api.apiImpl.RegisterChannelGuard(channelID) + api.recordTime(startTime, "RegisterChannelGuard", _returnsA == nil) + return _returnsA +} + +func (api *apiTimerLayer) UnregisterChannelGuard(channelID string) *model.AppError { + startTime := timePkg.Now() + _returnsA := api.apiImpl.UnregisterChannelGuard(channelID) + api.recordTime(startTime, "UnregisterChannelGuard", _returnsA == nil) + return _returnsA +} + func (api *apiTimerLayer) SearchChannels(teamID string, term string) ([]*model.Channel, *model.AppError) { startTime := timePkg.Now() _returnsA, _returnsB := api.apiImpl.SearchChannels(teamID, term) diff --git a/server/public/plugin/client_rpc.go b/server/public/plugin/client_rpc.go index 196b43ed2b3e..76ce6c8825df 100644 --- a/server/public/plugin/client_rpc.go +++ b/server/public/plugin/client_rpc.go @@ -1460,6 +1460,83 @@ func (s *hooksRPCServer) ChannelMemberWillBeAdded(args *Z_ChannelMemberWillBeAdd return nil } +// MessageWillBePostedWithRPCErr returns the same values as MessageWillBePosted, with an additional +// trailing error for the RPC transport — always the LAST return slot. This hand-written companion +// exists because MessageWillBePosted is in excludedPluginHooks and therefore absent from the +// auto-generated HooksWithRPCErrGenerated interface in client_rpc_generated.go. +func (g *hooksRPCClient) MessageWillBePostedWithRPCErr(c *Context, post *model.Post) (*model.Post, string, error) { + _args := &Z_MessageWillBePostedArgs{c, post} + _returns := &Z_MessageWillBePostedReturns{} + var _err error + if g.implemented[MessageWillBePostedID] { + _err = g.client.Call("Plugin.MessageWillBePosted", _args, _returns) + if _err != nil { + // Reset _returns so partial gob decoding can't leak non-zero + // values past a transport failure (HooksWithRPCErrGenerated contract). + _returns = &Z_MessageWillBePostedReturns{} + g.log.Debug("RPC call MessageWillBePosted to plugin failed.", mlog.Err(_err)) + } + } + return _returns.A, _returns.B, _err +} + +// MessageWillBeUpdatedWithRPCErr returns the same values as MessageWillBeUpdated, with an additional +// trailing error for the RPC transport — always the LAST return slot. This hand-written companion +// exists because MessageWillBeUpdated is in excludedPluginHooks and therefore absent from the +// auto-generated HooksWithRPCErrGenerated interface in client_rpc_generated.go. +func (g *hooksRPCClient) MessageWillBeUpdatedWithRPCErr(c *Context, newPost, oldPost *model.Post) (*model.Post, string, error) { + _args := &Z_MessageWillBeUpdatedArgs{c, newPost, oldPost} + _returns := &Z_MessageWillBeUpdatedReturns{} + var _err error + if g.implemented[MessageWillBeUpdatedID] { + _err = g.client.Call("Plugin.MessageWillBeUpdated", _args, _returns) + if _err != nil { + // Reset _returns so partial gob decoding can't leak non-zero + // values past a transport failure (HooksWithRPCErrGenerated contract). + _returns = &Z_MessageWillBeUpdatedReturns{} + g.log.Debug("RPC call MessageWillBeUpdated to plugin failed.", mlog.Err(_err)) + } + } + return _returns.A, _returns.B, _err +} + +// ChannelMemberWillBeAddedWithRPCErr returns the same values as ChannelMemberWillBeAdded, with an +// additional trailing error for the RPC transport — always the LAST return slot. This hand-written +// companion exists because ChannelMemberWillBeAdded is in excludedPluginHooks and therefore absent +// from the auto-generated HooksWithRPCErrGenerated interface in client_rpc_generated.go. +func (g *hooksRPCClient) ChannelMemberWillBeAddedWithRPCErr(c *Context, channelMember *model.ChannelMember) (*model.ChannelMember, string, error) { + _args := &Z_ChannelMemberWillBeAddedArgs{c, channelMember} + _returns := &Z_ChannelMemberWillBeAddedReturns{} + var _err error + if g.implemented[ChannelMemberWillBeAddedID] { + _err = g.client.Call("Plugin.ChannelMemberWillBeAdded", _args, _returns) + if _err != nil { + // Reset _returns so partial gob decoding can't leak non-zero + // values past a transport failure (HooksWithRPCErrGenerated contract). + _returns = &Z_ChannelMemberWillBeAddedReturns{} + g.log.Debug("RPC call ChannelMemberWillBeAdded to plugin failed.", mlog.Err(_err)) + } + } + return _returns.A, _returns.B, _err +} + +// HooksWithRPCErr extends HooksWithRPCErrGenerated with *WithRPCErr companions for the three hooks whose +// base stubs are hand-written in this file. The auto-generated HooksWithRPCErrGenerated in +// client_rpc_generated.go cannot include these because the generator skips excluded hooks. +// Returned by Environment.HooksForPluginWithRPCErr so callers can invoke any *WithRPCErr method +// without a type assertion. +type HooksWithRPCErr interface { + HooksWithRPCErrGenerated + MessageWillBePostedWithRPCErr(c *Context, post *model.Post) (*model.Post, string, error) + MessageWillBeUpdatedWithRPCErr(c *Context, newPost, oldPost *model.Post) (*model.Post, string, error) + ChannelMemberWillBeAddedWithRPCErr(c *Context, channelMember *model.ChannelMember) (*model.ChannelMember, string, error) +} + +var ( + _ HooksWithRPCErr = (*hooksRPCClient)(nil) + _ HooksWithRPCErr = (*hooksTimerLayer)(nil) +) + // TeamMemberWillBeAdded is hand-written to preserve the original TeamMember as the default // return value, avoiding unintentional field removal by older plugins. func init() { diff --git a/server/public/plugin/client_rpc_generated.go b/server/public/plugin/client_rpc_generated.go index 5cf2fd522065..412646599d8c 100644 --- a/server/public/plugin/client_rpc_generated.go +++ b/server/public/plugin/client_rpc_generated.go @@ -47,7 +47,7 @@ func (g *hooksRPCClient) OnDeactivateWithRPCErr() (error, error) { _err = g.client.Call("Plugin.OnDeactivate", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnDeactivateReturns{} g.log.Debug("RPC call OnDeactivate to plugin failed.", mlog.Err(_err)) } @@ -99,7 +99,7 @@ func (g *hooksRPCClient) OnConfigurationChangeWithRPCErr() (error, error) { _err = g.client.Call("Plugin.OnConfigurationChange", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnConfigurationChangeReturns{} g.log.Debug("RPC call OnConfigurationChange to plugin failed.", mlog.Err(_err)) } @@ -154,7 +154,7 @@ func (g *hooksRPCClient) ExecuteCommandWithRPCErr(c *Context, args *model.Comman _err = g.client.Call("Plugin.ExecuteCommand", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_ExecuteCommandReturns{} g.log.Debug("RPC call ExecuteCommand to plugin failed.", mlog.Err(_err)) } @@ -206,7 +206,7 @@ func (g *hooksRPCClient) UserHasBeenCreatedWithRPCErr(c *Context, user *model.Us _err = g.client.Call("Plugin.UserHasBeenCreated", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_UserHasBeenCreatedReturns{} g.log.Debug("RPC call UserHasBeenCreated to plugin failed.", mlog.Err(_err)) } @@ -259,7 +259,7 @@ func (g *hooksRPCClient) UserWillLogInWithRPCErr(c *Context, user *model.User) ( _err = g.client.Call("Plugin.UserWillLogIn", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_UserWillLogInReturns{} g.log.Debug("RPC call UserWillLogIn to plugin failed.", mlog.Err(_err)) } @@ -311,7 +311,7 @@ func (g *hooksRPCClient) UserHasLoggedInWithRPCErr(c *Context, user *model.User) _err = g.client.Call("Plugin.UserHasLoggedIn", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_UserHasLoggedInReturns{} g.log.Debug("RPC call UserHasLoggedIn to plugin failed.", mlog.Err(_err)) } @@ -363,7 +363,7 @@ func (g *hooksRPCClient) MessageHasBeenPostedWithRPCErr(c *Context, post *model. _err = g.client.Call("Plugin.MessageHasBeenPosted", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_MessageHasBeenPostedReturns{} g.log.Debug("RPC call MessageHasBeenPosted to plugin failed.", mlog.Err(_err)) } @@ -416,7 +416,7 @@ func (g *hooksRPCClient) MessageHasBeenUpdatedWithRPCErr(c *Context, newPost, ol _err = g.client.Call("Plugin.MessageHasBeenUpdated", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_MessageHasBeenUpdatedReturns{} g.log.Debug("RPC call MessageHasBeenUpdated to plugin failed.", mlog.Err(_err)) } @@ -468,7 +468,7 @@ func (g *hooksRPCClient) MessageHasBeenDeletedWithRPCErr(c *Context, post *model _err = g.client.Call("Plugin.MessageHasBeenDeleted", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_MessageHasBeenDeletedReturns{} g.log.Debug("RPC call MessageHasBeenDeleted to plugin failed.", mlog.Err(_err)) } @@ -520,7 +520,7 @@ func (g *hooksRPCClient) ChannelHasBeenCreatedWithRPCErr(c *Context, channel *mo _err = g.client.Call("Plugin.ChannelHasBeenCreated", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_ChannelHasBeenCreatedReturns{} g.log.Debug("RPC call ChannelHasBeenCreated to plugin failed.", mlog.Err(_err)) } @@ -573,7 +573,7 @@ func (g *hooksRPCClient) ChannelWillBeArchivedWithRPCErr(c *Context, channel *mo _err = g.client.Call("Plugin.ChannelWillBeArchived", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_ChannelWillBeArchivedReturns{} g.log.Debug("RPC call ChannelWillBeArchived to plugin failed.", mlog.Err(_err)) } @@ -626,7 +626,7 @@ func (g *hooksRPCClient) UserHasJoinedChannelWithRPCErr(c *Context, channelMembe _err = g.client.Call("Plugin.UserHasJoinedChannel", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_UserHasJoinedChannelReturns{} g.log.Debug("RPC call UserHasJoinedChannel to plugin failed.", mlog.Err(_err)) } @@ -679,7 +679,7 @@ func (g *hooksRPCClient) UserHasLeftChannelWithRPCErr(c *Context, channelMember _err = g.client.Call("Plugin.UserHasLeftChannel", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_UserHasLeftChannelReturns{} g.log.Debug("RPC call UserHasLeftChannel to plugin failed.", mlog.Err(_err)) } @@ -732,7 +732,7 @@ func (g *hooksRPCClient) UserHasJoinedTeamWithRPCErr(c *Context, teamMember *mod _err = g.client.Call("Plugin.UserHasJoinedTeam", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_UserHasJoinedTeamReturns{} g.log.Debug("RPC call UserHasJoinedTeam to plugin failed.", mlog.Err(_err)) } @@ -785,7 +785,7 @@ func (g *hooksRPCClient) UserHasLeftTeamWithRPCErr(c *Context, teamMember *model _err = g.client.Call("Plugin.UserHasLeftTeam", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_UserHasLeftTeamReturns{} g.log.Debug("RPC call UserHasLeftTeam to plugin failed.", mlog.Err(_err)) } @@ -840,7 +840,7 @@ func (g *hooksRPCClient) FileWillBeDownloadedWithRPCErr(c *Context, fileInfo *mo _err = g.client.Call("Plugin.FileWillBeDownloaded", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_FileWillBeDownloadedReturns{} g.log.Debug("RPC call FileWillBeDownloaded to plugin failed.", mlog.Err(_err)) } @@ -892,7 +892,7 @@ func (g *hooksRPCClient) ReactionHasBeenAddedWithRPCErr(c *Context, reaction *mo _err = g.client.Call("Plugin.ReactionHasBeenAdded", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_ReactionHasBeenAddedReturns{} g.log.Debug("RPC call ReactionHasBeenAdded to plugin failed.", mlog.Err(_err)) } @@ -944,7 +944,7 @@ func (g *hooksRPCClient) ReactionHasBeenRemovedWithRPCErr(c *Context, reaction * _err = g.client.Call("Plugin.ReactionHasBeenRemoved", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_ReactionHasBeenRemovedReturns{} g.log.Debug("RPC call ReactionHasBeenRemoved to plugin failed.", mlog.Err(_err)) } @@ -996,7 +996,7 @@ func (g *hooksRPCClient) OnPluginClusterEventWithRPCErr(c *Context, ev model.Plu _err = g.client.Call("Plugin.OnPluginClusterEvent", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnPluginClusterEventReturns{} g.log.Debug("RPC call OnPluginClusterEvent to plugin failed.", mlog.Err(_err)) } @@ -1048,7 +1048,7 @@ func (g *hooksRPCClient) OnWebSocketConnectWithRPCErr(webConnID, userID string) _err = g.client.Call("Plugin.OnWebSocketConnect", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnWebSocketConnectReturns{} g.log.Debug("RPC call OnWebSocketConnect to plugin failed.", mlog.Err(_err)) } @@ -1100,7 +1100,7 @@ func (g *hooksRPCClient) OnWebSocketDisconnectWithRPCErr(webConnID, userID strin _err = g.client.Call("Plugin.OnWebSocketDisconnect", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnWebSocketDisconnectReturns{} g.log.Debug("RPC call OnWebSocketDisconnect to plugin failed.", mlog.Err(_err)) } @@ -1153,7 +1153,7 @@ func (g *hooksRPCClient) WebSocketMessageHasBeenPostedWithRPCErr(webConnID, user _err = g.client.Call("Plugin.WebSocketMessageHasBeenPosted", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_WebSocketMessageHasBeenPostedReturns{} g.log.Debug("RPC call WebSocketMessageHasBeenPosted to plugin failed.", mlog.Err(_err)) } @@ -1207,7 +1207,7 @@ func (g *hooksRPCClient) RunDataRetentionWithRPCErr(nowTime, batchSize int64) (i _err = g.client.Call("Plugin.RunDataRetention", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_RunDataRetentionReturns{} g.log.Debug("RPC call RunDataRetention to plugin failed.", mlog.Err(_err)) } @@ -1261,7 +1261,7 @@ func (g *hooksRPCClient) OnInstallWithRPCErr(c *Context, event model.OnInstallEv _err = g.client.Call("Plugin.OnInstall", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnInstallReturns{} g.log.Debug("RPC call OnInstall to plugin failed.", mlog.Err(_err)) } @@ -1312,7 +1312,7 @@ func (g *hooksRPCClient) OnSendDailyTelemetryWithRPCErr() error { _err = g.client.Call("Plugin.OnSendDailyTelemetry", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnSendDailyTelemetryReturns{} g.log.Debug("RPC call OnSendDailyTelemetry to plugin failed.", mlog.Err(_err)) } @@ -1363,7 +1363,7 @@ func (g *hooksRPCClient) OnCloudLimitsUpdatedWithRPCErr(limits *model.ProductLim _err = g.client.Call("Plugin.OnCloudLimitsUpdated", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnCloudLimitsUpdatedReturns{} g.log.Debug("RPC call OnCloudLimitsUpdated to plugin failed.", mlog.Err(_err)) } @@ -1416,7 +1416,7 @@ func (g *hooksRPCClient) ConfigurationWillBeSavedWithRPCErr(newCfg *model.Config _err = g.client.Call("Plugin.ConfigurationWillBeSaved", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_ConfigurationWillBeSavedReturns{} g.log.Debug("RPC call ConfigurationWillBeSaved to plugin failed.", mlog.Err(_err)) } @@ -1470,7 +1470,7 @@ func (g *hooksRPCClient) EmailNotificationWillBeSentWithRPCErr(emailNotification _err = g.client.Call("Plugin.EmailNotificationWillBeSent", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_EmailNotificationWillBeSentReturns{} g.log.Debug("RPC call EmailNotificationWillBeSent to plugin failed.", mlog.Err(_err)) } @@ -1524,7 +1524,7 @@ func (g *hooksRPCClient) NotificationWillBePushedWithRPCErr(pushNotification *mo _err = g.client.Call("Plugin.NotificationWillBePushed", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_NotificationWillBePushedReturns{} g.log.Debug("RPC call NotificationWillBePushed to plugin failed.", mlog.Err(_err)) } @@ -1576,7 +1576,7 @@ func (g *hooksRPCClient) UserHasBeenDeactivatedWithRPCErr(c *Context, user *mode _err = g.client.Call("Plugin.UserHasBeenDeactivated", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_UserHasBeenDeactivatedReturns{} g.log.Debug("RPC call UserHasBeenDeactivated to plugin failed.", mlog.Err(_err)) } @@ -1630,7 +1630,7 @@ func (g *hooksRPCClient) OnSharedChannelsSyncMsgWithRPCErr(msg *model.SyncMsg, r _err = g.client.Call("Plugin.OnSharedChannelsSyncMsg", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnSharedChannelsSyncMsgReturns{} g.log.Debug("RPC call OnSharedChannelsSyncMsg to plugin failed.", mlog.Err(_err)) } @@ -1683,7 +1683,7 @@ func (g *hooksRPCClient) OnSharedChannelsPingWithRPCErr(rc *model.RemoteCluster) _err = g.client.Call("Plugin.OnSharedChannelsPing", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnSharedChannelsPingReturns{} g.log.Debug("RPC call OnSharedChannelsPing to plugin failed.", mlog.Err(_err)) } @@ -1735,7 +1735,7 @@ func (g *hooksRPCClient) PreferencesHaveChangedWithRPCErr(c *Context, preference _err = g.client.Call("Plugin.PreferencesHaveChanged", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_PreferencesHaveChangedReturns{} g.log.Debug("RPC call PreferencesHaveChanged to plugin failed.", mlog.Err(_err)) } @@ -1789,7 +1789,7 @@ func (g *hooksRPCClient) OnSharedChannelsAttachmentSyncMsgWithRPCErr(fi *model.F _err = g.client.Call("Plugin.OnSharedChannelsAttachmentSyncMsg", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnSharedChannelsAttachmentSyncMsgReturns{} g.log.Debug("RPC call OnSharedChannelsAttachmentSyncMsg to plugin failed.", mlog.Err(_err)) } @@ -1843,7 +1843,7 @@ func (g *hooksRPCClient) OnSharedChannelsProfileImageSyncMsgWithRPCErr(user *mod _err = g.client.Call("Plugin.OnSharedChannelsProfileImageSyncMsg", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnSharedChannelsProfileImageSyncMsgReturns{} g.log.Debug("RPC call OnSharedChannelsProfileImageSyncMsg to plugin failed.", mlog.Err(_err)) } @@ -1897,7 +1897,7 @@ func (g *hooksRPCClient) GenerateSupportDataWithRPCErr(c *Context) ([]*model.Fil _err = g.client.Call("Plugin.GenerateSupportData", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_GenerateSupportDataReturns{} g.log.Debug("RPC call GenerateSupportData to plugin failed.", mlog.Err(_err)) } @@ -1952,7 +1952,7 @@ func (g *hooksRPCClient) OnSAMLLoginWithRPCErr(c *Context, user *model.User, ass _err = g.client.Call("Plugin.OnSAMLLogin", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &Z_OnSAMLLoginReturns{} g.log.Debug("RPC call OnSAMLLogin to plugin failed.", mlog.Err(_err)) } @@ -1972,7 +1972,223 @@ func (s *hooksRPCServer) OnSAMLLogin(args *Z_OnSAMLLoginArgs, returns *Z_OnSAMLL return nil } -// HooksWithRPCErr provides a WithRPCErr variant for every generated hook. The last error return +func init() { + hookNameToId["ChannelWillBeUpdated"] = ChannelWillBeUpdatedID +} + +type Z_ChannelWillBeUpdatedArgs struct { + A *Context + B *model.Channel + C *model.Channel +} + +type Z_ChannelWillBeUpdatedReturns struct { + A *model.Channel + B string +} + +func (g *hooksRPCClient) ChannelWillBeUpdated(c *Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) { + _args := &Z_ChannelWillBeUpdatedArgs{c, newChannel, oldChannel} + _returns := &Z_ChannelWillBeUpdatedReturns{} + if g.implemented[ChannelWillBeUpdatedID] { + if err := g.client.Call("Plugin.ChannelWillBeUpdated", _args, _returns); err != nil { + g.log.Error("RPC call ChannelWillBeUpdated to plugin failed.", mlog.Err(err)) + } + } + return _returns.A, _returns.B +} + +// ChannelWillBeUpdatedWithRPCErr returns the same values as ChannelWillBeUpdated, with an additional trailing error +// for the RPC transport — always the LAST return slot. +func (g *hooksRPCClient) ChannelWillBeUpdatedWithRPCErr(c *Context, newChannel, oldChannel *model.Channel) (*model.Channel, string, error) { + _args := &Z_ChannelWillBeUpdatedArgs{c, newChannel, oldChannel} + _returns := &Z_ChannelWillBeUpdatedReturns{} + var _err error + if g.implemented[ChannelWillBeUpdatedID] { + _err = g.client.Call("Plugin.ChannelWillBeUpdated", _args, _returns) + if _err != nil { + // Reset _returns so partial gob decoding can't leak non-zero + // values past a transport failure (HooksWithRPCErrGenerated contract). + _returns = &Z_ChannelWillBeUpdatedReturns{} + g.log.Debug("RPC call ChannelWillBeUpdated to plugin failed.", mlog.Err(_err)) + } + } + return _returns.A, _returns.B, _err +} + +func (s *hooksRPCServer) ChannelWillBeUpdated(args *Z_ChannelWillBeUpdatedArgs, returns *Z_ChannelWillBeUpdatedReturns) error { + if hook, ok := s.impl.(interface { + ChannelWillBeUpdated(c *Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) + }); ok { + returns.A, returns.B = hook.ChannelWillBeUpdated(args.A, args.B, args.C) + } else { + return encodableError(fmt.Errorf("Hook ChannelWillBeUpdated called but not implemented.")) + } + return nil +} + +func init() { + hookNameToId["ChannelWillBeRestored"] = ChannelWillBeRestoredID +} + +type Z_ChannelWillBeRestoredArgs struct { + A *Context + B *model.Channel +} + +type Z_ChannelWillBeRestoredReturns struct { + A string +} + +func (g *hooksRPCClient) ChannelWillBeRestored(c *Context, channel *model.Channel) string { + _args := &Z_ChannelWillBeRestoredArgs{c, channel} + _returns := &Z_ChannelWillBeRestoredReturns{} + if g.implemented[ChannelWillBeRestoredID] { + if err := g.client.Call("Plugin.ChannelWillBeRestored", _args, _returns); err != nil { + g.log.Error("RPC call ChannelWillBeRestored to plugin failed.", mlog.Err(err)) + } + } + return _returns.A +} + +// ChannelWillBeRestoredWithRPCErr returns the same values as ChannelWillBeRestored, with an additional trailing error +// for the RPC transport — always the LAST return slot. +func (g *hooksRPCClient) ChannelWillBeRestoredWithRPCErr(c *Context, channel *model.Channel) (string, error) { + _args := &Z_ChannelWillBeRestoredArgs{c, channel} + _returns := &Z_ChannelWillBeRestoredReturns{} + var _err error + if g.implemented[ChannelWillBeRestoredID] { + _err = g.client.Call("Plugin.ChannelWillBeRestored", _args, _returns) + if _err != nil { + // Reset _returns so partial gob decoding can't leak non-zero + // values past a transport failure (HooksWithRPCErrGenerated contract). + _returns = &Z_ChannelWillBeRestoredReturns{} + g.log.Debug("RPC call ChannelWillBeRestored to plugin failed.", mlog.Err(_err)) + } + } + return _returns.A, _err +} + +func (s *hooksRPCServer) ChannelWillBeRestored(args *Z_ChannelWillBeRestoredArgs, returns *Z_ChannelWillBeRestoredReturns) error { + if hook, ok := s.impl.(interface { + ChannelWillBeRestored(c *Context, channel *model.Channel) string + }); ok { + returns.A = hook.ChannelWillBeRestored(args.A, args.B) + } else { + return encodableError(fmt.Errorf("Hook ChannelWillBeRestored called but not implemented.")) + } + return nil +} + +func init() { + hookNameToId["ScheduledPostWillBeCreated"] = ScheduledPostWillBeCreatedID +} + +type Z_ScheduledPostWillBeCreatedArgs struct { + A *Context + B *model.ScheduledPost +} + +type Z_ScheduledPostWillBeCreatedReturns struct { + A *model.ScheduledPost + B string +} + +func (g *hooksRPCClient) ScheduledPostWillBeCreated(c *Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) { + _args := &Z_ScheduledPostWillBeCreatedArgs{c, scheduledPost} + _returns := &Z_ScheduledPostWillBeCreatedReturns{} + if g.implemented[ScheduledPostWillBeCreatedID] { + if err := g.client.Call("Plugin.ScheduledPostWillBeCreated", _args, _returns); err != nil { + g.log.Error("RPC call ScheduledPostWillBeCreated to plugin failed.", mlog.Err(err)) + } + } + return _returns.A, _returns.B +} + +// ScheduledPostWillBeCreatedWithRPCErr returns the same values as ScheduledPostWillBeCreated, with an additional trailing error +// for the RPC transport — always the LAST return slot. +func (g *hooksRPCClient) ScheduledPostWillBeCreatedWithRPCErr(c *Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string, error) { + _args := &Z_ScheduledPostWillBeCreatedArgs{c, scheduledPost} + _returns := &Z_ScheduledPostWillBeCreatedReturns{} + var _err error + if g.implemented[ScheduledPostWillBeCreatedID] { + _err = g.client.Call("Plugin.ScheduledPostWillBeCreated", _args, _returns) + if _err != nil { + // Reset _returns so partial gob decoding can't leak non-zero + // values past a transport failure (HooksWithRPCErrGenerated contract). + _returns = &Z_ScheduledPostWillBeCreatedReturns{} + g.log.Debug("RPC call ScheduledPostWillBeCreated to plugin failed.", mlog.Err(_err)) + } + } + return _returns.A, _returns.B, _err +} + +func (s *hooksRPCServer) ScheduledPostWillBeCreated(args *Z_ScheduledPostWillBeCreatedArgs, returns *Z_ScheduledPostWillBeCreatedReturns) error { + if hook, ok := s.impl.(interface { + ScheduledPostWillBeCreated(c *Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) + }); ok { + returns.A, returns.B = hook.ScheduledPostWillBeCreated(args.A, args.B) + } else { + return encodableError(fmt.Errorf("Hook ScheduledPostWillBeCreated called but not implemented.")) + } + return nil +} + +func init() { + hookNameToId["DraftWillBeUpserted"] = DraftWillBeUpsertedID +} + +type Z_DraftWillBeUpsertedArgs struct { + A *Context + B *model.Draft +} + +type Z_DraftWillBeUpsertedReturns struct { + A *model.Draft + B string +} + +func (g *hooksRPCClient) DraftWillBeUpserted(c *Context, draft *model.Draft) (*model.Draft, string) { + _args := &Z_DraftWillBeUpsertedArgs{c, draft} + _returns := &Z_DraftWillBeUpsertedReturns{} + if g.implemented[DraftWillBeUpsertedID] { + if err := g.client.Call("Plugin.DraftWillBeUpserted", _args, _returns); err != nil { + g.log.Error("RPC call DraftWillBeUpserted to plugin failed.", mlog.Err(err)) + } + } + return _returns.A, _returns.B +} + +// DraftWillBeUpsertedWithRPCErr returns the same values as DraftWillBeUpserted, with an additional trailing error +// for the RPC transport — always the LAST return slot. +func (g *hooksRPCClient) DraftWillBeUpsertedWithRPCErr(c *Context, draft *model.Draft) (*model.Draft, string, error) { + _args := &Z_DraftWillBeUpsertedArgs{c, draft} + _returns := &Z_DraftWillBeUpsertedReturns{} + var _err error + if g.implemented[DraftWillBeUpsertedID] { + _err = g.client.Call("Plugin.DraftWillBeUpserted", _args, _returns) + if _err != nil { + // Reset _returns so partial gob decoding can't leak non-zero + // values past a transport failure (HooksWithRPCErrGenerated contract). + _returns = &Z_DraftWillBeUpsertedReturns{} + g.log.Debug("RPC call DraftWillBeUpserted to plugin failed.", mlog.Err(_err)) + } + } + return _returns.A, _returns.B, _err +} + +func (s *hooksRPCServer) DraftWillBeUpserted(args *Z_DraftWillBeUpsertedArgs, returns *Z_DraftWillBeUpsertedReturns) error { + if hook, ok := s.impl.(interface { + DraftWillBeUpserted(c *Context, draft *model.Draft) (*model.Draft, string) + }); ok { + returns.A, returns.B = hook.DraftWillBeUpserted(args.A, args.B) + } else { + return encodableError(fmt.Errorf("Hook DraftWillBeUpserted called but not implemented.")) + } + return nil +} + +// HooksWithRPCErrGenerated provides a WithRPCErr variant for every generated hook. The last error return // is always the RPC transport error — if non-nil, the plugin's other return values are zero. For // hooks whose base signature already returns error, the tuple is (originalReturns..., rpcErr) // where the final slot is always transport. @@ -1981,8 +2197,8 @@ func (s *hooksRPCServer) OnSAMLLogin(args *Z_OnSAMLLoginArgs, returns *Z_OnSAMLL // indistinguishable from a successful invocation that returned zeros. Callers MUST gate on // supervisor.Implements() (or use Environment.RunMultiPluginHookWithRPCErr, which gates // by the iteration's hook ID — note that any *WithRPCErr method called on the closure's -// HooksWithRPCErr is independently subject to its own implemented-gate). -type HooksWithRPCErr interface { +// HooksWithRPCErrGenerated is independently subject to its own implemented-gate). +type HooksWithRPCErrGenerated interface { OnDeactivateWithRPCErr() (error, error) OnConfigurationChangeWithRPCErr() (error, error) @@ -2056,6 +2272,14 @@ type HooksWithRPCErr interface { GenerateSupportDataWithRPCErr(c *Context) ([]*model.FileData, error, error) OnSAMLLoginWithRPCErr(c *Context, user *model.User, assertion *saml2.AssertionInfo) (error, error) + + ChannelWillBeUpdatedWithRPCErr(c *Context, newChannel, oldChannel *model.Channel) (*model.Channel, string, error) + + ChannelWillBeRestoredWithRPCErr(c *Context, channel *model.Channel) (string, error) + + ScheduledPostWillBeCreatedWithRPCErr(c *Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string, error) + + DraftWillBeUpsertedWithRPCErr(c *Context, draft *model.Draft) (*model.Draft, string, error) } type Z_RegisterCommandArgs struct { @@ -4240,6 +4464,62 @@ func (s *apiRPCServer) UpdateChannel(args *Z_UpdateChannelArgs, returns *Z_Updat return nil } +type Z_RegisterChannelGuardArgs struct { + A string +} + +type Z_RegisterChannelGuardReturns struct { + A *model.AppError +} + +func (g *apiRPCClient) RegisterChannelGuard(channelID string) *model.AppError { + _args := &Z_RegisterChannelGuardArgs{channelID} + _returns := &Z_RegisterChannelGuardReturns{} + if err := g.client.Call("Plugin.RegisterChannelGuard", _args, _returns); err != nil { + log.Printf("RPC call to RegisterChannelGuard API failed: %s", err.Error()) + } + return _returns.A +} + +func (s *apiRPCServer) RegisterChannelGuard(args *Z_RegisterChannelGuardArgs, returns *Z_RegisterChannelGuardReturns) error { + if hook, ok := s.impl.(interface { + RegisterChannelGuard(channelID string) *model.AppError + }); ok { + returns.A = hook.RegisterChannelGuard(args.A) + } else { + return encodableError(fmt.Errorf("API RegisterChannelGuard called but not implemented.")) + } + return nil +} + +type Z_UnregisterChannelGuardArgs struct { + A string +} + +type Z_UnregisterChannelGuardReturns struct { + A *model.AppError +} + +func (g *apiRPCClient) UnregisterChannelGuard(channelID string) *model.AppError { + _args := &Z_UnregisterChannelGuardArgs{channelID} + _returns := &Z_UnregisterChannelGuardReturns{} + if err := g.client.Call("Plugin.UnregisterChannelGuard", _args, _returns); err != nil { + log.Printf("RPC call to UnregisterChannelGuard API failed: %s", err.Error()) + } + return _returns.A +} + +func (s *apiRPCServer) UnregisterChannelGuard(args *Z_UnregisterChannelGuardArgs, returns *Z_UnregisterChannelGuardReturns) error { + if hook, ok := s.impl.(interface { + UnregisterChannelGuard(channelID string) *model.AppError + }); ok { + returns.A = hook.UnregisterChannelGuard(args.A) + } else { + return encodableError(fmt.Errorf("API UnregisterChannelGuard called but not implemented.")) + } + return nil +} + type Z_SearchChannelsArgs struct { A string B string diff --git a/server/public/plugin/environment.go b/server/public/plugin/environment.go index e37e50565f75..bdb135fefe69 100644 --- a/server/public/plugin/environment.go +++ b/server/public/plugin/environment.go @@ -8,6 +8,7 @@ import ( "hash/fnv" "os" "path/filepath" + "slices" "sync" "time" @@ -595,6 +596,19 @@ func (env *Environment) HooksForPlugin(id string) (Hooks, error) { return nil, fmt.Errorf("plugin not found: %v", id) } +// HooksForPluginWithRPCErr returns the full *WithRPCErr hook surface for the named plugin. +// Returns an error if the plugin is not found or not active. +func (env *Environment) HooksForPluginWithRPCErr(id string) (HooksWithRPCErr, error) { + if p, ok := env.registeredPlugins.Load(id); ok { + rp := p.(registeredPlugin) + if rp.supervisor != nil && env.IsActive(id) { + return rp.supervisor.HooksWithRPCErr(), nil + } + } + + return nil, fmt.Errorf("plugin not found: %v", id) +} + // RunMultiPluginHook invokes hookRunnerFunc for each active plugin that implements the given hookId. // // If hookRunnerFunc returns false, iteration will not continue. The iteration order among active @@ -626,9 +640,47 @@ func (env *Environment) RunMultiPluginHook(hookRunnerFunc func(hooks Hooks, mani } } +// RunMultiPluginHookExcluding is like RunMultiPluginHook but skips plugins whose IDs appear in +// excludePluginIDs, otherwise the semantics are the same as RunMultiPluginHook. The exclusion check +// is a linear scan. +func (env *Environment) RunMultiPluginHookExcluding( + excludePluginIDs []string, + hookRunnerFunc func(hooks Hooks, manifest *model.Manifest) bool, + hookId int, +) { + startTime := time.Now() + + env.registeredPlugins.Range(func(key, value any) bool { + rp := value.(registeredPlugin) + id := rp.BundleInfo.Manifest.Id + if slices.Contains(excludePluginIDs, id) { + return true + } + + if rp.supervisor == nil || !rp.supervisor.Implements(hookId) || !env.IsActive(id) { + return true + } + + hookStartTime := time.Now() + cont := hookRunnerFunc(rp.supervisor.Hooks(), rp.BundleInfo.Manifest) + + if env.metrics != nil { + elapsedTime := float64(time.Since(hookStartTime)) / float64(time.Second) + env.metrics.ObservePluginMultiHookIterationDuration(id, elapsedTime) + } + + return cont + }) + + if env.metrics != nil { + elapsedTime := float64(time.Since(startTime)) / float64(time.Second) + env.metrics.ObservePluginMultiHookDuration(elapsedTime) + } +} + // RunMultiPluginHookWithRPCErr is like RunMultiPluginHook but surfaces RPC transport errors. The -// closure receives a HooksWithRPCErr so it can call *WithRPCErr variants. Iteration stops on the first -// non-nil error returned by the closure. +// closure receives a HooksWithRPCErr so it can call any *WithRPCErr variant. Iteration stops on the +// first non-nil error returned by the closure. func (env *Environment) RunMultiPluginHookWithRPCErr(hookRunnerFunc func(hooks HooksWithRPCErr, manifest *model.Manifest) (bool, error), hookId int) error { startTime := time.Now() var retErr error diff --git a/server/public/plugin/environment_with_rpcerr_test.go b/server/public/plugin/environment_with_rpcerr_test.go index 05e200ec79ec..aa4a3856badc 100644 --- a/server/public/plugin/environment_with_rpcerr_test.go +++ b/server/public/plugin/environment_with_rpcerr_test.go @@ -18,14 +18,6 @@ import ( "github.com/mattermost/mattermost/server/public/shared/mlog" ) -// Both the wire-level client and the metrics-wrapping layer returned by -// supervisor.Hooks() must implement HooksWithRPCErr — RunMultiPluginHookWithRPCErr's -// type assertion targets the latter. -var ( - _ HooksWithRPCErr = (*hooksRPCClient)(nil) - _ HooksWithRPCErr = (*hooksTimerLayer)(nil) -) - func TestRunMultiPluginHookWithRPCErr(t *testing.T) { pluginDir, err := os.MkdirTemp("", "mm-rpcerr-plugin") require.NoError(t, err) diff --git a/server/public/plugin/hooks.go b/server/public/plugin/hooks.go index 58f77e5f559d..2ea23301b6e2 100644 --- a/server/public/plugin/hooks.go +++ b/server/public/plugin/hooks.go @@ -68,6 +68,10 @@ const ( ChannelMemberWillBeAddedID = 49 TeamMemberWillBeAddedID = 50 ChannelWillBeArchivedID = 51 + ChannelWillBeUpdatedID = 52 + ChannelWillBeRestoredID = 53 + ScheduledPostWillBeCreatedID = 54 + DraftWillBeUpsertedID = 55 TotalHooksID = iota ) @@ -465,4 +469,42 @@ type Hooks interface { // // Minimum server version: 10.7 OnSAMLLogin(c *Context, user *model.User, assertion *saml2.AssertionInfo) error + + // ChannelWillBeUpdated is invoked before a channel update is committed, allowing plugins to + // modify the channel or reject the update. + // + // To reject the update, return a non-empty string describing why. To modify the channel, return + // the replacement *model.Channel and an empty string. To allow the update without modification, + // return nil and an empty string. + // + // Fires from the app-layer UpdateChannel and PatchChannel paths so REST, local API, plugin API, + // import, and bulk callers all hit it. + // + // Minimum server version: 11.8 + ChannelWillBeUpdated(c *Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) + + // ChannelWillBeRestored is invoked before an archived channel is un-archived. Fires from + // app.RestoreChannel before the store's Channel().Restore call. Sibling of + // ChannelWillBeArchived for the inverse operation. + // + // To reject, return a non-empty string. Empty string allows the restore. + // + // Minimum server version: 11.8 + ChannelWillBeRestored(c *Context, channel *model.Channel) string + + // ScheduledPostWillBeCreated is invoked before a scheduled post is committed. Fires from the + // app-layer SaveScheduledPost and UpdateScheduledPost paths. + // + // Return value semantics match MessageWillBePosted. + // + // Minimum server version: 11.8 + ScheduledPostWillBeCreated(c *Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) + + // DraftWillBeUpserted is invoked before a draft is committed. Fires from the app-layer + // UpsertDraft path. + // + // Return value semantics match MessageWillBePosted. + // + // Minimum server version: 11.8 + DraftWillBeUpserted(c *Context, draft *model.Draft) (*model.Draft, string) } diff --git a/server/public/plugin/hooks_timer_layer_generated.go b/server/public/plugin/hooks_timer_layer_generated.go index dee068b975f1..a219828db2b4 100644 --- a/server/public/plugin/hooks_timer_layer_generated.go +++ b/server/public/plugin/hooks_timer_layer_generated.go @@ -336,6 +336,34 @@ func (hooks *hooksTimerLayer) OnSAMLLogin(c *Context, user *model.User, assertio return _returnsA } +func (hooks *hooksTimerLayer) ChannelWillBeUpdated(c *Context, newChannel, oldChannel *model.Channel) (*model.Channel, string) { + startTime := timePkg.Now() + _returnsA, _returnsB := hooks.hooksImpl.ChannelWillBeUpdated(c, newChannel, oldChannel) + hooks.recordTime(startTime, "ChannelWillBeUpdated", true) + return _returnsA, _returnsB +} + +func (hooks *hooksTimerLayer) ChannelWillBeRestored(c *Context, channel *model.Channel) string { + startTime := timePkg.Now() + _returnsA := hooks.hooksImpl.ChannelWillBeRestored(c, channel) + hooks.recordTime(startTime, "ChannelWillBeRestored", true) + return _returnsA +} + +func (hooks *hooksTimerLayer) ScheduledPostWillBeCreated(c *Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) { + startTime := timePkg.Now() + _returnsA, _returnsB := hooks.hooksImpl.ScheduledPostWillBeCreated(c, scheduledPost) + hooks.recordTime(startTime, "ScheduledPostWillBeCreated", true) + return _returnsA, _returnsB +} + +func (hooks *hooksTimerLayer) DraftWillBeUpserted(c *Context, draft *model.Draft) (*model.Draft, string) { + startTime := timePkg.Now() + _returnsA, _returnsB := hooks.hooksImpl.DraftWillBeUpserted(c, draft) + hooks.recordTime(startTime, "DraftWillBeUpserted", true) + return _returnsA, _returnsB +} + func (hooks *hooksTimerLayer) OnDeactivateWithRPCErr() (error, error) { startTime := timePkg.Now() _returnsA, _returnsRPCErr := hooks.hooksWithRPCErrImpl.OnDeactivateWithRPCErr() @@ -594,3 +622,31 @@ func (hooks *hooksTimerLayer) OnSAMLLoginWithRPCErr(c *Context, user *model.User hooks.recordTime(startTime, "OnSAMLLoginWithRPCErr", _returnsRPCErr == nil && _returnsA == nil) return _returnsA, _returnsRPCErr } + +func (hooks *hooksTimerLayer) ChannelWillBeUpdatedWithRPCErr(c *Context, newChannel, oldChannel *model.Channel) (*model.Channel, string, error) { + startTime := timePkg.Now() + _returnsA, _returnsB, _returnsRPCErr := hooks.hooksWithRPCErrImpl.ChannelWillBeUpdatedWithRPCErr(c, newChannel, oldChannel) + hooks.recordTime(startTime, "ChannelWillBeUpdatedWithRPCErr", _returnsRPCErr == nil) + return _returnsA, _returnsB, _returnsRPCErr +} + +func (hooks *hooksTimerLayer) ChannelWillBeRestoredWithRPCErr(c *Context, channel *model.Channel) (string, error) { + startTime := timePkg.Now() + _returnsA, _returnsRPCErr := hooks.hooksWithRPCErrImpl.ChannelWillBeRestoredWithRPCErr(c, channel) + hooks.recordTime(startTime, "ChannelWillBeRestoredWithRPCErr", _returnsRPCErr == nil) + return _returnsA, _returnsRPCErr +} + +func (hooks *hooksTimerLayer) ScheduledPostWillBeCreatedWithRPCErr(c *Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string, error) { + startTime := timePkg.Now() + _returnsA, _returnsB, _returnsRPCErr := hooks.hooksWithRPCErrImpl.ScheduledPostWillBeCreatedWithRPCErr(c, scheduledPost) + hooks.recordTime(startTime, "ScheduledPostWillBeCreatedWithRPCErr", _returnsRPCErr == nil) + return _returnsA, _returnsB, _returnsRPCErr +} + +func (hooks *hooksTimerLayer) DraftWillBeUpsertedWithRPCErr(c *Context, draft *model.Draft) (*model.Draft, string, error) { + startTime := timePkg.Now() + _returnsA, _returnsB, _returnsRPCErr := hooks.hooksWithRPCErrImpl.DraftWillBeUpsertedWithRPCErr(c, draft) + hooks.recordTime(startTime, "DraftWillBeUpsertedWithRPCErr", _returnsRPCErr == nil) + return _returnsA, _returnsB, _returnsRPCErr +} diff --git a/server/public/plugin/hooks_timer_layer_manual.go b/server/public/plugin/hooks_timer_layer_manual.go new file mode 100644 index 000000000000..4bfd4a6efa4f --- /dev/null +++ b/server/public/plugin/hooks_timer_layer_manual.go @@ -0,0 +1,41 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +// Hand-written timer-layer wrappers for the three hooks excluded from the code generator. +// The auto-generated hooks_timer_layer_generated.go ranges over HooksMethodsRPCErr which +// omits excluded hooks; these three fill that gap so hooksTimerLayer satisfies HooksWithRPCErr. + +package plugin + +import ( + timePkg "time" + + "github.com/mattermost/mattermost/server/public/model" +) + +// MessageWillBePostedWithRPCErr wraps the underlying implementation's MessageWillBePostedWithRPCErr +// and records timing metrics. +func (hooks *hooksTimerLayer) MessageWillBePostedWithRPCErr(c *Context, post *model.Post) (*model.Post, string, error) { + startTime := timePkg.Now() + _returnsA, _returnsB, _returnsRPCErr := hooks.hooksWithRPCErrImpl.MessageWillBePostedWithRPCErr(c, post) + hooks.recordTime(startTime, "MessageWillBePostedWithRPCErr", _returnsRPCErr == nil) + return _returnsA, _returnsB, _returnsRPCErr +} + +// MessageWillBeUpdatedWithRPCErr wraps the underlying implementation's MessageWillBeUpdatedWithRPCErr +// and records timing metrics. +func (hooks *hooksTimerLayer) MessageWillBeUpdatedWithRPCErr(c *Context, newPost, oldPost *model.Post) (*model.Post, string, error) { + startTime := timePkg.Now() + _returnsA, _returnsB, _returnsRPCErr := hooks.hooksWithRPCErrImpl.MessageWillBeUpdatedWithRPCErr(c, newPost, oldPost) + hooks.recordTime(startTime, "MessageWillBeUpdatedWithRPCErr", _returnsRPCErr == nil) + return _returnsA, _returnsB, _returnsRPCErr +} + +// ChannelMemberWillBeAddedWithRPCErr wraps the underlying implementation's ChannelMemberWillBeAddedWithRPCErr +// and records timing metrics. +func (hooks *hooksTimerLayer) ChannelMemberWillBeAddedWithRPCErr(c *Context, channelMember *model.ChannelMember) (*model.ChannelMember, string, error) { + startTime := timePkg.Now() + _returnsA, _returnsB, _returnsRPCErr := hooks.hooksWithRPCErrImpl.ChannelMemberWillBeAddedWithRPCErr(c, channelMember) + hooks.recordTime(startTime, "ChannelMemberWillBeAddedWithRPCErr", _returnsRPCErr == nil) + return _returnsA, _returnsB, _returnsRPCErr +} diff --git a/server/public/plugin/interface_generator/main.go b/server/public/plugin/interface_generator/main.go index fe7773fdb777..2df8f73a5d5e 100644 --- a/server/public/plugin/interface_generator/main.go +++ b/server/public/plugin/interface_generator/main.go @@ -391,7 +391,7 @@ func (g *hooksRPCClient) {{.Name}}WithRPCErr{{funcStyle .Params}} {{funcStyleApp _err = g.client.Call("Plugin.{{.Name}}", _args, _returns) if _err != nil { // Reset _returns so partial gob decoding can't leak non-zero - // values past a transport failure (HooksWithRPCErr contract). + // values past a transport failure (HooksWithRPCErrGenerated contract). _returns = &{{.Name | obscure}}Returns{} g.log.Debug("RPC call {{.Name}} to plugin failed.", mlog.Err(_err)) } @@ -412,7 +412,7 @@ func (s *hooksRPCServer) {{.Name}}(args *{{.Name | obscure}}Args, returns *{{.Na } {{end}} -// HooksWithRPCErr provides a WithRPCErr variant for every generated hook. The last error return +// HooksWithRPCErrGenerated provides a WithRPCErr variant for every generated hook. The last error return // is always the RPC transport error — if non-nil, the plugin's other return values are zero. For // hooks whose base signature already returns error, the tuple is (originalReturns..., rpcErr) // where the final slot is always transport. @@ -421,8 +421,8 @@ func (s *hooksRPCServer) {{.Name}}(args *{{.Name | obscure}}Args, returns *{{.Na // indistinguishable from a successful invocation that returned zeros. Callers MUST gate on // supervisor.Implements() (or use Environment.RunMultiPluginHookWithRPCErr, which gates // by the iteration's hook ID — note that any *WithRPCErr method called on the closure's -// HooksWithRPCErr is independently subject to its own implemented-gate). -type HooksWithRPCErr interface { +// HooksWithRPCErrGenerated is independently subject to its own implemented-gate). +type HooksWithRPCErrGenerated interface { {{range .HooksMethods}} {{.Name}}WithRPCErr{{funcStyle .Params}} {{funcStyleAppendErr .Return}} {{end}} @@ -646,7 +646,7 @@ func generatePluginTimerLayer(info *PluginInterfaceInfo) { // Prepare template params. The timer layer wraps the full Hooks interface, so // HooksMethods includes excluded hooks too. *WithRPCErr companions only exist - // for non-excluded hooks (see HooksWithRPCErr in client_rpc_generated.go), so the + // for non-excluded hooks (see HooksWithRPCErrGenerated in client_rpc_generated.go), so the // excluded subset is filtered into HooksMethodsRPCErr for that loop. excluded := func(name string) bool { return slices.Contains(excludedPluginHooks, name) } templateParams := HooksTemplateParams{} diff --git a/server/public/plugin/plugintest/api.go b/server/public/plugin/plugintest/api.go index 338fef74f2d7..e37102f60e98 100644 --- a/server/public/plugin/plugintest/api.go +++ b/server/public/plugin/plugintest/api.go @@ -9,10 +9,8 @@ import ( http "net/http" logr "github.com/mattermost/logr/v2" - - mock "github.com/stretchr/testify/mock" - model "github.com/mattermost/mattermost/server/public/model" + mock "github.com/stretchr/testify/mock" ) // API is an autogenerated mock type for the API type @@ -4675,6 +4673,26 @@ func (_m *API) ReceiveSharedChannelSyncMsg(remoteID string, msg *model.SyncMsg) return r0, r1 } +// RegisterChannelGuard provides a mock function with given fields: channelID +func (_m *API) RegisterChannelGuard(channelID string) *model.AppError { + ret := _m.Called(channelID) + + if len(ret) == 0 { + panic("no return value specified for RegisterChannelGuard") + } + + var r0 *model.AppError + if rf, ok := ret.Get(0).(func(string) *model.AppError); ok { + r0 = rf(channelID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.AppError) + } + } + + return r0 +} + // RegisterCollectionAndTopic provides a mock function with given fields: collectionType, topicType func (_m *API) RegisterCollectionAndTopic(collectionType string, topicType string) error { ret := _m.Called(collectionType, topicType) @@ -5457,6 +5475,26 @@ func (_m *API) UninviteRemoteFromChannel(channelID string, remoteID string) erro return r0 } +// UnregisterChannelGuard provides a mock function with given fields: channelID +func (_m *API) UnregisterChannelGuard(channelID string) *model.AppError { + ret := _m.Called(channelID) + + if len(ret) == 0 { + panic("no return value specified for UnregisterChannelGuard") + } + + var r0 *model.AppError + if rf, ok := ret.Get(0).(func(string) *model.AppError); ok { + r0 = rf(channelID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.AppError) + } + } + + return r0 +} + // UnregisterCommand provides a mock function with given fields: teamID, trigger func (_m *API) UnregisterCommand(teamID string, trigger string) error { ret := _m.Called(teamID, trigger) diff --git a/server/public/plugin/plugintest/driver.go b/server/public/plugin/plugintest/driver.go index 4c158856c393..db9288b93120 100644 --- a/server/public/plugin/plugintest/driver.go +++ b/server/public/plugin/plugintest/driver.go @@ -7,9 +7,8 @@ package plugintest import ( driver "database/sql/driver" - mock "github.com/stretchr/testify/mock" - plugin "github.com/mattermost/mattermost/server/public/plugin" + mock "github.com/stretchr/testify/mock" ) // Driver is an autogenerated mock type for the Driver type diff --git a/server/public/plugin/plugintest/hooks.go b/server/public/plugin/plugintest/hooks.go index 90206542a388..c33ee478943a 100644 --- a/server/public/plugin/plugintest/hooks.go +++ b/server/public/plugin/plugintest/hooks.go @@ -8,13 +8,10 @@ import ( io "io" http "net/http" - mock "github.com/stretchr/testify/mock" - + saml2 "github.com/mattermost/gosaml2" model "github.com/mattermost/mattermost/server/public/model" - plugin "github.com/mattermost/mattermost/server/public/plugin" - - saml2 "github.com/mattermost/gosaml2" + mock "github.com/stretchr/testify/mock" ) // Hooks is an autogenerated mock type for the Hooks type @@ -75,6 +72,54 @@ func (_m *Hooks) ChannelWillBeArchived(c *plugin.Context, channel *model.Channel return r0 } +// ChannelWillBeRestored provides a mock function with given fields: c, channel +func (_m *Hooks) ChannelWillBeRestored(c *plugin.Context, channel *model.Channel) string { + ret := _m.Called(c, channel) + + if len(ret) == 0 { + panic("no return value specified for ChannelWillBeRestored") + } + + var r0 string + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Channel) string); ok { + r0 = rf(c, channel) + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// ChannelWillBeUpdated provides a mock function with given fields: c, newChannel, oldChannel +func (_m *Hooks) ChannelWillBeUpdated(c *plugin.Context, newChannel *model.Channel, oldChannel *model.Channel) (*model.Channel, string) { + ret := _m.Called(c, newChannel, oldChannel) + + if len(ret) == 0 { + panic("no return value specified for ChannelWillBeUpdated") + } + + var r0 *model.Channel + var r1 string + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Channel, *model.Channel) (*model.Channel, string)); ok { + return rf(c, newChannel, oldChannel) + } + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Channel, *model.Channel) *model.Channel); ok { + r0 = rf(c, newChannel, oldChannel) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Channel) + } + } + + if rf, ok := ret.Get(1).(func(*plugin.Context, *model.Channel, *model.Channel) string); ok { + r1 = rf(c, newChannel, oldChannel) + } else { + r1 = ret.Get(1).(string) + } + + return r0, r1 +} + // ConfigurationWillBeSaved provides a mock function with given fields: newCfg func (_m *Hooks) ConfigurationWillBeSaved(newCfg *model.Config) (*model.Config, error) { ret := _m.Called(newCfg) @@ -105,6 +150,36 @@ func (_m *Hooks) ConfigurationWillBeSaved(newCfg *model.Config) (*model.Config, return r0, r1 } +// DraftWillBeUpserted provides a mock function with given fields: c, draft +func (_m *Hooks) DraftWillBeUpserted(c *plugin.Context, draft *model.Draft) (*model.Draft, string) { + ret := _m.Called(c, draft) + + if len(ret) == 0 { + panic("no return value specified for DraftWillBeUpserted") + } + + var r0 *model.Draft + var r1 string + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Draft) (*model.Draft, string)); ok { + return rf(c, draft) + } + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Draft) *model.Draft); ok { + r0 = rf(c, draft) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Draft) + } + } + + if rf, ok := ret.Get(1).(func(*plugin.Context, *model.Draft) string); ok { + r1 = rf(c, draft) + } else { + r1 = ret.Get(1).(string) + } + + return r0, r1 +} + // EmailNotificationWillBeSent provides a mock function with given fields: emailNotification func (_m *Hooks) EmailNotificationWillBeSent(emailNotification *model.EmailNotification) (*model.EmailNotificationContent, string) { ret := _m.Called(emailNotification) @@ -640,6 +715,36 @@ func (_m *Hooks) RunDataRetention(nowTime int64, batchSize int64) (int64, error) return r0, r1 } +// ScheduledPostWillBeCreated provides a mock function with given fields: c, scheduledPost +func (_m *Hooks) ScheduledPostWillBeCreated(c *plugin.Context, scheduledPost *model.ScheduledPost) (*model.ScheduledPost, string) { + ret := _m.Called(c, scheduledPost) + + if len(ret) == 0 { + panic("no return value specified for ScheduledPostWillBeCreated") + } + + var r0 *model.ScheduledPost + var r1 string + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.ScheduledPost) (*model.ScheduledPost, string)); ok { + return rf(c, scheduledPost) + } + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.ScheduledPost) *model.ScheduledPost); ok { + r0 = rf(c, scheduledPost) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ScheduledPost) + } + } + + if rf, ok := ret.Get(1).(func(*plugin.Context, *model.ScheduledPost) string); ok { + r1 = rf(c, scheduledPost) + } else { + r1 = ret.Get(1).(string) + } + + return r0, r1 +} + // ServeHTTP provides a mock function with given fields: c, w, r func (_m *Hooks) ServeHTTP(c *plugin.Context, w http.ResponseWriter, r *http.Request) { _m.Called(c, w, r) diff --git a/server/public/plugin/plugintest/hooks_with_rpcerr.go b/server/public/plugin/plugintest/hooks_with_rpcerr.go new file mode 100644 index 000000000000..f89efe7f8cdf --- /dev/null +++ b/server/public/plugin/plugintest/hooks_with_rpcerr.go @@ -0,0 +1,114 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +// Hand-written *WithRPCErr mock methods for the three hooks excluded from the code generator. +// The auto-generated hooks.go (regenerated by `make plugin-mocks`) covers the base Hooks interface; +// this file adds the extra *WithRPCErr companions so *Hooks satisfies plugin.HooksWithRPCErr. +// This file is not overwritten by `make plugin-mocks` because mockery writes only hooks.go for the +// Hooks interface (filename: "{{.InterfaceNameLower}}.go" in .mockery.yaml). + +package plugintest + +import ( + model "github.com/mattermost/mattermost/server/public/model" + plugin "github.com/mattermost/mattermost/server/public/plugin" +) + +// MessageWillBePostedWithRPCErr provides a mock function with given fields: c, post +func (_m *Hooks) MessageWillBePostedWithRPCErr(c *plugin.Context, post *model.Post) (*model.Post, string, error) { + ret := _m.Called(c, post) + + var r0 *model.Post + var r1 string + var r2 error + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Post) (*model.Post, string, error)); ok { + return rf(c, post) + } + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Post) *model.Post); ok { + r0 = rf(c, post) + } else if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Post) + } + + if rf, ok := ret.Get(1).(func(*plugin.Context, *model.Post) string); ok { + r1 = rf(c, post) + } else { + r1 = ret.Get(1).(string) + } + + if rf, ok := ret.Get(2).(func(*plugin.Context, *model.Post) error); ok { + r2 = rf(c, post) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MessageWillBeUpdatedWithRPCErr provides a mock function with given fields: c, newPost, oldPost +func (_m *Hooks) MessageWillBeUpdatedWithRPCErr(c *plugin.Context, newPost *model.Post, oldPost *model.Post) (*model.Post, string, error) { + ret := _m.Called(c, newPost, oldPost) + + var r0 *model.Post + var r1 string + var r2 error + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Post, *model.Post) (*model.Post, string, error)); ok { + return rf(c, newPost, oldPost) + } + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.Post, *model.Post) *model.Post); ok { + r0 = rf(c, newPost, oldPost) + } else if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.Post) + } + + if rf, ok := ret.Get(1).(func(*plugin.Context, *model.Post, *model.Post) string); ok { + r1 = rf(c, newPost, oldPost) + } else { + r1 = ret.Get(1).(string) + } + + if rf, ok := ret.Get(2).(func(*plugin.Context, *model.Post, *model.Post) error); ok { + r2 = rf(c, newPost, oldPost) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// ChannelMemberWillBeAddedWithRPCErr provides a mock function with given fields: c, channelMember +func (_m *Hooks) ChannelMemberWillBeAddedWithRPCErr(c *plugin.Context, channelMember *model.ChannelMember) (*model.ChannelMember, string, error) { + ret := _m.Called(c, channelMember) + + var r0 *model.ChannelMember + var r1 string + var r2 error + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.ChannelMember) (*model.ChannelMember, string, error)); ok { + return rf(c, channelMember) + } + if rf, ok := ret.Get(0).(func(*plugin.Context, *model.ChannelMember) *model.ChannelMember); ok { + r0 = rf(c, channelMember) + } else if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ChannelMember) + } + + if rf, ok := ret.Get(1).(func(*plugin.Context, *model.ChannelMember) string); ok { + r1 = rf(c, channelMember) + } else { + r1 = ret.Get(1).(string) + } + + if rf, ok := ret.Get(2).(func(*plugin.Context, *model.ChannelMember) error); ok { + r2 = rf(c, channelMember) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// Note: plugintest.Hooks is a mock for the base Hooks interface only. The auto-generated +// hooks.go does not include *WithRPCErr methods, so *Hooks cannot satisfy HooksWithRPCErr +// in full. The production compile-time assertions for HooksWithRPCErr live in client_rpc.go +// (for *hooksRPCClient and *hooksTimerLayer). Tests that need a HooksWithRPCErr double +// should embed *Hooks and add the needed *WithRPCErr stubs directly.