Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ const (
defaultCascadeCircuitOpenTimeout = 0
defaultCascadeCircuitCounterReset = 1 * time.Second

defaultDelegatedRoutingCacheControlSuccessHeader = "public, max-age=300, s-maxage=300, stale-while-revalidate=60, stale-if-error=120"
defaultDelegatedRoutingCacheControlNotFoundHeader = "public, max-age=300, s-maxage=300, stale-while-revalidate=60, stale-if-error=120"

defaultStatMaxProviders = 10
defaultStatProviderReportUpdate = 5 * time.Minute

Expand Down Expand Up @@ -69,6 +72,10 @@ var config struct {
OpenTimeout time.Duration
CounterReset time.Duration
}
DelegatedRouting struct {
CacheControlSuccessHeader string
CacheControlNotFoundHeader string
}
}

func init() {
Expand All @@ -94,6 +101,9 @@ func init() {
config.CascadeCircuit.HalfOpenSuccesses = getEnvOrDefault[int]("CASCADE_CIRCUIT_HALF_OPEN_SUCCESSES", defaultCascadeCircuitHalfOpenSuccesses)
config.CascadeCircuit.OpenTimeout = getEnvOrDefault[time.Duration]("CASCADE_CIRCUIT_OPEN_TIMEOUT", defaultCascadeCircuitOpenTimeout)
config.CascadeCircuit.CounterReset = getEnvOrDefault[time.Duration]("CASCADE_CIRCUIT_COUNTER_RESET", defaultCascadeCircuitCounterReset)

config.DelegatedRouting.CacheControlSuccessHeader = getEnvOrDefault[string]("DELEGATED_ROUTING_CACHE_CONTROL_SUCCESS_HEADER", defaultDelegatedRoutingCacheControlSuccessHeader)
config.DelegatedRouting.CacheControlNotFoundHeader = getEnvOrDefault[string]("DELEGATED_ROUTING_CACHE_CONTROL_NOT_FOUND_HEADER", defaultDelegatedRoutingCacheControlNotFoundHeader)
}

func getEnvOrDefault[T any](key string, def T) T {
Expand Down
106 changes: 66 additions & 40 deletions delegated_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@ const (
type findFunc func(ctx context.Context, method, source string, req *url.URL, encrypted bool) (int, []byte)
type findStreamFunc func(ctx context.Context, method string, req *url.URL, encrypted bool) (int, chan model.ProviderResult)

func NewDelegatedTranslator(backend findFunc, streamingBackend findStreamFunc) (http.Handler, error) {
func NewDelegatedTranslator(
backend findFunc,
streamingBackend findStreamFunc,
cacheControlSuccessHeader string,
cacheControlNotFoundHeader string,
) (http.Handler, error) {
finder := delegatedTranslator{
be: backend,
sbe: streamingBackend,
be: backend,
sbe: streamingBackend,
cacheControlSuccessHeader: cacheControlSuccessHeader,
cacheControlNotFoundHeader: cacheControlNotFoundHeader,
}
m := http.NewServeMux()
m.HandleFunc("/providers", finder.provide)
Expand All @@ -44,6 +51,9 @@ func NewDelegatedTranslator(backend findFunc, streamingBackend findStreamFunc) (
type delegatedTranslator struct {
be findFunc
sbe findStreamFunc

cacheControlSuccessHeader string
cacheControlNotFoundHeader string
}

func (dt *delegatedTranslator) provide(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -74,13 +84,19 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr
h := w.Header()
h.Add("Access-Control-Allow-Origin", "*")
h.Add("Access-Control-Allow-Methods", "GET, OPTIONS")
h.Add("X-Content-Type-Options", "nosniff")
h.Add("Vary", "Accept")

Comment thread
byo marked this conversation as resolved.
switch r.Method {
case http.MethodGet:
// continue with the request

case http.MethodOptions:
w.WriteHeader(http.StatusOK)
return

default:
w.Header().Set("Allow", http.MethodGet)
h.Add("Allow", http.MethodGet+", "+http.MethodOptions)
http.Error(w, "", http.StatusMethodNotAllowed)
return
}
Expand All @@ -105,49 +121,58 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr

switch {
case acc.ndjson:
rcode, respChan := dt.sbe(r.Context(), findMethodDelegated, uri, encrypted)
if rcode != http.StatusOK {
http.Error(w, "", rcode)
return
dt.findNDJSon(r, uri, encrypted, w, flt)

default:
dt.findJSON(r, uri, encrypted, w, flt)
}
}

func (dt *delegatedTranslator) findNDJSon(r *http.Request, uri *url.URL, encrypted bool, w http.ResponseWriter, flt *filters) {
rcode, respChan := dt.sbe(r.Context(), findMethodDelegated, uri, encrypted)
if rcode != http.StatusOK {
http.Error(w, "", rcode)
return
}

w.Header().Set("Content-Type", mediaTypeNDJson)
w.Header().Set("Cache-Control", dt.cacheControlSuccessHeader)

out := &drResp{}
encoder := json.NewEncoder(w)

for rcrd := range respChan {
prov := drProvFromResult(rcrd)

if !flt.apply(prov) {
// provider does not pass the filters, skip it.
continue
}

out := &drResp{}
hasWritten := false
encoder := json.NewEncoder(w)

for rcrd := range respChan {
prov := drProvFromResult(rcrd)

if !flt.apply(prov) {
// provider does not pass the filters, skip it.
continue
}

// if new
if out.append(prov) {
if !hasWritten {
w.Header().Set("Content-Type", mediaTypeNDJson)
w.Header().Set("Connection", "Keep-Alive")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(200)
hasWritten = true
}

if err := encoder.Encode(prov); err != nil {
return
}
}
if !out.append(prov) {
// duplicate provider, skip it.
continue
}
if len(out.seenProviders) == 0 {
// no response.
w.WriteHeader(http.StatusNotFound)

if err := encoder.Encode(prov); err != nil {
return
}
return
default:
}

if len(out.seenProviders) == 0 {
// no response.
w.Header().Set("Cache-Control", dt.cacheControlNotFoundHeader)
http.Error(w, "", http.StatusNotFound)
}
}

func (dt *delegatedTranslator) findJSON(r *http.Request, uri *url.URL, encrypted bool, w http.ResponseWriter, flt *filters) {
rcode, resp := dt.be(r.Context(), http.MethodGet, findMethodDelegated, uri, encrypted)
if rcode != http.StatusOK {
if rcode == http.StatusNotFound {
w.Header().Set("Cache-Control", dt.cacheControlNotFoundHeader)
http.Error(w, "", http.StatusNotFound)
return
} else if rcode != http.StatusOK {
http.Error(w, "", rcode)
return
}
Expand Down Expand Up @@ -196,6 +221,7 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr
http.Error(w, "", http.StatusInternalServerError)
}

w.Header().Set("Cache-Control", dt.cacheControlSuccessHeader)
writeJsonResponse(w, http.StatusOK, outBytes)
}

Expand Down
7 changes: 6 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,12 @@ func (s *server) Serve() chan error {
mux.Handle("/metrics", metrics.Start(nil))

ec := make(chan error)
delegated, err := NewDelegatedTranslator(s.doFind, s.doFindStreaming)
delegated, err := NewDelegatedTranslator(
s.doFind,
s.doFindStreaming,
config.DelegatedRouting.CacheControlSuccessHeader,
config.DelegatedRouting.CacheControlNotFoundHeader,
)
if err != nil {
ec <- err
close(ec)
Expand Down
171 changes: 170 additions & 1 deletion server_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -232,7 +233,7 @@ func (s *serverTestSuite) TestStreamingFindMalformedBackend() {

data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Empty(t, data)
require.Empty(t, bytes.TrimSpace(data))
}
}

Expand Down Expand Up @@ -306,3 +307,171 @@ func (s *serverTestSuite) TestLargeJSONResponse() {
require.NoError(t, err)
require.Len(t, response.Providers, 100)
}

func (s *serverTestSuite) TestDelegatedRoutingResponseHeaders() {
t := s.T()

const cidStr = "QmeLvFK9dBLhC3kbfc58mLntUei6s7fZUGWsm1xJhczm1S"

for _, dd := range []struct {
Name string

RequestNDJson bool
RequestMethod string

EmptyResponse bool
NoCacheControl bool

ExpectedContentType string
ExpectedStatusCode int
ExpectedAllowedMethods []string
}{
{
Name: "JSON response",
RequestNDJson: false,
RequestMethod: http.MethodGet,
EmptyResponse: false,
ExpectedContentType: "application/json",
ExpectedStatusCode: http.StatusOK,
},
{
Name: "NDJSON response",
RequestNDJson: true,
RequestMethod: http.MethodGet,
EmptyResponse: false,
ExpectedContentType: "application/x-ndjson",
ExpectedStatusCode: http.StatusOK,
},
{
Name: "Empty JSON response",
RequestNDJson: false,
RequestMethod: http.MethodGet,
EmptyResponse: true,
ExpectedContentType: "text/plain",
ExpectedStatusCode: http.StatusNotFound,
},
{
Name: "Empty NDJSON response",
RequestNDJson: true,
RequestMethod: http.MethodGet,
EmptyResponse: true,
ExpectedContentType: "text/plain",
ExpectedStatusCode: http.StatusNotFound,
},
{
Name: "Bad method for JSON",
RequestNDJson: false,
RequestMethod: http.MethodPost,
EmptyResponse: false,
NoCacheControl: true,
ExpectedContentType: "text/plain",
ExpectedStatusCode: http.StatusMethodNotAllowed,
ExpectedAllowedMethods: []string{"GET", "OPTIONS"},
},
{
Name: "Bad method for NDJSON",
RequestNDJson: true,
RequestMethod: http.MethodPost,
EmptyResponse: false,
NoCacheControl: true,
ExpectedContentType: "text/plain",
ExpectedStatusCode: http.StatusMethodNotAllowed,
ExpectedAllowedMethods: []string{"GET", "OPTIONS"},
},
} {
t.Run(dd.Name, func(t *testing.T) {
s.backendHandler = func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, `/cid/`+cidStr, r.URL.Path)

if dd.EmptyResponse {
http.Error(w, "", http.StatusNotFound)
return
}

if dd.RequestNDJson {
require.Equal(t, r.Header.Get("Accept"), "application/x-ndjson")
w.Header().Set("Content-Type", "application/x-ndjson")
writeOneLineJSON(t, w, `
{
"ContextID":"ctx1",
"Metadata":"gBI=",
"Provider":{
"ID":"12D3KooWAGjvuFgSMiSdivCnxifF23ovdqb8j8nzYiEcdy6quL6a",
"Addrs":[
"/ip4/1.2.3.4/tcp/30000"
]
}
}
`)
} else {
require.Equal(t, r.Header.Get("Accept"), "application/json")
w.Header().Set("Content-Type", "application/json")
writeOneLineJSON(t, w, `
{
"MultihashResults": [
{
"Multihash": "EiDtzI9MECNeznPpXjjXnrCpZ/Te+679GWm43DnGecaDIQ==",
"ProviderResults": [
{
"ContextID": "AXESIFBXwfY5v1krna9B2bzjlxEoRTG4avb/uIGFHJbGjtL4",
"Metadata": "oBIA",
"Provider": {
"ID": "12D3KooWAGjvuFgSMiSdivCnxifF23ovdqb8j8nzYiEcdy6quL6a",
"Addrs": [
"/ip4/1.2.3.4/tcp/30000"
]
}
}
]
}
]
}
`)
}
}

req, err := http.NewRequest(
dd.RequestMethod,
fmt.Sprintf("http://%s/routing/v1/providers/%s", s.srvListener.Addr(), cidStr),
nil,
)
require.NoError(t, err)

if dd.RequestNDJson {
req.Header.Set("Accept", "application/x-ndjson")
} else {
req.Header.Set("Accept", "application/json")
}

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, resp.Header.Get("Access-Control-Allow-Origin"), "*")
require.Equal(t, resp.Header.Get("Access-Control-Allow-Methods"), "GET, OPTIONS")
require.Equal(t, resp.Header.Get("X-Content-Type-Options"), "nosniff")
require.Equal(t, resp.Header.Get("Vary"), "Accept")

if !dd.NoCacheControl {
cc := resp.Header.Get("Cache-Control")
require.Contains(t, cc, "public")
require.Contains(t, cc, "max-age")
require.Contains(t, cc, "s-maxage")
require.Contains(t, cc, "stale-while-revalidate")
require.Contains(t, cc, "stale-if-error")
}

require.Equal(t, dd.ExpectedStatusCode, resp.StatusCode)
require.Contains(t, resp.Header.Get("Content-Type"), dd.ExpectedContentType)

allowedMethods := []string{}
for _, method := range strings.Split(resp.Header.Get("Allow"), ",") {
if method = strings.TrimSpace(method); method != "" {
allowedMethods = append(allowedMethods, method)
}
}

require.ElementsMatch(t, dd.ExpectedAllowedMethods, allowedMethods)
})
}
}
Loading