Skip to content

Commit e2cd19e

Browse files
committed
Use the original GT for final annotation merging
1 parent a097c3e commit e2cd19e

File tree

1 file changed

+40
-58
lines changed

1 file changed

+40
-58
lines changed

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 40 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,19 @@
1111
import datumaro as dm
1212
import numpy as np
1313

14-
import src.core.tasks.boxes_from_points as boxes_from_points_task
15-
import src.core.tasks.points as points_task
16-
import src.core.tasks.simple as simple_task
17-
import src.core.tasks.skeletons_from_boxes as skeletons_from_boxes_task
1814
import src.cvat.api_calls as cvat_api
1915
import src.services.validation as db_service
2016
from src.core.annotation_meta import AnnotationMeta
2117
from src.core.config import Config
2218
from src.core.gt_stats import GtStats, ValidationFrameStats
23-
from src.core.storage import compose_data_bucket_filename
2419
from src.core.types import TaskTypes
2520
from src.core.validation_errors import DatasetValidationError, LowAccuracyError, TooFewGtError
2621
from src.core.validation_meta import JobMeta, ResultMeta, ValidationMeta
2722
from src.core.validation_results import ValidationFailure, ValidationSuccess
2823
from src.db.utils import ForUpdateParams
2924
from src.services.cloud import make_client as make_cloud_client
3025
from src.services.cloud.utils import BucketAccessInfo
31-
from src.utils.annotations import ProjectLabels, flatten_points
26+
from src.utils.annotations import ProjectLabels
3227
from src.utils.zip_archive import extract_zip_archive, write_dir_to_zip_archive
3328

3429
if TYPE_CHECKING:
@@ -108,7 +103,7 @@ def __init__(
108103
self._rejected_jobs: _RejectedJobs | None = None
109104

110105
self._temp_dir: Path | None = None
111-
self._gt_dataset: dm.Dataset | None = None
106+
self._input_gt_dataset: dm.Dataset | None = None
112107
self._meta: AnnotationMeta = meta
113108

114109
def _require_field(self, field: T | None) -> T:
@@ -118,44 +113,26 @@ def _require_field(self, field: T | None) -> T:
118113
def _gt_key_to_sample_id(self, gt_key: str) -> str:
119114
return gt_key
120115

121-
def _get_meta_layout_and_serializer(self):
122-
if self.manifest.annotation.type == TaskTypes.image_boxes:
123-
return (
124-
simple_task.TaskMetaLayout(),
125-
simple_task.TaskMetaSerializer(),
126-
)
127-
if self.manifest.annotation.type == TaskTypes.image_points:
128-
return (
129-
points_task.TaskMetaLayout(),
130-
points_task.TaskMetaSerializer(),
131-
)
132-
if self.manifest.annotation.type == TaskTypes.image_boxes_from_points:
133-
return (
134-
boxes_from_points_task.TaskMetaLayout(),
135-
boxes_from_points_task.TaskMetaSerializer(),
136-
)
137-
if self.manifest.annotation.type == TaskTypes.image_skeletons_from_boxes:
138-
return (
139-
skeletons_from_boxes_task.TaskMetaLayout(),
140-
skeletons_from_boxes_task.TaskMetaSerializer(),
116+
def _parse_gt_dataset(self, gt_file_data: bytes) -> dm.Dataset:
117+
with TemporaryDirectory() as gt_temp_dir:
118+
gt_filename = os.path.join(gt_temp_dir, "gt_annotations.json")
119+
with open(gt_filename, "wb") as f:
120+
f.write(gt_file_data)
121+
122+
gt_dataset = dm.Dataset.import_from(
123+
gt_filename,
124+
format=DM_GT_DATASET_FORMAT_MAPPING[self.manifest.annotation.type],
141125
)
142-
raise AssertionError(f"Unknown task type {self.manifest.annotation.type}")
143126

144-
def _parse_gt(self):
145-
layout, serializer = self._get_meta_layout_and_serializer()
127+
gt_dataset.init_cache()
146128

147-
exchange_oracle_data_bucket = BucketAccessInfo.parse_obj(
148-
Config.exchange_oracle_storage_config
149-
)
150-
storage_client = make_cloud_client(exchange_oracle_data_bucket)
129+
return gt_dataset
151130

152-
self._gt_dataset = serializer.parse_gt_annotations(
153-
storage_client.download_file(
154-
compose_data_bucket_filename(
155-
self.escrow_address, self.chain_id, layout.GT_FILENAME
156-
),
157-
)
158-
)
131+
def _load_gt_dataset(self):
132+
input_gt_bucket = BucketAccessInfo.parse_obj(self.manifest.validation.gt_url)
133+
gt_bucket_client = make_cloud_client(input_gt_bucket)
134+
gt_data = gt_bucket_client.download_file(input_gt_bucket.path)
135+
self._input_gt_dataset = self._parse_gt_dataset(gt_data)
159136

160137
def _validate_jobs(self):
161138
manifest = self._require_field(self.manifest)
@@ -272,7 +249,7 @@ def _prepare_merged_dataset(self):
272249
tempdir = self._require_field(self._temp_dir)
273250
manifest = self._require_field(self.manifest)
274251
merged_annotations = self._require_field(self._merged_annotations)
275-
gt_dataset = self._require_field(self._gt_dataset)
252+
input_gt_dataset = self._require_field(self._input_gt_dataset)
276253

277254
merged_dataset_path = tempdir / "merged"
278255
merged_dataset_format = DM_DATASET_FORMAT_MAPPING[manifest.annotation.type]
@@ -281,7 +258,7 @@ def _prepare_merged_dataset(self):
281258
merged_dataset = dm.Dataset.import_from(
282259
os.fspath(merged_dataset_path), format=merged_dataset_format
283260
)
284-
self._put_gt_into_merged_dataset(gt_dataset, merged_dataset, manifest=manifest)
261+
self._put_gt_into_merged_dataset(input_gt_dataset, merged_dataset, manifest=manifest)
285262
self._restore_original_image_paths(merged_dataset)
286263

287264
updated_merged_dataset_path = tempdir / "merged_updated"
@@ -297,60 +274,65 @@ def _prepare_merged_dataset(self):
297274

298275
@classmethod
299276
def _put_gt_into_merged_dataset(
300-
cls, gt_dataset: dm.Dataset, merged_dataset: dm.Dataset, *, manifest: TaskManifest
277+
cls, input_gt_dataset: dm.Dataset, merged_dataset: dm.Dataset, *, manifest: TaskManifest
301278
) -> None:
302279
"""
303280
Updates the merged dataset inplace, writing GT annotations corresponding to the task type.
304281
"""
305282

306283
match manifest.annotation.type:
307284
case TaskTypes.image_boxes.value:
308-
merged_dataset.update(gt_dataset)
285+
merged_dataset.update(input_gt_dataset)
309286
case TaskTypes.image_points.value:
310287
merged_label_cat: dm.LabelCategories = merged_dataset.categories()[
311288
dm.AnnotationType.label
312289
]
290+
291+
# we support no more than 1 label so far
292+
assert len(manifest.annotation.labels) == 1
293+
313294
skeleton_label_id = next(
314295
i for i, label in enumerate(merged_label_cat) if not label.parent
315296
)
316297
point_label_id = next(i for i, label in enumerate(merged_label_cat) if label.parent)
317298

318-
for sample in gt_dataset:
299+
for sample in input_gt_dataset:
319300
annotations = [
320301
dm.Skeleton(
321302
elements=[
303+
# Put a point in the center of each GT bbox
304+
# Not ideal, but it's the target for now
322305
dm.Points(
323-
point.points,
306+
[bbox.x + bbox.w / 2, bbox.y + bbox.h / 2],
324307
label=point_label_id,
325-
attributes=point.attributes,
308+
attributes=bbox.attributes,
326309
)
327310
],
328311
label=skeleton_label_id,
329312
)
330-
for point in flatten_points(
331-
[p for p in sample.annotations if isinstance(p, dm.Points)]
332-
)
313+
for bbox in sample.annotations
314+
if isinstance(bbox, dm.Bbox)
333315
]
334316
merged_dataset.put(sample.wrap(annotations=annotations))
335317
case TaskTypes.image_label_binary.value:
336-
merged_dataset.update(gt_dataset)
318+
merged_dataset.update(input_gt_dataset)
337319
case TaskTypes.image_boxes_from_points:
338-
merged_dataset.update(gt_dataset)
320+
merged_dataset.update(input_gt_dataset)
339321
case TaskTypes.image_skeletons_from_boxes:
340-
# The original behavior is broken for skeletons
341-
gt_dataset = dm.Dataset(gt_dataset)
342-
gt_dataset = gt_dataset.transform(
322+
# The original behavior of project_labels is broken for skeletons
323+
input_gt_dataset = dm.Dataset(input_gt_dataset)
324+
input_gt_dataset = input_gt_dataset.transform(
343325
ProjectLabels, dst_labels=merged_dataset.categories()[dm.AnnotationType.label]
344326
)
345-
merged_dataset.update(gt_dataset)
327+
merged_dataset.update(input_gt_dataset)
346328
case _:
347329
raise AssertionError(f"Unknown task type {manifest.annotation.type}")
348330

349331
def validate(self) -> _ValidationResult:
350332
with TemporaryDirectory() as tempdir:
351333
self._temp_dir = Path(tempdir)
352334

353-
self._parse_gt()
335+
self._load_gt_dataset()
354336
self._validate_jobs()
355337
self._prepare_merged_dataset()
356338

0 commit comments

Comments
 (0)