@@ -5,7 +5,6 @@ package statedb
55
66import (
77 "context"
8- "fmt"
98 "maps"
109 "slices"
1110 "sync"
@@ -75,13 +74,12 @@ func (ws *WatchSet) Merge(other *WatchSet) {
7574 }
7675}
7776
78- // Wait for channels in the watch set to close until context is cancelled or timeout reached.
77+ // Wait for channels in the watch set to close or the context is cancelled.
78+ // After the first closed channel is seen Wait will wait [settleTime] for
79+ // more closed channels.
7980// Returns the closed channels and removes them from the set.
80- func (ws * WatchSet ) Wait (ctx context.Context , timeout time.Duration ) ([]<- chan struct {}, error ) {
81- if timeout <= 0 {
82- return nil , fmt .Errorf ("bad timeout %d, must be >0" , timeout )
83- }
84- innerCtx , cancel := context .WithTimeout (ctx , timeout )
81+ func (ws * WatchSet ) Wait (ctx context.Context , settleTime time.Duration ) ([]<- chan struct {}, error ) {
82+ innerCtx , cancel := context .WithCancel (ctx )
8583 defer cancel ()
8684
8785 ws .mu .Lock ()
@@ -91,7 +89,7 @@ func (ws *WatchSet) Wait(ctx context.Context, timeout time.Duration) ([]<-chan s
9189
9290 // No channels to watch? Just watch the context.
9391 if len (ws .chans ) == 0 {
94- <- innerCtx .Done ()
92+ <- ctx .Done ()
9593 return nil , ctx .Err ()
9694 }
9795
@@ -101,30 +99,44 @@ func (ws *WatchSet) Wait(ctx context.Context, timeout time.Duration) ([]<-chan s
10199 chunkSize := 16
102100 roundedSize := len (chans ) + (chunkSize - len (chans )% chunkSize )
103101 chans = slices .Grow (chans , roundedSize )[:roundedSize ]
104-
105- if len (ws .chans ) <= chunkSize {
106- watch16 (closedChannels , innerCtx .Done (), chans )
107- return closedChannels .chans , ctx .Err ()
108- }
102+ haveResult := make (chan struct {}, 1 )
109103
110104 var wg sync.WaitGroup
111- for chunk := range slices .Chunk (chans , chunkSize ) {
105+ chunks := slices .Chunk (chans , chunkSize )
106+ for chunk := range chunks {
112107 wg .Add (1 )
113108 go func () {
114109 defer wg .Done ()
115- watch16 (closedChannels , innerCtx .Done (), chunk )
110+ watch16 (haveResult , closedChannels , innerCtx .Done (), chunk )
116111 }()
117112 }
113+
114+ // Wait for the first closed channel to be seen. If [settleTime] is set,
115+ // then wait a bit longer for more.
116+ select {
117+ case <- haveResult :
118+ if settleTime > 0 {
119+ select {
120+ case <- time .After (settleTime ):
121+ case <- ctx .Done ():
122+ }
123+ }
124+ case <- ctx .Done ():
125+ }
126+
127+ // Stop waiting for more channels to close
128+ cancel ()
118129 wg .Wait ()
119130
131+ // Remove the closed channels from the watch set.
120132 for _ , ch := range closedChannels .chans {
121133 delete (ws .chans , ch )
122134 }
123135
124136 return closedChannels .chans , ctx .Err ()
125137}
126138
127- func watch16 (closedChannels * closedChannelsSlice , stop <- chan struct {}, chans []<- chan struct {}) {
139+ func watch16 (haveClosed chan struct {}, closedChannels * closedChannelsSlice , stop <- chan struct {}, chans []<- chan struct {}) {
128140 for {
129141 closedIndex := - 1
130142 select {
@@ -165,6 +177,13 @@ func watch16(closedChannels *closedChannelsSlice, stop <-chan struct{}, chans []
165177 }
166178 closedChannels .append (chans [closedIndex ])
167179 chans [closedIndex ] = nil
180+ if haveClosed != nil {
181+ select {
182+ case haveClosed <- struct {}{}:
183+ haveClosed = nil
184+ default :
185+ }
186+ }
168187 }
169188}
170189
0 commit comments