diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt index 29b3062927e..ef76b08195b 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt @@ -65,14 +65,15 @@ internal constructor( @OptIn(PublicPreviewAPI::class) internal fun toPublic(): Candidate { + val content = this.content?.toPublic() ?: content("model") {} val safetyRatings = safetyRatings?.mapNotNull { it.toPublic() }.orEmpty() - val citations = citationMetadata?.toPublic() + val citations = citationMetadata?.toPublic(content) val finishReason = finishReason?.toPublic() val groundingMetadata = groundingMetadata?.toPublic() val urlContextMetadata = urlContextMetadata?.toPublic() return Candidate( - this.content?.toPublic() ?: content("model") {}, + content, safetyRatings, citations, finishReason, @@ -163,7 +164,7 @@ public class CitationMetadata internal constructor(public val citations: List) { - internal fun toPublic() = CitationMetadata(citationSources.map { it.toPublic() }) + internal fun toPublic(content: Content) = CitationMetadata(citationSources.map { it.toPublic(content) }) } } @@ -203,7 +204,7 @@ internal constructor( val publicationDate: Date? = null, ) { - internal fun toPublic(): Citation { + internal fun toPublic(content: Content): Citation { val publicationDateAsCalendar = publicationDate?.let { val calendar = Calendar.getInstance() @@ -220,8 +221,8 @@ internal constructor( } return Citation( title = title, - startIndex = startIndex, - endIndex = endIndex, + startIndex = convertUtf8IndexToUtf16(content, startIndex), + endIndex = convertUtf8IndexToUtf16(content, endIndex), uri = uri, license = license, publicationDate = publicationDateAsCalendar @@ -635,3 +636,37 @@ private constructor(public val name: String, public val ordinal: Int) { @JvmField public val UNSAFE: UrlRetrievalStatus = UrlRetrievalStatus("UNSAFE", 4) } } + +internal fun convertUtf8IndexToUtf16(content: Content, originalIndex: Int): Int { + if (originalIndex == 0) { + return 0 + } + var sumIndex = 0 + var progress = 0 + for (part in content.parts) { + val text = part.asTextOrNull() ?: "" + var i = 0 + while (i < text.length) { + val c = text[i].code + progress += when { + c < 0x80 -> 1 // ASCII + c < 0x800 -> 2 // Two-byte codepoint + c in 0xD800 .. 0xDBFF -> 4 // High surrogate character + else -> 3 + } + if (c in 0xD800..0xDBFF) { + i++ // Skip the low surrogate + } + i++ + if (progress >= originalIndex) { + if (progress > originalIndex) { + throw StringIndexOutOfBoundsException("Desired index $originalIndex is between Unicode codepoints") + // Citation index was midway between a single codepoint?? + } + return sumIndex + i + } + } + sumIndex += text.length + } + throw StringIndexOutOfBoundsException("Desired index $originalIndex is higher than content size $progress") +} diff --git a/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/EncodingTests.kt b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/EncodingTests.kt new file mode 100644 index 00000000000..05c9067b51b --- /dev/null +++ b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/EncodingTests.kt @@ -0,0 +1,55 @@ +package com.google.firebase.ai + +import com.google.firebase.ai.type.Candidate +import com.google.firebase.ai.type.Citation +import com.google.firebase.ai.type.CitationMetadata +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.TextPart +import com.google.firebase.ai.type.content +import com.google.firebase.ai.type.convertUtf8IndexToUtf16 +import io.kotest.matchers.shouldBe +import kotlinx.serialization.ExperimentalSerializationApi +import org.junit.Test + +@OptIn(PublicPreviewAPI::class, ExperimentalSerializationApi::class) +class EncodingTests { + val testStrings = listOf( + "hello world", + "¡Sí! Tengo muchos años.", + "🙂🤝📩", + "速度を上げて", + "", + ) + + @Test + fun `UTF-8 to UFT-16 index mapping matches length`() { + for (string in testStrings) { + val content = content { + text(string) + } + val ba = string.toByteArray(Charsets.UTF_8) + val index = convertUtf8IndexToUtf16(content, ba.size) + index shouldBe string.length + } + } + + @Test + fun `CitationMetadata gets converted to UTF-16`() { + val internalCandidate = Candidate.Internal( + content = Content.Internal("", listOf(TextPart.Internal("í abc í"))), + citationMetadata = CitationMetadata.Internal( + listOf(Citation.Internal( + startIndex = 3, + endIndex = 6, + )) + ) + ) + val candidate = internalCandidate.toPublic() + val start = candidate.citationMetadata!!.citations.first().startIndex + val end = candidate.citationMetadata.citations.first().endIndex + (candidate.content.parts.first() as TextPart).text.substring(start, end) shouldBe "abc" + start shouldBe 2 + end shouldBe 5 + } +} \ No newline at end of file