Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 210 additions & 1 deletion backend/src/handlers/user_tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,40 @@ async fn generic_oauth_callback_impl(
// Handle OAuth provider errors
if let Some(ref error) = query.error {
let msg = query.error_description.as_deref().unwrap_or(error.as_str());

let mut revoked_placeholders = 0_u64;
let mut state_lookup_error: Option<String> = None;
if let Some(state_param) = query.state.as_deref().filter(|s| !s.is_empty()) {
match user_token_service::peek_oauth_state(&state.db, state_param).await {
Ok(oauth_state) => {
let owner_id = oauth_state
.target_user_id
.as_deref()
.unwrap_or(&oauth_state.user_id);
match user_api_key_service::revoke_pending_placeholders_for_provider(
&state.db,
owner_id,
&oauth_state.provider_config_id,
)
.await
{
Ok(count) => revoked_placeholders = count,
Err(e) => state_lookup_error = Some(e.to_string()),
}
}
Err(e) => state_lookup_error = Some(e.to_string()),
}
}

audit_service::log_async(
state.db.clone(),
auth_user.as_ref().map(|u| u.user_id.to_string()),
"provider_oauth_callback_failed".to_string(),
Some(serde_json::json!({
"error": error,
"error_description": &query.error_description,
"revoked_placeholders": revoked_placeholders,
"state_lookup_error": state_lookup_error,
})),
None,
None,
Expand Down Expand Up @@ -913,15 +940,19 @@ fn ensure_callback_user_matches_state(
#[cfg(test)]
mod tests {
use super::*;
use crate::models::oauth_state::{COLLECTION_NAME as OAUTH_STATES, OAuthState};
use crate::models::org_membership::COLLECTION_NAME as ORG_MEMBERSHIPS;
use crate::models::provider_config::{COLLECTION_NAME as PROVIDER_CONFIGS, ProviderConfig};
use crate::models::user::{COLLECTION_NAME as USERS, UserType};
use crate::models::user_api_key::{COLLECTION_NAME as USER_API_KEYS, UserApiKey};
use crate::models::user_provider_token::{
COLLECTION_NAME as USER_PROVIDER_TOKENS, UserProviderToken,
};
use crate::mw::auth::AuthMethod;
use crate::test_utils::{connect_test_database, test_app_state, test_membership, test_user};
use chrono::Utc;
use axum::http::header::LOCATION;
use axum::response::IntoResponse;
use chrono::{Duration, Utc};
use uuid::Uuid;

fn test_auth_user() -> AuthUser {
Expand Down Expand Up @@ -1004,6 +1035,69 @@ mod tests {
}
}

fn test_oauth_state(state_id: &str, user_id: &str, provider_id: &str) -> OAuthState {
let now = Utc::now();
OAuthState {
id: state_id.to_string(),
user_id: user_id.to_string(),
provider_config_id: provider_id.to_string(),
code_verifier: None,
device_code_encrypted: None,
user_code_encrypted: None,
poll_interval: None,
target_user_id: None,
credential_user_id: None,
redirect_path: None,
expires_at: now + Duration::minutes(10),
created_at: now,
}
}

fn test_pending_oauth_api_key(key_id: &str, user_id: &str, provider_id: &str) -> UserApiKey {
let now = Utc::now();
UserApiKey {
id: key_id.to_string(),
user_id: user_id.to_string(),
label: "GitHub OAuth".to_string(),
credential_type: "oauth2".to_string(),
credential_encrypted: None,
access_token_encrypted: None,
refresh_token_encrypted: None,
token_scopes: None,
expires_at: None,
provider_config_id: Some(provider_id.to_string()),
user_oauth_client_id_encrypted: None,
user_oauth_client_secret_encrypted: None,
status: "pending_auth".to_string(),
last_used_at: None,
error_message: None,
source: Some("user_created".to_string()),
source_id: None,
created_at: now,
updated_at: now,
}
}

async fn get_api_key(db: &mongodb::Database, key_id: &str) -> UserApiKey {
db.collection::<UserApiKey>(USER_API_KEYS)
.find_one(mongodb::bson::doc! { "_id": key_id })
.await
.unwrap()
.unwrap()
}

fn redirect_location(redirect: axum::response::Redirect) -> String {
let response = redirect.into_response();
assert!(response.status().is_redirection());
response
.headers()
.get(LOCATION)
.expect("redirect location")
.to_str()
.expect("valid redirect location")
.to_string()
}

#[test]
fn callback_state_allows_missing_session_cookie() {
assert!(ensure_callback_user_matches_state(None, "user-123").is_ok());
Expand Down Expand Up @@ -1135,4 +1229,119 @@ mod tests {
assert_eq!(response.tokens[0].provider_id, provider_id);
assert_eq!(response.tokens[0].provider_name, "GitHub");
}

#[tokio::test]
async fn generic_oauth_callback_denial_revokes_placeholder() {
let Some(db) = connect_test_database("oauth_callback_denial_revokes_placeholder").await
else {
eprintln!(
"skipping provider token handler integration test: no local MongoDB available"
);
return;
};
let state = test_app_state(db.clone());
let user_id = Uuid::new_v4().to_string();
let provider_id = Uuid::new_v4().to_string();
let state_id = Uuid::new_v4().to_string();
let key_id = Uuid::new_v4().to_string();

db.collection::<OAuthState>(OAUTH_STATES)
.insert_one(test_oauth_state(&state_id, &user_id, &provider_id))
.await
.unwrap();
db.collection::<UserApiKey>(USER_API_KEYS)
.insert_one(test_pending_oauth_api_key(&key_id, &user_id, &provider_id))
.await
.unwrap();

let redirect = generic_oauth_callback_impl(
state,
None,
GenericOAuthCallbackQuery {
code: None,
state: Some(state_id),
error: Some("access_denied".to_string()),
error_description: None,
},
)
.await;

let location = redirect_location(redirect);
assert!(location.contains("/providers/callback"));
assert!(location.contains("status=error"));
assert!(location.contains("message=access_denied"));
assert_eq!(get_api_key(&db, &key_id).await.status, "revoked");
}

#[tokio::test]
async fn generic_oauth_callback_denial_without_state_redirects_only() {
let Some(db) = connect_test_database("oauth_callback_denial_without_state").await else {
eprintln!(
"skipping provider token handler integration test: no local MongoDB available"
);
return;
};
let state = test_app_state(db.clone());
let user_id = Uuid::new_v4().to_string();
let provider_id = Uuid::new_v4().to_string();
let key_id = Uuid::new_v4().to_string();

db.collection::<UserApiKey>(USER_API_KEYS)
.insert_one(test_pending_oauth_api_key(&key_id, &user_id, &provider_id))
.await
.unwrap();

let redirect = generic_oauth_callback_impl(
state,
None,
GenericOAuthCallbackQuery {
code: None,
state: None,
error: Some("access_denied".to_string()),
error_description: None,
},
)
.await;

let location = redirect_location(redirect);
assert!(location.contains("status=error"));
assert!(location.contains("message=access_denied"));
assert_eq!(get_api_key(&db, &key_id).await.status, "pending_auth");
}

#[tokio::test]
async fn generic_oauth_callback_denial_with_invalid_state_redirects_only() {
let Some(db) = connect_test_database("oauth_callback_denial_invalid_state").await else {
eprintln!(
"skipping provider token handler integration test: no local MongoDB available"
);
return;
};
let state = test_app_state(db.clone());
let user_id = Uuid::new_v4().to_string();
let provider_id = Uuid::new_v4().to_string();
let key_id = Uuid::new_v4().to_string();

db.collection::<UserApiKey>(USER_API_KEYS)
.insert_one(test_pending_oauth_api_key(&key_id, &user_id, &provider_id))
.await
.unwrap();

let redirect = generic_oauth_callback_impl(
state,
None,
GenericOAuthCallbackQuery {
code: None,
state: Some("bogus-state".to_string()),
error: Some("access_denied".to_string()),
error_description: None,
},
)
.await;

let location = redirect_location(redirect);
assert!(location.contains("status=error"));
assert!(location.contains("message=access_denied"));
assert_eq!(get_api_key(&db, &key_id).await.status, "pending_auth");
}
}
Loading
Loading