Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 132 additions & 91 deletions apisix/plugins/ai-cache.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ local key_mod = require("apisix.plugins.ai-cache.key")
local binding = require("apisix.plugins.ai-protocols.binding")
local redis_util = require("apisix.utils.redis")
local semantic = require("apisix.plugins.ai-cache.semantic")
local stream = require("apisix.plugins.ai-cache.stream")

local ngx = ngx
local ngx_null = ngx.null
Expand Down Expand Up @@ -81,6 +82,53 @@ local function release(conf, red)
end


-- Run fn(red) on a pooled connection: released on success, closed when fn
-- returns an error or throws. Returns fn's result, or (nil, err) on any failure.
local function with_redis(conf, fn)
local red, err = redis_util.new(conf)
if not red then
return nil, err
end
local ok, res, ferr = pcall(fn, red)
if not ok or ferr then
red:close()
return nil, not ok and res or ferr
end
release(conf, red)
return res
end


-- fail-open: a cache-backend or embedding failure must never break the
-- request; log it and treat the lookup as a MISS.
local function fail_open(ctx, what, err)
core.log.warn("ai-cache: ", what, ", fail-open as MISS: ", err)
ctx.ai_cache_status = "MISS"
end


-- The L1 stored value; encoded only here so the shape has one home.
local function encode_entry(body, created_at, format)
return core.json.encode({ body = body, created_at = created_at, format = format })
end


-- Best-effort L2 -> L1 backfill under this request's L1 key, carrying
-- created_at and format so either layer replays the hit identically.
local function backfill_l1(conf, ctx, red, hit)
local envelope = encode_entry(hit.body, hit.created_at, hit.format)
if not envelope then
core.log.warn("ai-cache: L1 backfill skipped: json.encode returned nil")
return
end
local ok, err = red:set(ctx.ai_cache_key, envelope,
"EX", (conf.exact and conf.exact.ttl) or DEFAULT_TTL)
if not ok then
core.log.warn("ai-cache: L1 backfill SET failed: ", err)
end
end


local function serve_hit(conf, ctx, cached, similarity)
local status = "HIT"
ctx.ai_cache_status = status
Expand All @@ -93,7 +141,9 @@ local function serve_hit(conf, ctx, cached, similarity)
str_format("%.4f", similarity))
end
end
core.response.set_header("Content-Type", "application/json")
core.response.set_header("Content-Type",
cached.format == stream.FORMAT_SSE and "text/event-stream"
or "application/json")
return core.response.exit(200, cached.body)
end

Expand All @@ -111,10 +161,11 @@ function _M.access(conf, ctx)
return
end

-- Streaming responses are not cached in PR-1 (SSE replay is a later
-- increment). ai-proxy (higher priority) has already classified the
-- request, so bypass before doing any work.
if ctx.var.request_type == "ai_stream" then
-- A stream on a non-SSE wire framing (bedrock's aws-eventstream) can never
-- be captured or replayed, so the lookup would be a guaranteed-miss redis
-- GET on every request: bypass before doing any work.
if ctx.var.request_type == "ai_stream"
and not stream.provider_capturable(ctx.picked_ai_instance) then
ctx.ai_cache_status = "BYPASS"
return
end
Expand All @@ -137,78 +188,67 @@ function _M.access(conf, ctx)

ctx.ai_cache_fingerprint = key_mod.fingerprint(ctx, body)
ctx.ai_cache_key = key_mod.build(conf, ctx, ctx.ai_cache_fingerprint)
-- Remember which instance the fingerprint was computed for. ai-proxy-multi
-- may fall back to a different instance in before_proxy; the log phase uses
-- this to avoid writing that fallback response under the original key.
-- which instance the fingerprint was computed for; log() checks it so a
-- fallback instance's response is never written under this key
ctx.ai_cache_picked_at_access = ctx.picked_ai_instance

local red
red, err = redis_util.new(conf)
if not red then
-- fail-open: never let a cache-backend outage break the request.
core.log.warn("ai-cache: redis unavailable, fail-open as MISS: ", err)
ctx.ai_cache_status = "MISS"
return
end

local res
res, err = red:get(ctx.ai_cache_key)
if err then
red:close()
core.log.warn("ai-cache: redis get failed, fail-open as MISS: ", err)
ctx.ai_cache_status = "MISS"
return
end
if res ~= nil and res ~= ngx_null then
local cached = core.json.decode(res)
if cached and cached.body then
release(conf, red)
return serve_hit(conf, ctx, cached)
local cached
cached, err = with_redis(conf, function(red)
local res, gerr = red:get(ctx.ai_cache_key)
if gerr then
return nil, gerr
end
if res == nil or res == ngx_null then
return nil
end
local entry = core.json.decode(res)
if entry and entry.body then
return entry
end
core.log.warn("ai-cache: discarding malformed cache entry for ", ctx.ai_cache_key)
return nil
end)
if err then
return fail_open(ctx, "L1 lookup failed", err)
end
if cached then
return serve_hit(conf, ctx, cached)
end

-- L1 miss -> L2 semantic lookup. Release the L1 connection before
-- embed_query()'s HTTP call so the pool isn't pinned across the embedding
-- round-trip; re-acquire for the vector search. pcall keeps throws fail-open.
-- L1 miss -> L2 semantic lookup, in its own connection scope so the pool
-- isn't pinned across embed_query()'s HTTP round-trip.
if has_layer(conf, "semantic") and conf.semantic then
release(conf, red)

local ok, vec = pcall(semantic.embed_query, conf, ctx, body)
if not ok then
core.log.warn("ai-cache: semantic embed error, fail-open as MISS: ", vec)
vec = nil
fail_open(ctx, "semantic embed error", vec)
-- prevent log() from scheduling a write with partial/bad state
ctx.ai_cache_embedding = nil
return
end

if vec then
local sred
sred, err = redis_util.new(conf)
if not sred then
core.log.warn("ai-cache: redis unavailable for semantic search, ",
"fail-open as MISS: ", err)
else
local sok, hit = pcall(semantic.search, sred, conf, ctx, vec)
if not sok then
sred:close()
core.log.warn("ai-cache: semantic search error, fail-open as MISS: ", hit)
ctx.ai_cache_embedding = nil
else
release(conf, sred)
if hit then
return serve_hit(conf, ctx,
{ body = hit.body, created_at = hit.created_at }, hit.similarity)
local hit
hit, err = with_redis(conf, function(red)
local h = semantic.search(red, conf, ctx, vec)
if h then
local bok, berr = pcall(backfill_l1, conf, ctx, red, h)
if not bok then
core.log.warn("ai-cache: L1 backfill error: ", berr)
end
end
return h
end)
if err then
fail_open(ctx, "semantic search failed", err)
ctx.ai_cache_embedding = nil
return
end
if hit then
return serve_hit(conf, ctx, hit, hit.similarity)
end
end

ctx.ai_cache_status = "MISS"
return
end

release(conf, red)
ctx.ai_cache_status = "MISS"
end

Expand All @@ -222,7 +262,8 @@ end

function _M.body_filter(conf, ctx)
-- only a MISS gets written back; HIT exited in access, BYPASS opts out.
if ctx.ai_cache_status ~= "MISS" or ctx.ai_cache_oversized then
if ctx.ai_cache_status ~= "MISS" or ctx.ai_cache_oversized
or not stream.capturable(ctx) then
Comment thread
nic-6443 marked this conversation as resolved.
return
end
local chunk = ngx.arg[1]
Expand All @@ -244,64 +285,63 @@ function _M.body_filter(conf, ctx)
end


-- The response-capturing phases (body_filter / log) run in contexts where
-- cosockets are disabled, so the Redis write is deferred to a 0-delay timer
-- (timers run in a light thread where cosockets are allowed).
-- l2 (optional) = { partition, embedding, dim, fingerprint, ttl } for L2 write.
local function write_to_cache(premature, conf, cache_key, response_body, l2)
-- body_filter/log cannot use cosockets, so the Redis write runs in a 0-delay
-- timer. l2 (optional) = { partition, embedding, dim, fingerprint, ttl }.
local function write_to_cache(premature, conf, cache_key, response_body, l2, format)
if premature then
return
end
local red, err = redis_util.new(conf)
if not red then
core.log.warn("ai-cache: redis unavailable on write: ", err)
return
end
local envelope = core.json.encode({ body = response_body, created_at = ngx.time() })
local ttl = (conf.exact and conf.exact.ttl) or DEFAULT_TTL
local ok
ok, err = red:set(cache_key, envelope, "EX", ttl)
if not ok then
red:close()
core.log.warn("ai-cache: redis set failed: ", err)
return
end
if l2 then
l2.created_at = ngx.time()
local wok, werr = pcall(semantic.write, red, conf, l2, response_body)
if not wok then
core.log.warn("ai-cache: semantic write error: ", werr)
local now = ngx.time()
local envelope = encode_entry(response_body, now, format)
local _, err = with_redis(conf, function(red)
local ok, serr = red:set(cache_key, envelope, "EX",
(conf.exact and conf.exact.ttl) or DEFAULT_TTL)
if not ok then
return nil, serr
end
if l2 then
l2.created_at = now
l2.format = format
semantic.write(red, conf, l2, response_body)
end
return true
end)
if err then
core.log.warn("ai-cache: cache write failed: ", err)
end
release(conf, red)
end


function _M.log(conf, ctx)
if ctx.ai_cache_status ~= "MISS" or not ctx.ai_cache_fingerprint then
return
end
-- ai-proxy-multi may reassign the picked instance on fallback/retry during
-- before_proxy. The frozen fingerprint identifies the ORIGINAL instance, so a
-- response actually produced by a different (fallback) instance must not be
-- written under it -- that would replay the wrong instance's response on a
-- later hit.
-- the fingerprint identifies the instance picked at access time; a
-- fallback/retry response from another instance must not be cached under it
if ctx.picked_ai_instance ~= ctx.ai_cache_picked_at_access then
return
end
if ngx.status ~= 200 then
return
end
if ctx.ai_stream_aborted then
return
end
local buf = ctx.ai_cache_buf
if not buf or buf.bytes == 0 then
return
end
local response_body = concat(buf, "", 1, buf.n)

local format = stream.capture_format(ctx, response_body)
if not format then
return
end

local cache_key = key_mod.build(conf, ctx, ctx.ai_cache_fingerprint)

-- Build the L2 doc from ctx fields stashed by semantic.embed_query(); the
-- embedding is only set on a successful embed, so a nil check guards the write.
-- L2 doc from ctx fields stashed by semantic.embed_query(); the embedding
-- is only set on a successful embed.
local l2
if has_layer(conf, "semantic") and ctx.ai_cache_embedding then
l2 = {
Expand All @@ -313,7 +353,8 @@ function _M.log(conf, ctx)
}
end

local ok, err = ngx.timer.at(0, write_to_cache, conf, cache_key, response_body, l2)
local ok, err = ngx.timer.at(0, write_to_cache, conf, cache_key,
response_body, l2, format)
if not ok then
core.log.warn("ai-cache: failed to schedule cache write: ", err)
end
Expand Down
11 changes: 2 additions & 9 deletions apisix/plugins/ai-cache/key.lua
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ local function build_repr(ctx, body, messages)
params[k] = v
end
end
local proto = ctx.ai_client_protocol and protocols.get(ctx.ai_client_protocol)
params.stream = (proto and proto.is_streaming(body)) == true

return {
client = {
Expand Down Expand Up @@ -107,15 +109,6 @@ function _M.fingerprint(ctx, body)
end


-- Returns the SHA-256 hex digest of the effective context with message text
-- removed. Queries that differ only in phrasing (same model/params/instance)
-- share one fingerprint, enabling semantic deduplication without storing the
-- raw prompt.
function _M.context_fingerprint(ctx, body)
return hex_digest(core.json.canonical_encode(build_repr(ctx, body, nil)))
end


-- Percent-encode "%", ":" and "=" (in that order) in scope values so a request-controlled
-- include_vars value can't shift "name=value:" boundaries to forge another scope.
local function esc(v)
Expand Down
23 changes: 6 additions & 17 deletions apisix/plugins/ai-cache/semantic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ end

-- Phase 2 of the L2 lookup: vector search over a caller-owned connection
-- acquired AFTER embed_query() (so the pool isn't pinned across embedding).
-- Returns a hit {body, created_at, similarity} on a >=threshold match, else nil.
-- Returns a hit {body, created_at, format, similarity} on a >=threshold match,
-- else nil. The L1 backfill of a hit is the caller's job (ai-cache.lua owns L1).
function _M.search(red, conf, ctx, vec)
local sem = conf.semantic
local target = redis_target(conf)
Expand All @@ -276,26 +277,13 @@ function _M.search(red, conf, ctx, vec)
return nil
end

-- L2 -> L1 backfill, carrying the L2 entry's original created_at so Age is
-- consistent whether the next hit is served from L1 or L2. A real semantic
-- hit must be served regardless — only the backfill SET is skipped on error.
local envelope = core.json.encode({ body = hit.response, created_at = hit.created_at })
if not envelope then
core.log.warn("ai-cache: L1 backfill skipped: json.encode returned nil")
else
local exact_ttl = (conf.exact and conf.exact.ttl) or 3600
local bok, berr = red:set(ctx.ai_cache_key, envelope, "EX", exact_ttl)
if not bok then
core.log.warn("ai-cache: L1 backfill SET failed: ", berr)
end
end

return { body = hit.response, created_at = hit.created_at, similarity = similarity }
return { body = hit.response, created_at = hit.created_at,
format = hit.format, similarity = similarity }
end


-- Called from the write-back timer (after the L1 SET) with a still-open `red`.
-- l2 = { partition, embedding, dim, fingerprint, ttl, created_at }
-- l2 = { partition, embedding, dim, fingerprint, ttl, created_at, format }
function _M.write(red, conf, l2, response_body)
if not l2 or not l2.embedding then
return
Expand All @@ -314,6 +302,7 @@ function _M.write(red, conf, l2, response_body)
embedding = vs.pack_float32(l2.embedding),
response = response_body,
created_at = l2.created_at,
format = l2.format,
}, l2.ttl)
if not ok then
core.log.warn("ai-cache: L2 upsert failed: ", err)
Expand Down
Loading
Loading