@@ -78,7 +78,7 @@ pub struct StreamRouter {
78
78
#[ allow( dead_code) ]
79
79
workers : Vec < Worker > ,
80
80
metrics : Arc < CollabStreamMetrics > ,
81
- channels : DashMap < StreamKey , StreamSender > ,
81
+ channels : Arc < DashMap < StreamKey , WeakStreamSender > > ,
82
82
buffer_capacity : usize ,
83
83
}
84
84
@@ -94,6 +94,7 @@ impl StreamRouter {
94
94
) -> Result < Self , RedisError > {
95
95
let alive = Arc :: new ( AtomicBool :: new ( true ) ) ;
96
96
let ( tx, rx) = loole:: unbounded ( ) ;
97
+ let channels = Arc :: new ( DashMap :: new ( ) ) ;
97
98
let mut workers = Vec :: with_capacity ( options. worker_count ) ;
98
99
for worker_id in 0 ..options. worker_count {
99
100
let conn = client. get_connection ( ) ?;
@@ -105,6 +106,7 @@ impl StreamRouter {
105
106
alive. clone ( ) ,
106
107
& options,
107
108
metrics. clone ( ) ,
109
+ channels. clone ( ) ,
108
110
) ;
109
111
workers. push ( worker) ;
110
112
}
@@ -114,7 +116,7 @@ impl StreamRouter {
114
116
workers,
115
117
alive,
116
118
metrics,
117
- channels : DashMap :: new ( ) ,
119
+ channels,
118
120
buffer_capacity : options. xread_count . unwrap_or ( 10_000 ) ,
119
121
} )
120
122
}
@@ -128,16 +130,28 @@ impl StreamRouter {
128
130
Entry :: Vacant ( e) => {
129
131
tracing:: trace!( "creating new stream channel for {}" , e. key( ) ) ;
130
132
let ( tx, rx) = tokio:: sync:: broadcast:: channel ( self . buffer_capacity ) ;
133
+ e. insert ( tx. downgrade ( ) ) ;
131
134
let last_id = last_id. unwrap_or_else ( || "0" . to_string ( ) ) ;
132
- let h = StreamHandle :: new ( stream_key. clone ( ) , last_id, tx. clone ( ) ) ;
135
+ let h = StreamHandle :: new ( stream_key. clone ( ) , last_id, tx) ;
133
136
self . buf . send ( h) . unwrap ( ) ;
134
137
self . metrics . reads_enqueued . inc ( ) ;
135
- e. insert ( tx) ;
136
138
rx
137
139
} ,
138
- Entry :: Occupied ( e) => {
139
- tracing:: trace!( "reusing existing stream channel for {}" , e. key( ) ) ;
140
- e. get ( ) . subscribe ( )
140
+ Entry :: Occupied ( mut e) => {
141
+ let sender = e. get ( ) ;
142
+ if let Some ( sender) = sender. upgrade ( ) {
143
+ tracing:: trace!( "reusing existing stream channel for {}" , e. key( ) ) ;
144
+ sender. subscribe ( )
145
+ } else {
146
+ tracing:: trace!( "creating new stream channel for {}" , e. key( ) ) ;
147
+ let ( tx, rx) = tokio:: sync:: broadcast:: channel ( self . buffer_capacity ) ;
148
+ e. insert ( tx. downgrade ( ) ) ;
149
+ let last_id = last_id. unwrap_or_else ( || "0" . to_string ( ) ) ;
150
+ let h = StreamHandle :: new ( stream_key. clone ( ) , last_id, tx) ;
151
+ self . buf . send ( h) . unwrap ( ) ;
152
+ self . metrics . reads_enqueued . inc ( ) ;
153
+ rx
154
+ }
141
155
} ,
142
156
} ;
143
157
StreamReader :: new ( rx)
@@ -188,6 +202,7 @@ struct Worker {
188
202
}
189
203
190
204
impl Worker {
205
+ #[ allow( clippy:: too_many_arguments) ]
191
206
fn new (
192
207
worker_id : usize ,
193
208
conn : Connection ,
@@ -196,6 +211,7 @@ impl Worker {
196
211
alive : Arc < AtomicBool > ,
197
212
options : & StreamRouterOptions ,
198
213
metrics : Arc < CollabStreamMetrics > ,
214
+ channels : Arc < DashMap < StreamKey , WeakStreamSender > > ,
199
215
) -> Self {
200
216
let mut xread_options = StreamReadOptions :: default ( ) ;
201
217
if let Some ( block_millis) = options. xread_block_millis {
@@ -206,13 +222,16 @@ impl Worker {
206
222
}
207
223
let count = options. xread_streams ;
208
224
let handle = std:: thread:: spawn ( move || {
209
- if let Err ( err) = Self :: process_streams ( conn, tx, rx, alive, xread_options, count, metrics) {
225
+ if let Err ( err) =
226
+ Self :: process_streams ( conn, tx, rx, alive, xread_options, count, metrics, channels)
227
+ {
210
228
tracing:: error!( "worker {} failed: {}" , worker_id, err) ;
211
229
}
212
230
} ) ;
213
231
Self { _handle : handle }
214
232
}
215
233
234
+ #[ allow( clippy:: too_many_arguments) ]
216
235
fn process_streams (
217
236
mut conn : Connection ,
218
237
tx : Sender < StreamHandle > ,
@@ -221,6 +240,7 @@ impl Worker {
221
240
options : StreamReadOptions ,
222
241
count : usize ,
223
242
metrics : Arc < CollabStreamMetrics > ,
243
+ channels : Arc < DashMap < StreamKey , WeakStreamSender > > ,
224
244
) -> RedisResult < ( ) > {
225
245
let mut stream_keys = Vec :: with_capacity ( count) ;
226
246
let mut message_ids = Vec :: with_capacity ( count) ;
@@ -259,6 +279,7 @@ impl Worker {
259
279
}
260
280
261
281
if remove_sender {
282
+ channels. remove ( & stream. key ) ;
262
283
senders. remove ( stream. key . as_str ( ) ) ;
263
284
}
264
285
}
@@ -270,7 +291,13 @@ impl Worker {
270
291
key_count
271
292
) ;
272
293
}
273
- let scheduled = Self :: schedule_back ( & tx, & mut stream_keys, & mut message_ids, & mut senders) ;
294
+ let scheduled = Self :: schedule_back (
295
+ & tx,
296
+ & mut stream_keys,
297
+ & mut message_ids,
298
+ & mut senders,
299
+ & channels,
300
+ ) ;
274
301
metrics. reads_enqueued . inc_by ( scheduled as u64 ) ;
275
302
}
276
303
Ok ( ( ) )
@@ -281,6 +308,7 @@ impl Worker {
281
308
keys : & mut Vec < StreamKey > ,
282
309
ids : & mut Vec < String > ,
283
310
senders : & mut HashMap < & str , ( StreamSender , usize ) > ,
311
+ channels : & DashMap < StreamKey , WeakStreamSender > ,
284
312
) -> usize {
285
313
let keys = keys. drain ( ..) ;
286
314
let mut ids = ids. drain ( ..) ;
@@ -289,6 +317,8 @@ impl Worker {
289
317
if let Some ( last_id) = ids. next ( ) {
290
318
if let Some ( ( sender, _) ) = senders. remove ( key. as_str ( ) ) {
291
319
if sender. receiver_count ( ) == 0 {
320
+ channels. remove ( key. as_str ( ) ) ;
321
+ tracing:: trace!( "no subscribers for {}, removing channel" , key) ;
292
322
continue ; // sender is already closed
293
323
}
294
324
let h = StreamHandle :: new ( key, last_id, sender) ;
@@ -345,6 +375,7 @@ impl Worker {
345
375
}
346
376
347
377
pub type RedisMap = HashMap < String , Value > ;
378
+ type WeakStreamSender = tokio:: sync:: broadcast:: WeakSender < Arc < ( String , RedisMap ) > > ;
348
379
type StreamSender = tokio:: sync:: broadcast:: Sender < Arc < ( String , RedisMap ) > > ;
349
380
type StreamReceiver = tokio:: sync:: broadcast:: Receiver < Arc < ( String , RedisMap ) > > ;
350
381
@@ -372,7 +403,9 @@ mod test {
372
403
use rand:: random;
373
404
use redis:: { Client , Commands , FromRedisValue } ;
374
405
use std:: sync:: Arc ;
406
+ use std:: time:: Duration ;
375
407
use tokio:: task:: JoinSet ;
408
+ use tokio:: time:: timeout;
376
409
377
410
struct TestMessage {
378
411
id : String ,
@@ -524,6 +557,52 @@ mod test {
524
557
tokio:: time:: sleep ( tokio:: time:: Duration :: from_millis ( 200 ) ) . await ;
525
558
assert_eq ! ( metrics. reads_enqueued. get( ) , enqueued, "unchanged enqueues" ) ;
526
559
assert_eq ! ( metrics. reads_dequeued. get( ) , dequeued, "unchanged dequeues" ) ;
560
+
561
+ assert ! ( router. channels. get( & key) . is_none( ) ) ;
562
+ }
563
+
564
+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 4 ) ]
565
+ async fn observe_unobserve_observe_again ( ) {
566
+ const ROUTES_COUNT : usize = 1 ;
567
+ const MSG_PER_ROUTE : usize = 10 ;
568
+
569
+ let mut client = Client :: open ( "redis://127.0.0.1/" ) . unwrap ( ) ;
570
+ let mut keys = init_streams ( & mut client, ROUTES_COUNT , MSG_PER_ROUTE ) ;
571
+ let metrics = Arc :: new ( CollabStreamMetrics :: default ( ) ) ;
572
+
573
+ let router = StreamRouter :: with_options (
574
+ & client,
575
+ metrics. clone ( ) ,
576
+ StreamRouterOptions {
577
+ worker_count : 2 ,
578
+ xread_streams : 100 ,
579
+ xread_block_millis : Some ( 50 ) ,
580
+ xread_count : Some ( MSG_PER_ROUTE / 2 ) ,
581
+ } ,
582
+ )
583
+ . unwrap ( ) ;
584
+
585
+ let key = keys. pop ( ) . unwrap ( ) ;
586
+ let mut observer = router. observe ( key. clone ( ) , None ) ;
587
+ // read half of the messages
588
+ for i in 0 ..MSG_PER_ROUTE {
589
+ let msg: TestMessage = observer. next ( ) . await . unwrap ( ) . unwrap ( ) ;
590
+ assert_eq ! ( msg. data, format!( "{}-{}" , key, i) ) ;
591
+ }
592
+ drop ( observer) ;
593
+
594
+ // try to overflow the tokio broadcast buffer by producing more messages
595
+ for i in 0 ..MSG_PER_ROUTE {
596
+ let data = format ! ( "{}-{}" , key, i) ;
597
+ let _: String = client. xadd ( & key, "*" , & [ ( "data" , data) ] ) . unwrap ( ) ;
598
+ }
599
+
600
+ let mut observer = router. observe ( key. clone ( ) , None ) ;
601
+ let t = Duration :: from_millis ( 100 ) ;
602
+ for i in 0 ..MSG_PER_ROUTE {
603
+ let msg: TestMessage = timeout ( t, observer. next ( ) ) . await . unwrap ( ) . unwrap ( ) . unwrap ( ) ;
604
+ assert_eq ! ( msg. data, format!( "{}-{}" , key, i) ) ;
605
+ }
527
606
}
528
607
529
608
fn init_streams ( client : & mut Client , stream_count : usize , msgs_per_stream : usize ) -> Vec < String > {
0 commit comments