Skip to content

Truncation of message histories and individual messages to prevent context window overflows #246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mcp_agent/core/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class AgentConfig(BaseModel):
servers: List[str] = Field(default_factory=list)
model: str | None = None
use_history: bool = True
max_context_length_per_message: int = 100_000,
max_context_length_total: int = 1_000_000,
default_request_params: RequestParams | None = None
human_input: bool = False
agent_type: AgentType = AgentType.BASIC
Expand Down
18 changes: 18 additions & 0 deletions src/mcp_agent/core/direct_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def _decorator_impl(
servers: List[str] = [],
model: Optional[str] = None,
use_history: bool = True,
max_context_length_per_message: int = 100_000,
max_context_length_total: int = 1_000_000,
request_params: RequestParams | None = None,
human_input: bool = False,
default: bool = False,
Expand All @@ -103,6 +105,8 @@ def _decorator_impl(
servers: List of server names the agent should connect to
model: Model specification string
use_history: Whether to maintain conversation history
max_context_length_per_message: Maximum context length per message
max_context_length_total: Maximum total context length for the agent
request_params: Additional request parameters for the LLM
human_input: Whether to enable human input capabilities
default: Whether to mark this as the default agent
Expand Down Expand Up @@ -133,6 +137,8 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
servers=servers,
model=model,
use_history=use_history,
max_context_length_per_message=max_context_length_per_message,
max_context_length_total=max_context_length_total,
human_input=human_input,
default=default,
)
Expand Down Expand Up @@ -175,6 +181,8 @@ def agent(
servers: List[str] = [],
model: Optional[str] = None,
use_history: bool = True,
max_context_length_per_message: int = 100_000,
max_context_length_total: int = 1_000_000,
request_params: RequestParams | None = None,
human_input: bool = False,
default: bool = False,
Expand All @@ -189,6 +197,8 @@ def agent(
servers: List of server names the agent should connect to
model: Model specification string
use_history: Whether to maintain conversation history
max_context_length_per_message: Maximum context length per message
max_context_length_total: Maximum total context length for the agent
request_params: Additional request parameters for the LLM
human_input: Whether to enable human input capabilities
default: Whether to mark this as the default agent
Expand All @@ -206,6 +216,8 @@ def agent(
servers=servers,
model=model,
use_history=use_history,
max_context_length_per_message=max_context_length_per_message,
max_context_length_total=max_context_length_total,
request_params=request_params,
human_input=human_input,
default=default,
Expand All @@ -222,6 +234,8 @@ def custom(
servers: List[str] = [],
model: Optional[str] = None,
use_history: bool = True,
max_context_length_per_message: int = 100_000,
max_context_length_total: int = 1_000_000,
request_params: RequestParams | None = None,
human_input: bool = False,
default: bool = False,
Expand All @@ -236,6 +250,8 @@ def custom(
servers: List of server names the agent should connect to
model: Model specification string
use_history: Whether to maintain conversation history
max_context_length_per_message: Maximum context length per message
max_context_length_total: Maximum total context length for the agent
request_params: Additional request parameters for the LLM
human_input: Whether to enable human input capabilities

Expand All @@ -252,6 +268,8 @@ def custom(
servers=servers,
model=model,
use_history=use_history,
max_context_length_per_message=max_context_length_per_message,
max_context_length_total=max_context_length_total,
request_params=request_params,
human_input=human_input,
agent_class=cls,
Expand Down
92 changes: 92 additions & 0 deletions src/mcp_agent/llm/augmented_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
PARAM_METADATA = "metadata"
PARAM_USE_HISTORY = "use_history"
PARAM_MAX_ITERATIONS = "max_iterations"
PARAM_MAX_CONTEXT_LENGTH_TOTAL = "max_context_length_total"
PARAM_MAX_CONTEXT_LENGTH_PER_MESSAGE = "max_context_length_per_message"
PARAM_TEMPLATE_VARS = "template_vars"
# Base set of fields that should always be excluded
BASE_EXCLUDE_FIELDS = {PARAM_METADATA}
Expand Down Expand Up @@ -642,3 +644,93 @@ def _api_key(self):

assert self.provider
return ProviderKeyManager.get_api_key(self.provider.value, self.context.config)

def _truncate_message_history(self) -> None:
"""
Truncates the agent's message history to ensure it does not exceed the maximum context length.
This method is multimodal-aware and only calculates length based on text content.
"""

self.logger.debug(f"""[_truncate_message_history]

Truncating message history for {self.name} LLM.

""")

attributes = vars(self)
print(attributes)
print("\n--- Attributes and Values from augmented_llm_google_native---")
for key, value in attributes.items():

if "secret" in key or "key" in key or "password" in key or "token" in key or "env" in key:
# Mask sensitive information
value = "******"
if "max_context_length" in key:
value = f"{value} <-----------------------------------------------------------------" # Format large numbers with commas
print(f"{key}: {value}")

# Pull out the configurable limits for clarity
max_len_per_message = self.default_request_params.max_context_length_per_message
max_total_len = self.default_request_params.max_context_length_total

if not self._message_history:
self.logger.debug("[_truncate_message_history] Message history is empty, nothing to truncate.")
return

if not max_total_len and not max_len_per_message:
self.logger.debug("[_truncate_message_history] No truncation limits set, skipping truncation.")
return

# Step 1: Truncate individual message content pieces
if max_len_per_message is not None:
self.logger.debug(
f"[_truncate_message_history] Truncating individual message content pieces to max length {max_len_per_message}."
)
for message in self._message_history:
for i, content_piece in enumerate(message.content):
# *** FIX: Only attempt to truncate if it's TextContent ***
if isinstance(content_piece, TextContent) and content_piece.text:
if len(content_piece.text) > max_len_per_message:
original_length = len(content_piece.text)
truncated_text = content_piece.text[:max_len_per_message]

# Re-assign the modified content piece
message.content[i] = TextContent(type="text", text=truncated_text)

self.logger.debug(
f"[_truncate_message_history] Truncated content piece: original_length={original_length}, "
f"truncated_length={len(truncated_text)}"
)

# Step 2: Remove oldest messages if total context length is exceeded
if max_total_len is not None:
self.logger.debug(
f"[_truncate_message_history] Checking total context length against max limit {max_total_len}."
)
# Loop until the history is under the total length limit
while self._message_history:
current_total_length = 0
# Calculate total text length safely
for message in self._message_history:
for content_piece in message.content:
# *** FIX: Only add length if it's TextContent ***
if isinstance(content_piece, TextContent) and content_piece.text:
current_total_length += len(content_piece.text)

# If we are within the limit, we're done
if current_total_length <= max_total_len:
self.logger.info(
f"[_truncate_message_history] Total history truncation complete: final_length={current_total_length}, "
f"final_message_count={len(self._message_history)}"
)
break

# If over the limit, remove the oldest message and re-check
else:
self.logger.debug(
f"[_truncate_message_history] Total context length ({current_total_length}) > max ({max_total_len}). "
f"Removing oldest message. History size: {len(self._message_history)}"
)
self._message_history.pop(0)
else: # This else belongs to the while loop, executes if history becomes empty
self.logger.info("History became empty during total length truncation.")
5 changes: 5 additions & 0 deletions src/mcp_agent/llm/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

# from mcp_agent.workflows.llm.augmented_llm_deepseek import DeekSeekAugmentedLLM

from mcp_agent.logging.logger import get_logger

logger = get_logger(__name__)

# Type alias for LLM classes
LLMClass = Union[
Expand Down Expand Up @@ -206,6 +209,8 @@ def parse_model_string(cls, model_string: str) -> ModelConfig:
f"TensorZero provider requires a function name after the provider "
f"(e.g., tensorzero.my-function), got: {model_string}"
)

logger.debug(f"ModelConfig will be created: {model_name_str} with provider {provider}")

return ModelConfig(
provider=provider, model_name=model_name_str, reasoning_effort=reasoning_effort
Expand Down
102 changes: 102 additions & 0 deletions src/mcp_agent/llm/providers/augmented_llm_google_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,14 @@ async def _google_completion(
initial_history_length = len(conversation_history)

for i in range(request_params.max_iterations):

self.logger.debug(f"[self._truncate_message_history]: Iteration {i + 1} of {request_params.max_iterations}")


conversation_history = self._truncate_message_history(
conversation_history
)

# 1. Get available tools
aggregator_response = await self.aggregator.list_tools()
available_tools = self._converter.convert_to_google_tools(
Expand Down Expand Up @@ -474,3 +482,97 @@ async def post_tool_call(
"""
# Currently a pass-through, can add Google-specific logic if needed
return result

def _truncate_message_history(
self, conversation_history: List[types.Content]
) -> List[types.Content]:
"""
Override the base class method to handle Google-specific message history truncation.

This method operates on the native google.genai.types.Content format.
"""

self.logger.debug(
f"[_truncate_message_history] Starting truncation with {len(conversation_history)} messages."
)

if not conversation_history:
self.logger.debug("conversation_history is empty, nothing to truncate.")
return conversation_history

# Due to the immutability of Content/Part objects, it's safer to build a new list.
truncated_history = list(conversation_history)

# Step 1: Truncate individual message parts

attributes = vars(self)
print(attributes)
print("\n--- Attributes and Values from augmented_llm_google_native---")
for key, value in attributes.items():

if "secret" in key or "key" in key or "password" in key or "token" in key or "env" in key:
value = "******"
if "secret" in str(value) or "key" in str(value) or "password" in str(value) or "token" in str(value) or "env" in str(value):
value = "******"
if "max_context_length" in key:
value = f"{value} <-----------------------------------------------------------------" # Format large numbers with commas
print(f"{key}: {value}")

if self.max_context_length_per_message is not None:
self.logger.debug(
f"[_truncate_message_history] Truncating individual message parts to max length: {self.max_context_length_per_message}, {len(truncated_history)} messages"
)

# Create a new list to hold the potentially modified messages
temp_history = []
for message in truncated_history:
new_parts = []
for part in message.parts:
# Check if the part has text to truncate
if hasattr(part, "text") and part.text:
if len(part.text) > self.max_context_length_per_message:
truncated_text = part.text[: self.max_context_length_per_message]
# Create a new Part with the truncated text
new_parts.append(types.Part(text=truncated_text))
self.logger.debug(
f"Truncated part: original_length={len(part.text)}, "
f"truncated_length={len(truncated_text)}"
)
else:
new_parts.append(part) # Keep the original part
else:
new_parts.append(part) # Keep non-text parts as they are

# Create a new Content object with the new parts and original role
temp_history.append(types.Content(parts=new_parts, role=message.role))
truncated_history = temp_history


# Step 2: Remove oldest messages if total context length is exceeded
if self.max_context_length_total is not None:
self.logger.debug(
f"[_truncate_message_history] Truncating total context length to max: {self.max_context_length_total}, {len(truncated_history)} messages"
)
# Loop until the total length is acceptable
while True:
current_total_length = sum(
len(part.text)
for message in truncated_history
for part in message.parts
if hasattr(part, "text") and part.text
)

if current_total_length > self.max_context_length_total and truncated_history:
self.logger.debug(
f"Total context length ({current_total_length}) > max ({self.max_context_length_total}). "
f"Removing oldest message. History size: {len(truncated_history)}"
)
truncated_history.pop(0) # Remove the oldest message
else:
self.logger.info(
f"Total history truncation complete: final_length={current_total_length}, "
f"final_message_count={len(truncated_history)}"
)
break

return truncated_history
Loading