Skip to content

Docstore/memdocstore: nested query #3508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docstore/memdocstore/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func decodeDoc(m storedDoc, ddoc driver.Document, fps [][]string) error {
// (We don't need the key field because ddoc must already have it.)
m2 = map[string]interface{}{}
for _, fp := range fps {
val, err := getAtFieldPath(m, fp)
val, err := getAtFieldPath(m, fp, false)
if err != nil {
if gcerrors.Code(err) == gcerrors.NotFound {
continue
Expand Down
59 changes: 42 additions & 17 deletions docstore/memdocstore/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ type Options struct {
// When the collection is closed, its contents are saved to the file.
Filename string

// AllowNestedSliceQueries allows querying into nested slices.
// If true queries for a field path which points to a slice will return
// true if any element of the slice has a value that validates with the operator.
// This makes the memdocstore more compatible with MongoDB,
// but other providers may not support this feature.
AllowNestedSliceQueries bool

// Call this function when the collection is closed.
// For internal use only.
onClose func()
Expand Down Expand Up @@ -397,18 +404,44 @@ func (c *collection) checkRevision(arg driver.Document, current storedDoc) error
return nil
}

// getAtFieldPath gets the value of m at fp. It returns an error if fp is invalid
// getAtFieldPath gets the value of m at fp. It returns an error if fp is invalid.
// If nested is true compare against all elements of a slice, see AllowNestedSliceQueries
// (see getParentMap).
func getAtFieldPath(m map[string]interface{}, fp []string) (interface{}, error) {
m2, err := getParentMap(m, fp, false)
if err != nil {
return nil, err
func getAtFieldPath(m map[string]any, fp []string, nested bool) (result any, err error) {
var get func(m any, name string) any
get = func(m any, name string) any {
switch m := m.(type) {
case map[string]any:
return m[name]
case []any:
if !nested {
return nil
}
var result []any
for _, e := range m {
next := get(e, name)
// If we have slices within slices the compare function does not see the nested slices.
// Changing the compare function to be recursive would be more effort than flattening the slices here.
sliced, ok := next.([]any)
if ok {
result = append(result, sliced...)
} else {
result = append(result, next)
}
}
return result
}
return nil
}
v, ok := m2[fp[len(fp)-1]]
if ok {
return v, nil
result = m
for _, k := range fp {
next := get(result, k)
if next == nil {
return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", strings.Join(fp, "."))
}
result = next
}
return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", fp)
return result, nil
}

// setAtFieldPath sets m's value at fp to val. It creates intermediate maps as
Expand All @@ -422,14 +455,6 @@ func setAtFieldPath(m map[string]interface{}, fp []string, val interface{}) erro
return nil
}

// Delete the value from m at the given field path, if it exists.
func deleteAtFieldPath(m map[string]interface{}, fp []string) {
m2, _ := getParentMap(m, fp, false) // ignore error
if m2 != nil {
delete(m2, fp[len(fp)-1])
}
}

// getParentMap returns the map that directly contains the given field path;
// that is, the value of m at the field path that excludes the last component
// of fp. If a non-map is encountered along the way, an InvalidArgument error is
Expand Down
138 changes: 138 additions & 0 deletions docstore/memdocstore/mem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package memdocstore

import (
"context"
"io"
"os"
"path/filepath"
"slices"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -131,6 +133,142 @@ func TestUpdateAtomic(t *testing.T) {
}
}

func TestQueryNested(t *testing.T) {
ctx := context.Background()

dc, err := newCollection(drivertest.KeyField, nil, &Options{AllowNestedSliceQueries: true})
if err != nil {
t.Fatal(err)
}
coll := docstore.NewCollection(dc)
defer coll.Close()

// Set up test documents
testDocs := []docmap{{
drivertest.KeyField: "TestQueryNested",
"list": []any{docmap{"a": "A"}},
"map": docmap{"b": "B"},
"listOfMaps": []any{docmap{"id": "1"}, docmap{"id": "2"}, docmap{"id": "3"}},
"mapOfLists": docmap{"ids": []any{"1", "2", "3"}},
"deep": []any{docmap{"nesting": []any{docmap{"of": docmap{"elements": "yes"}}}}},
"listOfLists": []any{docmap{"items": []any{docmap{"price": 10}, docmap{"price": 20}}}},
dc.RevisionField(): nil,
}, {
drivertest.KeyField: "CheapItems",
"items": []any{docmap{"price": 10}, docmap{"price": 1}},
dc.RevisionField(): nil,
}, {
drivertest.KeyField: "ExpensiveItems",
"items": []any{docmap{"price": 50}, docmap{"price": 100}},
dc.RevisionField(): nil,
}}

for _, testDoc := range testDocs {
err = coll.Put(ctx, testDoc)
if err != nil {
t.Fatal(err)
}
}

tests := []struct {
name string
where []any
wantKeys []string
}{
{
name: "list field match",
where: []any{"list.a", "=", "A"},
wantKeys: []string{"TestQueryNested"},
}, {
name: "list field no match",
where: []any{"list.a", "=", "missing"},
}, {
name: "map field match",
where: []any{"map.b", "=", "B"},
wantKeys: []string{"TestQueryNested"},
}, {
name: "list of maps field match",
where: []any{"listOfMaps.id", "=", "2"},
wantKeys: []string{"TestQueryNested"},
}, {
name: "map of lists field match",
where: []any{"mapOfLists.ids", "=", "1"},
wantKeys: []string{"TestQueryNested"},
}, {
name: "deep nested field match",
where: []any{"deep.nesting.of.elements", "=", "yes"},
wantKeys: []string{"TestQueryNested"},
}, {
name: "list of lists exact price 10",
where: []any{"listOfLists.items.price", "=", 10},
wantKeys: []string{"TestQueryNested"},
}, {
name: "list of lists exact price 20",
where: []any{"listOfLists.items.price", "=", 20},
wantKeys: []string{"TestQueryNested"},
}, {
name: "list of lists price less than or equal to 20",
where: []any{"listOfLists.items.price", "<=", 20},
wantKeys: []string{"TestQueryNested"},
}, {
name: "items price equals 1",
where: []any{"items.price", "=", 1},
wantKeys: []string{"CheapItems"},
}, {
name: "items price equals 5 (no match)",
where: []any{"items.price", "=", 5},
}, {
name: "items price greater than or equal to 1",
where: []any{"items.price", ">=", 1},
wantKeys: []string{"CheapItems", "ExpensiveItems"},
}, {
name: "items price greater than or equal to 5",
where: []any{"items.price", ">=", 5},
wantKeys: []string{"CheapItems", "ExpensiveItems"},
}, {
name: "items price greater than or equal to 10",
where: []any{"items.price", ">=", 10},
wantKeys: []string{"CheapItems", "ExpensiveItems"},
}, {
name: "items price less than or equal to 50",
where: []any{"items.price", "<=", 50},
wantKeys: []string{"CheapItems", "ExpensiveItems"},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
iter := coll.Query().Where(docstore.FieldPath(tc.where[0].(string)), tc.where[1].(string), tc.where[2]).Get(ctx)
var got []docmap
for {
doc := docmap{}
err := iter.Next(ctx, doc)
if err != nil {
if err == io.EOF {
break
}
t.Fatal(err)
}
got = append(got, doc)
}

// Extract keys from results
var gotKeys []string
for _, d := range got {
if key, ok := d[drivertest.KeyField].(string); ok {
gotKeys = append(gotKeys, key)
}
}
slices.Sort(gotKeys)

diff := cmp.Diff(gotKeys, tc.wantKeys)
if diff != "" {
t.Errorf("query results mismatch (-got +want):\n%s", diff)
}
})
}
}

func TestSortDocs(t *testing.T) {
newDocs := func() []storedDoc {
return []storedDoc{
Expand Down
52 changes: 37 additions & 15 deletions docstore/memdocstore/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *collection) RunGetQuery(_ context.Context, q *driver.Query) (driver.Doc

var resultDocs []storedDoc
for _, doc := range c.docs {
if filtersMatch(q.Filters, doc) {
if filtersMatch(q.Filters, doc, c.opts.AllowNestedSliceQueries) {
resultDocs = append(resultDocs, doc)
}
}
Expand Down Expand Up @@ -74,22 +74,22 @@ func (c *collection) RunGetQuery(_ context.Context, q *driver.Query) (driver.Doc
}, nil
}

func filtersMatch(fs []driver.Filter, doc storedDoc) bool {
func filtersMatch(fs []driver.Filter, doc storedDoc, nested bool) bool {
for _, f := range fs {
if !filterMatches(f, doc) {
if !filterMatches(f, doc, nested) {
return false
}
}
return true
}

func filterMatches(f driver.Filter, doc storedDoc) bool {
docval, err := getAtFieldPath(doc, f.FieldPath)
func filterMatches(f driver.Filter, doc storedDoc, nested bool) bool {
docval, err := getAtFieldPath(doc, f.FieldPath, nested)
// missing or bad field path => no match
if err != nil {
return false
}
c, ok := compare(docval, f.Value)
c, ok := compare(docval, f.Value, f.Op)
if !ok {
return false
}
Expand Down Expand Up @@ -120,24 +120,46 @@ func applyComparison(op string, c int) bool {
}
}

func compare(x1, x2 interface{}) (int, bool) {
func compare(x1, x2 any, op string) (int, bool) {
v1 := reflect.ValueOf(x1)
v2 := reflect.ValueOf(x2)
// this is for in/not-in queries.
// return 0 if x1 is in slice x2, -1 if not.
// For in/not-in queries. Otherwise this should only be reached with AllowNestedSliceQueries set.
// Return 0 if x1 is in slice x2, -1 if not.
if v2.Kind() == reflect.Slice {
for i := 0; i < v2.Len(); i++ {
if c, ok := compare(x1, v2.Index(i).Interface()); ok {
if !ok {
return 0, false
}
for i := range v2.Len() {
if c, ok := compare(x1, v2.Index(i).Interface(), op); ok {
if c == 0 {
return 0, true
}
if op != "in" && op != "not-in" {
return c, true
}
}
}
return -1, true
}
// See Options.AllowNestedSliceQueries
// When querying for x2 in the document and x1 is a list of values we only need one value to match
// the comparison value depends on the operator.
if v1.Kind() == reflect.Slice {
v2Greater := false
v2Less := false
for i := range v1.Len() {
if c, ok := compare(x2, v1.Index(i).Interface(), op); ok {
if c == 0 {
return 0, true
}
v2Greater = v2Greater || c > 0
v2Less = v2Less || c < 0
}
}
if op[0] == '>' && v2Less {
return 1, true
} else if op[0] == '<' && v2Greater {
return -1, true
}
return 0, false
}
if v1.Kind() == reflect.String && v2.Kind() == reflect.String {
return strings.Compare(v1.String(), v2.String()), true
}
Expand All @@ -160,7 +182,7 @@ func compare(x1, x2 interface{}) (int, bool) {

func sortDocs(docs []storedDoc, field string, asc bool) {
sort.Slice(docs, func(i, j int) bool {
c, ok := compare(docs[i][field], docs[j][field])
c, ok := compare(docs[i][field], docs[j][field], ">")
if !ok {
return false
}
Expand Down
6 changes: 4 additions & 2 deletions docstore/memdocstore/urls.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ func (o *URLOpener) OpenCollectionURL(ctx context.Context, u *url.URL) (*docstor
}

options := &Options{
RevisionField: q.Get("revision_field"),
Filename: q.Get("filename"),
RevisionField: q.Get("revision_field"),
Filename: q.Get("filename"),
AllowNestedSliceQueries: q.Get("allow_nested_slice_queries") == "true",
onClose: func() {
o.mu.Lock()
delete(o.collections, collName)
Expand All @@ -75,6 +76,7 @@ func (o *URLOpener) OpenCollectionURL(ctx context.Context, u *url.URL) (*docstor
}
q.Del("revision_field")
q.Del("filename")
q.Del("allow_nested_slice_queries")
for param := range q {
return nil, fmt.Errorf("open collection %v: invalid query parameter %q", u, param)
}
Expand Down