Skip to content

Commit 4b4a584

Browse files
feat: raw response accumulation for streaming
1 parent facc630 commit 4b4a584

File tree

13 files changed

+211
-139
lines changed

13 files changed

+211
-139
lines changed

framework/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
feat: support raw response accumulation in stream accumulator

framework/streaming/accumulator.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ func (a *Accumulator) putChatStreamChunk(chunk *ChatStreamChunk) {
4343
chunk.ErrorDetails = nil
4444
chunk.FinishReason = nil
4545
chunk.TokenUsage = nil
46+
chunk.RawResponse = nil
4647
a.chatStreamChunkPool.Put(chunk)
4748
}
4849

@@ -60,6 +61,7 @@ func (a *Accumulator) putAudioStreamChunk(chunk *AudioStreamChunk) {
6061
chunk.ErrorDetails = nil
6162
chunk.FinishReason = nil
6263
chunk.TokenUsage = nil
64+
chunk.RawResponse = nil
6365
a.audioStreamChunkPool.Put(chunk)
6466
}
6567

@@ -77,6 +79,7 @@ func (a *Accumulator) putTranscriptionStreamChunk(chunk *TranscriptionStreamChun
7779
chunk.ErrorDetails = nil
7880
chunk.FinishReason = nil
7981
chunk.TokenUsage = nil
82+
chunk.RawResponse = nil
8083
a.transcriptionStreamChunkPool.Put(chunk)
8184
}
8285

@@ -94,6 +97,7 @@ func (a *Accumulator) putResponsesStreamChunk(chunk *ResponsesStreamChunk) {
9497
chunk.ErrorDetails = nil
9598
chunk.FinishReason = nil
9699
chunk.TokenUsage = nil
100+
chunk.RawResponse = nil
97101
a.responsesStreamChunkPool.Put(chunk)
98102
}
99103

framework/streaming/audio.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ func (a *Accumulator) processAccumulatedAudioStreamingChunks(requestID string, b
9191
data.CacheDebug = lastChunk.SemanticCacheDebug
9292
}
9393
}
94+
// Accumulate raw response
95+
if len(accumulator.AudioStreamChunks) > 0 {
96+
// Sort chunks by chunk index
97+
sort.Slice(accumulator.AudioStreamChunks, func(i, j int) bool {
98+
return accumulator.AudioStreamChunks[i].ChunkIndex < accumulator.AudioStreamChunks[j].ChunkIndex
99+
})
100+
for _, chunk := range accumulator.AudioStreamChunks {
101+
if chunk.RawResponse != nil {
102+
if data.RawResponse == nil {
103+
data.RawResponse = bifrost.Ptr(*chunk.RawResponse + "\n\n")
104+
} else {
105+
*data.RawResponse += *chunk.RawResponse + "\n\n"
106+
}
107+
}
108+
}
109+
}
94110
return data, nil
95111
}
96112

@@ -118,6 +134,9 @@ func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext,
118134
Audio: result.SpeechStreamResponse.Audio,
119135
}
120136
chunk.Delta = newDelta
137+
if result.SpeechStreamResponse.ExtraFields.RawResponse != nil {
138+
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.SpeechStreamResponse.ExtraFields.RawResponse))
139+
}
121140
if result.SpeechStreamResponse.Usage != nil {
122141
chunk.TokenUsage = result.SpeechStreamResponse.Usage
123142
}

framework/streaming/chat.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,22 @@ func (a *Accumulator) processAccumulatedChatStreamingChunks(requestID string, re
179179
}
180180
data.FinishReason = lastChunk.FinishReason
181181
}
182+
// Accumulate raw response
183+
if len(accumulator.ChatStreamChunks) > 0 {
184+
// Sort chunks by chunk index
185+
sort.Slice(accumulator.ChatStreamChunks, func(i, j int) bool {
186+
return accumulator.ChatStreamChunks[i].ChunkIndex < accumulator.ChatStreamChunks[j].ChunkIndex
187+
})
188+
for _, chunk := range accumulator.ChatStreamChunks {
189+
if chunk.RawResponse != nil {
190+
if data.RawResponse == nil {
191+
data.RawResponse = bifrost.Ptr(*chunk.RawResponse + "\n\n")
192+
} else {
193+
*data.RawResponse += *chunk.RawResponse + "\n\n"
194+
}
195+
}
196+
}
197+
}
182198
return data, nil
183199
}
184200

@@ -227,6 +243,9 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext,
227243
chunk.TokenUsage = result.ChatResponse.Usage
228244
}
229245
chunk.ChunkIndex = result.ChatResponse.ExtraFields.ChunkIndex
246+
if result.ChatResponse.ExtraFields.RawResponse != nil {
247+
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.ChatResponse.ExtraFields.RawResponse))
248+
}
230249
if isFinalChunk {
231250
if a.pricingManager != nil {
232251
cost := a.pricingManager.CalculateCostWithCacheDebug(result)

framework/streaming/responses.go

Lines changed: 97 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,23 @@ func (a *Accumulator) processAccumulatedResponsesStreamingChunks(requestID strin
660660
data.FinishReason = lastChunk.FinishReason
661661
}
662662

663+
// Accumulate raw response
664+
if len(accumulator.ResponsesStreamChunks) > 0 {
665+
// Sort chunks by chunk index
666+
sort.Slice(accumulator.ResponsesStreamChunks, func(i, j int) bool {
667+
return accumulator.ResponsesStreamChunks[i].ChunkIndex < accumulator.ResponsesStreamChunks[j].ChunkIndex
668+
})
669+
for _, chunk := range accumulator.ResponsesStreamChunks {
670+
if chunk.RawResponse != nil {
671+
if data.RawResponse == nil {
672+
data.RawResponse = bifrost.Ptr(*chunk.RawResponse + "\n\n")
673+
} else {
674+
*data.RawResponse += *chunk.RawResponse + "\n\n"
675+
}
676+
}
677+
}
678+
}
679+
663680
return data, nil
664681
}
665682

@@ -683,54 +700,94 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont
683700

684701
// For OpenAI-compatible providers, the last chunk already contains the whole accumulated response
685702
// so just return it as is
703+
// We maintain the accumulator only for raw response accumulation
686704
if provider == schemas.OpenAI || provider == schemas.OpenRouter || (provider == schemas.Azure && !schemas.IsAnthropicModel(model)) {
687705
isFinalChunk := bifrost.IsFinalChunk(ctx)
706+
chunk := a.getResponsesStreamChunk()
707+
chunk.Timestamp = time.Now()
708+
chunk.ErrorDetails = bifrostErr
709+
if bifrostErr != nil {
710+
chunk.FinishReason = bifrost.Ptr("error")
711+
} else if result != nil && result.ResponsesStreamResponse != nil {
712+
if result.ResponsesStreamResponse.ExtraFields.RawResponse != nil {
713+
rawResponse, ok := result.ResponsesStreamResponse.ExtraFields.RawResponse.(string)
714+
if ok {
715+
chunk.RawResponse = bifrost.Ptr(rawResponse)
716+
}
717+
}
718+
}
719+
if addErr := a.addResponsesStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
720+
return nil, fmt.Errorf("failed to add responses stream chunk for request %s: %w", requestID, addErr)
721+
}
688722
if isFinalChunk {
689-
// For OpenAI, the final chunk contains the complete response
690-
// Extract the complete response and return it
691-
if result != nil && result.ResponsesStreamResponse != nil {
692-
// Build the complete response from the final chunk
693-
data := &AccumulatedData{
694-
RequestID: requestID,
695-
Status: "success",
696-
Stream: true,
697-
StartTimestamp: startTimestamp,
698-
EndTimestamp: endTimestamp,
699-
Latency: result.GetExtraFields().Latency,
700-
ErrorDetails: bifrostErr,
723+
shouldProcess := false
724+
// Get the accumulator to check if processing has already been triggered
725+
accumulator := a.getOrCreateStreamAccumulator(requestID)
726+
accumulator.mu.Lock()
727+
shouldProcess = !accumulator.IsComplete
728+
// Mark as complete when we're about to process
729+
if shouldProcess {
730+
accumulator.IsComplete = true
731+
}
732+
accumulator.mu.Unlock()
733+
734+
if shouldProcess {
735+
accumulatedData, processErr := a.processAccumulatedResponsesStreamingChunks(requestID, bifrostErr, isFinalChunk)
736+
if processErr != nil {
737+
a.logger.Error("failed to process accumulated responses chunks for request %s: %v", requestID, processErr)
738+
return nil, processErr
701739
}
702740

703-
if bifrostErr != nil {
704-
data.Status = "error"
705-
}
741+
// For OpenAI, the final chunk contains the complete response
742+
// Extract the complete response and return it
743+
if result != nil && result.ResponsesStreamResponse != nil {
744+
// Build the complete response from the final chunk
745+
data := &AccumulatedData{
746+
RequestID: requestID,
747+
Status: "success",
748+
Stream: true,
749+
StartTimestamp: startTimestamp,
750+
EndTimestamp: endTimestamp,
751+
Latency: result.GetExtraFields().Latency,
752+
ErrorDetails: bifrostErr,
753+
RawResponse: accumulatedData.RawResponse,
754+
}
706755

707-
// Extract the complete response from the stream response
708-
if result.ResponsesStreamResponse.Response != nil {
709-
data.OutputMessages = result.ResponsesStreamResponse.Response.Output
710-
if result.ResponsesStreamResponse.Response.Usage != nil {
711-
// Convert ResponsesResponseUsage to schemas.LLMUsage
712-
data.TokenUsage = &schemas.BifrostLLMUsage{
713-
PromptTokens: result.ResponsesStreamResponse.Response.Usage.InputTokens,
714-
CompletionTokens: result.ResponsesStreamResponse.Response.Usage.OutputTokens,
715-
TotalTokens: result.ResponsesStreamResponse.Response.Usage.TotalTokens,
756+
if bifrostErr != nil {
757+
data.Status = "error"
758+
}
759+
760+
// Extract the complete response from the stream response
761+
if result.ResponsesStreamResponse.Response != nil {
762+
data.OutputMessages = result.ResponsesStreamResponse.Response.Output
763+
if result.ResponsesStreamResponse.Response.Usage != nil {
764+
// Convert ResponsesResponseUsage to schemas.LLMUsage
765+
data.TokenUsage = &schemas.BifrostLLMUsage{
766+
PromptTokens: result.ResponsesStreamResponse.Response.Usage.InputTokens,
767+
CompletionTokens: result.ResponsesStreamResponse.Response.Usage.OutputTokens,
768+
TotalTokens: result.ResponsesStreamResponse.Response.Usage.TotalTokens,
769+
}
716770
}
717771
}
718-
}
719772

720-
if a.pricingManager != nil {
721-
cost := a.pricingManager.CalculateCostWithCacheDebug(result)
722-
data.Cost = bifrost.Ptr(cost)
723-
}
773+
if a.pricingManager != nil {
774+
cost := a.pricingManager.CalculateCostWithCacheDebug(result)
775+
data.Cost = bifrost.Ptr(cost)
776+
}
724777

725-
return &ProcessedStreamResponse{
726-
Type: StreamResponseTypeFinal,
727-
RequestID: requestID,
728-
StreamType: StreamTypeResponses,
729-
Provider: provider,
730-
Model: model,
731-
Data: data,
732-
}, nil
778+
return &ProcessedStreamResponse{
779+
Type: StreamResponseTypeFinal,
780+
RequestID: requestID,
781+
StreamType: StreamTypeResponses,
782+
Provider: provider,
783+
Model: model,
784+
Data: data,
785+
}, nil
786+
} else {
787+
return nil, nil
788+
}
733789
}
790+
return nil, nil
734791
}
735792

736793
// For non-final chunks from OpenAI, just pass through
@@ -753,6 +810,9 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont
753810
if bifrostErr != nil {
754811
chunk.FinishReason = bifrost.Ptr("error")
755812
} else if result != nil && result.ResponsesStreamResponse != nil {
813+
if result.ResponsesStreamResponse.ExtraFields.RawResponse != nil {
814+
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.ResponsesStreamResponse.ExtraFields.RawResponse))
815+
}
756816
// Store a deep copy of the stream response to prevent shared data mutation between plugins
757817
chunk.StreamResponse = deepCopyResponsesStreamResponse(result.ResponsesStreamResponse)
758818
// Extract token usage from stream response if available

framework/streaming/transcription.go

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,22 @@ func (a *Accumulator) processAccumulatedTranscriptionStreamingChunks(requestID s
103103
data.CacheDebug = lastChunk.SemanticCacheDebug
104104
}
105105
}
106+
// Accumulate raw response
107+
if len(accumulator.TranscriptionStreamChunks) > 0 {
108+
// Sort chunks by chunk index
109+
sort.Slice(accumulator.TranscriptionStreamChunks, func(i, j int) bool {
110+
return accumulator.TranscriptionStreamChunks[i].ChunkIndex < accumulator.TranscriptionStreamChunks[j].ChunkIndex
111+
})
112+
for _, chunk := range accumulator.TranscriptionStreamChunks {
113+
if chunk.RawResponse != nil {
114+
if data.RawResponse == nil {
115+
data.RawResponse = bifrost.Ptr(*chunk.RawResponse + "\n\n")
116+
} else {
117+
*data.RawResponse += *chunk.RawResponse + "\n\n"
118+
}
119+
}
120+
}
121+
}
106122
return data, nil
107123
}
108124

@@ -123,18 +139,22 @@ func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.Bifrost
123139
if bifrostErr != nil {
124140
chunk.FinishReason = bifrost.Ptr("error")
125141
} else if result != nil && result.TranscriptionStreamResponse != nil {
142+
// Set delta for all chunks (not just final chunks with usage)
143+
// We create a deep copy of the delta to avoid pointing to stack memory
144+
newDelta := &schemas.BifrostTranscriptionStreamResponse{
145+
Type: result.TranscriptionStreamResponse.Type,
146+
Delta: result.TranscriptionStreamResponse.Delta,
147+
}
148+
chunk.Delta = newDelta
149+
150+
// Set token usage if available (typically only in final chunk)
126151
if result.TranscriptionStreamResponse.Usage != nil {
127152
chunk.TokenUsage = result.TranscriptionStreamResponse.Usage
128-
129-
// For Transcription, entire delta is sent in the final chunk which also has usage information
130-
// We create a deep copy of the delta to avoid pointing to stack memory
131-
newDelta := &schemas.BifrostTranscriptionStreamResponse{
132-
Type: result.TranscriptionStreamResponse.Type,
133-
Delta: result.TranscriptionStreamResponse.Delta,
134-
}
135-
chunk.Delta = newDelta
136153
}
137154
chunk.ChunkIndex = result.TranscriptionStreamResponse.ExtraFields.ChunkIndex
155+
if result.TranscriptionStreamResponse.ExtraFields.RawResponse != nil {
156+
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.TranscriptionStreamResponse.ExtraFields.RawResponse))
157+
}
138158
if isFinalChunk {
139159
if a.pricingManager != nil {
140160
cost := a.pricingManager.CalculateCostWithCacheDebug(result)

framework/streaming/types.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ type AccumulatedData struct {
4343
AudioOutput *schemas.BifrostSpeechResponse
4444
TranscriptionOutput *schemas.BifrostTranscriptionResponse
4545
FinishReason *string
46+
RawResponse *string
4647
}
4748

4849
// AudioStreamChunk represents a single streaming chunk
@@ -55,6 +56,7 @@ type AudioStreamChunk struct {
5556
Cost *float64 // Cost in dollars from pricing plugin
5657
ErrorDetails *schemas.BifrostError // Error if any
5758
ChunkIndex int // Index of the chunk in the stream
59+
RawResponse *string
5860
}
5961

6062
// TranscriptionStreamChunk represents a single transcription streaming chunk
@@ -67,6 +69,7 @@ type TranscriptionStreamChunk struct {
6769
Cost *float64 // Cost in dollars from pricing plugin
6870
ErrorDetails *schemas.BifrostError // Error if any
6971
ChunkIndex int // Index of the chunk in the stream
72+
RawResponse *string
7073
}
7174

7275
// ChatStreamChunk represents a single streaming chunk
@@ -79,6 +82,7 @@ type ChatStreamChunk struct {
7982
Cost *float64 // Cost in dollars from pricing plugin
8083
ErrorDetails *schemas.BifrostError // Error if any
8184
ChunkIndex int // Index of the chunk in the stream
85+
RawResponse *string // Raw response if available
8286
}
8387

8488
// ResponsesStreamChunk represents a single responses streaming chunk
@@ -91,6 +95,7 @@ type ResponsesStreamChunk struct {
9195
Cost *float64 // Cost in dollars from pricing plugin
9296
ErrorDetails *schemas.BifrostError // Error if any
9397
ChunkIndex int // Index of the chunk in the stream
98+
RawResponse *string
9499
}
95100

96101
// StreamAccumulator manages accumulation of streaming chunks

plugins/logging/operations.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ func (p *LoggerPlugin) updateStreamingLogEntry(
293293
updates["responses_output"] = tempEntry.ResponsesOutput
294294
}
295295
}
296+
// Handle raw response from stream updates
297+
if streamResponse.Data.RawResponse != nil {
298+
updates["raw_response"] = *streamResponse.Data.RawResponse
299+
}
296300
}
297301
// Only perform update if there's something to update
298302
if len(updates) > 0 {

transports/changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
fix: vertex and bedrock usage aggregation improvements for streaming
22
fix: choice index fixed to 0 for anthropic and bedrock streaming
33
feat: model field added to responses api response
4-
feat: check allowed models and deployments of key for list models
4+
feat: check allowed models and deployments of key for list models
5+
feat: support for raw response accumulation for streaming

0 commit comments

Comments
 (0)