diff --git a/ai-logic/firebase-ai/CHANGELOG.md b/ai-logic/firebase-ai/CHANGELOG.md index 473576c4af6..8bc8748ae68 100644 --- a/ai-logic/firebase-ai/CHANGELOG.md +++ b/ai-logic/firebase-ai/CHANGELOG.md @@ -19,6 +19,9 @@ - [fixed] Fixed an issue causing the SDK to throw an exception if an unknown message was received from the LiveAPI model, instead of ignoring it (#7975) +- [fixed] Fixed `LiveGenerativeModel.connect()` not attaching the `X-Firebase-AppCheck` + header, causing Live API requests to be rejected when App Check is enforced on AI Logic. (#8060) + # 17.10.1 - [fixed] Fixed an issue causing Live API to fail when using the `GoogleAI` backend (#7880) diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt index 90ef24420bb..f3d67e14850 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt @@ -255,8 +255,19 @@ internal constructor( "wss://firebasevertexai.googleapis.com/ws/google.firebase.vertexai.v1beta.GenerativeService/BidiGenerateContent?key=$key" } - suspend fun getWebSocketSession(location: String): DefaultClientWebSocketSession = - client.webSocketSession(getBidiEndpoint(location)) { applyCommonHeaders() } + suspend fun getWebSocketSession(location: String): DefaultClientWebSocketSession { + // applyHeaderProvider() is suspend; Ktor's webSocketSession { } config lambda is not. + // Pre-fetch headers (including X-Firebase-AppCheck) in the outer suspend context using + // the same timeout-protected path as HTTP methods, then set them synchronously inside the + // lambda. + val extraHeaders = extractHeaders(headerProvider) + return client.webSocketSession(getBidiEndpoint(location)) { + applyCommonHeaders() + for ((tag, value) in extraHeaders) { + header(tag, value) + } + } + } fun generateContentStream( request: GenerateContentRequest @@ -306,16 +317,18 @@ internal constructor( } private suspend fun HttpRequestBuilder.applyHeaderProvider() { - if (headerProvider != null) { - try { - withTimeout(headerProvider.timeout) { - for ((tag, value) in headerProvider.generateHeaders()) { - header(tag, value) - } - } - } catch (e: TimeoutCancellationException) { - Log.w(TAG, "HeaderProvided timed out without generating headers, ignoring") - } + for ((tag, value) in extractHeaders(headerProvider)) { + header(tag, value) + } + } + + private suspend fun extractHeaders(headerProvider: HeaderProvider?): Map { + if (headerProvider == null) return emptyMap() + return try { + withTimeout(headerProvider.timeout) { headerProvider.generateHeaders() } + } catch (e: TimeoutCancellationException) { + Log.w(TAG, "HeaderProvided timed out without generating headers, ignoring", e) + emptyMap() } } diff --git a/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/common/APIControllerTests.kt b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/common/APIControllerTests.kt index 3ca82c4a1b1..ab54be25619 100644 --- a/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/common/APIControllerTests.kt +++ b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/common/APIControllerTests.kt @@ -417,6 +417,42 @@ internal class RequestFormatTests { mockEngine.requestHistory.first().headers.contains("header1") shouldBe false } + @Test + fun `headers from HeaderProvider are added to the WebSocket handshake`() = doBlocking { + val mockEngine = MockEngine { + // MockEngine isn't designed to complete a WebSocket upgrade handshake, but the + // outgoing request is recorded in requestHistory before the handshake attempt, + // so we can still assert on its headers. + respond("", HttpStatusCode.OK) + } + + val testHeaderProvider = + object : HeaderProvider { + override val timeout: Duration + get() = 5.seconds + + override suspend fun generateHeaders(): Map = + mapOf("X-Firebase-AppCheck" to "test-token") + } + + val controller = + APIController( + "super_cool_test_key", + "gemini-pro-2.5", + RequestOptions(), + mockEngine, + TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, + testHeaderProvider, + ) + + runCatching { withTimeout(5.seconds) { controller.getWebSocketSession("us-central1") } } + + mockEngine.requestHistory.first().headers["X-Firebase-AppCheck"] shouldBe "test-token" + } + @Test fun `code execution tool serialization contains correct keys`() = doBlocking { val channel = ByteChannel(autoFlush = true)