Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions app/celery/tasks/agent_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ async def run_agent():
logger.info(
f"Agent execution cancelled: {conversation_id}:{run_id}"
)

# Flush any buffered AI response chunks before cancelling
try:
message_id = service.history_manager.flush_message_buffer(
conversation_id, MessageType.AI_GENERATED
)
if message_id:
logger.debug(
f"Flushed partial AI response (message_id: {message_id}) for cancelled task: {conversation_id}:{run_id}"
)
except Exception as e:
logger.warning(
f"Failed to flush message buffer on cancellation: {str(e)}"
)
# Continue with cancellation even if flush fails
redis_manager.publish_event(
conversation_id,
run_id,
Expand Down Expand Up @@ -161,6 +176,9 @@ async def run_agent():
f"Background agent execution cancelled: {conversation_id}:{run_id}"
)

# Return the completion status so on_success can check if it was cancelled
return completed

except Exception as e:
logger.error(
f"Background agent execution failed: {conversation_id}:{run_id}: {str(e)}",
Expand Down Expand Up @@ -215,6 +233,7 @@ async def run_regeneration():
ConversationStore,
)
from app.modules.conversations.message.message_store import MessageStore
from app.modules.conversations.message.message_model import MessageType

# Use BaseTask's context manager to get a fresh, non-pooled async session
# This avoids asyncpg Future binding issues across tasks sharing the same event loop
Expand Down Expand Up @@ -261,6 +280,22 @@ async def run_regeneration():
logger.info(
f"Regenerate execution cancelled: {conversation_id}:{run_id}"
)

# Flush any buffered AI response chunks before cancelling
try:
message_id = service.history_manager.flush_message_buffer(
conversation_id, MessageType.AI_GENERATED
)
if message_id:
logger.debug(
f"Flushed partial AI response (message_id: {message_id}) for cancelled regenerate: {conversation_id}:{run_id}"
)
except Exception as e:
logger.warning(
f"Failed to flush message buffer on cancellation: {str(e)}"
)
# Continue with cancellation even if flush fails

redis_manager.publish_event(
conversation_id,
run_id,
Expand Down Expand Up @@ -332,6 +367,9 @@ async def run_regeneration():
f"Background regenerate execution cancelled: {conversation_id}:{run_id}"
)

# Return the completion status so on_success can check if it was cancelled
return completed

except Exception as e:
logger.error(
f"Background regenerate execution failed: {conversation_id}:{run_id}: {str(e)}",
Expand Down
12 changes: 7 additions & 5 deletions app/celery/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ def run_async(self, coro):
return loop.run_until_complete(coro)

def on_success(self, retval, task_id, args, kwargs):
"""Called on successful task execution."""
logger.info(f"Task {task_id} completed successfully")
if self._db:
self._db.close()
self._db = None
try:
status = "cancelled" if retval is False else "completed successfully"
logger.info(f"Task {task_id} {status}")
finally:
if self._db:
self._db.close() # Returns to pool
self._db = None

def on_failure(self, exc, task_id, args, kwargs, einfo):
"""Called on task failure."""
Expand Down
79 changes: 71 additions & 8 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from app.modules.intelligence.prompts.prompt_service import PromptService
from app.modules.intelligence.tools.tool_service import ToolService
from app.modules.media.media_service import MediaService
from app.modules.conversations.session.session_service import SessionService
from app.modules.conversations.utils.redis_streaming import RedisStreamManager
from app.celery.celery_app import celery_app
from .conversation_store import ConversationStore, StoreError
from ..message.message_store import MessageStore

Expand Down Expand Up @@ -82,6 +85,8 @@ def __init__(
agent_service: AgentsService,
custom_agent_service: CustomAgentService,
media_service: MediaService,
session_service: SessionService = None,
redis_manager: RedisStreamManager = None,
):
self.db = db
self.user_id = user_id
Expand All @@ -96,6 +101,10 @@ def __init__(
self.agent_service = agent_service
self.custom_agent_service = custom_agent_service
self.media_service = media_service
# Dependency injection for stop_generation
self.session_service = session_service or SessionService()
self.redis_manager = redis_manager or RedisStreamManager()
self.celery_app = celery_app

@classmethod
def create(
Expand All @@ -116,6 +125,8 @@ def create(
)
custom_agent_service = CustomAgentService(db, provider_service, tool_service)
media_service = MediaService(db)
session_service = SessionService()
redis_manager = RedisStreamManager()

return cls(
db,
Expand All @@ -131,6 +142,8 @@ def create(
agent_service,
custom_agent_service,
media_service,
session_service,
redis_manager,
)

async def check_conversation_access(
Expand Down Expand Up @@ -1064,21 +1077,71 @@ async def stop_generation(
f"Attempting to stop generation for conversation {conversation_id}, run_id: {run_id}"
)

# If run_id not provided, try to find active session
if not run_id:
return {
"status": "error",
"message": "run_id required for stopping background generation",
}
from app.modules.conversations.conversation.conversation_schema import (
ActiveSessionErrorResponse,
)

active_session = self.session_service.get_active_session(conversation_id)

if isinstance(active_session, ActiveSessionErrorResponse):
# No active session found - this is okay, just return success
# The session might have already completed or been cleared
logger.info(
f"No active session found for conversation {conversation_id} - already stopped or never started"
)
return {
"status": "success",
"message": "No active session to stop",
}

run_id = active_session.sessionId
logger.info(
f"Found active session {run_id} for conversation {conversation_id}"
)

# Set cancellation flag in Redis for background task to check
from app.modules.conversations.utils.redis_streaming import RedisStreamManager
self.redis_manager.set_cancellation(conversation_id, run_id)

# Retrieve and revoke the Celery task
task_id = self.redis_manager.get_task_id(conversation_id, run_id)

if task_id:
try:
# Revoke the task - this works for both queued and running tasks:
# - For queued tasks: Prevents them from starting execution
# - For running tasks: Sends SIGTERM to terminate them
# terminate=True ensures both cases are handled
self.celery_app.control.revoke(
task_id, terminate=True, signal="SIGTERM"
)
logger.info(
f"Revoked Celery task {task_id} for {conversation_id}:{run_id} (works for both queued and running tasks)"
)
except Exception as e:
logger.warning(f"Failed to revoke Celery task {task_id}: {str(e)}")
# Continue anyway - cancellation flag is set
else:
logger.info(
f"No task ID found for {conversation_id}:{run_id} - task may have already completed or been revoked"
)

redis_manager = RedisStreamManager()
redis_manager.set_cancellation(conversation_id, run_id)
# Always clear the session - publish end event and update status
# This ensures clients know the session is stopped and prevents stale sessions
# This is important even if there's no task_id - it clears any stale session data
# This will also handle the case where stop is called with a stale session_id
try:
self.redis_manager.clear_session(conversation_id, run_id)
except Exception as e:
logger.warning(
f"Failed to clear session for {conversation_id}:{run_id}: {str(e)}"
)
# Continue anyway - the important part (revocation) is done

return {
"status": "success",
"message": "Cancellation signal sent to background task",
"message": "Cancellation signal sent and task revoked",
}

async def rename_conversation(
Expand Down
16 changes: 14 additions & 2 deletions app/modules/conversations/conversations_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ async def post_message(
)

# Start background task
execute_agent_background.delay(
task_result = execute_agent_background.delay(
conversation_id=conversation_id,
run_id=run_id,
user_id=user_id,
Expand All @@ -365,6 +365,12 @@ async def post_message(
attachment_ids=attachment_ids or [],
)

# Store the Celery task ID for later revocation
redis_manager.set_task_id(conversation_id, run_id, task_result.id)
logger.info(
f"Started agent task {task_result.id} for {conversation_id}:{run_id}"
)

# Wait for background task to start (with health check)
# Increased timeout to 30 seconds to handle queued tasks
task_started = redis_manager.wait_for_task_start(
Expand Down Expand Up @@ -484,14 +490,20 @@ async def regenerate_last_message(
},
)

execute_regenerate_background.delay(
task_result = execute_regenerate_background.delay(
conversation_id=conversation_id,
run_id=run_id,
user_id=user_id,
node_ids=request.node_ids or [],
attachment_ids=attachment_ids,
)

# Store the Celery task ID for later revocation
redis_manager.set_task_id(conversation_id, run_id, task_result.id)
logger.info(
f"Started regenerate task {task_result.id} for {conversation_id}:{run_id}"
)

# Wait for background task to start (with health check)
# Increased timeout to 30 seconds to handle queued tasks
task_started = redis_manager.wait_for_task_start(
Expand Down
35 changes: 35 additions & 0 deletions app/modules/conversations/utils/redis_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,41 @@ def get_task_status(self, conversation_id: str, run_id: str) -> Optional[str]:
status = self.redis_client.get(status_key)
return status.decode() if status else None

def set_task_id(self, conversation_id: str, run_id: str, task_id: str) -> None:
"""Store Celery task ID for this conversation/run"""
task_id_key = f"task:id:{conversation_id}:{run_id}"
self.redis_client.set(task_id_key, task_id, ex=600) # 10 minute expiry
logger.debug(f"Stored task ID {task_id} for {conversation_id}:{run_id}")

def get_task_id(self, conversation_id: str, run_id: str) -> Optional[str]:
"""Get Celery task ID for this conversation/run"""
task_id_key = f"task:id:{conversation_id}:{run_id}"
task_id = self.redis_client.get(task_id_key)
return task_id.decode() if task_id else None

def clear_session(self, conversation_id: str, run_id: str) -> None:
"""Clear session data when stopping - publishes end event and cleans up"""
try:
# Publish an end event with cancelled status so clients know to stop
self.publish_event(
conversation_id,
run_id,
"end",
{
"status": "cancelled",
"message": "Generation stopped by user",
},
)

# Set task status to cancelled
self.set_task_status(conversation_id, run_id, "cancelled")

logger.info(f"Cleared session for {conversation_id}:{run_id}")
except Exception as e:
logger.error(
f"Failed to clear session for {conversation_id}:{run_id}: {str(e)}"
)

def wait_for_task_start(
self, conversation_id: str, run_id: str, timeout: int = 10
) -> bool:
Expand Down