diff --git a/src/google/adk/examples/large_context_example.py b/src/google/adk/examples/large_context_example.py new file mode 100644 index 000000000..2c3ce18aa --- /dev/null +++ b/src/google/adk/examples/large_context_example.py @@ -0,0 +1,337 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Large Context Example - Demonstrating Efficient Context Management + +This example demonstrates how to use the context reference store and large context state +to efficiently handle very large context windows (1M-2M tokens) with ADK and Gemini. +""" + +import time +import sys +import json +import random +from typing import Dict, Any, List + +from google.adk.agents import LlmAgent +from google.adk.sessions import LargeContextState, ContextReferenceStore +from google.adk.tools import FunctionTool + + +def generate_large_document(token_size: int = 500000) -> Dict[str, Any]: + """ + Generate a mock document of approximately specified token size with realistic structure. + + Args: + token_size: Approximate size in tokens + + Returns: + A structured document + """ + # 1 token ≈ 4 characters in English + char_count = token_size * 4 + + # Create sample paragraphs to simulate document content + paragraphs = [ + "This is a sample document with information that might be relevant to the query.", + "It contains multiple paragraphs with different content to simulate a real document.", + "Some paragraphs are longer and contain more detailed information about specific topics.", + "Others are shorter and provide concise summaries or transitions between sections.", + "The document also includes sections with headings, lists, and other structured content.", + "This helps simulate the complexity of real documents processed by AI systems.", + "Technical documents often contain specialized terminology and references.", + "Research papers include citations, methodology descriptions, and analysis of results.", + "Legal documents have specific formatting, defined terms, and complex nested clauses.", + "Narrative text might include dialogue, character descriptions, and scene settings.", + ] + + # Create a complex structure to simulate real documents + document = { + "metadata": { + "title": "Large Context Processing Example Document", + "created_at": "2024-11-14T10:30:00Z", + "author": "ADK Context Management Example", + "version": "1.0", + "tags": ["example", "large_context", "gemini", "adk"], + "summary": "A synthetic document for demonstrating large context processing", + }, + "sections": [], + } + + # Generate enough sections and content to reach the desired token size + current_char_count = len(json.dumps(document)) + section_id = 1 + + while current_char_count < char_count: + # Create a section with subsections + section = { + "id": f"section-{section_id}", + "title": f"Section {section_id}: {random.choice(['Overview', 'Analysis', 'Results', 'Discussion', 'Methods'])}", + "content": "\n\n".join(random.sample(paragraphs, min(5, len(paragraphs)))), + "subsections": [], + } + + # Add subsections + for j in range(1, random.randint(3, 6)): + subsection = { + "id": f"section-{section_id}-{j}", + "title": f"Subsection {section_id}.{j}", + "content": "\n\n".join( + random.sample(paragraphs, min(3, len(paragraphs))) + ), + "paragraphs": [], + } + + # Add more detailed paragraphs to subsections + for k in range(1, random.randint(5, 15)): + paragraph = { + "id": f"para-{section_id}-{j}-{k}", + "text": random.choice(paragraphs), + "metadata": { + "relevance_score": round(random.random(), 2), + "keywords": random.sample( + [ + "ai", + "context", + "processing", + "testing", + "gemini", + "adk", + "efficiency", + ], + 3, + ), + }, + } + subsection["paragraphs"].append(paragraph) + + section["subsections"].append(subsection) + + document["sections"].append(section) + current_char_count = len(json.dumps(document)) + section_id += 1 + + # Safety check to avoid infinite loops + if section_id > 1000: + break + + print( + f"Generated document with approximate size: {current_char_count / 4:.0f} tokens" + ) + return document + + +def extract_section( + context_state: LargeContextState, section_id: str +) -> Dict[str, Any]: + """ + Extract a specific section from the document by ID. + + Args: + context_state: State with document reference + section_id: ID of the section to extract + + Returns: + The extracted section + """ + # Retrieve the document from the context store + document = context_state.get_context("document_ref") + + # Search for the section with the given ID + for section in document.get("sections", []): + if section.get("id") == section_id: + return section + + # Check subsections if not found at top level + for subsection in section.get("subsections", []): + if subsection.get("id") == section_id: + return subsection + + return {"error": f"Section with ID {section_id} not found"} + + +def search_document( + context_state: LargeContextState, keywords: List[str] +) -> List[Dict[str, Any]]: + """ + Search for keywords in the document and return matching paragraphs. + + Args: + context_state: State with document reference + keywords: List of keywords to search for + + Returns: + List of matching paragraphs with metadata + """ + # Retrieve the document from the context store + document = context_state.get_context("document_ref") + + # Normalize keywords for case-insensitive search + normalized_keywords = [k.lower() for k in keywords] + + # Search for matches + matches = [] + + for section in document.get("sections", []): + # Check section content + section_content = section.get("content", "").lower() + if any(k in section_content for k in normalized_keywords): + matches.append( + { + "id": section.get("id"), + "title": section.get("title"), + "match_type": "section", + "content_preview": ( + section.get("content")[:200] + "..." + if len(section.get("content", "")) > 200 + else section.get("content") + ), + } + ) + + # Check subsections + for subsection in section.get("subsections", []): + subsection_content = subsection.get("content", "").lower() + if any(k in subsection_content for k in normalized_keywords): + matches.append( + { + "id": subsection.get("id"), + "title": subsection.get("title"), + "match_type": "subsection", + "content_preview": ( + subsection.get("content")[:200] + "..." + if len(subsection.get("content", "")) > 200 + else subsection.get("content") + ), + } + ) + + # Check paragraphs + for paragraph in subsection.get("paragraphs", []): + paragraph_text = paragraph.get("text", "").lower() + if any(k in paragraph_text for k in normalized_keywords): + matches.append( + { + "id": paragraph.get("id"), + "match_type": "paragraph", + "text": paragraph.get("text"), + "relevance_score": paragraph.get("metadata", {}).get( + "relevance_score" + ), + } + ) + + return matches + + +def run_example(): + """Run the large context example.""" + print("Starting Large Context Example...") + + # Create a context store + context_store = ContextReferenceStore() + + # Create a large context state + state = LargeContextState(context_store=context_store) + + # Generate a large document + print("Generating large test document...") + start_time = time.time() + document = generate_large_document(token_size=500000) + generation_time = time.time() - start_time + print(f"Document generation time: {generation_time:.2f} seconds") + + # Store the document in the context store + print("Storing document in context store...") + start_time = time.time() + document_ref = state.add_large_context( + document, + metadata={"content_type": "application/json", "cache_ttl": 3600}, + key="document_ref", + ) + store_time = time.time() - start_time + print(f"Document storage time: {store_time:.2f} seconds") + print(f"Document reference ID: {document_ref}") + + # Retrieve the document from the context store + print("Retrieving document from context store...") + start_time = time.time() + retrieved_document = state.get_context("document_ref") + retrieval_time = time.time() - start_time + print(f"Document retrieval time: {retrieval_time:.2f} seconds") + + # Create function tools that use the context + extract_section_tool = FunctionTool( + func=extract_section, + name="extract_section", + description="Extract a specific section from the document by ID", + ) + + search_document_tool = FunctionTool( + func=search_document, + name="search_document", + description="Search for keywords in the document and return matching paragraphs", + ) + + # Create an agent that uses the tools + agent = LlmAgent( + name="document_explorer", + model="gemini-1.5-pro-latest", + instruction=""" + You are a document explorer agent. You have access to a large document + through reference-based context management. You can: + + 1. Extract specific sections by ID using the extract_section tool + 2. Search for keywords in the document using the search_document tool + + Always use these tools to access the document rather than trying to + process the entire document at once. + """, + tools=[extract_section_tool, search_document_tool], + description="Agent for exploring large documents efficiently", + ) + + print("\nAgent created with tools for exploring the document.") + print("This example demonstrates how to use the context reference store") + print("to efficiently manage large contexts with ADK and Gemini.") + print( + "\nIn a real application, you would use the agent to interact with the document." + ) + print("For example, you could call:") + print(' agent.run({"user_input": "Find sections about AI", "state": state})') + + # For this example, I will just demonstrate searching for a keyword + print("\nDemonstrating document search...") + search_results = search_document(state, ["ai", "context"]) + print(f"Found {len(search_results)} matches for 'ai' or 'context'") + + # Print a few results + for i, result in enumerate(search_results[:3]): + print(f"\nMatch {i+1}:") + print(f" ID: {result.get('id')}") + print(f" Type: {result.get('match_type')}") + if "title" in result: + print(f" Title: {result.get('title')}") + if "text" in result: + print(f" Text: {result.get('text')}") + if "content_preview" in result: + print(f" Preview: {result.get('content_preview')}") + + if len(search_results) > 3: + print(f"\n... and {len(search_results) - 3} more matches.") + + +if __name__ == "__main__": + run_example() diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index 5583ac436..eaf7fc974 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -15,27 +15,38 @@ from .base_session_service import BaseSessionService from .in_memory_session_service import InMemorySessionService -from .session import Session +from .session import Session, ExitResponse from .state import State from .vertex_ai_session_service import VertexAiSessionService +from .database_session_service import DatabaseSessionService +from .context_reference_store import ( + ContextReferenceStore, + ContextMetadata, +) +from .large_context_state import LargeContextState -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) __all__ = [ - 'BaseSessionService', - 'InMemorySessionService', - 'Session', - 'State', - 'VertexAiSessionService', + "BaseSessionService", + "InMemorySessionService", + "Session", + "State", + "ExitResponse", + "VertexAiSessionService", + "DatabaseSessionService", + "ContextReferenceStore", + "ContextMetadata", + "LargeContextState", ] try: - from .database_session_service import DatabaseSessionService + from .database_session_service import DatabaseSessionService - __all__.append('DatabaseSessionService') + __all__.append("DatabaseSessionService") except ImportError: - logger.debug( - 'DatabaseSessionService require sqlalchemy>=2.0, please ensure it is' - ' installed correctly.' - ) + logger.debug( + "DatabaseSessionService require sqlalchemy>=2.0, please ensure it is" + " installed correctly." + ) diff --git a/src/google/adk/sessions/context_reference_store.py b/src/google/adk/sessions/context_reference_store.py new file mode 100644 index 000000000..75bba9484 --- /dev/null +++ b/src/google/adk/sessions/context_reference_store.py @@ -0,0 +1,217 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Context Reference Store for Efficient Management of Large Context Windows + +This module implements a solution for efficiently managing large context windows (1M-2M tokens) +by using a reference-based approach rather than direct context passing. +""" + +import time +import json +import uuid +import hashlib +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field + + +@dataclass +class ContextMetadata: + """Metadata for stored context.""" + + content_type: str = "text/plain" + token_count: int = 0 + created_at: float = field(default_factory=time.time) + last_accessed: float = field(default_factory=time.time) + access_count: int = 0 + tags: List[str] = field(default_factory=list) + cache_id: Optional[str] = None + cached_until: Optional[float] = None # Timestamp when cache expires + is_structured: bool = False # Whether this is JSON or not + + def update_access_stats(self): + """Update access statistics.""" + self.last_accessed = time.time() + self.access_count += 1 + + +class ContextReferenceStore: + """ + A store for large contexts that provides reference-based access. + + This class allows large contexts to be stored once and referenced by ID, + preventing unnecessary duplication and serialization of large data. + """ + + def __init__(self, cache_size: int = 50): + """ + Args: + cache_size: Maximum number of contexts to keep in memory + """ + self._contexts: Dict[str, str] = {} + self._metadata: Dict[str, ContextMetadata] = {} + self._lru_cache_size = cache_size + + def store(self, content: Any, metadata: Optional[Dict[str, Any]] = None) -> str: + """ + Store context and return a reference ID. + + Args: + content: The context content to store (string or structured data) + metadata: Optional metadata about the context + + Returns: + A reference ID for the stored context + """ + # Handle both string and structured data (like JSON objects) + is_structured = not isinstance(content, str) + + # Convert structured data to string for storage + if is_structured: + content_str = json.dumps(content) + content_hash = hashlib.md5(content_str.encode()).hexdigest() + else: + content_str = content + content_hash = hashlib.md5(content.encode()).hexdigest() + + # Check if we already have this content + for context_id, existing_content in self._contexts.items(): + existing_hash = hashlib.md5(existing_content.encode()).hexdigest() + if ( + existing_hash == content_hash + and self._metadata[context_id].is_structured == is_structured + ): + # Update access stats + self._metadata[context_id].update_access_stats() + return context_id + + # Generate a new ID if not found + context_id = str(uuid.uuid4()) + + self._contexts[context_id] = content_str + + # Set content type based on input type + if is_structured: + content_type = "application/json" + else: + content_type = ( + metadata.get("content_type", "text/plain") if metadata else "text/plain" + ) + + # Create and store metadata + meta = ContextMetadata( + content_type=content_type, + token_count=len(content_str) // 4, # This is a rough approximation + tags=metadata.get("tags", []) if metadata else [], + is_structured=is_structured, + ) + + # Generate a cache ID for Gemini caching + if metadata and "cache_id" in metadata: + meta.cache_id = metadata["cache_id"] + else: + meta.cache_id = f"context_{content_hash[:16]}" + + # Set cache expiration if provided + if metadata and "cache_ttl" in metadata: + ttl_seconds = metadata["cache_ttl"] + meta.cached_until = time.time() + ttl_seconds + + self._metadata[context_id] = meta + + self._manage_cache() + + return context_id + + def retrieve(self, context_id: str) -> Any: + """ + Retrieve context by its reference ID. + + Args: + context_id: The reference ID for the context + + Returns: + The context content (string or structured data depending on how it was stored) + """ + if context_id not in self._contexts: + raise KeyError(f"Context ID {context_id} not found") + + # Update access stats + self._metadata[context_id].update_access_stats() + + # Get the content and metadata + content = self._contexts[context_id] + metadata = self._metadata[context_id] + + # If the content is structured (JSON), parse it back + if metadata.is_structured: + try: + return json.loads(content) + except json.JSONDecodeError: + # Fallback to returning as string if JSON parsing fails + return content + + return content + + def get_metadata(self, context_id: str) -> ContextMetadata: + """Get metadata for a context.""" + if context_id not in self._metadata: + raise KeyError(f"Context ID {context_id} not found") + return self._metadata[context_id] + + def _manage_cache(self): + """Manage the cache size by removing least recently used contexts.""" + if len(self._contexts) <= self._lru_cache_size: + return + + # Sort by last accessed time + sorted_contexts = sorted( + self._metadata.items(), key=lambda x: x[1].last_accessed + ) + + # Remove oldest contexts until we're under the limit + contexts_to_remove = len(self._contexts) - self._lru_cache_size + for i in range(contexts_to_remove): + context_id = sorted_contexts[i][0] + del self._contexts[context_id] + del self._metadata[context_id] + + def get_cache_hint(self, context_id: str) -> Dict[str, Any]: + """ + Get a cache hint object for Gemini API calls. + + This allows the Gemini API to cache the context for reuse. + According to Gemini API docs, context caching can significantly + reduce costs when reusing the same context multiple times. + """ + if context_id not in self._metadata: + raise KeyError(f"Context ID {context_id} not found") + + metadata = self._metadata[context_id] + + # Create cache hint with recommended parameters + cache_hint = { + "cache_id": metadata.cache_id, + "cache_level": "HIGH", # Strong caching for this context + } + + # If we have a cached_until timestamp, add it + if metadata.cached_until: + now = time.time() + if metadata.cached_until > now: + # Still valid, calculate remaining TTL in seconds + cache_hint["ttl_seconds"] = int(metadata.cached_until - now) + + return cache_hint diff --git a/src/google/adk/sessions/large_context_state.py b/src/google/adk/sessions/large_context_state.py new file mode 100644 index 000000000..32a4f5cfd --- /dev/null +++ b/src/google/adk/sessions/large_context_state.py @@ -0,0 +1,137 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Enhanced State class for handling large context windows efficiently. + +This module extends the ADK State class to provide efficient handling of large context +windows (1M-2M tokens) using a reference-based approach. +""" + +import json +from typing import Dict, Any, Optional, List + +from google.adk.sessions.state import State +from google.adk.sessions.context_reference_store import ContextReferenceStore + + +class LargeContextState(State): + """ + Enhanced State class for efficient handling of large contexts with Gemini. + + This class extends ADK's State to handle large contexts efficiently by: + - Storing references to contexts instead of the contexts themselves + - Providing methods to resolve references when needed + - Supporting Gemini's context caching feature for cost optimization + - Handling both text and structured contexts + """ + + def __init__( + self, + value: Dict[str, Any] = None, + delta: Dict[str, Any] = None, + context_store: Optional[ContextReferenceStore] = None, + ): + """ + + Args: + value: The current value of the state dict + delta: The delta change to the current value that hasn't been committed + context_store: Context reference store to use + """ + super().__init__(value=value or {}, delta=delta or {}) + self._context_store = context_store or ContextReferenceStore() + + def add_large_context( + self, + content: Any, + metadata: Optional[Dict[str, Any]] = None, + key: str = "context_ref", + ) -> str: + """ + Add large context to the state using reference-based storage. + + Args: + content: The context content to store (string or structured data) + metadata: Optional metadata about the context + key: The key to store the reference under in the state + + Returns: + The reference ID for the stored context + """ + context_id = self._context_store.store(content, metadata) + self[key] = context_id + return context_id + + def get_context(self, ref_key: str = "context_ref") -> Any: + """ + Retrieve context from a reference stored in the state. + + Args: + ref_key: The key where the context reference is stored + + Returns: + The context content + """ + if ref_key not in self: + raise KeyError(f"Context reference key '{ref_key}' not found in state") + + context_id = self[ref_key] + return self._context_store.retrieve(context_id) + + def with_cache_hint(self, ref_key: str = "context_ref") -> Dict[str, Any]: + """ + Get a cache hint object for Gemini API calls. + + This allows the Gemini API to cache the context for reuse. + According to Gemini API docs, context caching can significantly + reduce costs when reusing the same context multiple times. + + Args: + ref_key: The key where the context reference is stored + + Returns: + A cache hint object suitable for passing to Gemini API + """ + if ref_key not in self: + raise KeyError(f"Context reference key '{ref_key}' not found in state") + + context_id = self[ref_key] + return self._context_store.get_cache_hint(context_id) + + def store_structured_context( + self, + data: Dict[str, Any], + metadata: Optional[Dict[str, Any]] = None, + key: str = "structured_context_ref", + ) -> str: + """ + Store structured data (JSON/dict) in the context store. + + Args: + data: The structured data to store + metadata: Optional metadata about the context + key: The key to store the reference under in the state + + Returns: + The reference ID for the stored context + """ + if metadata is None: + metadata = {} + + # Ensure we mark this as structured data if not already specified + if "content_type" not in metadata: + metadata["content_type"] = "application/json" + + return self.add_large_context(data, metadata, key) diff --git a/src/google/adk/sessions/session.py b/src/google/adk/sessions/session.py index aa9939991..833738e9d 100644 --- a/src/google/adk/sessions/session.py +++ b/src/google/adk/sessions/session.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Dict from pydantic import alias_generators from pydantic import BaseModel @@ -22,37 +22,52 @@ from ..events.event import Event +class ExitResponse(BaseModel): + """ + A response indicating the agent should exit. + + This is used to signal that the agent has completed its task and should exit. + + Attributes: + reason: The reason for exiting + data: Optional additional data to include with the exit response + """ + + reason: str + data: Optional[Dict[str, Any]] = None + + class Session(BaseModel): - """Represents a series of interactions between a user and agents. - - Attributes: - id: The unique identifier of the session. - app_name: The name of the app. - user_id: The id of the user. - state: The state of the session. - events: The events of the session, e.g. user input, model response, function - call/response, etc. - last_update_time: The last update time of the session. - """ - - model_config = ConfigDict( - extra='forbid', - arbitrary_types_allowed=True, - alias_generator=alias_generators.to_camel, - populate_by_name=True, - ) - """The pydantic model config.""" - - id: str - """The unique identifier of the session.""" - app_name: str - """The name of the app.""" - user_id: str - """The id of the user.""" - state: dict[str, Any] = Field(default_factory=dict) - """The state of the session.""" - events: list[Event] = Field(default_factory=list) - """The events of the session, e.g. user input, model response, function + """Represents a series of interactions between a user and agents. + + Attributes: + id: The unique identifier of the session. + app_name: The name of the app. + user_id: The id of the user. + state: The state of the session. + events: The events of the session, e.g. user input, model response, function + call/response, etc. + last_update_time: The last update time of the session. + """ + + model_config = ConfigDict( + extra="forbid", + arbitrary_types_allowed=True, + alias_generator=alias_generators.to_camel, + populate_by_name=True, + ) + """The pydantic model config.""" + + id: str + """The unique identifier of the session.""" + app_name: str + """The name of the app.""" + user_id: str + """The id of the user.""" + state: dict[str, Any] = Field(default_factory=dict) + """The state of the session.""" + events: list[Event] = Field(default_factory=list) + """The events of the session, e.g. user input, model response, function call/response, etc.""" - last_update_time: float = 0.0 - """The last update time of the session.""" + last_update_time: float = 0.0 + """The last update time of the session.""" diff --git a/src/google/adk/utils/__init__.py b/src/google/adk/utils/__init__.py index 0a2669d7a..921c48a1f 100644 --- a/src/google/adk/utils/__init__.py +++ b/src/google/adk/utils/__init__.py @@ -11,3 +11,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Utils package for ADK.""" + +from google.adk.utils.adk_to_mcp_tool_type import adk_to_mcp_tool_type +from google.adk.utils.async_utils import gather_results +from google.adk.utils.date_utils import ( + from_rfc3339_datetime, + to_rfc3339_datetime, +) +from google.adk.utils.image_utils import get_base64_image_from_uri +from google.adk.utils.langgraph_utils import ( + LangGraphContextManager, + create_reference_aware_merge, +) +from google.adk.utils.multipart_utils import ( + create_multipart_message, + extract_boundary_from_content_type, + parse_multipart_message, +) +from google.adk.utils.structured_output_utils import ( + input_or_function_schema_to_signature, + to_function_schema, + typescript_schema_to_pydantic, +) +from google.adk.utils.truncate_utils import truncate_data +from google.adk.utils.uri_utils import uri_to_file_path + +__all__ = [ + "adk_to_mcp_tool_type", + "create_multipart_message", + "extract_boundary_from_content_type", + "from_rfc3339_datetime", + "gather_results", + "get_base64_image_from_uri", + "input_or_function_schema_to_signature", + "LangGraphContextManager", + "create_reference_aware_merge", + "parse_multipart_message", + "to_function_schema", + "to_rfc3339_datetime", + "truncate_data", + "typescript_schema_to_pydantic", + "uri_to_file_path", +] diff --git a/src/google/adk/utils/langgraph_utils.py b/src/google/adk/utils/langgraph_utils.py new file mode 100644 index 000000000..5e5a187b0 --- /dev/null +++ b/src/google/adk/utils/langgraph_utils.py @@ -0,0 +1,135 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +LangGraph Utilities for Context Management + +This module provides utilities for integrating ADK's context management with LangGraph. +It focuses on making it easy to use efficient context reference storage with LangGraph state. +""" + +from typing import Dict, Any, Optional, TypeVar, List, Callable + +from google.adk.sessions.context_reference_store import ContextReferenceStore + + +StateType = TypeVar("StateType", bound=Dict[str, Any]) + + +class LangGraphContextManager: + """ + Context manager for LangGraph applications. + + Provides methods to integrate ADK's context reference store with LangGraph state. + """ + + def __init__(self, context_store: Optional[ContextReferenceStore] = None): + """ + + Args: + context_store: Context reference store to use + """ + self._context_store = context_store or ContextReferenceStore() + + def add_to_state( + self, + state: StateType, + content: Any, + ref_key: str = "context_ref", + metadata: Optional[Dict[str, Any]] = None, + ) -> StateType: + """ + Add content to context store and reference it in state. + + Args: + state: LangGraph state dict + content: The content to store + ref_key: Key to store the reference under + metadata: Optional metadata about the content + + Returns: + Updated state dict with reference added + """ + # Create a copy of the state to avoid modifying the original + new_state = state.copy() + + # Store content and get a reference ID + context_id = self._context_store.store(content, metadata) + + new_state[ref_key] = context_id # store the reference to the state + + return new_state + + def retrieve_from_state( + self, state: StateType, ref_key: str = "context_ref" + ) -> Any: + """ + Retrieve content from a reference in the state. + + Args: + state: LangGraph state dict + ref_key: Key where the reference is stored + + Returns: + The retrieved content + """ + if ref_key not in state: + raise KeyError(f"Context reference key '{ref_key}' not found in state") + + context_id = state[ref_key] + return self._context_store.retrieve(context_id) + + +def create_reference_aware_merge( + context_store: Optional[ContextReferenceStore] = None, +) -> Callable[[StateType, StateType], StateType]: + """ + Create a merge function for LangGraph that's aware of context references. + + This merge function handles special merging of reference keys, ensuring that + the reference itself is passed rather than trying to merge the content. + + Args: + context_store: Context reference store to use + + Returns: + A merge function that can be used with LangGraph's StateGraph + """ + store = context_store or ContextReferenceStore() + + def reference_aware_merge(left: StateType, right: StateType) -> StateType: + """ + Merge two state dicts with awareness of context references. + + Args: + left: First state dict + right: Second state dict + + Returns: + Merged state dict + """ + # Start with a copy of the left dict + result = left.copy() + + # Process all keys in the right dict + for key, value in right.items(): + + if key.endswith("_ref") and key in left: + result[key] = value + else: + result[key] = value + + return result + + return reference_aware_merge diff --git a/tests/unittests/sessions/conftest.py b/tests/unittests/sessions/conftest.py new file mode 100644 index 000000000..365d466f2 --- /dev/null +++ b/tests/unittests/sessions/conftest.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test fixtures for sessions tests.""" + +import sys +from unittest import mock + +# Create mocks for modules before they're imported +mock_event = mock.MagicMock() +mock_event.Event = mock.MagicMock() + +mock_adk_to_mcp_tool_type = mock.MagicMock() +mock_async_utils = mock.MagicMock() +mock_date_utils = mock.MagicMock() +mock_image_utils = mock.MagicMock() +mock_multipart_utils = mock.MagicMock() +mock_structured_output_utils = mock.MagicMock() +mock_truncate_utils = mock.MagicMock() +mock_uri_utils = mock.MagicMock() +mock_llm_response = mock.MagicMock() +mock_variant_utils = mock.MagicMock() +mock_variant_utils.GoogleLLMVariant = mock.MagicMock() + +# Apply mocks at the module level before any imports happen +sys.modules["google.adk.events.event"] = mock_event +sys.modules["google.adk.utils.adk_to_mcp_tool_type"] = mock_adk_to_mcp_tool_type +sys.modules["google.adk.utils.async_utils"] = mock_async_utils +sys.modules["google.adk.utils.date_utils"] = mock_date_utils +sys.modules["google.adk.utils.image_utils"] = mock_image_utils +sys.modules["google.adk.utils.multipart_utils"] = mock_multipart_utils +sys.modules["google.adk.utils.structured_output_utils"] = mock_structured_output_utils +sys.modules["google.adk.utils.truncate_utils"] = mock_truncate_utils +sys.modules["google.adk.utils.uri_utils"] = mock_uri_utils +sys.modules["google.adk.models.llm_response"] = mock_llm_response +sys.modules["google.adk.utils.variant_utils"] = mock_variant_utils + + +# Function to reset mocks after tests +def pytest_unconfigure(config): + """Reset any global state modified by the tests.""" + # Remove our mocked modules + for module_name in [ + "google.adk.events.event", + "google.adk.utils.adk_to_mcp_tool_type", + "google.adk.utils.async_utils", + "google.adk.utils.date_utils", + "google.adk.utils.image_utils", + "google.adk.utils.multipart_utils", + "google.adk.utils.structured_output_utils", + "google.adk.utils.truncate_utils", + "google.adk.utils.uri_utils", + "google.adk.models.llm_response", + "google.adk.utils.variant_utils", + ]: + if module_name in sys.modules: + del sys.modules[module_name] diff --git a/tests/unittests/sessions/test_large_context_state.py b/tests/unittests/sessions/test_large_context_state.py new file mode 100644 index 000000000..18991d800 --- /dev/null +++ b/tests/unittests/sessions/test_large_context_state.py @@ -0,0 +1,186 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for large context state and context reference store.""" + +import json +import time +import pytest +from typing import Dict, Any + +# Force the importing of conftest.py which sets up our mock modules +from . import conftest + +# Now import the modules we want to test +from google.adk.sessions.context_reference_store import ( + ContextReferenceStore, + ContextMetadata, +) +from google.adk.sessions.large_context_state import LargeContextState + + +class TestContextReferenceStore: + """Tests for ContextReferenceStore.""" + + def test_store_and_retrieve_text(self): + """Test storing and retrieving text content.""" + store = ContextReferenceStore() + content = ( + "This is a large text content" * 100 + ) # This is not large, but sufficient for testing + context_id = store.store(content) + + # Verify content can be retrieved + retrieved = store.retrieve(context_id) + assert content == retrieved + + # Verify metadata is created + metadata = store.get_metadata(context_id) + assert isinstance(metadata, ContextMetadata) + assert metadata.content_type == "text/plain" + assert not metadata.is_structured + + def test_store_and_retrieve_structured(self): + """Test storing and retrieving structured content.""" + store = ContextReferenceStore() + content = { + "title": "Test Document", + "sections": [ + {"id": "section-1", "title": "Section 1", "content": "Content 1"}, + {"id": "section-2", "title": "Section 2", "content": "Content 2"}, + ], + } + context_id = store.store(content) + + # Verify content can be retrieved + retrieved = store.retrieve(context_id) + assert content == retrieved + + # Verify metadata is created + metadata = store.get_metadata(context_id) + assert isinstance(metadata, ContextMetadata) + assert metadata.content_type == "application/json" + assert metadata.is_structured + + def test_duplicate_content_deduplication(self): + """Test that storing the same content twice returns the same ID.""" + store = ContextReferenceStore() + content = "This is a duplicate content" + + # Store the same content twice + id1 = store.store(content) + id2 = store.store(content) + + # Verify both IDs are the same + assert id1 == id2 + + def test_cache_hint(self): + """Test getting cache hints for stored content.""" + store = ContextReferenceStore() + + # Store content with a cache TTL + content = "Content with cache TTL" + metadata = {"cache_ttl": 3600} # 1 hour cache + context_id = store.store(content, metadata) + + # Get cache hint + cache_hint = store.get_cache_hint(context_id) + + # Verify cache hint has expected fields + assert "cache_id" in cache_hint + assert cache_hint["cache_level"] == "HIGH" + assert "ttl_seconds" in cache_hint + assert cache_hint["ttl_seconds"] <= 3600 + + def test_cache_management(self): + """Test cache size management.""" + # Create a store with small cache size + store = ContextReferenceStore(cache_size=2) + + # Store 3 different items, which should evict the first one + id1 = store.store("Content 1") + time.sleep(0.01) # Ensure different access times + id2 = store.store("Content 2") + time.sleep(0.01) + id3 = store.store("Content 3") + + # First content should be evicted + with pytest.raises(KeyError): + store.retrieve(id1) + + # The other two should still be accessible + assert "Content 2" == store.retrieve(id2) + assert "Content 3" == store.retrieve(id3) + + +class TestLargeContextState: + """Tests for LargeContextState.""" + + def test_add_and_get_context(self): + """Test adding and retrieving context.""" + state = LargeContextState() + content = "This is a test context" + + # Add context + ref_id = state.add_large_context(content) + + # Verify reference is stored in state + assert "context_ref" in state + assert state["context_ref"] == ref_id + + # Retrieve context + retrieved = state.get_context() + assert content == retrieved + + def test_add_and_get_structured_context(self): + """Test adding and retrieving structured context.""" + state = LargeContextState() + content = {"key": "value", "nested": {"subkey": "subvalue"}} + + # Add structured context + ref_id = state.store_structured_context(content) + + # Verify reference is stored in state + assert "structured_context_ref" in state + assert state["structured_context_ref"] == ref_id + + # Retrieve context + retrieved = state.get_context("structured_context_ref") + assert content == retrieved + + def test_with_cache_hint(self): + """Test getting cache hints from state.""" + state = LargeContextState() + content = "Content for caching" + metadata = {"cache_ttl": 1800} # 30 minutes + + # Add context with cache metadata + state.add_large_context(content, metadata) + + # Get cache hint + cache_hint = state.with_cache_hint() + + # Verify cache hint + assert "cache_id" in cache_hint + assert cache_hint["cache_level"] == "HIGH" + assert "ttl_seconds" in cache_hint + assert cache_hint["ttl_seconds"] <= 1800 + + def test_context_not_found(self): + """Test error handling when context is not found.""" + state = LargeContextState() + + # Attempt to retrieve non-existent context + with pytest.raises(KeyError): + state.get_context("nonexistent_ref") diff --git a/tests/unittests/utils/conftest.py b/tests/unittests/utils/conftest.py new file mode 100644 index 000000000..adbe53d2c --- /dev/null +++ b/tests/unittests/utils/conftest.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test fixtures for utils tests.""" + +import sys +from unittest import mock + +# Create mocks for modules before they're imported +mock_event = mock.MagicMock() +mock_event.Event = mock.MagicMock() + +mock_adk_to_mcp_tool_type = mock.MagicMock() +mock_async_utils = mock.MagicMock() +mock_date_utils = mock.MagicMock() +mock_image_utils = mock.MagicMock() +mock_multipart_utils = mock.MagicMock() +mock_structured_output_utils = mock.MagicMock() +mock_truncate_utils = mock.MagicMock() +mock_uri_utils = mock.MagicMock() +mock_llm_response = mock.MagicMock() +mock_variant_utils = mock.MagicMock() +mock_variant_utils.GoogleLLMVariant = mock.MagicMock() + +# Apply mocks at the module level before any imports happen +sys.modules["google.adk.events.event"] = mock_event +sys.modules["google.adk.utils.adk_to_mcp_tool_type"] = mock_adk_to_mcp_tool_type +sys.modules["google.adk.utils.async_utils"] = mock_async_utils +sys.modules["google.adk.utils.date_utils"] = mock_date_utils +sys.modules["google.adk.utils.image_utils"] = mock_image_utils +sys.modules["google.adk.utils.multipart_utils"] = mock_multipart_utils +sys.modules["google.adk.utils.structured_output_utils"] = mock_structured_output_utils +sys.modules["google.adk.utils.truncate_utils"] = mock_truncate_utils +sys.modules["google.adk.utils.uri_utils"] = mock_uri_utils +sys.modules["google.adk.models.llm_response"] = mock_llm_response +sys.modules["google.adk.utils.variant_utils"] = mock_variant_utils + + +# Function to reset mocks after tests +def pytest_unconfigure(config): + """Reset any global state modified by the tests.""" + # Remove our mocked modules + for module_name in [ + "google.adk.events.event", + "google.adk.utils.adk_to_mcp_tool_type", + "google.adk.utils.async_utils", + "google.adk.utils.date_utils", + "google.adk.utils.image_utils", + "google.adk.utils.multipart_utils", + "google.adk.utils.structured_output_utils", + "google.adk.utils.truncate_utils", + "google.adk.utils.uri_utils", + "google.adk.models.llm_response", + "google.adk.utils.variant_utils", + ]: + if module_name in sys.modules: + del sys.modules[module_name] diff --git a/tests/unittests/utils/test_langgraph_utils.py b/tests/unittests/utils/test_langgraph_utils.py new file mode 100644 index 000000000..16290ef5d --- /dev/null +++ b/tests/unittests/utils/test_langgraph_utils.py @@ -0,0 +1,131 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LangGraph context management utilities.""" + +import pytest +from typing import Dict, Any, List + +# Force the importing of conftest.py which sets up our mock modules +from . import conftest + +# Now import the modules we want to test +from google.adk.sessions.context_reference_store import ContextReferenceStore +from google.adk.utils.langgraph_utils import ( + LangGraphContextManager, + create_reference_aware_merge, +) + + +class TestLangGraphContextManager: + """Tests for LangGraphContextManager.""" + + def test_add_to_state(self): + """Test adding content to state.""" + manager = LangGraphContextManager() + original_state = {"counter": 1, "messages": ["Hello"]} + + # Add content to state + content = "This is test content" + new_state = manager.add_to_state(original_state, content) + + # Verify state is updated with reference + assert "context_ref" in new_state + assert new_state["counter"] == 1 # Original data preserved + assert new_state["messages"] == ["Hello"] # Original data preserved + + # Verify original state is not modified + assert "context_ref" not in original_state + + def test_retrieve_from_state(self): + """Test retrieving content from state.""" + manager = LangGraphContextManager() + content = "This is retrievable content" + + # Add content to state + state = manager.add_to_state({}, content) + + # Retrieve content + retrieved = manager.retrieve_from_state(state) + assert content == retrieved + + # Test with custom key + custom_state = manager.add_to_state( + {}, "Custom key content", ref_key="custom_ref" + ) + custom_retrieved = manager.retrieve_from_state(custom_state, "custom_ref") + assert "Custom key content" == custom_retrieved + + def test_reference_aware_merge(self): + """Test reference-aware merge function.""" + # Create merge function + merge_fn = create_reference_aware_merge() + + # Create states with references + context_store = ContextReferenceStore() + ref1 = context_store.store("Content 1") + ref2 = context_store.store("Content 2") + + # Create states to merge + left: Dict[str, Any] = { + "context_ref": ref1, + "counter": 1, + "messages": ["First message"], + } + + right: Dict[str, Any] = { + "context_ref": ref2, # New reference that should replace the old one + "counter": 2, + "messages": ["Second message"], + } + + # Merge states + merged = merge_fn(left, right) + + # Verify merge results + assert merged["context_ref"] == ref2 # Right reference preferred + assert merged["counter"] == 2 # Right value preferred + assert merged["messages"] == ["Second message"] # Right value preferred + + def test_reference_aware_merge_partial_update(self): + """Test reference-aware merge with partial updates.""" + # Create merge function + merge_fn = create_reference_aware_merge() + + # Create states with references + context_store = ContextReferenceStore() + ref1 = context_store.store("Content 1") + + # Create states to merge + left: Dict[str, Any] = { + "context_ref": ref1, + "other_ref": "other-ref-1", + "counter": 1, + "messages": ["First message"], + } + + right: Dict[str, Any] = { + # Note: no context_ref in right state + "counter": 2, + # Note: no messages in right state + } + + # Merge states + merged = merge_fn(left, right) + + # Verify merge results + assert merged["context_ref"] == ref1 # Left reference preserved + assert merged["other_ref"] == "other-ref-1" # Left reference preserved + assert merged["counter"] == 2 # Right value used + assert merged["messages"] == ["First message"] # Left value preserved