Skip to content

Commit 6027079

Browse files
stop functionality & partial db flush (#489)
* stop functionality & partial db flush * chore: Auto-fix pre-commit issues --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent ce3a204 commit 6027079

File tree

5 files changed

+165
-15
lines changed

5 files changed

+165
-15
lines changed

app/celery/tasks/agent_tasks.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,21 @@ async def run_agent():
9999
logger.info(
100100
f"Agent execution cancelled: {conversation_id}:{run_id}"
101101
)
102+
103+
# Flush any buffered AI response chunks before cancelling
104+
try:
105+
message_id = service.history_manager.flush_message_buffer(
106+
conversation_id, MessageType.AI_GENERATED
107+
)
108+
if message_id:
109+
logger.debug(
110+
f"Flushed partial AI response (message_id: {message_id}) for cancelled task: {conversation_id}:{run_id}"
111+
)
112+
except Exception as e:
113+
logger.warning(
114+
f"Failed to flush message buffer on cancellation: {str(e)}"
115+
)
116+
# Continue with cancellation even if flush fails
102117
redis_manager.publish_event(
103118
conversation_id,
104119
run_id,
@@ -161,6 +176,9 @@ async def run_agent():
161176
f"Background agent execution cancelled: {conversation_id}:{run_id}"
162177
)
163178

179+
# Return the completion status so on_success can check if it was cancelled
180+
return completed
181+
164182
except Exception as e:
165183
logger.error(
166184
f"Background agent execution failed: {conversation_id}:{run_id}: {str(e)}",
@@ -215,6 +233,7 @@ async def run_regeneration():
215233
ConversationStore,
216234
)
217235
from app.modules.conversations.message.message_store import MessageStore
236+
from app.modules.conversations.message.message_model import MessageType
218237

219238
# Use BaseTask's context manager to get a fresh, non-pooled async session
220239
# This avoids asyncpg Future binding issues across tasks sharing the same event loop
@@ -261,6 +280,22 @@ async def run_regeneration():
261280
logger.info(
262281
f"Regenerate execution cancelled: {conversation_id}:{run_id}"
263282
)
283+
284+
# Flush any buffered AI response chunks before cancelling
285+
try:
286+
message_id = service.history_manager.flush_message_buffer(
287+
conversation_id, MessageType.AI_GENERATED
288+
)
289+
if message_id:
290+
logger.debug(
291+
f"Flushed partial AI response (message_id: {message_id}) for cancelled regenerate: {conversation_id}:{run_id}"
292+
)
293+
except Exception as e:
294+
logger.warning(
295+
f"Failed to flush message buffer on cancellation: {str(e)}"
296+
)
297+
# Continue with cancellation even if flush fails
298+
264299
redis_manager.publish_event(
265300
conversation_id,
266301
run_id,
@@ -332,6 +367,9 @@ async def run_regeneration():
332367
f"Background regenerate execution cancelled: {conversation_id}:{run_id}"
333368
)
334369

370+
# Return the completion status so on_success can check if it was cancelled
371+
return completed
372+
335373
except Exception as e:
336374
logger.error(
337375
f"Background regenerate execution failed: {conversation_id}:{run_id}: {str(e)}",

app/celery/tasks/base_task.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ def run_async(self, coro):
8282
return loop.run_until_complete(coro)
8383

8484
def on_success(self, retval, task_id, args, kwargs):
85-
"""Called on successful task execution."""
86-
logger.info(f"Task {task_id} completed successfully")
87-
if self._db:
88-
self._db.close()
89-
self._db = None
85+
try:
86+
status = "cancelled" if retval is False else "completed successfully"
87+
logger.info(f"Task {task_id} {status}")
88+
finally:
89+
if self._db:
90+
self._db.close() # Returns to pool
91+
self._db = None
9092

9193
def on_failure(self, exc, task_id, args, kwargs, einfo):
9294
"""Called on task failure."""

app/modules/conversations/conversation/conversation_service.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
from app.modules.intelligence.prompts.prompt_service import PromptService
4141
from app.modules.intelligence.tools.tool_service import ToolService
4242
from app.modules.media.media_service import MediaService
43+
from app.modules.conversations.session.session_service import SessionService
44+
from app.modules.conversations.utils.redis_streaming import RedisStreamManager
45+
from app.celery.celery_app import celery_app
4346
from .conversation_store import ConversationStore, StoreError
4447
from ..message.message_store import MessageStore
4548

@@ -82,6 +85,8 @@ def __init__(
8285
agent_service: AgentsService,
8386
custom_agent_service: CustomAgentService,
8487
media_service: MediaService,
88+
session_service: SessionService = None,
89+
redis_manager: RedisStreamManager = None,
8590
):
8691
self.db = db
8792
self.user_id = user_id
@@ -96,6 +101,10 @@ def __init__(
96101
self.agent_service = agent_service
97102
self.custom_agent_service = custom_agent_service
98103
self.media_service = media_service
104+
# Dependency injection for stop_generation
105+
self.session_service = session_service or SessionService()
106+
self.redis_manager = redis_manager or RedisStreamManager()
107+
self.celery_app = celery_app
99108

100109
@classmethod
101110
def create(
@@ -116,6 +125,8 @@ def create(
116125
)
117126
custom_agent_service = CustomAgentService(db, provider_service, tool_service)
118127
media_service = MediaService(db)
128+
session_service = SessionService()
129+
redis_manager = RedisStreamManager()
119130

120131
return cls(
121132
db,
@@ -131,6 +142,8 @@ def create(
131142
agent_service,
132143
custom_agent_service,
133144
media_service,
145+
session_service,
146+
redis_manager,
134147
)
135148

136149
async def check_conversation_access(
@@ -1064,21 +1077,71 @@ async def stop_generation(
10641077
f"Attempting to stop generation for conversation {conversation_id}, run_id: {run_id}"
10651078
)
10661079

1080+
# If run_id not provided, try to find active session
10671081
if not run_id:
1068-
return {
1069-
"status": "error",
1070-
"message": "run_id required for stopping background generation",
1071-
}
1082+
from app.modules.conversations.conversation.conversation_schema import (
1083+
ActiveSessionErrorResponse,
1084+
)
1085+
1086+
active_session = self.session_service.get_active_session(conversation_id)
1087+
1088+
if isinstance(active_session, ActiveSessionErrorResponse):
1089+
# No active session found - this is okay, just return success
1090+
# The session might have already completed or been cleared
1091+
logger.info(
1092+
f"No active session found for conversation {conversation_id} - already stopped or never started"
1093+
)
1094+
return {
1095+
"status": "success",
1096+
"message": "No active session to stop",
1097+
}
1098+
1099+
run_id = active_session.sessionId
1100+
logger.info(
1101+
f"Found active session {run_id} for conversation {conversation_id}"
1102+
)
10721103

10731104
# Set cancellation flag in Redis for background task to check
1074-
from app.modules.conversations.utils.redis_streaming import RedisStreamManager
1105+
self.redis_manager.set_cancellation(conversation_id, run_id)
1106+
1107+
# Retrieve and revoke the Celery task
1108+
task_id = self.redis_manager.get_task_id(conversation_id, run_id)
1109+
1110+
if task_id:
1111+
try:
1112+
# Revoke the task - this works for both queued and running tasks:
1113+
# - For queued tasks: Prevents them from starting execution
1114+
# - For running tasks: Sends SIGTERM to terminate them
1115+
# terminate=True ensures both cases are handled
1116+
self.celery_app.control.revoke(
1117+
task_id, terminate=True, signal="SIGTERM"
1118+
)
1119+
logger.info(
1120+
f"Revoked Celery task {task_id} for {conversation_id}:{run_id} (works for both queued and running tasks)"
1121+
)
1122+
except Exception as e:
1123+
logger.warning(f"Failed to revoke Celery task {task_id}: {str(e)}")
1124+
# Continue anyway - cancellation flag is set
1125+
else:
1126+
logger.info(
1127+
f"No task ID found for {conversation_id}:{run_id} - task may have already completed or been revoked"
1128+
)
10751129

1076-
redis_manager = RedisStreamManager()
1077-
redis_manager.set_cancellation(conversation_id, run_id)
1130+
# Always clear the session - publish end event and update status
1131+
# This ensures clients know the session is stopped and prevents stale sessions
1132+
# This is important even if there's no task_id - it clears any stale session data
1133+
# This will also handle the case where stop is called with a stale session_id
1134+
try:
1135+
self.redis_manager.clear_session(conversation_id, run_id)
1136+
except Exception as e:
1137+
logger.warning(
1138+
f"Failed to clear session for {conversation_id}:{run_id}: {str(e)}"
1139+
)
1140+
# Continue anyway - the important part (revocation) is done
10781141

10791142
return {
10801143
"status": "success",
1081-
"message": "Cancellation signal sent to background task",
1144+
"message": "Cancellation signal sent and task revoked",
10821145
}
10831146

10841147
async def rename_conversation(

app/modules/conversations/conversations_router.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ async def post_message(
355355
)
356356

357357
# Start background task
358-
execute_agent_background.delay(
358+
task_result = execute_agent_background.delay(
359359
conversation_id=conversation_id,
360360
run_id=run_id,
361361
user_id=user_id,
@@ -365,6 +365,12 @@ async def post_message(
365365
attachment_ids=attachment_ids or [],
366366
)
367367

368+
# Store the Celery task ID for later revocation
369+
redis_manager.set_task_id(conversation_id, run_id, task_result.id)
370+
logger.info(
371+
f"Started agent task {task_result.id} for {conversation_id}:{run_id}"
372+
)
373+
368374
# Wait for background task to start (with health check)
369375
# Increased timeout to 30 seconds to handle queued tasks
370376
task_started = redis_manager.wait_for_task_start(
@@ -484,14 +490,20 @@ async def regenerate_last_message(
484490
},
485491
)
486492

487-
execute_regenerate_background.delay(
493+
task_result = execute_regenerate_background.delay(
488494
conversation_id=conversation_id,
489495
run_id=run_id,
490496
user_id=user_id,
491497
node_ids=request.node_ids or [],
492498
attachment_ids=attachment_ids,
493499
)
494500

501+
# Store the Celery task ID for later revocation
502+
redis_manager.set_task_id(conversation_id, run_id, task_result.id)
503+
logger.info(
504+
f"Started regenerate task {task_result.id} for {conversation_id}:{run_id}"
505+
)
506+
495507
# Wait for background task to start (with health check)
496508
# Increased timeout to 30 seconds to handle queued tasks
497509
task_started = redis_manager.wait_for_task_start(

app/modules/conversations/utils/redis_streaming.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,41 @@ def get_task_status(self, conversation_id: str, run_id: str) -> Optional[str]:
197197
status = self.redis_client.get(status_key)
198198
return status.decode() if status else None
199199

200+
def set_task_id(self, conversation_id: str, run_id: str, task_id: str) -> None:
201+
"""Store Celery task ID for this conversation/run"""
202+
task_id_key = f"task:id:{conversation_id}:{run_id}"
203+
self.redis_client.set(task_id_key, task_id, ex=600) # 10 minute expiry
204+
logger.debug(f"Stored task ID {task_id} for {conversation_id}:{run_id}")
205+
206+
def get_task_id(self, conversation_id: str, run_id: str) -> Optional[str]:
207+
"""Get Celery task ID for this conversation/run"""
208+
task_id_key = f"task:id:{conversation_id}:{run_id}"
209+
task_id = self.redis_client.get(task_id_key)
210+
return task_id.decode() if task_id else None
211+
212+
def clear_session(self, conversation_id: str, run_id: str) -> None:
213+
"""Clear session data when stopping - publishes end event and cleans up"""
214+
try:
215+
# Publish an end event with cancelled status so clients know to stop
216+
self.publish_event(
217+
conversation_id,
218+
run_id,
219+
"end",
220+
{
221+
"status": "cancelled",
222+
"message": "Generation stopped by user",
223+
},
224+
)
225+
226+
# Set task status to cancelled
227+
self.set_task_status(conversation_id, run_id, "cancelled")
228+
229+
logger.info(f"Cleared session for {conversation_id}:{run_id}")
230+
except Exception as e:
231+
logger.error(
232+
f"Failed to clear session for {conversation_id}:{run_id}: {str(e)}"
233+
)
234+
200235
def wait_for_task_start(
201236
self, conversation_id: str, run_id: str, timeout: int = 10
202237
) -> bool:

0 commit comments

Comments
 (0)