diff --git a/spdy/listener.go b/spdy/listener.go index 570b69a..f531f7e 100644 --- a/spdy/listener.go +++ b/spdy/listener.go @@ -2,6 +2,9 @@ package spdy import ( "net" + "time" + "errors" + "fmt" ) // TransportListener is a listener which accepts new @@ -9,6 +12,7 @@ import ( type TransportListener struct { listener net.Listener auth Authenticator + timeout *time.Duration } // NewTransportListener creates a new listen transport using @@ -32,11 +36,36 @@ func (l *TransportListener) Close() error { // and creates a new stream. Connections which fail // authentication will not be returned. func (l *TransportListener) AcceptTransport() (*Transport, error) { + tranChan := make(chan interface{}) for { - conn, err := l.listener.Accept() - if err != nil { - return nil, err + // The timeout channel has a buffer of 1 + // to allow the timeout goroutine to exit + // if nothing is listening anymore. This prevents + // it form hanging forever waiting on a receiver. + timeoutChan := make(chan bool, 1) + // Launch listener wait inside a goroutine passing + // transport channel + go l.waitForAccept(tranChan) + // If timeout provided launch timeout goroutine with + // duration to wait and timeout channel + if l.timeout != nil { + go waitForTimeout(*l.timeout, timeoutChan) } + + // Wait for new connection channel or timeout channel + var conn net.Conn + select { + case x := <-tranChan: + if x, ok := x.(net.Conn); ok { + conn = x + } else if x, ok := x.(error); ok { + return nil, x + } + case <-timeoutChan: + // We have timed out + return nil, errors.New(fmt.Sprintf("listener timed out (%s)", l.timeout.String())) + } + authErr := l.auth(conn) if authErr != nil { // TODO log @@ -47,3 +76,25 @@ func (l *TransportListener) AcceptTransport() (*Transport, error) { return newSession(conn, true) } } + +func (l *TransportListener) waitForAccept(tranChan chan interface{}) { + conn, err := l.listener.Accept() + if err != nil { + tranChan<-err + } + tranChan<-conn + return +} + +// Sets the timeout for this listener. AcceptTransport() will return if no +// connection is opened for t amount of time. +func (l *TransportListener) SetTimeout(t time.Duration) { + l.timeout = &t +} + +// Function to wait for a timeout condition and then signal to a channel +func waitForTimeout(d time.Duration, timeoutChan chan bool) { + time.Sleep(d) + timeoutChan<-true +} + diff --git a/spdy/listener_test.go b/spdy/listener_test.go new file mode 100644 index 0000000..21399bf --- /dev/null +++ b/spdy/listener_test.go @@ -0,0 +1,57 @@ +package spdy + +import ( + "testing" + "time" + "net" +) + +// Test that server detects client session is dead while waiting for receive channel +func TestListenerTimeout(t *testing.T) { + // Test default behavior without timeout + noWait(t) + // Test enabling timeout behavior + withWait(t) +} + +func noWait(t *testing.T) { + // Start listener, ensure it doesn't throw error after 100 ms of no connecton + timeoutChan := make(chan bool) + go func() { + time.Sleep(time.Millisecond * 200) + close(timeoutChan) + }() + // Start listener, ensure it does throw error after 100 ms of no connection + listener, _ := net.Listen("tcp", "localhost:12945") + go func() { + transportListener, _ := NewTransportListener(listener, NoAuthenticator) + _, err := transportListener.AcceptTransport() + t.Fatal(err) + }() + <-timeoutChan +} + +func withWait(t *testing.T) { + timeoutChan := make(chan bool) + go func() { + time.Sleep(time.Millisecond * 200) + timeoutChan<-false + }() + // Start listener, ensure it does throw error after 100 ms of no connection + listener, _ := net.Listen("tcp", "localhost:12946") + go func() { + transportListener, _ := NewTransportListener(listener, NoAuthenticator) + transportListener.SetTimeout(time.Millisecond * 100) + _, err := transportListener.AcceptTransport() + if err.Error() != "listener timed out (100ms)" { + t.Fatal(err.Error() + ", should have timed out at (100ms)") + } + timeoutChan<-true + }() + select { + case ok := <-timeoutChan: + if !ok { + t.Fatal("timeout expected and did not occur") + } + } +} diff --git a/spdy/session.go b/spdy/session.go index dea1991..a1f9434 100644 --- a/spdy/session.go +++ b/spdy/session.go @@ -7,6 +7,8 @@ import ( "net/http" "strconv" "sync" + "time" + "fmt" "github.com/dmcgowan/go/codec" "github.com/docker/libchan" @@ -18,6 +20,18 @@ type direction uint8 const ( outbound = direction(0x01) inbound = direction(0x02) + // Defaults for heartbeat. + // Currently ping every 30 seconds and fail if no pings respond + // + // The frequency of pinging the client to + // detect a dead session. This is the default and + // can be overridden via property on the *Transport + defaultHeartbeatInterval = time.Second * 30 + // The amount of times the heartbeat fails + // before the session is considered dead. This is + // the default and can be overridden via property + // on the *Transport. + defaultHeartbeatLimit = 3 ) var ( @@ -29,9 +43,15 @@ var ( // Transport is a transport session on top of a network // connection using spdy. type Transport struct { + HeartbeatInterval time.Duration + HeartbeatLimit int + conn *spdystream.Connection handler codec.Handle + deadSessionChan chan struct{} + deadSessionFlag bool + receiverChan chan *channel channelC *sync.Cond channels map[uint64]*channel @@ -76,15 +96,18 @@ func newSession(conn net.Conn, server bool) (*Transport, error) { referenceCounter = 1 } session := &Transport{ - receiverChan: make(chan *channel), - channelC: sync.NewCond(new(sync.Mutex)), - channels: make(map[uint64]*channel), - referenceCounter: referenceCounter, - byteStreamC: sync.NewCond(new(sync.Mutex)), - byteStreams: make(map[uint64]*byteStream), - netConnC: sync.NewCond(new(sync.Mutex)), - netConns: make(map[byte]map[string]net.Conn), - networks: make(map[string]byte), + deadSessionChan: make(chan struct{}), + receiverChan: make(chan *channel), + channelC: sync.NewCond(new(sync.Mutex)), + channels: make(map[uint64]*channel), + referenceCounter: referenceCounter, + byteStreamC: sync.NewCond(new(sync.Mutex)), + byteStreams: make(map[uint64]*byteStream), + netConnC: sync.NewCond(new(sync.Mutex)), + netConns: make(map[byte]map[string]net.Conn), + networks: make(map[string]byte), + HeartbeatInterval: defaultHeartbeatInterval, + HeartbeatLimit: defaultHeartbeatLimit, } spdyConn, spdyErr := spdystream.NewConnection(conn, server) @@ -96,9 +119,57 @@ func newSession(conn net.Conn, server bool) (*Transport, error) { session.conn = spdyConn session.handler = session.initializeHandler() + // Looping heartbeat monitor. Pings the client to + // determine if it has lost connection without sending + // a close. + go session.monitorHeartbeat() + return session, nil } +// errDeadSession occurs when heartbeat is enabled and +// a ping returns an error trying to contact the client. +// This is useful for managing long runnning connections +// that may die form network failure. This is a method +// rather than a var to allow insertion of time elapsed +// dynamically. +func (s *Transport) errDeadSession() error { + return errors.New(fmt.Sprintf("session appears dead no response after %v", s.HeartbeatInterval*time.Duration(s.HeartbeatLimit))) +} + +func (s *Transport) monitorHeartbeat() { + var hbFailures int = 0 + for { + // Only loop after waiting for the heartbeatInterval + time.Sleep(s.HeartbeatInterval) + _, err := s.conn.Ping() + if err != nil { + // Increase heartbeat failure count + hbFailures++ + // If we have hit out limit on failures we trigger marking + // the session as dead. + if hbFailures >= s.HeartbeatLimit { + // Set the deadSessionFlag to true. This is used to + // check for a dead session before starting a blocking + // op using a channel. + s.deadSessionFlag = true + // Uses the closing of a channel trick to + // broadcast to all waiting threads that + // the session is dead. + // Any thread that needs to wait on a blocking + // op that is dependant on the session being live + // should implement a select that includes this + // channel closing as a signal. + close(s.deadSessionChan) + return + } + } else { + // Reset heartbeat failure count + hbFailures = 0 + } + } +} + func (s *Transport) newStreamHandler(stream *spdystream.Stream) { referenceIDString := stream.Headers().Get("libchan-ref") parentIDString := stream.Headers().Get("libchan-parent-ref") @@ -192,7 +263,7 @@ func (s *Transport) dial(referenceID uint64) (*byteStream, error) { func (s *Transport) nextReferenceID() uint64 { s.referenceLock.Lock() referenceID := s.referenceCounter - s.referenceCounter = referenceID + 2 + s.referenceCounter = referenceID+2 s.referenceLock.Unlock() return referenceID } @@ -306,12 +377,25 @@ func (s *Transport) NewSendChannel() (libchan.Sender, error) { // WaitReceiveChannel waits for a new channel be created by a remote // call to NewSendChannel. func (s *Transport) WaitReceiveChannel() (libchan.Receiver, error) { - r, ok := <-s.receiverChan - if !ok { - return nil, io.EOF - } + for { + // Safety check to see if session is dead before starting select + if s.deadSessionFlag { + return nil, s.errDeadSession() + } + // We use a select to wait for either the receiver channel + // or a dead session channel. + select { + case <-s.deadSessionChan: + // Return nil and ErrDeadSession + return nil, s.errDeadSession() + case r, ok := <-s.receiverChan: + if !ok { + return nil, io.EOF + } + return r, nil + } - return r, nil + } } func (c *channel) createSubChannel(direction direction) (libchan.Sender, libchan.Receiver, error) { @@ -387,19 +471,41 @@ func (c *channel) Receive(message interface{}) error { if c.direction == outbound { return ErrWrongDirection } - buf, readErr := c.stream.ReadData() - if readErr != nil { - if readErr == io.EOF { - c.stream.Close() + // Use a goroutine and channel to ReadData from channel + buffChan := make(chan interface{}) + go c.handleReadData(buffChan) + // Wait for channel response or signal that session is dead + select { + case <-c.session.deadSessionChan: + // Dead session + return c.session.errDeadSession() + case b := <-buffChan: + switch b.(type) { + case error: + if b.(error) == io.EOF { + c.stream.Close() + } + return b.(error) + case []byte: + decoder := codec.NewDecoderBytes(b.([]byte), c.session.handler) + decodeErr := decoder.Decode(message) + if decodeErr != nil { + return decodeErr + } + return nil + default: + panic("unknown type") } - return readErr } - decoder := codec.NewDecoderBytes(buf, c.session.handler) - decodeErr := decoder.Decode(message) - if decodeErr != nil { - return decodeErr +} + +func (c *channel) handleReadData(buffChan chan interface{}) { + buf, err := c.stream.ReadData() + if err != nil { + buffChan<-err + } else { + buffChan<-buf } - return nil } // Close closes the underlying stream, causing any subsequent diff --git a/spdy/session_test.go b/spdy/session_test.go index 75e40c8..4aaf150 100644 --- a/spdy/session_test.go +++ b/spdy/session_test.go @@ -399,3 +399,159 @@ func SpawnClientServerTest(t *testing.T, host string, client ClientRoutine, serv } } + +// Test that server detects client session is dead while waiting for receive channel +func TestHeartbeatWaitReceiveChannel(t *testing.T) { + testChan := make(chan struct {bool; string}) + + // Open connection and channel then close connection without closing transport + client := func() { + sleepChan := make(chan struct {}) + conn, _ := net.Dial("tcp", "localhost:12943") + NewClientTransport(conn) + go func() { + time.Sleep(time.Millisecond * 200) + close(sleepChan) + }() + <-sleepChan + conn.Close() + } + + // Check that dead session is detected from WaitReceiveChannel() + server := func() { + listener, _ := net.Listen("tcp", "localhost:12943") + transportListener, _ := NewTransportListener(listener, NoAuthenticator) + transport, err1 := transportListener.AcceptTransport() + if err1 != nil { + t.Fatal(err1) + } + // Shorten heartbeat for test speed reasons + transport.HeartbeatInterval = time.Millisecond*100 + transport.HeartbeatLimit = 3 + _, err2 := transport.WaitReceiveChannel() + if err2 != nil { + if err2.Error() == "session appears dead no response after 300ms" { + testChan<-struct{bool; string}{true, err2.Error()} + } else { + testChan<-struct{bool; string}{false, err2.Error()} + } + } else { + testChan<-struct{bool; string}{false, "No error thrown as expected"} + } + } + + go server() + time.Sleep(time.Millisecond * 100) + go client() + + x := <-testChan + if !x.bool { + t.Fatal(x.string) + } +} + +// Test that the dead session flag errors out +func TestDeadSessionFlagWaitReceiveChannel(t *testing.T) { + testChan := make(chan struct {bool; string}) + + // Open connection and channel then close connection without closing transport + client := func() { + sleepChan := make(chan struct {}) + conn, _ := net.Dial("tcp", "localhost:12950") + NewClientTransport(conn) + go func() { + time.Sleep(time.Millisecond * 200) + close(sleepChan) + }() + <-sleepChan + conn.Close() + } + + // Check that dead session is detected from WaitReceiveChannel() + server := func() { + listener, _ := net.Listen("tcp", "localhost:12950") + transportListener, _ := NewTransportListener(listener, NoAuthenticator) + transport, err1 := transportListener.AcceptTransport() + if err1 != nil { + t.Fatal(err1) + } + // Shorten heartbeat for test speed reasons + transport.HeartbeatInterval = time.Millisecond*100 + transport.HeartbeatLimit = 3 + // Need to use a goroutine to not block the transport + go func() { + time.Sleep(time.Millisecond * 500) + _, err2 := transport.WaitReceiveChannel() + if err2 != nil { + if err2.Error() == "session appears dead no response after 300ms" { + testChan<-struct{bool; string}{true, err2.Error()} + } else { + testChan<-struct{bool; string}{false, err2.Error()} + } + } else { + testChan<-struct{bool; string}{false, "No error thrown as expected"} + } + }() + } + + go server() + time.Sleep(time.Millisecond * 100) + go client() + + x := <-testChan + if !x.bool { + t.Fatal(x.string) + } +} + +// Test that server detects client session is dead while waiting for receiving data +func TestHeartbeatReceive(t *testing.T) { + testChan := make(chan struct {bool; string}) + + // Open connection and channel then close connection without closing transport + client := func() { + sleepChan := make(chan struct {}) + conn, _ := net.Dial("tcp", "localhost:12452") + transport, _ := NewClientTransport(conn) + // sender, _ := + transport.NewSendChannel() + go func() { + time.Sleep(time.Millisecond * 300) + close(sleepChan) + }() + <-sleepChan + conn.Close() + } + + // Check that dead session is detected from WaitReceiveChannel() + server := func() { + listener, _ := net.Listen("tcp", "localhost:12452") + transportListener, _ := NewTransportListener(listener, NoAuthenticator) + transport, err1 := transportListener.AcceptTransport() + if err1 != nil { + t.Fatal(err1) + } + // Shorten heartbeat for test speed reasons + transport.HeartbeatInterval = time.Millisecond*100 + transport.HeartbeatLimit = 3 + receiver, _ := transport.WaitReceiveChannel() + foo := &SimpleMessage{} + rerr := receiver.Receive(foo) + if rerr.Error() == "session appears dead no response after 300ms" { + testChan<-struct{bool; string}{true, rerr.Error()} + } else { + testChan<-struct{bool; string}{false, "No error thrown as expected"} + } + } + + go server() + time.Sleep(time.Millisecond * 100) + go client() + + x := <-testChan + if !x.bool { + t.Fatal(x.string) + } +} + +