Skip to content

Commit 878d76f

Browse files
authored
[CVAT] Oracle API updates (#2812)
* Hide escrows in internal states from /job output * Update final assignments on escrow finish
1 parent ed6f783 commit 878d76f

File tree

7 files changed

+171
-41
lines changed

7 files changed

+171
-41
lines changed

packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg
8282

8383
cvat_db_service.update_project_status(db_session, project.id, new_status)
8484

85+
cvat_db_service.touch_final_assignments(
86+
db_session,
87+
cvat_project_ids=[p.cvat_id for p in projects_chunk],
88+
touch_parents=True,
89+
)
90+
8591
cvat_db_service.update_escrow_validation(
8692
db_session,
8793
webhook.escrow_address,

packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from src.endpoints.pagination import Page, paginate
2828
from src.endpoints.serializers import (
2929
ASSIGNMENT_PROJECT_VALIDATION_STATUSES,
30-
PROJECT_ACTIVE_STATUSES,
3130
PROJECT_COMPLETED_STATUSES,
3231
serialize_assignment,
3332
serialize_job,
@@ -117,8 +116,16 @@ async def list_jobs(
117116

118117
query = select(cvat_service.Project)
119118

119+
# These states are internal, they should not be visible through the API
120120
query = query.filter(
121-
cvat_service.Project.status.not_in([ProjectStatuses.creation, ProjectStatuses.deleted])
121+
cvat_service.Project.status.not_in(
122+
[
123+
ProjectStatuses.creation,
124+
ProjectStatuses.deleted,
125+
ProjectStatuses.completed,
126+
ProjectStatuses.validation,
127+
]
128+
)
122129
)
123130

124131
# We need only high-level jobs (i.e. escrows) without project details
@@ -145,7 +152,7 @@ async def list_jobs(
145152
if status:
146153
match status:
147154
case JobStatuses.active:
148-
query = query.filter(cvat_service.Project.status.in_(PROJECT_ACTIVE_STATUSES))
155+
query = query.filter(cvat_service.Project.status == ProjectStatuses.annotation)
149156
case JobStatuses.canceled:
150157
query = query.filter(
151158
cvat_service.Project.status == cvat_service.ProjectStatuses.canceled

packages/examples/cvat/exchange-oracle/src/endpoints/serializers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212
from src.schemas import exchange as service_api
1313
from src.utils.assignments import compose_assignment_url, parse_manifest
1414

15-
PROJECT_ACTIVE_STATUSES = {
16-
ProjectStatuses.annotation,
17-
ProjectStatuses.completed,
18-
ProjectStatuses.validation,
19-
}
2015
PROJECT_COMPLETED_STATUSES = {
2116
ProjectStatuses.recorded,
2217
ProjectStatuses.deleted,
@@ -50,12 +45,12 @@ def serialize_job(
5045

5146
if project.status == ProjectStatuses.canceled:
5247
api_status = service_api.JobStatuses.canceled
53-
elif project.status in PROJECT_ACTIVE_STATUSES:
48+
elif project.status == ProjectStatuses.annotation:
5449
api_status = service_api.JobStatuses.active
5550
elif project.status in PROJECT_COMPLETED_STATUSES:
5651
api_status = service_api.JobStatuses.completed
5752
else:
58-
raise NotImplementedError(f"Unknown status {project.status}")
53+
raise AssertionError(f"Unexpected project status '{project.status}'")
5954

6055
return service_api.JobResponse(
6156
escrow_address=project.escrow_address,

packages/examples/cvat/exchange-oracle/src/models/cvat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def __repr__(self) -> str:
222222
return f"User. wallet_address={self.wallet_address} cvat_id={self.cvat_id}"
223223

224224

225-
class Assignment(BaseUUID):
225+
class Assignment(ChildOf[Job]):
226226
__tablename__ = "assignments"
227227
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
228228
updated_at = Column(

packages/examples/cvat/exchange-oracle/src/services/cvat.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,19 @@ def touch(
979979
time = utcnow()
980980

981981
session.execute(update(cls).where(cls.id.in_(ids)).values({cls.updated_at: time}))
982-
while touch_parents and issubclass(cls, ChildOf):
982+
983+
if touch_parents:
984+
touch_parent_objects(session, cls, ids, time=time)
985+
986+
987+
def touch_parent_objects(
988+
session: Session,
989+
cls: type["Base"],
990+
ids: list[str],
991+
*,
992+
time: datetime | None = None,
993+
):
994+
while issubclass(cls, ChildOf):
983995
parent_cls = cls.parent_cls
984996
foreign_key_column = next(iter(cls.parent.property.local_columns))
985997
parent_id_column = next(iter(foreign_key_column.foreign_keys)).column
@@ -997,3 +1009,49 @@ def touch(
9971009
)
9981010
ids = session.execute(parent_update_stmt).scalars().all()
9991011
cls = parent_cls
1012+
1013+
1014+
def touch_final_assignments(
1015+
session: Session,
1016+
cvat_project_ids: list[int],
1017+
*,
1018+
touch_parents: bool = True,
1019+
time: datetime | None = None,
1020+
) -> None:
1021+
if time is None:
1022+
time = utcnow()
1023+
1024+
last_assignment_time_per_job_id_subquery = (
1025+
select(
1026+
Assignment.cvat_job_id.label("cvat_job_id"),
1027+
func.max(Assignment.created_at).label("max_created_at"),
1028+
)
1029+
.join(Job)
1030+
.where(
1031+
Assignment.status == AssignmentStatuses.completed,
1032+
Job.cvat_project_id.in_(cvat_project_ids),
1033+
)
1034+
.order_by(Assignment.cvat_job_id)
1035+
.group_by(Assignment.cvat_job_id)
1036+
.subquery()
1037+
)
1038+
1039+
last_assignment_ids_query = (
1040+
select(Assignment.id)
1041+
.join(
1042+
last_assignment_time_per_job_id_subquery,
1043+
Assignment.cvat_job_id == last_assignment_time_per_job_id_subquery.c.cvat_job_id,
1044+
isouter=True,
1045+
)
1046+
.where(Assignment.created_at == last_assignment_time_per_job_id_subquery.c.max_created_at)
1047+
)
1048+
1049+
ids = session.execute(
1050+
update(Assignment)
1051+
.where(Assignment.id.in_(last_assignment_ids_query))
1052+
.values({Assignment.updated_at: time})
1053+
.returning(Assignment.id)
1054+
)
1055+
1056+
if touch_parents:
1057+
touch_parent_objects(session, Assignment, ids.scalars().all(), time=time)

packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,7 @@ def test_can_list_jobs_200_with_filters(client: TestClient, session: Session):
406406
ProjectStatuses.validation,
407407
ProjectStatuses.canceled,
408408
ProjectStatuses.recorded,
409-
# TODO: ProjectStatuses.deleted,
410-
# raise NotImplementedError(f"Unknown status {project.status}")
411-
# NotImplementedError: Unknown status deleted
409+
ProjectStatuses.deleted,
412410
]
413411
):
414412
cvat_project = create_project(
@@ -437,7 +435,20 @@ def test_can_list_jobs_200_with_filters(client: TestClient, session: Session):
437435
session.add(assignment)
438436
assignments.append(assignment)
439437
cvat.touch(session, Job, [cvat_job.id])
440-
session.commit() # imitate different created_dates
438+
session.commit() # TODO: imitate different created_dates
439+
440+
visible_projects_ids = set(
441+
p.cvat_id
442+
for p in cvat_projects
443+
if p.status
444+
not in [
445+
ProjectStatuses.creation,
446+
ProjectStatuses.completed,
447+
ProjectStatuses.validation,
448+
ProjectStatuses.deleted,
449+
]
450+
)
451+
visible_projects_count = len(visible_projects_ids)
441452

442453
middle_init_time = utcnow()
443454

@@ -460,23 +471,28 @@ def test_can_list_jobs_200_with_filters(client: TestClient, session: Session):
460471
"status": (
461472
(
462473
APIJobStatuses.active.value,
463-
3,
464-
), # ProjectStatuses::annotation,completed,validation
465-
(APIJobStatuses.completed.value, 1),
466-
(APIJobStatuses.canceled.value, 1),
474+
1,
475+
# ProjectStatuses.annotation
476+
# completed, validation are internal, so hidden
477+
),
478+
(APIJobStatuses.completed.value, 1), # ProjectStatuses.recorded
479+
(APIJobStatuses.canceled.value, 1), # ProjectStatuses.canceled
467480
),
468-
"chain_id": ((cvat_projects[0].chain_id, len(cvat_projects)),),
481+
"chain_id": ((cvat_projects[0].chain_id, visible_projects_count),),
469482
"job_type": (
470-
(cvat_projects[0].job_type, len(cvat_projects)),
483+
(cvat_projects[0].job_type, visible_projects_count),
471484
(TaskTypes.image_boxes_from_points.value, 0),
472485
),
473486
"created_after": (
474-
(str(pre_init_time - timedelta(minutes=1)), len(cvat_projects)),
487+
(str(pre_init_time - timedelta(minutes=1)), visible_projects_count),
475488
(str(post_init_time), 0),
476489
),
477490
"updated_after": (
478-
(str(pre_init_time - timedelta(minutes=1)), len(cvat_projects)),
479-
(str(middle_init_time), len(updated_cvat_project_ids)),
491+
(str(pre_init_time - timedelta(minutes=1)), visible_projects_count),
492+
(
493+
str(middle_init_time),
494+
len(visible_projects_ids.intersection(updated_cvat_project_ids)),
495+
),
480496
(str(post_init_time + timedelta(minutes=1)), 0),
481497
),
482498
}.items():

packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from src.utils.time import utcnow
3030

3131
from tests.utils.constants import DEFAULT_MANIFEST_URL, RECORDING_ORACLE_ADDRESS
32+
from tests.utils.db_helper import create_project_task_and_job
3233

3334
escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67"
3435
chain_id = Networks.localhost.value
@@ -42,22 +43,59 @@ def tearDown(self):
4243
self.session.close()
4344

4445
def test_process_incoming_recording_oracle_webhooks_job_completed_type(self):
45-
project_id = str(uuid.uuid4())
46-
cvat_project = Project(
47-
id=project_id,
46+
user = User(
47+
wallet_address="sample_wallet",
4848
cvat_id=1,
49-
cvat_cloudstorage_id=1,
50-
status=ProjectStatuses.validation.value,
51-
job_type=TaskTypes.image_label_binary.value,
52-
escrow_address=escrow_address,
53-
chain_id=Networks.localhost.value,
54-
bucket_url="https://test.storage.googleapis.com/",
49+
cvat_email="[email protected]",
50+
)
51+
self.session.add(user)
52+
53+
project_creation_time = utcnow() - timedelta(days=1)
54+
project_completion_time = project_creation_time + timedelta(hours=10)
55+
56+
cvat_project, cvat_task, cvat_job = create_project_task_and_job(
57+
self.session, escrow_address, 1
5558
)
59+
cvat_project.status = ProjectStatuses.validation
60+
cvat_project.created_at = project_creation_time
61+
cvat_project.updated_at = project_completion_time
5662
self.session.add(cvat_project)
5763

58-
webhok_id = str(uuid.uuid4())
64+
cvat_task.status = TaskStatuses.completed
65+
cvat_task.created_at = project_creation_time
66+
cvat_task.updated_at = project_completion_time
67+
self.session.add(cvat_task)
68+
69+
cvat_job.status = JobStatuses.completed
70+
cvat_job.created_at = project_creation_time
71+
cvat_job.updated_at = project_completion_time
72+
self.session.add(cvat_job)
73+
74+
cvat_assignment1 = Assignment(
75+
id=str(uuid.uuid4()),
76+
created_at=project_creation_time,
77+
updated_at=project_creation_time + timedelta(hours=1),
78+
expires_at=project_creation_time + timedelta(hours=1),
79+
user_wallet_address=user.wallet_address,
80+
cvat_job_id=cvat_job.cvat_id,
81+
status=AssignmentStatuses.expired,
82+
)
83+
self.session.add(cvat_assignment1)
84+
85+
cvat_assignment2 = Assignment(
86+
id=str(uuid.uuid4()),
87+
created_at=project_creation_time + timedelta(minutes=5),
88+
updated_at=project_completion_time,
89+
expires_at=project_creation_time + timedelta(minutes=5) + timedelta(hours=1),
90+
user_wallet_address=user.wallet_address,
91+
cvat_job_id=cvat_job.cvat_id,
92+
status=AssignmentStatuses.completed,
93+
)
94+
self.session.add(cvat_assignment2)
95+
96+
webhook_id = str(uuid.uuid4())
5997
webhook = Webhook(
60-
id=webhok_id,
98+
id=webhook_id,
6199
signature="signature",
62100
escrow_address=escrow_address,
63101
chain_id=chain_id,
@@ -72,15 +110,25 @@ def test_process_incoming_recording_oracle_webhooks_job_completed_type(self):
72110

73111
process_incoming_recording_oracle_webhooks()
74112

75-
updated_webhook = (
76-
self.session.execute(select(Webhook).where(Webhook.id == webhok_id)).scalars().first()
77-
)
78-
113+
updated_webhook = self.session.query(Webhook).get(webhook_id)
79114
assert updated_webhook.status == OracleWebhookStatuses.completed.value
80115
assert updated_webhook.attempts == 1
81-
db_project = self.session.query(Project).filter_by(id=project_id).first()
82116

117+
db_project = self.session.query(Project).get(cvat_project.id)
83118
assert db_project.status == ProjectStatuses.recorded.value
119+
assert db_project.updated_at > project_completion_time
120+
121+
db_task = self.session.query(Task).get(cvat_task.id)
122+
assert db_task.updated_at > project_completion_time
123+
124+
db_job = self.session.query(Job).get(cvat_job.id)
125+
assert db_job.updated_at > project_completion_time
126+
127+
db_assignment1 = self.session.query(Assignment).get(cvat_assignment1.id)
128+
assert db_assignment1.updated_at < project_completion_time
129+
130+
db_assignment2 = self.session.query(Assignment).get(cvat_assignment2.id)
131+
assert db_assignment2.updated_at > project_completion_time
84132

85133
def test_process_incoming_recording_oracle_webhooks_job_completed_type_invalid_project_status(
86134
self,

0 commit comments

Comments
 (0)