Skip to content

GODRIVER-3473 Short-cicruit cursor.next() on invalid timeouts #2135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions internal/integration/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2028,16 +2028,16 @@ func TestCollection(t *testing.T) {
})
}

func initCollection(mt *mtest.T, coll *mongo.Collection) {
mt.Helper()
func initCollection(tb testing.TB, coll *mongo.Collection) {
tb.Helper()

var docs []interface{}
for i := 1; i <= 5; i++ {
docs = append(docs, bson.D{{"x", int32(i)}})
}

_, err := coll.InsertMany(context.Background(), docs)
assert.Nil(mt, err, "InsertMany error for initial data: %v", err)
assert.NoError(tb, err, "InsertMany error for initial data: %v", err)
}

func testAggregateWithOptions(mt *mtest.T, createIndex bool, opts options.Lister[options.AggregateOptions]) {
Expand Down
262 changes: 217 additions & 45 deletions internal/integration/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/failpoint"
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
Expand Down Expand Up @@ -304,77 +305,248 @@ func TestCursor(t *testing.T) {
batchSize = sizeVal.Int32()
assert.Equal(mt, int32(4), batchSize, "expected batchSize 4, got %v", batchSize)
})
}

tailableAwaitDataCursorOpts := mtest.NewOptions().MinServerVersion("4.4").
Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 {
mt.Helper()

mt.RunOpts("tailable awaitData cursor", tailableAwaitDataCursorOpts, func(mt *mtest.T) {
mt.Run("apply remaining timeoutMS if less than maxAwaitTimeMS", func(mt *mtest.T) {
initCollection(mt, mt.Coll)
mt.ClearEvents()
maxTimeMSRaw, err := evt.Command.LookupErr("maxTimeMS")
require.NoError(mt, err)

// Create a find cursor
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(100 * time.Millisecond)
got, ok := maxTimeMSRaw.AsInt64OK()
require.True(mt, ok)

cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
require.NoError(mt, err)
return got
}

_ = mt.GetStartedEvent() // Empty find from started list.
func TestCursor_tailableAwaitData(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))

defer cursor.Close(context.Background())
cappedOpts := options.CreateCollection().SetCapped(true).
SetSizeInBytes(1024 * 64)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
// TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS.
mtOpts := mtest.NewOptions().MinServerVersion("4.4").
Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single).
CollectionCreateOptions(cappedOpts)

// Iterate twice to force a getMore
cursor.Next(ctx)
cursor.Next(ctx)
mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)

cmd := mt.GetStartedEvent().Command
// Create a 30ms failpoint for getMore.
mt.SetFailPoint(failpoint.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: failpoint.Mode{
Times: 1,
},
Data: failpoint.Data{
FailCommands: []string{"getMore"},
BlockConnection: true,
BlockTimeMS: 30,
},
})

maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
require.NoError(mt, err)
// Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData
// cursor type.
opts := options.Find().
SetBatchSize(1).
SetMaxAwaitTime(100 * time.Millisecond).
SetCursorType(options.TailableAwait)

got, ok := maxTimeMSRaw.AsInt64OK()
require.True(mt, ok)
cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts)
require.NoError(mt, err)

assert.LessOrEqual(mt, got, int64(50))
})
defer cursor.Close(context.Background())

mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", tailableAwaitDataCursorOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)
mt.ClearEvents()
// Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying
// getMore loop should run at least two times: the first getMore will block
// for 30ms on the getMore and then an additional 100ms for the
// maxAwaitTimeMS. The second getMore will then use the remaining ~70ms
// left on the timeout.
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()

// Create a find cursor
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)
// Iterate twice to force a getMore
cursor.Next(ctx)

cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
require.NoError(mt, err)
mt.ClearEvents()
cursor.Next(ctx)

_ = mt.GetStartedEvent() // Empty find from started list.
require.Error(mt, cursor.Err(), "expected error from cursor.Next")
assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")

defer cursor.Close(context.Background())
// Collect all started events to find the getMore commands.
startedEvents := mt.GetAllStartedEvents()

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
var getMoreStartedEvents []*event.CommandStartedEvent
for _, evt := range startedEvents {
if evt.CommandName == "getMore" {
getMoreStartedEvents = append(getMoreStartedEvents, evt)
}
}

// Iterate twice to force a getMore
cursor.Next(ctx)
cursor.Next(ctx)
// The first getMore should have a maxTimeMS of <= 100ms.
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100))

cmd := mt.GetStartedEvent().Command
// The second getMore should have a maxTimeMS of <=71, indicating that we
// are using the time remaining in the context rather than the
// maxAwaitTimeMS.
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71))
})

maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
require.NoError(mt, err)
mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)

got, ok := maxTimeMSRaw.AsInt64OK()
require.True(mt, ok)
mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)
mt.ClearEvents()

assert.LessOrEqual(mt, got, int64(50))
})
// Create a find cursor
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)

cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
require.NoError(mt, err)

_ = mt.GetStartedEvent() // Empty find from started list.

defer cursor.Close(context.Background())

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

// Iterate twice to force a getMore
cursor.Next(ctx)
cursor.Next(ctx)

cmd := mt.GetStartedEvent().Command

maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
require.NoError(mt, err)

got, ok := maxTimeMSRaw.AsInt64OK()
require.True(mt, ok)

assert.LessOrEqual(mt, got, int64(50))
})
}

func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))

cappedOpts := options.CreateCollection().SetCapped(true).
SetSizeInBytes(1024 * 64)

mtOpts := mtest.NewOptions().CollectionCreateOptions(cappedOpts)
tests := []struct {
name string
deadline time.Duration
maxAwaitTime time.Duration
wantShortCircuit bool
}{
{
name: "maxAwaitTime less than operation timeout",
deadline: 200 * time.Millisecond,
maxAwaitTime: 100 * time.Millisecond,
wantShortCircuit: false,
},
{
name: "maxAwaitTime equal to operation timeout",
deadline: 200 * time.Millisecond,
maxAwaitTime: 200 * time.Millisecond,
wantShortCircuit: true,
},
{
name: "maxAwaitTime greater than operation timeout",
deadline: 200 * time.Millisecond,
maxAwaitTime: 300 * time.Millisecond,
wantShortCircuit: true,
},
}

for _, tt := range tests {
mt.Run(tt.name, func(mt *mtest.T) {
mt.RunOpts("find", mtOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)

// Create a find cursor
opts := options.Find().
SetBatchSize(1).
SetMaxAwaitTime(tt.maxAwaitTime).
SetCursorType(options.TailableAwait)

ctx, cancel := context.WithTimeout(context.Background(), tt.deadline)
defer cancel()

cur, err := mt.Coll.Find(ctx, bson.D{{Key: "x", Value: 3}}, opts)
require.NoError(mt, err, "Find error: %v", err)

// Close to return the session to the pool.
defer cur.Close(context.Background())

ok := cur.Next(ctx)
if tt.wantShortCircuit {
assert.False(mt, ok, "expected Next to return false, got true")
assert.EqualError(t, cur.Err(), "MaxAwaitTime must be less than the operation timeout")
} else {
assert.True(mt, ok, "expected Next to return true, got false")
assert.NoError(mt, cur.Err(), "expected no error, got %v", cur.Err())
}
})

mt.RunOpts("aggregate", mtOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)

// Create a find cursor
opts := options.Aggregate().
SetBatchSize(1).
SetMaxAwaitTime(tt.maxAwaitTime)

ctx, cancel := context.WithTimeout(context.Background(), tt.deadline)
defer cancel()

cur, err := mt.Coll.Aggregate(ctx, []bson.D{}, opts)
require.NoError(mt, err, "Aggregate error: %v", err)

// Close to return the session to the pool.
defer cur.Close(context.Background())

ok := cur.Next(ctx)
if tt.wantShortCircuit {
assert.False(mt, ok, "expected Next to return false, got true")
assert.EqualError(t, cur.Err(), "MaxAwaitTime must be less than the operation timeout")
} else {
assert.True(mt, ok, "expected Next to return true, got false")
assert.NoError(mt, cur.Err(), "expected no error, got %v", cur.Err())
}
})

// The $changeStream stage is only supported on replica sets.
watchOpts := mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded)
mt.RunOpts("watch", watchOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)

// Create a find cursor
opts := options.ChangeStream().SetMaxAwaitTime(tt.maxAwaitTime)

ctx, cancel := context.WithTimeout(context.Background(), tt.deadline)
defer cancel()

cur, err := mt.Coll.Watch(ctx, []bson.D{}, opts)
require.NoError(mt, err, "Watch error: %v", err)

// Close to return the session to the pool.
defer cur.Close(context.Background())

if tt.wantShortCircuit {
ok := cur.Next(ctx)

assert.False(mt, ok, "expected Next to return false, got true")
assert.EqualError(mt, cur.Err(), "MaxAwaitTime must be less than the operation timeout")
}
})
})
}
}

type tryNextCursor interface {
TryNext(context.Context) bool
Err() error
Expand Down
16 changes: 16 additions & 0 deletions internal/mongoutil/mongoutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
package mongoutil

import (
"context"
"reflect"
"time"

"go.mongodb.org/mongo-driver/v2/mongo/options"
)
Expand Down Expand Up @@ -83,3 +85,17 @@ func HostsFromURI(uri string) ([]string, error) {

return opts.Hosts, nil
}

// TimeoutWithinContext will return true if the provided timeout is nil or if
// it is less than the context deadline. If the context does not have a
// deadline, it will return true.
func TimeoutWithinContext(ctx context.Context, timeout time.Duration) bool {
deadline, ok := ctx.Deadline()
if !ok {
return true
}

ctxTimeout := time.Until(deadline)

return ctxTimeout <= 0 || timeout < ctxTimeout
}
Loading
Loading