Skip to content

Commit 5e09554

Browse files
authored
fix(gemini): support output_dimension for Gemini (#1046)
* fix(gemini): support `output_dimension` for Gemini * fix: validate output vector dimension
1 parent 7d46ee4 commit 5e09554

File tree

3 files changed

+48
-11
lines changed

3 files changed

+48
-11
lines changed

src/llm/gemini.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,22 @@ impl LlmEmbeddingClient for AiStudioClient {
181181
if let Some(task_type) = request.task_type {
182182
payload["taskType"] = serde_json::Value::String(task_type.into());
183183
}
184+
if let Some(output_dimension) = request.output_dimension {
185+
payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
186+
}
184187
let resp = retryable::run(
185-
|| self.client.post(&url).json(&payload).send(),
188+
|| async {
189+
self.client
190+
.post(&url)
191+
.json(&payload)
192+
.send()
193+
.await?
194+
.error_for_status()
195+
},
186196
&retryable::HEAVY_LOADED_OPTIONS,
187197
)
188-
.await?;
189-
if !resp.status().is_success() {
190-
bail!(
191-
"Gemini API error: {:?}\n{}\n",
192-
resp.status(),
193-
resp.text().await?
194-
);
195-
}
198+
.await
199+
.context("Gemini API error")?;
196200
let embedding_resp: EmbedContentResponse = resp.json().await.context("Invalid JSON")?;
197201
Ok(super::LlmEmbeddingResponse {
198202
embedding: embedding_resp.embedding.values,
@@ -202,6 +206,10 @@ impl LlmEmbeddingClient for AiStudioClient {
202206
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
203207
get_embedding_dimension(model)
204208
}
209+
210+
fn behavior_version(&self) -> Option<u32> {
211+
Some(2)
212+
}
205213
}
206214

207215
pub struct VertexAiClient {

src/llm/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ pub trait LlmEmbeddingClient: Send + Sync {
9999
) -> Result<LlmEmbeddingResponse>;
100100

101101
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32>;
102+
103+
fn behavior_version(&self) -> Option<u32> {
104+
Some(1)
105+
}
102106
}
103107

104108
mod anthropic;

src/ops/functions/embed_text.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct Spec {
1818
struct Args {
1919
client: Box<dyn LlmEmbeddingClient>,
2020
text: ResolvedOpArg,
21+
expected_output_dimension: usize,
2122
}
2223

2324
struct Executor {
@@ -28,7 +29,7 @@ struct Executor {
2829
#[async_trait]
2930
impl SimpleFunctionExecutor for Executor {
3031
fn behavior_version(&self) -> Option<u32> {
31-
Some(1)
32+
self.args.client.behavior_version()
3233
}
3334

3435
fn enable_cache(&self) -> bool {
@@ -48,6 +49,23 @@ impl SimpleFunctionExecutor for Executor {
4849
.map(|s| Cow::Borrowed(s.as_str())),
4950
};
5051
let embedding = self.args.client.embed_text(req).await?;
52+
if embedding.embedding.len() != self.args.expected_output_dimension {
53+
if self.spec.output_dimension.is_some() {
54+
api_bail!(
55+
"Expected output dimension {expected} but got {actual} from the embedding API. \
56+
Consider setting `output_dimension` to {actual} or leave it unset to use the default.",
57+
expected = self.args.expected_output_dimension,
58+
actual = embedding.embedding.len()
59+
);
60+
} else {
61+
bail!(
62+
"Expected output dimension {expected} but got {actual} from the embedding API. \
63+
Consider setting `output_dimension` to {actual} as a workaround.",
64+
expected = self.args.expected_output_dimension,
65+
actual = embedding.embedding.len()
66+
)
67+
}
68+
}
5169
Ok(embedding.embedding.into())
5270
}
5371
}
@@ -87,7 +105,14 @@ impl SimpleFunctionFactoryBase for Factory {
87105
dimension: Some(output_dimension as usize),
88106
element_type: Box::new(BasicValueType::Float32),
89107
}));
90-
Ok((Args { client, text }, output_schema))
108+
Ok((
109+
Args {
110+
client,
111+
text,
112+
expected_output_dimension: output_dimension as usize,
113+
},
114+
output_schema,
115+
))
91116
}
92117

93118
async fn build_executor(

0 commit comments

Comments
 (0)