diff --git a/config.go b/config.go index 3e55ecb..cea665c 100644 --- a/config.go +++ b/config.go @@ -30,6 +30,9 @@ const ( defaultCascadeCircuitOpenTimeout = 0 defaultCascadeCircuitCounterReset = 1 * time.Second + defaultDelegatedRoutingCacheControlSuccessHeader = "public, max-age=300, s-maxage=300, stale-while-revalidate=60, stale-if-error=120" + defaultDelegatedRoutingCacheControlNotFoundHeader = "public, max-age=300, s-maxage=300, stale-while-revalidate=60, stale-if-error=120" + defaultStatMaxProviders = 10 defaultStatProviderReportUpdate = 5 * time.Minute @@ -69,6 +72,10 @@ var config struct { OpenTimeout time.Duration CounterReset time.Duration } + DelegatedRouting struct { + CacheControlSuccessHeader string + CacheControlNotFoundHeader string + } } func init() { @@ -94,6 +101,9 @@ func init() { config.CascadeCircuit.HalfOpenSuccesses = getEnvOrDefault[int]("CASCADE_CIRCUIT_HALF_OPEN_SUCCESSES", defaultCascadeCircuitHalfOpenSuccesses) config.CascadeCircuit.OpenTimeout = getEnvOrDefault[time.Duration]("CASCADE_CIRCUIT_OPEN_TIMEOUT", defaultCascadeCircuitOpenTimeout) config.CascadeCircuit.CounterReset = getEnvOrDefault[time.Duration]("CASCADE_CIRCUIT_COUNTER_RESET", defaultCascadeCircuitCounterReset) + + config.DelegatedRouting.CacheControlSuccessHeader = getEnvOrDefault[string]("DELEGATED_ROUTING_CACHE_CONTROL_SUCCESS_HEADER", defaultDelegatedRoutingCacheControlSuccessHeader) + config.DelegatedRouting.CacheControlNotFoundHeader = getEnvOrDefault[string]("DELEGATED_ROUTING_CACHE_CONTROL_NOT_FOUND_HEADER", defaultDelegatedRoutingCacheControlNotFoundHeader) } func getEnvOrDefault[T any](key string, def T) T { diff --git a/delegated_translator.go b/delegated_translator.go index f83481d..53b3ff0 100644 --- a/delegated_translator.go +++ b/delegated_translator.go @@ -28,10 +28,17 @@ const ( type findFunc func(ctx context.Context, method, source string, req *url.URL, encrypted bool) (int, []byte) type findStreamFunc func(ctx context.Context, method string, req *url.URL, encrypted bool) (int, chan model.ProviderResult) -func NewDelegatedTranslator(backend findFunc, streamingBackend findStreamFunc) (http.Handler, error) { +func NewDelegatedTranslator( + backend findFunc, + streamingBackend findStreamFunc, + cacheControlSuccessHeader string, + cacheControlNotFoundHeader string, +) (http.Handler, error) { finder := delegatedTranslator{ - be: backend, - sbe: streamingBackend, + be: backend, + sbe: streamingBackend, + cacheControlSuccessHeader: cacheControlSuccessHeader, + cacheControlNotFoundHeader: cacheControlNotFoundHeader, } m := http.NewServeMux() m.HandleFunc("/providers", finder.provide) @@ -44,6 +51,9 @@ func NewDelegatedTranslator(backend findFunc, streamingBackend findStreamFunc) ( type delegatedTranslator struct { be findFunc sbe findStreamFunc + + cacheControlSuccessHeader string + cacheControlNotFoundHeader string } func (dt *delegatedTranslator) provide(w http.ResponseWriter, r *http.Request) { @@ -74,13 +84,19 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr h := w.Header() h.Add("Access-Control-Allow-Origin", "*") h.Add("Access-Control-Allow-Methods", "GET, OPTIONS") + h.Add("X-Content-Type-Options", "nosniff") + h.Add("Vary", "Accept") + switch r.Method { case http.MethodGet: + // continue with the request + case http.MethodOptions: w.WriteHeader(http.StatusOK) return + default: - w.Header().Set("Allow", http.MethodGet) + h.Add("Allow", http.MethodGet+", "+http.MethodOptions) http.Error(w, "", http.StatusMethodNotAllowed) return } @@ -105,49 +121,58 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr switch { case acc.ndjson: - rcode, respChan := dt.sbe(r.Context(), findMethodDelegated, uri, encrypted) - if rcode != http.StatusOK { - http.Error(w, "", rcode) - return + dt.findNDJSon(r, uri, encrypted, w, flt) + + default: + dt.findJSON(r, uri, encrypted, w, flt) + } +} + +func (dt *delegatedTranslator) findNDJSon(r *http.Request, uri *url.URL, encrypted bool, w http.ResponseWriter, flt *filters) { + rcode, respChan := dt.sbe(r.Context(), findMethodDelegated, uri, encrypted) + if rcode != http.StatusOK { + http.Error(w, "", rcode) + return + } + + w.Header().Set("Content-Type", mediaTypeNDJson) + w.Header().Set("Cache-Control", dt.cacheControlSuccessHeader) + + out := &drResp{} + encoder := json.NewEncoder(w) + + for rcrd := range respChan { + prov := drProvFromResult(rcrd) + + if !flt.apply(prov) { + // provider does not pass the filters, skip it. + continue } - out := &drResp{} - hasWritten := false - encoder := json.NewEncoder(w) - - for rcrd := range respChan { - prov := drProvFromResult(rcrd) - - if !flt.apply(prov) { - // provider does not pass the filters, skip it. - continue - } - - // if new - if out.append(prov) { - if !hasWritten { - w.Header().Set("Content-Type", mediaTypeNDJson) - w.Header().Set("Connection", "Keep-Alive") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(200) - hasWritten = true - } - - if err := encoder.Encode(prov); err != nil { - return - } - } + if !out.append(prov) { + // duplicate provider, skip it. + continue } - if len(out.seenProviders) == 0 { - // no response. - w.WriteHeader(http.StatusNotFound) + + if err := encoder.Encode(prov); err != nil { + return } - return - default: } + if len(out.seenProviders) == 0 { + // no response. + w.Header().Set("Cache-Control", dt.cacheControlNotFoundHeader) + http.Error(w, "", http.StatusNotFound) + } +} + +func (dt *delegatedTranslator) findJSON(r *http.Request, uri *url.URL, encrypted bool, w http.ResponseWriter, flt *filters) { rcode, resp := dt.be(r.Context(), http.MethodGet, findMethodDelegated, uri, encrypted) - if rcode != http.StatusOK { + if rcode == http.StatusNotFound { + w.Header().Set("Cache-Control", dt.cacheControlNotFoundHeader) + http.Error(w, "", http.StatusNotFound) + return + } else if rcode != http.StatusOK { http.Error(w, "", rcode) return } @@ -196,6 +221,7 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr http.Error(w, "", http.StatusInternalServerError) } + w.Header().Set("Cache-Control", dt.cacheControlSuccessHeader) writeJsonResponse(w, http.StatusOK, outBytes) } diff --git a/server.go b/server.go index e94bf6a..7e3a2ef 100644 --- a/server.go +++ b/server.go @@ -279,7 +279,12 @@ func (s *server) Serve() chan error { mux.Handle("/metrics", metrics.Start(nil)) ec := make(chan error) - delegated, err := NewDelegatedTranslator(s.doFind, s.doFindStreaming) + delegated, err := NewDelegatedTranslator( + s.doFind, + s.doFindStreaming, + config.DelegatedRouting.CacheControlSuccessHeader, + config.DelegatedRouting.CacheControlNotFoundHeader, + ) if err != nil { ec <- err close(ec) diff --git a/server_test.go b/server_test.go index 65188db..ed35557 100644 --- a/server_test.go +++ b/server_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" @@ -232,7 +233,7 @@ func (s *serverTestSuite) TestStreamingFindMalformedBackend() { data, err := io.ReadAll(resp.Body) require.NoError(t, err) - require.Empty(t, data) + require.Empty(t, bytes.TrimSpace(data)) } } @@ -306,3 +307,171 @@ func (s *serverTestSuite) TestLargeJSONResponse() { require.NoError(t, err) require.Len(t, response.Providers, 100) } + +func (s *serverTestSuite) TestDelegatedRoutingResponseHeaders() { + t := s.T() + + const cidStr = "QmeLvFK9dBLhC3kbfc58mLntUei6s7fZUGWsm1xJhczm1S" + + for _, dd := range []struct { + Name string + + RequestNDJson bool + RequestMethod string + + EmptyResponse bool + NoCacheControl bool + + ExpectedContentType string + ExpectedStatusCode int + ExpectedAllowedMethods []string + }{ + { + Name: "JSON response", + RequestNDJson: false, + RequestMethod: http.MethodGet, + EmptyResponse: false, + ExpectedContentType: "application/json", + ExpectedStatusCode: http.StatusOK, + }, + { + Name: "NDJSON response", + RequestNDJson: true, + RequestMethod: http.MethodGet, + EmptyResponse: false, + ExpectedContentType: "application/x-ndjson", + ExpectedStatusCode: http.StatusOK, + }, + { + Name: "Empty JSON response", + RequestNDJson: false, + RequestMethod: http.MethodGet, + EmptyResponse: true, + ExpectedContentType: "text/plain", + ExpectedStatusCode: http.StatusNotFound, + }, + { + Name: "Empty NDJSON response", + RequestNDJson: true, + RequestMethod: http.MethodGet, + EmptyResponse: true, + ExpectedContentType: "text/plain", + ExpectedStatusCode: http.StatusNotFound, + }, + { + Name: "Bad method for JSON", + RequestNDJson: false, + RequestMethod: http.MethodPost, + EmptyResponse: false, + NoCacheControl: true, + ExpectedContentType: "text/plain", + ExpectedStatusCode: http.StatusMethodNotAllowed, + ExpectedAllowedMethods: []string{"GET", "OPTIONS"}, + }, + { + Name: "Bad method for NDJSON", + RequestNDJson: true, + RequestMethod: http.MethodPost, + EmptyResponse: false, + NoCacheControl: true, + ExpectedContentType: "text/plain", + ExpectedStatusCode: http.StatusMethodNotAllowed, + ExpectedAllowedMethods: []string{"GET", "OPTIONS"}, + }, + } { + t.Run(dd.Name, func(t *testing.T) { + s.backendHandler = func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, `/cid/`+cidStr, r.URL.Path) + + if dd.EmptyResponse { + http.Error(w, "", http.StatusNotFound) + return + } + + if dd.RequestNDJson { + require.Equal(t, r.Header.Get("Accept"), "application/x-ndjson") + w.Header().Set("Content-Type", "application/x-ndjson") + writeOneLineJSON(t, w, ` + { + "ContextID":"ctx1", + "Metadata":"gBI=", + "Provider":{ + "ID":"12D3KooWAGjvuFgSMiSdivCnxifF23ovdqb8j8nzYiEcdy6quL6a", + "Addrs":[ + "/ip4/1.2.3.4/tcp/30000" + ] + } + } + `) + } else { + require.Equal(t, r.Header.Get("Accept"), "application/json") + w.Header().Set("Content-Type", "application/json") + writeOneLineJSON(t, w, ` + { + "MultihashResults": [ + { + "Multihash": "EiDtzI9MECNeznPpXjjXnrCpZ/Te+679GWm43DnGecaDIQ==", + "ProviderResults": [ + { + "ContextID": "AXESIFBXwfY5v1krna9B2bzjlxEoRTG4avb/uIGFHJbGjtL4", + "Metadata": "oBIA", + "Provider": { + "ID": "12D3KooWAGjvuFgSMiSdivCnxifF23ovdqb8j8nzYiEcdy6quL6a", + "Addrs": [ + "/ip4/1.2.3.4/tcp/30000" + ] + } + } + ] + } + ] + } + `) + } + } + + req, err := http.NewRequest( + dd.RequestMethod, + fmt.Sprintf("http://%s/routing/v1/providers/%s", s.srvListener.Addr(), cidStr), + nil, + ) + require.NoError(t, err) + + if dd.RequestNDJson { + req.Header.Set("Accept", "application/x-ndjson") + } else { + req.Header.Set("Accept", "application/json") + } + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, resp.Header.Get("Access-Control-Allow-Origin"), "*") + require.Equal(t, resp.Header.Get("Access-Control-Allow-Methods"), "GET, OPTIONS") + require.Equal(t, resp.Header.Get("X-Content-Type-Options"), "nosniff") + require.Equal(t, resp.Header.Get("Vary"), "Accept") + + if !dd.NoCacheControl { + cc := resp.Header.Get("Cache-Control") + require.Contains(t, cc, "public") + require.Contains(t, cc, "max-age") + require.Contains(t, cc, "s-maxage") + require.Contains(t, cc, "stale-while-revalidate") + require.Contains(t, cc, "stale-if-error") + } + + require.Equal(t, dd.ExpectedStatusCode, resp.StatusCode) + require.Contains(t, resp.Header.Get("Content-Type"), dd.ExpectedContentType) + + allowedMethods := []string{} + for _, method := range strings.Split(resp.Header.Get("Allow"), ",") { + if method = strings.TrimSpace(method); method != "" { + allowedMethods = append(allowedMethods, method) + } + } + + require.ElementsMatch(t, dd.ExpectedAllowedMethods, allowedMethods) + }) + } +}