Skip to content

Commit 24bbc5a

Browse files
committed
[Recording oracle] Apply comments && small fixes && remove unused code
1 parent 30f84dc commit 24bbc5a

File tree

5 files changed

+58
-79
lines changed

5 files changed

+58
-79
lines changed

packages/examples/cvat/recording-oracle/src/cvat/api_calls.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_last_task_quality_report(task_id: int) -> models.QualityReport | None:
4040
def compute_task_quality_report(
4141
task_id: int,
4242
*,
43-
max_waiting_time: int = 10 * 60,
43+
max_waiting_time: int = 60 * 60,
4444
sleep_interval: float = 0.5,
4545
) -> models.QualityReport:
4646
logger = logging.getLogger("app")
@@ -86,7 +86,10 @@ def get_task_quality_report(
8686
report = get_last_task_quality_report(task_id)
8787
if report and report.created_date > report.target_last_updated:
8888
if logger.isEnabledFor(logging.DEBUG):
89-
logger.debug(f"The latest task({task_id}) quality report({report.id}) is actual")
89+
logger.debug(
90+
f"The latest task({task_id}) quality report({report.id}) is actual. "
91+
"Do not recreate it."
92+
)
9093
return report
9194

9295
return compute_task_quality_report(
@@ -103,20 +106,8 @@ def get_quality_report_data(report_id: int) -> QualityReportData:
103106
)
104107
return QualityReportData(**response.json())
105108

106-
except exceptions.ApiException as e:
107-
logger.exception(f"Exception when calling QualityApi.retrieve_report_data: {e}\n")
108-
raise
109-
110-
111-
def get_job_validation_layout(job_id: int) -> models.JobValidationLayoutRead:
112-
logger = logging.getLogger("app")
113-
with get_api_client() as api_client:
114-
try:
115-
layout, _ = api_client.jobs_api.retrieve_validation_layout(job_id)
116-
return layout
117-
118-
except exceptions.ApiException as e:
119-
logger.exception(f"Exception when calling JobApi.retrieve_validation_layout: {e}\n")
109+
except exceptions.ApiException as ex:
110+
logger.exception(f"Exception when calling QualityApi.retrieve_report_data: {ex}\n")
120111
raise
121112

122113

@@ -125,18 +116,21 @@ def get_task_validation_layout(task_id: int) -> models.TaskValidationLayoutRead:
125116
with get_api_client() as api_client:
126117
try:
127118
layout, _ = api_client.tasks_api.retrieve_validation_layout(task_id)
119+
120+
if logger.isEnabledFor(logging.DEBUG):
121+
logger.debug(f"Retrieved validation layout: {layout}")
122+
128123
return layout
129124

130-
except exceptions.ApiException as e:
131-
logger.exception(f"Exception when calling TaskApi.retrieve_validation_layout: {e}\n")
125+
except exceptions.ApiException as ex:
126+
logger.exception(f"Exception when calling TaskApi.retrieve_validation_layout: {ex}\n")
132127
raise
133128

134129

135130
def get_jobs_quality_reports(parent_id: int) -> dict[int, models.QualityReport]:
136131
logger = logging.getLogger("app")
137132
with get_api_client() as api_client:
138133
try:
139-
# TODO: optimize
140134
reports: list[models.QualityReport] = get_paginated_collection(
141135
api_client.quality_api.list_reports_endpoint, parent_id=parent_id, target="job"
142136
)
@@ -148,16 +142,33 @@ def get_jobs_quality_reports(parent_id: int) -> dict[int, models.QualityReport]:
148142

149143

150144
def update_task_validation_layout(
151-
task_id: int, *, disabled_frames: list[int], honeypot_real_frames: list[int]
145+
task_id: int,
146+
*,
147+
disabled_frames: list[int],
148+
shuffle_honeypots: bool = True,
152149
) -> None:
153150
logger = logging.getLogger("app")
151+
params = {
152+
"disabled_frames": disabled_frames,
153+
}
154+
if shuffle_honeypots:
155+
params["frame_selection_method"] = models.FrameSelectionMethod("random_uniform")
156+
154157
with get_api_client() as api_client:
155-
api_client.tasks_api.partial_update_validation_layout(
156-
task_id,
157-
patched_task_validation_layout_write_request=models.PatchedTaskValidationLayoutWriteRequest(
158-
frame_selection_method="manual",
159-
disabled_frames=disabled_frames,
160-
honeypot_real_frames=honeypot_real_frames,
161-
),
162-
)
158+
try:
159+
validation_layout, _ = api_client.tasks_api.partial_update_validation_layout(
160+
task_id,
161+
patched_task_validation_layout_write_request=models.PatchedTaskValidationLayoutWriteRequest(
162+
**params
163+
),
164+
)
165+
except exceptions.ApiException as ex:
166+
logger.exception(
167+
f"Exception when calling TasksApi.partial_update_validation_layout: {ex}\n"
168+
)
169+
raise
170+
163171
logger.info(f"Validation layout for the task {task_id} has been updated.")
172+
173+
if logger.isEnabledFor(logging.DEBUG):
174+
logger.debug(f"Validation layout: {validation_layout}")

packages/examples/cvat/recording-oracle/src/cvat/interface.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,11 @@
44
class AnnotationInfo(BaseModel):
55
accuracy: float | int
66

7-
class Config:
8-
extra = "allow"
9-
107

118
class FrameResult(BaseModel):
129
conflicts: list[dict]
1310
annotations: AnnotationInfo
1411

15-
class Config:
16-
extra = "allow"
17-
1812

1913
class QualityReportData(BaseModel):
2014
frame_results: dict[str, FrameResult]
21-
22-
class Config:
23-
extra = "allow"

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def __init__(
9494
chain_id: int,
9595
manifest: TaskManifest,
9696
*,
97-
job_annotations: dict[int, io.IOBase],
9897
merged_annotations: io.IOBase,
9998
meta: AnnotationMeta,
10099
gt_stats: _GtStats | None = None,
@@ -104,7 +103,6 @@ def __init__(
104103
self.manifest = manifest
105104

106105
self._initial_gt_stats: _GtStats = gt_stats or {}
107-
self._job_annotations: dict[int, io.IOBase] = job_annotations
108106
self._merged_annotations: io.IOBase = merged_annotations
109107

110108
self._updated_merged_dataset_archive: io.IOBase | None = None
@@ -147,16 +145,16 @@ def _load_job_dataset(self, job_id: int, job_dataset_path: Path) -> dm.Dataset:
147145
)
148146

149147
def _validate_jobs(self):
150-
tempdir = self._require_field(self._temp_dir)
151148
manifest = self._require_field(self.manifest)
152-
job_annotations = self._require_field(self._job_annotations)
149+
meta = self._require_field(self._meta)
153150

154151
job_results: _JobResults = {}
155152
rejected_jobs: _RejectedJobs = {}
156153
self._updated_gt_stats = {}
157154

158-
cvat_task_ids = {job_meta.task_id for job_meta in self._meta.jobs}
159-
job_id_to_task_id = {job_meta.job_id: job_meta.task_id for job_meta in self._meta.jobs}
155+
cvat_task_ids = {job_meta.task_id for job_meta in meta.jobs}
156+
cvat_job_ids = {job_meta.job_id for job_meta in meta.jobs}
157+
job_id_to_task_id = {job_meta.job_id: job_meta.task_id for job_meta in meta.jobs}
160158

161159
task_id_to_quality_report: dict[int, dict] = {}
162160
task_id_to_quality_report_data: dict[int, dict] = {}
@@ -192,12 +190,9 @@ def _validate_jobs(self):
192190

193191
job_id_to_quality_report = cvat_api.get_jobs_quality_reports(task_quality_report.id)
194192

195-
for cvat_job_id, job_annotations_file in job_annotations.items():
193+
for cvat_job_id in cvat_job_ids:
196194
cvat_task_id = job_id_to_task_id[cvat_job_id]
197195

198-
job_dataset_path = tempdir / str(cvat_job_id)
199-
extract_zip_archive(job_annotations_file, job_dataset_path)
200-
201196
job_quality_report = job_id_to_quality_report[cvat_job_id]
202197

203198
accuracy = job_quality_report.summary.accuracy
@@ -340,7 +335,6 @@ def process_intermediate_results( # noqa: PLR0912
340335
escrow_address: str,
341336
chain_id: int,
342337
meta: AnnotationMeta,
343-
job_annotations: dict[int, io.RawIOBase],
344338
merged_annotations: io.RawIOBase,
345339
manifest: TaskManifest,
346340
logger: logging.Logger,
@@ -353,6 +347,7 @@ def process_intermediate_results( # noqa: PLR0912
353347
), # should not happen, but waiting should not block processing
354348
)
355349
if not task:
350+
# Recording Oracle task represents all CVAT tasks related with the escrow
356351
task_id = db_service.create_task(session, escrow_address=escrow_address, chain_id=chain_id)
357352
task = db_service.get_task_by_id(session, task_id, for_update=True)
358353

@@ -369,7 +364,6 @@ def process_intermediate_results( # noqa: PLR0912
369364
escrow_address=escrow_address,
370365
chain_id=chain_id,
371366
manifest=manifest,
372-
job_annotations=job_annotations,
373367
merged_annotations=merged_annotations,
374368
meta=meta,
375369
gt_stats=initial_gt_stats,
@@ -398,26 +392,21 @@ def process_intermediate_results( # noqa: PLR0912
398392
for cvat_task_id, val_frame_ids in cvat_task_id_to_failed_val_frames.items():
399393
task_validation_layout = validation_result.task_id_to_val_layout[cvat_task_id]
400394
intersection = set(val_frame_ids) & set(task_validation_layout.disabled_frames)
395+
401396
if intersection:
402397
logger.error(f"Unexpected case: frames {intersection} were disabled earlier")
403398

404-
updated_disable_frames = task_validation_layout.disabled_frames + val_frame_ids
405-
not_disabled_frames = list(
406-
set(task_validation_layout.validation_frames) - set(updated_disable_frames)
407-
)
399+
upd_disabled_frames = task_validation_layout.disabled_frames + val_frame_ids
408400

409-
rng = np.random.default_rng()
410-
upd_honeypot_real_frames = [
411-
frame
412-
if frame not in updated_disable_frames
413-
else int(rng.choice(not_disabled_frames))
414-
for frame in task_validation_layout.honeypot_real_frames
415-
]
401+
shuffle_honeypots = True
402+
if set(upd_disabled_frames) == set(task_validation_layout.validation_frames):
403+
logger.error("All validation frames were banned. Honeypots will not be shuffled")
404+
shuffle_honeypots = False
416405

417406
cvat_api.update_task_validation_layout(
418407
cvat_task_id,
419-
disabled_frames=updated_disable_frames,
420-
honeypot_real_frames=upd_honeypot_real_frames,
408+
disabled_frames=upd_disabled_frames,
409+
shuffle_honeypots=shuffle_honeypots,
421410
)
422411

423412
if logger.isEnabledFor(logging.DEBUG):

packages/examples/cvat/recording-oracle/src/handlers/validation.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def __init__(
4848
self.data_bucket = BucketAccessInfo.parse_obj(Config.exchange_oracle_storage_config)
4949

5050
self.annotation_meta: annotation.AnnotationMeta | None = None
51-
self.job_annotations: dict[int, bytes] | None = None
5251
self.merged_annotations: bytes | None = None
5352

5453
def set_logger(self, logger: Logger):
@@ -70,23 +69,15 @@ def _download_annotations(self):
7069

7170
data_bucket_client = make_cloud_client(self.data_bucket)
7271

73-
job_annotations = {}
74-
for job_meta in self.annotation_meta.jobs:
75-
job_filename = compose_annotation_results_bucket_filename(
76-
self.escrow_address,
77-
self.chain_id,
78-
job_meta.annotation_filename,
79-
)
80-
job_annotations[job_meta.job_id] = data_bucket_client.download_file(job_filename)
81-
82-
excor_merged_annotation_path = compose_annotation_results_bucket_filename(
72+
exchange_oracle_merged_annotation_path = compose_annotation_results_bucket_filename(
8373
self.escrow_address,
8474
self.chain_id,
8575
annotation.RESULTING_ANNOTATIONS_FILE,
8676
)
87-
merged_annotations = data_bucket_client.download_file(excor_merged_annotation_path)
77+
merged_annotations = data_bucket_client.download_file(
78+
exchange_oracle_merged_annotation_path
79+
)
8880

89-
self.job_annotations = job_annotations
9081
self.merged_annotations = merged_annotations
9182

9283
def _download_results(self):
@@ -97,7 +88,6 @@ def _download_results(self):
9788

9889
def _process_annotation_results(self) -> ValidationResult:
9990
assert self.annotation_meta is not None
100-
assert self.job_annotations is not None
10191
assert self.merged_annotations is not None
10292

10393
# TODO: refactor further
@@ -106,7 +96,6 @@ def _process_annotation_results(self) -> ValidationResult:
10696
escrow_address=self.escrow_address,
10797
chain_id=self.chain_id,
10898
meta=self.annotation_meta,
109-
job_annotations={k: io.BytesIO(v) for k, v in self.job_annotations.items()},
11099
merged_annotations=io.BytesIO(self.merged_annotations),
111100
manifest=self.manifest,
112101
logger=self.logger,
@@ -201,7 +190,6 @@ def _handle_validation_result(self, validation_result: ValidationResult):
201190
OracleWebhookTypes.exchange_oracle,
202191
event=RecordingOracleEvent_SubmissionRejected(
203192
# TODO: send all assignments, handle rejection reason in Exchange Oracle
204-
# change validation frames in these jobs once possible
205193
assignments=[
206194
RecordingOracleEvent_SubmissionRejected.RejectedAssignmentInfo(
207195
assignment_id=job_id_to_assignment_id[rejected_job_id],

packages/examples/cvat/recording-oracle/src/services/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_task_gt_stats(
132132
)
133133

134134

135-
def update_gt_stats(session: Session, task_id: int, values: dict[tuple[int, int], int]):
135+
def update_gt_stats(session: Session, task_id: str, values: dict[tuple[int, int], int]):
136136
# Read more about upsert:
137137
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-upsert-statements
138138

0 commit comments

Comments
 (0)