diff --git a/pkg/mongoproxy/plugins/schema/schema.go b/pkg/mongoproxy/plugins/schema/schema.go index a200533..2793fa1 100644 --- a/pkg/mongoproxy/plugins/schema/schema.go +++ b/pkg/mongoproxy/plugins/schema/schema.go @@ -6,14 +6,16 @@ import ( "io/ioutil" "log" "path" + "strings" "sync/atomic" + "go.mongodb.org/mongo-driver/bson" + "gopkg.in/fsnotify.v1" + "github.com/cespare/xxhash/v2" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/sirupsen/logrus" - "go.mongodb.org/mongo-driver/bson" - "gopkg.in/fsnotify.v1" "github.com/wish/mongoproxy/pkg/bsonutil" "github.com/wish/mongoproxy/pkg/command" @@ -114,12 +116,20 @@ func (p *SchemaPlugin) Configure(d bson.D) error { if err := p.LoadSchema(); err != nil { return err } + // skip watcher for unit test + if strings.HasPrefix(p.conf.SchemaPath, "example.json") { + return nil + } + // start watch watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatal(err) } + defer watcher.Close() + done := make(chan bool) + go func() { for { select { @@ -145,7 +155,7 @@ func (p *SchemaPlugin) Configure(d bson.D) error { if err := watcher.Add(path.Dir(p.conf.SchemaPath)); err != nil { return err } - + <-done return nil } @@ -168,7 +178,7 @@ func (p *SchemaPlugin) Process(ctx context.Context, r *plugins.Request, next plu case *command.FindAndModify: if len(cmd.Update) > 0 { schema := p.GetSchema() - logrus.Infof("command findAndModify: %v", cmd.Update) + logrus.Debugf("command findAndModify: %v", cmd.Update) if err := schema.ValidateUpdate(ctx, cmd.Database, cmd.Collection, cmd.Update, bsonutil.GetBoolDefault(cmd.Upsert, false)); err != nil { schemaDeny.WithLabelValues(cmd.Database, cmd.Collection, r.CommandName).Inc() if !p.conf.EnforceSchemaLogOnly { @@ -182,7 +192,7 @@ func (p *SchemaPlugin) Process(ctx context.Context, r *plugins.Request, next plu case *command.Update: schema := p.GetSchema() for _, updateDoc := range cmd.Updates { - logrus.Infof("print command Update: %v", updateDoc) + logrus.Debugf("print command Update: %v", updateDoc) if err := schema.ValidateUpdate(ctx, cmd.Database, cmd.Collection, updateDoc.U, bsonutil.GetBoolDefault(updateDoc.Upsert, false)); err != nil { schemaDeny.WithLabelValues(cmd.Database, cmd.Collection, r.CommandName).Inc() if !p.conf.EnforceSchemaLogOnly { diff --git a/pkg/mongoproxy/plugins/schema/type_test.go b/pkg/mongoproxy/plugins/schema/type_test.go index 5e3ee8f..78aaf52 100644 --- a/pkg/mongoproxy/plugins/schema/type_test.go +++ b/pkg/mongoproxy/plugins/schema/type_test.go @@ -199,6 +199,10 @@ var ( // push extra field {DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$push", bson.D{{"doc.a", "name"}, {"doc.b", 1}}}}, Err: true}, {DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$push", bson.D{{"a", "name"}, {"doc.b", 1}}}}, Err: true}, + //test with each + {DB: "testdb", Collection: "nonrequire", In: bson.D{{"$push", bson.D{{"luckynumbers", bson.D{{"$each", bson.A{1, 2, 3}}}}}}}}, + //test with each + {DB: "testdb", Collection: "nonrequire", In: bson.D{{"$push", bson.D{{"luckynumbers", bson.E{"$each", bson.A{1, 2, 3}}}}}}}, // // pull tests @@ -337,6 +341,10 @@ var ( // addToSet extra field {DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$addToSet", bson.D{{"doc.a", "name"}, {"doc.b", 1}}}}, Err: true}, {DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$addToSet", bson.D{{"a", "name"}, {"doc.b", 1}}}}, Err: true}, + //test with each + {DB: "testdb", Collection: "nonrequire", In: bson.D{{"$addToSet", bson.D{{"luckynumbers", bson.D{{"$each", bson.A{1, 2, 3}}}}}}}}, + //test with each + {DB: "testdb", Collection: "nonrequire", In: bson.D{{"$addToSet", bson.D{{"luckynumbers", bson.E{"$each", bson.A{1, 2, 3}}}}}}}, // // rename tests diff --git a/pkg/mongoproxy/plugins/schema/types.go b/pkg/mongoproxy/plugins/schema/types.go index dfff3d6..c707fcf 100644 --- a/pkg/mongoproxy/plugins/schema/types.go +++ b/pkg/mongoproxy/plugins/schema/types.go @@ -267,7 +267,7 @@ func (c *Collection) ValidateUpdate(ctx context.Context, obj bson.D, upsert bool } case "$rename": renameFields = e.Value.(bson.D).Map() - case "$set", "$pull", "$push", "$addToSet", "$pullAll": + case "$set", "$pull", "$pullAll": if setFields == nil { setFields = Mapify(e.Value.(bson.D)) } else { @@ -276,6 +276,12 @@ func (c *Collection) ValidateUpdate(ctx context.Context, obj bson.D, upsert bool setFields[item.Key] = item.Value } } + case "$addToSet", "$push": + if setFields == nil { + setFields = make(bson.M, len(e.Value.(bson.D))) + } + setFields = MapifyWithOp(e.Value.(bson.D), setFields) + case "$setOnInsert": insertFields = Mapify(e.Value.(bson.D)) case "$unset": diff --git a/pkg/mongoproxy/plugins/schema/util.go b/pkg/mongoproxy/plugins/schema/util.go index f226a68..de78a01 100644 --- a/pkg/mongoproxy/plugins/schema/util.go +++ b/pkg/mongoproxy/plugins/schema/util.go @@ -6,6 +6,9 @@ import ( "reflect" "regexp" + "go.mongodb.org/mongo-driver/bson/primitive" + + "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" ) @@ -141,6 +144,27 @@ func Mapify(d bson.D) bson.M { return m } +// Map creates a map from the elements of the D with operator +// It makes additional process for arrays +func MapifyWithOp(d bson.D, m bson.M) bson.M { + for _, e := range d { + e := processArray(e) + if _, ok := e.Value.(primitive.D); ok { + itemValueSet := e.Value.(bson.D).Map() + if val, ok := itemValueSet["$each"]; ok { + m[e.Key] = val + continue + } + } else if _, ok := e.Value.(primitive.E); ok && e.Value.(bson.E).Key == "$each" { + m[e.Key] = e.Value.(bson.E).Value + continue + } + m[e.Key] = e.Value + logrus.Debugf("Add %s type element to set", fmt.Sprint(reflect.TypeOf(e.Value))) + } + return m +} + // looping and process elements in object func handleObj(obj bson.D, m bson.M) bson.M { for _, e := range obj {