diff --git a/internal/integration/collection_test.go b/internal/integration/collection_test.go index 10a3b803f5..58e2575b57 100644 --- a/internal/integration/collection_test.go +++ b/internal/integration/collection_test.go @@ -2028,8 +2028,8 @@ func TestCollection(t *testing.T) { }) } -func initCollection(mt *mtest.T, coll *mongo.Collection) { - mt.Helper() +func initCollection(tb testing.TB, coll *mongo.Collection) { + tb.Helper() var docs []interface{} for i := 1; i <= 5; i++ { @@ -2037,7 +2037,7 @@ func initCollection(mt *mtest.T, coll *mongo.Collection) { } _, err := coll.InsertMany(context.Background(), docs) - assert.Nil(mt, err, "InsertMany error for initial data: %v", err) + assert.NoError(tb, err, "InsertMany error for initial data: %v", err) } func testAggregateWithOptions(mt *mtest.T, createIndex bool, opts options.Lister[options.AggregateOptions]) { diff --git a/internal/integration/cursor_test.go b/internal/integration/cursor_test.go index 6376e78e74..8bba46df53 100644 --- a/internal/integration/cursor_test.go +++ b/internal/integration/cursor_test.go @@ -14,6 +14,7 @@ import ( "time" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/failpoint" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" @@ -304,77 +305,248 @@ func TestCursor(t *testing.T) { batchSize = sizeVal.Int32() assert.Equal(mt, int32(4), batchSize, "expected batchSize 4, got %v", batchSize) }) +} - tailableAwaitDataCursorOpts := mtest.NewOptions().MinServerVersion("4.4"). - Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single) +func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 { + mt.Helper() - mt.RunOpts("tailable awaitData cursor", tailableAwaitDataCursorOpts, func(mt *mtest.T) { - mt.Run("apply remaining timeoutMS if less than maxAwaitTimeMS", func(mt *mtest.T) { - initCollection(mt, mt.Coll) - mt.ClearEvents() + maxTimeMSRaw, err := evt.Command.LookupErr("maxTimeMS") + require.NoError(mt, err) - // Create a find cursor - opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(100 * time.Millisecond) + got, ok := maxTimeMSRaw.AsInt64OK() + require.True(mt, ok) - cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) - require.NoError(mt, err) + return got +} - _ = mt.GetStartedEvent() // Empty find from started list. +func TestCursor_tailableAwaitData(t *testing.T) { + mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) - defer cursor.Close(context.Background()) + cappedOpts := options.CreateCollection().SetCapped(true). + SetSizeInBytes(1024 * 64) - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() + // TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS. + mtOpts := mtest.NewOptions().MinServerVersion("4.4"). + Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single). + CollectionCreateOptions(cappedOpts) - // Iterate twice to force a getMore - cursor.Next(ctx) - cursor.Next(ctx) + mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) - cmd := mt.GetStartedEvent().Command + // Create a 30ms failpoint for getMore. + mt.SetFailPoint(failpoint.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: failpoint.Mode{ + Times: 1, + }, + Data: failpoint.Data{ + FailCommands: []string{"getMore"}, + BlockConnection: true, + BlockTimeMS: 30, + }, + }) - maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") - require.NoError(mt, err) + // Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData + // cursor type. + opts := options.Find(). + SetBatchSize(1). + SetMaxAwaitTime(100 * time.Millisecond). + SetCursorType(options.TailableAwait) - got, ok := maxTimeMSRaw.AsInt64OK() - require.True(mt, ok) + cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts) + require.NoError(mt, err) - assert.LessOrEqual(mt, got, int64(50)) - }) + defer cursor.Close(context.Background()) - mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", tailableAwaitDataCursorOpts, func(mt *mtest.T) { - initCollection(mt, mt.Coll) - mt.ClearEvents() + // Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying + // getMore loop should run at least two times: the first getMore will block + // for 30ms on the getMore and then an additional 100ms for the + // maxAwaitTimeMS. The second getMore will then use the remaining ~70ms + // left on the timeout. + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() - // Create a find cursor - opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond) + // Iterate twice to force a getMore + cursor.Next(ctx) - cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) - require.NoError(mt, err) + mt.ClearEvents() + cursor.Next(ctx) - _ = mt.GetStartedEvent() // Empty find from started list. + require.Error(mt, cursor.Err(), "expected error from cursor.Next") + assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error") - defer cursor.Close(context.Background()) + // Collect all started events to find the getMore commands. + startedEvents := mt.GetAllStartedEvents() - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() + var getMoreStartedEvents []*event.CommandStartedEvent + for _, evt := range startedEvents { + if evt.CommandName == "getMore" { + getMoreStartedEvents = append(getMoreStartedEvents, evt) + } + } - // Iterate twice to force a getMore - cursor.Next(ctx) - cursor.Next(ctx) + // The first getMore should have a maxTimeMS of <= 100ms. + assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100)) - cmd := mt.GetStartedEvent().Command + // The second getMore should have a maxTimeMS of <=71, indicating that we + // are using the time remaining in the context rather than the + // maxAwaitTimeMS. + assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71)) + }) - maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") - require.NoError(mt, err) + mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single) - got, ok := maxTimeMSRaw.AsInt64OK() - require.True(mt, ok) + mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + mt.ClearEvents() - assert.LessOrEqual(mt, got, int64(50)) - }) + // Create a find cursor + opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond) + + cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) + require.NoError(mt, err) + + _ = mt.GetStartedEvent() // Empty find from started list. + + defer cursor.Close(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Iterate twice to force a getMore + cursor.Next(ctx) + cursor.Next(ctx) + + cmd := mt.GetStartedEvent().Command + + maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") + require.NoError(mt, err) + + got, ok := maxTimeMSRaw.AsInt64OK() + require.True(mt, ok) + + assert.LessOrEqual(mt, got, int64(50)) }) } +func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) { + mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) + + cappedOpts := options.CreateCollection().SetCapped(true). + SetSizeInBytes(1024 * 64) + + mtOpts := mtest.NewOptions().CollectionCreateOptions(cappedOpts) + tests := []struct { + name string + deadline time.Duration + maxAwaitTime time.Duration + wantShortCircuit bool + }{ + { + name: "maxAwaitTime less than operation timeout", + deadline: 200 * time.Millisecond, + maxAwaitTime: 100 * time.Millisecond, + wantShortCircuit: false, + }, + { + name: "maxAwaitTime equal to operation timeout", + deadline: 200 * time.Millisecond, + maxAwaitTime: 200 * time.Millisecond, + wantShortCircuit: true, + }, + { + name: "maxAwaitTime greater than operation timeout", + deadline: 200 * time.Millisecond, + maxAwaitTime: 300 * time.Millisecond, + wantShortCircuit: true, + }, + } + + for _, tt := range tests { + mt.Run(tt.name, func(mt *mtest.T) { + mt.RunOpts("find", mtOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + + // Create a find cursor + opts := options.Find(). + SetBatchSize(1). + SetMaxAwaitTime(tt.maxAwaitTime). + SetCursorType(options.TailableAwait) + + ctx, cancel := context.WithTimeout(context.Background(), tt.deadline) + defer cancel() + + cur, err := mt.Coll.Find(ctx, bson.D{{Key: "x", Value: 3}}, opts) + require.NoError(mt, err, "Find error: %v", err) + + // Close to return the session to the pool. + defer cur.Close(context.Background()) + + ok := cur.Next(ctx) + if tt.wantShortCircuit { + assert.False(mt, ok, "expected Next to return false, got true") + assert.EqualError(t, cur.Err(), "MaxAwaitTime must be less than the operation timeout") + } else { + assert.True(mt, ok, "expected Next to return true, got false") + assert.NoError(mt, cur.Err(), "expected no error, got %v", cur.Err()) + } + }) + + mt.RunOpts("aggregate", mtOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + + // Create a find cursor + opts := options.Aggregate(). + SetBatchSize(1). + SetMaxAwaitTime(tt.maxAwaitTime) + + ctx, cancel := context.WithTimeout(context.Background(), tt.deadline) + defer cancel() + + cur, err := mt.Coll.Aggregate(ctx, []bson.D{}, opts) + require.NoError(mt, err, "Aggregate error: %v", err) + + // Close to return the session to the pool. + defer cur.Close(context.Background()) + + ok := cur.Next(ctx) + if tt.wantShortCircuit { + assert.False(mt, ok, "expected Next to return false, got true") + assert.EqualError(t, cur.Err(), "MaxAwaitTime must be less than the operation timeout") + } else { + assert.True(mt, ok, "expected Next to return true, got false") + assert.NoError(mt, cur.Err(), "expected no error, got %v", cur.Err()) + } + }) + + // The $changeStream stage is only supported on replica sets. + watchOpts := mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded) + mt.RunOpts("watch", watchOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + + // Create a find cursor + opts := options.ChangeStream().SetMaxAwaitTime(tt.maxAwaitTime) + + ctx, cancel := context.WithTimeout(context.Background(), tt.deadline) + defer cancel() + + cur, err := mt.Coll.Watch(ctx, []bson.D{}, opts) + require.NoError(mt, err, "Watch error: %v", err) + + // Close to return the session to the pool. + defer cur.Close(context.Background()) + + if tt.wantShortCircuit { + ok := cur.Next(ctx) + + assert.False(mt, ok, "expected Next to return false, got true") + assert.EqualError(mt, cur.Err(), "MaxAwaitTime must be less than the operation timeout") + } + }) + }) + } +} + type tryNextCursor interface { TryNext(context.Context) bool Err() error diff --git a/internal/mongoutil/mongoutil.go b/internal/mongoutil/mongoutil.go index 0345b96e8f..be58d38bf4 100644 --- a/internal/mongoutil/mongoutil.go +++ b/internal/mongoutil/mongoutil.go @@ -7,7 +7,9 @@ package mongoutil import ( + "context" "reflect" + "time" "go.mongodb.org/mongo-driver/v2/mongo/options" ) @@ -83,3 +85,17 @@ func HostsFromURI(uri string) ([]string, error) { return opts.Hosts, nil } + +// TimeoutWithinContext will return true if the provided timeout is nil or if +// it is less than the context deadline. If the context does not have a +// deadline, it will return true. +func TimeoutWithinContext(ctx context.Context, timeout time.Duration) bool { + deadline, ok := ctx.Deadline() + if !ok { + return true + } + + ctxTimeout := time.Until(deadline) + + return ctxTimeout <= 0 || timeout < ctxTimeout +} diff --git a/internal/mongoutil/mongoutil_test.go b/internal/mongoutil/mongoutil_test.go index 661ee5f5bb..919d2089ca 100644 --- a/internal/mongoutil/mongoutil_test.go +++ b/internal/mongoutil/mongoutil_test.go @@ -7,9 +7,13 @@ package mongoutil import ( + "context" "strings" "testing" + "time" + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/ptrutil" "go.mongodb.org/mongo-driver/v2/mongo/options" ) @@ -32,3 +36,71 @@ func BenchmarkNewOptions(b *testing.B) { } }) } + +func TestValidChangeStreamTimeouts(t *testing.T) { + tests := []struct { + name string + ctxTimeout *time.Duration + timeout time.Duration + wantTimeout time.Duration + want bool + }{ + { + name: "Timeout shorter than context deadline", + ctxTimeout: ptrutil.Ptr(10 * time.Second), + timeout: 1 * time.Second, + want: true, + }, + { + name: "Timeout equal to context deadline", + ctxTimeout: ptrutil.Ptr(1 * time.Second), + timeout: 1 * time.Second, + want: false, + }, + { + name: "Timeout greater than context deadline", + ctxTimeout: ptrutil.Ptr(1 * time.Second), + timeout: 10 * time.Second, + want: false, + }, + { + name: "Context deadline already expired", + ctxTimeout: ptrutil.Ptr(-1 * time.Second), + timeout: 1 * time.Second, + want: true, // *timeout <= 0 branch in code + }, + { + name: "Timeout is zero, context deadline in future", + ctxTimeout: ptrutil.Ptr(10 * time.Second), + timeout: 0 * time.Second, + want: true, + }, + { + name: "Timeout is negative, context deadline in future", + ctxTimeout: ptrutil.Ptr(10 * time.Second), + timeout: -1 * time.Second, + want: true, + }, + { + name: "Timeout provided, context has no deadline", + ctxTimeout: nil, + timeout: 1 * time.Second, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + if test.ctxTimeout != nil { + var cancel context.CancelFunc + + ctx, cancel = context.WithTimeout(ctx, *test.ctxTimeout) + defer cancel() + } + + got := TimeoutWithinContext(ctx, test.timeout) + assert.Equal(t, test.want, got) + }) + } +} diff --git a/internal/spectest/skip.go b/internal/spectest/skip.go index 4f6273c96e..e212a7ee4f 100644 --- a/internal/spectest/skip.go +++ b/internal/spectest/skip.go @@ -567,7 +567,6 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-operation-timeoutMS.json/timeoutMS_applied_to_withTransaction", "TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-timeoutMS.json", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_timeoutMode_is_cursor_lifetime", - "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set", @@ -802,19 +801,25 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/transactions-convenient-api/tests/unified/commit.json/withTransaction_commits_after_callback_returns", }, - // GODRIVER-3473: the implementation of DRIVERS-2868 makes it clear that the - // Go Driver does not correctly implement the following validation for - // tailable awaitData cursors: + "Address CSOT Compliance Issue in Timeout Handling for Cursor Constructors (GODRIVER-3480)": { + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS", + }, + + // The Go Driver does not support "iteration" mode for cursors. That is, + // we do not apply the timeout used to construct the cursor when using the + // cursor, rather we apply the context-level timeout if one is provided. It's + // doubtful that we will ever support this mode, so we skip these tests. // - // Drivers MUST error if this option is set, timeoutMS is set to a - // non-zero value, and maxAwaitTimeMS is greater than or equal to - // timeoutMS. + // If we do ever support this mode, it will be done as part of DRIVERS-2722 + // which does not currently have driver-specific tickets. // - // Once GODRIVER-3473 is completed, we can continue running these tests. - "When constructing tailable awaitData cusors must validate, timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to timeoutMS (GODRIVER-3473)": { + // Note that we have integration tests that cover the cases described in these + // tests upto what is supported in the Go Driver. See GODRIVER-3473 + "Change CSOT default cursor timeout mode to ITERATION (DRIVERS-2772)": { "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS", }, } diff --git a/mongo/batch_cursor.go b/mongo/batch_cursor.go index 148a627530..4781812dac 100644 --- a/mongo/batch_cursor.go +++ b/mongo/batch_cursor.go @@ -51,6 +51,11 @@ type batchCursor interface { // SetComment will set a user-configurable comment that can be used to // identify the operation in server logs. SetComment(interface{}) + + // MaxAwaitTime returns the maximum amount of time the server will allow + // the operations to execute. This is only valid for tailable awaitData + // cursors. + MaxAwaitTime() *time.Duration } // changeStreamCursor is the interface implemented by batch cursors that also provide the functionality for retrieving diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 8b3be5aad9..6f0ca8a084 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -12,7 +12,6 @@ import ( "fmt" "reflect" "strconv" - "time" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/csot" @@ -103,33 +102,6 @@ type changeStreamConfig struct { crypt driver.Crypt } -// validChangeStreamTimeouts will return "false" if maxAwaitTimeMS is set, -// timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or -// equal to timeoutMS. Otherwise, the timeouts are valid. -func validChangeStreamTimeouts(ctx context.Context, cs *ChangeStream) bool { - if cs.options == nil || cs.client == nil { - return true - } - - maxAwaitTime := cs.options.MaxAwaitTime - timeout := cs.client.timeout - - if maxAwaitTime == nil { - return true - } - - if deadline, ok := ctx.Deadline(); ok { - ctxTimeout := time.Until(deadline) - timeout = &ctxTimeout - } - - if timeout == nil { - return true - } - - return *timeout <= 0 || *maxAwaitTime < *timeout -} - func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{}, opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) { if ctx == nil { @@ -696,10 +668,33 @@ func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool { } func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) { - if !validChangeStreamTimeouts(ctx, cs) { - cs.err = fmt.Errorf("MaxAwaitTime must be less than the operation timeout") + // To avoid unnecessary socket timeouts, we attempt to short-circuit tailable + // awaitData "getMore" operations by ensuring that the maxAwaitTimeMS is less + // than the operation timeout. + // + // The specifications assume that drivers iteratively apply the timeout + // provided at the constructor level (e.g., (*collection).Find) for tailable + // awaitData cursors: + // + // If set, drivers MUST apply the timeoutMS option to the initial aggregate + // operation. Drivers MUST also apply the original timeoutMS value to each + // next call on the change stream but MUST NOT use it to derive a maxTimeMS + // field for getMore commands. + // + // The Go Driver might decide to support the above behavior with DRIVERS-2722. + // The principal concern is that it would be unexpected for users to apply an + // operation-level timeout via contexts to a constructor and then that timeout + // later be applied while working with a resulting cursor. Instead, it is more + // idiomatic to apply the timeout to the context passed to Next or TryNext. + if cs.options != nil && !nonBlocking { + maxAwaitTime := cs.cursorOptions.MaxAwaitTime + + // If maxAwaitTime is not set, this check is unnecessary. + if maxAwaitTime != nil && !mongoutil.TimeoutWithinContext(ctx, *maxAwaitTime) { + cs.err = fmt.Errorf("MaxAwaitTime must be less than the operation timeout") - return + return + } } // Apply the client-level timeout if the operation-level timeout is not set. diff --git a/mongo/change_stream_test.go b/mongo/change_stream_test.go index 8e722764a8..c2752a2c16 100644 --- a/mongo/change_stream_test.go +++ b/mongo/change_stream_test.go @@ -7,12 +7,9 @@ package mongo import ( - "context" "testing" - "time" "go.mongodb.org/mongo-driver/v2/internal/assert" - "go.mongodb.org/mongo-driver/v2/mongo/options" ) func TestChangeStream(t *testing.T) { @@ -30,96 +27,3 @@ func TestChangeStream(t *testing.T) { assert.Nil(t, err, "Close error: %v", err) }) } - -func TestValidChangeStreamTimeouts(t *testing.T) { - t.Parallel() - - newDurPtr := func(dur time.Duration) *time.Duration { - return &dur - } - - tests := []struct { - name string - parent context.Context - maxAwaitTimeout, timeout *time.Duration - wantTimeout time.Duration - want bool - }{ - { - name: "no context deadline and no timeouts", - parent: context.Background(), - maxAwaitTimeout: nil, - timeout: nil, - wantTimeout: 0, - want: true, - }, - { - name: "no context deadline and maxAwaitTimeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(1), - timeout: nil, - wantTimeout: 0, - want: true, - }, - { - name: "no context deadline and timeout", - parent: context.Background(), - maxAwaitTimeout: nil, - timeout: newDurPtr(1), - wantTimeout: 0, - want: true, - }, - { - name: "no context deadline and maxAwaitTime gt timeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(2), - timeout: newDurPtr(1), - wantTimeout: 0, - want: false, - }, - { - name: "no context deadline and maxAwaitTime lt timeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(1), - timeout: newDurPtr(2), - wantTimeout: 0, - want: true, - }, - { - name: "no context deadline and maxAwaitTime eq timeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(1), - timeout: newDurPtr(1), - wantTimeout: 0, - want: false, - }, - { - name: "no context deadline and maxAwaitTime with negative timeout", - parent: context.Background(), - maxAwaitTimeout: newDurPtr(1), - timeout: newDurPtr(-1), - wantTimeout: 0, - want: true, - }, - } - - for _, test := range tests { - test := test // Capture the range variable - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - cs := &ChangeStream{ - options: &options.ChangeStreamOptions{ - MaxAwaitTime: test.maxAwaitTimeout, - }, - client: &Client{ - timeout: test.timeout, - }, - } - - got := validChangeStreamTimeouts(test.parent, cs) - assert.Equal(t, test.want, got) - }) - } -} diff --git a/mongo/cursor.go b/mongo/cursor.go index e8ab9caa11..583cbd00df 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -16,6 +16,7 @@ import ( "time" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/mongoutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" @@ -175,6 +176,32 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { if ctx == nil { ctx = context.Background() } + + // To avoid unnecessary socket timeouts, we attempt to short-circuit tailable + // awaitData "getMore" operations by ensuring that the maxAwaitTimeMS is less + // than the operation timeout. + // + // The specifications assume that drivers iteratively apply the timeout + // provided at the constructor level (e.g., (*collection).Find) for tailable + // awaitData cursors: + // + // If set, drivers MUST apply the timeoutMS option to the initial aggregate + // operation. Drivers MUST also apply the original timeoutMS value to each + // next call on the change stream but MUST NOT use it to derive a maxTimeMS + // field for getMore commands. + // + // The Go Driver might decide to support the above behavior with DRIVERS-2722. + // The principal concern is that it would be unexpected for users to apply an + // operation-level timeout via contexts to a constructor and then that timeout + // later be applied while working with a resulting cursor. Instead, it is more + // idiomatic to apply the timeout to the context passed to Next or TryNext. + maxAwaitTime := c.bc.MaxAwaitTime() // + if maxAwaitTime != nil && !nonBlocking && !mongoutil.TimeoutWithinContext(ctx, *maxAwaitTime) { + c.err = fmt.Errorf("MaxAwaitTime must be less than the operation timeout") + + return false + } + val, err := c.batch.Next() switch { case err == nil: diff --git a/mongo/cursor_test.go b/mongo/cursor_test.go index 6601e94a62..4573c9f9ac 100644 --- a/mongo/cursor_test.go +++ b/mongo/cursor_test.go @@ -27,6 +27,8 @@ type testBatchCursor struct { closed bool } +var _ batchCursor = (*testBatchCursor)(nil) + func newTestBatchCursor(numBatches, batchSize int) *testBatchCursor { batches := make([]*bsoncore.Iterator, 0, numBatches) @@ -99,6 +101,7 @@ func (tbc *testBatchCursor) Close(context.Context) error { func (tbc *testBatchCursor) SetBatchSize(int32) {} func (tbc *testBatchCursor) SetComment(interface{}) {} func (tbc *testBatchCursor) SetMaxAwaitTime(time.Duration) {} +func (tbc *testBatchCursor) MaxAwaitTime() *time.Duration { return nil } func TestCursor(t *testing.T) { t.Run("TestAll", func(t *testing.T) { diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 6f563eb5c6..00674fbd2d 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -545,6 +545,12 @@ func (bc *BatchCursor) getOperationDeployment() Deployment { return SingleServerDeployment{bc.server} } +// MaxAwaitTime returns the maximum amount of time the server will allow +// the operations to execute. This is only valid for tailable awaitData cursors. +func (bc *BatchCursor) MaxAwaitTime() *time.Duration { + return bc.maxAwaitTime +} + // loadBalancedCursorDeployment is used as a Deployment for getMore and killCursors commands when pinning to a // connection in load balanced mode. This type also functions as an ErrorProcessor to ensure that SDAM errors are // handled for these commands in this mode.