Skip to content

fix: MCPToolset does not include authentication information during initialization and tool listing #2173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

from __future__ import annotations

import base64
import logging
import sys
from typing import List
from typing import Optional
from typing import TextIO
from typing import Union

from fastapi.openapi.models import APIKeyIn

from ...agents.readonly_context import ReadonlyContext
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
Expand Down Expand Up @@ -146,8 +149,15 @@ async def get_tools(
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
# Get authentication headers based on the auth scheme and credential
authentication_headers = await self._get_headers(
self._auth_credential, self._auth_scheme
)

# Get session from session manager
session = await self._mcp_session_manager.create_session()
session = await self._mcp_session_manager.create_session(
authentication_headers
)

# Fetch available tools from the MCP server
tools_response: ListToolsResult = await session.list_tools()
Expand All @@ -166,6 +176,79 @@ async def get_tools(
tools.append(mcp_tool)
return tools

@staticmethod
async def _get_headers(
credential: AuthCredential, auth_scheme: AuthScheme
) -> Optional[dict[str, str]]:
"""Extracts authentication headers from credentials.
Args:
credential: The authentication credential to process.
Returns:
Dictionary of headers to add to the request, or None if no auth.
Raises:
ValueError: If API key authentication is configured for non-header location.
"""
headers: Optional[dict[str, str]] = None
if credential:
if credential.oauth2:
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
elif credential.http:
# Handle HTTP authentication schemes
if (
credential.http.scheme.lower() == "bearer"
and credential.http.credentials.token
):
headers = {
"Authorization": f"Bearer {credential.http.credentials.token}"
}
elif credential.http.scheme.lower() == "basic":
# Handle basic auth
if (
credential.http.credentials.username
and credential.http.credentials.password
):
credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
encoded_credentials = base64.b64encode(
credentials.encode()
).decode()
headers = {"Authorization": f"Basic {encoded_credentials}"}
elif credential.http.credentials.token:
# Handle other HTTP schemes with token
headers = {
"Authorization": (
f"{credential.http.scheme} {credential.http.credentials.token}"
)
}
elif credential.api_key:
if not auth_scheme or not credential:
error_msg = (
"Cannot find corresponding auth scheme for API key credential"
f" {credential}"
)
logger.error(error_msg)
raise ValueError(error_msg)
elif auth_scheme.in_ != APIKeyIn.header:
error_msg = (
"MCPTool only supports header-based API key authentication."
" Configured location:"
f" {auth_scheme.in_}"
)
logger.error(error_msg)
raise ValueError(error_msg)
else:
headers = {auth_scheme.name: credential.api_key}
elif credential.service_account:
# Service accounts should be exchanged for access tokens before reaching this point
logger.warning(
"Service account credentials should be exchanged before MCP"
" session creation"
)

return headers

async def close(self) -> None:
"""Performs cleanup and releases resources held by the toolset.
Expand Down
192 changes: 192 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from unittest.mock import Mock
from unittest.mock import patch

from google.adk.auth import AuthCredentialTypes
from google.adk.auth import OAuth2Auth
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import ServiceAccount
import pytest

# Skip all tests in this module if Python version is less than 3.10
Expand Down Expand Up @@ -162,6 +167,193 @@ def test_init_with_auth(self):
assert toolset._auth_scheme == auth_scheme
assert toolset._auth_credential == auth_credential

@pytest.mark.asyncio
async def test_get_headers_oauth2(self):
"""Test header generation for OAuth2 credentials."""
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)

oauth2_auth = OAuth2Auth(access_token="test_token")
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
)

headers = await toolset._get_headers(credential, oauth2_auth)

assert headers == {"Authorization": "Bearer test_token"}

@pytest.mark.asyncio
async def test_get_headers_http_bearer(self):
"""Test header generation for HTTP Bearer credentials."""
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)

http_auth = HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, http=http_auth
)

headers = await toolset._get_headers(credential, http_auth)

assert headers == {"Authorization": "Bearer bearer_token"}

@pytest.mark.asyncio
async def test_get_headers_http_basic(self):
"""Test header generation for HTTP Basic credentials."""
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)

http_auth = HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password="pass"),
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, http=http_auth
)

headers = await toolset._get_headers(credential, http_auth)

# Should create Basic auth header with base64 encoded credentials
import base64

expected_encoded = base64.b64encode(b"user:pass").decode()
assert headers == {"Authorization": f"Basic {expected_encoded}"}

@pytest.mark.asyncio
async def test_get_headers_api_key_with_valid_header_scheme(self):
"""Test header generation for API Key credentials with header-based auth scheme."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType

# Create auth scheme for header-based API key
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.header,
"name": "X-Custom-API-Key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)

toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)

headers = await toolset._get_headers(auth_credential, auth_scheme)

assert headers == {"X-Custom-API-Key": "my_api_key"}

@pytest.mark.asyncio
async def test_get_headers_api_key_with_query_scheme_raises_error(self):
"""Test that API Key with query-based auth scheme raises ValueError."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType

# Create auth scheme for query-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.query,
"name": "api_key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)

toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)

with pytest.raises(
ValueError,
match="MCPTool only supports header-based API key authentication",
):
await toolset._get_headers(auth_credential, auth_scheme)

@pytest.mark.asyncio
async def test_get_headers_api_key_with_cookie_scheme_raises_error(self):
"""Test that API Key with cookie-based auth scheme raises ValueError."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType

# Create auth scheme for cookie-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.cookie,
"name": "session_id",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)

toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)

with pytest.raises(
ValueError,
match="MCPTool only supports header-based API key authentication",
):
await toolset._get_headers(auth_credential, auth_scheme)

@pytest.mark.asyncio
async def test_get_headers_api_key_without_auth_schema_raises_error(self):
"""Test that API Key without auth config raises ValueError."""
# Create tool without auth scheme/config
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)

credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)

with pytest.raises(
ValueError,
match="Cannot find corresponding auth scheme for API key credential",
):
await toolset._get_headers(credential, None)

@pytest.mark.asyncio
async def test_get_headers_no_credential(self):
"""Test header generation with no credentials."""
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)
oauth2_auth = OAuth2Auth(access_token="test_token")

headers = await toolset._get_headers(None, oauth2_auth)

assert headers is None

@pytest.mark.asyncio
async def test_get_headers_service_account(self):
"""Test header generation for service account credentials."""
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
)

# Create service account credential
service_account = ServiceAccount(scopes=["test"])
credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=service_account,
)

headers = await toolset._get_headers(
credential, AuthCredentialTypes.SERVICE_ACCOUNT
)

# Should return None as service account credentials are not supported for direct header generation
assert headers is None

def test_init_missing_connection_params(self):
"""Test initialization with missing connection params raises error."""
with pytest.raises(ValueError, match="Missing connection params"):
Expand Down