Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/auth/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ impl OAuthConfig {
OAuthProviderType::Anthropic
}
}

/// Rewrites a `localhost`/`127.0.0.1`/`[::1]` `redirect_uri` to use `port`.
///
/// Leaves non-loopback redirect URIs untouched (e.g. Anthropic's
/// `console.anthropic.com` callback). Used to keep the OAuth callback
/// `redirect_uri` in sync with the actual port chosen by the local
/// callback server when the configured port was busy.
///
/// # Examples
///
/// ```ignore
/// let cfg = OAuthConfig::openai_codex().with_callback_port(1456);
/// assert!(cfg.redirect_uri.contains(":1456/"));
/// ```
#[must_use]
pub fn with_callback_port(mut self, port: u16) -> Self {
if !is_localhost_url(&self.redirect_uri) {
return self;
}
if let Ok(mut url) = url::Url::parse(&self.redirect_uri) {
// Ignore set_port errors (cannot-be-base URLs); fall through unchanged.
let _ = url.set_port(Some(port));
self.redirect_uri = url.to_string();
}
self
}
}

impl OAuthConfig {
Expand Down Expand Up @@ -619,4 +645,26 @@ mod tests {
assert!(auth_url.url.contains("code_challenge_method=S256"));
assert!(auth_url.url.contains("scope="));
}

#[test]
fn test_with_callback_port_rewrites_openai_codex_localhost() {
let cfg = OAuthConfig::openai_codex().with_callback_port(1456);
assert_eq!(cfg.redirect_uri, "http://localhost:1456/auth/callback");
}

#[test]
fn test_with_callback_port_rewrites_gemini_localhost() {
let cfg = OAuthConfig::gemini().with_callback_port(13_460);
assert_eq!(
cfg.redirect_uri,
"http://localhost:13460/api/oauth/callback"
);
}

#[test]
fn test_with_callback_port_leaves_remote_redirect_alone() {
let original = OAuthConfig::anthropic().redirect_uri.clone();
let cfg = OAuthConfig::anthropic().with_callback_port(9999);
assert_eq!(cfg.redirect_uri, original);
}
}
57 changes: 45 additions & 12 deletions src/server/lifecycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use super::{oauth_handlers, AppState};
use crate::models::config::AppConfig;
use axum::{routing::get, Router as AxumRouter};
use std::sync::Arc;
use tokio::net::TcpListener;
use tracing::{error, info, warn};

/// Binds the server socket and serves with optional TLS and graceful shutdown.
Expand Down Expand Up @@ -114,32 +113,66 @@ pub(super) async fn drain_in_flight(state: &Arc<AppState>) {
}
}

/// Spawn the OAuth callback server (required for OpenAI Codex OAuth)
/// Spawn the OAuth callback server (required for OpenAI Codex OAuth).
///
/// Tries the configured `oauth_callback_port` first, then falls back to up
/// to [`OAUTH_CALLBACK_BIND_ATTEMPTS`] adjacent ports if the default is busy.
/// The actually-bound port is stored in
/// [`AppState::actual_oauth_callback_port`] so handlers can build `redirect_uri`
/// values that match the live listener.
pub(super) fn spawn_oauth_callback(oauth_state: Arc<AppState>) {
let port = oauth_state.snapshot().config.server.oauth_callback_port;
let configured_port = oauth_state.snapshot().config.server.oauth_callback_port;
tokio::spawn(async move {
let oauth_callback_app = AxumRouter::new()
.route("/auth/callback", get(oauth_handlers::oauth_callback))
.with_state(oauth_state);
.with_state(oauth_state.clone());

let oauth_addr = format!("127.0.0.1:{}", port);
match TcpListener::bind(&oauth_addr).await {
Ok(oauth_listener) => {
info!("OAuth callback server listening on {}", oauth_addr);
// NOTE: Bind to 127.0.0.1 because OAuth providers (OpenAI Codex, Gemini)
// redirect to a literal `localhost`/`127.0.0.1` callback URL.
match crate::shared::net::bind_with_port_retry(
"127.0.0.1",
configured_port,
OAUTH_CALLBACK_BIND_ATTEMPTS,
)
.await
{
Ok((oauth_listener, actual_port)) => {
oauth_state
.actual_oauth_callback_port
.store(actual_port, std::sync::atomic::Ordering::Relaxed);
if actual_port == configured_port {
info!(
"OAuth callback server listening on 127.0.0.1:{}",
actual_port
);
} else {
warn!(
"OAuth callback port {} busy; bound on 127.0.0.1:{} instead",
configured_port, actual_port
);
}
if let Err(e) = axum::serve(oauth_listener, oauth_callback_app).await {
error!("OAuth callback server error: {}", e);
}
}
Err(e) => {
let last_port =
configured_port.saturating_add(OAUTH_CALLBACK_BIND_ATTEMPTS.saturating_sub(1));
error!(
"Failed to bind OAuth callback server on {}: {}",
oauth_addr, e
"Failed to bind OAuth callback server on 127.0.0.1 in port range {}..={}: {:#}",
configured_port, last_port, e
);
error!(
"OpenAI Codex OAuth will not work. Port {} must be available.",
port
"OpenAI Codex / Gemini OAuth will not work. Free a port in {}..={} or set server.oauth_callback_port.",
configured_port, last_port
);
}
}
});
}

/// Maximum number of adjacent ports tried when binding the OAuth callback server.
///
/// The configured port is the first attempt; subsequent attempts increment by 1.
pub(super) const OAUTH_CALLBACK_BIND_ATTEMPTS: u16 =
crate::shared::net::DEFAULT_PORT_RETRY_ATTEMPTS;
9 changes: 9 additions & 0 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ pub struct AppState {
pub active_requests: std::sync::atomic::AtomicU64,
/// Server start time (for health/upgrade coordination)
pub started_at: chrono::DateTime<chrono::Utc>,
/// Actual TCP port the OAuth callback server bound to.
///
/// The configured port (`server.oauth_callback_port`) is used as the base
/// and the listener falls back to adjacent ports if it is busy. Handlers
/// that build OAuth `redirect_uri` values read this to stay in sync with
/// the live callback listener. `0` means the callback server has not yet
/// bound (or failed to bind).
pub actual_oauth_callback_port: std::sync::atomic::AtomicU16,

/// Metrics, tracing, spend tracking
pub observability: ObservabilityState,
Expand Down Expand Up @@ -274,6 +282,7 @@ pub async fn start_server(
config_source,
active_requests: std::sync::atomic::AtomicU64::new(0),
started_at: chrono::Utc::now(),
actual_oauth_callback_port: std::sync::atomic::AtomicU16::new(0),
event_bus,
log_exporter,
#[cfg(feature = "mcp")]
Expand Down
36 changes: 29 additions & 7 deletions src/server/oauth_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,35 @@ pub struct TokenInfo {
pub needs_refresh: bool,
}

/// Returns the actual OAuth callback port the local server bound to.
///
/// Falls back to the configured `oauth_callback_port` when the callback
/// server has not yet recorded its actual port (e.g. before [`spawn_oauth_callback`]
/// completes its first bind attempt).
fn live_oauth_callback_port(state: &AppState) -> u16 {
let actual = state
.actual_oauth_callback_port
.load(std::sync::atomic::Ordering::Relaxed);
if actual != 0 {
actual
} else {
state.snapshot().config.server.oauth_callback_port
}
}

/// Get authorization URL
pub async fn oauth_authorize(
State(state): State<Arc<AppState>>,
Json(req): Json<OAuthAuthorizeRequest>,
) -> Result<Json<OAuthAuthorizeResponse>, (StatusCode, String)> {
let callback_port = live_oauth_callback_port(&state);

// Create OAuth config based on type
let config = match req.oauth_type.as_str() {
"max" => OAuthConfig::anthropic(),
"console" => OAuthConfig::anthropic_console(),
"openai-codex" => OAuthConfig::openai_codex(),
"gemini" => OAuthConfig::gemini(),
"openai-codex" => OAuthConfig::openai_codex().with_callback_port(callback_port),
"gemini" => OAuthConfig::gemini().with_callback_port(callback_port),
_ => {
return Err((
StatusCode::BAD_REQUEST,
Expand Down Expand Up @@ -134,11 +152,13 @@ pub async fn oauth_exchange(
req.oauth_type
);

let callback_port = live_oauth_callback_port(&state);

// Determine OAuth config based on oauth_type if provided, otherwise fall back to provider_id
let config = if let Some(ref oauth_type) = req.oauth_type {
match oauth_type.as_str() {
"openai-codex" => OAuthConfig::openai_codex(),
"gemini" => OAuthConfig::gemini(),
"openai-codex" => OAuthConfig::openai_codex().with_callback_port(callback_port),
"gemini" => OAuthConfig::gemini().with_callback_port(callback_port),
"console" => OAuthConfig::anthropic_console(),
"max" => OAuthConfig::anthropic(),
_ => {
Expand All @@ -152,7 +172,7 @@ pub async fn oauth_exchange(
|| req.provider_id.to_lowercase().contains("codex")
|| req.provider_id.to_lowercase().contains("chatgpt")
{
OAuthConfig::openai_codex()
OAuthConfig::openai_codex().with_callback_port(callback_port)
} else {
OAuthConfig::anthropic()
};
Expand Down Expand Up @@ -265,16 +285,18 @@ pub async fn oauth_refresh_token(
State(state): State<Arc<AppState>>,
Json(req): Json<DeleteTokenRequest>,
) -> Result<Json<OAuthExchangeResponse>, (StatusCode, String)> {
let callback_port = live_oauth_callback_port(&state);

// Determine OAuth config based on provider_id
let config = if req.provider_id.to_lowercase().contains("openai")
|| req.provider_id.to_lowercase().contains("codex")
|| req.provider_id.to_lowercase().contains("chatgpt")
{
OAuthConfig::openai_codex()
OAuthConfig::openai_codex().with_callback_port(callback_port)
} else if req.provider_id.to_lowercase().contains("gemini")
|| req.provider_id.to_lowercase().contains("google")
{
OAuthConfig::gemini()
OAuthConfig::gemini().with_callback_port(callback_port)
} else {
OAuthConfig::anthropic()
};
Expand Down
123 changes: 123 additions & 0 deletions src/shared/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,57 @@ pub fn bind_reuseport_std(addr: &str) -> Result<std::net::TcpListener> {
TcpListener::bind(addr).with_context(|| format!("Failed to bind to {}", addr))
}

/// Default maximum bind attempts for [`bind_with_port_retry`] (base port plus 9 fallbacks).
pub const DEFAULT_PORT_RETRY_ATTEMPTS: u16 = 10;

/// Binds to `host:base_port`, falling back to adjacent ports up to `max_attempts` total.
///
/// Tries `base_port`, `base_port + 1`, ..., `base_port + max_attempts - 1` in order
/// and returns the first listener that binds successfully along with its actual port.
/// Uses plain `tokio::net::TcpListener::bind` (no `SO_REUSEPORT`) so a busy port is
/// reliably detected and the next candidate is tried. Useful for ephemeral local
/// servers — like the OAuth callback — where a stable port is preferred but a free
/// adjacent port is acceptable when the default is occupied.
///
/// # Errors
///
/// Returns an error if every port in `base_port..base_port + max_attempts` is busy
/// or otherwise unbindable, or if `max_attempts` is zero.
pub async fn bind_with_port_retry(
host: &str,
base_port: u16,
max_attempts: u16,
) -> Result<(tokio::net::TcpListener, u16)> {
if max_attempts == 0 {
anyhow::bail!("bind_with_port_retry: max_attempts must be > 0");
}

let end_port = base_port.saturating_add(max_attempts);
let mut last_err: Option<anyhow::Error> = None;
for offset in 0..max_attempts {
let Some(port) = base_port.checked_add(offset) else {
// NOTE: Stop early if we would overflow u16 (e.g. base_port=65535).
break;
};
let addr = crate::cli::format_bind_addr(host, port);
match tokio::net::TcpListener::bind(&addr).await {
Ok(listener) => {
return Ok((listener, port));
}
Err(e) => {
last_err = Some(anyhow::Error::new(e).context(format!("bind {}", addr)));
}
}
}

Err(last_err
.unwrap_or_else(|| anyhow::anyhow!("bind_with_port_retry: no attempts made for {}", host))
.context(format!(
"Could not bind any port in range {}..{}",
base_port, end_port
)))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -135,4 +186,76 @@ mod tests {
fn test_invalid_addr() {
assert!(bind_reuseport_std("not_an_addr").is_err());
}

#[tokio::test]
async fn test_bind_with_port_retry_uses_base_port_when_free() {
// Pick an ephemeral port to use as our "base", then immediately drop it
// so the retry helper can claim it on the first try.
let probe = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let base_port = probe.local_addr().unwrap().port();
drop(probe);

let (_listener, port) = bind_with_port_retry("127.0.0.1", base_port, 5)
.await
.expect("bind should succeed when base port is free");
assert_eq!(port, base_port);
}

#[tokio::test]
async fn test_bind_with_port_retry_falls_back_to_next_port() {
// Hold the base port so the retry helper has to advance to base+1.
let blocker = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let base_port = blocker.local_addr().unwrap().port();
// Keep the blocker alive to ensure base_port stays busy for the test.

let (_listener, port) = bind_with_port_retry("127.0.0.1", base_port, 5)
.await
.expect("bind should succeed on a fallback port");
assert!(
port > base_port && port < base_port.saturating_add(5),
"expected fallback port in range, got {port} (base {base_port})"
);
drop(blocker);
}

#[tokio::test]
async fn test_bind_with_port_retry_zero_attempts_errors() {
let err = bind_with_port_retry("127.0.0.1", 50_000, 0)
.await
.unwrap_err();
assert!(
err.to_string().contains("max_attempts must be > 0"),
"unexpected error: {err}"
);
}

#[tokio::test]
async fn test_bind_with_port_retry_reports_range_when_all_busy() {
// Hold a contiguous block of ports so every retry attempt must fail.
let blockers: Vec<std::net::TcpListener> = (0..3)
.map(|_| std::net::TcpListener::bind("127.0.0.1:0").unwrap())
.collect();

// Find a contiguous run of busy ports we already hold.
let ports: Vec<u16> = blockers
.iter()
.map(|l| l.local_addr().unwrap().port())
.collect();
let mut sorted_ports = ports.clone();
sorted_ports.sort_unstable();

// NOTE: We can't guarantee the OS gave us contiguous ports, so we
// verify the error path by trying to bind a single busy port with
// attempts=1, which deterministically fails with the range message.
let busy_port = sorted_ports[0];
let err = bind_with_port_retry("127.0.0.1", busy_port, 1)
.await
.unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("Could not bind any port in range"),
"unexpected error: {msg}"
);
drop(blockers);
}
}
Loading