diff --git a/src/auth/oauth.rs b/src/auth/oauth.rs index 94f3a18d..5cda3f02 100644 --- a/src/auth/oauth.rs +++ b/src/auth/oauth.rs @@ -119,6 +119,32 @@ impl OAuthConfig { OAuthProviderType::Anthropic } } + + /// Rewrites a `localhost`/`127.0.0.1`/`[::1]` `redirect_uri` to use `port`. + /// + /// Leaves non-loopback redirect URIs untouched (e.g. Anthropic's + /// `console.anthropic.com` callback). Used to keep the OAuth callback + /// `redirect_uri` in sync with the actual port chosen by the local + /// callback server when the configured port was busy. + /// + /// # Examples + /// + /// ```ignore + /// let cfg = OAuthConfig::openai_codex().with_callback_port(1456); + /// assert!(cfg.redirect_uri.contains(":1456/")); + /// ``` + #[must_use] + pub fn with_callback_port(mut self, port: u16) -> Self { + if !is_localhost_url(&self.redirect_uri) { + return self; + } + if let Ok(mut url) = url::Url::parse(&self.redirect_uri) { + // Ignore set_port errors (cannot-be-base URLs); fall through unchanged. + let _ = url.set_port(Some(port)); + self.redirect_uri = url.to_string(); + } + self + } } impl OAuthConfig { @@ -619,4 +645,26 @@ mod tests { assert!(auth_url.url.contains("code_challenge_method=S256")); assert!(auth_url.url.contains("scope=")); } + + #[test] + fn test_with_callback_port_rewrites_openai_codex_localhost() { + let cfg = OAuthConfig::openai_codex().with_callback_port(1456); + assert_eq!(cfg.redirect_uri, "http://localhost:1456/auth/callback"); + } + + #[test] + fn test_with_callback_port_rewrites_gemini_localhost() { + let cfg = OAuthConfig::gemini().with_callback_port(13_460); + assert_eq!( + cfg.redirect_uri, + "http://localhost:13460/api/oauth/callback" + ); + } + + #[test] + fn test_with_callback_port_leaves_remote_redirect_alone() { + let original = OAuthConfig::anthropic().redirect_uri.clone(); + let cfg = OAuthConfig::anthropic().with_callback_port(9999); + assert_eq!(cfg.redirect_uri, original); + } } diff --git a/src/server/lifecycle.rs b/src/server/lifecycle.rs index 1d848117..7ddb6c89 100644 --- a/src/server/lifecycle.rs +++ b/src/server/lifecycle.rs @@ -8,7 +8,6 @@ use super::{oauth_handlers, AppState}; use crate::models::config::AppConfig; use axum::{routing::get, Router as AxumRouter}; use std::sync::Arc; -use tokio::net::TcpListener; use tracing::{error, info, warn}; /// Binds the server socket and serves with optional TLS and graceful shutdown. @@ -114,32 +113,66 @@ pub(super) async fn drain_in_flight(state: &Arc) { } } -/// Spawn the OAuth callback server (required for OpenAI Codex OAuth) +/// Spawn the OAuth callback server (required for OpenAI Codex OAuth). +/// +/// Tries the configured `oauth_callback_port` first, then falls back to up +/// to [`OAUTH_CALLBACK_BIND_ATTEMPTS`] adjacent ports if the default is busy. +/// The actually-bound port is stored in +/// [`AppState::actual_oauth_callback_port`] so handlers can build `redirect_uri` +/// values that match the live listener. pub(super) fn spawn_oauth_callback(oauth_state: Arc) { - let port = oauth_state.snapshot().config.server.oauth_callback_port; + let configured_port = oauth_state.snapshot().config.server.oauth_callback_port; tokio::spawn(async move { let oauth_callback_app = AxumRouter::new() .route("/auth/callback", get(oauth_handlers::oauth_callback)) - .with_state(oauth_state); + .with_state(oauth_state.clone()); - let oauth_addr = format!("127.0.0.1:{}", port); - match TcpListener::bind(&oauth_addr).await { - Ok(oauth_listener) => { - info!("OAuth callback server listening on {}", oauth_addr); + // NOTE: Bind to 127.0.0.1 because OAuth providers (OpenAI Codex, Gemini) + // redirect to a literal `localhost`/`127.0.0.1` callback URL. + match crate::shared::net::bind_with_port_retry( + "127.0.0.1", + configured_port, + OAUTH_CALLBACK_BIND_ATTEMPTS, + ) + .await + { + Ok((oauth_listener, actual_port)) => { + oauth_state + .actual_oauth_callback_port + .store(actual_port, std::sync::atomic::Ordering::Relaxed); + if actual_port == configured_port { + info!( + "OAuth callback server listening on 127.0.0.1:{}", + actual_port + ); + } else { + warn!( + "OAuth callback port {} busy; bound on 127.0.0.1:{} instead", + configured_port, actual_port + ); + } if let Err(e) = axum::serve(oauth_listener, oauth_callback_app).await { error!("OAuth callback server error: {}", e); } } Err(e) => { + let last_port = + configured_port.saturating_add(OAUTH_CALLBACK_BIND_ATTEMPTS.saturating_sub(1)); error!( - "Failed to bind OAuth callback server on {}: {}", - oauth_addr, e + "Failed to bind OAuth callback server on 127.0.0.1 in port range {}..={}: {:#}", + configured_port, last_port, e ); error!( - "OpenAI Codex OAuth will not work. Port {} must be available.", - port + "OpenAI Codex / Gemini OAuth will not work. Free a port in {}..={} or set server.oauth_callback_port.", + configured_port, last_port ); } } }); } + +/// Maximum number of adjacent ports tried when binding the OAuth callback server. +/// +/// The configured port is the first attempt; subsequent attempts increment by 1. +pub(super) const OAUTH_CALLBACK_BIND_ATTEMPTS: u16 = + crate::shared::net::DEFAULT_PORT_RETRY_ATTEMPTS; diff --git a/src/server/mod.rs b/src/server/mod.rs index e815fa3d..3be15db4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -172,6 +172,14 @@ pub struct AppState { pub active_requests: std::sync::atomic::AtomicU64, /// Server start time (for health/upgrade coordination) pub started_at: chrono::DateTime, + /// Actual TCP port the OAuth callback server bound to. + /// + /// The configured port (`server.oauth_callback_port`) is used as the base + /// and the listener falls back to adjacent ports if it is busy. Handlers + /// that build OAuth `redirect_uri` values read this to stay in sync with + /// the live callback listener. `0` means the callback server has not yet + /// bound (or failed to bind). + pub actual_oauth_callback_port: std::sync::atomic::AtomicU16, /// Metrics, tracing, spend tracking pub observability: ObservabilityState, @@ -274,6 +282,7 @@ pub async fn start_server( config_source, active_requests: std::sync::atomic::AtomicU64::new(0), started_at: chrono::Utc::now(), + actual_oauth_callback_port: std::sync::atomic::AtomicU16::new(0), event_bus, log_exporter, #[cfg(feature = "mcp")] diff --git a/src/server/oauth_handlers.rs b/src/server/oauth_handlers.rs index 4aeda363..f614fb21 100644 --- a/src/server/oauth_handlers.rs +++ b/src/server/oauth_handlers.rs @@ -80,17 +80,35 @@ pub struct TokenInfo { pub needs_refresh: bool, } +/// Returns the actual OAuth callback port the local server bound to. +/// +/// Falls back to the configured `oauth_callback_port` when the callback +/// server has not yet recorded its actual port (e.g. before [`spawn_oauth_callback`] +/// completes its first bind attempt). +fn live_oauth_callback_port(state: &AppState) -> u16 { + let actual = state + .actual_oauth_callback_port + .load(std::sync::atomic::Ordering::Relaxed); + if actual != 0 { + actual + } else { + state.snapshot().config.server.oauth_callback_port + } +} + /// Get authorization URL pub async fn oauth_authorize( State(state): State>, Json(req): Json, ) -> Result, (StatusCode, String)> { + let callback_port = live_oauth_callback_port(&state); + // Create OAuth config based on type let config = match req.oauth_type.as_str() { "max" => OAuthConfig::anthropic(), "console" => OAuthConfig::anthropic_console(), - "openai-codex" => OAuthConfig::openai_codex(), - "gemini" => OAuthConfig::gemini(), + "openai-codex" => OAuthConfig::openai_codex().with_callback_port(callback_port), + "gemini" => OAuthConfig::gemini().with_callback_port(callback_port), _ => { return Err(( StatusCode::BAD_REQUEST, @@ -134,11 +152,13 @@ pub async fn oauth_exchange( req.oauth_type ); + let callback_port = live_oauth_callback_port(&state); + // Determine OAuth config based on oauth_type if provided, otherwise fall back to provider_id let config = if let Some(ref oauth_type) = req.oauth_type { match oauth_type.as_str() { - "openai-codex" => OAuthConfig::openai_codex(), - "gemini" => OAuthConfig::gemini(), + "openai-codex" => OAuthConfig::openai_codex().with_callback_port(callback_port), + "gemini" => OAuthConfig::gemini().with_callback_port(callback_port), "console" => OAuthConfig::anthropic_console(), "max" => OAuthConfig::anthropic(), _ => { @@ -152,7 +172,7 @@ pub async fn oauth_exchange( || req.provider_id.to_lowercase().contains("codex") || req.provider_id.to_lowercase().contains("chatgpt") { - OAuthConfig::openai_codex() + OAuthConfig::openai_codex().with_callback_port(callback_port) } else { OAuthConfig::anthropic() }; @@ -265,16 +285,18 @@ pub async fn oauth_refresh_token( State(state): State>, Json(req): Json, ) -> Result, (StatusCode, String)> { + let callback_port = live_oauth_callback_port(&state); + // Determine OAuth config based on provider_id let config = if req.provider_id.to_lowercase().contains("openai") || req.provider_id.to_lowercase().contains("codex") || req.provider_id.to_lowercase().contains("chatgpt") { - OAuthConfig::openai_codex() + OAuthConfig::openai_codex().with_callback_port(callback_port) } else if req.provider_id.to_lowercase().contains("gemini") || req.provider_id.to_lowercase().contains("google") { - OAuthConfig::gemini() + OAuthConfig::gemini().with_callback_port(callback_port) } else { OAuthConfig::anthropic() }; diff --git a/src/shared/net.rs b/src/shared/net.rs index 030e2ba9..51b8cf6e 100644 --- a/src/shared/net.rs +++ b/src/shared/net.rs @@ -94,6 +94,57 @@ pub fn bind_reuseport_std(addr: &str) -> Result { TcpListener::bind(addr).with_context(|| format!("Failed to bind to {}", addr)) } +/// Default maximum bind attempts for [`bind_with_port_retry`] (base port plus 9 fallbacks). +pub const DEFAULT_PORT_RETRY_ATTEMPTS: u16 = 10; + +/// Binds to `host:base_port`, falling back to adjacent ports up to `max_attempts` total. +/// +/// Tries `base_port`, `base_port + 1`, ..., `base_port + max_attempts - 1` in order +/// and returns the first listener that binds successfully along with its actual port. +/// Uses plain `tokio::net::TcpListener::bind` (no `SO_REUSEPORT`) so a busy port is +/// reliably detected and the next candidate is tried. Useful for ephemeral local +/// servers — like the OAuth callback — where a stable port is preferred but a free +/// adjacent port is acceptable when the default is occupied. +/// +/// # Errors +/// +/// Returns an error if every port in `base_port..base_port + max_attempts` is busy +/// or otherwise unbindable, or if `max_attempts` is zero. +pub async fn bind_with_port_retry( + host: &str, + base_port: u16, + max_attempts: u16, +) -> Result<(tokio::net::TcpListener, u16)> { + if max_attempts == 0 { + anyhow::bail!("bind_with_port_retry: max_attempts must be > 0"); + } + + let end_port = base_port.saturating_add(max_attempts); + let mut last_err: Option = None; + for offset in 0..max_attempts { + let Some(port) = base_port.checked_add(offset) else { + // NOTE: Stop early if we would overflow u16 (e.g. base_port=65535). + break; + }; + let addr = crate::cli::format_bind_addr(host, port); + match tokio::net::TcpListener::bind(&addr).await { + Ok(listener) => { + return Ok((listener, port)); + } + Err(e) => { + last_err = Some(anyhow::Error::new(e).context(format!("bind {}", addr))); + } + } + } + + Err(last_err + .unwrap_or_else(|| anyhow::anyhow!("bind_with_port_retry: no attempts made for {}", host)) + .context(format!( + "Could not bind any port in range {}..{}", + base_port, end_port + ))) +} + #[cfg(test)] mod tests { use super::*; @@ -135,4 +186,76 @@ mod tests { fn test_invalid_addr() { assert!(bind_reuseport_std("not_an_addr").is_err()); } + + #[tokio::test] + async fn test_bind_with_port_retry_uses_base_port_when_free() { + // Pick an ephemeral port to use as our "base", then immediately drop it + // so the retry helper can claim it on the first try. + let probe = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let base_port = probe.local_addr().unwrap().port(); + drop(probe); + + let (_listener, port) = bind_with_port_retry("127.0.0.1", base_port, 5) + .await + .expect("bind should succeed when base port is free"); + assert_eq!(port, base_port); + } + + #[tokio::test] + async fn test_bind_with_port_retry_falls_back_to_next_port() { + // Hold the base port so the retry helper has to advance to base+1. + let blocker = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let base_port = blocker.local_addr().unwrap().port(); + // Keep the blocker alive to ensure base_port stays busy for the test. + + let (_listener, port) = bind_with_port_retry("127.0.0.1", base_port, 5) + .await + .expect("bind should succeed on a fallback port"); + assert!( + port > base_port && port < base_port.saturating_add(5), + "expected fallback port in range, got {port} (base {base_port})" + ); + drop(blocker); + } + + #[tokio::test] + async fn test_bind_with_port_retry_zero_attempts_errors() { + let err = bind_with_port_retry("127.0.0.1", 50_000, 0) + .await + .unwrap_err(); + assert!( + err.to_string().contains("max_attempts must be > 0"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn test_bind_with_port_retry_reports_range_when_all_busy() { + // Hold a contiguous block of ports so every retry attempt must fail. + let blockers: Vec = (0..3) + .map(|_| std::net::TcpListener::bind("127.0.0.1:0").unwrap()) + .collect(); + + // Find a contiguous run of busy ports we already hold. + let ports: Vec = blockers + .iter() + .map(|l| l.local_addr().unwrap().port()) + .collect(); + let mut sorted_ports = ports.clone(); + sorted_ports.sort_unstable(); + + // NOTE: We can't guarantee the OS gave us contiguous ports, so we + // verify the error path by trying to bind a single busy port with + // attempts=1, which deterministically fails with the range message. + let busy_port = sorted_ports[0]; + let err = bind_with_port_retry("127.0.0.1", busy_port, 1) + .await + .unwrap_err(); + let msg = format!("{err:#}"); + assert!( + msg.contains("Could not bind any port in range"), + "unexpected error: {msg}" + ); + drop(blockers); + } }