Skip to content

Commit 54517d8

Browse files
committed
Touch parents updated_at when updating a child && update tests
1 parent 718a686 commit 54517d8

File tree

10 files changed

+511
-48
lines changed

10 files changed

+511
-48
lines changed

packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def track_completed_tasks(logger: logging.Logger, session: Session) -> None:
7777
completed_task_ids.append(task.cvat_id)
7878

7979
if completed_task_ids:
80+
# TODO: for_update
81+
cvat_service.touch_projects(
82+
session, {t.project.id for t in tasks if t.cvat_id in completed_task_ids}
83+
)
84+
8085
logger.info(
8186
"Found new completed tasks: {}".format(", ".join(str(t) for t in completed_task_ids))
8287
)
@@ -118,6 +123,16 @@ def track_assignments(logger: logging.Logger) -> None:
118123

119124
cvat_service.expire_assignment(session, assignment.id)
120125

126+
jobs_to_be_updated = [a.job for a in assignments]
127+
tasks_to_be_updated = [j.task for j in jobs_to_be_updated]
128+
projects_to_be_updated = [t.project for t in tasks_to_be_updated]
129+
cvat_service.touch_jobs(session, {job.id for job in jobs_to_be_updated})
130+
cvat_service.touch_tasks(session, {task.id for task in tasks_to_be_updated})
131+
cvat_service.touch_projects(session, {project.id for project in projects_to_be_updated})
132+
del jobs_to_be_updated
133+
del tasks_to_be_updated
134+
del projects_to_be_updated
135+
121136
with SessionLocal.begin() as session:
122137
assignments = cvat_service.get_active_assignments(
123138
session,
@@ -148,6 +163,14 @@ def track_assignments(logger: logging.Logger) -> None:
148163

149164
cvat_service.cancel_assignment(session, assignment.id)
150165

166+
# touch jobs/tasks/projects updated_at
167+
jobs_to_be_updated = [a.job for a in assignments]
168+
tasks_to_be_updated = [j.task for j in jobs_to_be_updated]
169+
projects_to_be_updated = [t.project for t in tasks_to_be_updated]
170+
cvat_service.touch_jobs(session, {job.id for job in jobs_to_be_updated})
171+
cvat_service.touch_tasks(session, {task.id for task in tasks_to_be_updated})
172+
cvat_service.touch_projects(session, {project.id for project in projects_to_be_updated})
173+
151174

152175
@cron_job
153176
def track_completed_escrows(logger: logging.Logger) -> None:
@@ -213,6 +236,7 @@ def track_task_creation(logger: logging.Logger, session: Session) -> None:
213236
)
214237

215238
completed.append(upload)
239+
upload.task.touch(session, touch_parent=True)
216240
except cvat_api.exceptions.ApiException as e:
217241
failed.append(upload)
218242

packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ class JobsFilter(Filter):
5656

5757
class SortingFields(StrEnum, metaclass=BetterEnumMeta):
5858
created_at = auto()
59+
updated_at = auto()
5960
chain_id = auto()
6061
job_type = auto()
61-
reward_amount = auto()
6262

6363
sort: OptionalQuery[OrderingDirection] = OrderingDirection.asc
6464
default_sort_field: ClassVar[SortingFields] = SortingFields.created_at
@@ -227,7 +227,6 @@ class SortingFields(StrEnum, metaclass=BetterEnumMeta):
227227
chain_id = auto()
228228
job_type = auto()
229229
status = auto()
230-
reward_amount = auto()
231230
created_at = auto()
232231
expires_at = auto()
233232

packages/examples/cvat/exchange-oracle/src/endpoints/serializers.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Literal
33

44
from human_protocol_sdk.storage import StorageFileNotFoundError
5-
from sqlalchemy import func
65
from sqlalchemy.orm import Session
76

87
import src.services.cvat as cvat_service
@@ -54,11 +53,6 @@ def serialize_job(
5453
else:
5554
raise NotImplementedError(f"Unknown status {project.status}")
5655

57-
updated_at = (
58-
session.query(func.max(cvat_service.Job.updated_at))
59-
.filter(cvat_service.Job.cvat_project_id == project.cvat_id)
60-
.scalar()
61-
)
6256
return service_api.JobResponse(
6357
escrow_address=project.escrow_address,
6458
chain_id=project.chain_id,
@@ -70,7 +64,7 @@ def serialize_job(
7064
service_api.DEFAULT_TOKEN
7165
), # set a value to avoid being excluded by response_model_exclude_unset=True
7266
created_at=project.created_at,
73-
updated_at=updated_at,
67+
updated_at=project.updated_at,
7468
qualifications=manifest.qualifications,
7569
)
7670

packages/examples/cvat/exchange-oracle/src/handlers/cvat_events.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def handle_update_job_event(payload: dict) -> None:
6666
"Assignment is expired, rejecting the update"
6767
)
6868
cvat_service.expire_assignment(session, matching_assignment.id)
69+
matching_assignment.job.touch(session, touch_parent=True)
6970

7071
if matching_assignment.id == latest_assignment.id:
7172
cvat_api.update_job_assignee(job.cvat_id, assignee_id=None)
@@ -88,6 +89,7 @@ def handle_update_job_event(payload: dict) -> None:
8889
session, matching_assignment.id, completed_at=webhook_time
8990
)
9091
cvat_service.update_job_status(session, job.id, new_status)
92+
job.task.touch(session, touch_parent=True)
9193

9294
cvat_api.update_job_assignee(job.cvat_id, assignee_id=None)
9395

@@ -122,6 +124,12 @@ def handle_create_job_event(payload: dict) -> None:
122124
payload.job["project_id"],
123125
status=JobStatuses[payload.job["state"]],
124126
)
127+
cvat_service.touch_tasks(
128+
session, [payload.job["task_id"]], field=cvat_service.Task.cvat_id
129+
)
130+
cvat_service.touch_projects(
131+
session, [payload.job["project_id"]], field=cvat_service.Project.cvat_id
132+
)
125133

126134
try:
127135
projects = cvat_service.get_projects_by_cvat_ids(

packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def build(self):
358358
)
359359

360360
db_service.create_data_upload(session, cvat_task.id)
361+
db_service.touch_projects(session, [project_id])
361362

362363

363364
class BoxesFromPointsTaskBuilder:
@@ -1394,6 +1395,7 @@ def _create_on_cvat(self):
13941395
)
13951396

13961397
db_service.create_data_upload(session, cvat_task.id)
1398+
db_service.touch_projects(session, [project_id])
13971399

13981400
@classmethod
13991401
def _make_cloud_storage_client(cls, bucket_info: BucketAccessInfo) -> StorageClient:
@@ -2466,6 +2468,7 @@ def _job_params_label_key(ts):
24662468
)
24672469

24682470
db_service.create_data_upload(session, cvat_task.id)
2471+
db_service.touch_projects(session, [project_id])
24692472

24702473
@classmethod
24712474
def _make_cloud_storage_client(cls, bucket_info: BucketAccessInfo) -> StorageClient:

packages/examples/cvat/exchange-oracle/src/models/cvat.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String, UniqueConstraint
5-
from sqlalchemy.orm import Mapped, relationship
5+
from sqlalchemy.orm import Mapped, Session, relationship
66
from sqlalchemy.sql import func
77

88
from src.core.types import (
@@ -65,6 +65,9 @@ class Project(Base):
6565
def __repr__(self) -> str:
6666
return f"Project. id={self.id}"
6767

68+
def touch(self, session: Session) -> None:
69+
session.query(Project).filter(Project.id == self.id).update({Project.updated_at: utcnow()})
70+
6871

6972
class Task(Base):
7073
__tablename__ = "tasks"
@@ -92,6 +95,12 @@ class Task(Base):
9295
def __repr__(self) -> str:
9396
return f"Task. id={self.id}"
9497

98+
def touch(self, session: Session, *, touch_parent: bool = True) -> None:
99+
session.query(Task).filter(Task.id == self.id).update({Task.updated_at: utcnow()})
100+
101+
if touch_parent:
102+
self.project.touch(session)
103+
95104

96105
class EscrowCreation(Base):
97106
__tablename__ = "escrow_creations"
@@ -170,6 +179,13 @@ def latest_assignment(self) -> Assignment | None:
170179
def __repr__(self) -> str:
171180
return f"Job. id={self.id}"
172181

182+
def touch(self, session: Session, *, touch_parent: bool = True) -> None:
183+
# TODO: check .update({})
184+
session.query(Job).filter(Job.id == self.id).update({Job.updated_at: utcnow()})
185+
186+
if touch_parent:
187+
self.task.touch(session, touch_parent=touch_parent)
188+
173189

174190
class User(Base):
175191
__tablename__ = "users"

packages/examples/cvat/exchange-oracle/src/services/cvat.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import itertools
22
import uuid
3-
from collections.abc import Sequence
3+
from collections.abc import Iterable, Sequence
44
from datetime import datetime
5+
from itertools import islice
6+
from typing import Any
57

6-
from sqlalchemy import delete, insert, update
8+
from sqlalchemy import Column, delete, insert, update
79
from sqlalchemy.orm import Session
810

911
from src.core.types import AssignmentStatuses, JobStatuses, ProjectStatuses, TaskStatuses, TaskTypes
@@ -13,6 +15,13 @@
1315
from src.utils.time import utcnow
1416

1517

18+
def batched(iterable: Iterable, *, batch_size: int) -> Iterable[Any]:
19+
assert batch_size > 0
20+
iterator = iter(iterable)
21+
while batch := tuple(islice(iterator, batch_size)):
22+
yield batch
23+
24+
1625
# Project
1726
def create_project(
1827
session: Session,
@@ -225,6 +234,20 @@ def update_project_statuses_by_escrow_address(
225234
session.execute(statement)
226235

227236

237+
def touch_projects(
238+
session: Session,
239+
list_to_update: Iterable[str | int],
240+
*,
241+
field: Column = Project.id,
242+
batch_size: int = 1000,
243+
) -> None:
244+
assert field is Project.id or field is Project.cvat_id
245+
updated_at = utcnow()
246+
for batch in batched(list_to_update, batch_size=batch_size):
247+
stmt = update(Project).where(field.in_(batch)).values({Project.updated_at: updated_at})
248+
session.execute(stmt)
249+
250+
228251
def delete_project(session: Session, project_id: str) -> None:
229252
project = session.query(Project).filter_by(id=project_id).first()
230253
session.delete(project)
@@ -392,11 +415,25 @@ def get_tasks_by_status(
392415
return query.all()
393416

394417

395-
def update_task_status(session: Session, task_id: int, status: TaskStatuses) -> None:
418+
def update_task_status(session: Session, task_id: str, status: TaskStatuses) -> None:
396419
upd = update(Task).where(Task.id == task_id).values(status=status.value)
397420
session.execute(upd)
398421

399422

423+
def touch_tasks(
424+
session: Session,
425+
list_to_update: Iterable[str | int],
426+
*,
427+
field: Column = Task.id,
428+
batch_size: int = 1000,
429+
) -> None:
430+
assert field is Task.id or field is Task.cvat_id
431+
updated_at = utcnow()
432+
for batch in batched(list_to_update, batch_size=batch_size):
433+
stmt = update(Task).where(field.in_(batch)).values({Task.updated_at: updated_at})
434+
session.execute(stmt)
435+
436+
400437
def get_tasks_by_cvat_project_id(
401438
session: Session, cvat_project_id: int, *, for_update: bool | ForUpdateParams = False
402439
) -> list[Task]:
@@ -486,11 +523,25 @@ def get_jobs_by_cvat_id(
486523
)
487524

488525

489-
def update_job_status(session: Session, job_id: int, status: JobStatuses) -> None:
526+
def update_job_status(session: Session, job_id: str, status: JobStatuses) -> None:
490527
upd = update(Job).where(Job.id == job_id).values(status=status.value)
491528
session.execute(upd)
492529

493530

531+
def touch_jobs(
532+
session: Session,
533+
list_to_update: Iterable[str | int],
534+
*,
535+
field: Column = Job.id,
536+
batch_size: int = 1000,
537+
) -> None:
538+
assert field is Job.id or field is Job.cvat_id
539+
updated_at = utcnow()
540+
for batch in batched(list_to_update, batch_size=batch_size):
541+
stmt = update(Job).where(field.in_(batch)).values({Job.updated_at: updated_at})
542+
session.execute(stmt)
543+
544+
494545
def get_jobs_by_cvat_task_id(
495546
session: Session, cvat_task_id: int, *, for_update: bool | ForUpdateParams = False
496547
) -> list[Job]:

packages/examples/cvat/exchange-oracle/src/services/exchange.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s
6262
if not unassigned_job:
6363
return None
6464

65+
cvat_service.get_task_by_id(
66+
session, unassigned_job.task.id, for_update=True
67+
) # lock the row
68+
6569
assignment_id = cvat_service.create_assignment(
6670
session,
6771
wallet_address=user.wallet_address,
@@ -70,6 +74,8 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s
7074
+ timedelta(seconds=get_default_assignment_timeout(TaskTypes(project.job_type))),
7175
)
7276

77+
unassigned_job.touch(session) # project|task|job rows are locked for update
78+
7379
with cvat_api.api_client_context(cvat_api.get_api_client()):
7480
cvat_api.clear_job_annotations(unassigned_job.cvat_id)
7581
cvat_api.restart_job(unassigned_job.cvat_id, assignee_id=user.cvat_id)
@@ -104,3 +110,12 @@ async def resign_assignment(assignment_id: str, wallet_address: str) -> None:
104110
raise NoAccessError
105111

106112
cvat_service.cancel_assignment(session, assignment_id)
113+
114+
job = assignment.job
115+
task = job.task
116+
project = task.project
117+
cvat_service.get_job_by_id(session, job.id, for_update=True) # lock the row
118+
cvat_service.get_task_by_id(session, task.id, for_update=True) # lock the row
119+
cvat_service.get_project_by_id(session, project.id, for_update=True) # lock the row
120+
121+
assignment.job.touch(session) # project|task rows are locked for update

0 commit comments

Comments
 (0)