@@ -138,6 +138,7 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int):
138138 self ._input_points_data : _MaybeUnset [bytes ] = _unset
139139
140140 self ._data_filenames : _MaybeUnset [Sequence [str ]] = _unset
141+ self ._input_gt_dataset : _MaybeUnset [dm .Dataset ] = _unset
141142 self ._gt_dataset : _MaybeUnset [dm .Dataset ] = _unset
142143 self ._points_dataset : _MaybeUnset [dm .Dataset ] = _unset
143144
@@ -242,7 +243,7 @@ def _parse_dataset(self, annotation_file_data: bytes, dataset_format: str) -> dm
242243 def _parse_gt (self ):
243244 assert self ._input_gt_data is not _unset
244245
245- self ._gt_dataset = self ._parse_dataset (
246+ self ._input_gt_dataset = self ._parse_dataset (
246247 self ._input_gt_data ,
247248 dataset_format = DM_GT_DATASET_FORMAT_MAPPING [self .manifest .annotation .type ],
248249 )
@@ -257,7 +258,7 @@ def _parse_points(self):
257258 def _validate_gt_labels (self ):
258259 gt_labels = set (
259260 label .name
260- for label in self ._gt_dataset .categories ()[dm .AnnotationType .label ]
261+ for label in self ._input_gt_dataset .categories ()[dm .AnnotationType .label ]
261262 if not label .parent
262263 )
263264 manifest_labels = set (label .name for label in self .manifest .annotation .labels )
@@ -268,13 +269,13 @@ def _validate_gt_labels(self):
268269 )
269270 )
270271
271- self ._gt_dataset .transform (
272+ self ._input_gt_dataset .transform (
272273 ProjectLabels , dst_labels = [label .name for label in self .manifest .annotation .labels ]
273274 )
274- self ._gt_dataset .init_cache ()
275+ self ._input_gt_dataset .init_cache ()
275276
276277 def _validate_gt_filenames (self ):
277- gt_filenames = set (s .id + s .media .ext for s in self ._gt_dataset )
278+ gt_filenames = set (s .id + s .media .ext for s in self ._input_gt_dataset )
278279
279280 known_data_filenames = set (self ._data_filenames )
280281 matched_gt_filenames = gt_filenames .intersection (known_data_filenames )
@@ -295,12 +296,12 @@ def _validate_gt_filenames(self):
295296 )
296297
297298 def _validate_gt_annotations (self ):
298- label_cat : dm .LabelCategories = self ._gt_dataset .categories ()[dm .AnnotationType .label ]
299+ label_cat : dm .LabelCategories = self ._input_gt_dataset .categories ()[dm .AnnotationType .label ]
299300
300301 excluded_gt_info = _ExcludedAnnotationsInfo ()
301302 excluded_samples = set ()
302303 visited_ids = set ()
303- for gt_sample in self ._gt_dataset :
304+ for gt_sample in self ._input_gt_dataset :
304305 # Could fail on this as well
305306 img_h , img_w = gt_sample .media_as (dm .Image ).size
306307
@@ -340,10 +341,10 @@ def _validate_gt_annotations(self):
340341 if not valid_boxes :
341342 excluded_samples .add ((gt_sample .id , gt_sample .subset ))
342343 else :
343- self ._gt_dataset .put (gt_sample .wrap (annotations = valid_boxes ))
344+ self ._input_gt_dataset .put (gt_sample .wrap (annotations = valid_boxes ))
344345
345346 for excluded_sample in excluded_samples :
346- self ._gt_dataset .remove (* excluded_sample )
347+ self ._input_gt_dataset .remove (* excluded_sample )
347348
348349 if excluded_gt_info .excluded_count :
349350 self .logger .warning (
@@ -369,7 +370,7 @@ def _validate_gt_annotations(self):
369370
370371 def _validate_gt (self ):
371372 assert self ._data_filenames is not _unset
372- assert self ._gt_dataset is not _unset
373+ assert self ._input_gt_dataset is not _unset
373374
374375 self ._validate_gt_filenames ()
375376 self ._validate_gt_labels ()
@@ -643,10 +644,10 @@ def _find_good_gt_boxes(
643644
644645 assert self ._data_filenames is not _unset
645646 assert self ._points_dataset is not _unset
646- assert self ._gt_dataset is not _unset
647- assert [label . name for label in self . _gt_dataset . categories ()[ dm . AnnotationType . label ]] == [
648- label .name for label in self .manifest . annotation . labels
649- ]
647+ assert self ._input_gt_dataset is not _unset
648+ assert [
649+ label .name for label in self ._input_gt_dataset . categories ()[ dm . AnnotationType . label ]
650+ ] == [ label . name for label in self . manifest . annotation . labels ]
650651 assert [
651652 label .name
652653 for label in self ._points_dataset .categories ()[dm .AnnotationType .label ]
@@ -656,17 +657,19 @@ def _find_good_gt_boxes(
656657 points_label_cat : dm .LabelCategories = self ._points_dataset .categories ()[
657658 dm .AnnotationType .label
658659 ]
659- gt_label_cat : dm .LabelCategories = self ._gt_dataset .categories ()[dm .AnnotationType .label ]
660+ gt_label_cat : dm .LabelCategories = self ._input_gt_dataset .categories ()[
661+ dm .AnnotationType .label
662+ ]
660663
661664 updated_gt_dataset = dm .Dataset (
662- categories = self ._gt_dataset .categories (), media_type = dm .Image
665+ categories = self ._input_gt_dataset .categories (), media_type = dm .Image
663666 )
664667
665668 excluded_points_info = _ExcludedAnnotationsInfo () # local for the function
666669 excluded_gt_info = self ._excluded_gt_info
667670 gt_count_per_class = {}
668671 bbox_point_mapping = {} # bbox id -> point id
669- for gt_sample in self ._gt_dataset :
672+ for gt_sample in self ._input_gt_dataset :
670673 points_sample = self ._points_dataset .get (gt_sample .id , gt_sample .subset )
671674 assert points_sample
672675
@@ -856,12 +859,24 @@ def _prepare_job_layout(self):
856859
857860 assert self ._rois is not _unset
858861 assert self ._bbox_point_mapping is not _unset
862+ assert self ._input_gt_dataset is not _unset
863+
864+ # This list can be different from what is selected for validation
865+ input_gt_filenames = set (sample .media .path for sample in self ._input_gt_dataset )
866+ original_image_id_to_filename = {
867+ sample .attributes ["id" ]: sample .media .path for sample in self ._points_dataset
868+ }
869+ point_id_to_original_image_id = {roi .point_id : roi .original_image_key for roi in self ._rois }
859870
860871 gt_point_ids = set (self ._bbox_point_mapping .values ())
861872 gt_filenames = [self ._roi_filenames [point_id ] for point_id in gt_point_ids ]
862873
863874 data_filenames = [
864- fn for point_id , fn in self ._roi_filenames .items () if not point_id in gt_point_ids
875+ fn
876+ for point_id , fn in self ._roi_filenames .items ()
877+ if not point_id in gt_point_ids
878+ if not original_image_id_to_filename [point_id_to_original_image_id [point_id ]]
879+ in input_gt_filenames
865880 ]
866881 random .shuffle (data_filenames )
867882
@@ -874,6 +889,12 @@ def _prepare_job_layout(self):
874889
875890 self ._job_layout = job_layout
876891
892+ self .logger .info (
893+ "Task creation for escrow '%s': will create %s assignments" ,
894+ self .escrow_address ,
895+ len (job_layout ),
896+ )
897+
877898 def _prepare_label_configuration (self ):
878899 self ._label_configuration = make_label_configuration (self .manifest )
879900
@@ -1145,6 +1166,7 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int):
11451166 self ._input_boxes_data : _MaybeUnset [bytes ] = _unset
11461167
11471168 self ._data_filenames : _MaybeUnset [Sequence [str ]] = _unset
1169+ self ._input_gt_dataset : _MaybeUnset [dm .Dataset ] = _unset
11481170 self ._gt_dataset : _MaybeUnset [dm .Dataset ] = _unset
11491171 self ._boxes_dataset : _MaybeUnset [dm .Dataset ] = _unset
11501172
@@ -1237,7 +1259,7 @@ def _parse_dataset(self, annotation_file_data: bytes, dataset_format: str) -> dm
12371259 def _parse_gt (self ):
12381260 assert self ._input_gt_data is not _unset
12391261
1240- self ._gt_dataset = self ._parse_dataset (
1262+ self ._input_gt_dataset = self ._parse_dataset (
12411263 self ._input_gt_data ,
12421264 dataset_format = DM_GT_DATASET_FORMAT_MAPPING [self .manifest .annotation .type ],
12431265 )
@@ -1252,7 +1274,7 @@ def _parse_boxes(self):
12521274 def _validate_gt_labels (self ):
12531275 gt_labels = set (
12541276 (label .name , label .parent )
1255- for label in self ._gt_dataset .categories ()[dm .AnnotationType .label ]
1277+ for label in self ._input_gt_dataset .categories ()[dm .AnnotationType .label ]
12561278 )
12571279
12581280 manifest_labels = set ()
@@ -1274,13 +1296,13 @@ def _validate_gt_labels(self):
12741296 )
12751297
12761298 # Reorder labels to match the manifest
1277- self ._gt_dataset .transform (
1299+ self ._input_gt_dataset .transform (
12781300 ProjectLabels , dst_labels = [label .name for label in self .manifest .annotation .labels ]
12791301 )
1280- self ._gt_dataset .init_cache ()
1302+ self ._input_gt_dataset .init_cache ()
12811303
12821304 def _validate_gt_filenames (self ):
1283- gt_filenames = set (s .id + s .media .ext for s in self ._gt_dataset )
1305+ gt_filenames = set (s .id + s .media .ext for s in self ._input_gt_dataset )
12841306
12851307 known_data_filenames = set (self ._data_filenames )
12861308 matched_gt_filenames = gt_filenames .intersection (known_data_filenames )
@@ -1316,12 +1338,12 @@ def _validate_skeleton(skeleton: dm.Skeleton, *, sample_bbox: dm.Bbox):
13161338 if not is_point_in_bbox (px , py , sample_bbox ):
13171339 raise InvalidCoordinates ("skeleton point is outside the image" )
13181340
1319- label_cat : dm .LabelCategories = self ._gt_dataset .categories ()[dm .AnnotationType .label ]
1341+ label_cat : dm .LabelCategories = self ._input_gt_dataset .categories ()[dm .AnnotationType .label ]
13201342
13211343 excluded_gt_info = _ExcludedAnnotationsInfo ()
13221344 excluded_samples = set ()
13231345 visited_ids = set ()
1324- for gt_sample in self ._gt_dataset :
1346+ for gt_sample in self ._input_gt_dataset :
13251347 # Could fail on this as well
13261348 img_h , img_w = gt_sample .media_as (dm .Image ).size
13271349 sample_bbox = dm .Bbox (0 , 0 , w = img_w , h = img_h )
@@ -1363,14 +1385,14 @@ def _validate_skeleton(skeleton: dm.Skeleton, *, sample_bbox: dm.Bbox):
13631385 else :
13641386 # Skeleton boxes can be in the list as well with the same ids / groups
13651387 skeleton_ids = set (a .id for a in valid_skeletons ) - {0 }
1366- self ._gt_dataset .put (
1388+ self ._input_gt_dataset .put (
13671389 gt_sample .wrap (
13681390 annotations = [a for a in gt_sample .annotations if a .id in skeleton_ids ]
13691391 )
13701392 )
13711393
13721394 for excluded_sample in excluded_samples :
1373- self ._gt_dataset .remove (* excluded_sample )
1395+ self ._input_gt_dataset .remove (* excluded_sample )
13741396
13751397 if excluded_gt_info .excluded_count :
13761398 self .logger .warning (
@@ -1396,7 +1418,7 @@ def _validate_skeleton(skeleton: dm.Skeleton, *, sample_bbox: dm.Bbox):
13961418
13971419 def _validate_gt (self ):
13981420 assert self ._data_filenames is not _unset
1399- assert self ._gt_dataset is not _unset
1421+ assert self ._input_gt_dataset is not _unset
14001422
14011423 self ._validate_gt_filenames ()
14021424 self ._validate_gt_labels ()
@@ -1653,12 +1675,14 @@ def _find_good_gt_skeletons(
16531675 matched_skeletons .append (gt_skeleton )
16541676 skeleton_bbox_mapping [gt_skeleton_id ] = matched_boxes [0 ].id
16551677
1678+ return matched_skeletons
1679+
16561680 assert self ._data_filenames is not _unset
16571681 assert self ._boxes_dataset is not _unset
1658- assert self ._gt_dataset is not _unset
1682+ assert self ._input_gt_dataset is not _unset
16591683 assert [
16601684 label .name
1661- for label in self ._gt_dataset .categories ()[dm .AnnotationType .label ]
1685+ for label in self ._input_gt_dataset .categories ()[dm .AnnotationType .label ]
16621686 if not label .parent
16631687 ] == [label .name for label in self .manifest .annotation .labels ]
16641688 assert [
@@ -1670,17 +1694,19 @@ def _find_good_gt_skeletons(
16701694 boxes_label_cat : dm .LabelCategories = self ._boxes_dataset .categories ()[
16711695 dm .AnnotationType .label
16721696 ]
1673- gt_label_cat : dm .LabelCategories = self ._gt_dataset .categories ()[dm .AnnotationType .label ]
1697+ gt_label_cat : dm .LabelCategories = self ._input_gt_dataset .categories ()[
1698+ dm .AnnotationType .label
1699+ ]
16741700
16751701 updated_gt_dataset = dm .Dataset (
1676- categories = self ._gt_dataset .categories (), media_type = dm .Image
1702+ categories = self ._input_gt_dataset .categories (), media_type = dm .Image
16771703 )
16781704
16791705 excluded_boxes_info = _ExcludedAnnotationsInfo () # local for the function
16801706 excluded_gt_info = self ._excluded_gt_info
16811707 gt_count_per_class = {}
16821708 skeleton_bbox_mapping = {} # skeleton id -> bbox id
1683- for gt_sample in self ._gt_dataset :
1709+ for gt_sample in self ._input_gt_dataset :
16841710 boxes_sample = self ._boxes_dataset .get (gt_sample .id , gt_sample .subset )
16851711 # Samples could be discarded, so we just skip them without an error
16861712 if not boxes_sample :
@@ -1799,6 +1825,13 @@ def _mangle_filenames(self):
17991825 def _prepare_job_params (self ):
18001826 assert self ._roi_infos is not _unset
18011827 assert self ._skeleton_bbox_mapping is not _unset
1828+ assert self ._input_gt_dataset is not _unset
1829+
1830+ # This list can be different from what is selected for validation
1831+ input_gt_filenames = set (sample .media .path for sample in self ._input_gt_dataset )
1832+ image_id_to_filename = {
1833+ sample .attributes ["id" ]: sample .media .path for sample in self ._boxes_dataset
1834+ }
18021835
18031836 # Make job layouts wrt. manifest params
18041837 # 1 job per task, 1 task for each point label
@@ -1824,6 +1857,7 @@ def _prepare_job_params(self):
18241857 for roi_info in self ._roi_infos
18251858 if roi_info .bbox_label == label_id
18261859 if roi_info .bbox_id not in label_gt_roi_ids
1860+ if image_id_to_filename [roi_info .original_image_key ] not in input_gt_filenames
18271861 ]
18281862 random .shuffle (label_data_roi_ids )
18291863
@@ -1844,6 +1878,12 @@ def _prepare_job_params(self):
18441878
18451879 self ._job_params = job_params
18461880
1881+ self .logger .info (
1882+ "Task creation for escrow '%s': will create %s assignments" ,
1883+ self .escrow_address ,
1884+ sum (len (self .manifest .annotation .labels [jp .label_id ].nodes ) for jp in job_params ),
1885+ )
1886+
18471887 def _prepare_job_labels (self ):
18481888 self .point_labels = {}
18491889
0 commit comments