diff --git a/backend/src/handlers/user_tokens.rs b/backend/src/handlers/user_tokens.rs index daaaf1a7..8bde73eb 100644 --- a/backend/src/handlers/user_tokens.rs +++ b/backend/src/handlers/user_tokens.rs @@ -382,6 +382,31 @@ async fn generic_oauth_callback_impl( // Handle OAuth provider errors if let Some(ref error) = query.error { let msg = query.error_description.as_deref().unwrap_or(error.as_str()); + + let mut revoked_placeholders = 0_u64; + let mut state_lookup_error: Option = None; + if let Some(state_param) = query.state.as_deref().filter(|s| !s.is_empty()) { + match user_token_service::peek_oauth_state(&state.db, state_param).await { + Ok(oauth_state) => { + let owner_id = oauth_state + .target_user_id + .as_deref() + .unwrap_or(&oauth_state.user_id); + match user_api_key_service::revoke_pending_placeholders_for_provider( + &state.db, + owner_id, + &oauth_state.provider_config_id, + ) + .await + { + Ok(count) => revoked_placeholders = count, + Err(e) => state_lookup_error = Some(e.to_string()), + } + } + Err(e) => state_lookup_error = Some(e.to_string()), + } + } + audit_service::log_async( state.db.clone(), auth_user.as_ref().map(|u| u.user_id.to_string()), @@ -389,6 +414,8 @@ async fn generic_oauth_callback_impl( Some(serde_json::json!({ "error": error, "error_description": &query.error_description, + "revoked_placeholders": revoked_placeholders, + "state_lookup_error": state_lookup_error, })), None, None, @@ -913,15 +940,19 @@ fn ensure_callback_user_matches_state( #[cfg(test)] mod tests { use super::*; + use crate::models::oauth_state::{COLLECTION_NAME as OAUTH_STATES, OAuthState}; use crate::models::org_membership::COLLECTION_NAME as ORG_MEMBERSHIPS; use crate::models::provider_config::{COLLECTION_NAME as PROVIDER_CONFIGS, ProviderConfig}; use crate::models::user::{COLLECTION_NAME as USERS, UserType}; + use crate::models::user_api_key::{COLLECTION_NAME as USER_API_KEYS, UserApiKey}; use crate::models::user_provider_token::{ COLLECTION_NAME as USER_PROVIDER_TOKENS, UserProviderToken, }; use crate::mw::auth::AuthMethod; use crate::test_utils::{connect_test_database, test_app_state, test_membership, test_user}; - use chrono::Utc; + use axum::http::header::LOCATION; + use axum::response::IntoResponse; + use chrono::{Duration, Utc}; use uuid::Uuid; fn test_auth_user() -> AuthUser { @@ -1004,6 +1035,69 @@ mod tests { } } + fn test_oauth_state(state_id: &str, user_id: &str, provider_id: &str) -> OAuthState { + let now = Utc::now(); + OAuthState { + id: state_id.to_string(), + user_id: user_id.to_string(), + provider_config_id: provider_id.to_string(), + code_verifier: None, + device_code_encrypted: None, + user_code_encrypted: None, + poll_interval: None, + target_user_id: None, + credential_user_id: None, + redirect_path: None, + expires_at: now + Duration::minutes(10), + created_at: now, + } + } + + fn test_pending_oauth_api_key(key_id: &str, user_id: &str, provider_id: &str) -> UserApiKey { + let now = Utc::now(); + UserApiKey { + id: key_id.to_string(), + user_id: user_id.to_string(), + label: "GitHub OAuth".to_string(), + credential_type: "oauth2".to_string(), + credential_encrypted: None, + access_token_encrypted: None, + refresh_token_encrypted: None, + token_scopes: None, + expires_at: None, + provider_config_id: Some(provider_id.to_string()), + user_oauth_client_id_encrypted: None, + user_oauth_client_secret_encrypted: None, + status: "pending_auth".to_string(), + last_used_at: None, + error_message: None, + source: Some("user_created".to_string()), + source_id: None, + created_at: now, + updated_at: now, + } + } + + async fn get_api_key(db: &mongodb::Database, key_id: &str) -> UserApiKey { + db.collection::(USER_API_KEYS) + .find_one(mongodb::bson::doc! { "_id": key_id }) + .await + .unwrap() + .unwrap() + } + + fn redirect_location(redirect: axum::response::Redirect) -> String { + let response = redirect.into_response(); + assert!(response.status().is_redirection()); + response + .headers() + .get(LOCATION) + .expect("redirect location") + .to_str() + .expect("valid redirect location") + .to_string() + } + #[test] fn callback_state_allows_missing_session_cookie() { assert!(ensure_callback_user_matches_state(None, "user-123").is_ok()); @@ -1135,4 +1229,119 @@ mod tests { assert_eq!(response.tokens[0].provider_id, provider_id); assert_eq!(response.tokens[0].provider_name, "GitHub"); } + + #[tokio::test] + async fn generic_oauth_callback_denial_revokes_placeholder() { + let Some(db) = connect_test_database("oauth_callback_denial_revokes_placeholder").await + else { + eprintln!( + "skipping provider token handler integration test: no local MongoDB available" + ); + return; + }; + let state = test_app_state(db.clone()); + let user_id = Uuid::new_v4().to_string(); + let provider_id = Uuid::new_v4().to_string(); + let state_id = Uuid::new_v4().to_string(); + let key_id = Uuid::new_v4().to_string(); + + db.collection::(OAUTH_STATES) + .insert_one(test_oauth_state(&state_id, &user_id, &provider_id)) + .await + .unwrap(); + db.collection::(USER_API_KEYS) + .insert_one(test_pending_oauth_api_key(&key_id, &user_id, &provider_id)) + .await + .unwrap(); + + let redirect = generic_oauth_callback_impl( + state, + None, + GenericOAuthCallbackQuery { + code: None, + state: Some(state_id), + error: Some("access_denied".to_string()), + error_description: None, + }, + ) + .await; + + let location = redirect_location(redirect); + assert!(location.contains("/providers/callback")); + assert!(location.contains("status=error")); + assert!(location.contains("message=access_denied")); + assert_eq!(get_api_key(&db, &key_id).await.status, "revoked"); + } + + #[tokio::test] + async fn generic_oauth_callback_denial_without_state_redirects_only() { + let Some(db) = connect_test_database("oauth_callback_denial_without_state").await else { + eprintln!( + "skipping provider token handler integration test: no local MongoDB available" + ); + return; + }; + let state = test_app_state(db.clone()); + let user_id = Uuid::new_v4().to_string(); + let provider_id = Uuid::new_v4().to_string(); + let key_id = Uuid::new_v4().to_string(); + + db.collection::(USER_API_KEYS) + .insert_one(test_pending_oauth_api_key(&key_id, &user_id, &provider_id)) + .await + .unwrap(); + + let redirect = generic_oauth_callback_impl( + state, + None, + GenericOAuthCallbackQuery { + code: None, + state: None, + error: Some("access_denied".to_string()), + error_description: None, + }, + ) + .await; + + let location = redirect_location(redirect); + assert!(location.contains("status=error")); + assert!(location.contains("message=access_denied")); + assert_eq!(get_api_key(&db, &key_id).await.status, "pending_auth"); + } + + #[tokio::test] + async fn generic_oauth_callback_denial_with_invalid_state_redirects_only() { + let Some(db) = connect_test_database("oauth_callback_denial_invalid_state").await else { + eprintln!( + "skipping provider token handler integration test: no local MongoDB available" + ); + return; + }; + let state = test_app_state(db.clone()); + let user_id = Uuid::new_v4().to_string(); + let provider_id = Uuid::new_v4().to_string(); + let key_id = Uuid::new_v4().to_string(); + + db.collection::(USER_API_KEYS) + .insert_one(test_pending_oauth_api_key(&key_id, &user_id, &provider_id)) + .await + .unwrap(); + + let redirect = generic_oauth_callback_impl( + state, + None, + GenericOAuthCallbackQuery { + code: None, + state: Some("bogus-state".to_string()), + error: Some("access_denied".to_string()), + error_description: None, + }, + ) + .await; + + let location = redirect_location(redirect); + assert!(location.contains("status=error")); + assert!(location.contains("message=access_denied")); + assert_eq!(get_api_key(&db, &key_id).await.status, "pending_auth"); + } } diff --git a/backend/src/services/user_api_key_service.rs b/backend/src/services/user_api_key_service.rs index e5717036..bb3c0066 100644 --- a/backend/src/services/user_api_key_service.rs +++ b/backend/src/services/user_api_key_service.rs @@ -302,6 +302,39 @@ pub async fn sync_provider_token_to_api_keys( Ok(()) } +/// Revoke any placeholder UserApiKey rows tied to a denied or failed OAuth +/// flow so the wizard's polling can exit immediately instead of waiting for +/// the 5-minute deadline. +/// +/// Each match is revoked through `revoke_api_key_if_pending` so a callback or +/// activation race cannot flip a non-pending key back to revoked. +pub async fn revoke_pending_placeholders_for_provider( + db: &mongodb::Database, + user_id: &str, + provider_config_id: &str, +) -> AppResult { + let placeholders: Vec = db + .collection::(COLLECTION_NAME) + .find(doc! { + "user_id": user_id, + "provider_config_id": provider_config_id, + "status": "pending_auth", + "credential_type": { "$ne": "node_managed" }, + }) + .await? + .try_collect() + .await?; + + let mut revoked = 0_u64; + for key in placeholders { + if revoke_api_key_if_pending(db, user_id, &key.id).await? { + revoked += 1; + } + } + + Ok(revoked) +} + pub async fn activate_node_managed_api_key( db: &mongodb::Database, user_id: &str, @@ -698,7 +731,10 @@ mod tests { use chrono::Utc; use mongodb::bson::doc; - use super::{USER_PROVIDER_TOKENS, has_server_credential, sync_provider_token_to_api_keys}; + use super::{ + USER_PROVIDER_TOKENS, has_server_credential, revoke_pending_placeholders_for_provider, + sync_provider_token_to_api_keys, + }; use crate::models::user_api_key::UserApiKey; use crate::models::user_provider_token::UserProviderToken; use crate::test_utils::connect_test_database; @@ -727,6 +763,29 @@ mod tests { } } + fn provider_key( + key_id: &str, + user_id: &str, + provider_id: &str, + status: &str, + credential_type: &str, + ) -> UserApiKey { + let mut key = sample_key(credential_type); + key.id = key_id.to_string(); + key.user_id = user_id.to_string(); + key.provider_config_id = Some(provider_id.to_string()); + key.status = status.to_string(); + key + } + + async fn get_key(db: &mongodb::Database, key_id: &str) -> UserApiKey { + db.collection::(super::COLLECTION_NAME) + .find_one(doc! { "_id": key_id }) + .await + .unwrap() + .unwrap() + } + #[test] fn detects_server_credential_for_oauth_keys() { let mut key = sample_key("oauth2"); @@ -837,4 +896,167 @@ mod tests { Some(vec![4, 5, 6]) ); } + + #[tokio::test] + async fn revoke_pending_placeholders_for_provider_revokes_pending_match() { + let Some(db) = connect_test_database("user_api_key_revoke_provider_pending_matches").await + else { + eprintln!("skipping user_api_key_service integration test: no local MongoDB available"); + return; + }; + + let user_id = uuid::Uuid::new_v4().to_string(); + let provider_id = uuid::Uuid::new_v4().to_string(); + let key_1 = uuid::Uuid::new_v4().to_string(); + let key_2 = uuid::Uuid::new_v4().to_string(); + + db.collection::(super::COLLECTION_NAME) + .insert_many(vec![ + provider_key(&key_1, &user_id, &provider_id, "pending_auth", "oauth2"), + provider_key(&key_2, &user_id, &provider_id, "pending_auth", "oauth2"), + ]) + .await + .unwrap(); + + let revoked = revoke_pending_placeholders_for_provider(&db, &user_id, &provider_id) + .await + .unwrap(); + + assert_eq!(revoked, 2); + assert_eq!(get_key(&db, &key_1).await.status, "revoked"); + assert_eq!(get_key(&db, &key_2).await.status, "revoked"); + } + + #[tokio::test] + async fn revoke_pending_placeholders_for_provider_skips_active() { + let Some(db) = connect_test_database("user_api_key_revoke_provider_skips_active").await + else { + eprintln!("skipping user_api_key_service integration test: no local MongoDB available"); + return; + }; + + let user_id = uuid::Uuid::new_v4().to_string(); + let provider_id = uuid::Uuid::new_v4().to_string(); + let active_key = uuid::Uuid::new_v4().to_string(); + let pending_key = uuid::Uuid::new_v4().to_string(); + + db.collection::(super::COLLECTION_NAME) + .insert_many(vec![ + provider_key(&active_key, &user_id, &provider_id, "active", "oauth2"), + provider_key( + &pending_key, + &user_id, + &provider_id, + "pending_auth", + "oauth2", + ), + ]) + .await + .unwrap(); + + let revoked = revoke_pending_placeholders_for_provider(&db, &user_id, &provider_id) + .await + .unwrap(); + + assert_eq!(revoked, 1); + assert_eq!(get_key(&db, &active_key).await.status, "active"); + assert_eq!(get_key(&db, &pending_key).await.status, "revoked"); + } + + #[tokio::test] + async fn revoke_pending_placeholders_for_provider_skips_node_managed() { + let Some(db) = + connect_test_database("user_api_key_revoke_provider_skips_node_managed").await + else { + eprintln!("skipping user_api_key_service integration test: no local MongoDB available"); + return; + }; + + let user_id = uuid::Uuid::new_v4().to_string(); + let provider_id = uuid::Uuid::new_v4().to_string(); + let node_key = uuid::Uuid::new_v4().to_string(); + let oauth_key = uuid::Uuid::new_v4().to_string(); + + db.collection::(super::COLLECTION_NAME) + .insert_many(vec![ + provider_key( + &node_key, + &user_id, + &provider_id, + "pending_auth", + "node_managed", + ), + provider_key(&oauth_key, &user_id, &provider_id, "pending_auth", "oauth2"), + ]) + .await + .unwrap(); + + let revoked = revoke_pending_placeholders_for_provider(&db, &user_id, &provider_id) + .await + .unwrap(); + + assert_eq!(revoked, 1); + assert_eq!(get_key(&db, &node_key).await.status, "pending_auth"); + assert_eq!(get_key(&db, &oauth_key).await.status, "revoked"); + } + + #[tokio::test] + async fn revoke_pending_placeholders_for_provider_no_matches_returns_zero() { + let Some(db) = connect_test_database("user_api_key_revoke_provider_no_matches").await + else { + eprintln!("skipping user_api_key_service integration test: no local MongoDB available"); + return; + }; + + let user_id = uuid::Uuid::new_v4().to_string(); + let provider_id = uuid::Uuid::new_v4().to_string(); + + let revoked = revoke_pending_placeholders_for_provider(&db, &user_id, &provider_id) + .await + .unwrap(); + + assert_eq!(revoked, 0); + } + + #[tokio::test] + async fn revoke_pending_placeholders_for_provider_scopes_by_provider() { + let Some(db) = connect_test_database("user_api_key_revoke_provider_scopes").await else { + eprintln!("skipping user_api_key_service integration test: no local MongoDB available"); + return; + }; + + let user_id = uuid::Uuid::new_v4().to_string(); + let provider_id = uuid::Uuid::new_v4().to_string(); + let other_provider_id = uuid::Uuid::new_v4().to_string(); + let matching_key = uuid::Uuid::new_v4().to_string(); + let other_key = uuid::Uuid::new_v4().to_string(); + + db.collection::(super::COLLECTION_NAME) + .insert_many(vec![ + provider_key( + &matching_key, + &user_id, + &provider_id, + "pending_auth", + "oauth2", + ), + provider_key( + &other_key, + &user_id, + &other_provider_id, + "pending_auth", + "oauth2", + ), + ]) + .await + .unwrap(); + + let revoked = revoke_pending_placeholders_for_provider(&db, &user_id, &provider_id) + .await + .unwrap(); + + assert_eq!(revoked, 1); + assert_eq!(get_key(&db, &matching_key).await.status, "revoked"); + assert_eq!(get_key(&db, &other_key).await.status, "pending_auth"); + } }