Skip to content

Commit a933ea6

Browse files
committed
[Recording oracle] Apply comments
1 parent 1963855 commit a933ea6

File tree

6 files changed

+72
-42
lines changed

6 files changed

+72
-42
lines changed

packages/examples/cvat/recording-oracle/alembic/versions/9d4367899f90_recreate_gt_stats.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@ def upgrade() -> None:
2525
op.create_table(
2626
"gt_stats",
2727
sa.Column("task_id", sa.String(), nullable=False),
28-
sa.Column("cvat_task_id", sa.Integer(), nullable=False),
29-
sa.Column("gt_frame_id", sa.Integer(), nullable=False),
28+
sa.Column("gt_frame_name", sa.String(), nullable=False),
3029
sa.Column("failed_attempts", sa.Integer(), nullable=False),
3130
sa.Column("accepted_attempts", sa.Integer(), nullable=False),
3231
sa.Column("accumulated_quality", sa.Float(), nullable=False),
3332
sa.ForeignKeyConstraint(["task_id"], ["tasks.id"], ondelete="CASCADE"),
34-
sa.PrimaryKeyConstraint("task_id", "gt_frame_id"),
33+
sa.PrimaryKeyConstraint("task_id", "gt_frame_name"),
3534
)
3635
# ### end Alembic commands ###
3736

packages/examples/cvat/recording-oracle/src/cvat/api_calls.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,19 @@ def get_quality_report_data(report_id: int) -> QualityReportData:
111111
raise
112112

113113

114+
def get_jobs_quality_reports(parent_id: int) -> list[models.QualityReport]:
115+
logger = logging.getLogger("app")
116+
with get_api_client() as api_client:
117+
try:
118+
return get_paginated_collection(
119+
api_client.quality_api.list_reports_endpoint, parent_id=parent_id, target="job"
120+
)
121+
122+
except exceptions.ApiException as e:
123+
logger.exception(f"Exception when calling QualityApi.list_reports: {e}\n")
124+
raise
125+
126+
114127
def get_task_validation_layout(task_id: int) -> models.TaskValidationLayoutRead:
115128
logger = logging.getLogger("app")
116129
with get_api_client() as api_client:
@@ -127,19 +140,6 @@ def get_task_validation_layout(task_id: int) -> models.TaskValidationLayoutRead:
127140
raise
128141

129142

130-
def get_jobs_quality_reports(parent_id: int) -> list[models.QualityReport]:
131-
logger = logging.getLogger("app")
132-
with get_api_client() as api_client:
133-
try:
134-
return get_paginated_collection(
135-
api_client.quality_api.list_reports_endpoint, parent_id=parent_id, target="job"
136-
)
137-
138-
except exceptions.ApiException as e:
139-
logger.exception(f"Exception when calling QualityApi.list_reports: {e}\n")
140-
raise
141-
142-
143143
def update_task_validation_layout(
144144
task_id: int,
145145
*,
@@ -168,3 +168,15 @@ def update_task_validation_layout(
168168

169169
if logger.isEnabledFor(logging.DEBUG):
170170
logger.debug(f"Validation layout: {validation_layout}")
171+
172+
173+
def get_task_data_meta(task_id: int) -> models.DataMetaRead:
174+
logger = logging.getLogger("app")
175+
with get_api_client() as api_client:
176+
try:
177+
data_meta, _ = api_client.tasks_api.retrieve_data_meta(task_id)
178+
return data_meta
179+
180+
except exceptions.ApiException as ex:
181+
logger.exception(f"Exception when calling TaskApi.retrieve_data_meta: {ex}\n")
182+
raise
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from pydantic import BaseModel
22

33

4-
class AnnotationInfo(BaseModel):
4+
class _AnnotationInfo(BaseModel):
55
accuracy: float | int
66

77

8-
class FrameResult(BaseModel):
8+
class _FrameResult(BaseModel):
99
conflicts: list[dict]
10-
annotations: AnnotationInfo
10+
annotations: _AnnotationInfo
1111

1212

1313
class QualityReportData(BaseModel):
14-
frame_results: dict[str, FrameResult]
14+
frame_results: dict[str, _FrameResult]

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858

5959
_TaskIdToValidationLayout = dict[int, dict]
6060
_TaskIdToHoneypotsMapping = dict[int, dict]
61+
_TaskIdToSequenceOfFrameNames = dict[int, list[str]]
6162

6263
_HoneypotFrameId = int
6364
_ValidationFrameId = int
@@ -72,6 +73,7 @@ class _ValidationResult:
7273
gt_stats: GtStats
7374
task_id_to_val_layout: _TaskIdToValidationLayout
7475
task_id_to_honeypots_mapping: _TaskIdToHoneypotsMapping
76+
task_id_to_sequence_of_frame_names: _TaskIdToSequenceOfFrameNames
7577

7678

7779
T = TypeVar("T")
@@ -141,23 +143,33 @@ def _validate_jobs(self):
141143
task_id_to_val_layout: dict[int, cvat_api.models.TaskValidationLayoutRead] = {}
142144
task_id_to_honeypots_mapping: dict[int, _HoneypotFrameToValFrame] = {}
143145

146+
# store sequence of frame names for each task
147+
# task honeypot with frame index matches the sequence[index]
148+
task_id_to_sequence_of_frame_names: dict[int, list[str]] = {}
149+
144150
min_quality = manifest.validation.min_quality
145151

146152
job_id_to_quality_report: dict[int, cvat_api.models.QualityReport] = {}
147153

148154
for cvat_task_id in cvat_task_ids:
155+
# obtain quality report details
149156
task_quality_report = cvat_api.get_task_quality_report(cvat_task_id)
150157
task_quality_report_data = cvat_api.get_quality_report_data(task_quality_report.id)
151158
task_id_to_quality_report_data[cvat_task_id] = task_quality_report_data
152159

160+
# obtain task validation layout and define honeypots mapping
153161
task_val_layout = cvat_api.get_task_validation_layout(cvat_task_id)
154162
honeypot_frame_to_real = {
155163
f: task_val_layout.honeypot_real_frames[idx]
156164
for idx, f in enumerate(task_val_layout.honeypot_frames)
157165
}
158166
task_id_to_val_layout[cvat_task_id] = task_val_layout
159167
task_id_to_honeypots_mapping[cvat_task_id] = honeypot_frame_to_real
168+
task_id_to_sequence_of_frame_names[cvat_task_id] = [
169+
frame.name for frame in cvat_api.get_task_data_meta(cvat_task_id).frames
170+
]
160171

172+
# obtain quality reports for each job from the task
161173
job_id_to_quality_report.update(
162174
{
163175
quality_report.job_id: quality_report
@@ -172,28 +184,28 @@ def _validate_jobs(self):
172184

173185
# assess quality of the job's honeypots
174186
task_quality_report_data = task_id_to_quality_report_data[cvat_task_id]
187+
sorted_task_frame_names = task_id_to_sequence_of_frame_names[cvat_task_id]
175188
task_honeypots = {int(frame) for frame in task_quality_report_data.frame_results}
176189
honeypots_mapping = task_id_to_honeypots_mapping[cvat_task_id]
177190

178191
for honeypot in task_honeypots & set(job_meta.job_frame_range):
179192
val_frame = honeypots_mapping[honeypot]
193+
val_frame_name = sorted_task_frame_names[val_frame]
180194

181195
result = task_quality_report_data.frame_results[str(honeypot)]
182-
self._gt_stats.setdefault((cvat_task_id, val_frame), ValidationFrameStats())
183-
self._gt_stats[
184-
(cvat_task_id, val_frame)
185-
].accumulated_quality += result.annotations.accuracy
196+
self._gt_stats.setdefault(val_frame_name, ValidationFrameStats())
197+
self._gt_stats[val_frame_name].accumulated_quality += result.annotations.accuracy
186198

187199
if result.annotations.accuracy < min_quality:
188-
self._gt_stats[(cvat_task_id, val_frame)].failed_attempts += 1
200+
self._gt_stats[val_frame_name].failed_attempts += 1
189201
else:
190-
self._gt_stats[(cvat_task_id, val_frame)].accepted_attempts += 1
202+
self._gt_stats[val_frame_name].accepted_attempts += 1
191203

192204
# assess job quality
193205
job_quality_report = job_id_to_quality_report[cvat_job_id]
194206

195207
accuracy = job_quality_report.summary.accuracy
196-
if isinstance(accuracy, int):
208+
if not job_quality_report.summary.gt_count:
197209
assert accuracy == 0
198210
job_results[cvat_job_id] = self.UNKNOWN_QUALITY
199211
rejected_jobs[cvat_job_id] = TooFewGtError
@@ -208,6 +220,7 @@ def _validate_jobs(self):
208220
self._rejected_jobs = rejected_jobs
209221
self._task_id_to_val_layout = task_id_to_val_layout
210222
self._task_id_to_honeypots_mapping = task_id_to_honeypots_mapping
223+
self._task_id_to_sequence_of_frame_names = task_id_to_sequence_of_frame_names
211224

212225
def _restore_original_image_paths(self, merged_dataset: dm.Dataset) -> dm.Dataset:
213226
class RemoveCommonPrefix(dm.ItemTransform):
@@ -320,6 +333,9 @@ def validate(self) -> _ValidationResult:
320333
gt_stats=self._require_field(self._gt_stats),
321334
task_id_to_val_layout=self._require_field(self._task_id_to_val_layout),
322335
task_id_to_honeypots_mapping=self._require_field(self._task_id_to_honeypots_mapping),
336+
task_id_to_sequence_of_frame_names=self._require_field(
337+
self._task_id_to_sequence_of_frame_names
338+
),
323339
)
324340

325341

@@ -359,7 +375,7 @@ def process_intermediate_results( # noqa: PLR0912
359375
logger.debug("Task id %s, %s", getattr(task, "id", None), getattr(task, "__dict__", None))
360376

361377
gt_stats = {
362-
(gt_image_stat.cvat_task_id, gt_image_stat.gt_frame_id): ValidationFrameStats(
378+
gt_image_stat.gt_frame_name: ValidationFrameStats(
363379
failed_attempts=gt_image_stat.failed_attempts,
364380
accepted_attempts=gt_image_stat.accepted_attempts,
365381
accumulated_quality=gt_image_stat.accumulated_quality,
@@ -391,9 +407,8 @@ def process_intermediate_results( # noqa: PLR0912
391407

392408
gt_stats = validation_result.gt_stats
393409
if gt_stats:
394-
cvat_task_id_to_failed_val_frames: dict[
395-
int, set[int]
396-
] = {} # cvat_task_id: {val_frame_id, ...}
410+
# cvat_task_id: {val_frame_id, ...}
411+
cvat_task_id_to_failed_val_frames: dict[int, set[int]] = {}
397412
rejected_job_ids = rejected_jobs.keys()
398413

399414
if rejected_job_ids:
@@ -411,9 +426,13 @@ def process_intermediate_results( # noqa: PLR0912
411426
for honeypot, val_frame in honeypots_mapping.items()
412427
if honeypot in job_honeypots
413428
]
429+
sorted_task_frame_names = validation_result.task_id_to_sequence_of_frame_names[
430+
cvat_task_id
431+
]
414432

415433
for val_frame in validation_frames:
416-
val_frame_stats = gt_stats[(cvat_task_id, val_frame)]
434+
val_frame_name = sorted_task_frame_names[val_frame]
435+
val_frame_stats = gt_stats[val_frame_name]
417436
if (
418437
val_frame_stats.failed_attempts >= Config.validation.gt_ban_threshold
419438
and not val_frame_stats.accepted_attempts
@@ -439,9 +458,11 @@ def process_intermediate_results( # noqa: PLR0912
439458
)
440459

441460
updated_task_honeypot_real_frames = task_validation_layout.honeypot_real_frames.copy()
442-
task_honeypot_real_frames_index = {
443-
f: idx for idx, f in enumerate(updated_task_honeypot_real_frames)
444-
}
461+
462+
# validation frames may be repeated
463+
task_honeypot_real_frames_index: dict[int, list[int]] = {}
464+
for idx, f in enumerate(updated_task_honeypot_real_frames):
465+
task_honeypot_real_frames_index.setdefault(f, []).append(idx)
445466

446467
rejected_jobs_for_task = [
447468
j
@@ -478,7 +499,7 @@ def process_intermediate_results( # noqa: PLR0912
478499
for prev_val_frame, new_val_frame in zip(
479500
validation_frames_to_replace, new_validation_frames, strict=True
480501
):
481-
idx = task_honeypot_real_frames_index[prev_val_frame]
502+
idx = task_honeypot_real_frames_index[prev_val_frame].pop(0)
482503
updated_task_honeypot_real_frames[idx] = new_val_frame
483504
except ValueError as ex:
484505
logger.exception(

packages/examples/cvat/recording-oracle/src/models/validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ class GtStats(Base):
5757
task_id = Column(
5858
String, ForeignKey("tasks.id", ondelete="CASCADE"), primary_key=True, nullable=False
5959
)
60-
cvat_task_id = Column(Integer, nullable=False)
61-
gt_frame_id = Column(Integer, primary_key=True, nullable=False)
60+
gt_frame_name = Column(String, primary_key=True, nullable=False)
6261

6362
failed_attempts = Column(Integer, default=0, nullable=False)
6463
accepted_attempts = Column(Integer, default=0, nullable=False)

packages/examples/cvat/recording-oracle/src/services/validation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def get_task_gt_stats(
134134

135135

136136
def update_gt_stats(
137-
session: Session, task_id: str, updated_gt_stats: dict[tuple[int, int], ValidationFrameStats]
137+
session: Session, task_id: str, updated_gt_stats: dict[str, ValidationFrameStats]
138138
):
139139
# Read more about upsert:
140140
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-upsert-statements
@@ -152,13 +152,12 @@ def update_gt_stats(
152152
[
153153
{
154154
"task_id": task_id,
155-
"cvat_task_id": cvat_task_id,
156-
"gt_frame_id": gt_frame_id,
155+
"gt_frame_name": gt_frame_name,
157156
"failed_attempts": val_frame_stats.failed_attempts,
158157
"accepted_attempts": val_frame_stats.accepted_attempts,
159158
"accumulated_quality": val_frame_stats.accumulated_quality,
160159
}
161-
for (cvat_task_id, gt_frame_id), val_frame_stats in updated_gt_stats.items()
160+
for gt_frame_name, val_frame_stats in updated_gt_stats.items()
162161
],
163162
)
164163
statement = statement.on_conflict_do_update(

0 commit comments

Comments
 (0)