From 7348da419ec136f9311f0a6c591d2e604c00c31d Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 13 Mar 2026 11:09:33 -0400 Subject: [PATCH 1/3] enable prompt caching for agent calls --- effectful/handlers/llm/completions.py | 48 ++++++++- tests/test_handlers_llm_provider.py | 146 ++++++++++++++++++++++++++ tests/test_handlers_llm_template.py | 4 +- 3 files changed, 196 insertions(+), 2 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index fc6ca47a..3a2cfd87 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -159,6 +159,34 @@ def to_feedback_message(self, include_traceback: bool) -> Message: type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None] +CACHE_CONTROL_EPHEMERAL = {"type": "ephemeral"} + + +def _add_cache_control_to_history( + history: collections.OrderedDict[str, "Message"], +) -> None: + """Add cache_control to the last user/tool message in an agent's history. + + This enables prompt caching on providers that support it (e.g. Anthropic). + Providers that don't support it (e.g. OpenAI) have cache_control stripped + by litellm's request transformation, so this is always safe to apply. + + Mutates the history OrderedDict in place. + """ + if not history: + return + key = next(reversed(history)) + msg = history[key] + if msg["role"] not in ("user", "tool"): + return + content = msg.get("content") + if isinstance(content, list) and content: + last_block = content[-1] + if isinstance(last_block, dict) and "cache_control" not in last_block: + new_content = list(content) + new_content[-1] = {**last_block, "cache_control": CACHE_CONTROL_EPHEMERAL} + history[key] = typing.cast(Message, {**msg, "content": new_content}) + @Operation.define @functools.wraps(litellm.completion) @@ -326,7 +354,18 @@ def flush_text() -> None: def call_system(template: Template) -> Message: """Get system instruction message(s) to prepend to all LLM prompts.""" system_prompt = template.__system_prompt__ or DEFAULT_SYSTEM_PROMPT - message = _make_message(dict(role="system", content=system_prompt)) + message = _make_message( + dict( + role="system", + content=[ + { + "type": "text", + "text": system_prompt, + "cache_control": {"type": "ephemeral"}, + } + ], + ) + ) try: history: collections.OrderedDict[str, Message] = _get_history() if any(m["role"] == "system" for m in history.values()): @@ -467,6 +506,7 @@ def _call[**P, T]( history: collections.OrderedDict[str, Message] = getattr( template, "__history__", collections.OrderedDict() ) # type: ignore + is_agent = hasattr(template, "__history__") history_copy = history.copy() with handler({_get_history: lambda: history_copy}): @@ -474,6 +514,12 @@ def _call[**P, T]( message: Message = call_user(template.__prompt_template__, env) + # For agents with persistent history, add cache_control to the + # last user message so the growing prefix gets cached on providers + # that support it (Anthropic). litellm strips it for OpenAI. + if is_agent: + _add_cache_control_to_history(history_copy) + # loop based on: https://cookbook.openai.com/examples/reasoning_function_calls tool_calls: list[DecodedToolCall] = [] result: T | None = None diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index eec2fc78..eb9e3115 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -2196,3 +2196,149 @@ def _completion(self, model, messages=None, **kwargs): assert messages[0]["role"] == "system", ( "System message should be the first message in history" ) + + +# ============================================================================ +# Prompt Caching Tests +# ============================================================================ + + +def _has_cache_control(msg: dict) -> bool: + """Check if a message dict contains cache_control in any content block.""" + content = msg.get("content") + if isinstance(content, list): + return any(isinstance(b, dict) and "cache_control" in b for b in content) + return False + + +class CachingAgent(Agent): + """A test agent with persistent history.""" + + @Template.define + def ask(self, question: str) -> str: + """You are a helpful assistant. Answer concisely: {question}""" + raise NotHandled + + +class TestPromptCaching: + """Tests that cache_control is present in messages sent to litellm.""" + + def test_system_message_has_cache_control(self): + """System message should include cache_control for prompt caching.""" + capture = MockCompletionHandler([make_text_response("42")]) + provider = LiteLLMProvider(model="test") + + with handler(provider), handler(capture): + simple_prompt("test") + + msgs = capture.received_messages[0] + system_msgs = [m for m in msgs if m["role"] == "system"] + assert len(system_msgs) == 1 + assert _has_cache_control(system_msgs[0]), ( + f"System message should have cache_control. Got: {system_msgs[0]}" + ) + + def test_agent_user_message_has_cache_control(self): + """Agent calls should add cache_control to the last user message.""" + capture = MockCompletionHandler([make_text_response("42")]) + provider = LiteLLMProvider(model="test") + agent = CachingAgent() + + with handler(provider), handler(capture): + agent.ask("What is 2+2?") + + msgs = capture.received_messages[0] + user_msgs = [m for m in msgs if m["role"] == "user"] + assert len(user_msgs) == 1 + content = user_msgs[0]["content"] + assert isinstance(content, list) + assert "cache_control" in content[-1], ( + f"Agent user message should have cache_control. Got: {content[-1]}" + ) + + def test_non_agent_user_message_no_cache_control(self): + """Non-agent calls should NOT add cache_control to user messages.""" + capture = MockCompletionHandler([make_text_response("42")]) + provider = LiteLLMProvider(model="test") + + with handler(provider), handler(capture): + simple_prompt("test") + + msgs = capture.received_messages[0] + user_msgs = [m for m in msgs if m["role"] == "user"] + content = user_msgs[0]["content"] + assert isinstance(content, list) + assert "cache_control" not in content[-1], ( + "Non-agent user messages should NOT have cache_control" + ) + + def test_cache_control_format_is_ephemeral(self): + """cache_control should use the ephemeral type.""" + capture = MockCompletionHandler([make_text_response("42")]) + provider = LiteLLMProvider(model="test") + + with handler(provider), handler(capture): + simple_prompt("test") + + for msg in capture.received_messages[0]: + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and "cache_control" in block: + assert block["cache_control"] == {"type": "ephemeral"} + + def test_litellm_strips_cache_control_for_openai(self): + """Verify litellm strips cache_control when transforming for OpenAI.""" + from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig + + msgs = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Hi.", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hi", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + ] + config = OpenAIGPTConfig() + transformed = config.transform_request( + model="gpt-4o", + messages=msgs, + optional_params={}, + litellm_params={}, + headers={}, + ) + for msg in transformed["messages"]: + content = msg.get("content") + if isinstance(content, list): + for block in content: + assert "cache_control" not in block + + @requires_openai + def test_openai_accepts_cache_control_via_litellm(self): + """OpenAI works fine with cache_control (litellm strips it).""" + provider = LiteLLMProvider(model="gpt-4o-mini") + with handler(provider): + result = simple_prompt("math") + assert isinstance(result, str) + + @requires_anthropic + def test_anthropic_accepts_cache_control(self): + """Anthropic should accept messages with cache_control.""" + provider = LiteLLMProvider(model="claude-opus-4-6", max_tokens=20) + with handler(provider): + result = simple_prompt("math") + assert isinstance(result, str) diff --git a/tests/test_handlers_llm_template.py b/tests/test_handlers_llm_template.py index 7c3bd5bc..7e792a98 100644 --- a/tests/test_handlers_llm_template.py +++ b/tests/test_handlers_llm_template.py @@ -464,7 +464,9 @@ def standalone(topic: str) -> str: standalone("fish") assert_single_system_message_first(mock.received_messages[0]) - assert mock.received_messages[0][0]["content"] == DEFAULT_SYSTEM_PROMPT + content = mock.received_messages[0][0]["content"] + # System message content is now a list of blocks with cache_control + assert content[0]["text"] == DEFAULT_SYSTEM_PROMPT class TestAgentDocstringFallback: From 369d85bff7bf4b7a308d64fe34c908c67880b1e0 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 13 Mar 2026 11:51:11 -0400 Subject: [PATCH 2/3] assistant messages caching? --- effectful/handlers/llm/completions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index 3a2cfd87..cce19ca2 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -177,7 +177,7 @@ def _add_cache_control_to_history( return key = next(reversed(history)) msg = history[key] - if msg["role"] not in ("user", "tool"): + if msg["role"] not in ("user", "tool", "assistant"): return content = msg.get("content") if isinstance(content, list) and content: From 716d87323a12c9ff94a0afc256c6803c8cdfca17 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 13 Mar 2026 12:11:18 -0400 Subject: [PATCH 3/3] cache control in all --- effectful/handlers/llm/completions.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index cce19ca2..58decb10 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -175,17 +175,20 @@ def _add_cache_control_to_history( """ if not history: return - key = next(reversed(history)) - msg = history[key] - if msg["role"] not in ("user", "tool", "assistant"): - return - content = msg.get("content") - if isinstance(content, list) and content: - last_block = content[-1] - if isinstance(last_block, dict) and "cache_control" not in last_block: - new_content = list(content) - new_content[-1] = {**last_block, "cache_control": CACHE_CONTROL_EPHEMERAL} - history[key] = typing.cast(Message, {**msg, "content": new_content}) + for key in history: + msg = history[key] + if msg["role"] not in ("user", "tool", "assistant"): + continue + content = msg.get("content") + if isinstance(content, list) and content: + last_block = content[-1] + if isinstance(last_block, dict) and "cache_control" not in last_block: + new_content = list(content) + new_content[-1] = { + **last_block, + "cache_control": CACHE_CONTROL_EPHEMERAL, + } + history[key] = typing.cast(Message, {**msg, "content": new_content}) @Operation.define