Skip to content

Commit d451c48

Browse files
committed
watchset: Wait for multiple channels to close until timeout
We need to be able to wait for many channels to close to avoid doing many Wait() calls back-to-back when multiple channels close in a WatchSet. Signed-off-by: Jussi Maki <[email protected]>
1 parent d1bfba1 commit d451c48

File tree

2 files changed

+111
-87
lines changed

2 files changed

+111
-87
lines changed

watchset.go

Lines changed: 87 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ package statedb
55

66
import (
77
"context"
8+
"fmt"
89
"maps"
910
"slices"
1011
"sync"
12+
"time"
1113
)
1214

1315
const watchSetChunkSize = 16
@@ -50,6 +52,18 @@ func (ws *WatchSet) Has(ch <-chan struct{}) bool {
5052
return found
5153
}
5254

55+
// HasAny returns true if the WatchSet has any of the given channels
56+
func (ws *WatchSet) HasAny(chans []<-chan struct{}) bool {
57+
ws.mu.Lock()
58+
defer ws.mu.Unlock()
59+
for _, ch := range chans {
60+
if _, found := ws.chans[ch]; found {
61+
return true
62+
}
63+
}
64+
return false
65+
}
66+
5367
// Merge channels from another WatchSet
5468
func (ws *WatchSet) Merge(other *WatchSet) {
5569
other.mu.Lock()
@@ -61,18 +75,23 @@ func (ws *WatchSet) Merge(other *WatchSet) {
6175
}
6276
}
6377

64-
// Wait for any channel in the watch set to close. The
65-
// watch set is cleared when this method returns.
66-
func (ws *WatchSet) Wait(ctx context.Context) (<-chan struct{}, error) {
78+
// Wait for channels in the watch set to close until context is cancelled or timeout reached.
79+
// Returns the closed channels and removes them from the set.
80+
func (ws *WatchSet) Wait(ctx context.Context, timeout time.Duration) ([]<-chan struct{}, error) {
81+
if timeout <= 0 {
82+
return nil, fmt.Errorf("bad timeout %d, must be >0", timeout)
83+
}
84+
innerCtx, cancel := context.WithTimeout(ctx, timeout)
85+
defer cancel()
86+
6787
ws.mu.Lock()
68-
defer func() {
69-
clear(ws.chans)
70-
ws.mu.Unlock()
71-
}()
88+
defer ws.mu.Unlock()
89+
90+
closedChannels := &closedChannelsSlice{}
7291

7392
// No channels to watch? Just watch the context.
7493
if len(ws.chans) == 0 {
75-
<-ctx.Done()
94+
<-innerCtx.Done()
7695
return nil, ctx.Err()
7796
}
7897

@@ -84,77 +103,78 @@ func (ws *WatchSet) Wait(ctx context.Context) (<-chan struct{}, error) {
84103
chans = slices.Grow(chans, roundedSize)[:roundedSize]
85104

86105
if len(ws.chans) <= chunkSize {
87-
ch := watch16(ctx.Done(), chans)
88-
return ch, ctx.Err()
106+
watch16(closedChannels, innerCtx.Done(), chans)
107+
return closedChannels.chans, ctx.Err()
89108
}
90109

91-
// More than one chunk. Fork goroutines to watch each chunk. The first chunk
92-
// that completes will cancel the context and stop the other goroutines.
93-
innerCtx, cancel := context.WithCancel(ctx)
94-
defer cancel()
95-
96-
closedChan := make(chan (<-chan struct{}), 1)
97-
defer close(closedChan)
98110
var wg sync.WaitGroup
99-
100111
for chunk := range slices.Chunk(chans, chunkSize) {
101112
wg.Add(1)
102113
go func() {
103-
defer cancel()
104114
defer wg.Done()
105-
chunk = slices.Clone(chunk)
106-
if ch := watch16(innerCtx.Done(), chunk); ch != nil {
107-
select {
108-
case closedChan <- ch:
109-
default:
110-
}
111-
}
115+
watch16(closedChannels, innerCtx.Done(), chunk)
112116
}()
113117
}
114118
wg.Wait()
115-
select {
116-
case <-ctx.Done():
117-
return nil, ctx.Err()
118-
case ch := <-closedChan:
119-
return ch, nil
119+
120+
for _, ch := range closedChannels.chans {
121+
delete(ws.chans, ch)
120122
}
123+
124+
return closedChannels.chans, ctx.Err()
121125
}
122126

123-
func watch16(stop <-chan struct{}, chans []<-chan struct{}) <-chan struct{} {
124-
select {
125-
case <-stop:
126-
return nil
127-
case <-chans[0]:
128-
return chans[0]
129-
case <-chans[1]:
130-
return chans[1]
131-
case <-chans[2]:
132-
return chans[2]
133-
case <-chans[3]:
134-
return chans[3]
135-
case <-chans[4]:
136-
return chans[4]
137-
case <-chans[5]:
138-
return chans[5]
139-
case <-chans[6]:
140-
return chans[6]
141-
case <-chans[7]:
142-
return chans[7]
143-
case <-chans[8]:
144-
return chans[8]
145-
case <-chans[9]:
146-
return chans[9]
147-
case <-chans[10]:
148-
return chans[10]
149-
case <-chans[11]:
150-
return chans[11]
151-
case <-chans[12]:
152-
return chans[12]
153-
case <-chans[13]:
154-
return chans[13]
155-
case <-chans[14]:
156-
return chans[14]
157-
case <-chans[15]:
158-
return chans[15]
127+
func watch16(closedChannels *closedChannelsSlice, stop <-chan struct{}, chans []<-chan struct{}) {
128+
for {
129+
closedIndex := -1
130+
select {
131+
case <-stop:
132+
return
133+
case <-chans[0]:
134+
closedIndex = 0
135+
case <-chans[1]:
136+
closedIndex = 1
137+
case <-chans[2]:
138+
closedIndex = 2
139+
case <-chans[3]:
140+
closedIndex = 3
141+
case <-chans[4]:
142+
closedIndex = 4
143+
case <-chans[5]:
144+
closedIndex = 5
145+
case <-chans[6]:
146+
closedIndex = 6
147+
case <-chans[7]:
148+
closedIndex = 7
149+
case <-chans[8]:
150+
closedIndex = 8
151+
case <-chans[9]:
152+
closedIndex = 9
153+
case <-chans[10]:
154+
closedIndex = 10
155+
case <-chans[11]:
156+
closedIndex = 11
157+
case <-chans[12]:
158+
closedIndex = 12
159+
case <-chans[13]:
160+
closedIndex = 13
161+
case <-chans[14]:
162+
closedIndex = 14
163+
case <-chans[15]:
164+
closedIndex = 15
165+
}
166+
closedChannels.append(chans[closedIndex])
167+
chans[closedIndex] = nil
159168
}
160169
}
170+
171+
type closedChannelsSlice struct {
172+
mu sync.Mutex
173+
chans []<-chan struct{}
174+
}
175+
176+
func (ccs *closedChannelsSlice) append(ch <-chan struct{}) {
177+
ccs.mu.Lock()
178+
ccs.chans = append(ccs.chans, ch)
179+
ccs.mu.Unlock()
180+
}

watchset_test.go

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func TestWatchSet(t *testing.T) {
2121
// Empty watch set, cancelled context.
2222
ctx, cancel := context.WithCancel(context.Background())
2323
go cancel()
24-
ch, err := ws.Wait(ctx)
24+
ch, err := ws.Wait(ctx, time.Second)
2525
require.ErrorIs(t, err, context.Canceled)
2626
require.Nil(t, ch)
2727

@@ -32,12 +32,12 @@ func TestWatchSet(t *testing.T) {
3232
ws.Add(ch1, ch2, ch3)
3333
ctx, cancel = context.WithCancel(context.Background())
3434
go cancel()
35-
ch, err = ws.Wait(ctx)
35+
ch, err = ws.Wait(ctx, time.Second)
3636
require.ErrorIs(t, err, context.Canceled)
3737
require.Nil(t, ch)
3838

3939
// Many channels
40-
for _, numChans := range []int{0, 1, 8, 12, 16, 31, 32, 61, 64, 121} {
40+
for _, numChans := range []int{0, 1, 16, 31, 61, 64} {
4141
for i := range numChans {
4242
var chans []chan struct{}
4343
var rchans []<-chan struct{}
@@ -46,14 +46,16 @@ func TestWatchSet(t *testing.T) {
4646
chans = append(chans, ch)
4747
rchans = append(rchans, ch)
4848
}
49+
ws.Clear()
4950
ws.Add(rchans...)
5051

5152
close(chans[i])
52-
ctx, cancel = context.WithCancel(context.Background())
53-
ch, err := ws.Wait(ctx)
53+
closed, err := ws.Wait(context.Background(), time.Millisecond)
5454
require.NoError(t, err)
55-
require.True(t, ch == chans[i])
55+
require.Len(t, closed, 1)
56+
require.True(t, closed[0] == chans[i])
5657
cancel()
58+
5759
}
5860
}
5961
}
@@ -68,10 +70,10 @@ func TestWatchSetInQueries(t *testing.T) {
6870

6971
// Should timeout as watches should not have closed yet.
7072
ws.Add(watchAll)
71-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
72-
ch, err := ws.Wait(ctx)
73+
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
74+
closed, err := ws.Wait(ctx, time.Second)
7375
require.ErrorIs(t, err, context.DeadlineExceeded)
74-
require.Nil(t, ch)
76+
require.Empty(t, closed)
7577
cancel()
7678

7779
// Insert some objects
@@ -83,20 +85,20 @@ func TestWatchSetInQueries(t *testing.T) {
8385

8486
// The 'watchAll' channel should now have closed and Wait() returns.
8587
ws.Add(watchAll)
86-
ch, err = ws.Wait(context.Background())
88+
closed, err = ws.Wait(context.Background(), time.Millisecond)
8789
require.NoError(t, err)
88-
require.Equal(t, ch, watchAll)
90+
require.Len(t, closed, 1)
91+
require.True(t, closed[0] == watchAll)
92+
ws.Clear()
8993

9094
// Try watching specific objects for changes.
9195
_, _, watch1, _ := table.GetWatch(txn, idIndex.Query(1))
9296
_, _, watch2, _ := table.GetWatch(txn, idIndex.Query(2))
9397
_, _, watch3, _ := table.GetWatch(txn, idIndex.Query(3))
9498

95-
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Millisecond)
96-
ch, err = ws.Wait(ctx)
97-
require.ErrorIs(t, err, context.DeadlineExceeded)
98-
require.Nil(t, ch)
99-
cancel()
99+
closed, err = ws.Wait(context.Background(), time.Millisecond)
100+
require.NoError(t, err)
101+
require.Empty(t, closed)
100102

101103
wtxn = db.WriteTxn(table)
102104
table.Insert(wtxn, testObject{ID: 1, Tags: part.NewSet("foo")})
@@ -111,11 +113,13 @@ func TestWatchSetInQueries(t *testing.T) {
111113
// in ws2.
112114
ws.Merge(ws2)
113115

114-
ch, err = ws.Wait(context.Background())
116+
closed, err = ws.Wait(context.Background(), time.Millisecond)
115117
require.NoError(t, err)
116-
require.True(t, ch == watch1)
117-
require.True(t, ws2.Has(ch))
118+
require.Len(t, closed, 1)
119+
require.True(t, closed[0] == watch1)
120+
require.True(t, ws2.Has(closed[0]))
121+
require.True(t, ws2.HasAny(closed))
118122

119123
ws2.Clear()
120-
require.False(t, ws2.Has(ch))
124+
require.False(t, ws2.Has(closed[0]))
121125
}

0 commit comments

Comments
 (0)