diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index ec1b03ce..8ccf91f9 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -114,6 +114,13 @@ pub mod auth; #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient}; +#[cfg(feature = "auth")] +#[cfg_attr(docsrs, doc(cfg(feature = "auth")))] +pub mod auth_server; +#[cfg(feature = "auth")] +#[cfg_attr(docsrs, doc(cfg(feature = "auth")))] +pub use auth_server::{AuthServer, AuthServerConfig, ServerAuthError}; + // #[cfg(feature = "transport-ws")] // #[cfg_attr(docsrs, doc(cfg(feature = "transport-ws")))] // pub mod ws; diff --git a/crates/rmcp/src/transport/auth_server.rs b/crates/rmcp/src/transport/auth_server.rs new file mode 100644 index 00000000..e738e178 --- /dev/null +++ b/crates/rmcp/src/transport/auth_server.rs @@ -0,0 +1,591 @@ +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; + +use oauth2::{ + AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, + basic::{BasicClient, BasicTokenType}, +}; +use reqwest::{StatusCode, Url, header::AUTHORIZATION}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio::sync::RwLock; +use tracing::{debug, error, info}; + +use crate::transport::auth::AuthError; + +/// Server OAuth error types +#[derive(Debug, Error)] +pub enum ServerAuthError { + #[error("Invalid token: {0}")] + InvalidToken(String), + + #[error("Token expired")] + TokenExpired, + + #[error("Insufficient scope: {0}")] + InsufficientScope(String), + + #[error("Invalid client: {0}")] + InvalidClient(String), + + #[error("Internal error: {0}")] + InternalError(String), + + #[error("HTTP error: {0}")] + HttpError(#[from] reqwest::Error), + + #[error("URL parse error: {0}")] + UrlError(#[from] url::ParseError), + + #[error("Third party authorization error: {0}")] + ThirdPartyAuthError(String), +} + +/// Represents an OAuth server configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthServerConfig { + pub client_id: String, + pub client_secret: String, + pub authorize_endpoint: String, + pub token_endpoint: String, + pub registration_endpoint: String, + pub issuer: String, + pub supported_scopes: Vec, +} + +/// Third party authorization server configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThirdPartyAuthConfig { + /// Client ID to use with the third-party auth server + pub client_id: String, + /// Client secret to use with the third-party auth server + pub client_secret: String, + /// URL of the third-party authorization server + pub auth_server_url: String, + /// Authorization endpoint at the third-party server + pub authorize_endpoint: String, + /// Token endpoint at the third-party server + pub token_endpoint: String, + /// Token introspection endpoint at the third-party server + pub introspection_endpoint: Option, + /// Token revocation endpoint at the third-party server + pub revocation_endpoint: Option, + /// Scopes supported by the third-party server + pub supported_scopes: Vec, + /// Additional parameters to include in auth requests + pub additional_params: HashMap, +} + +/// Represents OAuth server metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthServerMetadata { + pub authorization_endpoint: String, + pub token_endpoint: String, + pub registration_endpoint: String, + pub issuer: String, + pub scopes_supported: Vec, + #[serde(flatten)] + pub additional_fields: HashMap, +} + +/// Dynamic client registration request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientRegistrationRequest { + pub client_name: String, + pub redirect_uris: Vec, + pub grant_types: Vec, + pub token_endpoint_auth_method: String, + pub response_types: Vec, +} + +/// Dynamic client registration response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientRegistrationResponse { + pub client_id: String, + pub client_secret: Option, + pub client_name: String, + pub redirect_uris: Vec, + #[serde(flatten)] + pub additional_fields: HashMap, +} + +/// Token info for validation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenInfo { + pub active: bool, + pub scope: Option, + pub client_id: String, + pub username: Option, + pub exp: Option, + #[serde(flatten)] + pub additional_fields: HashMap, +} + +/// Third-party token introspection request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenIntrospectionRequest { + pub token: String, + pub token_type_hint: Option, +} + +/// Third-party token introspection response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenIntrospectionResponse { + pub active: bool, + pub scope: Option, + pub client_id: Option, + pub username: Option, + pub exp: Option, + #[serde(flatten)] + pub additional_fields: HashMap, +} + +/// OAuth server implementation +pub struct AuthServer { + config: AuthServerConfig, + third_party_config: Option, + registered_clients: RwLock>, + active_tokens: RwLock>, + server_base_url: Url, + http_client: reqwest::Client, +} + +impl AuthServer { + /// Create a new OAuth server + pub fn new(config: AuthServerConfig, server_base_url: Url) -> Self { + Self { + config, + third_party_config: None, + registered_clients: RwLock::new(HashMap::new()), + active_tokens: RwLock::new(HashMap::new()), + server_base_url, + http_client: reqwest::Client::new(), + } + } + + /// Create a new OAuth server with third-party authorization + pub fn with_third_party_auth( + config: AuthServerConfig, + third_party_config: ThirdPartyAuthConfig, + server_base_url: Url + ) -> Self { + Self { + config, + third_party_config: Some(third_party_config), + registered_clients: RwLock::new(HashMap::new()), + active_tokens: RwLock::new(HashMap::new()), + server_base_url, + http_client: reqwest::Client::new(), + } + } + + /// Check if server is using third-party auth + pub fn is_using_third_party_auth(&self) -> bool { + self.third_party_config.is_some() + } + + /// Generate metadata document + pub fn generate_metadata(&self) -> AuthServerMetadata { + AuthServerMetadata { + authorization_endpoint: self.config.authorize_endpoint.clone(), + token_endpoint: self.config.token_endpoint.clone(), + registration_endpoint: self.config.registration_endpoint.clone(), + issuer: self.config.issuer.clone(), + scopes_supported: self.config.supported_scopes.clone(), + additional_fields: HashMap::new(), + } + } + + /// Handle client registration + pub async fn handle_registration( + &self, + request: ClientRegistrationRequest, + ) -> Result { + debug!("Handling client registration request: {:?}", request); + + // Validate redirect URIs + for uri in &request.redirect_uris { + let url = Url::parse(uri).map_err(|e| { + ServerAuthError::InvalidClient(format!("Invalid redirect URI {}: {}", uri, e)) + })?; + + // Validate according to MCP spec (must be localhost or HTTPS) + if url.scheme() != "https" && !(url.host_str() == Some("localhost") || url.host_str() == Some("127.0.0.1")) { + return Err(ServerAuthError::InvalidClient( + format!("Redirect URI must use HTTPS or localhost: {}", uri) + )); + } + } + + // Generate client ID and optional secret + let client_id = format!("client_{}", uuid::Uuid::new_v4().to_string().replace("-", "")); + let client_secret = Some(format!("secret_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))); + + let response = ClientRegistrationResponse { + client_id: client_id.clone(), + client_secret, + client_name: request.client_name, + redirect_uris: request.redirect_uris, + additional_fields: HashMap::new(), + }; + + // Store the registered client + self.registered_clients.write().await.insert(client_id, response.clone()); + + debug!("Client registered successfully: {}", response.client_id); + Ok(response) + } + + /// Validate token + pub async fn validate_token(&self, token: &str) -> Result { + // If third-party auth is configured, use that for validation + if let Some(third_party_config) = &self.third_party_config { + return self.validate_token_with_third_party(token, third_party_config).await; + } + + // Otherwise, use local validation + let active_tokens = self.active_tokens.read().await; + + if let Some(token_info) = active_tokens.get(token) { + // Check if token is expired + if let Some(exp) = token_info.exp { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if now > exp { + return Err(ServerAuthError::TokenExpired); + } + } + + if token_info.active { + Ok(token_info.clone()) + } else { + Err(ServerAuthError::InvalidToken("Token is inactive".to_string())) + } + } else { + Err(ServerAuthError::InvalidToken("Unknown token".to_string())) + } + } + + /// Validate token with third-party auth server + async fn validate_token_with_third_party( + &self, + token: &str, + config: &ThirdPartyAuthConfig, + ) -> Result { + // Check if introspection endpoint is available + let introspection_endpoint = config.introspection_endpoint.as_ref().ok_or_else(|| { + ServerAuthError::ThirdPartyAuthError("No introspection endpoint configured".to_string()) + })?; + + // Prepare introspection request + let introspection_request = TokenIntrospectionRequest { + token: token.to_string(), + token_type_hint: Some("access_token".to_string()), + }; + + // Make request to third-party server + let response = self.http_client + .post(introspection_endpoint) + .basic_auth(&config.client_id, Some(&config.client_secret)) + .form(&introspection_request) + .send() + .await + .map_err(|e| ServerAuthError::HttpError(e))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + return Err(ServerAuthError::ThirdPartyAuthError( + format!("Introspection failed: HTTP {} - {}", status, error_text) + )); + } + + // Parse introspection response + let introspection: TokenIntrospectionResponse = response.json().await + .map_err(|e| ServerAuthError::ThirdPartyAuthError( + format!("Failed to parse introspection response: {}", e) + ))?; + + if !introspection.active { + return Err(ServerAuthError::InvalidToken("Token is inactive".to_string())); + } + + // Check expiration + if let Some(exp) = introspection.exp { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if now > exp { + return Err(ServerAuthError::TokenExpired); + } + } + + // Convert to internal TokenInfo + Ok(TokenInfo { + active: introspection.active, + scope: introspection.scope, + client_id: introspection.client_id.unwrap_or_else(|| "unknown".to_string()), + username: introspection.username, + exp: introspection.exp, + additional_fields: introspection.additional_fields, + }) + } + + /// Validate token for specific scope + pub async fn validate_token_scope(&self, token: &str, required_scope: &str) -> Result { + let token_info = self.validate_token(token).await?; + + // Check if token has the required scope + if let Some(scope) = &token_info.scope { + let scopes: Vec<&str> = scope.split(' ').collect(); + if scopes.contains(&required_scope) { + Ok(token_info) + } else { + Err(ServerAuthError::InsufficientScope(format!( + "Token does not have required scope: {}", required_scope + ))) + } + } else { + Err(ServerAuthError::InsufficientScope(format!( + "Token does not specify any scopes, required: {}", required_scope + ))) + } + } + + /// Extract token from Authorization header + pub fn extract_token_from_header(&self, auth_header: Option<&str>) -> Result { + match auth_header { + Some(header) => { + if header.starts_with("Bearer ") { + Ok(header[7..].to_string()) + } else { + Err(ServerAuthError::InvalidToken( + "Authorization header must use Bearer scheme".to_string() + )) + } + } + None => Err(ServerAuthError::InvalidToken("Missing Authorization header".to_string())), + } + } + + /// Register a token (for testing or custom token creation) + pub async fn register_token(&self, token: String, token_info: TokenInfo) -> Result<(), ServerAuthError> { + self.active_tokens.write().await.insert(token, token_info); + Ok(()) + } + + /// Revoke a token + pub async fn revoke_token(&self, token: &str) -> Result<(), ServerAuthError> { + // If third-party auth is configured and has revocation endpoint, use that + if let Some(config) = &self.third_party_config { + if let Some(revocation_endpoint) = &config.revocation_endpoint { + // Attempt to revoke with third-party + let form = [ + ("token", token), + ("token_type_hint", "access_token"), + ]; + + let response = self.http_client + .post(revocation_endpoint) + .basic_auth(&config.client_id, Some(&config.client_secret)) + .form(&form) + .send() + .await + .map_err(|e| ServerAuthError::HttpError(e))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + + // Log the error but don't fail - per OAuth spec, token revocation + // should succeed even for unknown tokens + error!("Third-party token revocation failed: HTTP {} - {}", status, error_text); + } + } + } + + // Always remove from local storage + let mut active_tokens = self.active_tokens.write().await; + active_tokens.remove(token); + + Ok(()) + } +} + +/// Extension trait for axum handlers +#[cfg(feature = "axum")] +pub mod axum_ext { + use super::*; + use axum::{ + extract::State, + http::{HeaderMap, Request, StatusCode}, + middleware::Next, + response::{IntoResponse, Response}, + Json, + }; + + pub async fn oauth_metadata_handler( + State(server): State>, + ) -> impl IntoResponse { + let metadata = server.generate_metadata(); + Json(metadata) + } + + pub async fn client_registration_handler( + State(server): State>, + Json(request): Json, + ) -> Result { + match server.handle_registration(request).await { + Ok(response) => Ok((StatusCode::CREATED, Json(response))), + Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())), + } + } + + pub async fn auth_middleware( + State(server): State>, + headers: HeaderMap, + request: Request, + next: Next, + ) -> Response { + // Extract token from Authorization header + let auth_header = headers.get(AUTHORIZATION).and_then(|h| h.to_str().ok()); + + match server.extract_token_from_header(auth_header) { + Ok(token) => { + match server.validate_token(&token).await { + Ok(_) => next.run(request).await, + Err(e) => { + let status = match e { + ServerAuthError::TokenExpired => StatusCode::UNAUTHORIZED, + ServerAuthError::InvalidToken(_) => StatusCode::UNAUTHORIZED, + ServerAuthError::InsufficientScope(_) => StatusCode::FORBIDDEN, + ServerAuthError::ThirdPartyAuthError(_) => StatusCode::UNAUTHORIZED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + + (status, e.to_string()).into_response() + } + } + } + Err(e) => (StatusCode::UNAUTHORIZED, e.to_string()).into_response(), + } + } + + /// Middleware for requiring specific scopes + pub fn scope_middleware(scope: &'static str) -> impl Fn( + State>, + HeaderMap, + Request, + Next, + ) -> std::pin::Pin + Send>> + Clone { + move |State(server): State>, + headers: HeaderMap, + request: Request, + next: Next| { + let server = server.clone(); + Box::pin(async move { + // Extract token from Authorization header + let auth_header = headers.get(AUTHORIZATION).and_then(|h| h.to_str().ok()); + + match server.extract_token_from_header(auth_header) { + Ok(token) => { + match server.validate_token_scope(&token, scope).await { + Ok(_) => next.run(request).await, + Err(e) => { + let status = match e { + ServerAuthError::TokenExpired => StatusCode::UNAUTHORIZED, + ServerAuthError::InvalidToken(_) => StatusCode::UNAUTHORIZED, + ServerAuthError::InsufficientScope(_) => StatusCode::FORBIDDEN, + ServerAuthError::ThirdPartyAuthError(_) => StatusCode::UNAUTHORIZED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + + (status, e.to_string()).into_response() + } + } + } + Err(e) => (StatusCode::UNAUTHORIZED, e.to_string()).into_response(), + } + }) + } + } +} + +/// Example server setup with axum +#[cfg(feature = "axum")] +pub mod example { + use super::*; + use axum::{ + routing::{get, post}, + Router, + }; + + pub async fn create_oauth_router(server: Arc) -> Router { + Router::new() + .route("/.well-known/oauth-authorization-server", get(axum_ext::oauth_metadata_handler)) + .route("/register", post(axum_ext::client_registration_handler)) + .with_state(server) + } + + /// Create an OAuth server with third-party authorization + pub fn create_third_party_oauth_server(base_url: &str) -> Result { + let server_base_url = Url::parse(base_url)?; + + // Local server config (will delegate to third-party) + let config = AuthServerConfig { + client_id: "mcp_server".to_string(), + client_secret: "mcp_server_secret".to_string(), + authorize_endpoint: format!("{}/authorize", server_base_url), + token_endpoint: format!("{}/token", server_base_url), + registration_endpoint: format!("{}/register", server_base_url), + issuer: server_base_url.to_string(), + supported_scopes: vec!["mcp".to_string(), "profile".to_string()], + }; + + // Third-party auth server config + let third_party_config = ThirdPartyAuthConfig { + client_id: "mcp_client_at_third_party".to_string(), + client_secret: "mcp_client_secret_at_third_party".to_string(), + auth_server_url: "https://auth.example.com".to_string(), + authorize_endpoint: "https://auth.example.com/oauth2/authorize".to_string(), + token_endpoint: "https://auth.example.com/oauth2/token".to_string(), + introspection_endpoint: Some("https://auth.example.com/oauth2/introspect".to_string()), + revocation_endpoint: Some("https://auth.example.com/oauth2/revoke".to_string()), + supported_scopes: vec!["openid".to_string(), "profile".to_string(), "mcp".to_string()], + additional_params: HashMap::new(), + }; + + Ok(AuthServer::with_third_party_auth(config, third_party_config, server_base_url)) + } + + /// Create a router with protected resources requiring specific scopes + pub fn create_protected_router(server: Arc) -> Router { + Router::new() + // Path that requires any valid token + .route("/api/resource", get(|| async { "Protected resource" })) + .route_layer(axum::middleware::from_fn_with_state( + server.clone(), + axum_ext::auth_middleware, + )) + // Path that requires specific scope "mcp" + .route("/api/mcp", get(|| async { "MCP protected resource" })) + .route_layer(axum::middleware::from_fn_with_state( + server.clone(), + axum_ext::scope_middleware("mcp"), + )) + // Path that requires specific scope "profile" + .route("/api/profile", get(|| async { "Profile protected resource" })) + .route_layer(axum::middleware::from_fn_with_state( + server.clone(), + axum_ext::scope_middleware("profile"), + )) + } +} \ No newline at end of file diff --git a/docs/OAUTH_SUPPORT.md b/docs/OAUTH_SUPPORT.md index bac516e3..87699a38 100644 --- a/docs/OAUTH_SUPPORT.md +++ b/docs/OAUTH_SUPPORT.md @@ -1,6 +1,6 @@ # Model Context Protocol OAuth Authorization -This document describes the OAuth 2.1 authorization implementation for Model Context Protocol (MCP), following the [MCP 2025-03-26 Authorization Specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/). +This document describes the OAuth 2.1 authorization implementation for Model Context Protocol (MCP), following the [MCP 2025-03-26 Authorization Specification](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization/). ## Features @@ -11,7 +11,10 @@ This document describes the OAuth 2.1 authorization implementation for Model Con - Automatic token refresh - Authorized SSE transport implementation - Authorized HTTP Client implementation -## Usage Guide +- Server-side OAuth implementation with token validation +- Axum middleware for protecting API endpoints + +## Client-Side Implementation ### 1. Enable Features @@ -76,20 +79,183 @@ rmcp = { version = "0.1", features = ["auth", "transport-sse-client"] } let client = oauth_state.to_authorized_http_client().await?; ``` -## Complete Example -client: Please refer to `examples/clients/src/oauth_client.rs` for a complete usage example. -server: Please refer to `examples/servers/src/mcp_oauth_server.rs` for a complete usage example. -### Running the Example in server -```bash -# Run example -cargo run --example mcp_oauth_server +## Server-Side Implementation + +The MCP SDK also provides a server-side OAuth implementation for creating OAuth 2.1 compliant MCP servers. + +### 1. Enable Server Features + +```toml +[dependencies] +rmcp = { version = "0.1", features = ["auth", "axum"] } +``` + +### 2. Basic Server Setup + +```rust +use std::sync::Arc; +use axum::{routing::get, Router}; +use rmcp::transport::auth_server::{AuthServer, AuthServerConfig, axum_ext}; +use url::Url; + +async fn setup_oauth_server() -> anyhow::Result { + // Configure server + let server_base_url = Url::parse("https://api.example.com")?; + + // Create OAuth server configuration + let config = AuthServerConfig { + client_id: "mcp_server".to_string(), + client_secret: "mcp_server_secret".to_string(), + authorize_endpoint: format!("{}/authorize", server_base_url), + token_endpoint: format!("{}/token", server_base_url), + registration_endpoint: format!("{}/register", server_base_url), + issuer: server_base_url.to_string(), + supported_scopes: vec!["mcp".to_string(), "profile".to_string()], + }; + + // Create AuthServer instance + let auth_server = Arc::new(AuthServer::new(config, server_base_url)); + + // Create routes + let oauth_router = Router::new() + .route("/.well-known/oauth-authorization-server", get(axum_ext::oauth_metadata_handler)) + .route("/register", get(axum_ext::client_registration_handler)) + .with_state(auth_server.clone()); + + // Create protected API routes + let protected_api = Router::new() + .route("/api/resource", get(resource_handler)) + .route_layer(axum::middleware::from_fn_with_state( + auth_server.clone(), + axum_ext::auth_middleware, + )); + + // Combine routes + let app = Router::new() + .merge(oauth_router) + .merge(protected_api); + + Ok(app) +} +``` + +### 3. Using Third-Party Authorization Servers + +MCP servers can delegate authentication to third-party OAuth servers (identity providers): + +```rust +use std::collections::HashMap; +use rmcp::transport::auth_server::{AuthServer, AuthServerConfig, ThirdPartyAuthConfig}; + +async fn setup_server_with_third_party_auth() -> anyhow::Result { + // Local MCP server configuration + let server_base_url = Url::parse("https://api.example.com")?; + let config = AuthServerConfig { + client_id: "mcp_server".to_string(), + client_secret: "mcp_server_secret".to_string(), + authorize_endpoint: format!("{}/authorize", server_base_url), + token_endpoint: format!("{}/token", server_base_url), + registration_endpoint: format!("{}/register", server_base_url), + issuer: server_base_url.to_string(), + supported_scopes: vec!["mcp".to_string(), "profile".to_string()], + }; + + // Third-party auth server configuration (e.g., Keycloak, Auth0, etc.) + let third_party_config = ThirdPartyAuthConfig { + client_id: "mcp_client_at_idp".to_string(), + client_secret: "mcp_client_secret_at_idp".to_string(), + auth_server_url: "https://auth.example.com".to_string(), + authorize_endpoint: "https://auth.example.com/oauth2/authorize".to_string(), + token_endpoint: "https://auth.example.com/oauth2/token".to_string(), + introspection_endpoint: Some("https://auth.example.com/oauth2/introspect".to_string()), + revocation_endpoint: Some("https://auth.example.com/oauth2/revoke".to_string()), + supported_scopes: vec!["openid".to_string(), "profile".to_string(), "mcp".to_string()], + additional_params: HashMap::new(), + }; + + // Create AuthServer with third-party delegation + let auth_server = Arc::new(AuthServer::with_third_party_auth( + config, + third_party_config, + server_base_url + )); + + // Create routes as before + // ... +} +``` + +When using third-party authentication: +1. The MCP server will forward token validation to the third-party server +2. Token introspection will be performed using the third-party's introspection endpoint +3. Token revocation will be delegated to the third-party's revocation endpoint +4. The MCP server will respect the scope and expiration information from the third-party + +### 4. Token Validation + +```rust +async fn validate_token(auth_server: &AuthServer, token: &str, required_scope: &str) -> Result<(), ServerAuthError> { + // Validate token and check scope + let token_info = auth_server.validate_token_scope(token, required_scope).await?; + + // Access token information + let client_id = &token_info.client_id; + let username = token_info.username.as_deref().unwrap_or("anonymous"); + + println!("Token is valid for client {} and user {}", client_id, username); + Ok(()) +} +``` + +### 5. Server-Side Token Management + +The server-side implementation provides methods for token management: + +```rust +// Register a test token +let test_token = "test_token_123456".to_string(); +let token_info = TokenInfo { + active: true, + scope: Some("mcp profile".to_string()), + client_id: "test_client".to_string(), + username: Some("test_user".to_string()), + exp: Some( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH)? + .as_secs() + 3600, // 1 hour expiration + ), + additional_fields: Default::default(), +}; +auth_server.register_token(test_token, token_info).await?; + +// Revoke a token +auth_server.revoke_token("some_token_to_revoke").await?; ``` -### Running the Example in client +## Complete Examples + +### Client Example +Please refer to `examples/clients/src/auth/oauth_client.rs` for a complete client usage example. + +### Server Examples +- Basic server: `examples/servers/auth_server.rs` +- Integration test: `examples/auth_integration/main.rs` +- Third-party auth: `examples/third_party_auth_server.rs` + +### Running the Examples ```bash -# Run example -cargo run --example oauth-client +# Run the client example +cargo run --example clients_oauth_client + +# Run the server example +cargo run --example auth_server + +# Run the integration test +cargo run --example auth_integration + +# Run the third-party auth server example +cargo run --example third_party_auth_server ``` ## Authorization Flow Description @@ -103,10 +269,20 @@ cargo run --example oauth-client ## Security Considerations -- All tokens are securely stored in memory +- Tokens are validated for expiration and appropriate scopes +- Bearer tokens are extracted securely from authorization headers +- Redirect URIs are validated according to MCP security requirements (HTTPS or localhost) - PKCE implementation prevents authorization code interception attacks +- Token revocation is supported for invalidating access - Automatic token refresh support reduces user intervention -- Only accepts HTTPS connections or secure local callback URIs + +## Best Practices + +1. Always use HTTPS in production +2. Implement proper token expiration +3. Use proper scope validation for each protected endpoint +4. Store client secrets securely +5. Implement rate limiting for registration and token endpoints ## Troubleshooting @@ -116,10 +292,11 @@ If you encounter authorization issues, check the following: 2. Verify callback URI matches server's allowed redirect URIs 3. Check network connection and firewall settings 4. Verify server supports metadata discovery or dynamic client registration +5. For server-side issues, check that token validation includes proper scopes ## References -- [MCP Authorization Specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) +- [MCP Authorization Specification](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) - [OAuth 2.1 Specification Draft](https://oauth.net/2.1/) - [RFC 8414: OAuth 2.0 Authorization Server Metadata](https://datatracker.ietf.org/doc/html/rfc8414) - [RFC 7591: OAuth 2.0 Dynamic Client Registration Protocol](https://datatracker.ietf.org/doc/html/rfc7591) \ No newline at end of file diff --git a/examples/auth_integration/Cargo.toml b/examples/auth_integration/Cargo.toml new file mode 100644 index 00000000..a4a95ab0 --- /dev/null +++ b/examples/auth_integration/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "auth_integration" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1", features = ["full"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +reqwest = { version = "0.12", features = ["json"] } +anyhow = "1.0" +thiserror = "2.0" +async-trait = "0.1" +futures = "0.3" +toml = "0.8" +rmcp = { workspace = true, features = [ + "client", + "transport-child-process", + "transport-sse-client", + "reqwest" +], no-default-features = true } +clap = { version = "4.0", features = ["derive"] } diff --git a/examples/auth_integration/src/main.rs b/examples/auth_integration/src/main.rs new file mode 100644 index 00000000..61861b90 --- /dev/null +++ b/examples/auth_integration/src/main.rs @@ -0,0 +1,136 @@ +use std::sync::Arc; + +use axum::{routing::get, Router}; +use rmcp::transport::{ + auth::{OAuthState, AuthorizationManager}, + auth_server::{AuthServer, AuthServerConfig, TokenInfo, axum_ext}, + SseClientTransport, sse_client::SseClientConfig, +}; +use tokio::{net::TcpListener, time::sleep, time::Duration}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use url::Url; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + + tokio::spawn(start_oauth_server()); + + sleep(Duration::from_secs(1)).await; + + run_oauth_client().await?; + + Ok(()) +} + +async fn start_oauth_server() -> anyhow::Result<()> { + let server_base_url = Url::parse("http://localhost:3000")?; + + + let config = AuthServerConfig { + client_id: "mcp_server".to_string(), + client_secret: "mcp_server_secret".to_string(), + authorize_endpoint: format!("{}/authorize", server_base_url), + token_endpoint: format!("{}/token", server_base_url), + registration_endpoint: format!("{}/register", server_base_url), + issuer: server_base_url.to_string(), + supported_scopes: vec![ + "mcp".to_string(), + "profile".to_string(), + "email".to_string() + ], + }; + + + let auth_server = Arc::new(AuthServer::new(config, server_base_url.clone())); + + let test_token = "test_token_123456".to_string(); + let token_info = TokenInfo { + active: true, + scope: Some("mcp profile".to_string()), + client_id: "test_client".to_string(), + username: Some("test_user".to_string()), + exp: Some( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH)? + .as_secs() + 3600, // 1小时后过期 + ), + additional_fields: Default::default(), + }; + auth_server.register_token(test_token, token_info).await?; + + + let oauth_router = Router::new() + .route("/.well-known/oauth-authorization-server", get(axum_ext::oauth_metadata_handler)) + .route("/register", get(axum_ext::client_registration_handler)) + .with_state(auth_server.clone()); + + let protected_api = Router::new() + .route("/api/user", get(user_handler)) + .route_layer(axum::middleware::from_fn_with_state( + auth_server.clone(), + axum_ext::auth_middleware, + )); + + + let app = Router::new() + .merge(oauth_router) + .merge(protected_api); + + let listener = TcpListener::bind("127.0.0.1:3000").await?; + println!("MCP OAuth server is running at http://localhost:3000"); + + axum::serve(listener, app).await?; + Ok(()) +} + + +async fn user_handler() -> axum::Json { + axum::Json(serde_json::json!({ + "id": "user_1", + "name": "Test User", + "email": "user@example.com" + })) +} + +async fn run_oauth_client() -> anyhow::Result<()> { + println!("Starting OAuth client..."); + + + let mut oauth_state = OAuthState::new("http://localhost:3000", None).await?; + + + let client_id = "test_client"; + let token = rmcp::transport::auth::oauth2::StandardTokenResponse::new( + rmcp::transport::auth::oauth2::AccessToken::new("test_token_123456".to_string()), + rmcp::transport::auth::oauth2::basic::BasicTokenType::Bearer, + rmcp::transport::auth::oauth2::EmptyExtraTokenFields {}, + ); + + oauth_state.set_credentials(client_id, token).await?; + + let auth_manager = oauth_state.into_authorization_manager() + .ok_or_else(|| anyhow::anyhow!("Failed to get authorization manager"))?; + + let client = rmcp::transport::auth::AuthClient::new(reqwest::Client::default(), auth_manager); + + let transport = SseClientTransport::start_with_client( + client, + SseClientConfig { + sse_endpoint: "http://localhost:3000/api/user".into(), + ..Default::default() + }, + ).await?; + + println!("OAuth client successfully connected to server"); + println!("This is a simplified example - in a real-world scenario, the client would go through the full OAuth flow"); + + Ok(()) +} \ No newline at end of file diff --git a/examples/servers/src/third_party_auth_server.rs b/examples/servers/src/third_party_auth_server.rs new file mode 100644 index 00000000..80eca04b --- /dev/null +++ b/examples/servers/src/third_party_auth_server.rs @@ -0,0 +1,146 @@ +use std::{collections::HashMap, sync::Arc}; + +use axum::{ + middleware, + routing::{get, post}, + Json, Router, +}; +use rmcp::transport::{ + auth_server::{ + AuthServer, AuthServerConfig, ThirdPartyAuthConfig, TokenInfo, + axum_ext, example::{create_oauth_router, create_protected_router}, + }, +}; +use tokio::net::TcpListener; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use url::Url; + + +#[derive(serde::Serialize)] +struct MockThirdPartyResponse { + active: bool, + scope: String, + client_id: String, + username: String, + exp: u64, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let server_base_url = Url::parse("http://localhost:3000")?; + + let config = AuthServerConfig { + client_id: "mcp_server".to_string(), + client_secret: "mcp_server_secret".to_string(), + authorize_endpoint: format!("{}/authorize", server_base_url), + token_endpoint: format!("{}/token", server_base_url), + registration_endpoint: format!("{}/register", server_base_url), + issuer: server_base_url.to_string(), + supported_scopes: vec![ + "mcp".to_string(), + "profile".to_string(), + "email".to_string() + ], + }; + + let third_party_config = ThirdPartyAuthConfig { + client_id: "mcp_client_at_idp".to_string(), + client_secret: "mcp_client_secret_at_idp".to_string(), + auth_server_url: "http://localhost:3001".to_string(), + authorize_endpoint: "http://localhost:3001/oauth2/authorize".to_string(), + token_endpoint: "http://localhost:3001/oauth2/token".to_string(), + introspection_endpoint: Some("http://localhost:3001/oauth2/introspect".to_string()), + revocation_endpoint: Some("http://localhost:3001/oauth2/revoke".to_string()), + supported_scopes: vec!["openid".to_string(), "profile".to_string(), "mcp".to_string()], + additional_params: HashMap::new(), + }; + + let auth_server = Arc::new(AuthServer::with_third_party_auth( + config, + third_party_config, + server_base_url + )); + + let oauth_router = create_oauth_router(auth_server.clone()).await; + + let protected_api = create_protected_router(auth_server.clone()); + + let mcp_app = Router::new() + .merge(oauth_router) + .merge(protected_api); + + tokio::spawn(start_mock_third_party_server()); + + let listener = TcpListener::bind("127.0.0.1:3000").await?; + println!("MCP OAuth server with third-party auth is running at http://localhost:3000"); + println!("This server delegates authentication to the third-party auth server at http://localhost:3001"); + println!("Metadata endpoint: http://localhost:3000/.well-known/oauth-authorization-server"); + println!("Registration endpoint: http://localhost:3000/register"); + println!("Protected API endpoints:"); + println!(" - http://localhost:3000/api/resource (requires any valid token)"); + println!(" - http://localhost:3000/api/mcp (requires 'mcp' scope)"); + println!(" - http://localhost:3000/api/profile (requires 'profile' scope)"); + println!("Test with: curl -H 'Authorization: Bearer test_token_123456' http://localhost:3000/api/resource"); + + axum::serve(listener, mcp_app).await?; + Ok(()) +} + + +async fn start_mock_third_party_server() -> anyhow::Result<()> { + + let app = Router::new() + .route("/oauth2/introspect", post(introspect_handler)) + .route("/oauth2/revoke", post(revoke_handler)); + + + let listener = TcpListener::bind("127.0.0.1:3001").await?; + println!("Mock third-party OAuth server is running at http://localhost:3001"); + + axum::serve(listener, app).await?; + Ok(()) +} + + +async fn introspect_handler( + axum::Form(params): axum::Form> +) -> Json { + let token = params.get("token").cloned().unwrap_or_default(); + + + if token == "test_token_123456" { + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + Json(serde_json::json!({ + "active": true, + "scope": "openid profile mcp", + "client_id": "test_client", + "username": "test_user", + "exp": now + 3600, + })) + } else { + + Json(serde_json::json!({ + "active": false + })) + } +} + + +async fn revoke_handler( + axum::Form(_params): axum::Form> +) -> &'static str { + "" +} \ No newline at end of file