Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 119 additions & 33 deletions dealer/dealer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dealer
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
Expand All @@ -19,6 +20,8 @@ const (
timeout = 10 * time.Second
)

var ErrDealerClosed = errors.New("dealer closed")

type Dealer struct {
log librespot.Logger

Expand All @@ -29,6 +32,7 @@ type Dealer struct {

conn *websocket.Conn

closed bool
stop bool
pingTickerStop chan struct{}
recvLoopStop chan struct{}
Expand Down Expand Up @@ -58,6 +62,8 @@ func NewDealer(log librespot.Logger, client *http.Client, dealerAddr librespot.G
log: log,
addr: dealerAddr,
accessToken: accessToken,
pingTickerStop: make(chan struct{}, 1),
recvLoopStop: make(chan struct{}, 1),
requestReceivers: map[string]requestReceiver{},
}
}
Expand All @@ -66,6 +72,10 @@ func (d *Dealer) Connect(ctx context.Context) error {
d.connMu.Lock()
defer d.connMu.Unlock()

if d.closed {
return ErrDealerClosed
}

if d.conn != nil && !d.stop {
d.log.Debugf("dealer connection already opened")
return nil
Expand All @@ -75,8 +85,6 @@ func (d *Dealer) Connect(ctx context.Context) error {
}

func (d *Dealer) connect(ctx context.Context) error {
d.recvLoopStop = make(chan struct{}, 1)
d.pingTickerStop = make(chan struct{}, 1)
d.stop = false

accessToken, err := d.accessToken(ctx, false)
Expand Down Expand Up @@ -106,27 +114,25 @@ func (d *Dealer) connect(ctx context.Context) error {

func (d *Dealer) Close() {
d.connMu.Lock()
defer d.connMu.Unlock()

d.closed = true
d.stop = true
conn := d.conn
d.connMu.Unlock()

if d.conn == nil {
return
}
d.signalStop()

d.recvLoopStop <- struct{}{}
d.pingTickerStop <- struct{}{}
_ = d.conn.Close(websocket.StatusGoingAway, "")
if conn != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
}
}

func (d *Dealer) startReceiving() {
d.recvLoopOnce.Do(func() {
d.clearStopSignals()
d.log.Tracef("starting dealer recv loop")
go d.recvLoop()

// set last pong in the future
d.lastPong = time.Now().Add(pingInterval)
d.resetPongDeadline()
go d.pingTicker()
go d.recvLoop()
})
}

Expand All @@ -139,27 +145,23 @@ loop:
case <-d.pingTickerStop:
break loop
case <-ticker.C:
d.lastPongLock.Lock()
timePassed := time.Since(d.lastPong)
d.lastPongLock.Unlock()
timePassed := d.timeSinceLastPong()
if timePassed > pingInterval+timeout {
d.log.Errorf("did not receive last pong from dealer, %.0fs passed", timePassed.Seconds())

// closing the connection should make the read on the "recvLoop" fail,
// continue hoping for a new connection
_ = d.conn.Close(websocket.StatusServiceRestart, "")
d.closeConn(websocket.StatusServiceRestart)
continue
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
d.connMu.RLock()
err := d.conn.Write(ctx, websocket.MessageText, []byte("{\"type\":\"ping\"}"))
d.connMu.RUnlock()
conn, err := d.writeConn(ctx, websocket.MessageText, []byte("{\"type\":\"ping\"}"))
cancel()
d.log.Tracef("sent dealer ping")

if err != nil {
if d.stop {
if d.isStopped() {
// break early without logging if we should stop
break loop
}
Expand All @@ -168,7 +170,7 @@ loop:

// closing the connection should make the read on the "recvLoop" fail,
// continue hoping for a new connection
_ = d.conn.Close(websocket.StatusServiceRestart, "")
d.closeConnRef(conn, websocket.StatusServiceRestart)
continue
}
}
Expand All @@ -185,10 +187,10 @@ loop:
break loop
default:
// no need to hold the connMu since reconnection happens in this routine
msgType, messageBytes, err := d.conn.Read(context.Background())
msgType, messageBytes, err := d.readConn(context.Background())

// don't log closed error if we're stopping
if d.stop && websocket.CloseStatus(err) == websocket.StatusGoingAway {
if d.isStopped() && websocket.CloseStatus(err) == websocket.StatusGoingAway {
d.log.Debugf("dealer connection closed")
break loop
} else if err != nil {
Expand Down Expand Up @@ -229,10 +231,10 @@ loop:
}

// always close as we might end up here because of application errors
_ = d.conn.Close(websocket.StatusInternalError, "")
d.closeConn(websocket.StatusInternalError)

// if we shouldn't stop, try to reconnect
if !d.stop {
if !d.isStopped() {
d.connMu.Lock()
if err := backoff.Retry(d.reconnect, backoff.NewExponentialBackOff()); err != nil {
d.log.WithError(err).Errorf("failed reconnecting dealer")
Expand Down Expand Up @@ -273,9 +275,7 @@ func (d *Dealer) sendReply(key string, success bool) error {
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
d.connMu.RLock()
err = d.conn.Write(ctx, websocket.MessageText, replyBytes)
d.connMu.RUnlock()
_, err = d.writeConn(ctx, websocket.MessageText, replyBytes)
cancel()
if err != nil {
return fmt.Errorf("failed sending dealer reply: %w", err)
Expand All @@ -289,12 +289,98 @@ func (d *Dealer) reconnect() error {
return err
}

d.lastPongLock.Lock()
d.lastPong = time.Now()
d.lastPongLock.Unlock()
d.resetPongDeadline()
// restart the recv loop
go d.recvLoop()

d.log.Debugf("re-established dealer connection")
return nil
}

func (d *Dealer) resetPongDeadline() {
d.lastPongLock.Lock()
d.lastPong = time.Now().Add(pingInterval)
d.lastPongLock.Unlock()
}

func (d *Dealer) timeSinceLastPong() time.Duration {
d.lastPongLock.Lock()
defer d.lastPongLock.Unlock()
return time.Since(d.lastPong)
}

func (d *Dealer) closeConn(status websocket.StatusCode) {
d.connMu.RLock()
conn := d.conn
d.connMu.RUnlock()

d.closeConnRef(conn, status)
}

func (d *Dealer) closeConnRef(conn *websocket.Conn, status websocket.StatusCode) {
if conn != nil {
_ = conn.Close(status, "")
}
}

func (d *Dealer) writeConn(ctx context.Context, typ websocket.MessageType, payload []byte) (*websocket.Conn, error) {
d.connMu.RLock()

if d.closed {
d.connMu.RUnlock()
return nil, ErrDealerClosed
}

conn := d.conn

if conn == nil {
d.connMu.RUnlock()
return nil, fmt.Errorf("dealer connection not established")
}

err := conn.Write(ctx, typ, payload)
d.connMu.RUnlock()
return conn, err
}

func (d *Dealer) readConn(ctx context.Context) (websocket.MessageType, []byte, error) {
d.connMu.RLock()
conn := d.conn
d.connMu.RUnlock()

if conn == nil {
return 0, nil, fmt.Errorf("dealer connection not established")
}

return conn.Read(ctx)
}

func (d *Dealer) signalStop() {
select {
case d.recvLoopStop <- struct{}{}:
default:
}

select {
case d.pingTickerStop <- struct{}{}:
default:
}
}

func (d *Dealer) clearStopSignals() {
select {
case <-d.recvLoopStop:
default:
}

select {
case <-d.pingTickerStop:
default:
}
}

func (d *Dealer) isStopped() bool {
d.connMu.RLock()
defer d.connMu.RUnlock()
return d.stop
}
94 changes: 94 additions & 0 deletions dealer/dealer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package dealer

import (
"context"
"errors"
"testing"
"testing/synctest"
"time"

"github.com/coder/websocket"
librespot "github.com/devgianlu/go-librespot"
)

func TestPingTickerDoesNotPanicWhenConnNil(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
d := &Dealer{
log: &librespot.NullLogger{},
pingTickerStop: make(chan struct{}, 1),
}

panicCh := make(chan any, 1)
go func() {
defer func() {
panicCh <- recover()
}()
d.pingTicker()
}()

time.Sleep(pingInterval + timeout + time.Nanosecond)
synctest.Wait()

select {
case p := <-panicCh:
if p != nil {
t.Fatalf("pingTicker panicked when conn was nil: %v", p)
}
default:
}

d.pingTickerStop <- struct{}{}
synctest.Wait()

select {
case p := <-panicCh:
if p != nil {
t.Fatalf("pingTicker panicked when conn was nil: %v", p)
}
default:
t.Fatal("pingTicker did not stop")
}
})
}

func TestCloseStopsPingTickerWhenConnNil(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
d := &Dealer{
log: &librespot.NullLogger{},
pingTickerStop: make(chan struct{}, 1),
}

done := make(chan struct{})
go func() {
defer close(done)
d.pingTicker()
}()

synctest.Wait()
d.Close()
synctest.Wait()

stopped := false
select {
case <-done:
stopped = true
default:
}

d.pingTickerStop <- struct{}{}
synctest.Wait()

if !stopped {
t.Fatal("pingTicker did not stop when closing with nil conn")
}
})
}

func TestWriteConnRejectsClosedDealer(t *testing.T) {
d := &Dealer{closed: true}

_, err := d.writeConn(context.Background(), websocket.MessageText, nil)
if !errors.Is(err, ErrDealerClosed) {
t.Fatalf("expected ErrDealerClosed, got %v", err)
}
}
Loading
Loading