Skip to content

Commit 1a730e9

Browse files
authored
[CVAT-M2] Review GT use (#1850)
* Exclude input GT files from the data for assignments * Use the original GT during merging final results for new task types * Fix gt preparation for skeletons_from_boxes * Fix debug message
1 parent db2bc21 commit 1a730e9

File tree

2 files changed

+82
-35
lines changed

2 files changed

+82
-35
lines changed

packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int):
138138
self._input_points_data: _MaybeUnset[bytes] = _unset
139139

140140
self._data_filenames: _MaybeUnset[Sequence[str]] = _unset
141+
self._input_gt_dataset: _MaybeUnset[dm.Dataset] = _unset
141142
self._gt_dataset: _MaybeUnset[dm.Dataset] = _unset
142143
self._points_dataset: _MaybeUnset[dm.Dataset] = _unset
143144

@@ -242,7 +243,7 @@ def _parse_dataset(self, annotation_file_data: bytes, dataset_format: str) -> dm
242243
def _parse_gt(self):
243244
assert self._input_gt_data is not _unset
244245

245-
self._gt_dataset = self._parse_dataset(
246+
self._input_gt_dataset = self._parse_dataset(
246247
self._input_gt_data,
247248
dataset_format=DM_GT_DATASET_FORMAT_MAPPING[self.manifest.annotation.type],
248249
)
@@ -257,7 +258,7 @@ def _parse_points(self):
257258
def _validate_gt_labels(self):
258259
gt_labels = set(
259260
label.name
260-
for label in self._gt_dataset.categories()[dm.AnnotationType.label]
261+
for label in self._input_gt_dataset.categories()[dm.AnnotationType.label]
261262
if not label.parent
262263
)
263264
manifest_labels = set(label.name for label in self.manifest.annotation.labels)
@@ -268,13 +269,13 @@ def _validate_gt_labels(self):
268269
)
269270
)
270271

271-
self._gt_dataset.transform(
272+
self._input_gt_dataset.transform(
272273
ProjectLabels, dst_labels=[label.name for label in self.manifest.annotation.labels]
273274
)
274-
self._gt_dataset.init_cache()
275+
self._input_gt_dataset.init_cache()
275276

276277
def _validate_gt_filenames(self):
277-
gt_filenames = set(s.id + s.media.ext for s in self._gt_dataset)
278+
gt_filenames = set(s.id + s.media.ext for s in self._input_gt_dataset)
278279

279280
known_data_filenames = set(self._data_filenames)
280281
matched_gt_filenames = gt_filenames.intersection(known_data_filenames)
@@ -295,12 +296,12 @@ def _validate_gt_filenames(self):
295296
)
296297

297298
def _validate_gt_annotations(self):
298-
label_cat: dm.LabelCategories = self._gt_dataset.categories()[dm.AnnotationType.label]
299+
label_cat: dm.LabelCategories = self._input_gt_dataset.categories()[dm.AnnotationType.label]
299300

300301
excluded_gt_info = _ExcludedAnnotationsInfo()
301302
excluded_samples = set()
302303
visited_ids = set()
303-
for gt_sample in self._gt_dataset:
304+
for gt_sample in self._input_gt_dataset:
304305
# Could fail on this as well
305306
img_h, img_w = gt_sample.media_as(dm.Image).size
306307

@@ -340,10 +341,10 @@ def _validate_gt_annotations(self):
340341
if not valid_boxes:
341342
excluded_samples.add((gt_sample.id, gt_sample.subset))
342343
else:
343-
self._gt_dataset.put(gt_sample.wrap(annotations=valid_boxes))
344+
self._input_gt_dataset.put(gt_sample.wrap(annotations=valid_boxes))
344345

345346
for excluded_sample in excluded_samples:
346-
self._gt_dataset.remove(*excluded_sample)
347+
self._input_gt_dataset.remove(*excluded_sample)
347348

348349
if excluded_gt_info.excluded_count:
349350
self.logger.warning(
@@ -369,7 +370,7 @@ def _validate_gt_annotations(self):
369370

370371
def _validate_gt(self):
371372
assert self._data_filenames is not _unset
372-
assert self._gt_dataset is not _unset
373+
assert self._input_gt_dataset is not _unset
373374

374375
self._validate_gt_filenames()
375376
self._validate_gt_labels()
@@ -643,10 +644,10 @@ def _find_good_gt_boxes(
643644

644645
assert self._data_filenames is not _unset
645646
assert self._points_dataset is not _unset
646-
assert self._gt_dataset is not _unset
647-
assert [label.name for label in self._gt_dataset.categories()[dm.AnnotationType.label]] == [
648-
label.name for label in self.manifest.annotation.labels
649-
]
647+
assert self._input_gt_dataset is not _unset
648+
assert [
649+
label.name for label in self._input_gt_dataset.categories()[dm.AnnotationType.label]
650+
] == [label.name for label in self.manifest.annotation.labels]
650651
assert [
651652
label.name
652653
for label in self._points_dataset.categories()[dm.AnnotationType.label]
@@ -656,17 +657,19 @@ def _find_good_gt_boxes(
656657
points_label_cat: dm.LabelCategories = self._points_dataset.categories()[
657658
dm.AnnotationType.label
658659
]
659-
gt_label_cat: dm.LabelCategories = self._gt_dataset.categories()[dm.AnnotationType.label]
660+
gt_label_cat: dm.LabelCategories = self._input_gt_dataset.categories()[
661+
dm.AnnotationType.label
662+
]
660663

661664
updated_gt_dataset = dm.Dataset(
662-
categories=self._gt_dataset.categories(), media_type=dm.Image
665+
categories=self._input_gt_dataset.categories(), media_type=dm.Image
663666
)
664667

665668
excluded_points_info = _ExcludedAnnotationsInfo() # local for the function
666669
excluded_gt_info = self._excluded_gt_info
667670
gt_count_per_class = {}
668671
bbox_point_mapping = {} # bbox id -> point id
669-
for gt_sample in self._gt_dataset:
672+
for gt_sample in self._input_gt_dataset:
670673
points_sample = self._points_dataset.get(gt_sample.id, gt_sample.subset)
671674
assert points_sample
672675

@@ -856,12 +859,24 @@ def _prepare_job_layout(self):
856859

857860
assert self._rois is not _unset
858861
assert self._bbox_point_mapping is not _unset
862+
assert self._input_gt_dataset is not _unset
863+
864+
# This list can be different from what is selected for validation
865+
input_gt_filenames = set(sample.media.path for sample in self._input_gt_dataset)
866+
original_image_id_to_filename = {
867+
sample.attributes["id"]: sample.media.path for sample in self._points_dataset
868+
}
869+
point_id_to_original_image_id = {roi.point_id: roi.original_image_key for roi in self._rois}
859870

860871
gt_point_ids = set(self._bbox_point_mapping.values())
861872
gt_filenames = [self._roi_filenames[point_id] for point_id in gt_point_ids]
862873

863874
data_filenames = [
864-
fn for point_id, fn in self._roi_filenames.items() if not point_id in gt_point_ids
875+
fn
876+
for point_id, fn in self._roi_filenames.items()
877+
if not point_id in gt_point_ids
878+
if not original_image_id_to_filename[point_id_to_original_image_id[point_id]]
879+
in input_gt_filenames
865880
]
866881
random.shuffle(data_filenames)
867882

@@ -874,6 +889,12 @@ def _prepare_job_layout(self):
874889

875890
self._job_layout = job_layout
876891

892+
self.logger.info(
893+
"Task creation for escrow '%s': will create %s assignments",
894+
self.escrow_address,
895+
len(job_layout),
896+
)
897+
877898
def _prepare_label_configuration(self):
878899
self._label_configuration = make_label_configuration(self.manifest)
879900

@@ -1145,6 +1166,7 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int):
11451166
self._input_boxes_data: _MaybeUnset[bytes] = _unset
11461167

11471168
self._data_filenames: _MaybeUnset[Sequence[str]] = _unset
1169+
self._input_gt_dataset: _MaybeUnset[dm.Dataset] = _unset
11481170
self._gt_dataset: _MaybeUnset[dm.Dataset] = _unset
11491171
self._boxes_dataset: _MaybeUnset[dm.Dataset] = _unset
11501172

@@ -1237,7 +1259,7 @@ def _parse_dataset(self, annotation_file_data: bytes, dataset_format: str) -> dm
12371259
def _parse_gt(self):
12381260
assert self._input_gt_data is not _unset
12391261

1240-
self._gt_dataset = self._parse_dataset(
1262+
self._input_gt_dataset = self._parse_dataset(
12411263
self._input_gt_data,
12421264
dataset_format=DM_GT_DATASET_FORMAT_MAPPING[self.manifest.annotation.type],
12431265
)
@@ -1252,7 +1274,7 @@ def _parse_boxes(self):
12521274
def _validate_gt_labels(self):
12531275
gt_labels = set(
12541276
(label.name, label.parent)
1255-
for label in self._gt_dataset.categories()[dm.AnnotationType.label]
1277+
for label in self._input_gt_dataset.categories()[dm.AnnotationType.label]
12561278
)
12571279

12581280
manifest_labels = set()
@@ -1274,13 +1296,13 @@ def _validate_gt_labels(self):
12741296
)
12751297

12761298
# Reorder labels to match the manifest
1277-
self._gt_dataset.transform(
1299+
self._input_gt_dataset.transform(
12781300
ProjectLabels, dst_labels=[label.name for label in self.manifest.annotation.labels]
12791301
)
1280-
self._gt_dataset.init_cache()
1302+
self._input_gt_dataset.init_cache()
12811303

12821304
def _validate_gt_filenames(self):
1283-
gt_filenames = set(s.id + s.media.ext for s in self._gt_dataset)
1305+
gt_filenames = set(s.id + s.media.ext for s in self._input_gt_dataset)
12841306

12851307
known_data_filenames = set(self._data_filenames)
12861308
matched_gt_filenames = gt_filenames.intersection(known_data_filenames)
@@ -1316,12 +1338,12 @@ def _validate_skeleton(skeleton: dm.Skeleton, *, sample_bbox: dm.Bbox):
13161338
if not is_point_in_bbox(px, py, sample_bbox):
13171339
raise InvalidCoordinates("skeleton point is outside the image")
13181340

1319-
label_cat: dm.LabelCategories = self._gt_dataset.categories()[dm.AnnotationType.label]
1341+
label_cat: dm.LabelCategories = self._input_gt_dataset.categories()[dm.AnnotationType.label]
13201342

13211343
excluded_gt_info = _ExcludedAnnotationsInfo()
13221344
excluded_samples = set()
13231345
visited_ids = set()
1324-
for gt_sample in self._gt_dataset:
1346+
for gt_sample in self._input_gt_dataset:
13251347
# Could fail on this as well
13261348
img_h, img_w = gt_sample.media_as(dm.Image).size
13271349
sample_bbox = dm.Bbox(0, 0, w=img_w, h=img_h)
@@ -1363,14 +1385,14 @@ def _validate_skeleton(skeleton: dm.Skeleton, *, sample_bbox: dm.Bbox):
13631385
else:
13641386
# Skeleton boxes can be in the list as well with the same ids / groups
13651387
skeleton_ids = set(a.id for a in valid_skeletons) - {0}
1366-
self._gt_dataset.put(
1388+
self._input_gt_dataset.put(
13671389
gt_sample.wrap(
13681390
annotations=[a for a in gt_sample.annotations if a.id in skeleton_ids]
13691391
)
13701392
)
13711393

13721394
for excluded_sample in excluded_samples:
1373-
self._gt_dataset.remove(*excluded_sample)
1395+
self._input_gt_dataset.remove(*excluded_sample)
13741396

13751397
if excluded_gt_info.excluded_count:
13761398
self.logger.warning(
@@ -1396,7 +1418,7 @@ def _validate_skeleton(skeleton: dm.Skeleton, *, sample_bbox: dm.Bbox):
13961418

13971419
def _validate_gt(self):
13981420
assert self._data_filenames is not _unset
1399-
assert self._gt_dataset is not _unset
1421+
assert self._input_gt_dataset is not _unset
14001422

14011423
self._validate_gt_filenames()
14021424
self._validate_gt_labels()
@@ -1653,12 +1675,14 @@ def _find_good_gt_skeletons(
16531675
matched_skeletons.append(gt_skeleton)
16541676
skeleton_bbox_mapping[gt_skeleton_id] = matched_boxes[0].id
16551677

1678+
return matched_skeletons
1679+
16561680
assert self._data_filenames is not _unset
16571681
assert self._boxes_dataset is not _unset
1658-
assert self._gt_dataset is not _unset
1682+
assert self._input_gt_dataset is not _unset
16591683
assert [
16601684
label.name
1661-
for label in self._gt_dataset.categories()[dm.AnnotationType.label]
1685+
for label in self._input_gt_dataset.categories()[dm.AnnotationType.label]
16621686
if not label.parent
16631687
] == [label.name for label in self.manifest.annotation.labels]
16641688
assert [
@@ -1670,17 +1694,19 @@ def _find_good_gt_skeletons(
16701694
boxes_label_cat: dm.LabelCategories = self._boxes_dataset.categories()[
16711695
dm.AnnotationType.label
16721696
]
1673-
gt_label_cat: dm.LabelCategories = self._gt_dataset.categories()[dm.AnnotationType.label]
1697+
gt_label_cat: dm.LabelCategories = self._input_gt_dataset.categories()[
1698+
dm.AnnotationType.label
1699+
]
16741700

16751701
updated_gt_dataset = dm.Dataset(
1676-
categories=self._gt_dataset.categories(), media_type=dm.Image
1702+
categories=self._input_gt_dataset.categories(), media_type=dm.Image
16771703
)
16781704

16791705
excluded_boxes_info = _ExcludedAnnotationsInfo() # local for the function
16801706
excluded_gt_info = self._excluded_gt_info
16811707
gt_count_per_class = {}
16821708
skeleton_bbox_mapping = {} # skeleton id -> bbox id
1683-
for gt_sample in self._gt_dataset:
1709+
for gt_sample in self._input_gt_dataset:
16841710
boxes_sample = self._boxes_dataset.get(gt_sample.id, gt_sample.subset)
16851711
# Samples could be discarded, so we just skip them without an error
16861712
if not boxes_sample:
@@ -1799,6 +1825,13 @@ def _mangle_filenames(self):
17991825
def _prepare_job_params(self):
18001826
assert self._roi_infos is not _unset
18011827
assert self._skeleton_bbox_mapping is not _unset
1828+
assert self._input_gt_dataset is not _unset
1829+
1830+
# This list can be different from what is selected for validation
1831+
input_gt_filenames = set(sample.media.path for sample in self._input_gt_dataset)
1832+
image_id_to_filename = {
1833+
sample.attributes["id"]: sample.media.path for sample in self._boxes_dataset
1834+
}
18021835

18031836
# Make job layouts wrt. manifest params
18041837
# 1 job per task, 1 task for each point label
@@ -1824,6 +1857,7 @@ def _prepare_job_params(self):
18241857
for roi_info in self._roi_infos
18251858
if roi_info.bbox_label == label_id
18261859
if roi_info.bbox_id not in label_gt_roi_ids
1860+
if image_id_to_filename[roi_info.original_image_key] not in input_gt_filenames
18271861
]
18281862
random.shuffle(label_data_roi_ids)
18291863

@@ -1844,6 +1878,12 @@ def _prepare_job_params(self):
18441878

18451879
self._job_params = job_params
18461880

1881+
self.logger.info(
1882+
"Task creation for escrow '%s': will create %s assignments",
1883+
self.escrow_address,
1884+
sum(len(self.manifest.annotation.labels[jp.label_id].nodes) for jp in job_params),
1885+
)
1886+
18471887
def _prepare_job_labels(self):
18481888
self.point_labels = {}
18491889

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,10 @@ def _make_gt_dataset_for_job(self, job_id: int, job_dataset: dm.Dataset) -> dm.D
510510

511511
return job_gt_dataset
512512

513+
def _prepare_merged_dataset(self):
514+
super()._parse_gt() # We need to download the original GT dataset
515+
return super()._prepare_merged_dataset()
516+
513517

514518
class _SkeletonsFromBoxesValidator(_TaskValidatorWithPerJobGt):
515519
def __init__(self, *args, **kwargs):
@@ -845,6 +849,10 @@ def _update_gt_stats(
845849

846850
return updated_gt_stats
847851

852+
def _prepare_merged_dataset(self):
853+
super()._parse_gt() # We need to download the original GT dataset
854+
return super()._prepare_merged_dataset()
855+
848856

849857
def _compute_gt_stats_update(
850858
initial_gt_stats: _FailedGtAttempts, validation_gt_stats: _UpdatedFailedGtStats
@@ -894,8 +902,7 @@ def process_intermediate_results(
894902

895903
if logger.isEnabledFor(logging.DEBUG):
896904
logger.debug("process_intermediate_results for escrow %s", escrow_address)
897-
logger.debug("Task id %s", task_id)
898-
logger.debug("Task %s %s", task, getattr(task, "__dict__", None))
905+
logger.debug("Task id %s, %s", getattr(task, 'id', None), getattr(task, "__dict__", None))
899906

900907
initial_gt_stats = {
901908
gt_image_stat.gt_key: gt_image_stat.failed_attempts

0 commit comments

Comments
 (0)