Skip to content

Commit 681ecdd

Browse files
committed
Include session id in fireworks request for better prompt caching
1 parent b110a59 commit 681ecdd

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

web/src/llm-api/fireworks.ts

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ function createFireworksRequest(params: {
9292
originalModel: string
9393
fetch: typeof globalThis.fetch
9494
modelIdOverride?: string
95+
sessionId: string
9596
}) {
96-
const { body, originalModel, fetch, modelIdOverride } = params
97+
const { body, originalModel, fetch, modelIdOverride, sessionId } = params
9798
const fireworksBody: Record<string, unknown> = {
9899
...body,
99100
model: modelIdOverride ?? getFireworksModelId(originalModel),
@@ -115,6 +116,7 @@ function createFireworksRequest(params: {
115116
headers: {
116117
Authorization: `Bearer ${env.FIREWORKS_API_KEY}`,
117118
'Content-Type': 'application/json',
119+
'x-session-affinity': sessionId
118120
},
119121
body: JSON.stringify(fireworksBody),
120122
// @ts-expect-error - dispatcher is a valid undici option not in fetch types
@@ -168,7 +170,7 @@ export async function handleFireworksNonStream({
168170
const startTime = new Date()
169171
const { clientId, clientRequestId, costMode } = extractRequestMetadata({ body, logger })
170172

171-
const response = await createFireworksRequestWithFallback({ body, originalModel, fetch, logger })
173+
const response = await createFireworksRequestWithFallback({ body, originalModel, fetch, logger, sessionId: userId })
172174

173175
if (!response.ok) {
174176
throw await parseFireworksError(response)
@@ -244,7 +246,7 @@ export async function handleFireworksStream({
244246
const startTime = new Date()
245247
const { clientId, clientRequestId, costMode } = extractRequestMetadata({ body, logger })
246248

247-
const response = await createFireworksRequestWithFallback({ body, originalModel, fetch, logger })
249+
const response = await createFireworksRequestWithFallback({ body, originalModel, fetch, logger, sessionId: userId })
248250

249251
if (!response.ok) {
250252
throw await parseFireworksError(response)
@@ -657,8 +659,9 @@ export async function createFireworksRequestWithFallback(params: {
657659
fetch: typeof globalThis.fetch
658660
logger: Logger
659661
useCustomDeployment?: boolean
662+
sessionId: string
660663
}): Promise<Response> {
661-
const { body, originalModel, fetch, logger } = params
664+
const { body, originalModel, fetch, logger, sessionId } = params
662665
const useCustomDeployment = params.useCustomDeployment ?? FIREWORKS_USE_CUSTOM_DEPLOYMENT
663666
const deploymentModelId = FIREWORKS_DEPLOYMENT_MAP[originalModel]
664667
const shouldTryDeployment =
@@ -677,6 +680,7 @@ export async function createFireworksRequestWithFallback(params: {
677680
originalModel,
678681
fetch,
679682
modelIdOverride: deploymentModelId,
683+
sessionId,
680684
})
681685

682686
if (response.status === 503) {
@@ -697,7 +701,7 @@ export async function createFireworksRequestWithFallback(params: {
697701
}
698702
}
699703

700-
return createFireworksRequest({ body, originalModel, fetch })
704+
return createFireworksRequest({ body, originalModel, fetch, sessionId })
701705
}
702706

703707
function creditsToFakeCost(credits: number): number {

0 commit comments

Comments
 (0)