Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,37 @@ def to_feedback_message(self, include_traceback: bool) -> Message:

type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None]

CACHE_CONTROL_EPHEMERAL = {"type": "ephemeral"}


def _add_cache_control_to_history(
history: collections.OrderedDict[str, "Message"],
) -> None:
"""Add cache_control to the last user/tool message in an agent's history.

This enables prompt caching on providers that support it (e.g. Anthropic).
Providers that don't support it (e.g. OpenAI) have cache_control stripped
by litellm's request transformation, so this is always safe to apply.

Mutates the history OrderedDict in place.
"""
if not history:
return
for key in history:
msg = history[key]
if msg["role"] not in ("user", "tool", "assistant"):
continue
content = msg.get("content")
if isinstance(content, list) and content:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see this applies to all Messages. Would it be easier to have this live in _make_message, our Message constructor? Or maybe in completions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes more sense, will update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no I guess the key difference is that we only apply this if the template has a history so we don't cache every template, only agent ones.

last_block = content[-1]
if isinstance(last_block, dict) and "cache_control" not in last_block:
new_content = list(content)
new_content[-1] = {
**last_block,
"cache_control": CACHE_CONTROL_EPHEMERAL,
}
history[key] = typing.cast(Message, {**msg, "content": new_content})


@Operation.define
@functools.wraps(litellm.completion)
Expand Down Expand Up @@ -326,7 +357,18 @@ def flush_text() -> None:
def call_system(template: Template) -> Message:
"""Get system instruction message(s) to prepend to all LLM prompts."""
system_prompt = template.__system_prompt__ or DEFAULT_SYSTEM_PROMPT
message = _make_message(dict(role="system", content=system_prompt))
message = _make_message(
dict(
role="system",
content=[
{
"type": "text",
"text": system_prompt,
"cache_control": {"type": "ephemeral"},
}
],
)
)
try:
history: collections.OrderedDict[str, Message] = _get_history()
if any(m["role"] == "system" for m in history.values()):
Expand Down Expand Up @@ -467,13 +509,20 @@ def _call[**P, T](
history: collections.OrderedDict[str, Message] = getattr(
template, "__history__", collections.OrderedDict()
) # type: ignore
is_agent = hasattr(template, "__history__")
history_copy = history.copy()

with handler({_get_history: lambda: history_copy}):
call_system(template)

message: Message = call_user(template.__prompt_template__, env)

# For agents with persistent history, add cache_control to the
# last user message so the growing prefix gets cached on providers
# that support it (Anthropic). litellm strips it for OpenAI.
if is_agent:
_add_cache_control_to_history(history_copy)

# loop based on: https://cookbook.openai.com/examples/reasoning_function_calls
tool_calls: list[DecodedToolCall] = []
result: T | None = None
Expand Down
146 changes: 146 additions & 0 deletions tests/test_handlers_llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,3 +2196,149 @@ def _completion(self, model, messages=None, **kwargs):
assert messages[0]["role"] == "system", (
"System message should be the first message in history"
)


# ============================================================================
# Prompt Caching Tests
# ============================================================================


def _has_cache_control(msg: dict) -> bool:
"""Check if a message dict contains cache_control in any content block."""
content = msg.get("content")
if isinstance(content, list):
return any(isinstance(b, dict) and "cache_control" in b for b in content)
return False


class CachingAgent(Agent):
"""A test agent with persistent history."""

@Template.define
def ask(self, question: str) -> str:
"""You are a helpful assistant. Answer concisely: {question}"""
raise NotHandled


class TestPromptCaching:
"""Tests that cache_control is present in messages sent to litellm."""

def test_system_message_has_cache_control(self):
"""System message should include cache_control for prompt caching."""
capture = MockCompletionHandler([make_text_response("42")])
provider = LiteLLMProvider(model="test")

with handler(provider), handler(capture):
simple_prompt("test")

msgs = capture.received_messages[0]
system_msgs = [m for m in msgs if m["role"] == "system"]
assert len(system_msgs) == 1
assert _has_cache_control(system_msgs[0]), (
f"System message should have cache_control. Got: {system_msgs[0]}"
)

def test_agent_user_message_has_cache_control(self):
"""Agent calls should add cache_control to the last user message."""
capture = MockCompletionHandler([make_text_response("42")])
provider = LiteLLMProvider(model="test")
agent = CachingAgent()

with handler(provider), handler(capture):
agent.ask("What is 2+2?")

msgs = capture.received_messages[0]
user_msgs = [m for m in msgs if m["role"] == "user"]
assert len(user_msgs) == 1
content = user_msgs[0]["content"]
assert isinstance(content, list)
assert "cache_control" in content[-1], (
f"Agent user message should have cache_control. Got: {content[-1]}"
)

def test_non_agent_user_message_no_cache_control(self):
"""Non-agent calls should NOT add cache_control to user messages."""
capture = MockCompletionHandler([make_text_response("42")])
provider = LiteLLMProvider(model="test")

with handler(provider), handler(capture):
simple_prompt("test")

msgs = capture.received_messages[0]
user_msgs = [m for m in msgs if m["role"] == "user"]
content = user_msgs[0]["content"]
assert isinstance(content, list)
assert "cache_control" not in content[-1], (
"Non-agent user messages should NOT have cache_control"
)

def test_cache_control_format_is_ephemeral(self):
"""cache_control should use the ephemeral type."""
capture = MockCompletionHandler([make_text_response("42")])
provider = LiteLLMProvider(model="test")

with handler(provider), handler(capture):
simple_prompt("test")

for msg in capture.received_messages[0]:
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and "cache_control" in block:
assert block["cache_control"] == {"type": "ephemeral"}

def test_litellm_strips_cache_control_for_openai(self):
"""Verify litellm strips cache_control when transforming for OpenAI."""
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig

msgs = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "Hi.",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hi",
"cache_control": {"type": "ephemeral"},
}
],
},
]
config = OpenAIGPTConfig()
transformed = config.transform_request(
model="gpt-4o",
messages=msgs,
optional_params={},
litellm_params={},
headers={},
)
for msg in transformed["messages"]:
content = msg.get("content")
if isinstance(content, list):
for block in content:
assert "cache_control" not in block

@requires_openai
def test_openai_accepts_cache_control_via_litellm(self):
"""OpenAI works fine with cache_control (litellm strips it)."""
provider = LiteLLMProvider(model="gpt-4o-mini")
with handler(provider):
result = simple_prompt("math")
assert isinstance(result, str)

@requires_anthropic
def test_anthropic_accepts_cache_control(self):
"""Anthropic should accept messages with cache_control."""
provider = LiteLLMProvider(model="claude-opus-4-6", max_tokens=20)
with handler(provider):
result = simple_prompt("math")
assert isinstance(result, str)
4 changes: 3 additions & 1 deletion tests/test_handlers_llm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,9 @@ def standalone(topic: str) -> str:
standalone("fish")

assert_single_system_message_first(mock.received_messages[0])
assert mock.received_messages[0][0]["content"] == DEFAULT_SYSTEM_PROMPT
content = mock.received_messages[0][0]["content"]
# System message content is now a list of blocks with cache_control
assert content[0]["text"] == DEFAULT_SYSTEM_PROMPT


class TestAgentDocstringFallback:
Expand Down
Loading