diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 2fc9d640a..89edcff94 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -14,6 +14,7 @@ from __future__ import annotations +import base64 import logging import sys from typing import List @@ -21,6 +22,8 @@ from typing import TextIO from typing import Union +from fastapi.openapi.models import APIKeyIn + from ...agents.readonly_context import ReadonlyContext from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme @@ -146,8 +149,15 @@ async def get_tools( Returns: List[BaseTool]: A list of tools available under the specified context. """ + # Get authentication headers based on the auth scheme and credential + authentication_headers = await self._get_headers( + self._auth_credential, self._auth_scheme + ) + # Get session from session manager - session = await self._mcp_session_manager.create_session() + session = await self._mcp_session_manager.create_session( + authentication_headers + ) # Fetch available tools from the MCP server tools_response: ListToolsResult = await session.list_tools() @@ -166,6 +176,79 @@ async def get_tools( tools.append(mcp_tool) return tools + @staticmethod + async def _get_headers( + credential: AuthCredential, auth_scheme: AuthScheme + ) -> Optional[dict[str, str]]: + """Extracts authentication headers from credentials. + + Args: + credential: The authentication credential to process. + + Returns: + Dictionary of headers to add to the request, or None if no auth. + + Raises: + ValueError: If API key authentication is configured for non-header location. + """ + headers: Optional[dict[str, str]] = None + if credential: + if credential.oauth2: + headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} + elif credential.http: + # Handle HTTP authentication schemes + if ( + credential.http.scheme.lower() == "bearer" + and credential.http.credentials.token + ): + headers = { + "Authorization": f"Bearer {credential.http.credentials.token}" + } + elif credential.http.scheme.lower() == "basic": + # Handle basic auth + if ( + credential.http.credentials.username + and credential.http.credentials.password + ): + credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}" + encoded_credentials = base64.b64encode( + credentials.encode() + ).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + elif credential.http.credentials.token: + # Handle other HTTP schemes with token + headers = { + "Authorization": ( + f"{credential.http.scheme} {credential.http.credentials.token}" + ) + } + elif credential.api_key: + if not auth_scheme or not credential: + error_msg = ( + "Cannot find corresponding auth scheme for API key credential" + f" {credential}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + elif auth_scheme.in_ != APIKeyIn.header: + error_msg = ( + "MCPTool only supports header-based API key authentication." + " Configured location:" + f" {auth_scheme.in_}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + else: + headers = {auth_scheme.name: credential.api_key} + elif credential.service_account: + # Service accounts should be exchanged for access tokens before reaching this point + logger.warning( + "Service account credentials should be exchanged before MCP" + " session creation" + ) + + return headers + async def close(self) -> None: """Performs cleanup and releases resources held by the toolset. diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index d5e6ae243..20ae2882d 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -19,7 +19,12 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.auth import AuthCredentialTypes +from google.adk.auth import OAuth2Auth from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import ServiceAccount import pytest # Skip all tests in this module if Python version is less than 3.10 @@ -162,6 +167,193 @@ def test_init_with_auth(self): assert toolset._auth_scheme == auth_scheme assert toolset._auth_credential == auth_credential + @pytest.mark.asyncio + async def test_get_headers_oauth2(self): + """Test header generation for OAuth2 credentials.""" + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + + oauth2_auth = OAuth2Auth(access_token="test_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + headers = await toolset._get_headers(credential, oauth2_auth) + + assert headers == {"Authorization": "Bearer test_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_bearer(self): + """Test header generation for HTTP Bearer credentials.""" + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + + http_auth = HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + headers = await toolset._get_headers(credential, http_auth) + + assert headers == {"Authorization": "Bearer bearer_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_basic(self): + """Test header generation for HTTP Basic credentials.""" + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + + http_auth = HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password="pass"), + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + headers = await toolset._get_headers(credential, http_auth) + + # Should create Basic auth header with base64 encoded credentials + import base64 + + expected_encoded = base64.b64encode(b"user:pass").decode() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_valid_header_scheme(self): + """Test header generation for API Key credentials with header-based auth scheme.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for header-based API key + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.header, + "name": "X-Custom-API-Key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + + headers = await toolset._get_headers(auth_credential, auth_scheme) + + assert headers == {"X-Custom-API-Key": "my_api_key"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_query_scheme_raises_error(self): + """Test that API Key with query-based auth scheme raises ValueError.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for query-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.query, + "name": "api_key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + + with pytest.raises( + ValueError, + match="MCPTool only supports header-based API key authentication", + ): + await toolset._get_headers(auth_credential, auth_scheme) + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_cookie_scheme_raises_error(self): + """Test that API Key with cookie-based auth scheme raises ValueError.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for cookie-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.cookie, + "name": "session_id", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + + with pytest.raises( + ValueError, + match="MCPTool only supports header-based API key authentication", + ): + await toolset._get_headers(auth_credential, auth_scheme) + + @pytest.mark.asyncio + async def test_get_headers_api_key_without_auth_schema_raises_error(self): + """Test that API Key without auth config raises ValueError.""" + # Create tool without auth scheme/config + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + with pytest.raises( + ValueError, + match="Cannot find corresponding auth scheme for API key credential", + ): + await toolset._get_headers(credential, None) + + @pytest.mark.asyncio + async def test_get_headers_no_credential(self): + """Test header generation with no credentials.""" + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + oauth2_auth = OAuth2Auth(access_token="test_token") + + headers = await toolset._get_headers(None, oauth2_auth) + + assert headers is None + + @pytest.mark.asyncio + async def test_get_headers_service_account(self): + """Test header generation for service account credentials.""" + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + ) + + # Create service account credential + service_account = ServiceAccount(scopes=["test"]) + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=service_account, + ) + + headers = await toolset._get_headers( + credential, AuthCredentialTypes.SERVICE_ACCOUNT + ) + + # Should return None as service account credentials are not supported for direct header generation + assert headers is None + def test_init_missing_connection_params(self): """Test initialization with missing connection params raises error.""" with pytest.raises(ValueError, match="Missing connection params"):