Skip to content

Commit be54b7d

Browse files
committed
add global cache key and prefix
1 parent a2af5cd commit be54b7d

File tree

5 files changed

+241
-5
lines changed

5 files changed

+241
-5
lines changed

config/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type CacheConfig struct {
3838
ShouldHashQuery bool
3939
HashQueryIgnore map[string]bool
4040
HashHeaders []string
41+
GlobalCacheKeys map[string]string // path pattern -> global key
4142
}
4243

4344
type Config struct {
@@ -63,6 +64,7 @@ func New() Config {
6364
ShouldHashQuery: getEnvAsBool("CACHE_SHOULD_HASH_QUERY", "true"),
6465
HashQueryIgnore: hashQueryIgnoreMap(getEnvAsSlice("CACHE_HASH_QUERY_IGNORE")),
6566
HashHeaders: getEnvAsSlice("CACHE_HASH_HEADERS"),
67+
GlobalCacheKeys: parseGlobalCacheKeys(getEnv("CACHE_GLOBAL_KEYS", "/public:static-files,/_next:nextjs-assets")),
6668
}
6769

6870
if strings.ToLower(serverConfig.Storage) == "memory" {

config/helpers.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,35 @@ func hashQueryIgnoreMap(queryIgnore []string) map[string]bool {
1818
return hashQueryIgnoreMap
1919
}
2020

21+
func parseGlobalCacheKeys(globalKeysStr string) map[string]string {
22+
globalKeys := make(map[string]string)
23+
if globalKeysStr == "" {
24+
return globalKeys
25+
}
26+
27+
// Expected format: "pattern1:key1,pattern2:key2"
28+
// Example: "/assets:static-assets,/_next:nextjs-assets,/static:static-files"
29+
pairs := strings.SplitSeq(globalKeysStr, ",")
30+
for pair := range pairs {
31+
trimmedPair := strings.TrimSpace(pair)
32+
if trimmedPair == "" {
33+
continue
34+
}
35+
36+
// Split on colon and ensure exactly one colon exists
37+
parts := strings.Split(trimmedPair, ":")
38+
if len(parts) == 2 {
39+
pattern := strings.TrimSpace(parts[0])
40+
key := strings.TrimSpace(parts[1])
41+
if pattern != "" && key != "" {
42+
globalKeys[pattern] = key
43+
}
44+
}
45+
// Ignore malformed pairs with 0, 1, or more than 2 parts
46+
}
47+
return globalKeys
48+
}
49+
2150
func getEnv(key, defaultVal string) string {
2251
if value, exists := os.LookupEnv(key); exists {
2352
return strings.TrimSpace(value)

config/helpers_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,64 @@ func TestConfigHelpers(t *testing.T) {
158158
})
159159
})
160160
})
161+
t.Run("parseGlobalCacheKeys", func(t *testing.T) {
162+
t.Run("should parse valid global cache keys", func(t *testing.T) {
163+
input := "/assets:static-assets,/_next:nextjs-assets,/static:static-files"
164+
got := parseGlobalCacheKeys(input)
165+
166+
expected := map[string]string{
167+
"/assets": "static-assets",
168+
"/_next": "nextjs-assets",
169+
"/static": "static-files",
170+
}
171+
172+
assert.Equal(t, expected, got)
173+
})
174+
t.Run("should handle empty string", func(t *testing.T) {
175+
got := parseGlobalCacheKeys("")
176+
assert.Equal(t, map[string]string{}, got)
177+
})
178+
t.Run("should handle single pair", func(t *testing.T) {
179+
input := "/assets:static-assets"
180+
got := parseGlobalCacheKeys(input)
181+
182+
expected := map[string]string{
183+
"/assets": "static-assets",
184+
}
185+
186+
assert.Equal(t, expected, got)
187+
})
188+
t.Run("should handle spaces around delimiters", func(t *testing.T) {
189+
input := " /assets : static-assets , /_next : nextjs-assets "
190+
got := parseGlobalCacheKeys(input)
191+
192+
expected := map[string]string{
193+
"/assets": "static-assets",
194+
"/_next": "nextjs-assets",
195+
}
196+
197+
assert.Equal(t, expected, got)
198+
})
199+
t.Run("should ignore malformed pairs", func(t *testing.T) {
200+
input := "/assets:static-assets,invalid,/_next:nextjs-assets,also:invalid:format"
201+
got := parseGlobalCacheKeys(input)
202+
203+
expected := map[string]string{
204+
"/assets": "static-assets",
205+
"/_next": "nextjs-assets",
206+
}
207+
208+
assert.Equal(t, expected, got)
209+
})
210+
t.Run("should ignore empty keys or values", func(t *testing.T) {
211+
input := ":empty-key,empty-value:,/assets:static-assets"
212+
got := parseGlobalCacheKeys(input)
213+
214+
expected := map[string]string{
215+
"/assets": "static-assets",
216+
}
217+
218+
assert.Equal(t, expected, got)
219+
})
220+
})
161221
}

pkg/http/cache.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ func cloneHeaders(src http.Header) http.Header {
9494
}
9595

9696
func (h handler) getCacheKey(req *http.Request) string {
97+
// Check for global cache keys first
98+
for pattern, globalKey := range h.cfg.GlobalCacheKeys {
99+
if strings.HasPrefix(req.URL.Path, pattern) {
100+
return globalKey
101+
}
102+
}
103+
97104
hash := sha256.New()
98105
hash.Write([]byte(req.Method))
99106
hash.Write([]byte(":"))
@@ -240,12 +247,13 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
240247
ctx := logger.WithContext(r.Context())
241248

242249
if strings.ToUpper(r.Method) != http.MethodGet || r.Header.Get("Range") != "" {
250+
logger.Debug().Msg("skipping cache")
243251
h.proxy.ServeHTTP(w, r)
244252
return
245253
}
246254

247255
if isWebSocket(r) {
248-
logger.Info().Msg("skip cache: websocket")
256+
logger.Debug().Msg("skipping cache: websocket")
249257
h.proxy.ServeHTTP(w, r)
250258
return
251259
}
@@ -282,10 +290,11 @@ func (h handler) cacheResponse(ctx context.Context, key string) func(*http.Respo
282290
if !(resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNoContent) {
283291
return nil
284292
}
285-
cc := resp.Header.Get("Cache-Control")
286-
if strings.Contains(cc, "no-store") || strings.Contains(cc, "private") {
287-
return nil
288-
}
293+
// TODO: add cache control back in
294+
// cc := resp.Header.Get("Cache-Control")
295+
// if strings.Contains(cc, "no-store") || strings.Contains(cc, "private") {
296+
// return nil
297+
// }
289298
if resp.Header.Get("Set-Cookie") != "" {
290299
return nil
291300
}

pkg/http/cache_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,3 +465,139 @@ func TestGetCacheKey(t *testing.T) {
465465
})
466466
}
467467
}
468+
469+
func TestGetCacheKeyWithGlobalKeys(t *testing.T) {
470+
tests := []struct {
471+
name string
472+
path string
473+
method string
474+
query string
475+
headers map[string]string
476+
globalCacheKeys map[string]string
477+
want string
478+
}{
479+
{
480+
name: "global key for assets path",
481+
path: "/assets/script.js",
482+
method: "GET",
483+
globalCacheKeys: map[string]string{
484+
"/assets": "static-assets",
485+
},
486+
want: "static-assets",
487+
},
488+
{
489+
name: "global key for _next path",
490+
path: "/_next/static/chunk.js",
491+
method: "GET",
492+
globalCacheKeys: map[string]string{
493+
"/_next": "nextjs-assets",
494+
},
495+
want: "nextjs-assets",
496+
},
497+
{
498+
name: "global key for static path",
499+
path: "/static/images/logo.png",
500+
method: "GET",
501+
globalCacheKeys: map[string]string{
502+
"/static": "static-files",
503+
},
504+
want: "static-files",
505+
},
506+
{
507+
name: "multiple global keys - first match wins",
508+
path: "/assets/css/style.css",
509+
method: "GET",
510+
globalCacheKeys: map[string]string{
511+
"/assets": "assets-key",
512+
"/assets/css": "css-key",
513+
},
514+
want: "assets-key", // First match based on map iteration
515+
},
516+
{
517+
name: "no matching global key - falls back to hash",
518+
path: "/api/users",
519+
method: "GET",
520+
query: "id=1",
521+
globalCacheKeys: map[string]string{
522+
"/assets": "static-assets",
523+
"/_next": "nextjs-assets",
524+
},
525+
want: "hashed", // Will be calculated as hash
526+
},
527+
{
528+
name: "global key ignores query parameters and headers",
529+
path: "/assets/script.js",
530+
method: "GET",
531+
query: "version=1.0&cache=false",
532+
headers: map[string]string{
533+
"Authorization": "Bearer token123",
534+
},
535+
globalCacheKeys: map[string]string{
536+
"/assets": "static-assets",
537+
},
538+
want: "static-assets",
539+
},
540+
{
541+
name: "partial path match - should not match",
542+
path: "/asset", // Missing 's' from '/assets'
543+
method: "GET",
544+
globalCacheKeys: map[string]string{
545+
"/assets": "static-assets",
546+
},
547+
want: "hashed", // Will be calculated as hash
548+
},
549+
{
550+
name: "exact prefix match",
551+
path: "/assets",
552+
method: "GET",
553+
globalCacheKeys: map[string]string{
554+
"/assets": "static-assets",
555+
},
556+
want: "static-assets",
557+
},
558+
}
559+
560+
for _, tt := range tests {
561+
t.Run(tt.name, func(t *testing.T) {
562+
// Create request
563+
req, err := http.NewRequest(tt.method, tt.path, nil)
564+
if err != nil {
565+
t.Fatal(err)
566+
}
567+
568+
// Add query parameters if any
569+
if tt.query != "" {
570+
req.URL.RawQuery = tt.query
571+
}
572+
573+
// Add headers
574+
for key, value := range tt.headers {
575+
req.Header.Set(key, value)
576+
}
577+
578+
// Create handler with config
579+
h := handler{
580+
cfg: config.CacheConfig{
581+
ShouldHashQuery: true,
582+
HashQueryIgnore: make(map[string]bool),
583+
HashHeaders: []string{},
584+
GlobalCacheKeys: tt.globalCacheKeys,
585+
},
586+
}
587+
588+
// Get cache key
589+
got := h.getCacheKey(req)
590+
591+
if tt.want == "hashed" {
592+
// For non-global keys, verify it's a proper hash
593+
assert.Len(t, got, 64, "should be a SHA256 hash (64 chars)")
594+
assert.NotEqual(t, "static-assets", got)
595+
assert.NotEqual(t, "nextjs-assets", got)
596+
assert.NotEqual(t, "static-files", got)
597+
} else {
598+
// For global keys, should match exactly
599+
assert.Equal(t, tt.want, got, "cache key mismatch")
600+
}
601+
})
602+
}
603+
}

0 commit comments

Comments
 (0)