Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@

_RejectedJobs = dict[int, DatasetValidationError]


_TaskIdToValidationLayout = dict[int, dict]
_TaskIdToHoneypotsMapping = dict[int, dict]
_TaskIdToSequenceOfFrameNames = dict[int, list[str]]

_HoneypotFrameId = int
_ValidationFrameId = int
_HoneypotFrameToValFrame = dict[_HoneypotFrameId, _ValidationFrameId]
_TaskIdToValidationLayout = dict[int, cvat_api.models.ITaskValidationLayoutRead]
_TaskIdToHoneypotsMapping = dict[int, _HoneypotFrameToValFrame]
_TaskIdToFrameNames = dict[int, list[str]]


@dataclass
Expand All @@ -72,7 +70,7 @@ class _ValidationResult:
gt_stats: GtStats
task_id_to_val_layout: _TaskIdToValidationLayout
task_id_to_honeypots_mapping: _TaskIdToHoneypotsMapping
task_id_to_sequence_of_frame_names: _TaskIdToSequenceOfFrameNames
task_id_to_frame_names: _TaskIdToFrameNames


T = TypeVar("T")
Expand Down Expand Up @@ -349,9 +347,7 @@ def validate(self) -> _ValidationResult:
gt_stats=self._require_field(self._gt_stats),
task_id_to_val_layout=self._require_field(self._task_id_to_val_layout),
task_id_to_honeypots_mapping=self._require_field(self._task_id_to_honeypots_mapping),
task_id_to_sequence_of_frame_names=self._require_field(
self._task_id_to_sequence_of_frame_names
),
task_id_to_frame_names=self._require_field(self._task_id_to_sequence_of_frame_names),
)


Expand Down Expand Up @@ -436,16 +432,16 @@ def process_intermediate_results( # noqa: PLR0912
for rejected_job_id in rejected_job_ids:
job_frame_range = job_id_to_frame_range[rejected_job_id]
cvat_task_id = job_id_to_task_id[rejected_job_id]
honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id]
job_honeypots = set(honeypots_mapping.keys()) & set(job_frame_range)
task_honeypots_mapping = validation_result.task_id_to_honeypots_mapping[
cvat_task_id
]
job_honeypots = sorted(set(task_honeypots_mapping.keys()) & set(job_frame_range))
validation_frames = [
val_frame
for honeypot, val_frame in honeypots_mapping.items()
for honeypot, val_frame in task_honeypots_mapping.items()
if honeypot in job_honeypots
]
sorted_task_frame_names = validation_result.task_id_to_sequence_of_frame_names[
cvat_task_id
]
sorted_task_frame_names = validation_result.task_id_to_frame_names[cvat_task_id]

for val_frame in validation_frames:
val_frame_name = sorted_task_frame_names[val_frame]
Expand All @@ -458,84 +454,98 @@ def process_intermediate_results( # noqa: PLR0912
val_frame
)

for cvat_task_id, val_frame_ids in cvat_task_id_to_failed_val_frames.items():
for cvat_task_id, task_bad_validation_frames in cvat_task_id_to_failed_val_frames.items():
task_validation_layout = validation_result.task_id_to_val_layout[cvat_task_id]
intersection = val_frame_ids & set(task_validation_layout.disabled_frames)

if intersection:
task_disabled_bad_frames = (
set(task_validation_layout.disabled_frames) & task_bad_validation_frames
)
if task_disabled_bad_frames:
logger.error(
"Logical error occurred while disabling validation frames "
f"for the task({task_id}). Frames {intersection} "
f"for the task({task_id}). Frames {task_disabled_bad_frames} "
"are already disabled."
)

updated_disable_frames = task_validation_layout.disabled_frames + list(val_frame_ids)
non_disabled_frames = list(
set(task_validation_layout.validation_frames) - set(updated_disable_frames)
task_updated_disabled_frames = list(
set(task_validation_layout.disabled_frames) | set(task_bad_validation_frames)
)
task_good_validation_frames = list(
set(task_validation_layout.validation_frames) - set(task_updated_disabled_frames)
)

if len(non_disabled_frames) < task_validation_layout.frames_per_job_count:
if len(task_good_validation_frames) < task_validation_layout.frames_per_job_count:
should_complete = True
logger.info(
f"Validation for escrow_address={escrow_address}: "
"Too few validation frames left "
f"(required: {task_validation_layout.frames_per_job_count}, "
f"left: {len(non_disabled_frames)}) for the task({cvat_task_id}), "
f"left: {len(task_good_validation_frames)}) for the task({cvat_task_id}), "
"stopping annotation"
)
break

updated_task_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy()
task_honeypot_to_index: dict[int, int] = {
honeypot: i for i, honeypot in enumerate(task_validation_layout.honeypot_frames)
} # honeypot -> list index

# validation frames may be repeated
task_honeypot_real_frames_index: dict[int, list[int]] = {}
for idx, f in enumerate(updated_task_honeypot_real_frames):
task_honeypot_real_frames_index.setdefault(f, []).append(idx)
task_honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id]

rejected_jobs_for_task = [
task_rejected_jobs = [
j
for j in unchecked_jobs_meta.jobs
if j.job_id in rejected_job_ids and j.task_id == cvat_task_id
]

for job in rejected_jobs_for_task:
task_updated_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy()
for job in task_rejected_jobs:
job_frame_range = job.job_frame_range
honeypots_mapping = validation_result.task_id_to_honeypots_mapping[cvat_task_id]
job_honeypots = set(honeypots_mapping.keys()) & set(job_frame_range)

validation_frames_to_replace = [
honeypots_mapping[honeypot]
for honeypot in job_honeypots
if honeypots_mapping[honeypot] in val_frame_ids
]
valid_used_validation_frames = [
honeypots_mapping[honeypot]
for honeypot in job_honeypots
if honeypots_mapping[honeypot] not in validation_frames_to_replace
]
job_honeypots = sorted(set(task_honeypots_mapping.keys()) & set(job_frame_range))

job_honeypots_to_replace = []
job_validation_frames_to_replace = []
job_validation_frames_to_keep = []
for honeypot in job_honeypots:
validation_frame = task_honeypots_mapping[honeypot]
if validation_frame in task_bad_validation_frames:
job_honeypots_to_replace.append(honeypot)
job_validation_frames_to_replace.append(validation_frame)
else:
job_validation_frames_to_keep.append(validation_frame)

# choose new unique validation frames for the job
assert not (set(validation_frames_to_replace) & set(non_disabled_frames))
available_validation_frames = list(
set(non_disabled_frames) - set(valid_used_validation_frames)
assert not (
set(job_validation_frames_to_replace) & set(task_good_validation_frames)
)
job_available_validation_frames = list(
set(task_good_validation_frames) - set(job_validation_frames_to_keep)
)

rng = np.random.Generator(np.random.MT19937())
new_validation_frames = rng.choice(
available_validation_frames,
new_job_validation_frames = rng.choice(
job_available_validation_frames,
replace=False,
size=len(validation_frames_to_replace),
size=len(job_validation_frames_to_replace),
).tolist()

for prev_val_frame, new_val_frame in zip(
validation_frames_to_replace, new_validation_frames, strict=True
for honeypot, new_validation_frame in zip(
job_honeypots_to_replace, new_job_validation_frames, strict=True
):
idx = task_honeypot_real_frames_index[prev_val_frame].pop(0)
updated_task_honeypot_real_frames[idx] = new_val_frame
honeypot_index = task_honeypot_to_index[honeypot]
task_updated_honeypot_real_frames[honeypot_index] = new_validation_frame

# Make sure honeypots do not repeat in jobs
assert len(
{
task_updated_honeypot_real_frames[task_honeypot_to_index[honeypot]]
for honeypot in job_honeypots
}
) == len(job_honeypots)

cvat_api.update_task_validation_layout(
cvat_task_id,
disabled_frames=updated_disable_frames,
honeypot_real_frames=updated_task_honeypot_real_frames,
disabled_frames=task_updated_disabled_frames,
honeypot_real_frames=task_updated_honeypot_real_frames,
)

if logger.isEnabledFor(logging.DEBUG):
Expand Down
Loading