Skip to content

Commit e8727db

Browse files
committed
[draft] Update recording oracle
1 parent 4fa8d23 commit e8727db

File tree

5 files changed

+314
-1062
lines changed

5 files changed

+314
-1062
lines changed

packages/examples/cvat/recording-oracle/src/core/annotation_meta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class JobMeta(BaseModel):
1010
job_id: int
11+
task_id: int
1112
annotation_filename: Path
1213
annotator_wallet_address: str
1314
assignment_id: str

packages/examples/cvat/recording-oracle/src/core/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,13 @@ def validate(cls) -> None:
223223
raise Exception(" ".join([ex_prefix, str(ex)]))
224224

225225

226+
class CvatConfig:
227+
cvat_url = os.environ.get("CVAT_URL", "http://localhost:8080")
228+
cvat_admin = os.environ.get("CVAT_ADMIN", "admin")
229+
cvat_admin_pass = os.environ.get("CVAT_ADMIN_PASS", "admin")
230+
cvat_org_slug = os.environ.get("CVAT_ORG_SLUG", "org1")
231+
232+
226233
class Config:
227234
port = int(os.environ.get("PORT", 8000))
228235
environment = os.environ.get("ENVIRONMENT", "development")
@@ -243,6 +250,7 @@ class Config:
243250
features = FeaturesConfig
244251
validation = ValidationConfig
245252
encryption_config = EncryptionConfig
253+
cvat_config = CvatConfig
246254

247255
@classmethod
248256
def validate(cls) -> None:
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import json
2+
import logging
3+
from datetime import timedelta
4+
from http import HTTPStatus
5+
from time import sleep
6+
from typing import Any
7+
8+
from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
9+
from cvat_sdk.core.helpers import get_paginated_collection
10+
11+
from src.core.config import Config
12+
from src.utils.time import utcnow
13+
14+
15+
def get_api_client() -> ApiClient:
16+
configuration = Configuration(
17+
host=Config.cvat_config.cvat_url,
18+
username=Config.cvat_config.cvat_admin,
19+
password=Config.cvat_config.cvat_admin_pass,
20+
)
21+
22+
api_client = ApiClient(configuration=configuration)
23+
api_client.set_default_header("X-organization", Config.cvat_config.cvat_org_slug)
24+
25+
return api_client
26+
27+
28+
def get_last_task_quality_report(task_id: int) -> models.QualityReport | None:
29+
with get_api_client() as api_client:
30+
paginated_result, _ = api_client.quality_api.list_reports(
31+
task_id=task_id,
32+
page_size=1,
33+
target="task",
34+
sort="-created_date",
35+
)
36+
assert len(paginated_result.results) <= 1
37+
return paginated_result.results[0] if paginated_result.results else None
38+
39+
40+
def compute_task_quality_report(
41+
task_id: int,
42+
*,
43+
max_waiting_time: int = 10 * 60,
44+
sleep_interval: float = 0.5,
45+
) -> models.QualityReport:
46+
logger = logging.getLogger("app")
47+
start_time = utcnow()
48+
49+
with get_api_client() as api_client:
50+
_, response = api_client.quality_api.create_report(
51+
quality_report_create_request={"task_id": task_id}, _parse_response=False
52+
)
53+
rq_id = json.loads(response.data).get("rq_id")
54+
assert rq_id, (
55+
"CVAT server hasn't returned rq_id in the response "
56+
f"when creating a task({task_id}) quality report"
57+
)
58+
59+
while utcnow() - start_time < timedelta(seconds=max_waiting_time):
60+
_, response = api_client.quality_api.create_report(
61+
rq_id=rq_id, _check_status=False, _parse_response=False
62+
)
63+
match response.status:
64+
case HTTPStatus.CREATED:
65+
report = models.QualityReport._from_openapi_data(**json.loads(response.data))
66+
if logger.isEnabledFor(logging.DEBUG):
67+
logger.debug(f"Created quality report: {report.id}")
68+
69+
return report
70+
case HTTPStatus.ACCEPTED:
71+
sleep(sleep_interval)
72+
continue
73+
case _:
74+
raise Exception(f"Unexpected response status: {response.status}")
75+
76+
raise Exception(f"Task({task_id}) quality report has not been created in time")
77+
78+
79+
def get_task_quality_report(
80+
task_id: int,
81+
*,
82+
max_waiting_time: int = 10 * 60,
83+
sleep_interval: float = 0.5,
84+
) -> models.QualityReport:
85+
logger = logging.getLogger("app")
86+
report = get_last_task_quality_report(task_id)
87+
if report and report.created_date > report.target_last_updated:
88+
if logger.isEnabledFor(logging.DEBUG):
89+
logger.debug(f"The latest task({task_id}) quality report({report.id}) is actual")
90+
return report
91+
92+
return compute_task_quality_report(
93+
task_id, max_waiting_time=max_waiting_time, sleep_interval=sleep_interval
94+
)
95+
96+
97+
def get_quality_report_data(report_id: int) -> dict[str, Any]:
98+
logger = logging.getLogger("app")
99+
with get_api_client() as api_client:
100+
try:
101+
_, response = api_client.quality_api.retrieve_report_data(
102+
report_id, _parse_response=False
103+
)
104+
report_data = json.loads(response.data)
105+
assert report_data
106+
return report_data
107+
108+
except exceptions.ApiException as e:
109+
logger.exception(f"Exception when calling QualityApi.retrieve_report_data: {e}\n")
110+
raise
111+
112+
113+
def get_job_validation_layout(job_id: int) -> models.JobValidationLayoutRead:
114+
logger = logging.getLogger("app")
115+
with get_api_client() as api_client:
116+
try:
117+
layout, _ = api_client.jobs_api.retrieve_validation_layout(job_id)
118+
return layout
119+
120+
except exceptions.ApiException as e:
121+
logger.exception(f"Exception when calling JobApi.retrieve_validation_layout: {e}\n")
122+
raise
123+
124+
125+
def get_task_validation_layout(task_id: int) -> models.TaskValidationLayoutRead:
126+
logger = logging.getLogger("app")
127+
with get_api_client() as api_client:
128+
try:
129+
layout, _ = api_client.tasks_api.retrieve_validation_layout(task_id)
130+
return layout
131+
132+
except exceptions.ApiException as e:
133+
logger.exception(f"Exception when calling TaskApi.retrieve_validation_layout: {e}\n")
134+
raise
135+
136+
137+
def get_jobs_quality_reports(parent_id: int) -> dict[int, models.QualityReport]:
138+
logger = logging.getLogger("app")
139+
with get_api_client() as api_client:
140+
try:
141+
reports: list[models.QualityReport] = get_paginated_collection(
142+
api_client.quality_api.list_reports_endpoint, parent_id=parent_id, target="job"
143+
)
144+
return {report.job_id: report for report in reports}
145+
146+
except exceptions.ApiException as e:
147+
logger.exception(f"Exception when calling QualityApi.list_reports: {e}\n")
148+
raise
149+
150+
151+
def get_quality_report_data(report_id: int) -> dict[str, Any]:
152+
with get_api_client() as api_client:
153+
_, response = api_client.quality_api.retrieve_report_data(report_id, _parse_response=False)
154+
assert response.status == HTTPStatus.OK
155+
return response.json()
156+
157+
158+
def shuffle_honeypots_in_jobs(job_ids: list[int] | int) -> None:
159+
logger = logging.getLogger("app")
160+
161+
if isinstance(job_ids, int):
162+
job_ids = [job_ids]
163+
164+
with get_api_client() as api_client:
165+
for job_id in job_ids:
166+
updated_validation_layout, _ = api_client.jobs_api.partial_update_validation_layout(
167+
job_id,
168+
patched_job_validation_layout_write_request=models.PatchedJobValidationLayoutWriteRequest(
169+
frame_selection_method="random_uniform",
170+
),
171+
)
172+
if logger.isEnabledFor(logging.DEBUG):
173+
logger.debug(
174+
f"Updated validation layout for the job {job_id}: {updated_validation_layout!s}"
175+
)
176+
177+
178+
def disable_validation_frames(task_id: int, *, frames_to_disable: list[int]) -> None:
179+
logger = logging.getLogger("app")
180+
with get_api_client() as api_client:
181+
task_validation_layout, _ = api_client.tasks_api.retrieve_validation_layout(task_id)
182+
disabled_frames = task_validation_layout.disabled_frames
183+
184+
# nothing to update
185+
if not (set(frames_to_disable) - set(disabled_frames)):
186+
if logger.isEnabledFor(logging.DEBUG):
187+
logger.debug(
188+
f"Validation frames {frames_to_disable!r} are already "
189+
f"disabled for the CVAT task {task_id}"
190+
)
191+
return
192+
193+
api_client.tasks_api.partial_update_validation_layout(
194+
task_id,
195+
patched_task_validation_layout_write_request=models.PatchedTaskValidationLayoutWriteRequest(
196+
disabled_frames=sorted(set(disabled_frames + frames_to_disable))
197+
),
198+
)
199+
logger.info(
200+
f"Validation frames {frames_to_disable!r} have been disabled "
201+
f"for the CVAT task {task_id}"
202+
)

0 commit comments

Comments
 (0)