Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions backend/openedx_ai_extensions/processors/llm/llm_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

from openedx_ai_extensions.functions.decorators import AVAILABLE_TOOLS
from openedx_ai_extensions.processors.llm.litellm_base_processor import LitellmProcessor
from openedx_ai_extensions.processors.llm.providers import adapt_to_provider, after_tool_call_adaptations
from openedx_ai_extensions.processors.llm.providers import (
adapt_to_provider,
after_tool_call_adaptations,
provider_supports,
)
from openedx_ai_extensions.processors.llm.tool_executor import ToolExecutor
from openedx_ai_extensions.utils import STREAMING_FAILED_MESSAGE, normalize_input_to_text

Expand Down Expand Up @@ -173,8 +177,7 @@ def _call_responses_wrapper(self, params, initialize=False, system_role=None):
response_id = getattr(response, "id", None)
content = self._extract_response_content(response=response)

# Update session with response ID for threading
if response_id:
if response_id and provider_supports(self.provider, "server_side_thread_id"):
self.user_session.remote_response_id = response_id
self.user_session.save()

Expand Down
85 changes: 81 additions & 4 deletions backend/openedx_ai_extensions/processors/llm/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
"""
Provider-specific quirks and adaptations for different LLM providers.
"""
import logging

logger = logging.getLogger(__name__)

_PROVIDER_CAPABILITIES = {
"openai": {
# Provider stores conversation history server-side and returns a response ID.
# Subsequent turns send only the new user message + that ID; the provider
# reconstructs context itself. Without this, full history is fetched from local
# storage and sent on every request.
# Affects: adapt_to_provider (sets previous_response_id, replaces input with new
# user message only), after_tool_call_adaptations (persists new response ID),
# _call_responses_wrapper (skips saving remote_response_id for providers without it).
"server_side_thread_id",
},
"anthropic": {
# Provider supports prompt caching via cache_control on content blocks. Cached
# prefixes are reused at ~10% of normal input token cost for a 5-minute window.
# Without this, every request pays full price for system context and history.
# Two breakpoints per request: last system message (stable course context) and last
# user message (becomes the lookback target for the next turn). See ADR 0010.
# Affects: adapt_to_provider (_apply_multi_turn_cache).
"multi_turn_cache",
},
}


def provider_supports(provider, capability):
"""Return True if the given provider supports the named capability."""
return capability in _PROVIDER_CAPABILITIES.get(provider, set())


# TODO: refactor this module to make it more extensible for future providers
Expand Down Expand Up @@ -31,8 +61,7 @@ def adapt_to_provider(
Returns:
dict: Modified parameters with provider-specific adaptations applied
"""
if provider == "openai":
# OpenAI supports threading via previous_response_id
if provider_supports(provider, "server_side_thread_id"):
if user_session and user_session.remote_response_id and input_data:
params["previous_response_id"] = user_session.remote_response_id
if "input" in params:
Expand All @@ -56,17 +85,65 @@ def adapt_to_provider(
elif "messages" in params:
params["messages"].append({"role": "user", "content": user_prompt})

if provider != "openai" and params.get("stream") and "input" in params:
if not provider_supports(provider, "server_side_thread_id") and params.get("stream") and "input" in params:
# Non-OpenAI providers: convert Responses API shape → Completion API
# shape so that completion() / _completion_with_tools() can be called
# directly, ensuring tool-call events are visible during streaming.
params["messages"] = params.pop("input")
for key in ("previous_response_id", "store", "truncation"):
params.pop(key, None)

if provider_supports(provider, "multi_turn_cache"):
key = "messages" if "messages" in params else "input"
if key in params:
params[key] = _apply_multi_turn_cache(params[key])
Comment thread
felipemontoya marked this conversation as resolved.

return params


def _apply_multi_turn_cache(messages):
"""
Add Anthropic style cache_control breakpoints to the last system and last user messages.

Two breakpoints are sufficient for any conversation length:
- Last system message: stable across all turns (course context never changes).
- Last user message: becomes the lookback target for the next turn. The 20-block
lookback window finds the previous turn's cache entry within 2 steps (one
assistant + one user block), so no additional breakpoints are needed regardless
of conversation length.

History is always stored as plain strings (get_full_message_history filters out
non-string content), so this transformation is request-only and never persisted.
"""
Comment thread
felipemontoya marked this conversation as resolved.
last_system_idx = None
last_user_idx = None
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "system":
last_system_idx = i
elif role == "user":
last_user_idx = i

result = list(messages)
for idx in (last_system_idx, last_user_idx):
if idx is None:
continue
msg = result[idx]
content = msg.get("content", "")
if isinstance(content, str):
result[idx] = {
**msg,
"content": [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}],
}
Comment thread
felipemontoya marked this conversation as resolved.
else:
logger.warning(
"multi_turn_cache: skipping cache_control on role=%r message at index %d "
"— content is %s, not a plain string. Cache breakpoint will be missing for this turn.",
msg.get("role"), idx, type(content).__name__,
)
return result
Comment thread
felipemontoya marked this conversation as resolved.


def after_tool_call_adaptations(provider, params, data=None):
"""
Apply provider-specific modifications to API call parameters after tool calls.
Expand All @@ -80,7 +157,7 @@ def after_tool_call_adaptations(provider, params, data=None):
Returns:
dict: Modified parameters with provider-specific adaptations applied
"""
if provider == "openai":
if provider_supports(provider, "server_side_thread_id"):
if data and hasattr(data, "id"):
params["previous_response_id"] = data.id

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@

logger = logging.getLogger(__name__)

# Filter constants for _process_messages / get_full_message_history.
# Pass a subset (or frozenset()) to control what gets excluded from results.
#
# FILTER_SYSTEM — exclude messages with role == "system"
# FILTER_NON_STRING_CONTENT — exclude role-based messages whose content is
# not a plain non-empty string (e.g. block-format
# content from cache transforms, tool outputs)
#
FILTER_SYSTEM = "system"
FILTER_NON_STRING_CONTENT = "non_string_content"

_DEFAULT_FILTERS = frozenset({FILTER_SYSTEM, FILTER_NON_STRING_CONTENT})


class SubmissionProcessor:
"""Handles OpenEdX submission operations for chat history and persistence"""
Expand Down Expand Up @@ -44,6 +57,7 @@ def _process_messages(
current_messages_count=0,
use_max_context=True,
include_submission_id=False,
filters=_DEFAULT_FILTERS,
):
"""
Retrieve messages from submissions.
Expand All @@ -52,6 +66,9 @@ def _process_messages(

Args:
current_messages_count: Number of messages already loaded in the frontend
filters: Set of FILTER_* constants controlling which messages are excluded.
Defaults to _DEFAULT_FILTERS (excludes system messages and
non-string content). Pass frozenset() to return everything.

Returns:
tuple: (new_messages, has_more) where new_messages is a list of messages
Expand All @@ -66,10 +83,17 @@ def _process_messages(
submission_messages = json.loads(submission["answer"])
timestamp = str(submission.get("created_at") or submission.get("submitted_at") or "")
if submission_messages and isinstance(submission_messages, list):
# Remove system messages if present
submission_messages_copy = [
msg for msg in submission_messages if isinstance(msg, dict) and msg.get("role") != "system"
]
submission_messages_copy = []
for msg in submission_messages:
if not isinstance(msg, dict):
continue
if FILTER_SYSTEM in filters and msg.get("role") == "system":
continue
if FILTER_NON_STRING_CONTENT in filters and "role" in msg:
content = msg.get("content")
if not isinstance(content, str) or not content:
continue
submission_messages_copy.append(msg)
submission_uuid = submission.get("uuid", "")
for msg in submission_messages_copy:
msg["timestamp"] = timestamp
Expand Down Expand Up @@ -215,27 +239,25 @@ def get_submission(self):
)
return None

def get_full_message_history(self):
def get_full_message_history(self, filters=_DEFAULT_FILTERS):
"""
Retrieve the full message history for the current submission.

Args:
filters: Set of FILTER_* constants passed through to _process_messages.
Default behaviour (FILTER_SYSTEM + FILTER_NON_STRING_CONTENT)
returns only user/assistant messages with plain string content,
which is what the LLM input path requires.
Pass frozenset() to retrieve everything stored (system messages,
function calls, block-format content) for debug/admin views.
"""
if self.user_session.local_submission_id:
messages, _ = self._process_messages(use_max_context=False)
cleaned = []
messages, _ = self._process_messages(use_max_context=False, filters=filters)
for msg in messages:
if isinstance(msg, dict):
msg.pop("timestamp", None)
# Only validate content for role-based messages (user/assistant/system).
# Function call/output items use other fields (type, call_id, etc.) and
# legitimately have no content field — never filter those out.
if "role" in msg:
content = msg.get("content")
if not isinstance(content, str) or not content:
continue
cleaned.append(msg)
return cleaned
else:
return None
return messages
return None

def get_full_thread(self):
"""
Expand Down Expand Up @@ -273,7 +295,7 @@ def get_full_thread(self):
}
try:
messages, _ = self._process_messages(
use_max_context=False, include_submission_id=True
use_max_context=False, include_submission_id=True, filters=frozenset()
)
# Sort by timestamp to guarantee chronological order
messages.sort(key=lambda m: m.get("timestamp", ""))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re

from openedx_ai_extensions.processors import LLMProcessor, OpenEdXProcessor
from openedx_ai_extensions.processors.llm.providers import provider_supports
from openedx_ai_extensions.utils import STREAMING_FAILED_MESSAGE, is_generator, normalize_input_to_text
from openedx_ai_extensions.xapi.constants import EVENT_NAME_WORKFLOW_INITIALIZED, EVENT_NAME_WORKFLOW_INTERACTED

Expand Down Expand Up @@ -104,9 +105,10 @@ def _stream_and_save_history(self, generator, input_data, # pylint: disable=too
messages.insert(0, {"role": "user", "content": user_text})

# Re-inject system messages if this was a new thread (and not OpenAI)
if self.llm_processor.get_provider() != "openai" and initial_system_msgs:
provider = self.llm_processor.get_provider()
if not provider_supports(provider, "server_side_thread_id") and initial_system_msgs:
for msg in initial_system_msgs:
messages.insert(0, {"role": msg["role"], "conte{}nt": msg["content"]})
messages.insert(0, {"role": msg["role"], "content": msg["content"]})

try:
submission_processor.update_chat_submission(messages)
Expand Down
Loading
Loading