diff --git a/backend/openedx_ai_extensions/processors/llm/llm_processor.py b/backend/openedx_ai_extensions/processors/llm/llm_processor.py index f9bd1c48..c90b0b51 100644 --- a/backend/openedx_ai_extensions/processors/llm/llm_processor.py +++ b/backend/openedx_ai_extensions/processors/llm/llm_processor.py @@ -12,7 +12,11 @@ from openedx_ai_extensions.functions.decorators import AVAILABLE_TOOLS from openedx_ai_extensions.processors.llm.litellm_base_processor import LitellmProcessor -from openedx_ai_extensions.processors.llm.providers import adapt_to_provider, after_tool_call_adaptations +from openedx_ai_extensions.processors.llm.providers import ( + adapt_to_provider, + after_tool_call_adaptations, + provider_supports, +) from openedx_ai_extensions.processors.llm.tool_executor import ToolExecutor from openedx_ai_extensions.utils import STREAMING_FAILED_MESSAGE, normalize_input_to_text @@ -173,8 +177,7 @@ def _call_responses_wrapper(self, params, initialize=False, system_role=None): response_id = getattr(response, "id", None) content = self._extract_response_content(response=response) - # Update session with response ID for threading - if response_id: + if response_id and provider_supports(self.provider, "server_side_thread_id"): self.user_session.remote_response_id = response_id self.user_session.save() diff --git a/backend/openedx_ai_extensions/processors/llm/providers/__init__.py b/backend/openedx_ai_extensions/processors/llm/providers/__init__.py index c0f7352c..cbacdbd0 100644 --- a/backend/openedx_ai_extensions/processors/llm/providers/__init__.py +++ b/backend/openedx_ai_extensions/processors/llm/providers/__init__.py @@ -1,6 +1,36 @@ """ Provider-specific quirks and adaptations for different LLM providers. """ +import logging + +logger = logging.getLogger(__name__) + +_PROVIDER_CAPABILITIES = { + "openai": { + # Provider stores conversation history server-side and returns a response ID. + # Subsequent turns send only the new user message + that ID; the provider + # reconstructs context itself. Without this, full history is fetched from local + # storage and sent on every request. + # Affects: adapt_to_provider (sets previous_response_id, replaces input with new + # user message only), after_tool_call_adaptations (persists new response ID), + # _call_responses_wrapper (skips saving remote_response_id for providers without it). + "server_side_thread_id", + }, + "anthropic": { + # Provider supports prompt caching via cache_control on content blocks. Cached + # prefixes are reused at ~10% of normal input token cost for a 5-minute window. + # Without this, every request pays full price for system context and history. + # Two breakpoints per request: last system message (stable course context) and last + # user message (becomes the lookback target for the next turn). See ADR 0010. + # Affects: adapt_to_provider (_apply_multi_turn_cache). + "multi_turn_cache", + }, +} + + +def provider_supports(provider, capability): + """Return True if the given provider supports the named capability.""" + return capability in _PROVIDER_CAPABILITIES.get(provider, set()) # TODO: refactor this module to make it more extensible for future providers @@ -31,8 +61,7 @@ def adapt_to_provider( Returns: dict: Modified parameters with provider-specific adaptations applied """ - if provider == "openai": - # OpenAI supports threading via previous_response_id + if provider_supports(provider, "server_side_thread_id"): if user_session and user_session.remote_response_id and input_data: params["previous_response_id"] = user_session.remote_response_id if "input" in params: @@ -56,7 +85,7 @@ def adapt_to_provider( elif "messages" in params: params["messages"].append({"role": "user", "content": user_prompt}) - if provider != "openai" and params.get("stream") and "input" in params: + if not provider_supports(provider, "server_side_thread_id") and params.get("stream") and "input" in params: # Non-OpenAI providers: convert Responses API shape → Completion API # shape so that completion() / _completion_with_tools() can be called # directly, ensuring tool-call events are visible during streaming. @@ -64,9 +93,57 @@ def adapt_to_provider( for key in ("previous_response_id", "store", "truncation"): params.pop(key, None) + if provider_supports(provider, "multi_turn_cache"): + key = "messages" if "messages" in params else "input" + if key in params: + params[key] = _apply_multi_turn_cache(params[key]) + return params +def _apply_multi_turn_cache(messages): + """ + Add Anthropic style cache_control breakpoints to the last system and last user messages. + + Two breakpoints are sufficient for any conversation length: + - Last system message: stable across all turns (course context never changes). + - Last user message: becomes the lookback target for the next turn. The 20-block + lookback window finds the previous turn's cache entry within 2 steps (one + assistant + one user block), so no additional breakpoints are needed regardless + of conversation length. + + History is always stored as plain strings (get_full_message_history filters out + non-string content), so this transformation is request-only and never persisted. + """ + last_system_idx = None + last_user_idx = None + for i, msg in enumerate(messages): + role = msg.get("role") + if role == "system": + last_system_idx = i + elif role == "user": + last_user_idx = i + + result = list(messages) + for idx in (last_system_idx, last_user_idx): + if idx is None: + continue + msg = result[idx] + content = msg.get("content", "") + if isinstance(content, str): + result[idx] = { + **msg, + "content": [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}], + } + else: + logger.warning( + "multi_turn_cache: skipping cache_control on role=%r message at index %d " + "— content is %s, not a plain string. Cache breakpoint will be missing for this turn.", + msg.get("role"), idx, type(content).__name__, + ) + return result + + def after_tool_call_adaptations(provider, params, data=None): """ Apply provider-specific modifications to API call parameters after tool calls. @@ -80,7 +157,7 @@ def after_tool_call_adaptations(provider, params, data=None): Returns: dict: Modified parameters with provider-specific adaptations applied """ - if provider == "openai": + if provider_supports(provider, "server_side_thread_id"): if data and hasattr(data, "id"): params["previous_response_id"] = data.id diff --git a/backend/openedx_ai_extensions/processors/openedx/submission_processor.py b/backend/openedx_ai_extensions/processors/openedx/submission_processor.py index d30d5dc9..03c85f09 100644 --- a/backend/openedx_ai_extensions/processors/openedx/submission_processor.py +++ b/backend/openedx_ai_extensions/processors/openedx/submission_processor.py @@ -10,6 +10,19 @@ logger = logging.getLogger(__name__) +# Filter constants for _process_messages / get_full_message_history. +# Pass a subset (or frozenset()) to control what gets excluded from results. +# +# FILTER_SYSTEM — exclude messages with role == "system" +# FILTER_NON_STRING_CONTENT — exclude role-based messages whose content is +# not a plain non-empty string (e.g. block-format +# content from cache transforms, tool outputs) +# +FILTER_SYSTEM = "system" +FILTER_NON_STRING_CONTENT = "non_string_content" + +_DEFAULT_FILTERS = frozenset({FILTER_SYSTEM, FILTER_NON_STRING_CONTENT}) + class SubmissionProcessor: """Handles OpenEdX submission operations for chat history and persistence""" @@ -44,6 +57,7 @@ def _process_messages( current_messages_count=0, use_max_context=True, include_submission_id=False, + filters=_DEFAULT_FILTERS, ): """ Retrieve messages from submissions. @@ -52,6 +66,9 @@ def _process_messages( Args: current_messages_count: Number of messages already loaded in the frontend + filters: Set of FILTER_* constants controlling which messages are excluded. + Defaults to _DEFAULT_FILTERS (excludes system messages and + non-string content). Pass frozenset() to return everything. Returns: tuple: (new_messages, has_more) where new_messages is a list of messages @@ -66,10 +83,17 @@ def _process_messages( submission_messages = json.loads(submission["answer"]) timestamp = str(submission.get("created_at") or submission.get("submitted_at") or "") if submission_messages and isinstance(submission_messages, list): - # Remove system messages if present - submission_messages_copy = [ - msg for msg in submission_messages if isinstance(msg, dict) and msg.get("role") != "system" - ] + submission_messages_copy = [] + for msg in submission_messages: + if not isinstance(msg, dict): + continue + if FILTER_SYSTEM in filters and msg.get("role") == "system": + continue + if FILTER_NON_STRING_CONTENT in filters and "role" in msg: + content = msg.get("content") + if not isinstance(content, str) or not content: + continue + submission_messages_copy.append(msg) submission_uuid = submission.get("uuid", "") for msg in submission_messages_copy: msg["timestamp"] = timestamp @@ -215,27 +239,25 @@ def get_submission(self): ) return None - def get_full_message_history(self): + def get_full_message_history(self, filters=_DEFAULT_FILTERS): """ Retrieve the full message history for the current submission. + + Args: + filters: Set of FILTER_* constants passed through to _process_messages. + Default behaviour (FILTER_SYSTEM + FILTER_NON_STRING_CONTENT) + returns only user/assistant messages with plain string content, + which is what the LLM input path requires. + Pass frozenset() to retrieve everything stored (system messages, + function calls, block-format content) for debug/admin views. """ if self.user_session.local_submission_id: - messages, _ = self._process_messages(use_max_context=False) - cleaned = [] + messages, _ = self._process_messages(use_max_context=False, filters=filters) for msg in messages: if isinstance(msg, dict): msg.pop("timestamp", None) - # Only validate content for role-based messages (user/assistant/system). - # Function call/output items use other fields (type, call_id, etc.) and - # legitimately have no content field — never filter those out. - if "role" in msg: - content = msg.get("content") - if not isinstance(content, str) or not content: - continue - cleaned.append(msg) - return cleaned - else: - return None + return messages + return None def get_full_thread(self): """ @@ -273,7 +295,7 @@ def get_full_thread(self): } try: messages, _ = self._process_messages( - use_max_context=False, include_submission_id=True + use_max_context=False, include_submission_id=True, filters=frozenset() ) # Sort by timestamp to guarantee chronological order messages.sort(key=lambda m: m.get("timestamp", "")) diff --git a/backend/openedx_ai_extensions/workflows/orchestrators/threaded_orchestrator.py b/backend/openedx_ai_extensions/workflows/orchestrators/threaded_orchestrator.py index 3001625f..a40144ac 100644 --- a/backend/openedx_ai_extensions/workflows/orchestrators/threaded_orchestrator.py +++ b/backend/openedx_ai_extensions/workflows/orchestrators/threaded_orchestrator.py @@ -7,6 +7,7 @@ import re from openedx_ai_extensions.processors import LLMProcessor, OpenEdXProcessor +from openedx_ai_extensions.processors.llm.providers import provider_supports from openedx_ai_extensions.utils import STREAMING_FAILED_MESSAGE, is_generator, normalize_input_to_text from openedx_ai_extensions.xapi.constants import EVENT_NAME_WORKFLOW_INITIALIZED, EVENT_NAME_WORKFLOW_INTERACTED @@ -104,9 +105,10 @@ def _stream_and_save_history(self, generator, input_data, # pylint: disable=too messages.insert(0, {"role": "user", "content": user_text}) # Re-inject system messages if this was a new thread (and not OpenAI) - if self.llm_processor.get_provider() != "openai" and initial_system_msgs: + provider = self.llm_processor.get_provider() + if not provider_supports(provider, "server_side_thread_id") and initial_system_msgs: for msg in initial_system_msgs: - messages.insert(0, {"role": msg["role"], "conte{}nt": msg["content"]}) + messages.insert(0, {"role": msg["role"], "content": msg["content"]}) try: submission_processor.update_chat_submission(messages) diff --git a/backend/tests/test_providers.py b/backend/tests/test_providers.py new file mode 100644 index 00000000..1dd37d57 --- /dev/null +++ b/backend/tests/test_providers.py @@ -0,0 +1,277 @@ +""" +Tests for processors/llm/providers — provider capability registry, +adapt_to_provider, and _apply_multi_turn_cache. + +Focus: Anthropic multi-turn flow where the full message history is sent on +every request and cache_control breakpoints are applied to the last system +and last user messages. +""" +# pylint: disable=invalid-sequence-index + +from openedx_ai_extensions.processors.llm.providers import _apply_multi_turn_cache, adapt_to_provider, provider_supports + + +def _make_session(remote_response_id=None, local_submission_id=None): + """Return a minimal mock session object.""" + session = type("Session", (), { + "remote_response_id": remote_response_id, + "local_submission_id": local_submission_id, + "save": lambda self: None, + })() + return session + + +def _roles(messages): + """Return the list of roles from a message list.""" + return [m.get("role") for m in messages] + + +def _cache_controlled_indices(messages): + """Return indices of messages that carry cache_control.""" + result = [] + for i, msg in enumerate(messages): + content = msg.get("content") + if isinstance(content, list) and any("cache_control" in b for b in content): + result.append(i) + return result + + +class TestProviderSupports: + """Tests for provider_supports function.""" + + def test_openai_server_side_thread_id(self): + assert provider_supports("openai", "server_side_thread_id") is True + + def test_anthropic_multi_turn_cache(self): + assert provider_supports("anthropic", "multi_turn_cache") is True + + def test_anthropic_does_not_support_server_side_thread_id(self): + assert provider_supports("anthropic", "server_side_thread_id") is False + + def test_openai_does_not_support_multi_turn_cache(self): + assert provider_supports("openai", "multi_turn_cache") is False + + def test_unknown_provider_returns_false(self): + assert provider_supports("unknown_llm", "multi_turn_cache") is False + + def test_unknown_capability_returns_false(self): + assert provider_supports("anthropic", "nonexistent_capability") is False + + +class TestApplyMultiTurnCache: + """Tests for _apply_multi_turn_cache function.""" + + def _system(self, text): + return {"role": "system", "content": text} + + def _user(self, text): + return {"role": "user", "content": text} + + def _assistant(self, text): + return {"role": "assistant", "content": text} + + def test_marks_last_system_and_last_user(self): + messages = [ + self._system("You are a helpful assistant."), + self._system("Course context: unit 1 content."), + self._user("What is this about?"), + self._assistant("It is about unit 1."), + self._user("Tell me more."), + ] + result = _apply_multi_turn_cache(messages) + + # Only 2 breakpoints: last system (idx 1) and last user (idx 4) + assert _cache_controlled_indices(result) == [1, 4] + + def test_non_targeted_messages_are_unchanged(self): + messages = [ + self._system("System A."), + self._system("System B."), + self._user("Question 1?"), + self._assistant("Answer 1."), + self._user("Question 2?"), + ] + result = _apply_multi_turn_cache(messages) + + # First system, first user, and assistant are plain strings + assert isinstance(result[0]["content"], str) + assert isinstance(result[2]["content"], str) + assert isinstance(result[3]["content"], str) + + def test_content_wrapped_in_text_block(self): + messages = [ + self._system("Stable system prompt."), + self._user("Current question."), + ] + result = _apply_multi_turn_cache(messages) + + system_content = result[0]["content"] + user_content = result[1]["content"] + + assert isinstance(system_content, list) + assert system_content[0]["type"] == "text" + assert system_content[0]["text"] == "Stable system prompt." + assert system_content[0]["cache_control"] == {"type": "ephemeral"} + + assert isinstance(user_content, list) + assert user_content[0]["text"] == "Current question." + assert user_content[0]["cache_control"] == {"type": "ephemeral"} + + def test_original_message_dict_is_not_mutated(self): + original = {"role": "system", "content": "Unchanged."} + messages = [original, {"role": "user", "content": "Q?"}] + _apply_multi_turn_cache(messages) + assert original["content"] == "Unchanged." + + def test_no_user_message_only_marks_system(self): + messages = [ + self._system("System only."), + ] + result = _apply_multi_turn_cache(messages) + assert _cache_controlled_indices(result) == [0] + + def test_no_system_message_only_marks_user(self): + messages = [ + self._user("No system here."), + self._assistant("Reply."), + self._user("Follow-up."), + ] + result = _apply_multi_turn_cache(messages) + assert _cache_controlled_indices(result) == [2] + + def test_long_conversation_still_uses_two_breakpoints(self): + """The 2-breakpoint strategy must hold regardless of conversation length.""" + messages = [self._system("Ctx.")] + for i in range(10): + messages.append(self._user(f"Q{i}")) + messages.append(self._assistant(f"A{i}")) + messages.append(self._user("Final question.")) + + result = _apply_multi_turn_cache(messages) + assert len(_cache_controlled_indices(result)) == 2 + + def test_already_block_format_content_is_not_double_wrapped(self): + """If content is already a list (not a string), it is left as-is.""" + block_content = [{"type": "text", "text": "Already wrapped."}] + messages = [ + {"role": "system", "content": block_content}, + {"role": "user", "content": "Question."}, + ] + result = _apply_multi_turn_cache(messages) + # System was already a list — not re-wrapped + assert result[0]["content"] is block_content + # User (string) is wrapped normally + assert isinstance(result[1]["content"], list) + + +class TestAdaptToProviderAnthropic: + """ + Verify that adapt_to_provider correctly handles the Anthropic case: + - Full message history is sent on every request (no server-side threading) + - cache_control breakpoints are applied to last system + last user messages + - The input→messages conversion for streaming works alongside caching + """ + + def _base_params(self, stream=False): + return { + "stream": stream, + "input": [ + {"role": "system", "content": "You are a course assistant."}, + {"role": "system", "content": "Course context: chapter 3."}, + {"role": "user", "content": "Summarize chapter 3."}, + {"role": "assistant", "content": "Chapter 3 covers..."}, + {"role": "user", "content": "What are the key points?"}, + ], + } + + def test_cache_applied_to_last_system_and_last_user(self): + params = self._base_params() + result = adapt_to_provider("anthropic", params) + + msgs = result["input"] + cached = _cache_controlled_indices(msgs) + # Last system is index 1, last user is index 4 + assert cached == [1, 4] + + def test_all_five_messages_are_present(self): + """No messages are dropped — full history is sent.""" + params = self._base_params() + result = adapt_to_provider("anthropic", params) + assert len(result["input"]) == 5 + + def test_roles_preserved(self): + params = self._base_params() + result = adapt_to_provider("anthropic", params) + assert _roles(result["input"]) == [ + "system", "system", "user", "assistant", "user" + ] + + def test_assistant_message_not_cached(self): + params = self._base_params() + result = adapt_to_provider("anthropic", params) + assistant_msg = result["input"][3] + assert isinstance(assistant_msg["content"], str) + + def test_streaming_converts_input_to_messages_then_caches(self): + """For streaming, input is renamed to messages before caching runs.""" + params = self._base_params(stream=True) + result = adapt_to_provider("anthropic", params) + + assert "input" not in result + assert "messages" in result + msgs = result["messages"] + assert len(msgs) == 5 + assert _cache_controlled_indices(msgs) == [1, 4] + + def test_no_server_side_thread_id_set(self): + """Anthropic must never receive previous_response_id.""" + session = _make_session(remote_response_id="some-id") + params = self._base_params() + result = adapt_to_provider( + "anthropic", params, user_session=session, input_data="What are the key points?" + ) + assert "previous_response_id" not in result + + def test_dummy_user_message_injected_when_no_user_message(self): + """Anthropic requires at least one user message; a dummy is added if missing.""" + params = { + "stream": False, + "input": [ + {"role": "system", "content": "You are a course assistant."}, + {"role": "system", "content": "Course context."}, + ], + } + result = adapt_to_provider("anthropic", params, has_user_input=False) + roles = _roles(result["input"]) + assert "user" in roles + + def test_dummy_message_also_gets_cache_control(self): + """The injected dummy user message is the last user message, so it gets cached.""" + params = { + "stream": False, + "input": [ + {"role": "system", "content": "System prompt."}, + ], + } + result = adapt_to_provider("anthropic", params, has_user_input=False) + # The dummy user message should be the last and should carry cache_control + last = result["input"][-1] + assert last["role"] == "user" + assert isinstance(last["content"], list) + assert last["content"][0]["cache_control"] == {"type": "ephemeral"} + + +class TestAdaptToProviderOpenAIUnaffected: + """Tests for adapt_to_provider with OpenAI.""" + + def test_openai_input_not_cache_transformed(self): + params = { + "stream": False, + "input": [ + {"role": "system", "content": "Sys."}, + {"role": "user", "content": "Q?"}, + ], + } + result = adapt_to_provider("openai", params) + for msg in result["input"]: + assert isinstance(msg["content"], str) diff --git a/docs/decisions/0010-anthropic-multi-turn-prompt-caching.md b/docs/decisions/0010-anthropic-multi-turn-prompt-caching.md new file mode 100644 index 00000000..87b78f5f --- /dev/null +++ b/docs/decisions/0010-anthropic-multi-turn-prompt-caching.md @@ -0,0 +1,47 @@ +# 0010 - Anthropic multi-turn prompt caching via explicit cache breakpoints + +## Status + +Proposed + +## Context + +When using Anthropic as the LLM provider, conversation history is managed client-side: on every turn the full message list (system context + all prior turns + current user message) is sent to the API. The system context alone (course content extracted by `OpenEdXProcessor`) can be several thousand tokens. Sending it uncached on every turn is expensive and slow. + +Anthropic supports prompt caching via `cache_control` markers on individual content blocks. Cached prefixes are reused across requests for a 5-minute window (refreshed on each hit), at 10% of the normal input token price. Cache writes cost 25% more than base input tokens, so the break-even is any prefix read more than once within the TTL — which is guaranteed for every turn after the first in an active conversation. + +Two caching strategies exist: + +- **Automatic caching**: a single top-level `cache_control` field; Anthropic moves the breakpoint automatically. Simple, but requires the Anthropic API directly — support through LiteLLM's emulation layer is not guaranteed. +- **Explicit breakpoints**: `cache_control` placed on individual content blocks. Fully supported through LiteLLM's `completion()` call, which is the path used for Anthropic in this codebase. + +[Official docs](https://platform.claude.com/docs/en/build-with-claude/prompt-caching) + +The naive explicit approach (marking every user message) hits Anthropic's limit of 4 breakpoints per request once a conversation exceeds 3 turns. + +## Decision + +Use **two explicit cache breakpoints** per request, applied as a request-time transformation in `adapt_to_provider`: + +1. **Last system message** — the course context and instructions are identical on every turn; an explicit breakpoint here ensures a cache hit from the second turn onward regardless of conversation length. +2. **Last user message** (current turn) — this becomes the cache entry that the *next* turn's lookback window will find. Anthropic's lookback walks backward up to 20 blocks from the new breakpoint; since each turn adds exactly 2 blocks (one assistant + one user), the lookback always finds the previous cache entry within 2 steps. + +The transformation is encapsulated in `_apply_multi_turn_cache()` in `processors/llm/providers/__init__.py` and is gated by a `multi_turn_cache` entry in `_PROVIDER_CAPABILITIES["anthropic"]`, consistent with the `provider_supports()` pattern introduced for `server_side_thread_id`. + +The transformation converts the `content` field of the targeted messages from a plain string to Anthropic's block format: + +```json +{"type": "text", "text": "...", "cache_control": {"type": "ephemeral"}} +``` + +This is applied after all other `adapt_to_provider` transforms (user-message injection, streaming key conversion) so it operates on the final message list regardless of which key (`input` or `messages`) is in use. + +The transformation is **never persisted**. `get_full_message_history()` in the submission processor filters out any message whose `content` is not a plain string, so history always round-trips as plain strings and the block format is reconstructed fresh on each request. + +## Consequences + +- System context and growing conversation history are cached at Anthropic from the second turn onward, reducing input token costs and latency for active conversations. +- The 2-breakpoint strategy stays within Anthropic's 4-breakpoint limit for conversations of any length. +- Adding a new provider that supports prompt caching requires only adding `multi_turn_cache` to its `_PROVIDER_CAPABILITIES` entry; no other code changes are needed. +- The minimum cacheable prompt length for `claude-sonnet-4-6` is 1,024 tokens. Requests below this threshold are silently processed without caching; no error is returned. +- Cache hits and misses are visible in the Anthropic API response under `usage.cache_read_input_tokens` and `usage.cache_creation_input_tokens`.