Skip to content

Commit 760a925

Browse files
Add recoverer on transformer (#180)
* Add recoverer on transformer * Add test on transformer server's recoverer * Remove redundancy of setting http status code
1 parent 463b440 commit 760a925

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

api/pkg/transformer/server/server.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http/pprof"
1010
"os"
1111
"os/signal"
12+
"runtime/debug"
1213
"strings"
1314
"syscall"
1415
"time"
@@ -31,8 +32,10 @@ import (
3132

3233
const MerlinLogIdHeader = "X-Merlin-Log-Id"
3334

34-
var shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM}
35-
var onlyOneSignalHandler = make(chan struct{})
35+
var (
36+
shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM}
37+
onlyOneSignalHandler = make(chan struct{})
38+
)
3639

3740
var hystrixCommandName = "model_predict"
3841

@@ -247,9 +250,11 @@ func (s *Server) predict(ctx context.Context, r *http.Request, request []byte) (
247250

248251
// Run serves the HTTP endpoints.
249252
func (s *Server) Run() {
250-
// use default mux
253+
s.router.Use(recoveryHandler)
254+
251255
health := healthcheck.NewHandler()
252256
s.router.Handle("/", health)
257+
253258
s.router.Handle("/metrics", promhttp.Handler())
254259
s.router.PathPrefix("/debug/pprof/profile").HandlerFunc(pprof.Profile)
255260
s.router.PathPrefix("/debug/pprof/trace").HandlerFunc(pprof.Trace)
@@ -295,6 +300,19 @@ func (s *Server) Run() {
295300
}
296301
}
297302

303+
func recoveryHandler(next http.Handler) http.Handler {
304+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
305+
defer func() {
306+
if err := recover(); err != nil {
307+
debug.PrintStack()
308+
response.NewError(http.StatusInternalServerError, fmt.Errorf("panic: %v", err)).Write(w)
309+
}
310+
}()
311+
312+
next.ServeHTTP(w, r)
313+
})
314+
}
315+
298316
// setupSignalHandler registered for SIGTERM and SIGINT. A stop channel is returned
299317
// which is closed on one of these signals. If a second signal is caught, the program
300318
// is terminated with exit code 1.

api/pkg/transformer/server/server_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@ import (
55
"context"
66
"fmt"
77
"io/ioutil"
8+
"net"
89
"net/http"
910
"net/http/httptest"
11+
"strings"
1012
"testing"
1113
"time"
1214

1315
feastSdk "github.com/feast-dev/feast/sdk/go"
1416
"github.com/feast-dev/feast/sdk/go/protos/feast/serving"
1517
feastTypes "github.com/feast-dev/feast/sdk/go/protos/feast/types"
18+
"github.com/gorilla/mux"
1619
"github.com/stretchr/testify/assert"
1720
"github.com/stretchr/testify/mock"
1821
"go.uber.org/zap"
@@ -764,3 +767,41 @@ func respBody(t *testing.T, response *http.Response) string {
764767

765768
return string(respBody)
766769
}
770+
771+
func Test_recoveryHandler(t *testing.T) {
772+
router := mux.NewRouter()
773+
logger, _ := zap.NewDevelopment()
774+
775+
ts := httptest.NewServer(nil)
776+
defer ts.Close()
777+
778+
port := fmt.Sprint(ts.Listener.Addr().(*net.TCPAddr).Port)
779+
modelName := "test-panic"
780+
781+
s := &Server{
782+
router: router,
783+
logger: logger,
784+
options: &Options{
785+
Port: port,
786+
ModelName: modelName,
787+
},
788+
PreprocessHandler: func(ctx context.Context, rawRequest []byte, rawRequestHeaders map[string]string) ([]byte, error) {
789+
panic("panic at preprocess")
790+
return nil, nil
791+
},
792+
}
793+
go s.Run()
794+
795+
// Give some time for the server to run.
796+
time.Sleep(1 * time.Second)
797+
798+
resp, err := http.Post(fmt.Sprintf("http://localhost:%s/v1/models/%s:predict", port, modelName), "", strings.NewReader("{}"))
799+
assert.Nil(t, err)
800+
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
801+
802+
respBody, err := ioutil.ReadAll(resp.Body)
803+
defer resp.Body.Close()
804+
805+
assert.Nil(t, err)
806+
assert.Equal(t, `{"code":500,"message":"panic: panic at preprocess"}`, string(respBody))
807+
}

0 commit comments

Comments
 (0)