diff --git a/streaming/cmd/chatter/http.go b/streaming/cmd/chatter/http.go index 4c9c3d37a..f4243d60a 100644 --- a/streaming/cmd/chatter/http.go +++ b/streaming/cmd/chatter/http.go @@ -55,7 +55,9 @@ func handleHTTPServer(ctx context.Context, u *url.URL, chatterEndpoints *chatter { eh := errorHandler(logger) upgrader := &websocket.Upgrader{} - chatterServer = chattersvr.New(chatterEndpoints, mux, dec, enc, eh, nil, upgrader, nil) + chatterConfigurer := chattersvr.NewConnConfigurer(nil) + chatterConfigurer.SubscribeFn = pingPonger(logger) + chatterServer = chattersvr.New(chatterEndpoints, mux, dec, enc, eh, nil, upgrader, chatterConfigurer) if debug { servers := goahttp.Servers{ chatterServer, @@ -111,7 +113,86 @@ func handleHTTPServer(ctx context.Context, u *url.URL, chatterEndpoints *chatter func errorHandler(logger *log.Logger) func(context.Context, http.ResponseWriter, error) { return func(ctx context.Context, w http.ResponseWriter, err error) { id := ctx.Value(middleware.RequestIDKey).(string) - _, _ = w.Write([]byte("[" + id + "] encoding: " + err.Error())) + _, writeErr := w.Write([]byte("[" + id + "] encoding: " + err.Error())) + if writeErr != nil { + logger.Printf("[%s] ERROR: failed to write error response: %s", id, writeErr.Error()) + } logger.Printf("[%s] ERROR: %s", id, err.Error()) } } + +// pingPonger configures the websocket connection to check the health of the +// connection between client and server. It periodically sends a ping message +// to the client and if the client does not respond with a pong within a +// specified time, it closes the websocket connection and cancels the request +// context. +// +// NOTE: This is suitable for use only in server-side streaming endpoints +// (i.e. client does NOT send any messages through the stream), because it +// reads the websocket connection for pong messages from the client. If this is +// used in any endpoint where the client streams, it will result in lost +// messages from the client which is undesirable. +func pingPonger(logger *log.Logger) goahttp.ConnConfigureFunc { + pingInterval := 3 * time.Second + return goahttp.ConnConfigureFunc(func(conn *websocket.Conn, cancel context.CancelFunc) *websocket.Conn { + // errc is the channel read by ping-ponger to check if there were any + // errors when reading messages sent by the client from the websocket. + errc := make(chan error) + + // Start a goroutine to read messages sent by the client from the + // websocket connection. This will pick up any pong message sent + // by the client. Send any errors to errc. + go func() { + for { + if _, _, err := conn.ReadMessage(); err != nil { + logger.Printf("error reading messages from client: %v", err) + errc <- err + return + } + } + }() + + // Start the pinger in a separate goroutine. Read any errors in the + // error channel and stop the goroutine when error received. Close the + // websocket connection and cancel the request when client when error + // received. + go func() { + ticker := time.NewTicker(pingInterval) + defer func() { + ticker.Stop() + logger.Printf("client did not respond with pong") + // cancel the request context when timer expires + cancel() + }() + + // Set a read deadline to read pong messages from the client. + // If a client fails to send a pong before the deadline any + // further connection reads will result in an error. We exit the + // goroutine if connection reads error out. + conn.SetReadDeadline(time.Now().Add(pingInterval * 2)) + + // set a custom pong handler + pongFn := conn.PongHandler() + conn.SetPongHandler(func(appData string) error { + logger.Printf("client says pong") + // Reset the read deadline + conn.SetReadDeadline(time.Now().Add(pingInterval * 2)) + return pongFn(appData) + }) + + for { + select { + case <-errc: + return + case <-ticker.C: + // send periodic ping message + if err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(time.Second)); err != nil { + return + } + logger.Printf("pinged client") + } + } + }() + return conn + }) +}