Skip to content

Commit 3ac3fbb

Browse files
zhiltsov-maxflopez7
authored andcommitted
[CVAT] Draw roi point along with bbox in skeleton tasks (#3356)
* Draw roi point along with bbox in skeleton tasks * Relax filtering for overlapping gt skeletons (#3358)
1 parent a298352 commit 3ac3fbb

File tree

3 files changed

+168
-25
lines changed

3 files changed

+168
-25
lines changed

packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class RoiInfo:
2323
bbox_y: int
2424
bbox_label: int
2525

26+
point_x: int
27+
point_y: int
28+
2629
# RoI is centered on the bbox center
2730
# Coordinates can be out of image boundaries.
2831
# In this case RoI includes extra margins to be centered on bbox center
@@ -117,7 +120,10 @@ def parse_skeleton_bbox_mapping(self, skeleton_bbox_mapping_data: bytes) -> Skel
117120
return {int(k): int(v) for k, v in parse_json(skeleton_bbox_mapping_data).items()}
118121

119122
def parse_roi_info(self, rois_info_data: bytes) -> RoiInfos:
120-
return [RoiInfo(**roi_info) for roi_info in parse_json(rois_info_data)]
123+
return [
124+
RoiInfo(**{"point_x": 0, "point_y": 0, **roi_info})
125+
for roi_info in parse_json(rois_info_data)
126+
]
121127

122128
def parse_roi_filenames(self, roi_filenames_data: bytes) -> RoiFilenames:
123129
return {int(k): v for k, v in parse_json(roi_filenames_data).items()}

packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py

Lines changed: 154 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import datumaro as dm
2121
import numpy as np
2222
from datumaro.util import filter_dict, take_by
23-
from datumaro.util.annotation_util import BboxCoords, bbox_iou
23+
from datumaro.util.annotation_util import BboxCoords, bbox_iou, find_instances
2424
from datumaro.util.image import IMAGE_EXTENSIONS, decode_image, encode_image
2525

2626
import src.core.tasks.boxes_from_points as boxes_from_points_task
@@ -1709,13 +1709,18 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -
17091709
)
17101710
"Minimum absolute ROI size, (w, h)"
17111711

1712-
self.boxes_format = "coco_instances"
1712+
self.boxes_format = "coco_person_keypoints"
17131713

17141714
self.embed_bbox_in_roi_image = True
17151715
"Put a bbox into the extracted skeleton RoI images"
17161716

17171717
self.embed_tile_border = True
17181718

1719+
self.embedded_point_radius = 15
1720+
self.min_embedded_point_radius_percent = 0.005
1721+
self.max_embedded_point_radius_percent = 0.01
1722+
self.embedded_point_color = (0, 255, 255)
1723+
17191724
self.roi_embedded_bbox_color = (0, 255, 255) # BGR
17201725
self.roi_background_color = (245, 240, 242) # BGR - CVAT background color
17211726

@@ -1729,6 +1734,9 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -
17291734
GT annotations or samples for successful job launch
17301735
"""
17311736

1737+
self.gt_id_attribute = "object_id"
1738+
"An additional way to match GT skeletons with input boxes"
1739+
17321740
# TODO: probably, need to also add an absolute number of minimum GT RoIs per class
17331741

17341742
def _download_input_data(self):
@@ -1948,7 +1956,7 @@ def _validate_boxes_filenames(self):
19481956
)
19491957
)
19501958

1951-
def _validate_boxes_annotations(self):
1959+
def _validate_boxes_annotations(self): # noqa: PLR0912
19521960
# Convert possible polygons and masks into boxes
19531961
self._boxes_dataset.transform(InstanceSegmentsToBbox)
19541962
self._boxes_dataset.init_cache()
@@ -1962,22 +1970,87 @@ def _validate_boxes_annotations(self):
19621970
# Could fail on this as well
19631971
image_h, image_w = sample.media_as(dm.Image).size
19641972

1965-
sample_boxes = [a for a in sample.annotations if isinstance(a, dm.Bbox)]
1966-
valid_boxes = []
1967-
for bbox in sample_boxes:
1968-
if not (
1969-
(0 <= int(bbox.x) < int(bbox.x + bbox.w) <= image_w)
1970-
and (0 <= int(bbox.y) < int(bbox.y + bbox.h) <= image_h)
1971-
):
1973+
valid_instances: list[tuple[dm.Bbox, dm.Points]] = []
1974+
instances = find_instances(
1975+
[a for a in sample.annotations if isinstance(a, dm.Bbox | dm.Skeleton)]
1976+
)
1977+
for instance_anns in instances:
1978+
if len(instance_anns) != 2:
1979+
excluded_boxes_info.add_message(
1980+
"Sample '{}': object #{} ({}) skipped - unexpected group size ({})".format(
1981+
sample.id,
1982+
instance_anns[0].id,
1983+
label_cat[instance_anns[0].label].name,
1984+
len(instance_anns),
1985+
),
1986+
sample_id=sample.id,
1987+
sample_subset=sample.subset,
1988+
)
1989+
continue
1990+
1991+
bbox = next((a for a in instance_anns if isinstance(a, dm.Bbox)), None)
1992+
if not bbox:
1993+
excluded_boxes_info.add_message(
1994+
"Sample '{}': object #{} ({}) skipped - no matching bbox".format(
1995+
sample.id, instance_anns[0].id, label_cat[instance_anns[0].label].name
1996+
),
1997+
sample_id=sample.id,
1998+
sample_subset=sample.subset,
1999+
)
2000+
continue
2001+
2002+
skeleton = next((a for a in instance_anns if isinstance(a, dm.Skeleton)), None)
2003+
if not skeleton:
2004+
excluded_boxes_info.add_message(
2005+
"Sample '{}': object #{} ({}) skipped - no matching skeleton".format(
2006+
sample.id, instance_anns[0].id, label_cat[instance_anns[0].label].name
2007+
),
2008+
sample_id=sample.id,
2009+
sample_subset=sample.subset,
2010+
)
2011+
continue
2012+
2013+
if len(skeleton.elements) != 1 or len(skeleton.elements[0].points) != 2:
2014+
excluded_boxes_info.add_message(
2015+
"Sample '{}': object #{} ({}) skipped - invalid skeleton points".format(
2016+
sample.id, skeleton.id, label_cat[skeleton.label].name
2017+
),
2018+
sample_id=sample.id,
2019+
sample_subset=sample.subset,
2020+
)
2021+
continue
2022+
2023+
point = skeleton.elements[0]
2024+
if not is_point_in_bbox(point.points[0], point.points[1], (0, 0, image_w, image_h)):
19722025
excluded_boxes_info.add_message(
1973-
"Sample '{}': bbox #{} ({}) skipped - invalid coordinates".format(
2026+
"Sample '{}': object #{} ({}) skipped - invalid point coordinates".format(
2027+
sample.id, skeleton.id, label_cat[skeleton.label].name
2028+
),
2029+
sample_id=sample.id,
2030+
sample_subset=sample.subset,
2031+
)
2032+
continue
2033+
2034+
if not is_point_in_bbox(int(bbox.x), int(bbox.y), (0, 0, image_w, image_h)):
2035+
excluded_boxes_info.add_message(
2036+
"Sample '{}': object #{} ({}) skipped - invalid bbox coordinates".format(
19742037
sample.id, bbox.id, label_cat[bbox.label].name
19752038
),
19762039
sample_id=sample.id,
19772040
sample_subset=sample.subset,
19782041
)
19792042
continue
19802043

2044+
if not is_point_in_bbox(point.points[0], point.points[1], bbox):
2045+
excluded_boxes_info.add_message(
2046+
"Sample '{}': object #{} ({}) skipped - point is outside the bbox".format(
2047+
sample.id, skeleton.id, label_cat[skeleton.label].name
2048+
),
2049+
sample_id=sample.id,
2050+
sample_subset=sample.subset,
2051+
)
2052+
continue
2053+
19812054
if bbox.id in visited_ids:
19822055
excluded_boxes_info.add_message(
19832056
"Sample '{}': bbox #{} ({}) skipped - repeated annotation id {}".format(
@@ -1988,14 +2061,18 @@ def _validate_boxes_annotations(self):
19882061
)
19892062
continue
19902063

1991-
valid_boxes.append(bbox)
2064+
valid_instances.append(
2065+
(bbox, point.wrap(group=bbox.group, id=bbox.id, attributes=bbox.attributes))
2066+
)
19922067
visited_ids.add(bbox.id)
19932068

1994-
excluded_boxes_info.excluded_count += len(sample_boxes) - len(valid_boxes)
1995-
excluded_boxes_info.total_count += len(sample_boxes)
2069+
excluded_boxes_info.excluded_count += len(instances) - len(valid_instances)
2070+
excluded_boxes_info.total_count += len(instances)
19962071

1997-
if len(valid_boxes) != len(sample.annotations):
1998-
self._boxes_dataset.put(sample.wrap(annotations=valid_boxes))
2072+
if len(valid_instances) != len(sample.annotations):
2073+
self._boxes_dataset.put(
2074+
sample.wrap(annotations=list(chain.from_iterable(valid_instances)))
2075+
)
19992076

20002077
if excluded_boxes_info.excluded_count > ceil(
20012078
excluded_boxes_info.total_count * self.max_discarded_threshold
@@ -2066,8 +2143,14 @@ def _find_unambiguous_matches(
20662143
input_boxes: list[dm.Bbox],
20672144
gt_skeletons: list[dm.Skeleton],
20682145
*,
2146+
input_points: list[dm.Points],
20692147
gt_annotations: list[dm.Annotation],
20702148
) -> list[tuple[dm.Bbox, dm.Skeleton]]:
2149+
bbox_point_mapping: dict[int, dm.Points] = {
2150+
bbox.id: next(p for p in input_points if p.group == bbox.group)
2151+
for bbox in input_boxes
2152+
}
2153+
20712154
matches = [
20722155
[
20732156
(input_bbox.label == gt_skeleton.label)
@@ -2077,6 +2160,18 @@ def _find_unambiguous_matches(
20772160
self._get_skeleton_bbox(gt_skeleton, gt_annotations),
20782161
)
20792162
)
2163+
and (input_point := bbox_point_mapping[input_bbox.id])
2164+
and is_point_in_bbox(
2165+
input_point.points[0],
2166+
input_point.points[1],
2167+
self._get_skeleton_bbox(gt_skeleton, gt_annotations),
2168+
)
2169+
and (
2170+
# a way to customize matching if the default method is too rough
2171+
not (bbox_id := input_bbox.attributes.get(self.gt_id_attribute))
2172+
or not (skeleton_id := gt_skeleton.attributes.get(self.gt_id_attribute))
2173+
or bbox_id == skeleton_id
2174+
)
20802175
for gt_skeleton in gt_skeletons
20812176
]
20822177
for input_bbox in input_boxes
@@ -2167,10 +2262,11 @@ def _find_good_gt_skeletons(
21672262
input_boxes: list[dm.Bbox],
21682263
gt_skeletons: list[dm.Skeleton],
21692264
*,
2265+
input_points: list[dm.Points],
21702266
gt_annotations: list[dm.Annotation],
21712267
) -> list[dm.Skeleton]:
21722268
matches = _find_unambiguous_matches(
2173-
input_boxes, gt_skeletons, gt_annotations=gt_annotations
2269+
input_boxes, gt_skeletons, input_points=input_points, gt_annotations=gt_annotations
21742270
)
21752271

21762272
matched_skeletons = []
@@ -2221,13 +2317,18 @@ def _find_good_gt_skeletons(
22212317

22222318
gt_skeletons = [a for a in gt_sample.annotations if isinstance(a, dm.Skeleton)]
22232319
input_boxes = [a for a in boxes_sample.annotations if isinstance(a, dm.Bbox)]
2320+
input_points = [a for a in boxes_sample.annotations if isinstance(a, dm.Points)]
2321+
assert len(input_boxes) == len(input_points)
22242322

22252323
# Samples without boxes are allowed, so we just skip them without an error
22262324
if not gt_skeletons:
22272325
continue
22282326

22292327
matched_skeletons = _find_good_gt_skeletons(
2230-
input_boxes, gt_skeletons, gt_annotations=gt_sample.annotations
2328+
input_boxes,
2329+
gt_skeletons,
2330+
input_points=input_points,
2331+
gt_annotations=gt_sample.annotations,
22312332
)
22322333
if not matched_skeletons:
22332334
continue
@@ -2294,9 +2395,10 @@ def _prepare_roi_infos(self):
22942395

22952396
rois: list[skeletons_from_boxes_task.RoiInfo] = []
22962397
for sample in self._boxes_dataset:
2297-
for bbox in sample.annotations:
2298-
if not isinstance(bbox, dm.Bbox):
2299-
continue
2398+
instances = find_instances(sample.annotations)
2399+
for instance_anns in instances:
2400+
bbox = next(a for a in instance_anns if isinstance(a, dm.Bbox))
2401+
point = next(a for a in instance_anns if isinstance(a, dm.Points))
23002402

23012403
# RoI is centered on bbox center
23022404
original_bbox_cx = int(bbox.x + bbox.w / 2)
@@ -2320,6 +2422,8 @@ def _prepare_roi_infos(self):
23202422
bbox_label=bbox.label,
23212423
bbox_x=new_bbox_x,
23222424
bbox_y=new_bbox_y,
2425+
point_x=point.points[0] - roi_x,
2426+
point_y=point.points[1] - roi_y,
23232427
roi_x=roi_x,
23242428
roi_y=roi_y,
23252429
roi_w=roi_w,
@@ -2511,6 +2615,32 @@ def _draw_roi_bbox(self, roi_image: np.ndarray, bbox: dm.Bbox) -> np.ndarray:
25112615
cv2.LINE_4,
25122616
)
25132617

2618+
def _draw_roi_point(self, roi_image: np.ndarray, point: tuple[float, float]) -> np.ndarray:
2619+
roi_r = (roi_image.shape[0] ** 2 + roi_image.shape[1] ** 2) ** 0.5 / 2
2620+
radius = int(
2621+
min(
2622+
self.max_embedded_point_radius_percent * roi_r,
2623+
max(self.embedded_point_radius, self.min_embedded_point_radius_percent * roi_r),
2624+
)
2625+
)
2626+
2627+
roi_image = cv2.circle(
2628+
roi_image,
2629+
tuple(map(int, (point[0], point[1]))),
2630+
radius + 1,
2631+
(255, 255, 255),
2632+
-1,
2633+
cv2.LINE_4,
2634+
)
2635+
return cv2.circle(
2636+
roi_image,
2637+
tuple(map(int, (point[0], point[1]))),
2638+
radius,
2639+
self.embedded_point_color,
2640+
-1,
2641+
cv2.LINE_4,
2642+
)
2643+
25142644
def _extract_and_upload_rois(self):
25152645
assert self._roi_filenames is not _unset
25162646
assert self._roi_infos is not _unset
@@ -2564,6 +2694,9 @@ def process_file(filename: str, image_pixels: np.ndarray):
25642694

25652695
if self.embed_bbox_in_roi_image:
25662696
roi_pixels = self._draw_roi_bbox(roi_pixels, bbox_by_id[roi_info.bbox_id])
2697+
roi_pixels = self._draw_roi_point(
2698+
roi_pixels, (roi_info.point_x, roi_info.point_y)
2699+
)
25672700

25682701
filename = self._roi_filenames[roi_info.bbox_id]
25692702
roi_bytes = encode_image(roi_pixels, os.path.splitext(filename)[-1])

packages/examples/cvat/exchange-oracle/src/utils/annotations.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import datumaro as dm
99
import numpy as np
1010
from datumaro.util import filter_dict, mask_tools
11-
from datumaro.util.annotation_util import find_group_leader, find_instances, max_bbox
11+
from datumaro.util.annotation_util import BboxCoords, find_group_leader, find_instances, max_bbox
1212
from defusedxml import ElementTree
1313

1414

@@ -343,8 +343,12 @@ def transform_item(self, item):
343343
return item.wrap(annotations=annotations)
344344

345345

346-
def is_point_in_bbox(px: float, py: float, bbox: dm.Bbox) -> bool:
347-
return (bbox.x <= px <= bbox.x + bbox.w) and (bbox.y <= py <= bbox.y + bbox.h)
346+
def is_point_in_bbox(px: float, py: float, bbox: dm.Bbox | BboxCoords) -> bool:
347+
if isinstance(bbox, dm.Bbox):
348+
bbox = bbox.get_bbox()
349+
350+
x, y, w, h = bbox
351+
return (x <= px <= x + w) and (y <= py <= y + h)
348352

349353

350354
class InstanceSegmentsToBbox(dm.ItemTransform):

0 commit comments

Comments
 (0)