Skip to content

Commit 1b5faa9

Browse files
authored
[CVAT] Refactor and fix honeypot rerolls (#2860)
* Refactor and fix honeypot rerolls * Fix and satisfy linter * Add test * Fix and refactor tests
1 parent 430d1d3 commit 1b5faa9

File tree

3 files changed

+356
-79
lines changed

3 files changed

+356
-79
lines changed

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,12 @@
5454

5555
_RejectedJobs = dict[int, DatasetValidationError]
5656

57-
58-
_TaskIdToValidationLayout = dict[int, dict]
59-
_TaskIdToHoneypotsMapping = dict[int, dict]
60-
_TaskIdToSequenceOfFrameNames = dict[int, list[str]]
61-
6257
_HoneypotFrameId = int
6358
_ValidationFrameId = int
6459
_HoneypotFrameToValFrame = dict[_HoneypotFrameId, _ValidationFrameId]
60+
_TaskIdToValidationLayout = dict[int, cvat_api.models.ITaskValidationLayoutRead]
61+
_TaskIdToHoneypotsMapping = dict[int, _HoneypotFrameToValFrame]
62+
_TaskIdToFrameNames = dict[int, list[str]]
6563

6664

6765
@dataclass
@@ -72,7 +70,7 @@ class _ValidationResult:
7270
gt_stats: GtStats
7371
task_id_to_val_layout: _TaskIdToValidationLayout
7472
task_id_to_honeypots_mapping: _TaskIdToHoneypotsMapping
75-
task_id_to_sequence_of_frame_names: _TaskIdToSequenceOfFrameNames
73+
task_id_to_frame_names: _TaskIdToFrameNames
7674

7775

7876
T = TypeVar("T")
@@ -349,9 +347,7 @@ def validate(self) -> _ValidationResult:
349347
gt_stats=self._require_field(self._gt_stats),
350348
task_id_to_val_layout=self._require_field(self._task_id_to_val_layout),
351349
task_id_to_honeypots_mapping=self._require_field(self._task_id_to_honeypots_mapping),
352-
task_id_to_sequence_of_frame_names=self._require_field(
353-
self._task_id_to_sequence_of_frame_names
354-
),
350+
task_id_to_frame_names=self._require_field(self._task_id_to_sequence_of_frame_names),
355351
)
356352

357353

@@ -436,16 +432,16 @@ def process_intermediate_results( # noqa: PLR0912
436432
for rejected_job_id in rejected_job_ids:
437433
job_frame_range = job_id_to_frame_range[rejected_job_id]
438434
cvat_task_id = job_id_to_task_id[rejected_job_id]
439-
honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id]
440-
job_honeypots = set(honeypots_mapping.keys()) & set(job_frame_range)
435+
task_honeypots_mapping = validation_result.task_id_to_honeypots_mapping[
436+
cvat_task_id
437+
]
438+
job_honeypots = sorted(set(task_honeypots_mapping.keys()) & set(job_frame_range))
441439
validation_frames = [
442440
val_frame
443-
for honeypot, val_frame in honeypots_mapping.items()
441+
for honeypot, val_frame in task_honeypots_mapping.items()
444442
if honeypot in job_honeypots
445443
]
446-
sorted_task_frame_names = validation_result.task_id_to_sequence_of_frame_names[
447-
cvat_task_id
448-
]
444+
sorted_task_frame_names = validation_result.task_id_to_frame_names[cvat_task_id]
449445

450446
for val_frame in validation_frames:
451447
val_frame_name = sorted_task_frame_names[val_frame]
@@ -458,84 +454,98 @@ def process_intermediate_results( # noqa: PLR0912
458454
val_frame
459455
)
460456

461-
for cvat_task_id, val_frame_ids in cvat_task_id_to_failed_val_frames.items():
457+
for cvat_task_id, task_bad_validation_frames in cvat_task_id_to_failed_val_frames.items():
462458
task_validation_layout = validation_result.task_id_to_val_layout[cvat_task_id]
463-
intersection = val_frame_ids & set(task_validation_layout.disabled_frames)
464459

465-
if intersection:
460+
task_disabled_bad_frames = (
461+
set(task_validation_layout.disabled_frames) & task_bad_validation_frames
462+
)
463+
if task_disabled_bad_frames:
466464
logger.error(
467465
"Logical error occurred while disabling validation frames "
468-
f"for the task({task_id}). Frames {intersection} "
466+
f"for the task({task_id}). Frames {task_disabled_bad_frames} "
469467
"are already disabled."
470468
)
471469

472-
updated_disable_frames = task_validation_layout.disabled_frames + list(val_frame_ids)
473-
non_disabled_frames = list(
474-
set(task_validation_layout.validation_frames) - set(updated_disable_frames)
470+
task_updated_disabled_frames = list(
471+
set(task_validation_layout.disabled_frames) | set(task_bad_validation_frames)
472+
)
473+
task_good_validation_frames = list(
474+
set(task_validation_layout.validation_frames) - set(task_updated_disabled_frames)
475475
)
476476

477-
if len(non_disabled_frames) < task_validation_layout.frames_per_job_count:
477+
if len(task_good_validation_frames) < task_validation_layout.frames_per_job_count:
478478
should_complete = True
479479
logger.info(
480480
f"Validation for escrow_address={escrow_address}: "
481481
"Too few validation frames left "
482482
f"(required: {task_validation_layout.frames_per_job_count}, "
483-
f"left: {len(non_disabled_frames)}) for the task({cvat_task_id}), "
483+
f"left: {len(task_good_validation_frames)}) for the task({cvat_task_id}), "
484484
"stopping annotation"
485485
)
486486
break
487487

488-
updated_task_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy()
488+
task_honeypot_to_index: dict[int, int] = {
489+
honeypot: i for i, honeypot in enumerate(task_validation_layout.honeypot_frames)
490+
} # honeypot -> list index
489491

490-
# validation frames may be repeated
491-
task_honeypot_real_frames_index: dict[int, list[int]] = {}
492-
for idx, f in enumerate(updated_task_honeypot_real_frames):
493-
task_honeypot_real_frames_index.setdefault(f, []).append(idx)
492+
task_honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id]
494493

495-
rejected_jobs_for_task = [
494+
task_rejected_jobs = [
496495
j
497496
for j in unchecked_jobs_meta.jobs
498497
if j.job_id in rejected_job_ids and j.task_id == cvat_task_id
499498
]
500499

501-
for job in rejected_jobs_for_task:
500+
task_updated_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy()
501+
for job in task_rejected_jobs:
502502
job_frame_range = job.job_frame_range
503-
honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id]
504-
job_honeypots = set(honeypots_mapping.keys()) & set(job_frame_range)
505-
506-
validation_frames_to_replace = [
507-
honeypots_mapping[honeypot]
508-
for honeypot in job_honeypots
509-
if honeypots_mapping[honeypot] in val_frame_ids
510-
]
511-
valid_used_validation_frames = [
512-
honeypots_mapping[honeypot]
513-
for honeypot in job_honeypots
514-
if honeypots_mapping[honeypot] not in validation_frames_to_replace
515-
]
503+
job_honeypots = sorted(set(task_honeypots_mapping.keys()) & set(job_frame_range))
504+
505+
job_honeypots_to_replace = []
506+
job_validation_frames_to_replace = []
507+
job_validation_frames_to_keep = []
508+
for honeypot in job_honeypots:
509+
validation_frame = task_honeypots_mapping[honeypot]
510+
if validation_frame in task_bad_validation_frames:
511+
job_honeypots_to_replace.append(honeypot)
512+
job_validation_frames_to_replace.append(validation_frame)
513+
else:
514+
job_validation_frames_to_keep.append(validation_frame)
516515

517516
# choose new unique validation frames for the job
518-
assert not (set(validation_frames_to_replace) & set(non_disabled_frames))
519-
available_validation_frames = list(
520-
set(non_disabled_frames) - set(valid_used_validation_frames)
517+
assert not (
518+
set(job_validation_frames_to_replace) & set(task_good_validation_frames)
521519
)
520+
job_available_validation_frames = list(
521+
set(task_good_validation_frames) - set(job_validation_frames_to_keep)
522+
)
523+
522524
rng = np.random.Generator(np.random.MT19937())
523-
new_validation_frames = rng.choice(
524-
available_validation_frames,
525+
new_job_validation_frames = rng.choice(
526+
job_available_validation_frames,
525527
replace=False,
526-
size=len(validation_frames_to_replace),
528+
size=len(job_validation_frames_to_replace),
527529
).tolist()
528530

529-
for prev_val_frame, new_val_frame in zip(
530-
validation_frames_to_replace, new_validation_frames, strict=True
531+
for honeypot, new_validation_frame in zip(
532+
job_honeypots_to_replace, new_job_validation_frames, strict=True
531533
):
532-
idx = task_honeypot_real_frames_index[prev_val_frame].pop(0)
533-
updated_task_honeypot_real_frames[idx] = new_val_frame
534+
honeypot_index = task_honeypot_to_index[honeypot]
535+
task_updated_honeypot_real_frames[honeypot_index] = new_validation_frame
536+
537+
# Make sure honeypots do not repeat in jobs
538+
assert len(
539+
{
540+
task_updated_honeypot_real_frames[task_honeypot_to_index[honeypot]]
541+
for honeypot in job_honeypots
542+
}
543+
) == len(job_honeypots)
534544

535545
cvat_api.update_task_validation_layout(
536546
cvat_task_id,
537-
disabled_frames=updated_disable_frames,
538-
honeypot_real_frames=updated_task_honeypot_real_frames,
547+
disabled_frames=task_updated_disabled_frames,
548+
honeypot_real_frames=task_updated_honeypot_real_frames,
539549
)
540550

541551
if logger.isEnabledFor(logging.DEBUG):

0 commit comments

Comments
 (0)