Skip to content

Commit fa5cf7a

Browse files
committed
[Rec oracle] Small fixes
1 parent e4ce52d commit fa5cf7a

File tree

3 files changed

+51
-43
lines changed

3 files changed

+51
-43
lines changed

packages/examples/cvat/recording-oracle/src/cvat/api_calls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def get_task_quality_report(
100100

101101
if (
102102
report
103-
# retrieving the task details to check if the latest quality report is actual
104-
# or not should be more effective than recreating quality report each time
103+
# retrieving the task details to check if the latest quality report is up-to-date
104+
# or not should be more effective than recreating the quality report each time
105105
and get_task(task_id).updated_date <= report.target_last_updated
106106
):
107107
if logger.isEnabledFor(logging.DEBUG):
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from pydantic import BaseModel
22

33

4-
class _AnnotationInfo(BaseModel):
4+
class AnnotationInfo(BaseModel):
55
accuracy: float | int
66

77

8-
class _FrameResult(BaseModel):
8+
class FrameResult(BaseModel):
99
conflicts: list[dict]
10-
annotations: _AnnotationInfo
10+
annotations: AnnotationInfo
1111

1212

1313
class QualityReportData(BaseModel):
14-
frame_results: dict[str, _FrameResult]
14+
frame_results: dict[str, FrameResult]

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,14 @@ def _validate_jobs(self):
188188
task_honeypots = {int(frame) for frame in task_quality_report_data.frame_results}
189189
honeypots_mapping = task_id_to_honeypots_mapping[cvat_task_id]
190190

191-
for honeypot in task_honeypots & set(job_meta.job_frame_range):
191+
job_honeypots = task_honeypots & set(job_meta.job_frame_range)
192+
193+
if not job_honeypots:
194+
job_results[cvat_job_id] = self.UNKNOWN_QUALITY
195+
rejected_jobs[cvat_job_id] = TooFewGtError
196+
continue
197+
198+
for honeypot in job_honeypots:
192199
val_frame = honeypots_mapping[honeypot]
193200
val_frame_name = sorted_task_frame_names[val_frame]
194201

@@ -205,11 +212,6 @@ def _validate_jobs(self):
205212
job_quality_report = job_id_to_quality_report[cvat_job_id]
206213

207214
accuracy = job_quality_report.summary.accuracy
208-
if not job_quality_report.summary.gt_count:
209-
assert accuracy == 0
210-
job_results[cvat_job_id] = self.UNKNOWN_QUALITY
211-
rejected_jobs[cvat_job_id] = TooFewGtError
212-
continue
213215

214216
job_results[cvat_job_id] = accuracy
215217

@@ -349,6 +351,8 @@ def process_intermediate_results( # noqa: PLR0912
349351
manifest: TaskManifest,
350352
logger: logging.Logger,
351353
) -> ValidationSuccess | ValidationFailure:
354+
should_complete = False
355+
352356
task = db_service.get_task_by_escrow_address(
353357
session,
354358
escrow_address,
@@ -457,6 +461,17 @@ def process_intermediate_results( # noqa: PLR0912
457461
set(task_validation_layout.validation_frames) - set(updated_disable_frames)
458462
)
459463

464+
if len(non_disabled_frames) < task_validation_layout.frames_per_job_count:
465+
should_complete = True
466+
logger.info(
467+
f"Validation for escrow_address={escrow_address}: "
468+
"Too few validation frames left "
469+
f"(required: {task_validation_layout.frames_per_job_count}, "
470+
f"left: {len(non_disabled_frames)}) for the task({cvat_task_id}), "
471+
"stopping annotation"
472+
)
473+
break
474+
460475
updated_task_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy()
461476

462477
# validation frames may be repeated
@@ -480,40 +495,35 @@ def process_intermediate_results( # noqa: PLR0912
480495
for honeypot in job_honeypots
481496
if honeypots_mapping[honeypot] in val_frame_ids
482497
]
498+
valid_used_validation_frames = [
499+
honeypots_mapping[honeypot]
500+
for honeypot in job_honeypots
501+
if honeypots_mapping[honeypot] not in validation_frames_to_replace
502+
]
483503

484504
# choose new unique validation frames for the job
505+
assert not (set(validation_frames_to_replace) & set(non_disabled_frames))
485506
available_validation_frames = list(
486-
set(non_disabled_frames) - set(validation_frames_to_replace)
507+
set(non_disabled_frames) - set(valid_used_validation_frames)
487508
)
488509
rng = np.random.Generator(np.random.MT19937())
489-
try:
490-
new_validation_frames = rng.choice(
491-
available_validation_frames,
492-
replace=False,
493-
size=len(validation_frames_to_replace),
494-
).tolist()
495-
496-
for prev_val_frame, new_val_frame in zip(
497-
validation_frames_to_replace, new_validation_frames, strict=True
498-
):
499-
idx = task_honeypot_real_frames_index[prev_val_frame].pop(0)
500-
updated_task_honeypot_real_frames[idx] = new_val_frame
501-
except ValueError as ex:
502-
logger.exception(
503-
"Exception occurred while generating new validation frames for "
504-
f"the job {job.job_id}. Tried to replace validation frames "
505-
f"{validation_frames_to_replace!r} by new "
506-
f"{len(validation_frames_to_replace)} frames. \nDetails: {ex!s}"
507-
)
508-
509-
if set(updated_disable_frames) != set(task_validation_layout.validation_frames):
510-
cvat_api.update_task_validation_layout(
511-
cvat_task_id,
512-
disabled_frames=updated_disable_frames,
513-
honeypot_real_frames=updated_task_honeypot_real_frames,
514-
)
515-
else:
516-
logger.error("All validation frames were banned. Honeypots will not be shuffled")
510+
new_validation_frames = rng.choice(
511+
available_validation_frames,
512+
replace=False,
513+
size=len(validation_frames_to_replace),
514+
).tolist()
515+
516+
for prev_val_frame, new_val_frame in zip(
517+
validation_frames_to_replace, new_validation_frames, strict=True
518+
):
519+
idx = task_honeypot_real_frames_index[prev_val_frame].pop(0)
520+
updated_task_honeypot_real_frames[idx] = new_val_frame
521+
522+
cvat_api.update_task_validation_layout(
523+
cvat_task_id,
524+
disabled_frames=updated_disable_frames,
525+
honeypot_real_frames=updated_task_honeypot_real_frames,
526+
)
517527

518528
if logger.isEnabledFor(logging.DEBUG):
519529
logger.debug("Updating GT stats: %s", gt_stats)
@@ -546,8 +556,6 @@ def process_intermediate_results( # noqa: PLR0912
546556

547557
task_jobs = task.jobs
548558

549-
should_complete = False
550-
551559
if Config.validation.max_escrow_iterations > 0:
552560
escrow_iteration = task.iteration
553561
if escrow_iteration and Config.validation.max_escrow_iterations <= escrow_iteration:

0 commit comments

Comments
 (0)