diff --git a/iterable_channel.go b/iterable_channel.go index 1f29ff39..436b9b76 100644 --- a/iterable_channel.go +++ b/iterable_channel.go @@ -5,10 +5,16 @@ import ( "sync" ) +type subscription struct { + ctx context.Context + ch chan Item +} + type channelIterable struct { next <-chan Item opts []Option - subscribers []chan Item + nextSubscriberID uint64 + subscribers map[uint64]*subscription mutex sync.RWMutex producerAlreadyCreated bool } @@ -16,7 +22,7 @@ type channelIterable struct { func newChannelIterable(next <-chan Item, opts ...Option) Iterable { return &channelIterable{ next: next, - subscribers: make([]chan Item, 0), + subscribers: make(map[uint64]*subscription), opts: opts, } } @@ -34,9 +40,21 @@ func (i *channelIterable) Observe(opts ...Option) <-chan Item { return nil } + ch := i.createSubscription(option) + return ch +} + +func (i *channelIterable) createSubscription(option Option) chan Item { ch := option.buildChannel() + sctx := option.buildContext(emptyContext) + i.mutex.Lock() - i.subscribers = append(i.subscribers, ch) + sid := i.nextSubscriberID + i.nextSubscriberID++ + i.subscribers[sid] = &subscription{ + ctx: sctx, + ch: ch, + } i.mutex.Unlock() return ch } @@ -54,7 +72,7 @@ func (i *channelIterable) produce(ctx context.Context) { defer func() { i.mutex.RLock() for _, subscriber := range i.subscribers { - close(subscriber) + close(subscriber.ch) } i.mutex.RUnlock() }() @@ -67,11 +85,26 @@ func (i *channelIterable) produce(ctx context.Context) { if !ok { return } + toBeCleaned := make([]uint64, 0) i.mutex.RLock() - for _, subscriber := range i.subscribers { - subscriber <- item + for sid, subscriber := range i.subscribers { + select { + case <-subscriber.ctx.Done(): + toBeCleaned = append(toBeCleaned, sid) + case subscriber.ch <- item: + } } i.mutex.RUnlock() + + i.removeSubscriptions(toBeCleaned) } } } + +func (i *channelIterable) removeSubscriptions(sids []uint64) { + i.mutex.Lock() + defer i.mutex.Unlock() + for _, sid := range sids { + delete(i.subscribers, sid) + } +}