1111import datumaro as dm
1212import numpy as np
1313
14- import src .core .tasks .boxes_from_points as boxes_from_points_task
15- import src .core .tasks .points as points_task
16- import src .core .tasks .simple as simple_task
17- import src .core .tasks .skeletons_from_boxes as skeletons_from_boxes_task
1814import src .cvat .api_calls as cvat_api
1915import src .services .validation as db_service
2016from src .core .annotation_meta import AnnotationMeta
2117from src .core .config import Config
2218from src .core .gt_stats import GtStats , ValidationFrameStats
23- from src .core .storage import compose_data_bucket_filename
2419from src .core .types import TaskTypes
2520from src .core .validation_errors import DatasetValidationError , LowAccuracyError , TooFewGtError
2621from src .core .validation_meta import JobMeta , ResultMeta , ValidationMeta
2722from src .core .validation_results import ValidationFailure , ValidationSuccess
2823from src .db .utils import ForUpdateParams
2924from src .services .cloud import make_client as make_cloud_client
3025from src .services .cloud .utils import BucketAccessInfo
31- from src .utils .annotations import ProjectLabels , flatten_points
26+ from src .utils .annotations import ProjectLabels
3227from src .utils .zip_archive import extract_zip_archive , write_dir_to_zip_archive
3328
3429if TYPE_CHECKING :
@@ -108,7 +103,7 @@ def __init__(
108103 self ._rejected_jobs : _RejectedJobs | None = None
109104
110105 self ._temp_dir : Path | None = None
111- self ._gt_dataset : dm .Dataset | None = None
106+ self ._input_gt_dataset : dm .Dataset | None = None
112107 self ._meta : AnnotationMeta = meta
113108
114109 def _require_field (self , field : T | None ) -> T :
@@ -118,44 +113,26 @@ def _require_field(self, field: T | None) -> T:
118113 def _gt_key_to_sample_id (self , gt_key : str ) -> str :
119114 return gt_key
120115
121- def _get_meta_layout_and_serializer (self ):
122- if self .manifest .annotation .type == TaskTypes .image_boxes :
123- return (
124- simple_task .TaskMetaLayout (),
125- simple_task .TaskMetaSerializer (),
126- )
127- if self .manifest .annotation .type == TaskTypes .image_points :
128- return (
129- points_task .TaskMetaLayout (),
130- points_task .TaskMetaSerializer (),
131- )
132- if self .manifest .annotation .type == TaskTypes .image_boxes_from_points :
133- return (
134- boxes_from_points_task .TaskMetaLayout (),
135- boxes_from_points_task .TaskMetaSerializer (),
136- )
137- if self .manifest .annotation .type == TaskTypes .image_skeletons_from_boxes :
138- return (
139- skeletons_from_boxes_task .TaskMetaLayout (),
140- skeletons_from_boxes_task .TaskMetaSerializer (),
116+ def _parse_gt_dataset (self , gt_file_data : bytes ) -> dm .Dataset :
117+ with TemporaryDirectory () as gt_temp_dir :
118+ gt_filename = os .path .join (gt_temp_dir , "gt_annotations.json" )
119+ with open (gt_filename , "wb" ) as f :
120+ f .write (gt_file_data )
121+
122+ gt_dataset = dm .Dataset .import_from (
123+ gt_filename ,
124+ format = DM_GT_DATASET_FORMAT_MAPPING [self .manifest .annotation .type ],
141125 )
142- raise AssertionError (f"Unknown task type { self .manifest .annotation .type } " )
143126
144- def _parse_gt (self ):
145- layout , serializer = self ._get_meta_layout_and_serializer ()
127+ gt_dataset .init_cache ()
146128
147- exchange_oracle_data_bucket = BucketAccessInfo .parse_obj (
148- Config .exchange_oracle_storage_config
149- )
150- storage_client = make_cloud_client (exchange_oracle_data_bucket )
129+ return gt_dataset
151130
152- self ._gt_dataset = serializer .parse_gt_annotations (
153- storage_client .download_file (
154- compose_data_bucket_filename (
155- self .escrow_address , self .chain_id , layout .GT_FILENAME
156- ),
157- )
158- )
131+ def _load_gt_dataset (self ):
132+ input_gt_bucket = BucketAccessInfo .parse_obj (self .manifest .validation .gt_url )
133+ gt_bucket_client = make_cloud_client (input_gt_bucket )
134+ gt_data = gt_bucket_client .download_file (input_gt_bucket .path )
135+ self ._input_gt_dataset = self ._parse_gt_dataset (gt_data )
159136
160137 def _validate_jobs (self ):
161138 manifest = self ._require_field (self .manifest )
@@ -272,7 +249,7 @@ def _prepare_merged_dataset(self):
272249 tempdir = self ._require_field (self ._temp_dir )
273250 manifest = self ._require_field (self .manifest )
274251 merged_annotations = self ._require_field (self ._merged_annotations )
275- gt_dataset = self ._require_field (self ._gt_dataset )
252+ input_gt_dataset = self ._require_field (self ._input_gt_dataset )
276253
277254 merged_dataset_path = tempdir / "merged"
278255 merged_dataset_format = DM_DATASET_FORMAT_MAPPING [manifest .annotation .type ]
@@ -281,7 +258,7 @@ def _prepare_merged_dataset(self):
281258 merged_dataset = dm .Dataset .import_from (
282259 os .fspath (merged_dataset_path ), format = merged_dataset_format
283260 )
284- self ._put_gt_into_merged_dataset (gt_dataset , merged_dataset , manifest = manifest )
261+ self ._put_gt_into_merged_dataset (input_gt_dataset , merged_dataset , manifest = manifest )
285262 self ._restore_original_image_paths (merged_dataset )
286263
287264 updated_merged_dataset_path = tempdir / "merged_updated"
@@ -297,60 +274,65 @@ def _prepare_merged_dataset(self):
297274
298275 @classmethod
299276 def _put_gt_into_merged_dataset (
300- cls , gt_dataset : dm .Dataset , merged_dataset : dm .Dataset , * , manifest : TaskManifest
277+ cls , input_gt_dataset : dm .Dataset , merged_dataset : dm .Dataset , * , manifest : TaskManifest
301278 ) -> None :
302279 """
303280 Updates the merged dataset inplace, writing GT annotations corresponding to the task type.
304281 """
305282
306283 match manifest .annotation .type :
307284 case TaskTypes .image_boxes .value :
308- merged_dataset .update (gt_dataset )
285+ merged_dataset .update (input_gt_dataset )
309286 case TaskTypes .image_points .value :
310287 merged_label_cat : dm .LabelCategories = merged_dataset .categories ()[
311288 dm .AnnotationType .label
312289 ]
290+
291+ # we support no more than 1 label so far
292+ assert len (manifest .annotation .labels ) == 1
293+
313294 skeleton_label_id = next (
314295 i for i , label in enumerate (merged_label_cat ) if not label .parent
315296 )
316297 point_label_id = next (i for i , label in enumerate (merged_label_cat ) if label .parent )
317298
318- for sample in gt_dataset :
299+ for sample in input_gt_dataset :
319300 annotations = [
320301 dm .Skeleton (
321302 elements = [
303+ # Put a point in the center of each GT bbox
304+ # Not ideal, but it's the target for now
322305 dm .Points (
323- point . points ,
306+ [ bbox . x + bbox . w / 2 , bbox . y + bbox . h / 2 ] ,
324307 label = point_label_id ,
325- attributes = point .attributes ,
308+ attributes = bbox .attributes ,
326309 )
327310 ],
328311 label = skeleton_label_id ,
329312 )
330- for point in flatten_points (
331- [p for p in sample .annotations if isinstance (p , dm .Points )]
332- )
313+ for bbox in sample .annotations
314+ if isinstance (bbox , dm .Bbox )
333315 ]
334316 merged_dataset .put (sample .wrap (annotations = annotations ))
335317 case TaskTypes .image_label_binary .value :
336- merged_dataset .update (gt_dataset )
318+ merged_dataset .update (input_gt_dataset )
337319 case TaskTypes .image_boxes_from_points :
338- merged_dataset .update (gt_dataset )
320+ merged_dataset .update (input_gt_dataset )
339321 case TaskTypes .image_skeletons_from_boxes :
340- # The original behavior is broken for skeletons
341- gt_dataset = dm .Dataset (gt_dataset )
342- gt_dataset = gt_dataset .transform (
322+ # The original behavior of project_labels is broken for skeletons
323+ input_gt_dataset = dm .Dataset (input_gt_dataset )
324+ input_gt_dataset = input_gt_dataset .transform (
343325 ProjectLabels , dst_labels = merged_dataset .categories ()[dm .AnnotationType .label ]
344326 )
345- merged_dataset .update (gt_dataset )
327+ merged_dataset .update (input_gt_dataset )
346328 case _:
347329 raise AssertionError (f"Unknown task type { manifest .annotation .type } " )
348330
349331 def validate (self ) -> _ValidationResult :
350332 with TemporaryDirectory () as tempdir :
351333 self ._temp_dir = Path (tempdir )
352334
353- self ._parse_gt ()
335+ self ._load_gt_dataset ()
354336 self ._validate_jobs ()
355337 self ._prepare_merged_dataset ()
356338
0 commit comments