diff --git a/app/celery/tasks/agent_tasks.py b/app/celery/tasks/agent_tasks.py index 217aa043..d7c77f53 100644 --- a/app/celery/tasks/agent_tasks.py +++ b/app/celery/tasks/agent_tasks.py @@ -99,6 +99,21 @@ async def run_agent(): logger.info( f"Agent execution cancelled: {conversation_id}:{run_id}" ) + + # Flush any buffered AI response chunks before cancelling + try: + message_id = service.history_manager.flush_message_buffer( + conversation_id, MessageType.AI_GENERATED + ) + if message_id: + logger.debug( + f"Flushed partial AI response (message_id: {message_id}) for cancelled task: {conversation_id}:{run_id}" + ) + except Exception as e: + logger.warning( + f"Failed to flush message buffer on cancellation: {str(e)}" + ) + # Continue with cancellation even if flush fails redis_manager.publish_event( conversation_id, run_id, @@ -161,6 +176,9 @@ async def run_agent(): f"Background agent execution cancelled: {conversation_id}:{run_id}" ) + # Return the completion status so on_success can check if it was cancelled + return completed + except Exception as e: logger.error( f"Background agent execution failed: {conversation_id}:{run_id}: {str(e)}", @@ -215,6 +233,7 @@ async def run_regeneration(): ConversationStore, ) from app.modules.conversations.message.message_store import MessageStore + from app.modules.conversations.message.message_model import MessageType # Use BaseTask's context manager to get a fresh, non-pooled async session # This avoids asyncpg Future binding issues across tasks sharing the same event loop @@ -261,6 +280,22 @@ async def run_regeneration(): logger.info( f"Regenerate execution cancelled: {conversation_id}:{run_id}" ) + + # Flush any buffered AI response chunks before cancelling + try: + message_id = service.history_manager.flush_message_buffer( + conversation_id, MessageType.AI_GENERATED + ) + if message_id: + logger.debug( + f"Flushed partial AI response (message_id: {message_id}) for cancelled regenerate: {conversation_id}:{run_id}" + ) + except Exception as e: + logger.warning( + f"Failed to flush message buffer on cancellation: {str(e)}" + ) + # Continue with cancellation even if flush fails + redis_manager.publish_event( conversation_id, run_id, @@ -332,6 +367,9 @@ async def run_regeneration(): f"Background regenerate execution cancelled: {conversation_id}:{run_id}" ) + # Return the completion status so on_success can check if it was cancelled + return completed + except Exception as e: logger.error( f"Background regenerate execution failed: {conversation_id}:{run_id}: {str(e)}", diff --git a/app/celery/tasks/base_task.py b/app/celery/tasks/base_task.py index 425a4c41..fb6eae5a 100644 --- a/app/celery/tasks/base_task.py +++ b/app/celery/tasks/base_task.py @@ -82,11 +82,13 @@ def run_async(self, coro): return loop.run_until_complete(coro) def on_success(self, retval, task_id, args, kwargs): - """Called on successful task execution.""" - logger.info(f"Task {task_id} completed successfully") - if self._db: - self._db.close() - self._db = None + try: + status = "cancelled" if retval is False else "completed successfully" + logger.info(f"Task {task_id} {status}") + finally: + if self._db: + self._db.close() # Returns to pool + self._db = None def on_failure(self, exc, task_id, args, kwargs, einfo): """Called on task failure.""" diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index 5cbbc2c9..95922116 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -40,6 +40,9 @@ from app.modules.intelligence.prompts.prompt_service import PromptService from app.modules.intelligence.tools.tool_service import ToolService from app.modules.media.media_service import MediaService +from app.modules.conversations.session.session_service import SessionService +from app.modules.conversations.utils.redis_streaming import RedisStreamManager +from app.celery.celery_app import celery_app from .conversation_store import ConversationStore, StoreError from ..message.message_store import MessageStore @@ -82,6 +85,8 @@ def __init__( agent_service: AgentsService, custom_agent_service: CustomAgentService, media_service: MediaService, + session_service: SessionService = None, + redis_manager: RedisStreamManager = None, ): self.db = db self.user_id = user_id @@ -96,6 +101,10 @@ def __init__( self.agent_service = agent_service self.custom_agent_service = custom_agent_service self.media_service = media_service + # Dependency injection for stop_generation + self.session_service = session_service or SessionService() + self.redis_manager = redis_manager or RedisStreamManager() + self.celery_app = celery_app @classmethod def create( @@ -116,6 +125,8 @@ def create( ) custom_agent_service = CustomAgentService(db, provider_service, tool_service) media_service = MediaService(db) + session_service = SessionService() + redis_manager = RedisStreamManager() return cls( db, @@ -131,6 +142,8 @@ def create( agent_service, custom_agent_service, media_service, + session_service, + redis_manager, ) async def check_conversation_access( @@ -1064,21 +1077,71 @@ async def stop_generation( f"Attempting to stop generation for conversation {conversation_id}, run_id: {run_id}" ) + # If run_id not provided, try to find active session if not run_id: - return { - "status": "error", - "message": "run_id required for stopping background generation", - } + from app.modules.conversations.conversation.conversation_schema import ( + ActiveSessionErrorResponse, + ) + + active_session = self.session_service.get_active_session(conversation_id) + + if isinstance(active_session, ActiveSessionErrorResponse): + # No active session found - this is okay, just return success + # The session might have already completed or been cleared + logger.info( + f"No active session found for conversation {conversation_id} - already stopped or never started" + ) + return { + "status": "success", + "message": "No active session to stop", + } + + run_id = active_session.sessionId + logger.info( + f"Found active session {run_id} for conversation {conversation_id}" + ) # Set cancellation flag in Redis for background task to check - from app.modules.conversations.utils.redis_streaming import RedisStreamManager + self.redis_manager.set_cancellation(conversation_id, run_id) + + # Retrieve and revoke the Celery task + task_id = self.redis_manager.get_task_id(conversation_id, run_id) + + if task_id: + try: + # Revoke the task - this works for both queued and running tasks: + # - For queued tasks: Prevents them from starting execution + # - For running tasks: Sends SIGTERM to terminate them + # terminate=True ensures both cases are handled + self.celery_app.control.revoke( + task_id, terminate=True, signal="SIGTERM" + ) + logger.info( + f"Revoked Celery task {task_id} for {conversation_id}:{run_id} (works for both queued and running tasks)" + ) + except Exception as e: + logger.warning(f"Failed to revoke Celery task {task_id}: {str(e)}") + # Continue anyway - cancellation flag is set + else: + logger.info( + f"No task ID found for {conversation_id}:{run_id} - task may have already completed or been revoked" + ) - redis_manager = RedisStreamManager() - redis_manager.set_cancellation(conversation_id, run_id) + # Always clear the session - publish end event and update status + # This ensures clients know the session is stopped and prevents stale sessions + # This is important even if there's no task_id - it clears any stale session data + # This will also handle the case where stop is called with a stale session_id + try: + self.redis_manager.clear_session(conversation_id, run_id) + except Exception as e: + logger.warning( + f"Failed to clear session for {conversation_id}:{run_id}: {str(e)}" + ) + # Continue anyway - the important part (revocation) is done return { "status": "success", - "message": "Cancellation signal sent to background task", + "message": "Cancellation signal sent and task revoked", } async def rename_conversation( diff --git a/app/modules/conversations/conversations_router.py b/app/modules/conversations/conversations_router.py index 8a1f7ce8..123c896d 100644 --- a/app/modules/conversations/conversations_router.py +++ b/app/modules/conversations/conversations_router.py @@ -355,7 +355,7 @@ async def post_message( ) # Start background task - execute_agent_background.delay( + task_result = execute_agent_background.delay( conversation_id=conversation_id, run_id=run_id, user_id=user_id, @@ -365,6 +365,12 @@ async def post_message( attachment_ids=attachment_ids or [], ) + # Store the Celery task ID for later revocation + redis_manager.set_task_id(conversation_id, run_id, task_result.id) + logger.info( + f"Started agent task {task_result.id} for {conversation_id}:{run_id}" + ) + # Wait for background task to start (with health check) # Increased timeout to 30 seconds to handle queued tasks task_started = redis_manager.wait_for_task_start( @@ -484,7 +490,7 @@ async def regenerate_last_message( }, ) - execute_regenerate_background.delay( + task_result = execute_regenerate_background.delay( conversation_id=conversation_id, run_id=run_id, user_id=user_id, @@ -492,6 +498,12 @@ async def regenerate_last_message( attachment_ids=attachment_ids, ) + # Store the Celery task ID for later revocation + redis_manager.set_task_id(conversation_id, run_id, task_result.id) + logger.info( + f"Started regenerate task {task_result.id} for {conversation_id}:{run_id}" + ) + # Wait for background task to start (with health check) # Increased timeout to 30 seconds to handle queued tasks task_started = redis_manager.wait_for_task_start( diff --git a/app/modules/conversations/utils/redis_streaming.py b/app/modules/conversations/utils/redis_streaming.py index 0f9ca9e4..c5274309 100644 --- a/app/modules/conversations/utils/redis_streaming.py +++ b/app/modules/conversations/utils/redis_streaming.py @@ -197,6 +197,41 @@ def get_task_status(self, conversation_id: str, run_id: str) -> Optional[str]: status = self.redis_client.get(status_key) return status.decode() if status else None + def set_task_id(self, conversation_id: str, run_id: str, task_id: str) -> None: + """Store Celery task ID for this conversation/run""" + task_id_key = f"task:id:{conversation_id}:{run_id}" + self.redis_client.set(task_id_key, task_id, ex=600) # 10 minute expiry + logger.debug(f"Stored task ID {task_id} for {conversation_id}:{run_id}") + + def get_task_id(self, conversation_id: str, run_id: str) -> Optional[str]: + """Get Celery task ID for this conversation/run""" + task_id_key = f"task:id:{conversation_id}:{run_id}" + task_id = self.redis_client.get(task_id_key) + return task_id.decode() if task_id else None + + def clear_session(self, conversation_id: str, run_id: str) -> None: + """Clear session data when stopping - publishes end event and cleans up""" + try: + # Publish an end event with cancelled status so clients know to stop + self.publish_event( + conversation_id, + run_id, + "end", + { + "status": "cancelled", + "message": "Generation stopped by user", + }, + ) + + # Set task status to cancelled + self.set_task_status(conversation_id, run_id, "cancelled") + + logger.info(f"Cleared session for {conversation_id}:{run_id}") + except Exception as e: + logger.error( + f"Failed to clear session for {conversation_id}:{run_id}: {str(e)}" + ) + def wait_for_task_start( self, conversation_id: str, run_id: str, timeout: int = 10 ) -> bool: