Skip to content

Commit 7aee0c3

Browse files
authored
Feat/context windows (#247)
* initial usage tracking * add tracking for gemini * add last turn display, summary on shutdown * add /usage command * dedup usage report * model database v1 * db updates * update model database * tests/fixtures/linter * revert example
1 parent 26c6c7d commit 7aee0c3

22 files changed

+1618
-46
lines changed

.vscode/settings.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
"fastagent.secrets.yaml"
1313
]
1414
},
15-
"editor.fontFamily": "BlexMono Nerd Font",
1615
"python.testing.pytestArgs": ["tests"],
1716
"python.testing.unittestEnabled": false,
1817
"python.testing.pytestEnabled": true,
19-
"python.analysis.typeCheckingMode": "standard"
18+
"python.analysis.typeCheckingMode": "standard",
19+
"python.analysis.nodeExecutable": "auto"
2020
}

src/mcp_agent/agents/base_agent.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
HUMAN_INPUT_TOOL_NAME = "__human_input__"
5959
if TYPE_CHECKING:
6060
from mcp_agent.context import Context
61+
from mcp_agent.llm.usage_tracking import UsageAccumulator
6162

6263

6364
DEFAULT_CAPABILITIES = AgentCapabilities(
@@ -698,3 +699,15 @@ def message_history(self) -> List[PromptMessageMultipart]:
698699
if self._llm:
699700
return self._llm.message_history
700701
return []
702+
703+
@property
704+
def usage_accumulator(self) -> Optional["UsageAccumulator"]:
705+
"""
706+
Return the usage accumulator for tracking token usage across turns.
707+
708+
Returns:
709+
UsageAccumulator object if LLM is attached, None otherwise
710+
"""
711+
if self._llm:
712+
return self._llm.usage_accumulator
713+
return None

src/mcp_agent/core/agent_app.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
from deprecated import deprecated
88
from mcp.types import PromptMessage
9+
from rich import print as rich_print
910

1011
from mcp_agent.agents.agent import Agent
1112
from mcp_agent.core.interactive_prompt import InteractivePrompt
1213
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
14+
from mcp_agent.progress_display import progress_display
1315

1416

1517
class AgentApp:
@@ -272,7 +274,12 @@ async def interactive(self, agent_name: str | None = None, default_prompt: str =
272274

273275
# Define the wrapper for send function
274276
async def send_wrapper(message, agent_name):
275-
return await self.send(message, agent_name)
277+
result = await self.send(message, agent_name)
278+
279+
# Show usage info after each turn if progress display is enabled
280+
self._show_turn_usage(agent_name)
281+
282+
return result
276283

277284
# Start the prompt loop with the agent name (not the agent object)
278285
return await prompt.prompt_loop(
@@ -282,3 +289,36 @@ async def send_wrapper(message, agent_name):
282289
prompt_provider=self, # Pass self as the prompt provider
283290
default=default_prompt,
284291
)
292+
293+
def _show_turn_usage(self, agent_name: str) -> None:
294+
"""Show subtle usage information after each turn."""
295+
agent = self._agents.get(agent_name)
296+
if not agent or not agent.usage_accumulator:
297+
return
298+
299+
# Get the last turn's usage (if any)
300+
turns = agent.usage_accumulator.turns
301+
if not turns:
302+
return
303+
304+
last_turn = turns[-1]
305+
input_tokens = last_turn.input_tokens
306+
output_tokens = last_turn.output_tokens
307+
308+
# Build cache indicators with bright colors
309+
cache_indicators = ""
310+
if last_turn.cache_usage.cache_write_tokens > 0:
311+
cache_indicators += "[bright_yellow]^[/bright_yellow]"
312+
if last_turn.cache_usage.cache_read_tokens > 0 or last_turn.cache_usage.cache_hit_tokens > 0:
313+
cache_indicators += "[bright_green]*[/bright_green]"
314+
315+
# Build context percentage - get from accumulator, not individual turn
316+
context_info = ""
317+
context_percentage = agent.usage_accumulator.context_usage_percentage
318+
if context_percentage is not None:
319+
context_info = f" ({context_percentage:.1f}%)"
320+
321+
# Show subtle usage line - pause progress display to ensure visibility
322+
with progress_display.paused():
323+
cache_suffix = f" {cache_indicators}" if cache_indicators else ""
324+
rich_print(f"[dim]Last turn: {input_tokens:,} Input, {output_tokens:,} Output{context_info}[/dim]{cache_suffix}")

src/mcp_agent/core/enhanced_prompt.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
"prompts": "List and select MCP prompts", # Changed description
5959
"prompt": "Apply a specific prompt by name (/prompt <name>)", # New command
6060
"agents": "List available agents",
61+
"usage": "Show current usage statistics",
6162
"clear": "Clear the screen",
6263
"STOP": "Stop this prompting session and move to next workflow step",
6364
"EXIT": "Exit fast-agent, terminating any running workflows",
@@ -67,6 +68,7 @@ def __init__(
6768
self.commands.pop("agents")
6869
self.commands.pop("prompts") # Remove prompts command in human input mode
6970
self.commands.pop("prompt", None) # Remove prompt command in human input mode
71+
self.commands.pop("usage", None) # Remove usage command in human input mode
7072
self.agent_types = agent_types or {}
7173

7274
def get_completions(self, document, complete_event):
@@ -390,6 +392,8 @@ def pre_process_input(text):
390392
return "CLEAR"
391393
elif cmd == "agents":
392394
return "LIST_AGENTS"
395+
elif cmd == "usage":
396+
return "SHOW_USAGE"
393397
elif cmd == "prompts":
394398
# Return a dictionary with select_prompt action instead of a string
395399
# This way it will match what the command handler expects
@@ -566,6 +570,7 @@ async def handle_special_commands(command, agent_app=None):
566570
rich_print(" /agents - List available agents")
567571
rich_print(" /prompts - List and select MCP prompts")
568572
rich_print(" /prompt <name> - Apply a specific prompt by name")
573+
rich_print(" /usage - Show current usage statistics")
569574
rich_print(" @agent_name - Switch to agent")
570575
rich_print(" STOP - Return control back to the workflow")
571576
rich_print(" EXIT - Exit fast-agent, terminating any running workflows")
@@ -594,6 +599,10 @@ async def handle_special_commands(command, agent_app=None):
594599
rich_print("[yellow]No agents available[/yellow]")
595600
return True
596601

602+
elif command == "SHOW_USAGE":
603+
# Return a dictionary to signal that usage should be shown
604+
return {"show_usage": True}
605+
597606
elif command == "SELECT_PROMPT" or (
598607
isinstance(command, str) and command.startswith("SELECT_PROMPT:")
599608
):

src/mcp_agent/core/fastagent.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
ServerConfigError,
5555
ServerInitializationError,
5656
)
57+
from mcp_agent.core.usage_display import display_usage_report
5758
from mcp_agent.core.validation import (
5859
validate_server_references,
5960
validate_workflow_references,
@@ -392,22 +393,29 @@ def model_factory_func(model=None, request_params=None):
392393

393394
yield wrapper
394395

396+
except PromptExitError as e:
397+
# User requested exit - not an error, show usage report
398+
self._handle_error(e)
399+
raise SystemExit(0)
395400
except (
396401
ServerConfigError,
397402
ProviderKeyError,
398403
AgentConfigError,
399404
ServerInitializationError,
400405
ModelConfigError,
401406
CircularDependencyError,
402-
PromptExitError,
403407
) as e:
404408
had_error = True
405409
self._handle_error(e)
406410
raise SystemExit(1)
407411

408412
finally:
409-
# Clean up any active agents
413+
# Print usage report before cleanup (show for user exits too)
410414
if active_agents and not had_error:
415+
self._print_usage_report(active_agents)
416+
417+
# Clean up any active agents (always cleanup, even on errors)
418+
if active_agents:
411419
for agent in active_agents.values():
412420
try:
413421
await agent.shutdown()
@@ -472,6 +480,10 @@ def _handle_error(self, e: Exception, error_type: Optional[str] = None) -> None:
472480
else:
473481
handle_error(e, error_type or "Error", "An unexpected error occurred.")
474482

483+
def _print_usage_report(self, active_agents: dict) -> None:
484+
"""Print a formatted table of token usage for all agents."""
485+
display_usage_report(active_agents, show_if_progress_disabled=False, subdued_colors=True)
486+
475487
async def start_server(
476488
self,
477489
transport: str = "sse",

src/mcp_agent/core/interactive_prompt.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,34 @@
2828
get_selection_input,
2929
handle_special_commands,
3030
)
31+
from mcp_agent.core.usage_display import collect_agents_from_provider, display_usage_report
3132
from mcp_agent.mcp.mcp_aggregator import SEP # Import SEP once at the top
3233
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
3334
from mcp_agent.progress_display import progress_display
3435

3536
# Type alias for the send function
3637
SendFunc = Callable[[Union[str, PromptMessage, PromptMessageMultipart], str], Awaitable[str]]
3738

39+
# Type alias for the agent getter function
40+
AgentGetter = Callable[[str], Optional[object]]
41+
3842

3943
class PromptProvider(Protocol):
4044
"""Protocol for objects that can provide prompt functionality."""
41-
42-
async def list_prompts(self, server_name: Optional[str] = None, agent_name: Optional[str] = None) -> Mapping[str, List[Prompt]]:
45+
46+
async def list_prompts(
47+
self, server_name: Optional[str] = None, agent_name: Optional[str] = None
48+
) -> Mapping[str, List[Prompt]]:
4349
"""List available prompts."""
4450
...
45-
46-
async def apply_prompt(self, prompt_name: str, arguments: Optional[Dict[str, str]] = None, agent_name: Optional[str] = None, **kwargs) -> str:
51+
52+
async def apply_prompt(
53+
self,
54+
prompt_name: str,
55+
arguments: Optional[Dict[str, str]] = None,
56+
agent_name: Optional[str] = None,
57+
**kwargs,
58+
) -> str:
4759
"""Apply a prompt."""
4860
...
4961

@@ -160,17 +172,23 @@ async def prompt_loop(
160172
await self._list_prompts(prompt_provider, agent)
161173
else:
162174
# Use the name-based selection
163-
await self._select_prompt(
164-
prompt_provider, agent, prompt_name
165-
)
175+
await self._select_prompt(prompt_provider, agent, prompt_name)
176+
continue
177+
elif "show_usage" in command_result:
178+
# Handle usage display
179+
await self._show_usage(prompt_provider, agent)
166180
continue
167181

168182
# Skip further processing if:
169183
# 1. The command was handled (command_result is truthy)
170184
# 2. The original input was a dictionary (special command like /prompt)
171185
# 3. The command result itself is a dictionary (special command handling result)
172186
# This fixes the issue where /prompt without arguments gets sent to the LLM
173-
if command_result or isinstance(user_input, dict) or isinstance(command_result, dict):
187+
if (
188+
command_result
189+
or isinstance(user_input, dict)
190+
or isinstance(command_result, dict)
191+
):
174192
continue
175193

176194
if user_input.upper() == "STOP":
@@ -183,7 +201,9 @@ async def prompt_loop(
183201

184202
return result
185203

186-
async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Optional[str] = None):
204+
async def _get_all_prompts(
205+
self, prompt_provider: PromptProvider, agent_name: Optional[str] = None
206+
):
187207
"""
188208
Get a list of all available prompts.
189209
@@ -196,8 +216,10 @@ async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Op
196216
"""
197217
try:
198218
# Call list_prompts on the provider
199-
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
200-
219+
prompt_servers = await prompt_provider.list_prompts(
220+
server_name=None, agent_name=agent_name
221+
)
222+
201223
all_prompts = []
202224

203225
# Process the returned prompt servers
@@ -326,9 +348,11 @@ async def _select_prompt(
326348
try:
327349
# Get all available prompts directly from the prompt provider
328350
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
329-
351+
330352
# Call list_prompts on the provider
331-
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
353+
prompt_servers = await prompt_provider.list_prompts(
354+
server_name=None, agent_name=agent_name
355+
)
332356

333357
if not prompt_servers:
334358
rich_print("[yellow]No prompts available for this agent[/yellow]")
@@ -557,3 +581,25 @@ async def _select_prompt(
557581

558582
rich_print(f"[red]Error selecting or applying prompt: {e}[/red]")
559583
rich_print(f"[dim]{traceback.format_exc()}[/dim]")
584+
585+
async def _show_usage(self, prompt_provider: PromptProvider, agent_name: str) -> None:
586+
"""
587+
Show usage statistics for the current agent(s) in a colorful table format.
588+
589+
Args:
590+
prompt_provider: Provider that has access to agents
591+
agent_name: Name of the current agent
592+
"""
593+
try:
594+
# Collect all agents from the prompt provider
595+
agents_to_show = collect_agents_from_provider(prompt_provider, agent_name)
596+
597+
if not agents_to_show:
598+
rich_print("[yellow]No usage data available[/yellow]")
599+
return
600+
601+
# Use the shared display utility
602+
display_usage_report(agents_to_show, show_if_progress_disabled=True)
603+
604+
except Exception as e:
605+
rich_print(f"[red]Error showing usage: {e}[/red]")

0 commit comments

Comments
 (0)