diff --git a/packages/sdk/python/agent_protocol/agent.py b/packages/sdk/python/agent_protocol/agent.py index be0fd146..d0ead730 100644 --- a/packages/sdk/python/agent_protocol/agent.py +++ b/packages/sdk/python/agent_protocol/agent.py @@ -131,6 +131,7 @@ async def execute_agent_task_step( step = await _step_handler(step) step.status = Status.completed + await Agent.db.update_step(task_id, step) return step diff --git a/packages/sdk/python/agent_protocol/db.py b/packages/sdk/python/agent_protocol/db.py index 24efb732..4e94a96e 100644 --- a/packages/sdk/python/agent_protocol/db.py +++ b/packages/sdk/python/agent_protocol/db.py @@ -65,6 +65,15 @@ async def get_step(self, task_id: str, step_id: str) -> Step: async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: raise NotImplementedError + async def update_task(self, task: Task) -> Task: + raise NotImplementedError + + async def update_step(self, task_id: str, step: Step) -> Step: + raise NotImplementedError + + async def update_artifact(self, task_id: str, artifact: Artifact) -> Artifact: + raise NotImplementedError + async def list_tasks(self) -> List[Task]: raise NotImplementedError @@ -99,6 +108,18 @@ async def create_task( self._tasks[task_id] = task return task + async def get_task(self, task_id: str) -> Task: + task = self._tasks.get(task_id, None) + if not task: + raise NotFoundException("Task", task_id) + return task + + async def update_task(self, task: Task) -> Task: + if self._tasks.get(task.task_id, None) is None: + raise NotFoundException("Task", task.task_id) + self._tasks[task.task_id] = task + return task + async def create_step( self, task_id: str, @@ -123,12 +144,6 @@ async def create_step( task.steps.append(step) return step - async def get_task(self, task_id: str) -> Task: - task = self._tasks.get(task_id, None) - if not task: - raise NotFoundException("Task", task_id) - return task - async def get_step(self, task_id: str, step_id: str) -> Step: task = await self.get_task(task_id) step = next(filter(lambda s: s.task_id == task_id, task.steps), None) @@ -136,14 +151,15 @@ async def get_step(self, task_id: str, step_id: str) -> Step: raise NotFoundException("Step", step_id) return step - async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: + async def update_step(self, task_id: str, step: Step) -> Step: task = await self.get_task(task_id) - artifact = next( - filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None - ) - if not artifact: - raise NotFoundException("Artifact", artifact_id) - return artifact + + for i, s in enumerate(task.steps): + if s.step_id == step.step_id: + task.steps[i] = step + return step + + raise NotFoundException("Step", step.step_id) async def create_artifact( self, @@ -169,6 +185,25 @@ async def create_artifact( return artifact + async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: + task = await self.get_task(task_id) + artifact = next( + filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None + ) + if not artifact: + raise NotFoundException("Artifact", artifact_id) + return artifact + + async def update_artifact(self, task_id: str, artifact: Artifact) -> Artifact: + task = await self.get_task(task_id) + + for i, a in enumerate(task.artifacts): + if a.artifact_id == artifact.artifact_id: + task.artifacts[i] = artifact + return artifact + + raise NotFoundException("Artifact", artifact.artifact_id) + async def list_tasks(self) -> List[Task]: return [task for task in self._tasks.values()]