Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/sdk/python/agent_protocol/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ async def execute_agent_task_step(
step = await _step_handler(step)

step.status = Status.completed
await Agent.db.update_step(task_id, step)
return step


Expand Down
61 changes: 48 additions & 13 deletions packages/sdk/python/agent_protocol/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ async def get_step(self, task_id: str, step_id: str) -> Step:
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
raise NotImplementedError

async def update_task(self, task: Task) -> Task:
raise NotImplementedError

async def update_step(self, task_id: str, step: Step) -> Step:
raise NotImplementedError

async def update_artifact(self, task_id: str, artifact: Artifact) -> Artifact:
raise NotImplementedError

async def list_tasks(self) -> List[Task]:
raise NotImplementedError

Expand Down Expand Up @@ -99,6 +108,18 @@ async def create_task(
self._tasks[task_id] = task
return task

async def get_task(self, task_id: str) -> Task:
task = self._tasks.get(task_id, None)
if not task:
raise NotFoundException("Task", task_id)
return task

async def update_task(self, task: Task) -> Task:
if self._tasks.get(task.task_id, None) is None:
raise NotFoundException("Task", task.task_id)
self._tasks[task.task_id] = task
return task

async def create_step(
self,
task_id: str,
Expand All @@ -123,27 +144,22 @@ async def create_step(
task.steps.append(step)
return step

async def get_task(self, task_id: str) -> Task:
task = self._tasks.get(task_id, None)
if not task:
raise NotFoundException("Task", task_id)
return task

async def get_step(self, task_id: str, step_id: str) -> Step:
task = await self.get_task(task_id)
step = next(filter(lambda s: s.task_id == task_id, task.steps), None)
if not step:
raise NotFoundException("Step", step_id)
return step

async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
async def update_step(self, task_id: str, step: Step) -> Step:
task = await self.get_task(task_id)
artifact = next(
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
)
if not artifact:
raise NotFoundException("Artifact", artifact_id)
return artifact

for i, s in enumerate(task.steps):
if s.step_id == step.step_id:
task.steps[i] = step
return step

raise NotFoundException("Step", step.step_id)

async def create_artifact(
self,
Expand All @@ -169,6 +185,25 @@ async def create_artifact(

return artifact

async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
task = await self.get_task(task_id)
artifact = next(
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
)
if not artifact:
raise NotFoundException("Artifact", artifact_id)
return artifact

async def update_artifact(self, task_id: str, artifact: Artifact) -> Artifact:
task = await self.get_task(task_id)

for i, a in enumerate(task.artifacts):
if a.artifact_id == artifact.artifact_id:
task.artifacts[i] = artifact
return artifact

raise NotFoundException("Artifact", artifact.artifact_id)

async def list_tasks(self) -> List[Task]:
return [task for task in self._tasks.values()]

Expand Down