Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions server/channels/app/plugin_hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,93 @@ func TestHookMessagesWillBeConsumed(t *testing.T) {
})
}

func TestHookMessagesWillBeConsumedWithContext(t *testing.T) {
mainHelper.Parallel(t)

setupPlugin := func(t *testing.T, th *TestHelper) {
var mockAPI plugintest.API
mockAPI.On("LoadPluginConfiguration", mock.Anything).Return(nil)
mockAPI.On("LogDebug", "message").Return(nil)

// The plugin records whether it received a non-nil context to confirm the context is
// threaded all the way through to the hook.
tearDown, _, _ := SetAppEnvironmentWithPlugins(t, []string{`
package main

import (
"github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/model"
)

type MyPlugin struct {
plugin.MattermostPlugin
}

func (p *MyPlugin) MessagesWillBeConsumedWithContext(c *plugin.Context, posts []*model.Post) []*model.Post {
prefix := "mwbcwc_plugin:"
if c == nil {
prefix = "nilctx:"
}
for _, post := range posts {
post.Message = prefix + post.Message
}
return posts
}

func main() {
plugin.ClientMain(&MyPlugin{})
}
`}, th.App, func(*model.Manifest) plugin.API { return &mockAPI })
t.Cleanup(tearDown)
}

t.Run("feature flag disabled", func(t *testing.T) {
mainHelper.Parallel(t)

th := SetupConfig(t, func(cfg *model.Config) {
cfg.FeatureFlags.ConsumePostHook = false
}).InitBasic(t)

setupPlugin(t, th)

newPost := &model.Post{
UserId: th.BasicUser.Id,
ChannelId: th.BasicChannel.Id,
Message: "message",
CreateAt: model.GetMillis() - 10000,
}
_, _, err := th.App.CreatePost(th.Context, newPost, th.BasicChannel, model.CreatePostFlags{SetOnline: true})
require.Nil(t, err)

post, err := th.App.GetSinglePost(th.Context, newPost.Id, true)
require.Nil(t, err)
assert.Equal(t, "message", post.Message)
})

t.Run("feature flag enabled", func(t *testing.T) {
mainHelper.Parallel(t)

th := SetupConfig(t, func(cfg *model.Config) {
cfg.FeatureFlags.ConsumePostHook = true
}).InitBasic(t)

setupPlugin(t, th)

newPost := &model.Post{
UserId: th.BasicUser.Id,
ChannelId: th.BasicChannel.Id,
Message: "message",
CreateAt: model.GetMillis() - 10000,
}
_, _, err := th.App.CreatePost(th.Context, newPost, th.BasicChannel, model.CreatePostFlags{SetOnline: true})
require.Nil(t, err)

post, err := th.App.GetSinglePost(th.Context, newPost.Id, true)
require.Nil(t, err)
assert.Equal(t, "mwbcwc_plugin:message", post.Message)
})
}

func TestUpdatePostFiresConsumeHook(t *testing.T) {
mainHelper.Parallel(t)

Expand Down
61 changes: 39 additions & 22 deletions server/channels/app/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"net/http"
"regexp"
"slices"
"strconv"
"strings"
"sync"
"time"

"maps"
"slices"

"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/shared/i18n"
Expand Down Expand Up @@ -437,7 +436,7 @@ func (a *App) CreatePost(rctx request.CTX, post *model.Post, channel *model.Chan
}
}

a.applyPostWillBeConsumedHook(&rpost)
a.applyPostWillBeConsumedHook(rctx, &rpost)

if rpost.RootId != "" {
if appErr := a.ResolvePersistentNotification(rctx, parentPostList.Posts[post.RootId], rpost.UserId); appErr != nil {
Expand Down Expand Up @@ -982,7 +981,7 @@ func (a *App) UpdatePost(rctx request.CTX, receivedUpdatedPost *model.Post, upda
}
}

a.applyPostWillBeConsumedHook(&rpost)
a.applyPostWillBeConsumedHook(rctx, &rpost)

message := model.NewWebSocketEvent(model.WebsocketEventPostEdited, "", rpost.ChannelId, "", nil, "")

Expand Down Expand Up @@ -1271,7 +1270,7 @@ func (a *App) GetPostsPage(rctx request.CTX, options model.GetPostsOptions) (*mo
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand Down Expand Up @@ -1300,7 +1299,7 @@ func (a *App) GetPostsForView(rctx request.CTX, options model.GetPostsOptions) (
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand All @@ -1327,7 +1326,7 @@ func (a *App) GetPosts(rctx request.CTX, channelID string, offset int, limit int
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand Down Expand Up @@ -1364,7 +1363,7 @@ func (a *App) GetPostsSince(rctx request.CTX, options model.GetPostsSinceOptions
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand Down Expand Up @@ -1461,7 +1460,7 @@ func (a *App) GetSinglePost(rctx request.CTX, postID string, includeDeleted bool
return nil, model.NewAppError("GetSinglePost", "app.post.cloud.get.app_error", nil, "", http.StatusForbidden)
}

a.applyPostWillBeConsumedHook(&post)
a.applyPostWillBeConsumedHook(rctx, &post)

return post, nil
}
Expand Down Expand Up @@ -1499,7 +1498,7 @@ func (a *App) GetPostThread(rctx request.CTX, postID string, opts model.GetPosts
return nil, appErr
}

a.applyPostsWillBeConsumedHook(posts.Posts)
a.applyPostsWillBeConsumedHook(rctx, posts.Posts)

return posts, nil
}
Expand All @@ -1521,7 +1520,7 @@ func (a *App) GetFlaggedPosts(rctx request.CTX, userID string, offset int, limit
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand All @@ -1543,7 +1542,7 @@ func (a *App) GetFlaggedPostsForTeam(rctx request.CTX, userID, teamID string, of
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand All @@ -1565,7 +1564,7 @@ func (a *App) GetFlaggedPostsForChannel(rctx request.CTX, userID, channelID stri
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand Down Expand Up @@ -1609,7 +1608,7 @@ func (a *App) GetPermalinkPost(rctx request.CTX, postID string, userID string) (
return nil, appErr
}

a.applyPostsWillBeConsumedHook(list.Posts)
a.applyPostsWillBeConsumedHook(rctx, list.Posts)

return list, nil
}
Expand Down Expand Up @@ -1645,7 +1644,7 @@ func (a *App) GetPostsBeforePost(rctx request.CTX, options model.GetPostsOptions
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand Down Expand Up @@ -1681,7 +1680,7 @@ func (a *App) GetPostsAfterPost(rctx request.CTX, options model.GetPostsOptions)
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}
Expand Down Expand Up @@ -1725,18 +1724,18 @@ func (a *App) GetPostsAroundPost(rctx request.CTX, before bool, options model.Ge
return nil, appErr
}

a.applyPostsWillBeConsumedHook(postList.Posts)
a.applyPostsWillBeConsumedHook(rctx, postList.Posts)

return postList, nil
}

func (a *App) GetPostAfterTime(channelID string, time int64, collapsedThreads bool) (*model.Post, *model.AppError) {
func (a *App) GetPostAfterTime(rctx request.CTX, channelID string, time int64, collapsedThreads bool) (*model.Post, *model.AppError) {
post, err := a.Srv().Store().Post().GetPostAfterTime(channelID, time, collapsedThreads)
if err != nil {
return nil, model.NewAppError("GetPostAfterTime", "app.post.get_post_after_time.app_error", nil, "", http.StatusInternalServerError).Wrap(err)
}

a.applyPostWillBeConsumedHook(&post)
a.applyPostWillBeConsumedHook(rctx, &post)

return post, nil
}
Expand Down Expand Up @@ -2904,7 +2903,7 @@ func (a *App) GetPostInfo(rctx request.CTX, postID string, channel *model.Channe
return &info, nil
}

func (a *App) applyPostsWillBeConsumedHook(posts map[string]*model.Post) {
func (a *App) applyPostsWillBeConsumedHook(rctx request.CTX, posts map[string]*model.Post) {
if !a.Config().FeatureFlags.ConsumePostHook {
return
}
Expand All @@ -2921,9 +2920,18 @@ func (a *App) applyPostsWillBeConsumedHook(posts map[string]*model.Post) {
}
return true
}, plugin.MessagesWillBeConsumedID)

pluginContext := pluginContext(rctx)
a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool {
postReplacements := hooks.MessagesWillBeConsumedWithContext(pluginContext, postsSlice)
for _, postReplacement := range postReplacements {
posts[postReplacement.Id] = postReplacement
}
return true
}, plugin.MessagesWillBeConsumedWithContextID)
}

func (a *App) applyPostWillBeConsumedHook(post **model.Post) {
func (a *App) applyPostWillBeConsumedHook(rctx request.CTX, post **model.Post) {
if !a.Config().FeatureFlags.ConsumePostHook || (*post).Type == model.PostTypeBurnOnRead {
return
}
Expand All @@ -2936,6 +2944,15 @@ func (a *App) applyPostWillBeConsumedHook(post **model.Post) {
}
return true
}, plugin.MessagesWillBeConsumedID)

pluginContext := pluginContext(rctx)
a.ch.RunMultiHook(func(hooks plugin.Hooks, _ *model.Manifest) bool {
rp := hooks.MessagesWillBeConsumedWithContext(pluginContext, ps)
if len(rp) > 0 {
(*post) = rp[0]
}
return true
}, plugin.MessagesWillBeConsumedWithContextID)
}

func makePostLink(siteURL, teamName, postID string) string {
Expand Down
38 changes: 38 additions & 0 deletions server/public/plugin/client_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,44 @@ func (s *hooksRPCServer) MessagesWillBeConsumed(args *Z_MessagesWillBeConsumedAr
return nil
}

// MessagesWillBeConsumedWithContext is in this file because of the difficulty of identifying which fields
// need special behaviour. The special behaviour needed is decoding the returned post into the original one
// to avoid the unintentional removal of fields by older plugins.
func init() {
hookNameToId["MessagesWillBeConsumedWithContext"] = MessagesWillBeConsumedWithContextID
}

type Z_MessagesWillBeConsumedWithContextArgs struct {
A *Context
B []*model.Post
}

type Z_MessagesWillBeConsumedWithContextReturns struct {
A []*model.Post
}

func (g *hooksRPCClient) MessagesWillBeConsumedWithContext(c *Context, posts []*model.Post) []*model.Post {
_args := &Z_MessagesWillBeConsumedWithContextArgs{c, posts}
_returns := &Z_MessagesWillBeConsumedWithContextReturns{}
if g.implemented[MessagesWillBeConsumedWithContextID] {
if err := g.client.Call("Plugin.MessagesWillBeConsumedWithContext", _args, _returns); err != nil {
g.log.Error("RPC call MessagesWillBeConsumedWithContext to plugin failed.", mlog.Err(err))
}
}
return _returns.A
}

func (s *hooksRPCServer) MessagesWillBeConsumedWithContext(args *Z_MessagesWillBeConsumedWithContextArgs, returns *Z_MessagesWillBeConsumedWithContextReturns) error {
if hook, ok := s.impl.(interface {
MessagesWillBeConsumedWithContext(c *Context, posts []*model.Post) []*model.Post
}); ok {
returns.A = hook.MessagesWillBeConsumedWithContext(args.A, args.B)
} else {
return encodableError(fmt.Errorf("hook MessagesWillBeConsumedWithContext called but not implemented"))
}
return nil
}

type Z_LogDebugArgs struct {
A string
B []any
Expand Down
13 changes: 13 additions & 0 deletions server/public/plugin/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ const (
ChannelWillBeRestoredID = 53
ScheduledPostWillBeCreatedID = 54
DraftWillBeUpsertedID = 55
MessagesWillBeConsumedWithContextID = 56
TotalHooksID = iota
)

Expand Down Expand Up @@ -199,6 +200,18 @@ type Hooks interface {
// Minimum server version: 9.3
MessagesWillBeConsumed(posts []*model.Post) []*model.Post

// MessagesWillBeConsumedWithContext is invoked when messages are requested by a client, before
// they are returned to the client. It is the context-aware variant of MessagesWillBeConsumed.
//
// To modify a post, return the replacement post; the returned posts are matched to the originals
// by ID. Posts that should be left unchanged may be omitted from the returned slice.
//
// Note that this method will be called for posts created by plugins, including the plugin that
// created the post.
//
// Minimum server version: 11.9
MessagesWillBeConsumedWithContext(c *Context, posts []*model.Post) []*model.Post

// MessageHasBeenDeleted is invoked after the message has been deleted from the database.
// Note that this method will be called for posts deleted by plugins, including the plugin that
// deleted the post.
Expand Down
7 changes: 7 additions & 0 deletions server/public/plugin/hooks_timer_layer_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/public/plugin/interface_generator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var excludedPluginHooks = []string{
"MessageWillBePosted",
"MessageWillBeUpdated",
"MessagesWillBeConsumed",
"MessagesWillBeConsumedWithContext",
"OnActivate",
"PluginHTTP",
"ServeHTTP",
Expand Down
Loading
Loading