Skip to content

Commit a96d1e7

Browse files
committed
[Recording oracle] Clean up the code
1 parent d70a63b commit a96d1e7

File tree

2 files changed

+70
-67
lines changed

2 files changed

+70
-67
lines changed

packages/examples/cvat/recording-oracle/src/core/annotation_meta.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,4 @@ class AnnotationMeta(BaseModel):
2525
jobs: list[JobMeta]
2626

2727
def skip_jobs(self, job_ids: list[int]):
28-
# self.jobs = [
29-
# job for job in self.jobs if job.job_id not in job_ids
30-
# ]
3128
return AnnotationMeta(jobs=[job for job in self.jobs if job.job_id not in job_ids])

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ class _ValidationResult:
7272
gt_stats: GtStats
7373
task_id_to_val_layout: _TaskIdToValidationLayout
7474
task_id_to_honeypots_mapping: _TaskIdToHoneypotsMapping
75-
job_id_to_task_id: dict[int, int]
7675

7776

7877
T = TypeVar("T")
@@ -137,9 +136,7 @@ def _validate_jobs(self):
137136
rejected_jobs: _RejectedJobs = {}
138137

139138
cvat_task_ids = {job_meta.task_id for job_meta in meta.jobs}
140-
job_id_to_task_id = {job_meta.job_id: job_meta.task_id for job_meta in meta.jobs}
141139

142-
task_id_to_quality_report: dict[int, cvat_api.models.QualityReport] = {}
143140
task_id_to_quality_report_data: dict[int, QualityReportData] = {}
144141
task_id_to_val_layout: dict[int, cvat_api.models.TaskValidationLayoutRead] = {}
145142
task_id_to_honeypots_mapping: dict[int, _HoneypotFrameToValFrame] = {}
@@ -150,8 +147,6 @@ def _validate_jobs(self):
150147

151148
for cvat_task_id in cvat_task_ids:
152149
task_quality_report = cvat_api.get_task_quality_report(cvat_task_id)
153-
task_id_to_quality_report[cvat_task_id] = task_quality_report
154-
155150
task_quality_report_data = cvat_api.get_quality_report_data(task_quality_report.id)
156151
task_id_to_quality_report_data[cvat_task_id] = task_quality_report_data
157152

@@ -170,28 +165,29 @@ def _validate_jobs(self):
170165
}
171166
)
172167

173-
# accepted jobs from the previous epoch are not included
168+
# accepted jobs from the previous epochs are not included
174169
for job_meta in meta.jobs:
175170
cvat_job_id = job_meta.job_id
176-
cvat_task_id = job_id_to_task_id[cvat_job_id]
171+
cvat_task_id = job_meta.task_id
177172

178-
# assess quality of the job honeypots
173+
# assess quality of the job's honeypots
179174
task_quality_report_data = task_id_to_quality_report_data[cvat_task_id]
180175
task_honeypots = {int(frame) for frame in task_quality_report_data.frame_results}
181176
honeypots_mapping = task_id_to_honeypots_mapping[cvat_task_id]
177+
182178
for honeypot in task_honeypots & set(job_meta.job_frame_range):
183-
val_frame_id = honeypots_mapping[honeypot]
179+
val_frame = honeypots_mapping[honeypot]
184180

185181
result = task_quality_report_data.frame_results[str(honeypot)]
186-
self._gt_stats.setdefault((cvat_task_id, val_frame_id), ValidationFrameStats())
182+
self._gt_stats.setdefault((cvat_task_id, val_frame), ValidationFrameStats())
187183
self._gt_stats[
188-
(cvat_task_id, val_frame_id)
184+
(cvat_task_id, val_frame)
189185
].accumulated_quality += result.annotations.accuracy
190186

191187
if result.annotations.accuracy < min_quality:
192-
self._gt_stats[(cvat_task_id, val_frame_id)].failed_attempts += 1
188+
self._gt_stats[(cvat_task_id, val_frame)].failed_attempts += 1
193189
else:
194-
self._gt_stats[(cvat_task_id, val_frame_id)].accepted_attempts += 1
190+
self._gt_stats[(cvat_task_id, val_frame)].accepted_attempts += 1
195191

196192
# assess job quality
197193
job_quality_report = job_id_to_quality_report[cvat_job_id]
@@ -205,16 +201,13 @@ def _validate_jobs(self):
205201

206202
job_results[cvat_job_id] = accuracy
207203

208-
min_quality = manifest.validation.min_quality
209-
210204
if accuracy < min_quality:
211205
rejected_jobs[cvat_job_id] = LowAccuracyError()
212206

213207
self._job_results = job_results
214208
self._rejected_jobs = rejected_jobs
215209
self._task_id_to_val_layout = task_id_to_val_layout
216210
self._task_id_to_honeypots_mapping = task_id_to_honeypots_mapping
217-
self._job_id_to_task_id = job_id_to_task_id
218211

219212
def _restore_original_image_paths(self, merged_dataset: dm.Dataset) -> dm.Dataset:
220213
class RemoveCommonPrefix(dm.ItemTransform):
@@ -327,7 +320,6 @@ def validate(self) -> _ValidationResult:
327320
gt_stats=self._require_field(self._gt_stats),
328321
task_id_to_val_layout=self._require_field(self._task_id_to_val_layout),
329322
task_id_to_honeypots_mapping=self._require_field(self._task_id_to_honeypots_mapping),
330-
job_id_to_task_id=self._require_field(self._job_id_to_task_id),
331323
)
332324

333325

@@ -355,13 +347,12 @@ def process_intermediate_results( # noqa: PLR0912
355347
for validation_result in db_service.get_task_validation_results(session, task.id)
356348
if validation_result.annotation_quality >= manifest.validation.min_quality
357349
]
358-
full_meta = meta
359-
meta = meta.skip_jobs(accepted_cvat_job_ids)
360-
# meta.skip_jobs(accepted_cvat_job_ids)
350+
unchecked_jobs_meta = meta.skip_jobs(accepted_cvat_job_ids)
361351
else:
362352
# Recording Oracle task represents all CVAT tasks related with the escrow
363353
task_id = db_service.create_task(session, escrow_address=escrow_address, chain_id=chain_id)
364354
task = db_service.get_task_by_id(session, task_id, for_update=True)
355+
unchecked_jobs_meta = meta
365356

366357
if logger.isEnabledFor(logging.DEBUG):
367358
logger.debug("process_intermediate_results for escrow %s", escrow_address)
@@ -381,7 +372,7 @@ def process_intermediate_results( # noqa: PLR0912
381372
chain_id=chain_id,
382373
manifest=manifest,
383374
merged_annotations=merged_annotations,
384-
meta=meta,
375+
meta=unchecked_jobs_meta,
385376
gt_stats=gt_stats,
386377
)
387378

@@ -400,18 +391,19 @@ def process_intermediate_results( # noqa: PLR0912
400391

401392
gt_stats = validation_result.gt_stats
402393
if gt_stats:
403-
cvat_task_id_to_failed_val_frames = {} # cvat_task_id: [val_frame_id, ...]
394+
cvat_task_id_to_failed_val_frames: dict[
395+
int, set[int]
396+
] = {} # cvat_task_id: {val_frame_id, ...}
404397
rejected_job_ids = rejected_jobs.keys()
405398

406399
if rejected_job_ids:
407-
job_id_to_frame_range = {
408-
job_meta.job_id: job_meta.job_frame_range for job_meta in meta.jobs
409-
}
400+
job_id_to_task_id = {j.job_id: j.task_id for j in unchecked_jobs_meta.jobs}
401+
job_id_to_frame_range = {j.job_id: j.job_frame_range for j in unchecked_jobs_meta.jobs}
410402

411-
# define which validation frames should be disabled
403+
# find validation frames to be disabled
412404
for rejected_job_id in rejected_job_ids:
413405
job_frame_range = job_id_to_frame_range[rejected_job_id]
414-
cvat_task_id = validation_result.job_id_to_task_id[rejected_job_id]
406+
cvat_task_id = job_id_to_task_id[rejected_job_id]
415407
honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id]
416408
job_honeypots = set(honeypots_mapping.keys()) & set(job_frame_range)
417409
validation_frames = [
@@ -426,79 +418,93 @@ def process_intermediate_results( # noqa: PLR0912
426418
val_frame_stats.failed_attempts >= Config.validation.gt_ban_threshold
427419
and not val_frame_stats.accepted_attempts
428420
):
429-
cvat_task_id_to_failed_val_frames.setdefault(cvat_task_id, []).append(
421+
cvat_task_id_to_failed_val_frames.setdefault(cvat_task_id, set()).add(
430422
val_frame
431423
)
432424

433425
for cvat_task_id, val_frame_ids in cvat_task_id_to_failed_val_frames.items():
434426
task_validation_layout = validation_result.task_id_to_val_layout[cvat_task_id]
435-
intersection = set(val_frame_ids) & set(task_validation_layout.disabled_frames)
427+
intersection = val_frame_ids & set(task_validation_layout.disabled_frames)
436428

437429
if intersection:
438-
logger.error(f"Unexpected case: frames {intersection} were disabled earlier")
430+
logger.error(
431+
"Logical error occurred while disabling validation frames "
432+
f"for the task({task_id}). Frames {intersection} "
433+
"are already disabled."
434+
)
439435

440-
updated_disable_frames = task_validation_layout.disabled_frames + val_frame_ids
441-
not_disabled_frames = list(
436+
updated_disable_frames = task_validation_layout.disabled_frames + list(val_frame_ids)
437+
non_disabled_frames = list(
442438
set(task_validation_layout.validation_frames) - set(updated_disable_frames)
443439
)
440+
444441
updated_task_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy()
445442
task_honeypot_real_frames_index = {
446443
f: idx for idx, f in enumerate(updated_task_honeypot_real_frames)
447444
}
448445

449-
jobs_ = [
450-
job_meta
451-
for job_meta in meta.jobs
452-
if job_meta.job_id in rejected_job_ids and job_meta.task_id == cvat_task_id
446+
rejected_jobs_for_task = [
447+
j
448+
for j in unchecked_jobs_meta.jobs
449+
if j.job_id in rejected_job_ids and j.task_id == cvat_task_id
453450
]
454451

455-
for job in jobs_:
452+
for job in rejected_jobs_for_task:
456453
job_frame_range = job.job_frame_range
457454
honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id]
458455
job_honeypots = set(honeypots_mapping.keys()) & set(job_frame_range)
459-
job_validation_frames = [
460-
val_frame
461-
for honeypot, val_frame in honeypots_mapping.items()
462-
if honeypot in job_honeypots
463-
]
456+
464457
validation_frames_to_replace = [
465-
f for f in job_validation_frames if f in val_frame_ids
458+
honeypots_mapping[honeypot]
459+
for honeypot in job_honeypots
460+
if honeypots_mapping[honeypot] in val_frame_ids
466461
]
467462

468463
# choose new unique validation frames for the job
464+
available_validation_frames = list(
465+
set(non_disabled_frames) - set(validation_frames_to_replace)
466+
)
469467
rng = np.random.default_rng()
470-
new_validation_frames = [
471-
int(f)
472-
for f in rng.choice(
473-
list(set(not_disabled_frames) - set(job_validation_frames)),
474-
replace=False,
475-
size=len(validation_frames_to_replace),
468+
try:
469+
new_validation_frames = [
470+
int(f)
471+
for f in rng.choice(
472+
available_validation_frames,
473+
replace=False,
474+
size=len(validation_frames_to_replace),
475+
)
476+
]
477+
478+
for prev_val_frame, new_val_frame in zip(
479+
validation_frames_to_replace, new_validation_frames, strict=True
480+
):
481+
idx = task_honeypot_real_frames_index[prev_val_frame]
482+
updated_task_honeypot_real_frames[idx] = new_val_frame
483+
except ValueError as ex:
484+
logger.exception(
485+
"Exception occurred while generating new validation frames for "
486+
f"the job {job.job_id}. Tried to replace validation frames "
487+
f"{validation_frames_to_replace!r} by new "
488+
f"{len(validation_frames_to_replace)} frames. \nDetails: {ex!s}"
476489
)
477-
]
478-
for prev_val_frame, new_val_frame in zip(
479-
validation_frames_to_replace, new_validation_frames, strict=False
480-
):
481-
idx = task_honeypot_real_frames_index[prev_val_frame]
482-
updated_task_honeypot_real_frames[idx] = new_val_frame
483490

484-
if set(updated_disable_frames) == set(task_validation_layout.validation_frames):
491+
if set(updated_disable_frames) != set(task_validation_layout.validation_frames):
492+
cvat_api.update_task_validation_layout(
493+
cvat_task_id,
494+
disabled_frames=updated_disable_frames,
495+
honeypot_real_frames=updated_task_honeypot_real_frames,
496+
)
497+
else:
485498
logger.error("All validation frames were banned. Honeypots will not be shuffled")
486499

487-
cvat_api.update_task_validation_layout(
488-
cvat_task_id,
489-
disabled_frames=updated_disable_frames,
490-
honeypot_real_frames=updated_task_honeypot_real_frames,
491-
)
492-
493500
if logger.isEnabledFor(logging.DEBUG):
494501
logger.debug("Updating GT stats: %s", gt_stats)
495502

496503
db_service.update_gt_stats(session, task.id, gt_stats)
497504

498505
job_final_result_ids: dict[int, str] = {}
499506

500-
# # # meta.jobs does not include jobs that have already been accepted on the previous epochs
501-
for job_meta in full_meta.jobs:
507+
for job_meta in meta.jobs:
502508
job = db_service.get_job_by_cvat_id(session, job_meta.job_id)
503509
if not job:
504510
job_id = db_service.create_job(session, task_id=task.id, job_cvat_id=job_meta.job_id)

0 commit comments

Comments
 (0)