diff --git a/BitFun-Installer/src-tauri/Cargo.toml b/BitFun-Installer/src-tauri/Cargo.toml index 7f7c6431..562602c0 100644 --- a/BitFun-Installer/src-tauri/Cargo.toml +++ b/BitFun-Installer/src-tauri/Cargo.toml @@ -23,6 +23,7 @@ tauri-plugin-dialog = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["full"] } +tokio-stream = "0.1" anyhow = "1.0" log = "0.4" dirs = "5.0" @@ -30,8 +31,11 @@ zip = "0.6" flate2 = "1.0" tar = "0.4" chrono = "0.4" -reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] } urlencoding = "2" +futures = "0.3" +eventsource-stream = "0.2" +installer-ai-stream = { path = "crates/installer-ai-stream" } [target.'cfg(windows)'.dependencies] winreg = "0.52" diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/Cargo.toml b/BitFun-Installer/src-tauri/crates/installer-ai-stream/Cargo.toml new file mode 100644 index 00000000..d9d0c05a --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "installer-ai-stream" +version = "0.2.1" +edition = "2021" + +[lib] +path = "lib.rs" + +[dependencies] +anyhow = "1.0" +chrono = "0.4" +eventsource-stream = "0.2" +futures = "0.3" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tokio = { version = "1", features = ["rt-multi-thread", "macros", "io-util", "sync", "time"] } +log = "0.4" diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/lib.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/lib.rs new file mode 100644 index 00000000..d10cee11 --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/lib.rs @@ -0,0 +1,8 @@ +mod stream_handler; +mod types; + +pub use stream_handler::handle_anthropic_stream; +pub use stream_handler::handle_gemini_stream; +pub use stream_handler::handle_openai_stream; +pub use stream_handler::handle_responses_stream; +pub use types::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/anthropic.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/anthropic.rs new file mode 100644 index 00000000..60593ae5 --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/anthropic.rs @@ -0,0 +1,172 @@ +use super::stream_stats::StreamStats; +use crate::types::anthropic::{ + AnthropicSSEError, ContentBlock, ContentBlockDelta, ContentBlockStart, MessageDelta, + MessageStart, Usage, +}; +use crate::types::unified::UnifiedResponse; +use anyhow::{anyhow, Result}; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use log::{error, trace}; +use reqwest::Response; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; + +/// Convert a byte stream into a structured response stream +/// +/// # Arguments +/// * `response` - HTTP response +/// * `tx_event` - parsed event sender +/// * `tx_raw_sse` - optional raw SSE sender (collect raw data for diagnostics) +pub async fn handle_anthropic_stream( + response: Response, + tx_event: mpsc::UnboundedSender>, + tx_raw_sse: Option>, +) { + let mut stream = response.bytes_stream().eventsource(); + let idle_timeout = Duration::from_secs(600); + let mut usage = Usage::default(); + let mut stats = StreamStats::new("Anthropic"); + + loop { + let sse_event = timeout(idle_timeout, stream.next()).await; + let sse = match sse_event { + Ok(Some(Ok(sse))) => sse, + Ok(None) => { + let error_msg = "SSE Error: stream closed before response completed"; + stats.log_summary("stream_closed_before_completion"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Ok(Some(Err(e))) => { + let error_msg = format!("SSE Error: {}", e); + stats.log_summary("sse_stream_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Err(_) => { + let error_msg = "SSE Timeout: idle timeout waiting for SSE"; + stats.log_summary("sse_stream_timeout"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + trace!("Anthropic SSE: {:?}", sse); + let event_type = sse.event; + let data = sse.data; + stats.record_sse_event(&event_type); + + if let Some(ref tx) = tx_raw_sse { + let _ = tx.send(format!("[{}] {}", event_type, data)); + } + + match event_type.as_str() { + "message_start" => { + let message_start: MessageStart = match serde_json::from_str(&data) { + Ok(message_start) => message_start, + Err(e) => { + stats.increment("error:sse_parsing"); + let err_str = format!("SSE Parsing Error: {e}, data: {}", &data); + error!("{}", err_str); + continue; + } + }; + if let Some(message_usage) = message_start.message.usage { + usage.update(&message_usage); + } + } + "content_block_start" => { + let content_block_start: ContentBlockStart = match serde_json::from_str(&data) { + Ok(content_block_start) => content_block_start, + Err(e) => { + stats.increment("error:sse_parsing"); + let err_str = format!("SSE Parsing Error: {e}, data: {}", &data); + error!("{}", err_str); + continue; + } + }; + if matches!( + content_block_start.content_block, + ContentBlock::ToolUse { .. } + ) { + let unified_response = UnifiedResponse::from(content_block_start); + trace!("Anthropic unified response: {:?}", unified_response); + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + } + "content_block_delta" => { + let content_block_delta: ContentBlockDelta = match serde_json::from_str(&data) { + Ok(content_block_delta) => content_block_delta, + Err(e) => { + stats.increment("error:sse_parsing"); + let err_str = format!("SSE Parsing Error: {e}, data: {}", &data); + error!("{}", err_str); + continue; + } + }; + match UnifiedResponse::try_from(content_block_delta) { + Ok(unified_response) => { + trace!("Anthropic unified response: {:?}", unified_response); + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + Err(e) => { + stats.increment("skip:invalid_content_block_delta"); + error!("Skipping invalid content_block_delta: {}", e); + } + }; + } + "message_delta" => { + let mut message_delta: MessageDelta = match serde_json::from_str(&data) { + Ok(message_delta) => message_delta, + Err(e) => { + stats.increment("error:sse_parsing"); + let err_str = format!("SSE Parsing Error: {e}, data: {}", &data); + error!("{}", err_str); + continue; + } + }; + if let Some(delta_usage) = message_delta.usage.as_ref() { + usage.update(delta_usage); + } + message_delta.usage = if usage.is_empty() { + None + } else { + Some(usage.clone()) + }; + let unified_response = UnifiedResponse::from(message_delta); + trace!("Anthropic unified response: {:?}", unified_response); + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + "error" => { + let sse_error: AnthropicSSEError = match serde_json::from_str(&data) { + Ok(message_delta) => message_delta, + Err(e) => { + stats.increment("error:sse_parsing"); + let err_str = format!("SSE Parsing Error: {e}, data: {}", &data); + stats.log_summary("sse_parsing_error"); + error!("{}", err_str); + let _ = tx_event.send(Err(anyhow!(err_str))); + return; + } + }; + stats.increment("error:api"); + stats.log_summary("error_event_received"); + let _ = tx_event.send(Err(anyhow!(String::from(sse_error.error)))); + return; + } + "message_stop" => { + stats.log_summary("message_stop"); + return; + } + _ => {} + } + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/gemini.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/gemini.rs new file mode 100644 index 00000000..957ea6bb --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/gemini.rs @@ -0,0 +1,267 @@ +use super::stream_stats::StreamStats; +use crate::types::gemini::GeminiSSEData; +use crate::types::unified::UnifiedResponse; +use anyhow::{anyhow, Result}; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use log::{error, trace}; +use reqwest::Response; +use serde_json::Value; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; + +static GEMINI_STREAM_ID_SEQ: AtomicU64 = AtomicU64::new(1); + +#[derive(Debug)] +struct GeminiToolCallState { + active_name: Option, + active_id: Option, + stream_id: u64, + next_index: usize, +} + +impl GeminiToolCallState { + fn new() -> Self { + Self { + active_name: None, + active_id: None, + stream_id: GEMINI_STREAM_ID_SEQ.fetch_add(1, Ordering::Relaxed), + next_index: 0, + } + } + + fn on_non_tool_response(&mut self) { + self.active_name = None; + self.active_id = None; + } + + fn assign_id(&mut self, tool_call: &mut crate::types::unified::UnifiedToolCall) { + if let Some(existing_id) = tool_call.id.as_ref().filter(|value| !value.is_empty()) { + self.active_id = Some(existing_id.clone()); + self.active_name = tool_call.name.clone().filter(|value| !value.is_empty()); + return; + } + + let tool_name = tool_call.name.clone().filter(|value| !value.is_empty()); + let is_same_active_call = self.active_id.is_some() && self.active_name == tool_name; + + if is_same_active_call { + tool_call.id = None; + return; + } + + self.next_index += 1; + let generated_id = format!("gemini_call_{}_{}", self.stream_id, self.next_index); + tool_call.id = Some(generated_id.clone()); + self.active_id = Some(generated_id); + self.active_name = tool_name; + } +} + +fn extract_api_error_message(event_json: &Value) -> Option { + let error = event_json.get("error")?; + if let Some(message) = error.get("message").and_then(Value::as_str) { + return Some(message.to_string()); + } + if let Some(message) = error.as_str() { + return Some(message.to_string()); + } + Some("Gemini streaming request failed".to_string()) +} + +pub async fn handle_gemini_stream( + response: Response, + tx_event: mpsc::UnboundedSender>, + tx_raw_sse: Option>, +) { + let mut stream = response.bytes_stream().eventsource(); + let idle_timeout = Duration::from_secs(600); + let mut received_finish_reason = false; + let mut tool_call_state = GeminiToolCallState::new(); + let mut stats = StreamStats::new("Gemini"); + + loop { + let sse_event = timeout(idle_timeout, stream.next()).await; + let sse = match sse_event { + Ok(Some(Ok(sse))) => sse, + Ok(None) => { + if received_finish_reason { + stats.log_summary("stream_closed_after_finish_reason"); + return; + } + let error_msg = "Gemini SSE stream closed before response completed"; + stats.log_summary("stream_closed_before_completion"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Ok(Some(Err(e))) => { + let error_msg = format!("Gemini SSE stream error: {}", e); + stats.log_summary("sse_stream_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Err(_) => { + let error_msg = format!( + "Gemini SSE stream timeout after {}s", + idle_timeout.as_secs() + ); + stats.log_summary("sse_stream_timeout"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let raw = sse.data; + stats.record_sse_event("data"); + trace!("Gemini SSE: {:?}", raw); + + if let Some(ref tx) = tx_raw_sse { + let _ = tx.send(raw.clone()); + } + + if raw == "[DONE]" { + stats.increment("marker:done"); + stats.log_summary("done_marker_received"); + return; + } + + let event_json: Value = match serde_json::from_str(&raw) { + Ok(json) => json, + Err(e) => { + let error_msg = format!("Gemini SSE parsing error: {}, data: {}", e, raw); + stats.increment("error:sse_parsing"); + stats.log_summary("sse_parsing_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + if let Some(message) = extract_api_error_message(&event_json) { + let error_msg = format!("Gemini SSE API error: {}, data: {}", message, raw); + stats.increment("error:api"); + stats.log_summary("sse_api_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + + let sse_data: GeminiSSEData = match serde_json::from_value(event_json) { + Ok(data) => data, + Err(e) => { + let error_msg = format!("Gemini SSE data schema error: {}, data: {}", e, raw); + stats.increment("error:schema"); + stats.log_summary("sse_data_schema_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let mut unified_responses = sse_data.into_unified_responses(); + if unified_responses.is_empty() { + stats.increment("skip:empty_unified_responses"); + } + for unified_response in &mut unified_responses { + if let Some(tool_call) = unified_response.tool_call.as_mut() { + tool_call_state.assign_id(tool_call); + } else { + tool_call_state.on_non_tool_response(); + } + + if unified_response.finish_reason.is_some() { + received_finish_reason = true; + tool_call_state.on_non_tool_response(); + } + } + + for unified_response in unified_responses { + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + } +} + +#[cfg(test)] +mod tests { + use super::GeminiToolCallState; + use crate::types::unified::UnifiedToolCall; + + #[test] + fn reuses_active_tool_id_by_omitting_follow_up_ids() { + let mut state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{\"city\":".to_string()), + }; + state.assign_id(&mut first); + + let mut second = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("\"Paris\"}".to_string()), + }; + state.assign_id(&mut second); + + assert!(first + .id + .as_deref() + .is_some_and(|id| id.starts_with("gemini_call_"))); + assert!(second.id.is_none()); + } + + #[test] + fn clears_active_tool_after_non_tool_response() { + let mut state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{}".to_string()), + }; + state.assign_id(&mut first); + state.on_non_tool_response(); + + let mut second = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{}".to_string()), + }; + state.assign_id(&mut second); + + let first_id = first.id.expect("first id"); + let second_id = second.id.expect("second id"); + assert!(first_id.starts_with("gemini_call_")); + assert!(second_id.starts_with("gemini_call_")); + assert_ne!(first_id, second_id); + } + + #[test] + fn generates_unique_prefixes_across_streams() { + let mut first_state = GeminiToolCallState::new(); + let mut second_state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("grep".to_string()), + arguments: Some("{}".to_string()), + }; + let mut second = UnifiedToolCall { + id: None, + name: Some("read".to_string()), + arguments: Some("{}".to_string()), + }; + + first_state.assign_id(&mut first); + second_state.assign_id(&mut second); + + assert_ne!(first.id, second.id); + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/mod.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/mod.rs new file mode 100644 index 00000000..31f2b8c4 --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/mod.rs @@ -0,0 +1,10 @@ +mod stream_stats; +mod anthropic; +mod gemini; +mod openai; +mod responses; + +pub use anthropic::handle_anthropic_stream; +pub use gemini::handle_gemini_stream; +pub use openai::handle_openai_stream; +pub use responses::handle_responses_stream; diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/openai.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/openai.rs new file mode 100644 index 00000000..113498e3 --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/openai.rs @@ -0,0 +1,770 @@ +use super::stream_stats::StreamStats; +use crate::types::openai::OpenAISSEData; +use crate::types::unified::{UnifiedResponse, UnifiedTokenUsage}; +use anyhow::{anyhow, Result}; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use log::{error, trace, warn}; +use reqwest::Response; +use serde_json::Value; +use std::collections::HashSet; +use std::mem; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; + +const OPENAI_CHAT_COMPLETION_CHUNK_OBJECT: &str = "chat.completion.chunk"; +const INLINE_THINK_OPEN_TAG: &str = ""; +const INLINE_THINK_CLOSE_TAG: &str = ""; + +#[derive(Debug, Default)] +struct OpenAIToolCallFilter { + seen_tool_call_ids: HashSet, +} + +impl OpenAIToolCallFilter { + fn normalize_response(&mut self, mut response: UnifiedResponse) -> Option { + let Some(tool_call) = response.tool_call.as_ref() else { + return Some(response); + }; + + let tool_id = tool_call.id.as_ref().filter(|value| !value.is_empty()).cloned(); + let has_name = tool_call + .name + .as_ref() + .is_some_and(|value| !value.is_empty()); + let has_arguments = tool_call + .arguments + .as_ref() + .is_some_and(|value| !value.is_empty()); + + if let Some(tool_id) = tool_id { + let seen_before = self.seen_tool_call_ids.contains(&tool_id); + self.seen_tool_call_ids.insert(tool_id); + + // OpenAI-compatible providers may emit a trailing chunk that only repeats an + // already-seen tool id after the arguments have completed. It carries no new + // information and should not reopen a fresh tool-call buffer downstream. + if seen_before && !has_name && !has_arguments { + response.tool_call = None; + return Self::keep_if_non_empty(response); + } + } else if !has_name && !has_arguments { + response.tool_call = None; + return Self::keep_if_non_empty(response); + } + + Some(response) + } + + fn keep_if_non_empty(response: UnifiedResponse) -> Option { + if response.text.is_some() + || response.reasoning_content.is_some() + || response.thinking_signature.is_some() + || response.tool_call.is_some() + || response.usage.is_some() + || response.finish_reason.is_some() + || response.provider_metadata.is_some() + { + Some(response) + } else { + None + } + } +} + +#[derive(Debug, Default)] +struct DeferredResponseMeta { + usage: Option, + finish_reason: Option, + provider_metadata: Option, +} + +impl DeferredResponseMeta { + fn from_response(response: &mut UnifiedResponse) -> Self { + Self { + usage: response.usage.take(), + finish_reason: response.finish_reason.take(), + provider_metadata: response.provider_metadata.take(), + } + } + + fn merge(&mut self, other: Self) { + if other.usage.is_some() { + self.usage = other.usage; + } + if other.finish_reason.is_some() { + self.finish_reason = other.finish_reason; + } + if other.provider_metadata.is_some() { + self.provider_metadata = other.provider_metadata; + } + } + + fn apply_to(self, response: &mut UnifiedResponse) { + if response.usage.is_none() { + response.usage = self.usage; + } + if response.finish_reason.is_none() { + response.finish_reason = self.finish_reason; + } + if response.provider_metadata.is_none() { + response.provider_metadata = self.provider_metadata; + } + } + + fn is_empty(&self) -> bool { + self.usage.is_none() && self.finish_reason.is_none() && self.provider_metadata.is_none() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InlineThinkActivation { + Unknown, + Enabled, + Disabled, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InlineThinkMode { + Text, + Thinking, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum InlineThinkSegment { + Text(String), + Thinking(String), +} + +#[derive(Debug)] +struct OpenAIInlineThinkParser { + enabled: bool, + activation: InlineThinkActivation, + mode: InlineThinkMode, + pending_tail: String, + initial_probe: String, + deferred_meta: DeferredResponseMeta, +} + +impl OpenAIInlineThinkParser { + fn new(enabled: bool) -> Self { + Self { + enabled, + activation: InlineThinkActivation::Unknown, + mode: InlineThinkMode::Text, + pending_tail: String::new(), + initial_probe: String::new(), + deferred_meta: DeferredResponseMeta::default(), + } + } + + fn normalize_response(&mut self, mut response: UnifiedResponse) -> Vec { + if !self.enabled { + return vec![response]; + } + + let Some(text) = response.text.take() else { + return vec![response]; + }; + + // Respect providers that already emit native reasoning chunks. + if response.reasoning_content.is_some() + || response.tool_call.is_some() + || response.thinking_signature.is_some() + { + response.text = Some(text); + return vec![response]; + } + + let current_meta = DeferredResponseMeta::from_response(&mut response); + let segments = match self.activation { + InlineThinkActivation::Unknown => self.consume_unknown_text(text), + InlineThinkActivation::Enabled => self.parse_enabled_text(text), + InlineThinkActivation::Disabled => vec![InlineThinkSegment::Text(text)], + }; + + self.attach_meta_to_segments(segments, current_meta) + } + + fn flush(&mut self) -> Vec { + if !self.enabled { + return Vec::new(); + } + + let segments = match self.activation { + InlineThinkActivation::Unknown => { + let pending = mem::take(&mut self.initial_probe); + if pending.is_empty() { + Vec::new() + } else { + vec![InlineThinkSegment::Text(pending)] + } + } + InlineThinkActivation::Enabled => { + let pending = mem::take(&mut self.pending_tail); + if pending.is_empty() { + Vec::new() + } else if self.mode == InlineThinkMode::Thinking { + vec![InlineThinkSegment::Thinking(pending)] + } else { + vec![InlineThinkSegment::Text(pending)] + } + } + InlineThinkActivation::Disabled => Vec::new(), + }; + + self.attach_meta_to_segments(segments, DeferredResponseMeta::default()) + } + + fn consume_unknown_text(&mut self, text: String) -> Vec { + self.initial_probe.push_str(&text); + + let trimmed = self.initial_probe.trim_start_matches(char::is_whitespace); + if trimmed.is_empty() { + return Vec::new(); + } + + if trimmed.starts_with(INLINE_THINK_OPEN_TAG) { + self.activation = InlineThinkActivation::Enabled; + let buffered = mem::take(&mut self.initial_probe); + return self.parse_enabled_text(buffered); + } + + if INLINE_THINK_OPEN_TAG.starts_with(trimmed) { + return Vec::new(); + } + + self.activation = InlineThinkActivation::Disabled; + vec![InlineThinkSegment::Text(mem::take(&mut self.initial_probe))] + } + + fn parse_enabled_text(&mut self, text: String) -> Vec { + let mut data = mem::take(&mut self.pending_tail); + data.push_str(&text); + + let mut segments = Vec::new(); + + loop { + let marker = match self.mode { + InlineThinkMode::Text => INLINE_THINK_OPEN_TAG, + InlineThinkMode::Thinking => INLINE_THINK_CLOSE_TAG, + }; + + if let Some(marker_idx) = data.find(marker) { + let before_marker = data[..marker_idx].to_string(); + self.push_segment(&mut segments, before_marker); + + data = data[marker_idx + marker.len()..].to_string(); + self.mode = match self.mode { + InlineThinkMode::Text => InlineThinkMode::Thinking, + InlineThinkMode::Thinking => InlineThinkMode::Text, + }; + continue; + } + + let tail_len = longest_suffix_prefix_len(&data, marker); + let flush_len = data.len() - tail_len; + let ready = data[..flush_len].to_string(); + self.push_segment(&mut segments, ready); + self.pending_tail = data[flush_len..].to_string(); + break; + } + + segments + } + + fn push_segment(&self, segments: &mut Vec, content: String) { + if content.is_empty() { + return; + } + + match self.mode { + InlineThinkMode::Text => segments.push(InlineThinkSegment::Text(content)), + InlineThinkMode::Thinking => segments.push(InlineThinkSegment::Thinking(content)), + } + } + + fn attach_meta_to_segments( + &mut self, + segments: Vec, + current_meta: DeferredResponseMeta, + ) -> Vec { + let mut merged_meta = mem::take(&mut self.deferred_meta); + merged_meta.merge(current_meta); + + let mut responses: Vec = segments + .into_iter() + .map(|segment| match segment { + InlineThinkSegment::Text(text) => UnifiedResponse { + text: Some(text), + ..Default::default() + }, + InlineThinkSegment::Thinking(reasoning_content) => UnifiedResponse { + reasoning_content: Some(reasoning_content), + ..Default::default() + }, + }) + .collect(); + + if let Some(last_response) = responses.last_mut() { + merged_meta.apply_to(last_response); + } else if !merged_meta.is_empty() { + self.deferred_meta = merged_meta; + } + + responses + } +} + +#[derive(Debug)] +struct OpenAIResponseNormalizer { + tool_call_filter: OpenAIToolCallFilter, + inline_think_parser: OpenAIInlineThinkParser, +} + +impl OpenAIResponseNormalizer { + fn new(inline_think_in_text: bool) -> Self { + Self { + tool_call_filter: OpenAIToolCallFilter::default(), + inline_think_parser: OpenAIInlineThinkParser::new(inline_think_in_text), + } + } + + fn normalize_response(&mut self, response: UnifiedResponse) -> Vec { + let Some(response) = self.tool_call_filter.normalize_response(response) else { + return Vec::new(); + }; + + self.inline_think_parser.normalize_response(response) + } + + fn flush(&mut self) -> Vec { + self.inline_think_parser.flush() + } +} + +fn longest_suffix_prefix_len(value: &str, marker: &str) -> usize { + let max_len = value.len().min(marker.len().saturating_sub(1)); + (1..=max_len) + .rev() + .find(|&len| value.ends_with(&marker[..len])) + .unwrap_or(0) +} + +fn is_valid_chat_completion_chunk_weak(event_json: &Value) -> bool { + matches!( + event_json.get("object").and_then(|value| value.as_str()), + Some(OPENAI_CHAT_COMPLETION_CHUNK_OBJECT) + ) +} + +fn extract_sse_api_error_message(event_json: &Value) -> Option { + let error = event_json.get("error")?; + if let Some(message) = error.get("message").and_then(|value| value.as_str()) { + return Some(message.to_string()); + } + if let Some(message) = error.as_str() { + return Some(message.to_string()); + } + Some("An error occurred during streaming".to_string()) +} + +/// Convert a byte stream into a structured response stream +/// +/// # Arguments +/// * `response` - HTTP response +/// * `tx_event` - parsed event sender +/// * `tx_raw_sse` - optional raw SSE sender (collect raw data for diagnostics) +pub async fn handle_openai_stream( + response: Response, + tx_event: mpsc::UnboundedSender>, + tx_raw_sse: Option>, + inline_think_in_text: bool, +) { + let mut stream = response.bytes_stream().eventsource(); + let idle_timeout = Duration::from_secs(600); + let mut stats = StreamStats::new("OpenAI"); + // Track whether a chunk with `finish_reason` was received. + // Some providers (e.g. MiniMax) close the stream after the final chunk + // without sending `[DONE]`, so we treat `Ok(None)` as a normal termination + // when a finish_reason has already been seen. + let mut received_finish_reason = false; + let mut normalizer = OpenAIResponseNormalizer::new(inline_think_in_text); + + loop { + let sse_event = timeout(idle_timeout, stream.next()).await; + let sse = match sse_event { + Ok(Some(Ok(sse))) => sse, + Ok(None) => { + if received_finish_reason { + for normalized_response in normalizer.flush() { + stats.record_unified_response(&normalized_response); + let _ = tx_event.send(Ok(normalized_response)); + } + stats.log_summary("stream_closed_after_finish_reason"); + return; + } + let error_msg = "SSE stream closed before response completed"; + stats.log_summary("stream_closed_before_completion"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Ok(Some(Err(e))) => { + let error_msg = format!("SSE stream error: {}", e); + stats.log_summary("sse_stream_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Err(_) => { + let error_msg = format!("SSE stream timeout after {}s", idle_timeout.as_secs()); + stats.log_summary("sse_stream_timeout"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let raw = sse.data; + stats.record_sse_event("data"); + trace!("OpenAI SSE: {:?}", raw); + if let Some(ref tx) = tx_raw_sse { + let _ = tx.send(raw.clone()); + } + if raw == "[DONE]" { + for normalized_response in normalizer.flush() { + stats.record_unified_response(&normalized_response); + let _ = tx_event.send(Ok(normalized_response)); + } + stats.increment("marker:done"); + stats.log_summary("done_marker_received"); + return; + } + + let event_json: Value = match serde_json::from_str(&raw) { + Ok(json) => json, + Err(e) => { + let error_msg = format!("SSE parsing error: {}, data: {}", e, &raw); + stats.increment("error:sse_parsing"); + stats.log_summary("sse_parsing_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + if let Some(api_error_message) = extract_sse_api_error_message(&event_json) { + let error_msg = format!("SSE API error: {}, data: {}", api_error_message, raw); + stats.increment("error:api"); + stats.log_summary("sse_api_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + + if !is_valid_chat_completion_chunk_weak(&event_json) { + stats.increment("skip:non_standard_event"); + warn!( + "Skipping non-standard OpenAI SSE event; object={}", + event_json + .get("object") + .and_then(|value| value.as_str()) + .unwrap_or("") + ); + continue; + } + + stats.increment("chunk:chat_completion"); + let sse_data: OpenAISSEData = match serde_json::from_value(event_json) { + Ok(event) => event, + Err(e) => { + let error_msg = format!("SSE data schema error: {}, data: {}", e, &raw); + stats.increment("error:schema"); + stats.log_summary("sse_data_schema_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let tool_call_count = sse_data.first_choice_tool_call_count(); + if tool_call_count > 1 { + stats.increment("chunk:multi_tool_call"); + warn!( + "OpenAI SSE chunk contains {} tool calls in the first choice; splitting and sending sequentially", + tool_call_count + ); + } + + let has_empty_choices = sse_data.is_choices_empty(); + let unified_responses = sse_data.into_unified_responses(); + trace!("OpenAI unified responses: {:?}", unified_responses); + if unified_responses.is_empty() { + if has_empty_choices { + stats.increment("skip:empty_choices_no_usage"); + warn!( + "Ignoring OpenAI SSE chunk with empty choices and no usage payload: {}", + raw + ); + // Ignore keepalive/metadata chunks with empty choices and no usage payload. + continue; + } + // Defensive fallback: this should be unreachable if OpenAISSEData::into_unified_responses + // keeps returning at least one event for all non-empty-choices chunks. + let error_msg = format!("OpenAI SSE chunk produced no unified events, data: {}", raw); + stats.increment("error:no_unified_events"); + stats.log_summary("no_unified_events"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + + for unified_response in unified_responses { + let normalized_responses = normalizer.normalize_response(unified_response); + if normalized_responses.is_empty() { + continue; + } + + for normalized_response in normalized_responses { + if normalized_response.finish_reason.is_some() { + received_finish_reason = true; + } + stats.record_unified_response(&normalized_response); + let _ = tx_event.send(Ok(normalized_response)); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + extract_sse_api_error_message, is_valid_chat_completion_chunk_weak, + longest_suffix_prefix_len, InlineThinkActivation, InlineThinkMode, OpenAIInlineThinkParser, + OpenAIToolCallFilter, + }; + use crate::types::unified::{UnifiedResponse, UnifiedToolCall}; + + #[test] + fn weak_filter_accepts_chat_completion_chunk() { + let event = serde_json::json!({ + "object": "chat.completion.chunk" + }); + assert!(is_valid_chat_completion_chunk_weak(&event)); + } + + #[test] + fn weak_filter_rejects_non_standard_object() { + let event = serde_json::json!({ + "object": "" + }); + assert!(!is_valid_chat_completion_chunk_weak(&event)); + } + + #[test] + fn weak_filter_rejects_missing_object() { + let event = serde_json::json!({ + "id": "chatcmpl_test" + }); + assert!(!is_valid_chat_completion_chunk_weak(&event)); + } + + #[test] + fn extracts_api_error_message_from_object_shape() { + let event = serde_json::json!({ + "error": { + "message": "provider error" + } + }); + assert_eq!( + extract_sse_api_error_message(&event).as_deref(), + Some("provider error") + ); + } + + #[test] + fn extracts_api_error_message_from_string_shape() { + let event = serde_json::json!({ + "error": "provider error" + }); + assert_eq!( + extract_sse_api_error_message(&event).as_deref(), + Some("provider error") + ); + } + + #[test] + fn returns_none_when_no_error_payload_exists() { + let event = serde_json::json!({ + "object": "chat.completion.chunk" + }); + assert!(extract_sse_api_error_message(&event).is_none()); + } + + #[test] + fn drops_redundant_empty_tool_call_after_same_id_was_seen() { + let mut filter = OpenAIToolCallFilter::default(); + + let first = UnifiedResponse { + tool_call: Some(UnifiedToolCall { + id: Some("call_1".to_string()), + name: Some("read_file".to_string()), + arguments: Some("{\"path\":\"a.txt\"}".to_string()), + }), + ..Default::default() + }; + let trailing_empty = UnifiedResponse { + tool_call: Some(UnifiedToolCall { + id: Some("call_1".to_string()), + name: None, + arguments: Some(String::new()), + }), + ..Default::default() + }; + + assert!(filter.normalize_response(first).is_some()); + assert!(filter.normalize_response(trailing_empty).is_none()); + } + + #[test] + fn keeps_finish_reason_when_redundant_tool_call_is_stripped() { + let mut filter = OpenAIToolCallFilter::default(); + + let first = UnifiedResponse { + tool_call: Some(UnifiedToolCall { + id: Some("call_1".to_string()), + name: Some("read_file".to_string()), + arguments: Some("{\"path\":\"a.txt\"}".to_string()), + }), + ..Default::default() + }; + let trailing_empty = UnifiedResponse { + tool_call: Some(UnifiedToolCall { + id: Some("call_1".to_string()), + name: None, + arguments: None, + }), + finish_reason: Some("tool_calls".to_string()), + ..Default::default() + }; + + assert!(filter.normalize_response(first).is_some()); + let normalized = filter + .normalize_response(trailing_empty) + .expect("finish_reason should be preserved"); + assert!(normalized.tool_call.is_none()); + assert_eq!(normalized.finish_reason.as_deref(), Some("tool_calls")); + } + + #[test] + fn longest_suffix_prefix_len_detects_partial_tag_boundary() { + assert_eq!(longest_suffix_prefix_len(""), 4); + assert_eq!(longest_suffix_prefix_len("answer", ""), 0); + } + + #[test] + fn inline_think_parser_streams_thinking_and_text_per_chunk() { + let mut parser = OpenAIInlineThinkParser::new(true); + + let chunk1 = parser.normalize_response(UnifiedResponse { + text: Some("abc".to_string()), + ..Default::default() + }); + let chunk2 = parser.normalize_response(UnifiedResponse { + text: Some("defghi".to_string()), + ..Default::default() + }); + + assert_eq!(chunk1.len(), 1); + assert_eq!(chunk1[0].reasoning_content.as_deref(), Some("abc")); + assert_eq!(chunk2.len(), 2); + assert_eq!(chunk2[0].reasoning_content.as_deref(), Some("def")); + assert_eq!(chunk2[1].text.as_deref(), Some("ghi")); + } + + #[test] + fn inline_think_parser_handles_split_opening_tag() { + let mut parser = OpenAIInlineThinkParser::new(true); + + let first = parser.normalize_response(UnifiedResponse { + text: Some("hello".to_string()), + ..Default::default() + }); + + assert!(first.is_empty()); + assert_eq!(second.len(), 1); + assert_eq!(second[0].reasoning_content.as_deref(), Some("hello")); + } + + #[test] + fn inline_think_parser_disables_when_first_text_is_not_think_tag() { + let mut parser = OpenAIInlineThinkParser::new(true); + + let first = parser.normalize_response(UnifiedResponse { + text: Some("hello literal".to_string()), + ..Default::default() + }); + let second = parser.normalize_response(UnifiedResponse { + text: Some(" world".to_string()), + ..Default::default() + }); + + assert_eq!(first.len(), 1); + assert_eq!(first[0].text.as_deref(), Some("hello literal")); + assert_eq!(second.len(), 1); + assert_eq!(second[0].text.as_deref(), Some(" world")); + assert_eq!(parser.activation, InlineThinkActivation::Disabled); + assert_eq!(parser.mode, InlineThinkMode::Text); + } + + #[test] + fn inline_think_parser_preserves_finish_reason_on_last_segment() { + let mut parser = OpenAIInlineThinkParser::new(true); + + let responses = parser.normalize_response(UnifiedResponse { + text: Some("abcdone".to_string()), + finish_reason: Some("stop".to_string()), + ..Default::default() + }); + + assert_eq!(responses.len(), 2); + assert_eq!(responses[0].reasoning_content.as_deref(), Some("abc")); + assert_eq!(responses[1].text.as_deref(), Some("done")); + assert_eq!(responses[1].finish_reason.as_deref(), Some("stop")); + } + + #[test] + fn inline_think_parser_flushes_unclosed_thinking_at_stream_end() { + let mut parser = OpenAIInlineThinkParser::new(true); + + let first = parser.normalize_response(UnifiedResponse { + text: Some("abc".to_string()), + ..Default::default() + }); + let flushed = parser.flush(); + + assert_eq!(first.len(), 1); + assert_eq!(first[0].reasoning_content.as_deref(), Some("abc")); + assert!(flushed.is_empty()); + } + + #[test] + fn inline_think_parser_passthrough_when_feature_disabled() { + let mut parser = OpenAIInlineThinkParser::new(false); + + let responses = parser.normalize_response(UnifiedResponse { + text: Some("abcdone".to_string()), + ..Default::default() + }); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].text.as_deref(), Some("abcdone")); + assert!(responses[0].reasoning_content.is_none()); + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/responses.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/responses.rs new file mode 100644 index 00000000..48bd55a3 --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/responses.rs @@ -0,0 +1,613 @@ +use super::stream_stats::StreamStats; +use crate::types::responses::{ + parse_responses_output_item, ResponsesCompleted, ResponsesDone, ResponsesStreamEvent, +}; +use crate::types::unified::UnifiedResponse; +use anyhow::{anyhow, Result}; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use log::{error, trace}; +use reqwest::Response; +use serde_json::Value; +use std::collections::HashMap; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; + +#[derive(Debug, Default, Clone)] +struct InProgressToolCall { + call_id: Option, + name: Option, + args_so_far: String, + saw_any_delta: bool, + sent_header: bool, +} + +impl InProgressToolCall { + fn from_item_value(item: &Value) -> Option { + if item.get("type").and_then(Value::as_str) != Some("function_call") { + return None; + } + Some(Self { + call_id: item + .get("call_id") + .and_then(Value::as_str) + .map(ToString::to_string), + name: item + .get("name") + .and_then(Value::as_str) + .map(ToString::to_string), + args_so_far: String::new(), + saw_any_delta: false, + sent_header: false, + }) + } +} + +fn emit_tool_call_item( + tx_event: &mpsc::UnboundedSender>, + stats: &mut StreamStats, + item_value: Value, +) { + if let Some(unified_response) = parse_responses_output_item(item_value) { + if unified_response.tool_call.is_some() { + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + } +} + +fn cleanup_tool_call_tracking( + output_index: usize, + tool_calls_by_output_index: &mut HashMap, + tool_call_index_by_id: &mut HashMap, +) { + if let Some(tc) = tool_calls_by_output_index.remove(&output_index) { + if let Some(call_id) = tc.call_id { + tool_call_index_by_id.remove(&call_id); + } + } +} + +fn handle_function_call_output_item_done( + tx_event: &mpsc::UnboundedSender>, + stats: &mut StreamStats, + event_output_index: Option, + item_value: Value, + tool_calls_by_output_index: &mut HashMap, + tool_call_index_by_id: &mut HashMap, +) { + // Resolve output_index either directly or via call_id mapping. + let output_index = event_output_index.or_else(|| { + item_value + .get("call_id") + .and_then(Value::as_str) + .and_then(|id| tool_call_index_by_id.get(id).copied()) + }); + + let Some(output_index) = output_index else { + emit_tool_call_item(tx_event, stats, item_value); + return; + }; + + let Some(tc) = tool_calls_by_output_index.get_mut(&output_index) else { + // The provider may send `output_item.done` with an output_index even when the + // earlier `output_item.added` event was omitted or missed. Fall back to the full item. + emit_tool_call_item(tx_event, stats, item_value); + return; + }; + + let full_args = item_value + .get("arguments") + .and_then(Value::as_str) + .unwrap_or_default(); + let need_fallback_full = !tc.saw_any_delta; + let need_tail = tc.saw_any_delta + && tc.args_so_far.len() < full_args.len() + && full_args.starts_with(&tc.args_so_far); + + if need_fallback_full || need_tail { + let delta = if need_fallback_full { + full_args.to_string() + } else { + full_args[tc.args_so_far.len()..].to_string() + }; + + if !delta.is_empty() { + tc.args_so_far.push_str(&delta); + let (id, name) = if tc.sent_header { + (None, None) + } else { + tc.sent_header = true; + (tc.call_id.clone(), tc.name.clone()) + }; + let unified_response = UnifiedResponse { + tool_call: Some(crate::types::unified::UnifiedToolCall { + id, + name, + arguments: Some(delta), + }), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + } + + cleanup_tool_call_tracking( + output_index, + tool_calls_by_output_index, + tool_call_index_by_id, + ); +} + +fn extract_api_error_message(event_json: &Value) -> Option { + let response = event_json.get("response")?; + let error = response.get("error")?; + + if error.is_null() { + return None; + } + + if let Some(message) = error.get("message").and_then(Value::as_str) { + return Some(message.to_string()); + } + if let Some(message) = error.as_str() { + return Some(message.to_string()); + } + + Some("An error occurred during responses streaming".to_string()) +} + +pub async fn handle_responses_stream( + response: Response, + tx_event: mpsc::UnboundedSender>, + tx_raw_sse: Option>, +) { + let mut stream = response.bytes_stream().eventsource(); + let idle_timeout = Duration::from_secs(600); + // Some providers close the stream after emitting the terminal event and may not send `[DONE]`. + let mut received_finish_reason = false; + let mut received_text_delta = false; + let mut tool_calls_by_output_index: HashMap = HashMap::new(); + let mut tool_call_index_by_id: HashMap = HashMap::new(); + let mut stats = StreamStats::new("Responses"); + + loop { + let sse_event = timeout(idle_timeout, stream.next()).await; + let sse = match sse_event { + Ok(Some(Ok(sse))) => sse, + Ok(None) => { + if received_finish_reason { + stats.log_summary("stream_closed_after_finish_reason"); + return; + } + let error_msg = "Responses SSE stream closed before response completed"; + stats.log_summary("stream_closed_before_completion"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Ok(Some(Err(e))) => { + let error_msg = format!("Responses SSE stream error: {}", e); + stats.log_summary("sse_stream_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Err(_) => { + let error_msg = format!( + "Responses SSE stream timeout after {}s", + idle_timeout.as_secs() + ); + stats.log_summary("sse_stream_timeout"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let raw = sse.data; + stats.record_sse_event("data"); + trace!("Responses SSE: {:?}", raw); + if let Some(ref tx) = tx_raw_sse { + let _ = tx.send(raw.clone()); + } + if raw == "[DONE]" { + stats.increment("marker:done"); + stats.log_summary("done_marker_received"); + return; + } + + let event_json: Value = match serde_json::from_str(&raw) { + Ok(json) => json, + Err(e) => { + let error_msg = format!("Responses SSE parsing error: {}, data: {}", e, &raw); + stats.increment("error:sse_parsing"); + stats.log_summary("sse_parsing_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + if let Some(api_error_message) = extract_api_error_message(&event_json) { + let error_msg = format!( + "Responses SSE API error: {}, data: {}", + api_error_message, raw + ); + stats.increment("error:api"); + stats.log_summary("sse_api_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + + let event: ResponsesStreamEvent = match serde_json::from_value(event_json) { + Ok(event) => event, + Err(e) => { + let error_msg = format!("Responses SSE schema error: {}, data: {}", e, &raw); + stats.increment("error:schema"); + stats.log_summary("sse_schema_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + stats.increment(format!("event:{}", event.kind)); + + match event.kind.as_str() { + "response.output_item.added" => { + // Track tool calls so we can stream arguments via `response.function_call_arguments.delta`. + if let (Some(output_index), Some(item)) = (event.output_index, event.item.as_ref()) + { + if let Some(tc) = InProgressToolCall::from_item_value(item) { + if let Some(ref call_id) = tc.call_id { + tool_call_index_by_id.insert(call_id.clone(), output_index); + } + tool_calls_by_output_index.insert(output_index, tc); + } + } + } + "response.output_text.delta" => { + if let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) { + received_text_delta = true; + let unified_response = UnifiedResponse { + text: Some(delta), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + } + "response.reasoning_text.delta" | "response.reasoning_summary_text.delta" => { + if let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) { + let unified_response = UnifiedResponse { + reasoning_content: Some(delta), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + } + "response.function_call_arguments.delta" => { + let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) else { + continue; + }; + let Some(output_index) = event.output_index else { + continue; + }; + let Some(tc) = tool_calls_by_output_index.get_mut(&output_index) else { + continue; + }; + + tc.saw_any_delta = true; + tc.args_so_far.push_str(&delta); + + // Some consumers treat `id` as a "new tool call" marker and reset buffers when it repeats. + // Only send id/name once per tool call; deltas that follow carry arguments only. + let (id, name) = if tc.sent_header { + (None, None) + } else { + tc.sent_header = true; + (tc.call_id.clone(), tc.name.clone()) + }; + + let unified_response = UnifiedResponse { + tool_call: Some(crate::types::unified::UnifiedToolCall { + id, + name, + arguments: Some(delta), + }), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + "response.output_item.done" => { + let Some(item_value) = event.item else { + continue; + }; + + // For tool calls, prefer streaming deltas and only use item.done as a tail-filler / fallback. + if item_value.get("type").and_then(Value::as_str) == Some("function_call") { + handle_function_call_output_item_done( + &tx_event, + &mut stats, + event.output_index, + item_value, + &mut tool_calls_by_output_index, + &mut tool_call_index_by_id, + ); + continue; + } + + if let Some(mut unified_response) = parse_responses_output_item(item_value) { + if received_text_delta && unified_response.text.is_some() { + unified_response.text = None; + } + if unified_response.text.is_some() || unified_response.tool_call.is_some() { + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + } + } + "response.completed" => { + if received_finish_reason { + continue; + } + // Best-effort: use the final response object to fill any missing tool-call argument tail. + if let Some(response_val) = event.response.as_ref() { + if let Some(output) = response_val.get("output").and_then(Value::as_array) { + for (idx, item) in output.iter().enumerate() { + if item.get("type").and_then(Value::as_str) != Some("function_call") { + continue; + } + let Some(tc) = tool_calls_by_output_index.get_mut(&idx) else { + continue; + }; + let full_args = item + .get("arguments") + .and_then(Value::as_str) + .unwrap_or_default(); + if tc.args_so_far.len() < full_args.len() + && full_args.starts_with(&tc.args_so_far) + { + let delta = full_args[tc.args_so_far.len()..].to_string(); + if !delta.is_empty() { + tc.args_so_far.push_str(&delta); + let (id, name) = if tc.sent_header { + (None, None) + } else { + tc.sent_header = true; + (tc.call_id.clone(), tc.name.clone()) + }; + let unified_response = UnifiedResponse { + tool_call: Some(crate::types::unified::UnifiedToolCall { + id, + name, + arguments: Some(delta), + }), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + } + } + } + } + } + match event + .response + .map(serde_json::from_value::) + { + Some(Ok(response)) => { + received_finish_reason = true; + let unified_response = UnifiedResponse { + usage: response.usage.map(Into::into), + finish_reason: Some("stop".to_string()), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + continue; + } + Some(Err(e)) => { + let error_msg = + format!("Failed to parse response.completed payload: {}", e); + stats.increment("error:completed_payload"); + stats.log_summary("response_completed_parse_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + None => { + received_finish_reason = true; + let unified_response = UnifiedResponse { + finish_reason: Some("stop".to_string()), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + continue; + } + } + } + "response.done" => { + if received_finish_reason { + continue; + } + match event.response.map(serde_json::from_value::) { + Some(Ok(response)) => { + received_finish_reason = true; + let unified_response = UnifiedResponse { + usage: response.usage.map(Into::into), + finish_reason: Some("stop".to_string()), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + continue; + } + Some(Err(e)) => { + let error_msg = format!("Failed to parse response.done payload: {}", e); + stats.increment("error:done_payload"); + stats.log_summary("response_done_parse_error"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + None => { + received_finish_reason = true; + let unified_response = UnifiedResponse { + finish_reason: Some("stop".to_string()), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + continue; + } + } + } + "response.failed" => { + let error_msg = event + .response + .as_ref() + .and_then(|response| response.get("error")) + .and_then(|error| error.get("message")) + .and_then(Value::as_str) + .unwrap_or("Responses API returned response.failed") + .to_string(); + stats.increment("error:failed"); + stats.log_summary("response_failed"); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + "response.incomplete" => { + // Prefer returning partial output (rust-genai behavior) instead of hard-failing the round. + // Still mark finish_reason so the caller can decide how to handle it. + if received_finish_reason { + continue; + } + let reason = event + .response + .as_ref() + .and_then(|response| response.get("incomplete_details")) + .and_then(|details| details.get("reason")) + .and_then(Value::as_str) + .map(|s| s.to_string()); + + let finish_reason = reason + .as_deref() + .map(|r| format!("incomplete:{r}")) + .unwrap_or_else(|| "incomplete".to_string()); + + let usage = event + .response + .clone() + .and_then(|v| serde_json::from_value::(v).ok()) + .and_then(|r| r.usage) + .map(Into::into); + + received_finish_reason = true; + let unified_response = UnifiedResponse { + usage, + finish_reason: Some(finish_reason), + ..Default::default() + }; + stats.record_unified_response(&unified_response); + let _ = tx_event.send(Ok(unified_response)); + continue; + } + _ => {} + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + super::stream_stats::StreamStats, + extract_api_error_message, handle_function_call_output_item_done, InProgressToolCall, + }; + use serde_json::json; + use std::collections::HashMap; + use tokio::sync::mpsc; + + #[test] + fn extracts_api_error_message_from_response_error() { + let event = json!({ + "type": "response.failed", + "response": { + "error": { + "message": "provider error" + } + } + }); + + assert_eq!( + extract_api_error_message(&event).as_deref(), + Some("provider error") + ); + } + + #[test] + fn returns_none_when_no_response_error_exists() { + let event = json!({ + "type": "response.created", + "response": { + "id": "resp_1" + } + }); + + assert!(extract_api_error_message(&event).is_none()); + } + + #[test] + fn returns_none_when_response_error_is_null() { + let event = json!({ + "type": "response.created", + "response": { + "id": "resp_1", + "error": null + } + }); + + assert!(extract_api_error_message(&event).is_none()); + } + + #[test] + fn output_item_done_falls_back_when_output_index_is_untracked() { + let (tx_event, mut rx_event) = mpsc::unbounded_channel(); + let mut tool_calls_by_output_index: HashMap = HashMap::new(); + let mut tool_call_index_by_id: HashMap = HashMap::new(); + let mut stats = StreamStats::new("Responses"); + + handle_function_call_output_item_done( + &tx_event, + &mut stats, + Some(3), + json!({ + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": "{\"city\":\"Beijing\"}" + }), + &mut tool_calls_by_output_index, + &mut tool_call_index_by_id, + ); + + let response = rx_event + .try_recv() + .expect("tool call event") + .expect("ok response"); + let tool_call = response.tool_call.expect("tool call"); + assert_eq!(tool_call.id.as_deref(), Some("call_1")); + assert_eq!(tool_call.name.as_deref(), Some("get_weather")); + assert_eq!( + tool_call.arguments.as_deref(), + Some("{\"city\":\"Beijing\"}") + ); + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/stream_stats.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/stream_stats.rs new file mode 100644 index 00000000..ecad7abd --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/stream_handler/stream_stats.rs @@ -0,0 +1,148 @@ +use crate::types::unified::UnifiedResponse; +use chrono::{DateTime, Local}; +use log::debug; +use std::collections::BTreeMap; +use std::time::Instant; + +#[derive(Debug)] +pub(super) struct StreamStats { + provider: &'static str, + started_at: Instant, + started_at_wall: DateTime, + first_event_at: Option, + first_event_at_wall: Option>, + last_event_at: Option, + last_event_at_wall: Option>, + total_sse_events: usize, + total_unified_responses: usize, + counters: BTreeMap, +} + +impl StreamStats { + pub(super) fn new(provider: &'static str) -> Self { + Self { + provider, + started_at: Instant::now(), + started_at_wall: Local::now(), + first_event_at: None, + first_event_at_wall: None, + last_event_at: None, + last_event_at_wall: None, + total_sse_events: 0, + total_unified_responses: 0, + counters: BTreeMap::new(), + } + } + + pub(super) fn record_sse_event(&mut self, event_kind: impl AsRef) { + let now = Instant::now(); + let now_wall = Local::now(); + if self.first_event_at.is_none() { + self.first_event_at = Some(now); + self.first_event_at_wall = Some(now_wall); + } + self.last_event_at = Some(now); + self.last_event_at_wall = Some(now_wall); + self.total_sse_events += 1; + self.increment(format!("sse:{}", event_kind.as_ref())); + } + + pub(super) fn increment(&mut self, label: impl Into) { + *self.counters.entry(label.into()).or_insert(0) += 1; + } + + pub(super) fn record_unified_response(&mut self, response: &UnifiedResponse) { + self.total_unified_responses += 1; + + let mut classified = false; + + if response.text.is_some() { + self.increment("out:text"); + classified = true; + } + if response.reasoning_content.is_some() { + self.increment("out:reasoning"); + classified = true; + } + if response.tool_call.is_some() { + self.increment("out:tool_call"); + classified = true; + } + if response.usage.is_some() { + self.increment("out:usage"); + classified = true; + } + if response.finish_reason.is_some() { + self.increment("out:finish_reason"); + classified = true; + } + if response.thinking_signature.is_some() { + self.increment("out:thinking_signature"); + classified = true; + } + if response.provider_metadata.is_some() { + self.increment("out:provider_metadata"); + classified = true; + } + + if !classified { + self.increment("out:other"); + } + } + + pub(super) fn log_summary(&self, reason: &str) { + let ended_at_wall = Local::now(); + let wall_elapsed = self.started_at.elapsed(); + let wall_elapsed_ms = wall_elapsed.as_millis(); + let first_event_latency_ms = self + .first_event_at + .map(|instant| instant.duration_since(self.started_at).as_millis()) + .unwrap_or(0); + let receive_elapsed_secs = match (self.first_event_at, self.last_event_at) { + (Some(first), Some(last)) => last.duration_since(first).as_secs_f64(), + _ => 0.0, + }; + let receive_elapsed_ms = (receive_elapsed_secs * 1000.0).round() as u128; + let unified_response_rate_per_sec = if receive_elapsed_secs > 0.0 { + self.total_unified_responses as f64 / receive_elapsed_secs + } else { + 0.0 + }; + let started_at = self.started_at_wall.format("%Y-%m-%d %H:%M:%S%.3f"); + let first_event_at = self + .first_event_at_wall + .map(|value| value.format("%Y-%m-%d %H:%M:%S%.3f").to_string()) + .unwrap_or_else(|| "none".to_string()); + let last_event_at = self + .last_event_at_wall + .map(|value| value.format("%Y-%m-%d %H:%M:%S%.3f").to_string()) + .unwrap_or_else(|| "none".to_string()); + let ended_at = ended_at_wall.format("%Y-%m-%d %H:%M:%S%.3f"); + let counter_lines = if self.counters.is_empty() { + "counter.none=0".to_string() + } else { + self.counters + .iter() + .map(|(label, count)| format!("counter.{}={}", label, count)) + .collect::>() + .join("\n") + }; + + debug!( + "{} stream stats\nreason={}\nstarted_at={}\nfirst_event_at={}\nlast_event_at={}\nended_at={}\ntotal_sse_events={}\ntotal_unified_responses={}\nfirst_event_latency_ms={}\nreceive_elapsed_ms={}\nwall_elapsed_ms={}\nunified_response_rate_per_sec={:.2}\n{}", + self.provider, + reason, + started_at, + first_event_at, + last_event_at, + ended_at, + self.total_sse_events, + self.total_unified_responses, + first_event_latency_ms, + receive_elapsed_ms, + wall_elapsed_ms, + unified_response_rate_per_sec, + counter_lines + ); + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/anthropic.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/anthropic.rs new file mode 100644 index 00000000..4f101ab1 --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/anthropic.rs @@ -0,0 +1,207 @@ +use super::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +pub struct MessageStart { + pub message: Message, +} + +#[derive(Debug, Deserialize)] +pub struct Message { + pub usage: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Usage { + input_tokens: Option, + output_tokens: Option, + cache_read_input_tokens: Option, + cache_creation_input_tokens: Option, +} + +impl Default for Usage { + fn default() -> Self { + Self { + input_tokens: None, + output_tokens: None, + cache_read_input_tokens: None, + cache_creation_input_tokens: None, + } + } +} + +impl Usage { + pub fn update(&mut self, other: &Usage) { + if other.input_tokens.is_some() { + self.input_tokens = other.input_tokens; + } + if other.output_tokens.is_some() { + self.output_tokens = other.output_tokens; + } + if other.cache_read_input_tokens.is_some() { + self.cache_read_input_tokens = other.cache_read_input_tokens; + } + if other.cache_creation_input_tokens.is_some() { + self.cache_creation_input_tokens = other.cache_creation_input_tokens; + } + } + + pub fn is_empty(&self) -> bool { + self.input_tokens.is_none() + && self.output_tokens.is_none() + && self.cache_read_input_tokens.is_none() + && self.cache_creation_input_tokens.is_none() + } +} + +impl From for UnifiedTokenUsage { + fn from(value: Usage) -> Self { + let cache_read = value.cache_read_input_tokens.unwrap_or(0); + let cache_creation = value.cache_creation_input_tokens.unwrap_or(0); + let prompt_token_count = value.input_tokens.unwrap_or(0) + cache_read + cache_creation; + let candidates_token_count = value.output_tokens.unwrap_or(0); + Self { + prompt_token_count, + candidates_token_count, + total_token_count: prompt_token_count + candidates_token_count, + reasoning_token_count: None, + cached_content_token_count: match ( + value.cache_read_input_tokens, + value.cache_creation_input_tokens, + ) { + (None, None) => None, + (read, creation) => Some(read.unwrap_or(0) + creation.unwrap_or(0)), + }, + } + } +} + +#[derive(Debug, Deserialize)] +pub struct MessageDelta { + pub delta: MessageDeltaDelta, + pub usage: Option, +} + +#[derive(Debug, Deserialize)] +pub struct MessageDeltaDelta { + pub stop_reason: Option, + pub stop_sequence: Option, +} + +impl From for UnifiedResponse { + fn from(value: MessageDelta) -> Self { + Self { + text: None, + reasoning_content: None, + thinking_signature: None, + tool_call: None, + usage: value.usage.map(UnifiedTokenUsage::from), + finish_reason: value.delta.stop_reason, + provider_metadata: None, + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ContentBlockStart { + pub content_block: ContentBlock, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum ContentBlock { + #[serde(rename = "thinking")] + Thinking, + #[serde(rename = "text")] + Text, + #[serde(rename = "tool_use")] + ToolUse { id: String, name: String }, + #[serde(other)] + Unknown, +} + +impl From for UnifiedResponse { + fn from(value: ContentBlockStart) -> Self { + let mut result = UnifiedResponse::default(); + match value.content_block { + ContentBlock::ToolUse { id, name } => { + let tool_call = UnifiedToolCall { + id: Some(id), + name: Some(name), + arguments: None, + }; + result.tool_call = Some(tool_call); + } + _ => {} + } + result + } +} + +#[derive(Debug, Deserialize)] +pub struct ContentBlockDelta { + delta: Delta, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum Delta { + #[serde(rename = "thinking_delta")] + ThinkingDelta { thinking: String }, + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "input_json_delta")] + InputJsonDelta { partial_json: String }, + #[serde(rename = "signature_delta")] + SignatureDelta { signature: String }, + #[serde(other)] + Unknown, +} + +impl TryFrom for UnifiedResponse { + type Error = String; + fn try_from(value: ContentBlockDelta) -> Result { + let mut result = UnifiedResponse::default(); + match value.delta { + Delta::ThinkingDelta { thinking } => { + result.reasoning_content = Some(thinking); + } + Delta::TextDelta { text } => { + result.text = Some(text); + } + Delta::InputJsonDelta { partial_json } => { + let tool_call = UnifiedToolCall { + id: None, + name: None, + arguments: Some(partial_json), + }; + result.tool_call = Some(tool_call); + } + Delta::SignatureDelta { signature } => { + result.thinking_signature = Some(signature); + } + Delta::Unknown => { + return Err("Unsupported anthropic delta type".to_string()); + } + } + Ok(result) + } +} + +#[derive(Debug, Deserialize)] +pub struct AnthropicSSEError { + pub error: AnthropicSSEErrorDetails, +} + +#[derive(Debug, Deserialize)] +pub struct AnthropicSSEErrorDetails { + #[serde(rename = "type")] + pub error_type: String, + pub message: String, +} + +impl From for String { + fn from(value: AnthropicSSEErrorDetails) -> Self { + format!("{}: {}", value.error_type, value.message) + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/gemini.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/gemini.rs new file mode 100644 index 00000000..c2e26719 --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/gemini.rs @@ -0,0 +1,700 @@ +use crate::types::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiSSEData { + #[serde(default)] + pub candidates: Vec, + #[serde(default)] + pub usage_metadata: Option, + #[serde(default)] + pub prompt_feedback: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiCandidate { + #[serde(default)] + pub content: Option, + #[serde(default)] + pub finish_reason: Option, + #[serde(default)] + pub grounding_metadata: Option, + #[serde(default)] + pub safety_ratings: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiContent { + #[serde(default)] + pub parts: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiPart { + #[serde(default)] + pub text: Option, + #[serde(default)] + pub thought: Option, + #[serde(default)] + pub thought_signature: Option, + #[serde(default)] + pub function_call: Option, + #[serde(default)] + pub executable_code: Option, + #[serde(default)] + pub code_execution_result: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiFunctionCall { + #[serde(default)] + pub name: Option, + #[serde(default)] + pub args: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiExecutableCode { + #[serde(default)] + pub language: Option, + #[serde(default)] + pub code: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiCodeExecutionResult { + #[serde(default)] + pub outcome: Option, + #[serde(default)] + pub output: Option, +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct GeminiUsageMetadata { + #[serde(default)] + pub prompt_token_count: u32, + #[serde(default)] + pub candidates_token_count: u32, + #[serde(default)] + pub total_token_count: u32, + #[serde(default)] + pub thoughts_token_count: Option, + #[serde(default)] + pub cached_content_token_count: Option, +} + +impl From for UnifiedTokenUsage { + fn from(usage: GeminiUsageMetadata) -> Self { + let reasoning_token_count = usage.thoughts_token_count; + let candidates_token_count = usage + .candidates_token_count + .saturating_add(reasoning_token_count.unwrap_or(0)); + Self { + prompt_token_count: usage.prompt_token_count, + candidates_token_count, + total_token_count: usage.total_token_count, + reasoning_token_count, + cached_content_token_count: usage.cached_content_token_count, + } + } +} + +impl GeminiSSEData { + fn render_executable_code(executable_code: &GeminiExecutableCode) -> Option { + let code = executable_code.code.as_deref()?.trim(); + if code.is_empty() { + return None; + } + + let language = executable_code + .language + .as_deref() + .map(|language| language.to_ascii_lowercase()) + .unwrap_or_else(|| "text".to_string()); + + Some(format!( + "Gemini code execution generated code:\n```{}\n{}\n```", + language, code + )) + } + + fn render_code_execution_result(result: &GeminiCodeExecutionResult) -> Option { + let output = result.output.as_deref()?.trim(); + if output.is_empty() { + return None; + } + + let outcome = result.outcome.as_deref().unwrap_or("OUTCOME_UNKNOWN"); + Some(format!( + "Gemini code execution result ({}):\n{}", + outcome, output + )) + } + + fn grounding_summary(metadata: &Value) -> Option { + let mut lines = Vec::new(); + + let queries = metadata + .get("webSearchQueries") + .and_then(Value::as_array) + .map(|queries| { + queries + .iter() + .filter_map(Value::as_str) + .filter(|query| !query.trim().is_empty()) + .collect::>() + }) + .unwrap_or_default(); + + if !queries.is_empty() { + lines.push(format!("Search queries: {}", queries.join(" | "))); + } + + let sources = metadata + .get("groundingChunks") + .and_then(Value::as_array) + .map(|chunks| { + chunks + .iter() + .filter_map(|chunk| { + let web = chunk.get("web")?; + let uri = web.get("uri").and_then(Value::as_str)?.trim(); + if uri.is_empty() { + return None; + } + let title = web + .get("title") + .and_then(Value::as_str) + .map(str::trim) + .filter(|title| !title.is_empty()) + .unwrap_or(uri); + Some((title.to_string(), uri.to_string())) + }) + .collect::>() + }) + .unwrap_or_default(); + + if !sources.is_empty() { + lines.push("Sources:".to_string()); + for (index, (title, uri)) in sources.into_iter().enumerate() { + lines.push(format!("{}. {} - {}", index + 1, title, uri)); + } + } + + let supports = metadata + .get("groundingSupports") + .and_then(Value::as_array) + .map(|supports| { + supports + .iter() + .filter_map(|support| { + let segment_text = support + .get("segment") + .and_then(Value::as_object) + .and_then(|segment| segment.get("text")) + .and_then(Value::as_str) + .map(str::trim) + .filter(|text| !text.is_empty())?; + + let chunk_indices = support + .get("groundingChunkIndices") + .and_then(Value::as_array) + .map(|indices| { + indices + .iter() + .filter_map(Value::as_u64) + .map(|index| (index + 1).to_string()) + .collect::>() + }) + .unwrap_or_default(); + + if chunk_indices.is_empty() { + None + } else { + Some((segment_text.to_string(), chunk_indices.join(", "))) + } + }) + .collect::>() + }) + .unwrap_or_default(); + + if !supports.is_empty() { + lines.push("Citations:".to_string()); + for (segment, indices) in supports.into_iter().take(5) { + lines.push(format!("- \"{}\" -> [{}]", segment, indices)); + } + } + + if lines.is_empty() { + None + } else { + Some(lines.join("\n")) + } + } + + fn safety_summary( + prompt_feedback: Option<&Value>, + safety_ratings: Option<&Value>, + ) -> Option { + let mut lines = Vec::new(); + + if let Some(prompt_feedback) = prompt_feedback { + if let Some(blocked_reason) = prompt_feedback + .get("blockReason") + .and_then(Value::as_str) + .filter(|reason| !reason.trim().is_empty()) + { + lines.push(format!("Prompt blocked reason: {}", blocked_reason)); + } + + if let Some(block_reason_message) = prompt_feedback + .get("blockReasonMessage") + .and_then(Value::as_str) + .filter(|message| !message.trim().is_empty()) + { + lines.push(format!("Prompt block message: {}", block_reason_message)); + } + } + + let ratings = safety_ratings + .and_then(Value::as_array) + .map(|ratings| { + ratings + .iter() + .filter_map(|rating| { + let category = rating.get("category").and_then(Value::as_str)?; + let probability = rating + .get("probability") + .and_then(Value::as_str) + .unwrap_or("UNKNOWN"); + let blocked = rating + .get("blocked") + .and_then(Value::as_bool) + .unwrap_or(false); + + if blocked || probability != "NEGLIGIBLE" { + Some(format!( + "{} (probability={}, blocked={})", + category, probability, blocked + )) + } else { + None + } + }) + .collect::>() + }) + .unwrap_or_default(); + + if !ratings.is_empty() { + lines.push("Safety ratings:".to_string()); + lines.extend(ratings.into_iter().map(|rating| format!("- {}", rating))); + } + + if lines.is_empty() { + None + } else { + Some(lines.join("\n")) + } + } + + fn provider_metadata_summary(metadata: &Value) -> Option { + let prompt_feedback = metadata.get("promptFeedback"); + let grounding_metadata = metadata.get("groundingMetadata"); + let safety_ratings = metadata.get("safetyRatings"); + + let mut sections = Vec::new(); + if let Some(safety) = Self::safety_summary(prompt_feedback, safety_ratings) { + sections.push(safety); + } + if let Some(grounding) = grounding_metadata.and_then(Self::grounding_summary) { + sections.push(grounding); + } + + if sections.is_empty() { + None + } else { + Some(sections.join("\n\n")) + } + } + + pub fn into_unified_responses(self) -> Vec { + let mut usage = self.usage_metadata.map(Into::into); + let prompt_feedback = self.prompt_feedback; + let Some(candidate) = self.candidates.into_iter().next() else { + return usage + .take() + .map(|usage| { + vec![UnifiedResponse { + usage: Some(usage), + ..Default::default() + }] + }) + .unwrap_or_default(); + }; + + let mut responses = Vec::new(); + let mut finish_reason = candidate.finish_reason; + let grounding_metadata = candidate.grounding_metadata; + let safety_ratings = candidate.safety_ratings; + + if let Some(content) = candidate.content { + for part in content.parts { + let has_function_call = part.function_call.is_some(); + let text = part.text.filter(|text| !text.is_empty()); + let is_thought = part.thought.unwrap_or(false); + let thinking_signature = part.thought_signature.filter(|value| !value.is_empty()); + + if let Some(function_call) = part.function_call { + let arguments = function_call.args.unwrap_or_else(|| json!({})); + responses.push(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature, + tool_call: Some(UnifiedToolCall { + id: None, + name: function_call.name, + arguments: serde_json::to_string(&arguments).ok(), + }), + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + continue; + } + + if let Some(executable_code) = part.executable_code.as_ref() { + if let Some(reasoning_content) = Self::render_executable_code(executable_code) { + responses.push(UnifiedResponse { + text: None, + reasoning_content: Some(reasoning_content), + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + continue; + } + } + + if let Some(code_execution_result) = part.code_execution_result.as_ref() { + if let Some(reasoning_content) = + Self::render_code_execution_result(code_execution_result) + { + responses.push(UnifiedResponse { + text: None, + reasoning_content: Some(reasoning_content), + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + continue; + } + } + + if let Some(text) = text { + responses.push(UnifiedResponse { + text: if is_thought { None } else { Some(text.clone()) }, + reasoning_content: if is_thought { Some(text) } else { None }, + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + continue; + } + + if thinking_signature.is_some() && !has_function_call { + responses.push(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + } + } + } + + let provider_metadata = { + let mut metadata = serde_json::Map::new(); + if let Some(prompt_feedback) = prompt_feedback { + metadata.insert("promptFeedback".to_string(), prompt_feedback); + } + if let Some(grounding_metadata) = grounding_metadata { + metadata.insert("groundingMetadata".to_string(), grounding_metadata); + } + if let Some(safety_ratings) = safety_ratings { + metadata.insert("safetyRatings".to_string(), safety_ratings); + } + + if metadata.is_empty() { + None + } else { + Some(Value::Object(metadata)) + } + }; + + if let Some(provider_metadata) = provider_metadata { + let summary = Self::provider_metadata_summary(&provider_metadata); + responses.push(UnifiedResponse { + text: summary, + reasoning_content: None, + thinking_signature: None, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: Some(provider_metadata), + }); + } + + if responses.is_empty() { + responses.push(UnifiedResponse { + usage, + finish_reason, + ..Default::default() + }); + } + + responses + } +} + +#[cfg(test)] +mod tests { + use super::GeminiSSEData; + + #[test] + fn converts_text_thought_and_usage() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { "text": "thinking", "thought": true, "thoughtSignature": "sig_1" }, + { "text": "answer" } + ] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 4, + "thoughtsTokenCount": 2, + "totalTokenCount": 14 + } + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert_eq!(responses[0].reasoning_content.as_deref(), Some("thinking")); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_1")); + assert_eq!( + responses[0] + .usage + .as_ref() + .and_then(|usage| usage.reasoning_token_count), + Some(2) + ); + assert_eq!( + responses[0] + .usage + .as_ref() + .map(|usage| usage.candidates_token_count), + Some(6) + ); + assert_eq!( + responses[0] + .usage + .as_ref() + .map(|usage| usage.total_token_count), + Some(14) + ); + assert_eq!(responses[1].text.as_deref(), Some("answer")); + } + + #[test] + fn keeps_thought_signature_on_function_call_parts() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { + "thoughtSignature": "sig_tool", + "functionCall": { + "name": "get_weather", + "args": { "city": "Paris" } + } + } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_tool")); + assert_eq!( + responses[0] + .tool_call + .as_ref() + .and_then(|tool_call| tool_call.name.as_deref()), + Some("get_weather") + ); + } + + #[test] + fn keeps_standalone_thought_signature_parts() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { "thoughtSignature": "sig_only" } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_only")); + assert!(responses[0].tool_call.is_none()); + assert!(responses[0].text.is_none()); + assert!(responses[0].reasoning_content.is_none()); + } + + #[test] + fn converts_code_execution_parts_to_reasoning_chunks() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { + "executableCode": { + "language": "PYTHON", + "code": "print(1 + 1)" + } + }, + { + "codeExecutionResult": { + "outcome": "OUTCOME_OK", + "output": "2" + } + } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert!(responses[0] + .reasoning_content + .as_deref() + .is_some_and(|text| text.contains("print(1 + 1)"))); + assert!(responses[1] + .reasoning_content + .as_deref() + .is_some_and(|text| text.contains("OUTCOME_OK") && text.contains("2"))); + } + + #[test] + fn emits_grounding_summary_and_provider_metadata() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { "text": "answer" } + ] + }, + "groundingMetadata": { + "webSearchQueries": ["latest rust release"], + "groundingChunks": [ + { + "web": { + "uri": "https://www.rust-lang.org", + "title": "Rust" + } + } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert_eq!(responses[0].text.as_deref(), Some("answer")); + assert!(responses[1] + .text + .as_deref() + .is_some_and(|text| text.contains("Sources:") && text.contains("rust-lang.org"))); + assert!(responses[1] + .provider_metadata + .as_ref() + .and_then(|metadata| metadata.get("groundingMetadata")) + .is_some()); + } + + #[test] + fn emits_prompt_feedback_and_safety_summary() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { "parts": [] }, + "finishReason": "SAFETY", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "MEDIUM", + "blocked": true + } + ] + }], + "promptFeedback": { + "blockReason": "SAFETY", + "blockReasonMessage": "Blocked by safety system" + } + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].finish_reason.as_deref(), Some("SAFETY")); + assert!(responses[0] + .text + .as_deref() + .is_some_and(|text| text.contains("Prompt blocked reason: SAFETY"))); + assert!(responses[0] + .text + .as_deref() + .is_some_and(|text| text.contains("HARM_CATEGORY_DANGEROUS_CONTENT"))); + assert!(responses[0] + .provider_metadata + .as_ref() + .and_then(|metadata| metadata.get("promptFeedback")) + .is_some()); + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/mod.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/mod.rs new file mode 100644 index 00000000..39693a3a --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/mod.rs @@ -0,0 +1,5 @@ +pub mod anthropic; +pub mod gemini; +pub mod openai; +pub mod responses; +pub mod unified; diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/openai.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/openai.rs new file mode 100644 index 00000000..ed2bdedc --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/openai.rs @@ -0,0 +1,388 @@ +use super::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +struct PromptTokensDetails { + cached_tokens: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAIUsage { + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, + #[serde(default)] + total_tokens: u32, + prompt_tokens_details: Option, +} + +impl From for UnifiedTokenUsage { + fn from(usage: OpenAIUsage) -> Self { + Self { + prompt_token_count: usage.prompt_tokens, + candidates_token_count: usage.completion_tokens, + total_token_count: usage.total_tokens, + reasoning_token_count: None, + cached_content_token_count: usage + .prompt_tokens_details + .and_then(|prompt_tokens_details| prompt_tokens_details.cached_tokens), + } + } +} + +#[derive(Debug, Deserialize)] +struct Choice { + #[allow(dead_code)] + index: usize, + delta: Delta, + finish_reason: Option, +} + +/// MiniMax `reasoning_details` array element. +/// Only elements with `type == "reasoning.text"` carry thinking text. +#[derive(Debug, Deserialize)] +struct ReasoningDetail { + #[serde(rename = "type")] + detail_type: Option, + text: Option, +} + +#[derive(Debug, Deserialize)] +struct Delta { + #[allow(dead_code)] + role: Option, + /// Standard OpenAI-compatible reasoning field (DeepSeek, Qwen, etc.) + reasoning_content: Option, + /// MiniMax-specific reasoning field; used as fallback when `reasoning_content` is absent. + reasoning_details: Option>, + content: Option, + tool_calls: Option>, +} + +#[derive(Debug, Deserialize, Clone)] +struct OpenAIToolCall { + #[allow(dead_code)] + index: usize, + #[allow(dead_code)] + id: Option, + #[allow(dead_code)] + #[serde(rename = "type")] + tool_type: Option, + function: Option, +} + +impl From for UnifiedToolCall { + fn from(tool_call: OpenAIToolCall) -> Self { + Self { + id: tool_call.id, + name: tool_call.function.as_ref().and_then(|f| f.name.clone()), + arguments: tool_call + .function + .as_ref() + .and_then(|f| f.arguments.clone()), + } + } +} + +#[derive(Debug, Deserialize, Clone)] +struct FunctionCall { + name: Option, + arguments: Option, +} + +#[derive(Debug, Deserialize)] +pub struct OpenAISSEData { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + created: u64, + #[allow(dead_code)] + model: String, + choices: Vec, + usage: Option, +} + +impl OpenAISSEData { + pub fn is_choices_empty(&self) -> bool { + self.choices.is_empty() + } + + pub fn first_choice_tool_call_count(&self) -> usize { + self.choices + .first() + .and_then(|choice| choice.delta.tool_calls.as_ref()) + .map(|tool_calls| tool_calls.len()) + .unwrap_or(0) + } + + pub fn into_unified_responses(self) -> Vec { + let mut usage = self.usage.map(|usage| usage.into()); + + let Some(first_choice) = self.choices.into_iter().next() else { + // OpenAI can emit `choices: []` for the final usage chunk. + return usage + .map(|usage_data| { + vec![UnifiedResponse { + usage: Some(usage_data), + ..Default::default() + }] + }) + .unwrap_or_default(); + }; + + let Choice { + delta, + finish_reason, + .. + } = first_choice; + let mut finish_reason = finish_reason; + let Delta { + reasoning_content, + reasoning_details, + content, + tool_calls, + .. + } = delta; + + // Treat empty strings the same as absent fields (MiniMax sends `content: ""` in + // reasoning-only chunks). + let content = content.filter(|s| !s.is_empty()); + let reasoning_content = reasoning_content.filter(|s| !s.is_empty()); + + // MiniMax uses `reasoning_details` instead of `reasoning_content`. + // Collect all "reasoning.text" entries and join them as a fallback. + let reasoning_content = reasoning_content.or_else(|| { + reasoning_details.and_then(|details| { + let text: String = details + .into_iter() + .filter(|d| d.detail_type.as_deref() == Some("reasoning.text")) + .filter_map(|d| d.text) + .collect(); + if text.is_empty() { + None + } else { + Some(text) + } + }) + }); + + let mut responses = Vec::new(); + + if content.is_some() || reasoning_content.is_some() { + responses.push(UnifiedResponse { + text: content, + reasoning_content, + thinking_signature: None, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + provider_metadata: None, + }); + } + + if let Some(tool_calls) = tool_calls { + for tool_call in tool_calls { + let is_first_event = responses.is_empty(); + responses.push(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature: None, + tool_call: Some(UnifiedToolCall::from(tool_call)), + usage: if is_first_event { usage.take() } else { None }, + finish_reason: if is_first_event { + finish_reason.take() + } else { + None + }, + provider_metadata: None, + }); + } + } + + if responses.is_empty() { + responses.push(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature: None, + tool_call: None, + usage, + finish_reason, + provider_metadata: None, + }); + } + + responses + } +} + +impl From for UnifiedResponse { + fn from(data: OpenAISSEData) -> Self { + data.into_unified_responses() + .into_iter() + .next() + .unwrap_or_default() + } +} + +#[cfg(test)] +mod tests { + use super::OpenAISSEData; + + #[test] + fn splits_multiple_tool_calls_in_first_choice() { + let raw = r#"{ + "id": "chatcmpl_test", + "created": 123, + "model": "gpt-test", + "choices": [{ + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_1", + "type": "function", + "function": { + "name": "tool_a", + "arguments": "{\"a\":1}" + } + }, + { + "index": 1, + "id": "call_2", + "type": "function", + "function": { + "name": "tool_b", + "arguments": "{\"b\":2}" + } + } + ] + }, + "finish_reason": "tool_calls" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + "prompt_tokens_details": { + "cached_tokens": 3 + } + } + }"#; + + let sse_data: OpenAISSEData = serde_json::from_str(raw).expect("valid openai sse data"); + let responses = sse_data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert_eq!( + responses[0] + .tool_call + .as_ref() + .and_then(|tool| tool.id.as_deref()), + Some("call_1") + ); + assert_eq!( + responses[1] + .tool_call + .as_ref() + .and_then(|tool| tool.id.as_deref()), + Some("call_2") + ); + assert_eq!(responses[0].finish_reason.as_deref(), Some("tool_calls")); + assert!(responses[1].finish_reason.is_none()); + assert!(responses[0].usage.is_some()); + assert!(responses[1].usage.is_none()); + } + + #[test] + fn handles_empty_choices_with_usage_chunk() { + let raw = r#"{ + "id": "chatcmpl_test", + "created": 123, + "model": "gpt-test", + "choices": [], + "usage": { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10 + } + }"#; + + let sse_data: OpenAISSEData = serde_json::from_str(raw).expect("valid openai sse data"); + let responses = sse_data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert!(responses[0].usage.is_some()); + assert!(responses[0].text.is_none()); + assert!(responses[0].tool_call.is_none()); + } + + #[test] + fn handles_empty_choices_without_usage_chunk() { + let raw = r#"{ + "id": "chatcmpl_test", + "created": 123, + "model": "gpt-test", + "choices": [], + "usage": null + }"#; + + let sse_data: OpenAISSEData = serde_json::from_str(raw).expect("valid openai sse data"); + let responses = sse_data.into_unified_responses(); + + assert!(responses.is_empty()); + } + + #[test] + fn preserves_text_when_tool_calls_exist_in_same_chunk() { + let raw = r#"{ + "id": "chatcmpl_test", + "created": 123, + "model": "gpt-test", + "choices": [{ + "index": 0, + "delta": { + "content": "hello", + "tool_calls": [ + { + "index": 0, + "id": "call_1", + "type": "function", + "function": { + "name": "tool_a", + "arguments": "{\"a\":1}" + } + } + ] + }, + "finish_reason": "tool_calls" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + }"#; + + let sse_data: OpenAISSEData = serde_json::from_str(raw).expect("valid openai sse data"); + let responses = sse_data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert_eq!(responses[0].text.as_deref(), Some("hello")); + assert!(responses[0].tool_call.is_none()); + assert!(responses[0].usage.is_some()); + assert_eq!(responses[0].finish_reason.as_deref(), Some("tool_calls")); + + assert!(responses[1].text.is_none()); + assert_eq!( + responses[1] + .tool_call + .as_ref() + .and_then(|tool| tool.id.as_deref()), + Some("call_1") + ); + assert!(responses[1].usage.is_none()); + assert!(responses[1].finish_reason.is_none()); + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/responses.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/responses.rs new file mode 100644 index 00000000..a12ef79b --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/responses.rs @@ -0,0 +1,208 @@ +use super::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; +use serde::Deserialize; +use serde_json::Value; + +#[derive(Debug, Deserialize)] +pub struct ResponsesStreamEvent { + #[serde(rename = "type")] + pub kind: String, + /// Output item index in the `response.output` array. + #[serde(default)] + pub output_index: Option, + /// Content part index within an output item (for content-part events). + #[allow(dead_code)] + #[serde(default)] + pub content_index: Option, + #[serde(default)] + pub response: Option, + #[serde(default)] + pub item: Option, + #[serde(default)] + pub delta: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesCompleted { + #[allow(dead_code)] + pub id: String, + #[serde(default)] + pub usage: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesDone { + #[serde(default)] + #[allow(dead_code)] + pub id: Option, + #[serde(default)] + pub usage: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesUsage { + pub input_tokens: u32, + #[serde(default)] + pub input_tokens_details: Option, + pub output_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesInputTokensDetails { + pub cached_tokens: u32, +} + +impl From for UnifiedTokenUsage { + fn from(usage: ResponsesUsage) -> Self { + Self { + prompt_token_count: usage.input_tokens, + candidates_token_count: usage.output_tokens, + total_token_count: usage.total_tokens, + reasoning_token_count: None, + cached_content_token_count: usage + .input_tokens_details + .map(|details| details.cached_tokens), + } + } +} + +pub fn parse_responses_output_item(item_value: Value) -> Option { + let item_type = item_value.get("type")?.as_str()?; + + match item_type { + "function_call" => Some(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature: None, + tool_call: Some(UnifiedToolCall { + id: item_value + .get("call_id") + .and_then(Value::as_str) + .map(ToString::to_string), + name: item_value + .get("name") + .and_then(Value::as_str) + .map(ToString::to_string), + arguments: item_value + .get("arguments") + .and_then(Value::as_str) + .map(ToString::to_string), + }), + usage: None, + finish_reason: None, + provider_metadata: None, + }), + "message" => { + let text = item_value + .get("content") + .and_then(Value::as_array) + .map(|content| { + content + .iter() + .filter(|item| { + item.get("type").and_then(Value::as_str) == Some("output_text") + }) + .filter_map(|item| item.get("text").and_then(Value::as_str)) + .collect::() + }) + .filter(|text| !text.is_empty()); + + text.map(|text| UnifiedResponse { + text: Some(text), + reasoning_content: None, + thinking_signature: None, + tool_call: None, + usage: None, + finish_reason: None, + provider_metadata: None, + }) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::{parse_responses_output_item, ResponsesCompleted, ResponsesStreamEvent}; + use serde_json::json; + + #[test] + fn parses_output_text_message_item() { + let response = parse_responses_output_item(json!({ + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "hello" + } + ] + })) + .expect("message item"); + + assert_eq!(response.text.as_deref(), Some("hello")); + } + + #[test] + fn parses_function_call_item() { + let response = parse_responses_output_item(json!({ + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": "{\"city\":\"Beijing\"}" + })) + .expect("function call item"); + + let tool_call = response.tool_call.expect("tool call"); + assert_eq!(tool_call.id.as_deref(), Some("call_1")); + assert_eq!(tool_call.name.as_deref(), Some("get_weather")); + } + + #[test] + fn parses_completed_payload_usage() { + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.completed", + "response": { + "id": "resp_1", + "usage": { + "input_tokens": 10, + "input_tokens_details": { "cached_tokens": 2 }, + "output_tokens": 4, + "total_tokens": 14 + } + } + })) + .expect("event"); + + let completed: ResponsesCompleted = + serde_json::from_value(event.response.expect("response")).expect("completed"); + assert_eq!(completed.id, "resp_1"); + assert_eq!(completed.usage.expect("usage").total_tokens, 14); + } + + #[test] + fn parses_output_item_added_indices() { + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.output_item.added", + "output_index": 3, + "item": { "type": "function_call", "call_id": "call_1", "name": "tool", "arguments": "" } + })) + .expect("event"); + + assert_eq!(event.output_index, Some(3)); + assert!(event.item.is_some()); + } + + #[test] + fn parses_function_call_arguments_delta_indices() { + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.function_call_arguments.delta", + "output_index": 1, + "delta": "{\"a\":" + })) + .expect("event"); + + assert_eq!(event.output_index, Some(1)); + assert_eq!(event.delta.as_deref(), Some("{\"a\":")); + } +} diff --git a/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/unified.rs b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/unified.rs new file mode 100644 index 00000000..309a3501 --- /dev/null +++ b/BitFun-Installer/src-tauri/crates/installer-ai-stream/types/unified.rs @@ -0,0 +1,50 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UnifiedToolCall { + pub id: Option, + pub name: Option, + pub arguments: Option, +} + +/// Unified AI response format +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UnifiedResponse { + pub text: Option, + pub reasoning_content: Option, + /// Signature for Anthropic extended thinking (returned in multi-turn conversations) + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_signature: Option, + pub tool_call: Option, + pub usage: Option, + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_metadata: Option, +} + +impl Default for UnifiedResponse { + fn default() -> Self { + Self { + text: None, + reasoning_content: None, + thinking_signature: None, + tool_call: None, + usage: None, + finish_reason: None, + provider_metadata: None, + } + } +} + +/// Unified token usage statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UnifiedTokenUsage { + pub prompt_token_count: u32, + pub candidates_token_count: u32, + pub total_token_count: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_token_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cached_content_token_count: Option, +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/client.rs b/BitFun-Installer/src-tauri/src/connection_test/client.rs new file mode 100644 index 00000000..34503557 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/client.rs @@ -0,0 +1,2307 @@ +//! AI client implementation - refactored version +//! +//! Uses a modular architecture to separate provider-specific logic into the providers module + +use crate::connection_test::json_checker::JsonChecker; +use crate::connection_test::providers::anthropic::AnthropicMessageConverter; +use crate::connection_test::providers::gemini::GeminiMessageConverter; +use crate::connection_test::providers::openai::OpenAIMessageConverter; +use crate::connection_test::proxy::ProxyConfig; +use crate::connection_test::types::*; +use installer_ai_stream::{ + handle_anthropic_stream, handle_gemini_stream, handle_openai_stream, handle_responses_stream, + UnifiedResponse, +}; +use anyhow::{anyhow, Result}; +use futures::StreamExt; +use log::{debug, error, info, warn}; +use reqwest::{Client, Proxy}; +use serde::Deserialize; +use std::collections::HashMap; +use tokio::sync::mpsc; + +/// Streamed response result with the parsed stream and optional raw SSE receiver +pub struct StreamResponse { + /// Parsed response stream + pub stream: std::pin::Pin> + Send>>, + /// Raw SSE receiver (for error diagnostics) + pub raw_sse_rx: Option>, +} + +#[derive(Debug, Clone)] +pub struct AIClient { + client: Client, + pub config: AIConfig, +} + +#[derive(Debug, Deserialize)] +struct OpenAIModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAIModelEntry { + id: String, +} + +#[derive(Debug, Deserialize)] +struct AnthropicModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct AnthropicModelEntry { + id: String, + #[serde(default)] + display_name: Option, +} + +#[derive(Debug, Deserialize)] +struct GeminiModelsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GeminiModelEntry { + name: String, + #[serde(default)] + display_name: Option, + #[serde(default, deserialize_with = "deserialize_null_as_default")] + supported_generation_methods: Vec, +} + +fn deserialize_null_as_default<'de, D, T>(deserializer: D) -> std::result::Result +where + D: serde::Deserializer<'de>, + T: Default + serde::Deserialize<'de>, +{ + Option::::deserialize(deserializer).map(|v| v.unwrap_or_default()) +} + +impl AIClient { + const TEST_IMAGE_EXPECTED_CODE: &'static str = "BYGR"; + const TEST_IMAGE_PNG_BASE64: &'static str = + "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAACBklEQVR42u3ZsREAIAwDMYf9dw4txwJupI7Wua+YZEPBfO91h4ZjAgQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABIAAQAAgABAACAAEAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAAAAAAAEDRZI3QGf7jDvEPAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAACAAEAAIAAYAAQAAgABAACAABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAACAAEAAIAAQAAgABgABAAAjABAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQAAgABAACAAGAAEAAIAAQALwuLkoG8OSfau4AAAAASUVORK5CYII="; + const STREAM_CONNECT_TIMEOUT_SECS: u64 = 10; + const HTTP_POOL_IDLE_TIMEOUT_SECS: u64 = 30; + const HTTP_TCP_KEEPALIVE_SECS: u64 = 60; + + fn image_test_response_matches_expected(response: &str) -> bool { + let upper = response.to_ascii_uppercase(); + + // Accept contiguous letters even when separated by spaces/punctuation. + let letters_only: String = upper.chars().filter(|c| c.is_ascii_alphabetic()).collect(); + if letters_only.contains(Self::TEST_IMAGE_EXPECTED_CODE) { + return true; + } + + let tokens: Vec<&str> = upper + .split(|c: char| !c.is_ascii_alphabetic()) + .filter(|s| !s.is_empty()) + .collect(); + + if tokens + .iter() + .any(|token| *token == Self::TEST_IMAGE_EXPECTED_CODE) + { + return true; + } + + // Accept outputs like: "B Y G R". + let single_letter_stream: String = tokens + .iter() + .filter_map(|token| { + if token.len() == 1 { + let ch = token.chars().next()?; + if matches!(ch, 'R' | 'G' | 'B' | 'Y') { + return Some(ch); + } + } + None + }) + .collect(); + if single_letter_stream.contains(Self::TEST_IMAGE_EXPECTED_CODE) { + return true; + } + + // Accept outputs like: "Blue, Yellow, Green, Red". + let color_word_stream: String = tokens + .iter() + .filter_map(|token| match *token { + "RED" => Some('R'), + "GREEN" => Some('G'), + "BLUE" => Some('B'), + "YELLOW" => Some('Y'), + _ => None, + }) + .collect(); + if color_word_stream.contains(Self::TEST_IMAGE_EXPECTED_CODE) { + return true; + } + + // Last fallback: keep only RGBY letters and search code. + let color_letter_stream: String = upper + .chars() + .filter(|c| matches!(*c, 'R' | 'G' | 'B' | 'Y')) + .collect(); + color_letter_stream.contains(Self::TEST_IMAGE_EXPECTED_CODE) + } + + fn is_responses_api_format(api_format: &str) -> bool { + matches!( + api_format.to_ascii_lowercase().as_str(), + "response" | "responses" + ) + } + + fn is_gemini_api_format(api_format: &str) -> bool { + matches!( + api_format.to_ascii_lowercase().as_str(), + "gemini" | "google" + ) + } + + fn normalize_base_url_for_discovery(base_url: &str) -> String { + base_url + .trim() + .trim_end_matches('#') + .trim_end_matches('/') + .to_string() + } + + fn resolve_openai_models_url(&self) -> String { + let mut base = Self::normalize_base_url_for_discovery(&self.config.base_url); + + for suffix in ["/chat/completions", "/responses", "/models"] { + if base.ends_with(suffix) { + base.truncate(base.len() - suffix.len()); + break; + } + } + + if base.is_empty() { + return "models".to_string(); + } + + format!("{}/models", base) + } + + fn resolve_anthropic_models_url(&self) -> String { + let mut base = Self::normalize_base_url_for_discovery(&self.config.base_url); + + if base.ends_with("/v1/messages") { + base.truncate(base.len() - "/v1/messages".len()); + return format!("{}/v1/models", base); + } + + if base.ends_with("/v1/models") { + return base; + } + + if base.ends_with("/v1") { + return format!("{}/models", base); + } + + if base.is_empty() { + return "v1/models".to_string(); + } + + format!("{}/v1/models", base) + } + + fn dedupe_remote_models(models: Vec) -> Vec { + let mut seen = std::collections::HashSet::new(); + let mut deduped = Vec::new(); + + for model in models { + if seen.insert(model.id.clone()) { + deduped.push(model); + } + } + + deduped + } + + async fn list_openai_models(&self) -> Result> { + let url = self.resolve_openai_models_url(); + let response = self + .apply_openai_headers(self.client.get(&url)) + .send() + .await? + .error_for_status()?; + + let payload: OpenAIModelsResponse = response.json().await?; + Ok(Self::dedupe_remote_models( + payload + .data + .into_iter() + .map(|model| RemoteModelInfo { + id: model.id, + display_name: None, + }) + .collect(), + )) + } + + async fn list_anthropic_models(&self) -> Result> { + let url = self.resolve_anthropic_models_url(); + let response = self + .apply_anthropic_headers(self.client.get(&url), &url) + .send() + .await? + .error_for_status()?; + + let payload: AnthropicModelsResponse = response.json().await?; + Ok(Self::dedupe_remote_models( + payload + .data + .into_iter() + .map(|model| RemoteModelInfo { + id: model.id, + display_name: model.display_name, + }) + .collect(), + )) + } + + fn resolve_gemini_models_url(&self) -> String { + let base = Self::normalize_base_url_for_discovery(&self.config.base_url); + let base = Self::gemini_base_url(&base); + format!("{}/v1beta/models", base) + } + + async fn list_gemini_models(&self) -> Result> { + let url = self.resolve_gemini_models_url(); + debug!("Gemini models list URL: {}", url); + + let response = self + .apply_gemini_headers(self.client.get(&url)) + .send() + .await? + .error_for_status()?; + + let payload: GeminiModelsResponse = response.json().await?; + Ok(Self::dedupe_remote_models( + payload + .models + .into_iter() + .filter(|m| { + m.supported_generation_methods.is_empty() + || m.supported_generation_methods + .iter() + .any(|method| method == "generateContent") + }) + .map(|model| { + let id = model + .name + .strip_prefix("models/") + .unwrap_or(&model.name) + .to_string(); + RemoteModelInfo { + id, + display_name: model.display_name, + } + }) + .collect(), + )) + } + + /// Create an AIClient without proxy (backward compatible) + pub fn new(config: AIConfig) -> Self { + let skip_ssl_verify = config.skip_ssl_verify; + let client = Self::create_http_client(None, skip_ssl_verify); + Self { client, config } + } + + /// Create an AIClient with proxy configuration + pub fn new_with_proxy(config: AIConfig, proxy_config: Option) -> Self { + let skip_ssl_verify = config.skip_ssl_verify; + let client = Self::create_http_client(proxy_config, skip_ssl_verify); + Self { client, config } + } + + /// Create an HTTP client (supports proxy config and SSL verification control) + fn create_http_client(proxy_config: Option, skip_ssl_verify: bool) -> Client { + let mut builder = Client::builder() + // SSE requests can legitimately stay open for a long time while the model + // thinks or executes tools. Keep only connect timeout here and let the + // stream handlers enforce idle timeouts between chunks. + .connect_timeout(std::time::Duration::from_secs( + Self::STREAM_CONNECT_TIMEOUT_SECS, + )) + .user_agent("BitFun/1.0") + .pool_idle_timeout(std::time::Duration::from_secs( + Self::HTTP_POOL_IDLE_TIMEOUT_SECS, + )) + .pool_max_idle_per_host(4) + .tcp_keepalive(Some(std::time::Duration::from_secs( + Self::HTTP_TCP_KEEPALIVE_SECS, + ))) + .danger_accept_invalid_certs(skip_ssl_verify); + + if skip_ssl_verify { + warn!("SSL certificate verification disabled - security risk, use only in test environments"); + } + + // rustls mode does not support http2_keep_alive_interval/http2_keep_alive_timeout. + if let Some(proxy_cfg) = proxy_config { + if proxy_cfg.enabled && !proxy_cfg.url.is_empty() { + match Self::build_proxy(&proxy_cfg) { + Ok(proxy) => { + info!("Using proxy: {}", proxy_cfg.url); + builder = builder.proxy(proxy); + } + Err(e) => { + error!( + "Proxy configuration failed: {}, proceeding without proxy", + e + ); + builder = builder.no_proxy(); + } + } + } else { + builder = builder.no_proxy(); + } + } else { + builder = builder.no_proxy(); + } + + match builder.build() { + Ok(client) => client, + Err(e) => { + error!( + "HTTP client initialization failed: {}, using default client", + e + ); + Client::new() + } + } + } + + fn build_proxy(config: &ProxyConfig) -> Result { + let mut proxy = + Proxy::all(&config.url).map_err(|e| anyhow!("Failed to create proxy: {}", e))?; + + if let (Some(username), Some(password)) = (&config.username, &config.password) { + if !username.is_empty() && !password.is_empty() { + proxy = proxy.basic_auth(username, password); + debug!("Proxy authentication configured for user: {}", username); + } + } + + Ok(proxy) + } + + fn get_api_format(&self) -> &str { + &self.config.format + } + + /// Whether the URL is Alibaba DashScope API. + /// Alibaba DashScope uses `enable_thinking`=true/false for thinking, not the `thinking` object. + fn is_dashscope_url(url: &str) -> bool { + url.contains("dashscope.aliyuncs.com") + } + + /// Whether the URL is MiniMax API. + /// MiniMax (api.minimaxi.com) uses `reasoning_split=true` to enable streamed thinking content + /// delivered via `delta.reasoning_details` rather than the standard `reasoning_content` field. + fn is_minimax_url(url: &str) -> bool { + url.contains("api.minimaxi.com") + } + + /// Apply thinking-related fields onto the request body (mutates `request_body`). + /// + /// * `enable` - whether thinking process is enabled + /// * `url` - request URL + /// * `model_name` - model name (e.g. for Claude budget_tokens in Anthropic format) + /// * `api_format` - "openai" or "anthropic" + /// * `max_tokens` - optional max_tokens (for Anthropic Claude budget_tokens) + fn apply_thinking_fields( + request_body: &mut serde_json::Value, + enable: bool, + url: &str, + model_name: &str, + api_format: &str, + max_tokens: Option, + ) { + if Self::is_dashscope_url(url) && api_format.eq_ignore_ascii_case("openai") { + request_body["enable_thinking"] = serde_json::json!(enable); + return; + } + if Self::is_minimax_url(url) && api_format.eq_ignore_ascii_case("openai") { + if enable { + request_body["reasoning_split"] = serde_json::json!(true); + } + return; + } + let thinking_value = if enable { + if api_format.eq_ignore_ascii_case("anthropic") && model_name.starts_with("claude") { + let mut obj = serde_json::map::Map::new(); + obj.insert( + "type".to_string(), + serde_json::Value::String("enabled".to_string()), + ); + if let Some(m) = max_tokens { + obj.insert( + "budget_tokens".to_string(), + serde_json::json!(10000u32.min(m * 3 / 4)), + ); + } + serde_json::Value::Object(obj) + } else { + serde_json::json!({ "type": "enabled" }) + } + } else { + serde_json::json!({ "type": "disabled" }) + }; + request_body["thinking"] = thinking_value; + } + + /// Whether to append the `tool_stream` request field. + /// + /// Only Zhipu (https://open.bigmodel.cn) uses this field; and only for GLM models (pure version >= 4.6). + /// Adding this parameter for non-Zhipu APIs may cause abnormal behavior: + /// 1) incomplete output; (Aliyun Coding Plan, 2026-02-28) + /// 2) extra `` prefix on some tool names. (Aliyun Coding Plan, 2026-02-28) + fn should_append_tool_stream(url: &str, model_name: &str) -> bool { + if !url.contains("open.bigmodel.cn") { + return false; + } + Self::parse_glm_major_minor(model_name) + .map(|(major, minor)| major > 4 || (major == 4 && minor >= 6)) + .unwrap_or(false) + } + + /// Parse strict `glm-[.]` from model names like: + /// - glm-4.6 + /// - glm-5 + /// + /// Models with non-numeric suffixes are treated as not requiring this GLM-specific field, e.g.: + /// - glm-4.6-flash + /// - glm-4.5v + fn parse_glm_major_minor(model_name: &str) -> Option<(u32, u32)> { + let version_part = model_name.strip_prefix("glm-")?; + + if version_part.is_empty() { + return None; + } + + let mut parts = version_part.split('.'); + let major: u32 = parts.next()?.parse().ok()?; + let minor: u32 = match parts.next() { + Some(v) => v.parse().ok()?, + None => 0, + }; + + // Only allow one numeric segment after the decimal point. + if parts.next().is_some() { + return None; + } + + Some((major, minor)) + } + + /// Determine whether to use merge mode + /// + /// true: apply default headers first, then custom headers (custom can override) + /// false: if custom headers exist, replace defaults entirely + /// Default is merge mode + fn is_merge_headers_mode(&self) -> bool { + // Default to merge mode; use replace mode only when explicitly set to "replace" + self.config.custom_headers_mode.as_deref() != Some("replace") + } + + /// Apply custom headers to the builder + fn apply_custom_headers( + &self, + mut builder: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { + if let Some(custom_headers) = &self.config.custom_headers { + if !custom_headers.is_empty() { + for (key, value) in custom_headers { + builder = builder.header(key.as_str(), value.as_str()); + } + } + } + builder + } + + /// Apply OpenAI-style request headers (merge/replace). + fn apply_openai_headers( + &self, + mut builder: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { + let has_custom_headers = self + .config + .custom_headers + .as_ref() + .map_or(false, |h| !h.is_empty()); + let is_merge_mode = self.is_merge_headers_mode(); + + if has_custom_headers && !is_merge_mode { + return self.apply_custom_headers(builder); + } + + builder = builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.config.api_key)); + + if self.config.base_url.contains("openbitfun.com") { + builder = builder.header("X-Verification-Code", "from_bitfun"); + } + + if has_custom_headers && is_merge_mode { + builder = self.apply_custom_headers(builder); + } + + builder + } + + /// Apply Anthropic-style request headers (merge/replace). + fn apply_anthropic_headers( + &self, + mut builder: reqwest::RequestBuilder, + url: &str, + ) -> reqwest::RequestBuilder { + let has_custom_headers = self + .config + .custom_headers + .as_ref() + .map_or(false, |h| !h.is_empty()); + let is_merge_mode = self.is_merge_headers_mode(); + + if has_custom_headers && !is_merge_mode { + return self.apply_custom_headers(builder); + } + + builder = builder.header("Content-Type", "application/json"); + + if url.contains("bigmodel.cn") { + builder = builder.header("Authorization", format!("Bearer {}", self.config.api_key)); + } else { + builder = builder + .header("x-api-key", &self.config.api_key) + .header("anthropic-version", "2023-06-01"); + } + + if url.contains("openbitfun.com") { + builder = builder.header("X-Verification-Code", "from_bitfun"); + } + + if has_custom_headers && is_merge_mode { + builder = self.apply_custom_headers(builder); + } + + builder + } + + /// Apply Gemini-style request headers (merge/replace). + fn apply_gemini_headers( + &self, + mut builder: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { + let has_custom_headers = self + .config + .custom_headers + .as_ref() + .map_or(false, |h| !h.is_empty()); + let is_merge_mode = self.is_merge_headers_mode(); + + if has_custom_headers && !is_merge_mode { + return self.apply_custom_headers(builder); + } + + builder = builder + .header("Content-Type", "application/json") + .header("x-goog-api-key", &self.config.api_key) + .header( + "Authorization", + format!("Bearer {}", self.config.api_key), + ); + + if self.config.base_url.contains("openbitfun.com") { + builder = builder.header("X-Verification-Code", "from_bitfun"); + } + + if has_custom_headers && is_merge_mode { + builder = self.apply_custom_headers(builder); + } + + builder + } + + fn merge_json_value(target: &mut serde_json::Value, overlay: serde_json::Value) { + match (target, overlay) { + (serde_json::Value::Object(target_map), serde_json::Value::Object(overlay_map)) => { + for (key, value) in overlay_map { + let entry = target_map.entry(key).or_insert(serde_json::Value::Null); + Self::merge_json_value(entry, value); + } + } + (target_slot, overlay_value) => { + *target_slot = overlay_value; + } + } + } + + fn ensure_gemini_generation_config( + request_body: &mut serde_json::Value, + ) -> &mut serde_json::Map { + if !request_body + .get("generationConfig") + .is_some_and(serde_json::Value::is_object) + { + request_body["generationConfig"] = serde_json::json!({}); + } + + request_body["generationConfig"] + .as_object_mut() + .expect("generationConfig must be an object") + } + + fn insert_gemini_generation_field( + request_body: &mut serde_json::Value, + key: &str, + value: serde_json::Value, + ) { + Self::ensure_gemini_generation_config(request_body).insert(key.to_string(), value); + } + + fn normalize_gemini_stop_sequences(value: &serde_json::Value) -> Option { + match value { + serde_json::Value::String(sequence) => { + Some(serde_json::Value::Array(vec![serde_json::Value::String( + sequence.clone(), + )])) + } + serde_json::Value::Array(items) => { + let sequences = items + .iter() + .filter_map(|item| item.as_str().map(|sequence| sequence.to_string())) + .map(serde_json::Value::String) + .collect::>(); + + if sequences.is_empty() { + None + } else { + Some(serde_json::Value::Array(sequences)) + } + } + _ => None, + } + } + + fn apply_gemini_response_format_translation( + request_body: &mut serde_json::Value, + response_format: &serde_json::Value, + ) -> bool { + match response_format { + serde_json::Value::String(kind) if matches!(kind.as_str(), "json" | "json_object") => { + Self::insert_gemini_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + true + } + serde_json::Value::Object(map) => { + let Some(kind) = map.get("type").and_then(serde_json::Value::as_str) else { + return false; + }; + + match kind { + "json" | "json_object" => { + Self::insert_gemini_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + true + } + "json_schema" => { + Self::insert_gemini_generation_field( + request_body, + "responseMimeType", + serde_json::Value::String("application/json".to_string()), + ); + + if let Some(schema) = map + .get("json_schema") + .and_then(serde_json::Value::as_object) + .and_then(|json_schema| json_schema.get("schema")) + .or_else(|| map.get("schema")) + { + Self::insert_gemini_generation_field( + request_body, + "responseJsonSchema", + GeminiMessageConverter::sanitize_schema(schema.clone()), + ); + } + + true + } + _ => false, + } + } + _ => false, + } + } + + fn translate_gemini_extra_body( + request_body: &mut serde_json::Value, + extra_obj: &mut serde_json::Map, + ) { + if let Some(max_tokens) = extra_obj.remove("max_tokens") { + Self::insert_gemini_generation_field(request_body, "maxOutputTokens", max_tokens); + } + + if let Some(temperature) = extra_obj.remove("temperature") { + Self::insert_gemini_generation_field(request_body, "temperature", temperature); + } + + let top_p = extra_obj + .remove("top_p") + .or_else(|| extra_obj.remove("topP")); + if let Some(top_p) = top_p { + Self::insert_gemini_generation_field(request_body, "topP", top_p); + } + + if let Some(stop_sequences) = extra_obj + .get("stop") + .and_then(Self::normalize_gemini_stop_sequences) + { + extra_obj.remove("stop"); + Self::insert_gemini_generation_field(request_body, "stopSequences", stop_sequences); + } + + if let Some(response_mime_type) = extra_obj + .remove("responseMimeType") + .or_else(|| extra_obj.remove("response_mime_type")) + { + Self::insert_gemini_generation_field( + request_body, + "responseMimeType", + response_mime_type, + ); + } + + if let Some(response_schema) = extra_obj + .remove("responseJsonSchema") + .or_else(|| extra_obj.remove("responseSchema")) + .or_else(|| extra_obj.remove("response_schema")) + { + Self::insert_gemini_generation_field( + request_body, + "responseJsonSchema", + GeminiMessageConverter::sanitize_schema(response_schema), + ); + } + + if let Some(response_format) = extra_obj.get("response_format").cloned() { + if Self::apply_gemini_response_format_translation(request_body, &response_format) { + extra_obj.remove("response_format"); + } + } + } + + fn unified_usage_to_gemini_usage(usage: installer_ai_stream::UnifiedTokenUsage) -> GeminiUsage { + GeminiUsage { + prompt_token_count: usage.prompt_token_count, + candidates_token_count: usage.candidates_token_count, + total_token_count: usage.total_token_count, + reasoning_token_count: usage.reasoning_token_count, + cached_content_token_count: usage.cached_content_token_count, + } + } + + /// Build an OpenAI-format request body + fn build_openai_request_body( + &self, + url: &str, + openai_messages: Vec, + openai_tools: Option>, + extra_body: Option, + ) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "model": self.config.model, + "messages": openai_messages, + "stream": true + }); + + let model_name = self.config.model.to_lowercase(); + + if Self::should_append_tool_stream(url, &model_name) { + request_body["tool_stream"] = serde_json::Value::Bool(true); + } + + Self::apply_thinking_fields( + &mut request_body, + self.config.enable_thinking_process, + url, + &model_name, + "openai", + self.config.max_tokens, + ); + + if let Some(max_tokens) = self.config.max_tokens { + request_body["max_tokens"] = serde_json::json!(max_tokens); + } + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + for (key, value) in extra_obj { + request_body[key] = value.clone(); + } + debug!(target: "ai::openai_stream_request", "Applied extra_body overrides: {:?}", extra_obj.keys().collect::>()); + } + } + + // This client currently consumes only the first choice in stream handling. + // Remove custom n override and keep provider defaults. + if let Some(request_obj) = request_body.as_object_mut() { + if let Some(existing_n) = request_obj.remove("n") { + warn!( + target: "ai::openai_stream_request", + "Removed custom request field n={} because the stream processor only handles the first choice", + existing_n + ); + } + } + + debug!(target: "ai::openai_stream_request", + "OpenAI stream request body (excluding tools):\n{}", + serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "serialization failed".to_string()) + ); + + if let Some(tools) = openai_tools { + let tool_names = tools + .iter() + .map(|tool| Self::extract_openai_tool_name(tool)) + .collect::>(); + debug!(target: "ai::openai_stream_request", "\ntools: {:?}", tool_names); + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + // Respect `extra_body` overrides (e.g. tool_choice="required") when present. + let has_tool_choice = request_body + .get("tool_choice") + .is_some_and(|v| !v.is_null()); + if !has_tool_choice { + request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); + } + } + } + + request_body + } + + /// Build a Responses API request body. + fn build_responses_request_body( + &self, + instructions: Option, + response_input: Vec, + openai_tools: Option>, + extra_body: Option, + ) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "model": self.config.model, + "input": response_input, + "stream": true + }); + + if let Some(instructions) = instructions.filter(|value| !value.trim().is_empty()) { + request_body["instructions"] = serde_json::Value::String(instructions); + } + + if let Some(max_tokens) = self.config.max_tokens { + request_body["max_output_tokens"] = serde_json::json!(max_tokens); + } + + if let Some(ref effort) = self.config.reasoning_effort { + request_body["reasoning"] = serde_json::json!({ + "effort": effort, + "summary": "auto" + }); + } + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + for (key, value) in extra_obj { + request_body[key] = value.clone(); + } + debug!( + target: "ai::responses_stream_request", + "Applied extra_body overrides: {:?}", + extra_obj.keys().collect::>() + ); + } + } + + debug!( + target: "ai::responses_stream_request", + "Responses stream request body (excluding tools):\n{}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "serialization failed".to_string()) + ); + + if let Some(tools) = openai_tools { + let tool_names = tools + .iter() + .map(|tool| Self::extract_openai_tool_name(tool)) + .collect::>(); + debug!(target: "ai::responses_stream_request", "\ntools: {:?}", tool_names); + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + // Respect `extra_body` overrides (e.g. tool_choice="required") when present. + let has_tool_choice = request_body + .get("tool_choice") + .is_some_and(|v| !v.is_null()); + if !has_tool_choice { + request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); + } + } + } + + request_body + } + + /// Build an Anthropic-format request body + fn build_anthropic_request_body( + &self, + url: &str, + system_message: Option, + anthropic_messages: Vec, + anthropic_tools: Option>, + extra_body: Option, + ) -> serde_json::Value { + let max_tokens = self.config.max_tokens.unwrap_or(8192); + + let mut request_body = serde_json::json!({ + "model": self.config.model, + "messages": anthropic_messages, + "max_tokens": max_tokens, + "stream": true + }); + + let model_name = self.config.model.to_lowercase(); + + // Zhipu extension: only set `tool_stream` for open.bigmodel.cn. + if Self::should_append_tool_stream(url, &model_name) { + request_body["tool_stream"] = serde_json::Value::Bool(true); + } + + Self::apply_thinking_fields( + &mut request_body, + self.config.enable_thinking_process, + url, + &model_name, + "anthropic", + Some(max_tokens), + ); + + if let Some(system) = system_message { + request_body["system"] = serde_json::Value::String(system); + } + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + for (key, value) in extra_obj { + request_body[key] = value.clone(); + } + debug!(target: "ai::anthropic_stream_request", "Applied extra_body overrides: {:?}", extra_obj.keys().collect::>()); + } + } + + debug!(target: "ai::anthropic_stream_request", + "Anthropic stream request body (excluding tools):\n{}", + serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "serialization failed".to_string()) + ); + + if let Some(tools) = anthropic_tools { + let tool_names = tools + .iter() + .map(|tool| Self::extract_anthropic_tool_name(tool)) + .collect::>(); + debug!(target: "ai::anthropic_stream_request", "\ntools: {:?}", tool_names); + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + } + } + + request_body + } + + /// Build a Gemini-format request body. + fn build_gemini_request_body( + &self, + system_instruction: Option, + contents: Vec, + gemini_tools: Option>, + extra_body: Option, + ) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "contents": contents, + }); + + if let Some(system_instruction) = system_instruction { + request_body["systemInstruction"] = system_instruction; + } + + if let Some(max_tokens) = self.config.max_tokens { + Self::insert_gemini_generation_field( + &mut request_body, + "maxOutputTokens", + serde_json::json!(max_tokens), + ); + } + + if let Some(temperature) = self.config.temperature { + Self::insert_gemini_generation_field( + &mut request_body, + "temperature", + serde_json::json!(temperature), + ); + } + + if let Some(top_p) = self.config.top_p { + Self::insert_gemini_generation_field( + &mut request_body, + "topP", + serde_json::json!(top_p), + ); + } + + if self.config.enable_thinking_process { + Self::insert_gemini_generation_field( + &mut request_body, + "thinkingConfig", + serde_json::json!({ + "includeThoughts": true, + }), + ); + } + + if let Some(tools) = gemini_tools { + let tool_names = tools + .iter() + .flat_map(|tool| { + if let Some(declarations) = tool + .get("functionDeclarations") + .and_then(|value| value.as_array()) + { + declarations + .iter() + .filter_map(|declaration| { + declaration + .get("name") + .and_then(|value| value.as_str()) + .map(str::to_string) + }) + .collect::>() + } else { + tool.as_object() + .into_iter() + .flat_map(|map| map.keys().cloned()) + .collect::>() + } + }) + .collect::>(); + debug!(target: "ai::gemini_stream_request", "\ntools: {:?}", tool_names); + + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + let has_function_declarations = request_body["tools"] + .as_array() + .map(|tools| { + tools + .iter() + .any(|tool| tool.get("functionDeclarations").is_some()) + }) + .unwrap_or(false); + + if has_function_declarations { + request_body["toolConfig"] = serde_json::json!({ + "functionCallingConfig": { + "mode": "AUTO" + } + }); + } + } + } + + if let Some(extra) = extra_body { + if let Some(mut extra_obj) = extra.as_object().cloned() { + Self::translate_gemini_extra_body(&mut request_body, &mut extra_obj); + let override_keys = extra_obj.keys().cloned().collect::>(); + + for (key, value) in extra_obj { + if let Some(request_obj) = request_body.as_object_mut() { + let target = request_obj.entry(key).or_insert(serde_json::Value::Null); + Self::merge_json_value(target, value); + } + } + debug!( + target: "ai::gemini_stream_request", + "Applied extra_body overrides: {:?}", + override_keys + ); + } + } + + debug!( + target: "ai::gemini_stream_request", + "Gemini stream request body:\n{}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "serialization failed".to_string()) + ); + + request_body + } + + fn resolve_gemini_request_url(base_url: &str, model_name: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/'); + if trimmed.is_empty() { + return String::new(); + } + + let base = Self::gemini_base_url(trimmed); + let encoded_model = urlencoding::encode(model_name.trim()); + format!( + "{}/v1beta/models/{}:streamGenerateContent?alt=sse", + base, encoded_model + ) + } + + /// Strip /v1beta, /models/... and similar suffixes from a gemini URL, + /// returning only the bare host root (e.g. https://generativelanguage.googleapis.com). + fn gemini_base_url(url: &str) -> &str { + let mut u = url; + if let Some(pos) = u.find("/v1beta") { + u = &u[..pos]; + } + if let Some(pos) = u.find("/models/") { + u = &u[..pos]; + } + u.trim_end_matches('/') + } + + fn extract_openai_tool_name(tool: &serde_json::Value) -> String { + tool.get("function") + .and_then(|f| f.get("name")) + .and_then(|n| n.as_str()) + .unwrap_or("unknown") + .to_string() + } + + fn extract_anthropic_tool_name(tool: &serde_json::Value) -> String { + tool.get("name") + .and_then(|n| n.as_str()) + .unwrap_or("unknown") + .to_string() + } + + /// Send a streaming message request + /// + /// Returns `StreamResponse` with: + /// - `stream`: parsed response stream + /// - `raw_sse_rx`: raw SSE receiver (for collecting data during error diagnostics) + pub async fn send_message_stream( + &self, + messages: Vec, + tools: Option>, + ) -> Result { + let custom_body = self.config.custom_request_body.clone(); + self.send_message_stream_with_extra_body(messages, tools, custom_body) + .await + } + + /// Send a streaming message request with extra request body overrides + /// + /// Returns `StreamResponse` with: + /// - `stream`: parsed response stream + /// - `raw_sse_rx`: raw SSE receiver (for collecting data during error diagnostics) + pub async fn send_message_stream_with_extra_body( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + ) -> Result { + let max_tries = 3; + match self.get_api_format().to_lowercase().as_str() { + "openai" => { + self.send_openai_stream(messages, tools, extra_body, max_tries) + .await + } + format if Self::is_gemini_api_format(format) => { + self.send_gemini_stream(messages, tools, extra_body, max_tries) + .await + } + format if Self::is_responses_api_format(format) => { + self.send_responses_stream(messages, tools, extra_body, max_tries) + .await + } + "anthropic" => { + self.send_anthropic_stream(messages, tools, extra_body, max_tries) + .await + } + _ => Err(anyhow!("Unknown API format: {}", self.get_api_format())), + } + } + + /// Send an OpenAI streaming request with retries + /// + /// # Parameters + /// - `messages`: message list + /// - `tools`: tool definitions + /// - `extra_body`: extra request body parameters + /// - `max_tries`: max attempts (including the first) + async fn send_openai_stream( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, + ) -> Result { + let url = self.config.request_url.clone(); + debug!( + "OpenAI config: model={}, request_url={}, max_tries={}", + self.config.model, self.config.request_url, max_tries + ); + + // Use OpenAI message converter + let openai_messages = OpenAIMessageConverter::convert_messages(messages); + let openai_tools = OpenAIMessageConverter::convert_tools(tools); + + // Build request body + let request_body = + self.build_openai_request_body(&url, openai_messages, openai_tools, extra_body); + + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + + // Send request - apply request headers + let request_builder = self.apply_openai_headers(self.client.post(&url)); + let response_result = request_builder.json(&request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!( + "OpenAI Streaming API client error {}: {}", + status, error_text + ); + return Err(anyhow!( + "OpenAI Streaming API client error {}: {}", + status, + error_text + )); + } + + if status.is_success() { + debug!( + "Stream request connected: {}ms, status: {}, attempt: {}/{}", + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = + anyhow!("OpenAI Streaming API error {}: {}", status, error_text); + warn!( + "Stream request failed (attempt {}/{}): {}", + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("Stream request connection failed: {}", e); + warn!( + "Stream request connection failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + // Success: create channels and return + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + + tokio::spawn(handle_openai_stream( + response, + tx, + Some(tx_raw), + self.config.inline_think_in_text, + )); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "Stream request failed after {} attempts: {}", + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) + } + + /// Send a Gemini streaming request with retries. + async fn send_gemini_stream( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, + ) -> Result { + let url = Self::resolve_gemini_request_url(&self.config.request_url, &self.config.model); + debug!( + "Gemini config: model={}, request_url={}, max_tries={}", + self.config.model, url, max_tries + ); + + let (system_instruction, contents) = + GeminiMessageConverter::convert_messages(messages, &self.config.model); + let gemini_tools = GeminiMessageConverter::convert_tools(tools); + let request_body = + self.build_gemini_request_body(system_instruction, contents, gemini_tools, extra_body); + + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + let request_builder = self.apply_gemini_headers(self.client.post(&url)); + let response_result = request_builder.json(&request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!( + "Gemini Streaming API client error {}: {}", + status, error_text + ); + return Err(anyhow!( + "Gemini Streaming API client error {}: {}", + status, + error_text + )); + } + + if status.is_success() { + debug!( + "Gemini stream request connected: {}ms, status: {}, attempt: {}/{}", + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = + anyhow!("Gemini Streaming API error {}: {}", status, error_text); + warn!( + "Gemini stream request failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!( + "Retrying Gemini after {}ms (attempt {})", + delay_ms, + attempt + 2 + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("Gemini stream request connection failed: {}", e); + warn!( + "Gemini stream request connection failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!( + "Retrying Gemini after {}ms (attempt {})", + delay_ms, + attempt + 2 + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + + tokio::spawn(handle_gemini_stream(response, tx, Some(tx_raw))); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "Gemini stream request failed after {} attempts: {}", + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) + } + + /// Send a Responses API streaming request with retries. + async fn send_responses_stream( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, + ) -> Result { + let url = self.config.request_url.clone(); + debug!( + "Responses config: model={}, request_url={}, max_tries={}", + self.config.model, self.config.request_url, max_tries + ); + + let (instructions, response_input) = + OpenAIMessageConverter::convert_messages_to_responses_input(messages); + let openai_tools = OpenAIMessageConverter::convert_tools(tools); + let request_body = self.build_responses_request_body( + instructions, + response_input, + openai_tools, + extra_body, + ); + + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + let request_builder = self.apply_openai_headers(self.client.post(&url)); + let response_result = request_builder.json(&request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!("Responses API client error {}: {}", status, error_text); + return Err(anyhow!( + "Responses API client error {}: {}", + status, + error_text + )); + } + + if status.is_success() { + debug!( + "Responses request connected: {}ms, status: {}, attempt: {}/{}", + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = anyhow!("Responses API error {}: {}", status, error_text); + warn!( + "Responses request failed (attempt {}/{}): {}", + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("Responses request connection failed: {}", e); + warn!( + "Responses request connection failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + + tokio::spawn(handle_responses_stream(response, tx, Some(tx_raw))); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "Responses request failed after {} attempts: {}", + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) + } + + /// Send an Anthropic streaming request with retries + /// + /// # Parameters + /// - `messages`: message list + /// - `tools`: tool definitions + /// - `extra_body`: extra request body parameters + /// - `max_tries`: max attempts (including the first) + async fn send_anthropic_stream( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, + ) -> Result { + let url = self.config.request_url.clone(); + debug!( + "Anthropic config: model={}, request_url={}, max_tries={}", + self.config.model, self.config.request_url, max_tries + ); + + // Use Anthropic message converter + let (system_message, anthropic_messages) = + AnthropicMessageConverter::convert_messages(messages); + let anthropic_tools = AnthropicMessageConverter::convert_tools(tools); + + // Build request body + let request_body = self.build_anthropic_request_body( + &url, + system_message, + anthropic_messages, + anthropic_tools, + extra_body, + ); + + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + + // Send request - apply Anthropic-style request headers + let request_builder = self.apply_anthropic_headers(self.client.post(&url), &url); + let response_result = request_builder.json(&request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!( + "Anthropic Streaming API client error {}: {}", + status, error_text + ); + return Err(anyhow!( + "Anthropic Streaming API client error {}: {}", + status, + error_text + )); + } + + if status.is_success() { + debug!( + "Stream request connected: {}ms, status: {}, attempt: {}/{}", + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = + anyhow!("Anthropic Streaming API error {}: {}", status, error_text); + warn!( + "Stream request failed (attempt {}/{}): {}", + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("Stream request connection failed: {}", e); + warn!( + "Stream request connection failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + // Success: create channels and return + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + + tokio::spawn(handle_anthropic_stream(response, tx, Some(tx_raw))); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "Stream request failed after {} attempts: {}", + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) + } + + /// Send a message and wait for the full response (non-streaming) + pub async fn send_message( + &self, + messages: Vec, + tools: Option>, + ) -> Result { + let custom_body = self.config.custom_request_body.clone(); + self.send_message_with_extra_body(messages, tools, custom_body) + .await + } + + /// Send a message and wait for the full response (non-streaming, with extra body overrides) + pub async fn send_message_with_extra_body( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + ) -> Result { + let stream_response = self + .send_message_stream_with_extra_body(messages, tools, extra_body) + .await?; + let mut stream = stream_response.stream; + + let mut full_text = String::new(); + let mut full_reasoning = String::new(); + let mut finish_reason = None; + let mut usage = None; + let mut provider_metadata: Option = None; + + let mut tool_calls: Vec = Vec::new(); + let mut cur_tool_call_id = String::new(); + let mut cur_tool_call_name = String::new(); + let mut json_checker = JsonChecker::new(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if let Some(text) = chunk.text { + full_text.push_str(&text); + } + + if let Some(reasoning_content) = chunk.reasoning_content { + full_reasoning.push_str(&reasoning_content); + } + + if let Some(finish_reason_) = chunk.finish_reason { + finish_reason = Some(finish_reason_); + } + + if let Some(chunk_usage) = chunk.usage { + usage = Some(Self::unified_usage_to_gemini_usage(chunk_usage)); + } + + if let Some(chunk_provider_metadata) = chunk.provider_metadata { + match provider_metadata.as_mut() { + Some(existing) => { + Self::merge_json_value(existing, chunk_provider_metadata); + } + None => provider_metadata = Some(chunk_provider_metadata), + } + } + + if let Some(tool_call) = chunk.tool_call { + if let Some(tool_call_id) = tool_call.id { + if !tool_call_id.is_empty() { + // Some providers repeat the tool id on every delta. Only reset when the id changes. + let is_new_tool = cur_tool_call_id != tool_call_id; + if is_new_tool { + cur_tool_call_id = tool_call_id; + cur_tool_call_name = tool_call.name.unwrap_or_default(); + json_checker.reset(); + debug!( + "[send_message] Detected tool call: {}", + cur_tool_call_name + ); + } else if cur_tool_call_name.is_empty() { + // Best-effort: keep name if provider repeats it. + cur_tool_call_name = tool_call.name.unwrap_or_default(); + } + } + } + + if let Some(ref tool_call_arguments) = tool_call.arguments { + json_checker.append(tool_call_arguments); + } + + if json_checker.is_valid() { + let arguments_string = json_checker.get_buffer(); + let arguments: HashMap = + serde_json::from_str(&arguments_string).unwrap_or_else(|e| { + error!( + "[send_message] Failed to parse tool arguments: {}, arguments: {}", + e, + arguments_string + ); + HashMap::new() + }); + tool_calls.push(ToolCall { + id: cur_tool_call_id.clone(), + name: cur_tool_call_name.clone(), + arguments, + }); + debug!( + "[send_message] Tool call arguments complete: {}", + cur_tool_call_name + ); + json_checker.reset(); + } + } + } + Err(e) => return Err(e), + } + } + + let reasoning_content = if full_reasoning.is_empty() { + None + } else { + Some(full_reasoning) + }; + + let tool_calls_result = if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }; + + let response = GeminiResponse { + text: full_text, + reasoning_content, + tool_calls: tool_calls_result, + usage, + finish_reason, + provider_metadata, + }; + + Ok(response) + } + + pub async fn test_connection(&self) -> Result { + let start_time = std::time::Instant::now(); + + // Reuse the normal chat request path so the test matches real conversations, even when + // a provider rejects stricter tool_choice settings such as "required". + let test_messages = vec![Message::user( + "Call the get_weather tool for city=Beijing. Do not answer with plain text." + .to_string(), + )]; + let tools = Some(vec![ToolDefinition { + name: "get_weather".to_string(), + description: "Get the weather of a city".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "city": { "type": "string", "description": "The city to get the weather for" } + }, + "required": ["city"], + "additionalProperties": false + }), + }]); + + let result = self.send_message(test_messages, tools).await; + + match result { + Ok(response) => { + let response_time_ms = start_time.elapsed().as_millis() as u64; + if response.tool_calls.is_some() { + Ok(ConnectionTestResult { + success: true, + response_time_ms, + model_response: Some(response.text), + message_code: None, + error_details: None, + }) + } else { + Ok(ConnectionTestResult { + success: true, + response_time_ms, + model_response: Some(response.text), + message_code: Some(ConnectionTestMessageCode::ToolCallsNotDetected), + error_details: None, + }) + } + } + Err(e) => { + let response_time_ms = start_time.elapsed().as_millis() as u64; + let error_msg = format!("{}", e); + debug!("test connection failed: {}", error_msg); + Ok(ConnectionTestResult { + success: false, + response_time_ms, + model_response: None, + message_code: None, + error_details: Some(error_msg), + }) + } + } + } + + pub async fn test_image_input_connection(&self) -> Result { + let start_time = std::time::Instant::now(); + let provider = self.config.format.to_ascii_lowercase(); + let prompt = "Inspect the attached image and reply with exactly one 4-letter code for quadrant colors in TL,TR,BL,BR order using letters R,G,B,Y (R=red, G=green, B=blue, Y=yellow)."; + + let content = if provider == "anthropic" { + serde_json::json!([ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": Self::TEST_IMAGE_PNG_BASE64 + } + }, + { + "type": "text", + "text": prompt + } + ]) + } else { + serde_json::json!([ + { + "type": "image_url", + "image_url": { + "url": format!("data:image/png;base64,{}", Self::TEST_IMAGE_PNG_BASE64) + } + }, + { + "type": "text", + "text": prompt + } + ]) + }; + + let test_messages = vec![Message { + role: "user".to_string(), + content: Some(content.to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + tool_image_attachments: None, + }]; + + match self.send_message(test_messages, None).await { + Ok(response) => { + let matched = Self::image_test_response_matches_expected(&response.text); + + if matched { + Ok(ConnectionTestResult { + success: true, + response_time_ms: start_time.elapsed().as_millis() as u64, + model_response: Some(response.text), + message_code: None, + error_details: None, + }) + } else { + let detail = format!( + "Image understanding verification failed: expected code '{}', got response '{}'", + Self::TEST_IMAGE_EXPECTED_CODE, response.text + ); + debug!("test image input connection failed: {}", detail); + Ok(ConnectionTestResult { + success: false, + response_time_ms: start_time.elapsed().as_millis() as u64, + model_response: Some(response.text), + message_code: Some(ConnectionTestMessageCode::ImageInputCheckFailed), + error_details: Some(detail), + }) + } + } + Err(e) => { + let error_msg = format!("{}", e); + debug!("test image input connection failed: {}", error_msg); + Ok(ConnectionTestResult { + success: false, + response_time_ms: start_time.elapsed().as_millis() as u64, + model_response: None, + message_code: None, + error_details: Some(error_msg), + }) + } + } + } + + pub async fn list_models(&self) -> Result> { + match self.get_api_format().to_ascii_lowercase().as_str() { + "openai" | "response" | "responses" => self.list_openai_models().await, + "anthropic" => self.list_anthropic_models().await, + format if Self::is_gemini_api_format(format) => self.list_gemini_models().await, + unsupported => Err(anyhow!( + "Listing models is not supported for API format: {}", + unsupported + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::AIClient; + use crate::connection_test::providers::gemini::GeminiMessageConverter; + use crate::connection_test::types::{AIConfig, ToolDefinition}; + use serde_json::json; + + fn make_test_client(format: &str, custom_request_body: Option) -> AIClient { + AIClient::new(AIConfig { + name: "test".to_string(), + base_url: "https://example.com/v1".to_string(), + request_url: "https://example.com/v1/chat/completions".to_string(), + api_key: "test-key".to_string(), + model: "test-model".to_string(), + format: format.to_string(), + context_window: 128000, + max_tokens: Some(8192), + temperature: None, + top_p: None, + enable_thinking_process: false, + support_preserved_thinking: false, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + custom_request_body, + }) + } + + #[test] + fn resolves_openai_models_url_from_completion_endpoint() { + let client = AIClient::new(AIConfig { + name: "test".to_string(), + base_url: "https://api.openai.com/v1/chat/completions".to_string(), + request_url: "https://api.openai.com/v1/chat/completions".to_string(), + api_key: "test-key".to_string(), + model: "gpt-4.1".to_string(), + format: "openai".to_string(), + context_window: 128000, + max_tokens: Some(8192), + temperature: None, + top_p: None, + enable_thinking_process: false, + support_preserved_thinking: false, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + custom_request_body: None, + }); + + assert_eq!( + client.resolve_openai_models_url(), + "https://api.openai.com/v1/models" + ); + } + + #[test] + fn resolves_anthropic_models_url_from_messages_endpoint() { + let client = AIClient::new(AIConfig { + name: "test".to_string(), + base_url: "https://api.anthropic.com/v1/messages".to_string(), + request_url: "https://api.anthropic.com/v1/messages".to_string(), + api_key: "test-key".to_string(), + model: "claude-sonnet-4-5".to_string(), + format: "anthropic".to_string(), + context_window: 200000, + max_tokens: Some(8192), + temperature: None, + top_p: None, + enable_thinking_process: false, + support_preserved_thinking: false, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + custom_request_body: None, + }); + + assert_eq!( + client.resolve_anthropic_models_url(), + "https://api.anthropic.com/v1/models" + ); + } + + #[test] + fn build_gemini_request_body_translates_response_format_and_merges_generation_config() { + let client = AIClient::new(AIConfig { + name: "gemini".to_string(), + base_url: "https://example.com".to_string(), + request_url: "https://example.com/models/gemini-2.5-pro:streamGenerateContent?alt=sse" + .to_string(), + api_key: "test-key".to_string(), + model: "gemini-2.5-pro".to_string(), + format: "gemini".to_string(), + context_window: 128000, + max_tokens: Some(4096), + temperature: Some(0.2), + top_p: Some(0.8), + enable_thinking_process: true, + support_preserved_thinking: true, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + custom_request_body: None, + }); + + let request_body = client.build_gemini_request_body( + None, + vec![json!({ + "role": "user", + "parts": [{ "text": "hello" }] + })], + None, + Some(json!({ + "response_format": { + "type": "json_schema", + "json_schema": { + "schema": { + "type": "object", + "properties": { + "answer": { "type": "string" } + }, + "required": ["answer"], + "additionalProperties": false + } + } + }, + "stop": ["END"], + "generationConfig": { + "candidateCount": 1 + } + })), + ); + + assert_eq!(request_body["generationConfig"]["maxOutputTokens"], 4096); + assert_eq!(request_body["generationConfig"]["temperature"], 0.2); + assert_eq!(request_body["generationConfig"]["topP"], 0.8); + assert_eq!( + request_body["generationConfig"]["thinkingConfig"]["includeThoughts"], + true + ); + assert_eq!( + request_body["generationConfig"]["responseMimeType"], + "application/json" + ); + assert_eq!(request_body["generationConfig"]["candidateCount"], 1); + assert_eq!( + request_body["generationConfig"]["stopSequences"], + json!(["END"]) + ); + assert_eq!( + request_body["generationConfig"]["responseJsonSchema"]["required"], + json!(["answer"]) + ); + assert!(request_body["generationConfig"]["responseJsonSchema"] + .get("additionalProperties") + .is_none()); + assert!(request_body.get("response_format").is_none()); + assert!(request_body.get("stop").is_none()); + } + + #[test] + fn build_gemini_request_body_omits_function_calling_config_for_native_only_tools() { + let client = AIClient::new(AIConfig { + name: "gemini".to_string(), + base_url: "https://example.com".to_string(), + request_url: "https://example.com/models/gemini-2.5-pro:streamGenerateContent?alt=sse" + .to_string(), + api_key: "test-key".to_string(), + model: "gemini-2.5-pro".to_string(), + format: "gemini".to_string(), + context_window: 128000, + max_tokens: Some(4096), + temperature: None, + top_p: None, + enable_thinking_process: false, + support_preserved_thinking: true, + inline_think_in_text: false, + custom_headers: None, + custom_headers_mode: None, + skip_ssl_verify: false, + reasoning_effort: None, + custom_request_body: None, + }); + + let gemini_tools = GeminiMessageConverter::convert_tools(Some(vec![ToolDefinition { + name: "WebSearch".to_string(), + description: "Search the web".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "query": { "type": "string" } + } + }), + }])); + + let request_body = client.build_gemini_request_body( + None, + vec![json!({ + "role": "user", + "parts": [{ "text": "hello" }] + })], + gemini_tools, + None, + ); + + assert_eq!(request_body["tools"][0]["googleSearch"], json!({})); + assert!(request_body.get("toolConfig").is_none()); + } + + #[test] + fn streaming_http_client_does_not_apply_global_request_timeout() { + let client = make_test_client("openai", None); + let request = client + .client + .get("https://example.com/stream") + .build() + .expect("request should build"); + + assert_eq!(request.timeout(), None); + } +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/json_checker.rs b/BitFun-Installer/src-tauri/src/connection_test/json_checker.rs new file mode 100644 index 00000000..e014afb7 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/json_checker.rs @@ -0,0 +1,620 @@ +/// JSON integrity checker - detect whether streamed JSON is complete +/// +/// Primarily used to check whether tool-parameter JSON in AI streaming responses has been fully received. +/// Tolerates leading non-JSON content (e.g. spaces sent by some models) by discarding +/// everything before the first '{'. +#[derive(Debug)] +pub struct JsonChecker { + buffer: String, + stack: Vec, + in_string: bool, + escape_next: bool, + seen_left_brace: bool, +} + +impl JsonChecker { + pub fn new() -> Self { + Self { + buffer: String::new(), + stack: Vec::new(), + in_string: false, + escape_next: false, + seen_left_brace: false, + } + } + + pub fn append(&mut self, s: &str) { + let mut chars = s.chars(); + + while let Some(ch) = chars.next() { + // Discard everything before the first '{' + if !self.seen_left_brace { + if ch == '{' { + self.seen_left_brace = true; + self.stack.push('{'); + self.buffer.push(ch); + } + continue; + } + + self.buffer.push(ch); + + if self.escape_next { + self.escape_next = false; + continue; + } + + match ch { + '\\' if self.in_string => { + self.escape_next = true; + } + '"' => { + self.in_string = !self.in_string; + } + '{' if !self.in_string => { + self.stack.push('{'); + } + '}' if !self.in_string => { + if !self.stack.is_empty() { + self.stack.pop(); + } + } + _ => {} + } + } + } + + pub fn get_buffer(&self) -> String { + self.buffer.clone() + } + + pub fn is_valid(&self) -> bool { + self.stack.is_empty() && self.seen_left_brace + } + + pub fn reset(&mut self) { + self.buffer.clear(); + self.stack.clear(); + self.in_string = false; + self.escape_next = false; + self.seen_left_brace = false; + } +} + +impl Default for JsonChecker { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Helper: feed string as single chunk ── + + fn check_one_shot(input: &str) -> (bool, String) { + let mut c = JsonChecker::new(); + c.append(input); + (c.is_valid(), c.get_buffer()) + } + + // ── Helper: feed string char-by-char (worst-case chunking) ── + + fn check_char_by_char(input: &str) -> (bool, String) { + let mut c = JsonChecker::new(); + for ch in input.chars() { + c.append(&ch.to_string()); + } + (c.is_valid(), c.get_buffer()) + } + + // ── Basic validity ── + + #[test] + fn empty_input_is_invalid() { + let (valid, _) = check_one_shot(""); + assert!(!valid); + } + + #[test] + fn simple_empty_object() { + let (valid, buf) = check_one_shot("{}"); + assert!(valid); + assert_eq!(buf, "{}"); + } + + #[test] + fn simple_object_with_string_value() { + let input = r#"{"city": "Beijing"}"#; + let (valid, buf) = check_one_shot(input); + assert!(valid); + assert_eq!(buf, input); + } + + #[test] + fn nested_object() { + let input = r#"{"a": {"b": {"c": 1}}}"#; + let (valid, _) = check_one_shot(input); + assert!(valid); + } + + #[test] + fn incomplete_object_missing_closing_brace() { + let (valid, _) = check_one_shot(r#"{"key": "value""#); + assert!(!valid); + } + + #[test] + fn incomplete_object_open_string() { + let (valid, _) = check_one_shot(r#"{"key": "val"#); + assert!(!valid); + } + + // ── Leading garbage / whitespace (ByteDance model issue) ── + + #[test] + fn leading_space_before_brace() { + let (valid, buf) = check_one_shot(r#" {"city": "Beijing"}"#); + assert!(valid); + assert_eq!(buf, r#"{"city": "Beijing"}"#); + } + + #[test] + fn leading_multiple_spaces_and_newlines() { + let (valid, buf) = check_one_shot(" \n\t {\"a\": 1}"); + assert!(valid); + assert_eq!(buf, "{\"a\": 1}"); + } + + #[test] + fn leading_random_text_before_brace() { + let (valid, buf) = check_one_shot("some garbage {\"ok\": true}"); + assert!(valid); + assert_eq!(buf, "{\"ok\": true}"); + } + + #[test] + fn only_spaces_no_brace() { + let (valid, _) = check_one_shot(" "); + assert!(!valid); + } + + // ── Escape handling ── + + #[test] + fn escaped_quote_in_string() { + // JSON: {"msg": "say \"hello\""} + let input = r#"{"msg": "say \"hello\""}"#; + let (valid, _) = check_one_shot(input); + assert!(valid); + } + + #[test] + fn escaped_backslash_before_quote() { + // JSON: {"path": "C:\\"} — value is C:\, the \\ is an escaped backslash + let input = r#"{"path": "C:\\"}"#; + let (valid, _) = check_one_shot(input); + assert!(valid); + } + + #[test] + fn escaped_backslash_followed_by_quote_char_by_char() { + // Ensure escape state survives across single-char chunks + let input = r#"{"path": "C:\\"}"#; + let (valid, buf) = check_char_by_char(input); + assert!(valid); + assert_eq!(buf, input); + } + + #[test] + fn braces_inside_string_are_ignored() { + let input = r#"{"code": "fn main() { println!(\"hi\"); }"}"#; + let (valid, _) = check_one_shot(input); + assert!(valid); + } + + #[test] + fn braces_inside_string_char_by_char() { + let input = r#"{"code": "fn main() { println!(\"hi\"); }"}"#; + let (valid, _) = check_char_by_char(input); + assert!(valid); + } + + // ── Cross-chunk escape: the exact ByteDance bug scenario ── + + #[test] + fn escape_split_across_chunks() { + // Simulates: {"new_string": "fn main() {\n println!(\"Hello, World!\");\n}"} + // The backslash and the quote land in different chunks + let mut c = JsonChecker::new(); + c.append(r#"{"new_string": "fn main() {\n println!(\"Hello, World!"#); + assert!(!c.is_valid()); + + // chunk ends with backslash + c.append("\\"); + assert!(!c.is_valid()); + + // next chunk starts with escaped quote — must NOT end the string + c.append("\""); + assert!(!c.is_valid()); + + c.append(r#");\n}"}"#); + assert!(c.is_valid()); + } + + #[test] + fn escape_at_chunk_boundary_does_not_leak() { + // After the escaped char is consumed, escape_next should be false + let mut c = JsonChecker::new(); + c.append(r#"{"a": "x\"#); // ends with backslash inside string + assert!(!c.is_valid()); + + c.append("n"); // \n escape sequence complete + assert!(!c.is_valid()); + + c.append(r#""}"#); // close string and object + assert!(c.is_valid()); + } + + // ── Realistic streaming simulation ── + + #[test] + fn bytedance_doubao_streaming_simulation() { + // Reproduces the exact chunking pattern from the bug report + let mut c = JsonChecker::new(); + c.append(""); // empty first arguments chunk + c.append(" {\""); // leading space + opening brace + assert!(!c.is_valid()); + + c.append("city"); + c.append("\":"); + c.append(" \""); + c.append("Beijing"); + c.append("\"}"); + assert!(c.is_valid()); + assert_eq!(c.get_buffer(), r#"{"city": "Beijing"}"#); + } + + #[test] + fn edit_tool_streaming_simulation() { + // Reproduces the Edit tool call from the second bug report + let mut c = JsonChecker::new(); + c.append("{\"file_path\": \"E:/Projects/ForTest/basic-rust/src/main.rs\", \"new_string\": \"fn main() {\\n println!(\\\"Hello,"); + c.append(" World"); + c.append("!\\"); // backslash at chunk end + c.append("\");"); // escaped quote at chunk start — must stay in string + assert!(!c.is_valid()); + + c.append("\\"); // another backslash at chunk end + c.append("n"); // \n escape + c.append("}\","); // closing brace inside string, then close string, comma + assert!(!c.is_valid()); // object not yet closed + + c.append(" \"old_string\": \"\""); + c.append("}"); + assert!(c.is_valid()); + } + + // ── Reset ── + + #[test] + fn reset_clears_all_state() { + let mut c = JsonChecker::new(); + c.append(r#" {"key": "val"#); // leading space, incomplete + assert!(!c.is_valid()); + + c.reset(); + assert!(!c.is_valid()); + assert_eq!(c.get_buffer(), ""); + + // Should work fresh after reset + c.append(r#"{"ok": true}"#); + assert!(c.is_valid()); + } + + #[test] + fn reset_clears_escape_state() { + let mut c = JsonChecker::new(); + c.append(r#"{"a": "\"#); // ends mid-escape + c.reset(); + + // The stale escape_next must not affect the new input + c.append(r#"{"b": "x"}"#); + assert!(c.is_valid()); + } + + // ── Edge cases ── + + #[test] + fn multiple_top_level_objects_first_wins() { + // After the first object completes, is_valid becomes true; + // subsequent data keeps appending but re-opens the stack + let mut c = JsonChecker::new(); + c.append("{}"); + assert!(c.is_valid()); + + c.append("{}"); + // stack opens and closes again, still valid + assert!(c.is_valid()); + } + + #[test] + fn deeply_nested_objects() { + let input = r#"{"a":{"b":{"c":{"d":{"e":{}}}}}}"#; + let (valid, _) = check_one_shot(input); + assert!(valid); + } + + #[test] + fn string_with_unicode_escapes() { + let input = r#"{"emoji": "\u0048\u0065\u006C\u006C\u006F"}"#; + let (valid, _) = check_one_shot(input); + assert!(valid); + } + + #[test] + fn string_with_newlines_and_tabs() { + let input = r#"{"text": "line1\nline2\ttab"}"#; + let (valid, _) = check_one_shot(input); + assert!(valid); + } + + #[test] + fn consecutive_escaped_backslashes() { + // JSON value: a\\b — two backslashes, meaning literal backslash in value + let input = r#"{"p": "a\\\\b"}"#; + let (valid, _) = check_one_shot(input); + assert!(valid); + } + + #[test] + fn consecutive_escaped_backslashes_char_by_char() { + let input = r#"{"p": "a\\\\b"}"#; + let (valid, _) = check_char_by_char(input); + assert!(valid); + } + + #[test] + fn default_trait_works() { + let c = JsonChecker::default(); + assert!(!c.is_valid()); + assert_eq!(c.get_buffer(), ""); + } + + // ── Streaming: no premature is_valid() ── + + #[test] + fn never_valid_during_progressive_append() { + // Feed a complete JSON object token-by-token, assert is_valid() is false + // at every step except after the final '}' + let chunks = vec![ + "{", "\"", "k", "e", "y", "\"", ":", " ", "\"", "v", "a", "l", "\"", "}", + ]; + let mut c = JsonChecker::new(); + for (i, chunk) in chunks.iter().enumerate() { + c.append(chunk); + if i < chunks.len() - 1 { + assert!( + !c.is_valid(), + "premature valid at chunk index {}: {:?}", + i, + c.get_buffer() + ); + } + } + assert!(c.is_valid()); + assert_eq!(c.get_buffer(), r#"{"key": "val"}"#); + } + + #[test] + fn never_valid_during_nested_object_streaming() { + // {"a": {"b": 1}} streamed in realistic chunks + let chunks = vec!["{\"a\"", ": ", "{\"b\"", ": 1", "}", "}"]; + let mut c = JsonChecker::new(); + for (i, chunk) in chunks.iter().enumerate() { + c.append(chunk); + if i < chunks.len() - 1 { + assert!( + !c.is_valid(), + "premature valid at chunk index {}: {:?}", + i, + c.get_buffer() + ); + } + } + assert!(c.is_valid()); + } + + #[test] + fn string_with_braces_never_premature_valid() { + // {"code": "{ } { }"} — braces inside string must not close the object + let chunks = vec!["{\"code\": \"", "{ ", "} ", "{ ", "}", "\"", "}"]; + let mut c = JsonChecker::new(); + for (i, chunk) in chunks.iter().enumerate() { + c.append(chunk); + if i < chunks.len() - 1 { + assert!( + !c.is_valid(), + "premature valid at chunk index {}: {:?}", + i, + c.get_buffer() + ); + } + } + assert!(c.is_valid()); + } + + // ── Streaming: empty chunks interspersed ── + + #[test] + fn empty_chunks_between_data() { + let mut c = JsonChecker::new(); + c.append(""); + assert!(!c.is_valid()); + c.append("{"); + assert!(!c.is_valid()); + c.append(""); + assert!(!c.is_valid()); + c.append("\"a\""); + c.append(""); + c.append(": 1"); + c.append(""); + c.append(""); + c.append("}"); + assert!(c.is_valid()); + assert_eq!(c.get_buffer(), r#"{"a": 1}"#); + } + + #[test] + fn empty_chunks_before_first_brace() { + let mut c = JsonChecker::new(); + c.append(""); + c.append(""); + c.append(""); + assert!(!c.is_valid()); + c.append(" "); + assert!(!c.is_valid()); + c.append("{}"); + assert!(c.is_valid()); + } + + // ── Streaming: \\\" sequence split at different positions ── + + #[test] + fn escaped_backslash_then_escaped_quote_split_1() { + // JSON: {"a": "x\\\"y"} — value is x\"y (backslash, quote, y) + // Split: `{"a": "x\` | `\` | `\` | `"` | `y"}` + // Char-by-char through the \\\" sequence + let mut c = JsonChecker::new(); + c.append(r#"{"a": "x"#); + assert!(!c.is_valid()); + c.append("\\"); // first \ of \\, sets escape_next + assert!(!c.is_valid()); + c.append("\\"); // consumed by escape (it's the escaped backslash), then done + assert!(!c.is_valid()); + c.append("\\"); // first \ of \", sets escape_next + assert!(!c.is_valid()); + c.append("\""); // consumed by escape (it's the escaped quote) + assert!(!c.is_valid()); // still inside string! + c.append("y"); + assert!(!c.is_valid()); + c.append("\"}"); + assert!(c.is_valid()); + } + + #[test] + fn escaped_backslash_then_escaped_quote_split_2() { + // Same JSON: {"a": "x\\\"y"} but split as: `...x\\` | `\"y"}` + let mut c = JsonChecker::new(); + c.append(r#"{"a": "x\\"#); // \\ = escaped backslash, escape_next consumed + assert!(!c.is_valid()); + c.append(r#"\"y"}"#); // \" = escaped quote, y, close string, close object + assert!(c.is_valid()); + } + + #[test] + fn escaped_backslash_then_escaped_quote_split_3() { + // Same JSON but split as: `...x\` | `\\` | `"y"}` + let mut c = JsonChecker::new(); + c.append(r#"{"a": "x\"#); // \ sets escape_next + assert!(!c.is_valid()); + c.append("\\\\"); // first \ consumed by escape, second \ sets escape_next + assert!(!c.is_valid()); + c.append("\"y\"}"); // " consumed by escape, y normal, " closes string, } closes object + assert!(c.is_valid()); + } + + // ── Streaming: escaped backslash + closing quote ── + + #[test] + fn escaped_backslash_then_closing_quote_split_at_boundary() { + // JSON: {"a": "x\\"} — value is x\ (escaped backslash), then " closes string + // Split as: `{"a": "x\` | `\"}` — \ crosses chunk boundary + let mut c = JsonChecker::new(); + c.append(r#"{"a": "x\"#); // \ sets escape_next + assert!(!c.is_valid()); + c.append("\\\"}"); // \ consumed by escape, " closes string, } closes object + assert!(c.is_valid()); + assert_eq!(c.get_buffer(), r#"{"a": "x\\"}"#); + } + + #[test] + fn escaped_backslash_then_closing_quote_split_after_pair() { + // Same JSON: {"a": "x\\"} — split as: `{"a": "x\\` | `"}` + let mut c = JsonChecker::new(); + c.append(r#"{"a": "x\\"#); // \\ pair complete, escape_next = false + assert!(!c.is_valid()); + c.append("\"}"); // " closes string, } closes object + assert!(c.is_valid()); + } + + // ── Streaming: multiple tool calls with reset (full lifecycle) ── + + #[test] + fn lifecycle_multiple_tool_calls_with_reset() { + let mut c = JsonChecker::new(); + + // --- Tool call 1: simple --- + c.append(" "); // leading space (ByteDance) + c.append("{\""); + c.append("city\": \"Beijing\"}"); + assert!(c.is_valid()); + assert_eq!(c.get_buffer(), r#"{"city": "Beijing"}"#); + + // --- Reset for tool call 2 --- + c.reset(); + assert!(!c.is_valid()); + assert_eq!(c.get_buffer(), ""); + + // --- Tool call 2: with escapes --- + c.append("{\"code\": \""); + assert!(!c.is_valid()); + c.append("fn main() {\\n"); + assert!(!c.is_valid()); + c.append(" println!(\\\"hi\\\");"); + assert!(!c.is_valid()); + c.append("\\n}\"}"); + assert!(c.is_valid()); + + // --- Reset for tool call 3 --- + c.reset(); + assert!(!c.is_valid()); + + // --- Tool call 3: empty object --- + c.append("{}"); + assert!(c.is_valid()); + } + + #[test] + fn lifecycle_reset_mid_escape_then_new_tool_call() { + let mut c = JsonChecker::new(); + + // Tool call 1: interrupted mid-escape + c.append("{\"a\": \"x\\"); // ends with pending escape + assert!(!c.is_valid()); + + // Reset before completion (e.g. stream error) + c.reset(); + + // Tool call 2: must work cleanly with no stale escape state + c.append("{\"b\": \"y\"}"); + assert!(c.is_valid()); + assert_eq!(c.get_buffer(), r#"{"b": "y"}"#); + } + + #[test] + fn lifecycle_reset_mid_string_then_new_tool_call() { + let mut c = JsonChecker::new(); + + // Tool call 1: interrupted inside string + c.append("{\"a\": \"some text"); + assert!(!c.is_valid()); + + c.reset(); + + // Tool call 2: must not think it's still in a string + c.append("{\"b\": \"{}\"}"); // braces inside string value + assert!(c.is_valid()); + } +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/mod.rs b/BitFun-Installer/src-tauri/src/connection_test/mod.rs new file mode 100644 index 00000000..be72a7b4 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/mod.rs @@ -0,0 +1,10 @@ +//! Full copy of `bitfun_core` AI client + stream stack for installer connection tests (no `bitfun_core` dependency). +#![allow(dead_code)] + +pub mod client; +pub mod json_checker; +pub mod proxy; +pub mod providers; +pub mod types; + +pub use client::AIClient; diff --git a/BitFun-Installer/src-tauri/src/connection_test/providers/anthropic/message_converter.rs b/BitFun-Installer/src-tauri/src/connection_test/providers/anthropic/message_converter.rs new file mode 100644 index 00000000..750ea4db --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/providers/anthropic/message_converter.rs @@ -0,0 +1,226 @@ +//! Anthropic message format converter +//! +//! Converts the unified message format to Anthropic Claude API format + +use crate::connection_test::types::{Message, ToolDefinition}; +use log::warn; +use serde_json::{json, Value}; + +pub struct AnthropicMessageConverter; + +impl AnthropicMessageConverter { + /// Convert unified message format to Anthropic format + /// + /// Note: Anthropic requires system messages to be handled separately, not in the messages array + pub fn convert_messages(messages: Vec) -> (Option, Vec) { + let mut system_message = None; + let mut anthropic_messages = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if let Some(content) = msg.content { + system_message = Some(content); + } + } + "user" => { + anthropic_messages.push(Self::convert_user_message(msg)); + } + "assistant" => { + if let Some(converted) = Self::convert_assistant_message(msg) { + anthropic_messages.push(converted); + } + } + "tool" => { + anthropic_messages.push(Self::convert_tool_result_message(msg)); + } + _ => { + warn!("Unknown message role: {}", msg.role); + } + } + } + + // Anthropic requires user/assistant messages to alternate + let merged_messages = Self::merge_consecutive_messages(anthropic_messages); + + (system_message, merged_messages) + } + + /// Merge consecutive same-role messages to keep user/assistant alternating + fn merge_consecutive_messages(messages: Vec) -> Vec { + let mut merged: Vec = Vec::new(); + + for msg in messages { + let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or(""); + + if let Some(last) = merged.last_mut() { + let last_role = last.get("role").and_then(|r| r.as_str()).unwrap_or(""); + + if last_role == role && role == "user" { + let current_content = msg.get("content"); + let last_content = last.get_mut("content"); + + match (last_content, current_content) { + (Some(Value::Array(last_arr)), Some(Value::Array(curr_arr))) => { + last_arr.extend(curr_arr.clone()); + continue; + } + (Some(Value::Array(last_arr)), Some(Value::String(curr_str))) => { + last_arr.push(json!({ + "type": "text", + "text": curr_str + })); + continue; + } + (Some(Value::String(last_str)), Some(Value::Array(curr_arr))) => { + let mut new_content = vec![json!({ + "type": "text", + "text": last_str + })]; + new_content.extend(curr_arr.clone()); + *last = json!({ + "role": "user", + "content": new_content + }); + continue; + } + (Some(Value::String(last_str)), Some(Value::String(curr_str))) => { + let merged_text = if last_str.is_empty() { + curr_str.to_string() + } else { + format!("{}\n\n{}", last_str, curr_str) + }; + *last = json!({ + "role": "user", + "content": merged_text + }); + continue; + } + _ => {} + } + } + } + + merged.push(msg); + } + + merged + } + + fn convert_user_message(msg: Message) -> Value { + let content = msg.content.unwrap_or_default(); + + if let Ok(parsed) = serde_json::from_str::(&content) { + if parsed.is_array() { + return json!({ + "role": "user", + "content": parsed + }); + } + } + + json!({ + "role": "user", + "content": content + }) + } + + /// Convert assistant messages; return None when empty. + fn convert_assistant_message(msg: Message) -> Option { + let mut content = Vec::new(); + + if let Some(thinking) = msg.reasoning_content.as_ref() { + if !thinking.is_empty() { + let mut thinking_block = json!({ + "type": "thinking", + "thinking": thinking + }); + + thinking_block["signature"] = + json!(msg.thinking_signature.as_deref().unwrap_or("")); + + content.push(thinking_block); + } + } + + if let Some(text) = msg.content { + if !text.is_empty() { + content.push(json!({ + "type": "text", + "text": text + })); + } + } + + if let Some(tool_calls) = msg.tool_calls { + for tc in tool_calls { + content.push(json!({ + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.arguments + })); + } + } + + if content.is_empty() { + None + } else { + Some(json!({ + "role": "assistant", + "content": content + })) + } + } + + fn convert_tool_result_message(msg: Message) -> Value { + let tool_call_id = msg.tool_call_id.unwrap_or_default(); + let text = msg.content.unwrap_or_default(); + + let tool_content: Value = + if let Some(attachments) = msg.tool_image_attachments.filter(|a| !a.is_empty()) { + let mut blocks: Vec = attachments + .into_iter() + .map(|att| { + json!({ + "type": "image", + "source": { + "type": "base64", + "media_type": att.mime_type, + "data": att.data_base64, + } + }) + }) + .collect(); + blocks.push(json!({ "type": "text", "text": text })); + json!(blocks) + } else { + json!(text) + }; + + json!({ + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": tool_call_id, + "content": tool_content + }] + }) + } + + /// Convert tool definitions to Anthropic format + pub fn convert_tools(tools: Option>) -> Option> { + tools.map(|tool_defs| { + tool_defs + .into_iter() + .map(|tool| { + json!({ + "name": tool.name, + "description": tool.description, + "input_schema": tool.parameters + }) + }) + .collect() + }) + } +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/providers/anthropic/mod.rs b/BitFun-Installer/src-tauri/src/connection_test/providers/anthropic/mod.rs new file mode 100644 index 00000000..e01d6710 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/providers/anthropic/mod.rs @@ -0,0 +1,7 @@ +//! Anthropic Claude API provider +//! +//! Implements interaction with Anthropic Claude models + +pub mod message_converter; + +pub use message_converter::AnthropicMessageConverter; diff --git a/BitFun-Installer/src-tauri/src/connection_test/providers/gemini/message_converter.rs b/BitFun-Installer/src-tauri/src/connection_test/providers/gemini/message_converter.rs new file mode 100644 index 00000000..313333bf --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/providers/gemini/message_converter.rs @@ -0,0 +1,919 @@ +//! Gemini message format converter + +use crate::connection_test::types::{Message, ToolDefinition}; +use log::warn; +use serde_json::{json, Map, Value}; + +pub struct GeminiMessageConverter; + +impl GeminiMessageConverter { + pub fn convert_messages( + messages: Vec, + model_name: &str, + ) -> (Option, Vec) { + let mut system_texts = Vec::new(); + let mut contents = Vec::new(); + let is_gemini_3 = model_name.contains("gemini-3"); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if let Some(content) = msg.content.filter(|content| !content.trim().is_empty()) + { + system_texts.push(content); + } + } + "user" => { + let parts = Self::convert_content_parts(msg.content.as_deref(), false); + Self::push_content(&mut contents, "user", parts); + } + "assistant" => { + let mut parts = Vec::new(); + + let mut pending_thought_signature = msg + .thinking_signature + .filter(|value| !value.trim().is_empty()); + let has_tool_calls = msg + .tool_calls + .as_ref() + .map(|tool_calls| !tool_calls.is_empty()) + .unwrap_or(false); + + if let Some(content) = msg + .content + .as_deref() + .filter(|value| !value.trim().is_empty()) + { + if !has_tool_calls { + if let Some(signature) = pending_thought_signature.take() { + parts.push(json!({ + "thoughtSignature": signature, + })); + } + } + parts.extend(Self::convert_content_parts(Some(content), true)); + } + + if let Some(tool_calls) = msg.tool_calls { + for (tool_call_index, tool_call) in tool_calls.into_iter().enumerate() { + let mut part = Map::new(); + part.insert( + "functionCall".to_string(), + json!({ + "name": tool_call.name, + "args": tool_call.arguments, + }), + ); + + match pending_thought_signature.take() { + Some(signature) => { + part.insert( + "thoughtSignature".to_string(), + Value::String(signature), + ); + } + None if is_gemini_3 && tool_call_index == 0 => { + part.insert( + "thoughtSignature".to_string(), + Value::String( + "skip_thought_signature_validator".to_string(), + ), + ); + } + None => {} + } + + parts.push(Value::Object(part)); + } + } + + if let Some(signature) = pending_thought_signature { + parts.push(json!({ + "thoughtSignature": signature, + })); + } + + Self::push_content(&mut contents, "model", parts); + } + "tool" => { + let tool_name = msg.name.unwrap_or_default(); + if tool_name.is_empty() { + warn!("Skipping Gemini tool response without tool name"); + continue; + } + + let response = Self::parse_tool_response(msg.content.as_deref()); + let parts = vec![json!({ + "functionResponse": { + "name": tool_name, + "response": response, + } + })]; + + Self::push_content(&mut contents, "user", parts); + } + _ => { + warn!("Unknown Gemini message role: {}", msg.role); + } + } + } + + let system_instruction = if system_texts.is_empty() { + None + } else { + Some(json!({ + "parts": [{ + "text": system_texts.join("\n\n") + }] + })) + }; + + (system_instruction, contents) + } + + pub fn convert_tools(tools: Option>) -> Option> { + tools.and_then(|tool_defs| { + let mut native_tools = Vec::new(); + let mut custom_tools = Vec::new(); + + for tool in tool_defs { + if let Some(native_tool) = Self::convert_native_tool(&tool) { + native_tools.push(native_tool); + } else { + custom_tools.push(tool); + } + } + + // Gemini providers such as AIHubMix reject requests that mix built-in tools + // with custom function declarations. When custom tools are present, keep all + // tools in function-calling mode so BitFun's local tool pipeline still works. + let should_fallback_to_function_calling = + !native_tools.is_empty() && !custom_tools.is_empty(); + + let declarations: Vec = if should_fallback_to_function_calling { + custom_tools + .into_iter() + .chain( + native_tools + .iter() + .cloned() + .filter_map(Self::convert_native_tool_to_custom_definition), + ) + .map(Self::convert_custom_tool) + .collect() + } else { + custom_tools + .into_iter() + .map(Self::convert_custom_tool) + .collect() + }; + + let mut result_tools = if should_fallback_to_function_calling { + Vec::new() + } else { + native_tools + }; + + if !declarations.is_empty() { + result_tools.push(json!({ + "functionDeclarations": declarations, + })); + } + + if result_tools.is_empty() { + None + } else { + Some(result_tools) + } + }) + } + + pub fn sanitize_schema(value: Value) -> Value { + Self::strip_unsupported_schema_fields(value) + } + + fn convert_native_tool(tool: &ToolDefinition) -> Option { + let native_name = Self::native_tool_name(&tool.name)?; + let config = Self::native_tool_config(&tool.parameters); + Some(json!({ + native_name: config, + })) + } + + fn convert_native_tool_to_custom_definition(native_tool: Value) -> Option { + let map = native_tool.as_object()?; + let (name, _config) = map.iter().next()?; + + Some(ToolDefinition { + name: Self::native_tool_fallback_name(name).to_string(), + description: Self::native_tool_fallback_description(name).to_string(), + parameters: Self::native_tool_fallback_schema(name), + }) + } + + fn convert_custom_tool(tool: ToolDefinition) -> Value { + let parameters = Self::sanitize_schema(tool.parameters); + json!({ + "name": tool.name, + "description": tool.description, + "parameters": parameters, + }) + } + + fn native_tool_name(tool_name: &str) -> Option<&'static str> { + match tool_name { + "WebSearch" | "googleSearch" | "GoogleSearch" => Some("googleSearch"), + "WebFetch" | "urlContext" | "UrlContext" | "URLContext" => Some("urlContext"), + "googleSearchRetrieval" | "GoogleSearchRetrieval" => Some("googleSearchRetrieval"), + "codeExecution" | "CodeExecution" => Some("codeExecution"), + _ => None, + } + } + + fn native_tool_fallback_name(native_name: &str) -> &'static str { + match native_name { + "googleSearch" => "WebSearch", + "urlContext" => "WebFetch", + "googleSearchRetrieval" => "googleSearchRetrieval", + "codeExecution" => "codeExecution", + _ => "unknown_native_tool", + } + } + + fn native_tool_fallback_description(native_name: &str) -> &'static str { + match native_name { + "googleSearch" => "Search the web for up-to-date information.", + "urlContext" => "Fetch content from a URL for context.", + "googleSearchRetrieval" => "Retrieve grounded results from Google Search.", + "codeExecution" => "Execute model-generated code and return the result.", + _ => "Gemini native tool fallback.", + } + } + + fn native_tool_fallback_schema(native_name: &str) -> Value { + match native_name { + "googleSearch" | "googleSearchRetrieval" => json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + } + }, + "required": ["query"] + }), + "urlContext" => json!({ + "type": "object", + "properties": { + "url": { + "type": "string", + } + }, + "required": ["url"] + }), + "codeExecution" => json!({ + "type": "object", + "properties": {} + }), + _ => json!({ + "type": "object", + "properties": {} + }), + } + } + + fn native_tool_config(parameters: &Value) -> Value { + if Self::looks_like_schema(parameters) { + json!({}) + } else { + match parameters { + Value::Object(map) if !map.is_empty() => parameters.clone(), + _ => json!({}), + } + } + } + + fn looks_like_schema(parameters: &Value) -> bool { + let Some(map) = parameters.as_object() else { + return false; + }; + + map.contains_key("type") + || map.contains_key("properties") + || map.contains_key("required") + || map.contains_key("$schema") + || map.contains_key("items") + || map.contains_key("allOf") + || map.contains_key("anyOf") + || map.contains_key("oneOf") + || map.contains_key("enum") + || map.contains_key("nullable") + || map.contains_key("format") + } + + fn push_content(contents: &mut Vec, role: &str, parts: Vec) { + if parts.is_empty() { + return; + } + + if let Some(last) = contents.last_mut() { + let last_role = last.get("role").and_then(Value::as_str).unwrap_or_default(); + if last_role == role { + if let Some(existing_parts) = last.get_mut("parts").and_then(Value::as_array_mut) { + existing_parts.extend(parts); + return; + } + } + } + + contents.push(json!({ + "role": role, + "parts": parts, + })); + } + + fn convert_content_parts(content: Option<&str>, is_model_role: bool) -> Vec { + let Some(content) = content else { + return Vec::new(); + }; + + if content.trim().is_empty() { + return Vec::new(); + } + + let parsed = match serde_json::from_str::(content) { + Ok(parsed) if parsed.is_array() => parsed, + _ => return vec![json!({ "text": content })], + }; + + let mut parts = Vec::new(); + + if let Some(items) = parsed.as_array() { + for item in items { + let item_type = item.get("type").and_then(Value::as_str); + match item_type { + Some("text") | Some("input_text") | Some("output_text") => { + if let Some(text) = item.get("text").and_then(Value::as_str) { + if !text.is_empty() { + parts.push(json!({ "text": text })); + } + } + } + Some("image_url") if !is_model_role => { + if let Some(url) = item.get("image_url").and_then(|value| { + value + .get("url") + .and_then(Value::as_str) + .or_else(|| value.as_str()) + }) { + if let Some(part) = Self::convert_image_url_to_part(url) { + parts.push(part); + } + } + } + Some("image") if !is_model_role => { + let source = item.get("source"); + let mime_type = source + .and_then(|value| value.get("media_type")) + .and_then(Value::as_str); + let data = source + .and_then(|value| value.get("data")) + .and_then(Value::as_str); + + if let (Some(mime_type), Some(data)) = (mime_type, data) { + parts.push(json!({ + "inlineData": { + "mimeType": mime_type, + "data": data, + } + })); + } + } + _ => {} + } + } + } + + if parts.is_empty() { + vec![json!({ "text": content })] + } else { + parts + } + } + + fn convert_image_url_to_part(url: &str) -> Option { + let prefix = "data:"; + if !url.starts_with(prefix) { + warn!("Gemini currently supports inline data URLs for image parts; skipping unsupported image URL"); + return None; + } + + let rest = &url[prefix.len()..]; + let (mime_type, data) = rest.split_once(";base64,")?; + if mime_type.is_empty() || data.is_empty() { + return None; + } + + Some(json!({ + "inlineData": { + "mimeType": mime_type, + "data": data, + } + })) + } + + fn parse_tool_response(content: Option<&str>) -> Value { + let Some(content) = content.filter(|value| !value.trim().is_empty()) else { + return json!({ "content": "Tool execution completed" }); + }; + + match serde_json::from_str::(content) { + Ok(Value::Object(map)) => Value::Object(map), + Ok(value) => json!({ "content": value }), + Err(_) => json!({ "content": content }), + } + } + + fn strip_unsupported_schema_fields(value: Value) -> Value { + match value { + Value::Object(mut map) => { + let all_of = map.remove("allOf"); + let any_of = map.remove("anyOf"); + let one_of = map.remove("oneOf"); + let (normalized_type, nullable_from_type) = + Self::normalize_schema_type(map.remove("type")); + + let mut sanitized = Map::new(); + for (key, value) in map { + if key == "properties" { + if let Value::Object(properties) = value { + sanitized.insert( + key, + Value::Object( + properties + .into_iter() + .map(|(name, schema)| { + (name, Self::strip_unsupported_schema_fields(schema)) + }) + .collect(), + ), + ); + } + continue; + } + + if Self::is_supported_schema_key(&key) { + sanitized.insert(key, Self::strip_unsupported_schema_fields(value)); + } + } + + if let Some(all_of) = all_of { + Self::merge_schema_variants(&mut sanitized, all_of, true); + } + + let mut nullable = nullable_from_type; + if let Some(any_of) = any_of { + nullable |= Self::merge_union_variants(&mut sanitized, any_of); + } + if let Some(one_of) = one_of { + nullable |= Self::merge_union_variants(&mut sanitized, one_of); + } + + if let Some(schema_type) = normalized_type { + sanitized.insert("type".to_string(), Value::String(schema_type)); + } + if nullable { + sanitized.insert("nullable".to_string(), Value::Bool(true)); + } + + Value::Object(sanitized) + } + Value::Array(items) => Value::Array( + items + .into_iter() + .map(Self::strip_unsupported_schema_fields) + .collect(), + ), + other => other, + } + } + + fn is_supported_schema_key(key: &str) -> bool { + matches!( + key, + "type" + | "format" + | "description" + | "nullable" + | "enum" + | "items" + | "properties" + | "required" + | "minItems" + | "maxItems" + | "minimum" + | "maximum" + | "minLength" + | "maxLength" + | "pattern" + ) + } + + fn normalize_schema_type(type_value: Option) -> (Option, bool) { + match type_value { + Some(Value::String(value)) if value != "null" => (Some(value), false), + Some(Value::String(_)) => (None, true), + Some(Value::Array(values)) => { + let mut types = values + .into_iter() + .filter_map(|value| value.as_str().map(str::to_string)); + let mut nullable = false; + let mut selected = None; + + for value in types.by_ref() { + if value == "null" { + nullable = true; + } else if selected.is_none() { + selected = Some(value); + } + } + + (selected, nullable) + } + _ => (None, false), + } + } + + fn merge_union_variants(target: &mut Map, variants: Value) -> bool { + let mut nullable = false; + + if let Value::Array(variants) = variants { + for variant in variants { + let sanitized = Self::strip_unsupported_schema_fields(variant); + match sanitized { + Value::Object(map) => { + let is_null_only = map + .get("type") + .and_then(Value::as_str) + .map(|value| value == "null") + .unwrap_or(false) + && map.len() == 1; + + if is_null_only { + nullable = true; + continue; + } + + Self::merge_schema_map(target, map, false); + } + Value::String(value) if value == "null" => nullable = true, + _ => {} + } + } + } + + nullable + } + + fn merge_schema_variants( + target: &mut Map, + variants: Value, + preserve_required: bool, + ) { + if let Value::Array(variants) = variants { + for variant in variants { + if let Value::Object(map) = Self::strip_unsupported_schema_fields(variant) { + Self::merge_schema_map(target, map, preserve_required); + } + } + } + } + + fn merge_schema_map( + target: &mut Map, + source: Map, + preserve_required: bool, + ) { + for (key, value) in source { + match key.as_str() { + "properties" => { + if let Value::Object(source_props) = value { + let target_props = target + .entry(key) + .or_insert_with(|| Value::Object(Map::new())); + if let Value::Object(target_props) = target_props { + for (prop_key, prop_value) in source_props { + target_props.entry(prop_key).or_insert(prop_value); + } + } + } + } + "required" if preserve_required => { + if let Value::Array(source_required) = value { + let target_required = target + .entry(key) + .or_insert_with(|| Value::Array(Vec::new())); + if let Value::Array(target_required) = target_required { + for item in source_required { + if !target_required.contains(&item) { + target_required.push(item); + } + } + } + } + } + "nullable" => { + if value.as_bool().unwrap_or(false) { + target.insert(key, Value::Bool(true)); + } + } + "type" => { + target.entry(key).or_insert(value); + } + _ => { + target.entry(key).or_insert(value); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::GeminiMessageConverter; + use crate::connection_test::types::{Message, ToolCall, ToolDefinition}; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn converts_messages_to_gemini_format() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Beijing")); + + let messages = vec![ + Message::system("You are helpful".to_string()), + Message::user("Hello".to_string()), + Message { + role: "assistant".to_string(), + content: Some("Working on it".to_string()), + reasoning_content: Some("Let me think".to_string()), + thinking_signature: Some("sig_1".to_string()), + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args.clone(), + }]), + tool_call_id: None, + name: None, + tool_image_attachments: None, + }, + Message { + role: "tool".to_string(), + content: Some("Sunny".to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: Some("call_1".to_string()), + name: Some("get_weather".to_string()), + tool_image_attachments: None, + }, + ]; + + let (system_instruction, contents) = + GeminiMessageConverter::convert_messages(messages, "gemini-2.5-pro"); + + assert_eq!( + system_instruction.unwrap()["parts"][0]["text"], + json!("You are helpful") + ); + assert_eq!(contents.len(), 3); + assert_eq!(contents[0]["role"], json!("user")); + assert_eq!(contents[1]["role"], json!("model")); + assert_eq!(contents[1]["parts"][0]["text"], json!("Working on it")); + assert_eq!( + contents[1]["parts"][1]["functionCall"]["name"], + json!("get_weather") + ); + assert_eq!(contents[1]["parts"][1]["thoughtSignature"], json!("sig_1")); + assert_eq!( + contents[2]["parts"][0]["functionResponse"]["name"], + json!("get_weather") + ); + } + + #[test] + fn injects_skip_signature_for_first_synthetic_gemini_3_tool_call() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Paris")); + + let messages = vec![Message { + role: "assistant".to_string(), + content: None, + reasoning_content: None, + thinking_signature: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args, + }]), + tool_call_id: None, + name: None, + tool_image_attachments: None, + }]; + + let (_, contents) = + GeminiMessageConverter::convert_messages(messages, "gemini-3-flash-preview"); + + assert_eq!(contents.len(), 1); + assert_eq!( + contents[0]["parts"][0]["thoughtSignature"], + json!("skip_thought_signature_validator") + ); + } + + #[test] + fn converts_data_url_images_to_inline_data() { + let messages = vec![Message { + role: "user".to_string(), + content: Some( + json!([ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,abc" + } + }, + { + "type": "text", + "text": "Describe this image" + } + ]) + .to_string(), + ), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + tool_image_attachments: None, + }]; + + let (_, contents) = GeminiMessageConverter::convert_messages(messages, "gemini-2.5-pro"); + + assert_eq!( + contents[0]["parts"][0]["inlineData"]["mimeType"], + json!("image/png") + ); + assert_eq!( + contents[0]["parts"][1]["text"], + json!("Describe this image") + ); + } + + #[test] + fn strips_unsupported_fields_from_tool_schema() { + let tools = Some(vec![ToolDefinition { + name: "get_weather".to_string(), + description: "Get weather".to_string(), + parameters: json!({ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "city": { "type": "string" }, + "timezone": { + "type": ["string", "null"] + }, + "link": { + "anyOf": [ + { + "type": "object", + "properties": { + "url": { "type": "string" } + }, + "required": ["url"] + }, + { "type": "null" } + ] + }, + "items": { + "allOf": [ + { + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "required": ["name"] + }, + { + "type": "object", + "properties": { + "count": { "type": "integer" } + }, + "required": ["count"] + } + ] + } + }, + "required": ["city"], + "additionalProperties": false, + "items": { + "type": "object", + "additionalProperties": false + } + }), + }]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + let schema = &converted[0]["functionDeclarations"][0]["parameters"]; + + assert!(schema.get("$schema").is_none()); + assert!(schema.get("additionalProperties").is_none()); + assert!(schema["items"].get("additionalProperties").is_none()); + assert_eq!(schema["properties"]["timezone"]["type"], json!("string")); + assert_eq!(schema["properties"]["timezone"]["nullable"], json!(true)); + assert_eq!(schema["properties"]["link"]["type"], json!("object")); + assert_eq!(schema["properties"]["link"]["nullable"], json!(true)); + assert_eq!(schema["properties"]["items"]["type"], json!("object")); + assert_eq!( + schema["properties"]["items"]["required"], + json!(["name", "count"]) + ); + } + + #[test] + fn maps_web_search_to_native_google_search_tool() { + let tools = Some(vec![ToolDefinition { + name: "WebSearch".to_string(), + description: "Search the web".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "query": { "type": "string" } + }, + "required": ["query"] + }), + }]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + assert_eq!(converted.len(), 1); + assert_eq!(converted[0]["googleSearch"], json!({})); + assert!(converted[0].get("functionDeclarations").is_none()); + } + + #[test] + fn falls_back_to_function_declarations_when_native_and_custom_tools_mix() { + let tools = Some(vec![ + ToolDefinition { + name: "WebSearch".to_string(), + description: "Search the web".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "query": { "type": "string" } + } + }), + }, + ToolDefinition { + name: "get_weather".to_string(), + description: "Get weather".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "city": { "type": "string" } + }, + "required": ["city"] + }), + }, + ]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + assert_eq!(converted.len(), 1); + assert!(converted[0].get("googleSearch").is_none()); + assert_eq!( + converted[0]["functionDeclarations"][0]["name"], + json!("get_weather") + ); + assert_eq!( + converted[0]["functionDeclarations"][1]["name"], + json!("WebSearch") + ); + } + + #[test] + fn maps_web_fetch_to_native_url_context_tool() { + let tools = Some(vec![ToolDefinition { + name: "WebFetch".to_string(), + description: "Fetch a URL".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "url": { "type": "string" } + }, + "required": ["url"] + }), + }]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + assert_eq!(converted.len(), 1); + assert_eq!(converted[0]["urlContext"], json!({})); + } +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/providers/gemini/mod.rs b/BitFun-Installer/src-tauri/src/connection_test/providers/gemini/mod.rs new file mode 100644 index 00000000..ee6d89d2 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/providers/gemini/mod.rs @@ -0,0 +1,5 @@ +//! Gemini provider module + +pub mod message_converter; + +pub use message_converter::GeminiMessageConverter; diff --git a/BitFun-Installer/src-tauri/src/connection_test/providers/mod.rs b/BitFun-Installer/src-tauri/src/connection_test/providers/mod.rs new file mode 100644 index 00000000..d927ece1 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/providers/mod.rs @@ -0,0 +1,7 @@ +//! AI provider module +//! +//! Provides a unified interface for different AI providers + +pub mod anthropic; +pub mod gemini; +pub mod openai; diff --git a/BitFun-Installer/src-tauri/src/connection_test/providers/openai/message_converter.rs b/BitFun-Installer/src-tauri/src/connection_test/providers/openai/message_converter.rs new file mode 100644 index 00000000..cef78db6 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/providers/openai/message_converter.rs @@ -0,0 +1,476 @@ +//! OpenAI message format converter + +use crate::connection_test::types::{Message, ToolDefinition}; +use log::{error, warn}; +use serde_json::{json, Value}; + +pub struct OpenAIMessageConverter; + +impl OpenAIMessageConverter { + pub fn convert_messages_to_responses_input( + messages: Vec, + ) -> (Option, Vec) { + let mut instructions = Vec::new(); + let mut input = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if let Some(content) = msg.content.filter(|content| !content.trim().is_empty()) + { + instructions.push(content); + } + } + "tool" => { + if let Some(tool_item) = Self::convert_tool_message_to_responses_item(msg) { + input.push(tool_item); + } + } + "assistant" => { + if let Some(content_items) = Self::convert_message_content_to_responses_items( + &msg.role, + msg.content.as_deref(), + ) { + input.push(json!({ + "type": "message", + "role": "assistant", + "content": content_items, + })); + } + + if let Some(tool_calls) = msg.tool_calls { + for tool_call in tool_calls { + input.push(json!({ + "type": "function_call", + "call_id": tool_call.id, + "name": tool_call.name, + "arguments": serde_json::to_string(&tool_call.arguments) + .unwrap_or_else(|_| "{}".to_string()), + })); + } + } + } + role => { + if let Some(content_items) = Self::convert_message_content_to_responses_items( + role, + msg.content.as_deref(), + ) { + input.push(json!({ + "type": "message", + "role": role, + "content": content_items, + })); + } + } + } + } + + let instructions = if instructions.is_empty() { + None + } else { + Some(instructions.join("\n\n")) + }; + + (instructions, input) + } + + pub fn convert_messages(messages: Vec) -> Vec { + messages + .into_iter() + .map(Self::convert_single_message) + .collect() + } + + fn convert_tool_message_to_responses_item(msg: Message) -> Option { + let call_id = msg.tool_call_id?; + let text = msg.content.unwrap_or_default(); + + // Responses API: `output` may be a string or a list of input_text / input_image / input_file + // (see OpenAI FunctionCallOutput schema). + let output: Value = if let Some(attachments) = msg.tool_image_attachments.filter(|a| !a.is_empty()) { + let mut parts: Vec = attachments + .into_iter() + .map(|att| { + let data_url = format!("data:{};base64,{}", att.mime_type, att.data_base64); + json!({ + "type": "input_image", + "image_url": data_url + }) + }) + .collect(); + parts.push(json!({ + "type": "input_text", + "text": if text.is_empty() { + "Tool execution completed".to_string() + } else { + text + } + })); + json!(parts) + } else { + json!(if text.is_empty() { + "Tool execution completed".to_string() + } else { + text + }) + }; + + Some(json!({ + "type": "function_call_output", + "call_id": call_id, + "output": output, + })) + } + + fn convert_message_content_to_responses_items( + role: &str, + content: Option<&str>, + ) -> Option> { + let content = content?; + let text_item_type = Self::responses_text_item_type(role); + + if content.trim().is_empty() { + return Some(vec![json!({ + "type": text_item_type, + "text": " ", + })]); + } + + let parsed = match serde_json::from_str::(content) { + Ok(parsed) if parsed.is_array() => parsed, + _ => { + return Some(vec![json!({ + "type": text_item_type, + "text": content, + })]); + } + }; + + let mut content_items = Vec::new(); + + if let Some(items) = parsed.as_array() { + for item in items { + let item_type = item.get("type").and_then(Value::as_str); + match item_type { + Some("text") | Some("input_text") | Some("output_text") => { + if let Some(text) = item.get("text").and_then(Value::as_str) { + content_items.push(json!({ + "type": text_item_type, + "text": text, + })); + } + } + Some("image_url") if role != "assistant" => { + let image_url = item.get("image_url").and_then(|value| { + value + .get("url") + .and_then(Value::as_str) + .or_else(|| value.as_str()) + }); + + if let Some(image_url) = image_url { + content_items.push(json!({ + "type": "input_image", + "image_url": image_url, + })); + } + } + _ => {} + } + } + } + + if content_items.is_empty() { + Some(vec![json!({ + "type": text_item_type, + "text": content, + })]) + } else { + Some(content_items) + } + } + + fn responses_text_item_type(role: &str) -> &'static str { + if role == "assistant" { + "output_text" + } else { + "input_text" + } + } + + fn convert_single_message(msg: Message) -> Value { + // Chat Completions: multimodal tool message (e.g. GPT-4o vision + tools) — image parts + text. + if msg.role == "tool" { + if let Some(ref attachments) = msg.tool_image_attachments { + if !attachments.is_empty() { + let mut parts: Vec = attachments + .iter() + .map(|att| { + let url = format!("data:{};base64,{}", att.mime_type, att.data_base64); + json!({ + "type": "image_url", + "image_url": { "url": url, "detail": "auto" } + }) + }) + .collect(); + let text = msg.content.clone().unwrap_or_default(); + if text.trim().is_empty() { + parts.push(json!({ + "type": "text", + "text": "Tool execution completed" + })); + } else { + parts.push(json!({ "type": "text", "text": text })); + } + let mut openai_msg = json!({ + "role": "tool", + "content": Value::Array(parts), + }); + if let Some(id) = msg.tool_call_id { + openai_msg["tool_call_id"] = Value::String(id); + } + if let Some(name) = msg.name { + openai_msg["name"] = Value::String(name); + } + return openai_msg; + } + } + } + + let mut openai_msg = json!({ + "role": msg.role, + }); + + let has_tool_calls = msg.tool_calls.is_some(); + + if let Some(content) = msg.content { + if content.trim().is_empty() { + if msg.role == "assistant" && has_tool_calls { + // OpenAI requires the content field; use a space for tool-call cases. + openai_msg["content"] = Value::String(" ".to_string()); + } else if msg.role == "tool" { + openai_msg["content"] = Value::String("Tool execution completed".to_string()); + warn!( + "[OpenAI] Tool response content is empty: name={:?}", + msg.name + ); + } else { + openai_msg["content"] = Value::String(" ".to_string()); + warn!("[OpenAI] Message content is empty: role={}", msg.role); + } + } else { + if let Ok(parsed) = serde_json::from_str::(&content) { + if parsed.is_array() { + openai_msg["content"] = parsed; + } else { + openai_msg["content"] = Value::String(content); + } + } else { + openai_msg["content"] = Value::String(content); + } + } + } else { + if msg.role == "assistant" && has_tool_calls { + // OpenAI requires the content field; use a space for tool-call cases. + openai_msg["content"] = Value::String(" ".to_string()); + } else if msg.role == "tool" { + openai_msg["content"] = Value::String("Tool execution completed".to_string()); + + warn!( + "[OpenAI] Tool response message content is empty, set to default: name={:?}", + msg.name + ); + } else { + error!( + "[OpenAI] Message content is empty and violates API spec: role={}, has_tool_calls={}", + msg.role, + has_tool_calls + ); + + openai_msg["content"] = Value::String(" ".to_string()); + } + } + + if let Some(reasoning) = msg.reasoning_content { + if !reasoning.is_empty() { + openai_msg["reasoning_content"] = Value::String(reasoning); + } + } + + if let Some(tool_calls) = msg.tool_calls { + let openai_tool_calls: Vec = tool_calls + .into_iter() + .map(|tc| { + json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": serde_json::to_string(&tc.arguments) + .unwrap_or_default() + } + }) + }) + .collect(); + openai_msg["tool_calls"] = Value::Array(openai_tool_calls); + } + + if let Some(tool_call_id) = msg.tool_call_id { + openai_msg["tool_call_id"] = Value::String(tool_call_id); + } + + if let Some(name) = msg.name { + openai_msg["name"] = Value::String(name); + } + + openai_msg + } + + pub fn convert_tools(tools: Option>) -> Option> { + tools.map(|tool_defs| { + tool_defs + .into_iter() + .map(|tool| { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } + }) + }) + .collect() + }) + } +} + +#[cfg(test)] +mod tests { + use super::OpenAIMessageConverter; + use crate::connection_test::types::{Message, ToolCall, ToolImageAttachment}; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn converts_messages_to_responses_input() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Beijing")); + + let messages = vec![ + Message::system("You are helpful".to_string()), + Message::user("Hello".to_string()), + Message::assistant_with_tools(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args.clone(), + }]), + Message { + role: "tool".to_string(), + content: Some("Sunny".to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: Some("call_1".to_string()), + name: Some("get_weather".to_string()), + tool_image_attachments: None, + }, + ]; + + let (instructions, input) = + OpenAIMessageConverter::convert_messages_to_responses_input(messages); + + assert_eq!(instructions.as_deref(), Some("You are helpful")); + assert_eq!(input.len(), 3); + assert_eq!(input[0]["type"], json!("message")); + assert_eq!(input[1]["type"], json!("function_call")); + assert_eq!(input[2]["type"], json!("function_call_output")); + } + + #[test] + fn converts_openai_style_image_content_to_responses_input() { + let messages = vec![Message { + role: "user".to_string(), + content: Some( + json!([ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,abc" + } + }, + { + "type": "text", + "text": "Describe this image" + } + ]) + .to_string(), + ), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + tool_image_attachments: None, + }]; + + let (_, input) = OpenAIMessageConverter::convert_messages_to_responses_input(messages); + let content = input[0]["content"].as_array().expect("content array"); + + assert_eq!(content[0]["type"], json!("input_image")); + assert_eq!(content[1]["type"], json!("input_text")); + } + + #[test] + fn converts_tool_message_with_images_to_responses_function_call_output() { + let messages = vec![Message { + role: "tool".to_string(), + content: Some("Screen captured".to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: Some("call_cu_1".to_string()), + name: Some("computer_use".to_string()), + tool_image_attachments: Some(vec![ToolImageAttachment { + mime_type: "image/jpeg".to_string(), + data_base64: "AAA".to_string(), + }]), + }]; + + let (_, input) = OpenAIMessageConverter::convert_messages_to_responses_input(messages); + let out = &input[0]; + assert_eq!(out["type"], json!("function_call_output")); + assert_eq!(out["call_id"], json!("call_cu_1")); + let output = out["output"].as_array().expect("multimodal output"); + assert_eq!(output[0]["type"], json!("input_image")); + assert!(output[0]["image_url"] + .as_str() + .unwrap() + .starts_with("data:image/jpeg;base64,")); + assert_eq!(output[1]["type"], json!("input_text")); + assert_eq!(output[1]["text"], json!("Screen captured")); + } + + #[test] + fn converts_tool_message_with_images_to_chat_completions_content_parts() { + let msg = Message { + role: "tool".to_string(), + content: Some("ok".to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: Some("call_1".to_string()), + name: Some("computer_use".to_string()), + tool_image_attachments: Some(vec![ToolImageAttachment { + mime_type: "image/jpeg".to_string(), + data_base64: "YmFi".to_string(), + }]), + }; + + let openai = OpenAIMessageConverter::convert_messages(vec![msg]); + let content = openai[0]["content"].as_array().expect("content parts"); + assert_eq!(content[0]["type"], json!("image_url")); + assert_eq!(content[1]["type"], json!("text")); + assert_eq!(content[1]["text"], json!("ok")); + } +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/providers/openai/mod.rs b/BitFun-Installer/src-tauri/src/connection_test/providers/openai/mod.rs new file mode 100644 index 00000000..44ad1060 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/providers/openai/mod.rs @@ -0,0 +1,5 @@ +//! OpenAI provider module + +pub mod message_converter; + +pub use message_converter::OpenAIMessageConverter; diff --git a/BitFun-Installer/src-tauri/src/connection_test/proxy.rs b/BitFun-Installer/src-tauri/src/connection_test/proxy.rs new file mode 100644 index 00000000..f483994f --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/proxy.rs @@ -0,0 +1,31 @@ +//! Copied from `bitfun_core::service::config::ProxyConfig` for standalone installer AI client. + +use serde::{Deserialize, Serialize}; + +/// Proxy configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct ProxyConfig { + /// Whether the proxy is enabled. + pub enabled: bool, + + /// Proxy URL (format: http://host:port or socks5://host:port). + pub url: String, + + /// Proxy username (optional). + pub username: Option, + + /// Proxy password (optional). + pub password: Option, +} + +impl Default for ProxyConfig { + fn default() -> Self { + Self { + enabled: false, + url: String::new(), + username: None, + password: None, + } + } +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/types/ai.rs b/BitFun-Installer/src-tauri/src/connection_test/types/ai.rs new file mode 100644 index 00000000..c2ae3a71 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/types/ai.rs @@ -0,0 +1,72 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Gemini API response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeminiResponse { + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_metadata: Option, +} + +/// Gemini usage stats +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeminiUsage { + #[serde(rename = "promptTokenCount")] + pub prompt_token_count: u32, + #[serde(rename = "candidatesTokenCount")] + pub candidates_token_count: u32, + #[serde(rename = "totalTokenCount")] + pub total_token_count: u32, + #[serde(rename = "reasoningTokenCount")] + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_token_count: Option, + #[serde(rename = "cachedContentTokenCount")] + #[serde(skip_serializing_if = "Option::is_none")] + pub cached_content_token_count: Option, +} + +/// Structured message codes for localized connection test messaging. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum ConnectionTestMessageCode { + ToolCallsNotDetected, + ImageInputCheckFailed, +} + +/// AI connection test result +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ConnectionTestResult { + /// Whether the test succeeded + pub success: bool, + /// Response time (ms) + pub response_time_ms: u64, + /// Model response content (if successful) + #[serde(skip_serializing_if = "Option::is_none")] + pub model_response: Option, + /// Structured message code for localized frontend messaging + #[serde(skip_serializing_if = "Option::is_none")] + pub message_code: Option, + /// Raw error or diagnostic details + #[serde(skip_serializing_if = "Option::is_none")] + pub error_details: Option, +} + +/// Remote model info discovered from a provider API. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RemoteModelInfo { + /// Provider model identifier (used as the actual model_name). + pub id: String, + /// Optional human-readable display name returned by the provider. + #[serde(skip_serializing_if = "Option::is_none")] + pub display_name: Option, +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/types/config.rs b/BitFun-Installer/src-tauri/src/connection_test/types/config.rs new file mode 100644 index 00000000..ab45cbcb --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/types/config.rs @@ -0,0 +1,160 @@ +//! Copied from `bitfun_core::util::types::config` (installer-local; no bitfun_core). + +use serde::{Deserialize, Serialize}; + +fn append_endpoint(base_url: &str, endpoint: &str) -> String { + let base = base_url.trim(); + if base.is_empty() { + return endpoint.to_string(); + } + if base.ends_with(endpoint) { + return base.to_string(); + } + format!("{}/{}", base.trim_end_matches('/'), endpoint) +} + +fn gemini_base_url(url: &str) -> &str { + let mut u = url; + if let Some(pos) = u.find("/v1beta") { + u = &u[..pos]; + } + if let Some(pos) = u.find("/models/") { + u = &u[..pos]; + } + u.trim_end_matches('/') +} + +fn resolve_gemini_request_url(base_url: &str, model_name: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/'); + if trimmed.is_empty() { + return String::new(); + } + + if let Some(stripped) = trimmed.strip_suffix('#') { + return stripped.trim_end_matches('/').to_string(); + } + + let model = model_name.trim(); + if model.is_empty() { + return trimmed.to_string(); + } + + let base = gemini_base_url(trimmed); + format!( + "{}/v1beta/models/{}:streamGenerateContent?alt=sse", + base, model + ) +} + +/// Same rules as `bitfun_core::util::types::config::resolve_request_url`. +pub fn resolve_request_url(base_url: &str, provider: &str, model_name: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/').to_string(); + if trimmed.is_empty() { + return String::new(); + } + + if let Some(stripped) = trimmed.strip_suffix('#') { + return stripped.trim_end_matches('/').to_string(); + } + + match provider.trim().to_ascii_lowercase().as_str() { + "openai" | "nvidia" | "openrouter" => append_endpoint(&trimmed, "chat/completions"), + "response" | "responses" => append_endpoint(&trimmed, "responses"), + "anthropic" => append_endpoint(&trimmed, "v1/messages"), + "gemini" | "google" => resolve_gemini_request_url(&trimmed, model_name), + _ => trimmed, + } +} + +/// AI client configuration (for AI requests) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AIConfig { + pub name: String, + pub base_url: String, + /// Actual request URL + /// Falls back to base_url when absent + pub request_url: String, + pub api_key: String, + pub model: String, + pub format: String, + pub context_window: u32, + pub max_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub enable_thinking_process: bool, + pub support_preserved_thinking: bool, + pub inline_think_in_text: bool, + pub custom_headers: Option>, + /// "replace" (default) or "merge" (defaults first, then custom) + pub custom_headers_mode: Option, + pub skip_ssl_verify: bool, + /// Reasoning effort for OpenAI Responses API ("low", "medium", "high", "xhigh") + pub reasoning_effort: Option, + /// Custom JSON overriding default request body fields + pub custom_request_body: Option, +} + +#[cfg(test)] +mod tests { + use super::resolve_request_url; + + #[test] + fn resolves_openai_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "openai", ""), + "https://api.openai.com/v1/chat/completions" + ); + } + + #[test] + fn resolves_responses_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "responses", ""), + "https://api.openai.com/v1/responses" + ); + } + + #[test] + fn resolves_response_alias_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "response", ""), + "https://api.openai.com/v1/responses" + ); + } + + #[test] + fn keeps_forced_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1/responses#", "responses", ""), + "https://api.openai.com/v1/responses" + ); + } + + #[test] + fn resolves_gemini_request_url_with_v1beta() { + assert_eq!( + resolve_request_url( + "https://generativelanguage.googleapis.com/v1beta", + "gemini", + "gemini-2.5-pro" + ), + "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse" + ); + } + + #[test] + fn resolves_nvidia_request_url() { + assert_eq!( + resolve_request_url("https://integrate.api.nvidia.com/v1", "nvidia", ""), + "https://integrate.api.nvidia.com/v1/chat/completions" + ); + } + + #[test] + fn resolves_openrouter_request_url() { + assert_eq!( + resolve_request_url("https://openrouter.ai/api/v1", "openrouter", ""), + "https://openrouter.ai/api/v1/chat/completions" + ); + } +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/types/core.rs b/BitFun-Installer/src-tauri/src/connection_test/types/core.rs new file mode 100644 index 00000000..33081888 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/types/core.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +/// Basic error type +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StructuredError { + pub message: String, + pub status: Option, +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/types/message.rs b/BitFun-Installer/src-tauri/src/connection_test/types/message.rs new file mode 100644 index 00000000..390b092d --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/types/message.rs @@ -0,0 +1,79 @@ +use super::tool::ToolCall; +use super::tool_image_attachment::ToolImageAttachment; +use serde::{Deserialize, Serialize}; + +/// Internal message representation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, // "user", "assistant", "tool", "system" + pub content: Option, + /// Reasoning content (for interleaved thinking mode) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, + /// Signature for Anthropic extended thinking + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_signature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Images attached to a tool result (Anthropic multimodal tool_result). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_image_attachments: Option>, +} + +impl Message { + pub fn user(content: String) -> Self { + Self { + role: "user".to_string(), + content: Some(content), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + tool_image_attachments: None, + } + } + + pub fn assistant(content: String) -> Self { + Self { + role: "assistant".to_string(), + content: Some(content), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + tool_image_attachments: None, + } + } + + pub fn assistant_with_tools(tool_calls: Vec) -> Self { + Self { + role: "assistant".to_string(), + content: None, + reasoning_content: None, + thinking_signature: None, + tool_calls: Some(tool_calls), + tool_call_id: None, + name: None, + tool_image_attachments: None, + } + } + + pub fn system(content: String) -> Self { + Self { + role: "system".to_string(), + content: Some(content), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + tool_image_attachments: None, + } + } +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/types/mod.rs b/BitFun-Installer/src-tauri/src/connection_test/types/mod.rs new file mode 100644 index 00000000..bb5b8329 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/types/mod.rs @@ -0,0 +1,13 @@ +pub mod ai; +pub mod config; +pub mod core; +pub mod message; +pub mod tool; +pub mod tool_image_attachment; + +pub use ai::*; +pub use config::*; +pub use message::*; +pub use tool::*; +#[cfg(test)] +pub use tool_image_attachment::ToolImageAttachment; diff --git a/BitFun-Installer/src-tauri/src/connection_test/types/tool.rs b/BitFun-Installer/src-tauri/src/connection_test/types/tool.rs new file mode 100644 index 00000000..a5c336b0 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/types/tool.rs @@ -0,0 +1,46 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDefinition { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallConfirmationDetails { + pub request: ToolCallRequestInfo, + #[serde(rename = "type")] + pub confirmation_type: String, // 'edit' | 'execute' | 'confirm' + pub message: Option, + pub file_diff: Option, + pub file_name: Option, + pub original_content: Option, + pub new_content: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallRequestInfo { + pub call_id: String, + pub name: String, + pub args: HashMap, + pub is_client_initiated: bool, + pub prompt_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallResponseInfo { + pub call_id: String, + pub response_parts: serde_json::Value, + pub result_display: Option, + pub error: Option, + pub error_type: Option, +} diff --git a/BitFun-Installer/src-tauri/src/connection_test/types/tool_image_attachment.rs b/BitFun-Installer/src-tauri/src/connection_test/types/tool_image_attachment.rs new file mode 100644 index 00000000..cbffc849 --- /dev/null +++ b/BitFun-Installer/src-tauri/src/connection_test/types/tool_image_attachment.rs @@ -0,0 +1,9 @@ +//! Image payload attached to tool results (multimodal tool messages). + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ToolImageAttachment { + pub mime_type: String, + pub data_base64: String, +} diff --git a/BitFun-Installer/src-tauri/src/installer/ai_config.rs b/BitFun-Installer/src-tauri/src/installer/ai_config.rs new file mode 100644 index 00000000..38c6719b --- /dev/null +++ b/BitFun-Installer/src-tauri/src/installer/ai_config.rs @@ -0,0 +1,74 @@ +//! Map installer `ModelConfig` to copied `AIConfig` (mirrors `bitfun_core` `TryFrom`). + +use crate::connection_test::types::{resolve_request_url, AIConfig}; +use crate::installer::types::ModelConfig; +use log::warn; + +/// Build `AIConfig` for the copied `AIClient`. +pub fn ai_config_from_installer_model(m: &ModelConfig) -> Result { + let custom_request_body = if let Some(body_str) = &m.custom_request_body { + let t = body_str.trim(); + if t.is_empty() { + None + } else { + match serde_json::from_str::(t) { + Ok(value) => Some(value), + Err(e) => { + warn!("Failed to parse custom_request_body: {}", e); + None + } + } + } + } else { + None + }; + + let format_key = m.format.trim(); + if format_key.is_empty() { + return Err("Model format is required".to_string()); + } + + let request_url = resolve_request_url(m.base_url.trim(), format_key, m.model_name.trim()); + + Ok(AIConfig { + name: m + .config_name + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("{} - {}", m.provider.trim(), m.model_name.trim())), + base_url: m.base_url.trim().to_string(), + request_url, + api_key: m.api_key.trim().to_string(), + model: m.model_name.trim().to_string(), + format: format_key.to_string(), + context_window: 128_128, + max_tokens: None, + temperature: None, + top_p: None, + enable_thinking_process: false, + support_preserved_thinking: false, + inline_think_in_text: false, + custom_headers: m.custom_headers.clone(), + custom_headers_mode: m.custom_headers_mode.clone(), + skip_ssl_verify: m.skip_ssl_verify.unwrap_or(false), + reasoning_effort: None, + custom_request_body, + }) +} + +/// Whether to run the image-input check (same rules as desktop `test_ai_config_connection`). +pub fn supports_image_input(m: &ModelConfig) -> bool { + m.capabilities + .as_ref() + .map(|c| { + c.iter() + .any(|x| x.eq_ignore_ascii_case("image_understanding")) + }) + .unwrap_or(false) + || m.category + .as_deref() + .map(|c| c.eq_ignore_ascii_case("multimodal")) + .unwrap_or(false) +} diff --git a/BitFun-Installer/src-tauri/src/installer/commands.rs b/BitFun-Installer/src-tauri/src/installer/commands.rs index 268c71a1..6dc06bd7 100644 --- a/BitFun-Installer/src-tauri/src/installer/commands.rs +++ b/BitFun-Installer/src-tauri/src/installer/commands.rs @@ -5,13 +5,11 @@ use super::model_list; use super::types::{ ConnectionTestResult, DiskSpaceInfo, InstallOptions, InstallProgress, ModelConfig, RemoteModelInfo, }; -use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::fs::File; use std::io::{Cursor, Read}; use std::path::{Path, PathBuf}; -use std::time::Duration; use tauri::{Emitter, Manager, Window}; #[cfg(target_os = "windows")] @@ -646,11 +644,9 @@ pub fn set_model_config(model_config: ModelConfig) -> Result<(), String> { apply_first_launch_model(&model_config) } -/// Validate model configuration connectivity from installer. +/// Validate model configuration connectivity from installer (same stack as desktop `test_ai_config_connection`). #[tauri::command] pub async fn test_model_config_connection(model_config: ModelConfig) -> Result { - let started_at = std::time::Instant::now(); - let required_fields = [ ("baseUrl", model_config.base_url.trim()), ("apiKey", model_config.api_key.trim()), @@ -660,29 +656,88 @@ pub async fn test_model_config_connection(model_config: ModelConfig) -> Result Ok(ConnectionTestResult { - success: true, - response_time_ms: elapsed_ms, - model_response, - error_details: None, - }), - Err(error_details) => Ok(ConnectionTestResult { - success: false, - response_time_ms: elapsed_ms, - model_response: None, - error_details: Some(error_details), - }), + let ai_config = super::ai_config::ai_config_from_installer_model(&model_config) + .map_err(|e| e.to_string())?; + let model_name = ai_config.name.clone(); + let supports_image_input = super::ai_config::supports_image_input(&model_config); + + let ai_client = crate::connection_test::AIClient::new(ai_config); + + match ai_client.test_connection().await { + Ok(result) => { + if !result.success { + log::info!( + "Installer AI config connection test: model={}, success={}, response_time={}ms", + model_name, result.success, result.response_time_ms + ); + return Ok(result); + } + + if supports_image_input { + match ai_client.test_image_input_connection().await { + Ok(image_result) => { + let response_time_ms = + result.response_time_ms + image_result.response_time_ms; + + if !image_result.success { + let merged = ConnectionTestResult { + success: false, + response_time_ms, + model_response: image_result.model_response.or(result.model_response), + message_code: image_result.message_code, + error_details: image_result.error_details, + }; + log::info!( + "Installer AI config connection test: model={}, success={}, response_time={}ms", + model_name, merged.success, merged.response_time_ms + ); + return Ok(merged); + } + + let merged = ConnectionTestResult { + success: true, + response_time_ms, + model_response: image_result.model_response.or(result.model_response), + message_code: result.message_code, + error_details: result.error_details, + }; + log::info!( + "Installer AI config connection test: model={}, success={}, response_time={}ms", + model_name, merged.success, merged.response_time_ms + ); + return Ok(merged); + } + Err(e) => { + log::error!( + "Installer multimodal image test failed unexpectedly: model={}, error={}", + model_name, e + ); + return Err(format!("Connection test failed: {}", e)); + } + } + } + + log::info!( + "Installer AI config connection test: model={}, success={}, response_time={}ms", + model_name, result.success, result.response_time_ms + ); + Ok(result) + } + Err(e) => { + log::error!( + "Installer AI config connection test failed: model={}, error={}", + model_name, e + ); + Err(format!("Connection test failed: {}", e)) + } } } @@ -700,17 +755,6 @@ pub async fn list_model_config_models(model_config: ModelConfig) -> Result String { - let normalized = model.format.trim().to_ascii_lowercase(); - match normalized.as_str() { - "anthropic" => "anthropic".to_string(), - "gemini" | "google" => "gemini".to_string(), - "responses" | "response" => "responses".to_string(), - _ => "openai".to_string(), - } -} - fn storage_format(model: &ModelConfig) -> String { model.format.trim().to_ascii_lowercase() } @@ -761,39 +805,6 @@ fn gemini_installer_base_url(url: &str) -> &str { u.trim_end_matches('/') } -fn append_endpoint(base_url: &str, endpoint: &str) -> String { - let base = base_url.trim(); - if base.is_empty() { - return endpoint.to_string(); - } - if base.ends_with(endpoint) { - return base.to_string(); - } - format!("{}/{}", base.trim_end_matches('/'), endpoint) -} - -fn resolve_request_url(base_url: &str, format: &str, model_name: &str) -> String { - let trimmed = base_url.trim().trim_end_matches('/').to_string(); - if trimmed.is_empty() { - return String::new(); - } - - if let Some(stripped) = trimmed.strip_suffix('#') { - return stripped.trim_end_matches('/').to_string(); - } - - match format { - "anthropic" => append_endpoint(&trimmed, "v1/messages"), - "openai" | "responses" => append_endpoint(&trimmed, "chat/completions"), - "gemini" => { - let base = gemini_installer_base_url(&trimmed); - let encoded = urlencoding::encode(model_name.trim()); - format!("{}/v1beta/models/{}:generateContent", base, encoded) - } - _ => trimmed, - } -} - fn parse_custom_request_body(raw: &Option) -> Result>, String> { let Some(raw_value) = raw else { return Ok(None); @@ -812,173 +823,6 @@ fn parse_custom_request_body(raw: &Option) -> Result, source: &Map) { - for (key, value) in source { - target.insert(key.clone(), value.clone()); - } -} - -fn build_request_headers(model: &ModelConfig, format: &str) -> Result { - let mode = model - .custom_headers_mode - .as_deref() - .unwrap_or("merge") - .trim() - .to_ascii_lowercase(); - if mode != "merge" && mode != "replace" { - return Err("customHeadersMode must be 'merge' or 'replace'".to_string()); - } - - let mut headers = HeaderMap::new(); - if mode != "replace" { - headers.insert(ACCEPT, HeaderValue::from_static("application/json")); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - if format == "anthropic" { - let api_key = HeaderValue::from_str(model.api_key.trim()) - .map_err(|_| "apiKey contains unsupported header characters".to_string())?; - headers.insert(HeaderName::from_static("x-api-key"), api_key); - headers.insert( - HeaderName::from_static("anthropic-version"), - HeaderValue::from_static("2023-06-01"), - ); - } else if format == "gemini" { - let api_key = HeaderValue::from_str(model.api_key.trim()) - .map_err(|_| "apiKey contains unsupported header characters".to_string())?; - headers.insert(HeaderName::from_static("x-goog-api-key"), api_key.clone()); - let bearer = format!("Bearer {}", model.api_key.trim()); - let auth = HeaderValue::from_str(&bearer) - .map_err(|_| "apiKey contains unsupported header characters".to_string())?; - headers.insert(AUTHORIZATION, auth); - } else { - let bearer = format!("Bearer {}", model.api_key.trim()); - let auth = HeaderValue::from_str(&bearer) - .map_err(|_| "apiKey contains unsupported header characters".to_string())?; - headers.insert(AUTHORIZATION, auth); - } - } - - if let Some(custom_headers) = &model.custom_headers { - for (key, value) in custom_headers { - let key_trimmed = key.trim(); - if key_trimmed.is_empty() { - continue; - } - let header_name = HeaderName::from_bytes(key_trimmed.as_bytes()) - .map_err(|_| format!("Invalid custom header name: {}", key_trimmed))?; - let header_value = HeaderValue::from_str(value.trim()) - .map_err(|_| format!("Invalid custom header value for '{}'", key_trimmed))?; - headers.insert(header_name, header_value); - } - } - - Ok(headers) -} - -fn truncate_error_text(raw: &str, limit: usize) -> String { - let compact = raw.replace('\n', " ").replace('\r', " ").trim().to_string(); - if compact.chars().count() <= limit { - return compact; - } - compact.chars().take(limit).collect::() + "..." -} - -async fn run_model_connection_test(model: &ModelConfig) -> Result, String> { - let format = normalize_api_format(model); - let endpoint = resolve_request_url(&model.base_url, &format, model.model_name.trim()); - let headers = build_request_headers(model, &format)?; - let custom_request_body = parse_custom_request_body(&model.custom_request_body)?; - - let mut payload = Map::new(); - if format == "anthropic" { - payload.insert("model".to_string(), Value::String(model.model_name.trim().to_string())); - payload.insert("max_tokens".to_string(), Value::Number(16_u64.into())); - payload.insert( - "messages".to_string(), - serde_json::json!([{ "role": "user", "content": "hello" }]), - ); - } else if format == "gemini" { - payload.insert( - "contents".to_string(), - serde_json::json!([{ "parts": [{ "text": "hello" }] }]), - ); - let mut gen = Map::new(); - gen.insert("maxOutputTokens".to_string(), Value::Number(32_u64.into())); - payload.insert("generationConfig".to_string(), Value::Object(gen)); - } else { - payload.insert("model".to_string(), Value::String(model.model_name.trim().to_string())); - payload.insert("max_tokens".to_string(), Value::Number(16_u64.into())); - payload.insert("temperature".to_string(), serde_json::json!(0.1)); - payload.insert( - "messages".to_string(), - serde_json::json!([{ "role": "user", "content": "hello" }]), - ); - } - if let Some(extra) = custom_request_body.as_ref() { - merge_json_object(&mut payload, extra); - } - - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(20)) - .danger_accept_invalid_certs(model.skip_ssl_verify.unwrap_or(false)) - .build() - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - - let response = client - .post(endpoint) - .headers(headers) - .json(&Value::Object(payload)) - .send() - .await - .map_err(|e| format!("Request failed: {}", e))?; - - let status = response.status(); - let response_body = response - .text() - .await - .map_err(|e| format!("Failed to read response body: {}", e))?; - if !status.is_success() { - return Err(format!( - "HTTP {}: {}", - status.as_u16(), - truncate_error_text(&response_body, 260) - )); - } - - let parsed_json = serde_json::from_str::(&response_body).unwrap_or(Value::Null); - let model_response = if format == "anthropic" { - parsed_json - .get("content") - .and_then(|v| v.as_array()) - .and_then(|arr| arr.first()) - .and_then(|item| item.get("text")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - } else if format == "gemini" { - parsed_json - .get("candidates") - .and_then(|v| v.as_array()) - .and_then(|arr| arr.first()) - .and_then(|c| c.get("content")) - .and_then(|c| c.get("parts")) - .and_then(|p| p.as_array()) - .and_then(|arr| arr.first()) - .and_then(|part| part.get("text")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - } else { - parsed_json - .get("choices") - .and_then(|v| v.as_array()) - .and_then(|arr| arr.first()) - .and_then(|item| item.get("message")) - .and_then(|msg| msg.get("content")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }; - - Ok(model_response) -} - fn emit_progress(window: &Window, step: &str, percent: u32, message: &str) { let progress = InstallProgress { step: step.to_string(), diff --git a/BitFun-Installer/src-tauri/src/installer/mod.rs b/BitFun-Installer/src-tauri/src/installer/mod.rs index 137ecebc..604ee848 100644 --- a/BitFun-Installer/src-tauri/src/installer/mod.rs +++ b/BitFun-Installer/src-tauri/src/installer/mod.rs @@ -1,3 +1,4 @@ +pub mod ai_config; pub mod commands; pub mod extract; pub mod model_list; diff --git a/BitFun-Installer/src-tauri/src/installer/types.rs b/BitFun-Installer/src-tauri/src/installer/types.rs index ba9c58a6..bbd3ec5f 100644 --- a/BitFun-Installer/src-tauri/src/installer/types.rs +++ b/BitFun-Installer/src-tauri/src/installer/types.rs @@ -44,6 +44,12 @@ pub struct ModelConfig { pub custom_headers: Option>, #[serde(default)] pub custom_headers_mode: Option, + /// Optional capability ids (e.g. `image_understanding`) — aligns with main app when set. + #[serde(default)] + pub capabilities: Option>, + /// Optional model category (e.g. `multimodal`) — aligns with main app when set. + #[serde(default)] + pub category: Option, } /// One entry from provider model discovery (installer-local; mirrors main app shape). @@ -55,16 +61,7 @@ pub struct RemoteModelInfo { pub display_name: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ConnectionTestResult { - pub success: bool, - pub response_time_ms: u64, - #[serde(skip_serializing_if = "Option::is_none")] - pub model_response: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error_details: Option, -} +pub use crate::connection_test::types::ConnectionTestResult; /// Progress update sent to the frontend #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/BitFun-Installer/src-tauri/src/lib.rs b/BitFun-Installer/src-tauri/src/lib.rs index 79d41206..c797076f 100644 --- a/BitFun-Installer/src-tauri/src/lib.rs +++ b/BitFun-Installer/src-tauri/src/lib.rs @@ -1,3 +1,4 @@ +mod connection_test; mod installer; use installer::commands; diff --git a/BitFun-Installer/src/types/installer.ts b/BitFun-Installer/src/types/installer.ts index e136c83b..0050b443 100644 --- a/BitFun-Installer/src/types/installer.ts +++ b/BitFun-Installer/src/types/installer.ts @@ -36,12 +36,20 @@ export interface ModelConfig { skipSslVerify?: boolean; customHeaders?: Record; customHeadersMode?: 'merge' | 'replace'; + /** Aligns with main app model capabilities when testing image input. */ + capabilities?: string[]; + /** Aligns with main app model category (e.g. multimodal). */ + category?: string; } +/** Matches backend `ConnectionTestMessageCode` (camelCase JSON). */ +export type ConnectionTestMessageCode = 'toolCallsNotDetected' | 'imageInputCheckFailed'; + export interface ConnectionTestResult { success: boolean; responseTimeMs: number; modelResponse?: string; + messageCode?: ConnectionTestMessageCode; errorDetails?: string; }