2020import datumaro as dm
2121import numpy as np
2222from datumaro .util import filter_dict , take_by
23- from datumaro .util .annotation_util import BboxCoords , bbox_iou
23+ from datumaro .util .annotation_util import BboxCoords , bbox_iou , find_instances
2424from datumaro .util .image import IMAGE_EXTENSIONS , decode_image , encode_image
2525
2626import src .core .tasks .boxes_from_points as boxes_from_points_task
@@ -1709,13 +1709,18 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -
17091709 )
17101710 "Minimum absolute ROI size, (w, h)"
17111711
1712- self .boxes_format = "coco_instances "
1712+ self .boxes_format = "coco_person_keypoints "
17131713
17141714 self .embed_bbox_in_roi_image = True
17151715 "Put a bbox into the extracted skeleton RoI images"
17161716
17171717 self .embed_tile_border = True
17181718
1719+ self .embedded_point_radius = 15
1720+ self .min_embedded_point_radius_percent = 0.005
1721+ self .max_embedded_point_radius_percent = 0.01
1722+ self .embedded_point_color = (0 , 255 , 255 )
1723+
17191724 self .roi_embedded_bbox_color = (0 , 255 , 255 ) # BGR
17201725 self .roi_background_color = (245 , 240 , 242 ) # BGR - CVAT background color
17211726
@@ -1729,6 +1734,9 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -
17291734 GT annotations or samples for successful job launch
17301735 """
17311736
1737+ self .gt_id_attribute = "object_id"
1738+ "An additional way to match GT skeletons with input boxes"
1739+
17321740 # TODO: probably, need to also add an absolute number of minimum GT RoIs per class
17331741
17341742 def _download_input_data (self ):
@@ -1948,7 +1956,7 @@ def _validate_boxes_filenames(self):
19481956 )
19491957 )
19501958
1951- def _validate_boxes_annotations (self ):
1959+ def _validate_boxes_annotations (self ): # noqa: PLR0912
19521960 # Convert possible polygons and masks into boxes
19531961 self ._boxes_dataset .transform (InstanceSegmentsToBbox )
19541962 self ._boxes_dataset .init_cache ()
@@ -1962,22 +1970,87 @@ def _validate_boxes_annotations(self):
19621970 # Could fail on this as well
19631971 image_h , image_w = sample .media_as (dm .Image ).size
19641972
1965- sample_boxes = [a for a in sample .annotations if isinstance (a , dm .Bbox )]
1966- valid_boxes = []
1967- for bbox in sample_boxes :
1968- if not (
1969- (0 <= int (bbox .x ) < int (bbox .x + bbox .w ) <= image_w )
1970- and (0 <= int (bbox .y ) < int (bbox .y + bbox .h ) <= image_h )
1971- ):
1973+ valid_instances : list [tuple [dm .Bbox , dm .Points ]] = []
1974+ instances = find_instances (
1975+ [a for a in sample .annotations if isinstance (a , dm .Bbox | dm .Skeleton )]
1976+ )
1977+ for instance_anns in instances :
1978+ if len (instance_anns ) != 2 :
1979+ excluded_boxes_info .add_message (
1980+ "Sample '{}': object #{} ({}) skipped - unexpected group size ({})" .format (
1981+ sample .id ,
1982+ instance_anns [0 ].id ,
1983+ label_cat [instance_anns [0 ].label ].name ,
1984+ len (instance_anns ),
1985+ ),
1986+ sample_id = sample .id ,
1987+ sample_subset = sample .subset ,
1988+ )
1989+ continue
1990+
1991+ bbox = next ((a for a in instance_anns if isinstance (a , dm .Bbox )), None )
1992+ if not bbox :
1993+ excluded_boxes_info .add_message (
1994+ "Sample '{}': object #{} ({}) skipped - no matching bbox" .format (
1995+ sample .id , instance_anns [0 ].id , label_cat [instance_anns [0 ].label ].name
1996+ ),
1997+ sample_id = sample .id ,
1998+ sample_subset = sample .subset ,
1999+ )
2000+ continue
2001+
2002+ skeleton = next ((a for a in instance_anns if isinstance (a , dm .Skeleton )), None )
2003+ if not skeleton :
2004+ excluded_boxes_info .add_message (
2005+ "Sample '{}': object #{} ({}) skipped - no matching skeleton" .format (
2006+ sample .id , instance_anns [0 ].id , label_cat [instance_anns [0 ].label ].name
2007+ ),
2008+ sample_id = sample .id ,
2009+ sample_subset = sample .subset ,
2010+ )
2011+ continue
2012+
2013+ if len (skeleton .elements ) != 1 or len (skeleton .elements [0 ].points ) != 2 :
2014+ excluded_boxes_info .add_message (
2015+ "Sample '{}': object #{} ({}) skipped - invalid skeleton points" .format (
2016+ sample .id , skeleton .id , label_cat [skeleton .label ].name
2017+ ),
2018+ sample_id = sample .id ,
2019+ sample_subset = sample .subset ,
2020+ )
2021+ continue
2022+
2023+ point = skeleton .elements [0 ]
2024+ if not is_point_in_bbox (point .points [0 ], point .points [1 ], (0 , 0 , image_w , image_h )):
19722025 excluded_boxes_info .add_message (
1973- "Sample '{}': bbox #{} ({}) skipped - invalid coordinates" .format (
2026+ "Sample '{}': object #{} ({}) skipped - invalid point coordinates" .format (
2027+ sample .id , skeleton .id , label_cat [skeleton .label ].name
2028+ ),
2029+ sample_id = sample .id ,
2030+ sample_subset = sample .subset ,
2031+ )
2032+ continue
2033+
2034+ if not is_point_in_bbox (int (bbox .x ), int (bbox .y ), (0 , 0 , image_w , image_h )):
2035+ excluded_boxes_info .add_message (
2036+ "Sample '{}': object #{} ({}) skipped - invalid bbox coordinates" .format (
19742037 sample .id , bbox .id , label_cat [bbox .label ].name
19752038 ),
19762039 sample_id = sample .id ,
19772040 sample_subset = sample .subset ,
19782041 )
19792042 continue
19802043
2044+ if not is_point_in_bbox (point .points [0 ], point .points [1 ], bbox ):
2045+ excluded_boxes_info .add_message (
2046+ "Sample '{}': object #{} ({}) skipped - point is outside the bbox" .format (
2047+ sample .id , skeleton .id , label_cat [skeleton .label ].name
2048+ ),
2049+ sample_id = sample .id ,
2050+ sample_subset = sample .subset ,
2051+ )
2052+ continue
2053+
19812054 if bbox .id in visited_ids :
19822055 excluded_boxes_info .add_message (
19832056 "Sample '{}': bbox #{} ({}) skipped - repeated annotation id {}" .format (
@@ -1988,14 +2061,18 @@ def _validate_boxes_annotations(self):
19882061 )
19892062 continue
19902063
1991- valid_boxes .append (bbox )
2064+ valid_instances .append (
2065+ (bbox , point .wrap (group = bbox .group , id = bbox .id , attributes = bbox .attributes ))
2066+ )
19922067 visited_ids .add (bbox .id )
19932068
1994- excluded_boxes_info .excluded_count += len (sample_boxes ) - len (valid_boxes )
1995- excluded_boxes_info .total_count += len (sample_boxes )
2069+ excluded_boxes_info .excluded_count += len (instances ) - len (valid_instances )
2070+ excluded_boxes_info .total_count += len (instances )
19962071
1997- if len (valid_boxes ) != len (sample .annotations ):
1998- self ._boxes_dataset .put (sample .wrap (annotations = valid_boxes ))
2072+ if len (valid_instances ) != len (sample .annotations ):
2073+ self ._boxes_dataset .put (
2074+ sample .wrap (annotations = list (chain .from_iterable (valid_instances )))
2075+ )
19992076
20002077 if excluded_boxes_info .excluded_count > ceil (
20012078 excluded_boxes_info .total_count * self .max_discarded_threshold
@@ -2066,8 +2143,14 @@ def _find_unambiguous_matches(
20662143 input_boxes : list [dm .Bbox ],
20672144 gt_skeletons : list [dm .Skeleton ],
20682145 * ,
2146+ input_points : list [dm .Points ],
20692147 gt_annotations : list [dm .Annotation ],
20702148 ) -> list [tuple [dm .Bbox , dm .Skeleton ]]:
2149+ bbox_point_mapping : dict [int , dm .Points ] = {
2150+ bbox .id : next (p for p in input_points if p .group == bbox .group )
2151+ for bbox in input_boxes
2152+ }
2153+
20712154 matches = [
20722155 [
20732156 (input_bbox .label == gt_skeleton .label )
@@ -2077,6 +2160,18 @@ def _find_unambiguous_matches(
20772160 self ._get_skeleton_bbox (gt_skeleton , gt_annotations ),
20782161 )
20792162 )
2163+ and (input_point := bbox_point_mapping [input_bbox .id ])
2164+ and is_point_in_bbox (
2165+ input_point .points [0 ],
2166+ input_point .points [1 ],
2167+ self ._get_skeleton_bbox (gt_skeleton , gt_annotations ),
2168+ )
2169+ and (
2170+ # a way to customize matching if the default method is too rough
2171+ not (bbox_id := input_bbox .attributes .get (self .gt_id_attribute ))
2172+ or not (skeleton_id := gt_skeleton .attributes .get (self .gt_id_attribute ))
2173+ or bbox_id == skeleton_id
2174+ )
20802175 for gt_skeleton in gt_skeletons
20812176 ]
20822177 for input_bbox in input_boxes
@@ -2167,10 +2262,11 @@ def _find_good_gt_skeletons(
21672262 input_boxes : list [dm .Bbox ],
21682263 gt_skeletons : list [dm .Skeleton ],
21692264 * ,
2265+ input_points : list [dm .Points ],
21702266 gt_annotations : list [dm .Annotation ],
21712267 ) -> list [dm .Skeleton ]:
21722268 matches = _find_unambiguous_matches (
2173- input_boxes , gt_skeletons , gt_annotations = gt_annotations
2269+ input_boxes , gt_skeletons , input_points = input_points , gt_annotations = gt_annotations
21742270 )
21752271
21762272 matched_skeletons = []
@@ -2221,13 +2317,18 @@ def _find_good_gt_skeletons(
22212317
22222318 gt_skeletons = [a for a in gt_sample .annotations if isinstance (a , dm .Skeleton )]
22232319 input_boxes = [a for a in boxes_sample .annotations if isinstance (a , dm .Bbox )]
2320+ input_points = [a for a in boxes_sample .annotations if isinstance (a , dm .Points )]
2321+ assert len (input_boxes ) == len (input_points )
22242322
22252323 # Samples without boxes are allowed, so we just skip them without an error
22262324 if not gt_skeletons :
22272325 continue
22282326
22292327 matched_skeletons = _find_good_gt_skeletons (
2230- input_boxes , gt_skeletons , gt_annotations = gt_sample .annotations
2328+ input_boxes ,
2329+ gt_skeletons ,
2330+ input_points = input_points ,
2331+ gt_annotations = gt_sample .annotations ,
22312332 )
22322333 if not matched_skeletons :
22332334 continue
@@ -2294,9 +2395,10 @@ def _prepare_roi_infos(self):
22942395
22952396 rois : list [skeletons_from_boxes_task .RoiInfo ] = []
22962397 for sample in self ._boxes_dataset :
2297- for bbox in sample .annotations :
2298- if not isinstance (bbox , dm .Bbox ):
2299- continue
2398+ instances = find_instances (sample .annotations )
2399+ for instance_anns in instances :
2400+ bbox = next (a for a in instance_anns if isinstance (a , dm .Bbox ))
2401+ point = next (a for a in instance_anns if isinstance (a , dm .Points ))
23002402
23012403 # RoI is centered on bbox center
23022404 original_bbox_cx = int (bbox .x + bbox .w / 2 )
@@ -2320,6 +2422,8 @@ def _prepare_roi_infos(self):
23202422 bbox_label = bbox .label ,
23212423 bbox_x = new_bbox_x ,
23222424 bbox_y = new_bbox_y ,
2425+ point_x = point .points [0 ] - roi_x ,
2426+ point_y = point .points [1 ] - roi_y ,
23232427 roi_x = roi_x ,
23242428 roi_y = roi_y ,
23252429 roi_w = roi_w ,
@@ -2511,6 +2615,32 @@ def _draw_roi_bbox(self, roi_image: np.ndarray, bbox: dm.Bbox) -> np.ndarray:
25112615 cv2 .LINE_4 ,
25122616 )
25132617
2618+ def _draw_roi_point (self , roi_image : np .ndarray , point : tuple [float , float ]) -> np .ndarray :
2619+ roi_r = (roi_image .shape [0 ] ** 2 + roi_image .shape [1 ] ** 2 ) ** 0.5 / 2
2620+ radius = int (
2621+ min (
2622+ self .max_embedded_point_radius_percent * roi_r ,
2623+ max (self .embedded_point_radius , self .min_embedded_point_radius_percent * roi_r ),
2624+ )
2625+ )
2626+
2627+ roi_image = cv2 .circle (
2628+ roi_image ,
2629+ tuple (map (int , (point [0 ], point [1 ]))),
2630+ radius + 1 ,
2631+ (255 , 255 , 255 ),
2632+ - 1 ,
2633+ cv2 .LINE_4 ,
2634+ )
2635+ return cv2 .circle (
2636+ roi_image ,
2637+ tuple (map (int , (point [0 ], point [1 ]))),
2638+ radius ,
2639+ self .embedded_point_color ,
2640+ - 1 ,
2641+ cv2 .LINE_4 ,
2642+ )
2643+
25142644 def _extract_and_upload_rois (self ):
25152645 assert self ._roi_filenames is not _unset
25162646 assert self ._roi_infos is not _unset
@@ -2564,6 +2694,9 @@ def process_file(filename: str, image_pixels: np.ndarray):
25642694
25652695 if self .embed_bbox_in_roi_image :
25662696 roi_pixels = self ._draw_roi_bbox (roi_pixels , bbox_by_id [roi_info .bbox_id ])
2697+ roi_pixels = self ._draw_roi_point (
2698+ roi_pixels , (roi_info .point_x , roi_info .point_y )
2699+ )
25672700
25682701 filename = self ._roi_filenames [roi_info .bbox_id ]
25692702 roi_bytes = encode_image (roi_pixels , os .path .splitext (filename )[- 1 ])
0 commit comments