Skip to content

Commit d2f3b0b

Browse files
committed
Merge remote-tracking branch 'upstream/mk/update_cvat_oracles' into zm/change_point_validation
2 parents 05b5442 + f6daa82 commit d2f3b0b

File tree

10 files changed

+383
-1115
lines changed

10 files changed

+383
-1115
lines changed

packages/examples/cvat/exchange-oracle/src/core/annotation_meta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class JobMeta(BaseModel):
1010
job_id: int
11+
task_id: int
1112
annotation_filename: Path
1213
annotator_wallet_address: str
1314
assignment_id: str

packages/examples/cvat/exchange-oracle/src/core/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ class CvatConfig:
180180
cvat_task_creation_check_interval = int(os.environ.get("CVAT_TASK_CREATION_CHECK_INTERVAL", 5))
181181

182182
# quality control settings
183-
cvat_val_frames_per_job_count = int(os.environ.get("CVAT_VAL_FRAMES_PER_JOB_COUNT", 2))
184183
cvat_max_validation_checks = int(os.environ.get("CVAT_MAX_VALIDATION_CHECKS", 3))
184+
"Maximum number of attempts to run a validation check on a job after completing annotation"
185+
185186
cvat_iou_threshold = float(os.environ.get("CVAT_IOU_THRESHOLD", 0.8))
186187
cvat_oks_sigma = float(os.environ.get("CVAT_OKS_SIGMA", 0.1))
187188

@@ -232,7 +233,7 @@ class FeaturesConfig:
232233
default_export_timeout = int(os.environ.get("DEFAULT_EXPORT_TIMEOUT", 60))
233234
"Timeout, in seconds, for annotations or dataset export waiting"
234235

235-
default_import_timeout = int(os.environ.get("DEFAULT_IMPORT_TIMEOUT", 60))
236+
default_import_timeout = int(os.environ.get("DEFAULT_IMPORT_TIMEOUT", 60 * 60))
236237
"Timeout, in seconds, for waiting on GT annotations import"
237238

238239
request_logging_enabled = to_bool(os.getenv("REQUEST_LOGGING_ENABLED", "0"))

packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
_NOTSET = object()
2727

2828

29+
class CVATException(Exception):
30+
"""Indicates that CVAT API returned unexpected response"""
31+
32+
2933
def _request_annotations(endpoint: Endpoint, cvat_id: int, format_name: str) -> bool:
3034
"""
3135
Requests annotations export.
@@ -314,7 +318,7 @@ def create_task(
314318
project_id: int,
315319
name: str,
316320
*,
317-
segment_size: int = Config.cvat_config.cvat_task_segment_size,
321+
segment_size: int,
318322
) -> models.TaskRead:
319323
logger = logging.getLogger("app")
320324
with get_api_client() as api_client:
@@ -325,7 +329,7 @@ def create_task(
325329
segment_size=segment_size,
326330
)
327331
try:
328-
(task_info, response) = api_client.tasks_api.create(task_write_request)
332+
(task_info, _) = api_client.tasks_api.create(task_write_request)
329333
return task_info
330334

331335
except exceptions.ApiException as e:
@@ -351,6 +355,7 @@ def put_task_data(
351355
task_id: int,
352356
cloudstorage_id: int,
353357
*,
358+
chunk_size: int,
354359
filenames: list[str] | None = None,
355360
sort_images: bool = True,
356361
validation_params: dict[str, str | float | list[str]] | None = None,
@@ -384,14 +389,11 @@ def put_task_data(
384389
mode=models.ValidationMode("gt_pool"),
385390
frames=gt_filenames,
386391
frame_selection_method=models.FrameSelectionMethod("manual"),
387-
frames_per_job_count=validation_params.get(
388-
"gt_frames_per_job_count",
389-
Config.cvat_config.cvat_val_frames_per_job_count,
390-
),
392+
frames_per_job_count=validation_params["gt_frames_per_job_count"],
391393
)
392394

393395
data_request = models.DataRequest(
394-
chunk_size=Config.cvat_config.cvat_task_segment_size,
396+
chunk_size=chunk_size,
395397
cloud_storage_id=cloudstorage_id,
396398
image_quality=Config.cvat_config.cvat_default_image_quality,
397399
use_cache=True,
@@ -614,9 +616,11 @@ def get_gt_job(task_id: int) -> models.JobRead:
614616
with get_api_client() as api_client:
615617
try:
616618
(paginated_jobs, _) = api_client.jobs_api.list(task_id=task_id, type="ground_truth")
617-
assert (
618-
len(paginated_jobs["results"]) == 1
619-
), f'CVAT returned {len(paginated_jobs["results"])} GT jobs'
619+
if (gt_jobs_count := len(paginated_jobs["results"])) != 1:
620+
raise CVATException(
621+
f"CVAT returned {gt_jobs_count} GT jobs for the task({task_id})"
622+
)
623+
620624
return paginated_jobs["results"][0]
621625
except (exceptions.ApiException, AssertionError) as ex:
622626
logger.exception(f"Exception when calling JobsApi.list(): {ex}\n")
@@ -631,7 +635,7 @@ def upload_gt_annotations(
631635
sleep_interval: int = 5,
632636
timeout: int | None = Config.features.default_import_timeout,
633637
) -> None:
634-
# FUTURE-TODO: use job.import_annotations when CVAT will support waiting timeout
638+
# FUTURE-TODO: use job.import_annotations when CVAT supports a waiting timeout
635639
start_time = datetime.now(timezone.utc)
636640
logger = logging.getLogger("app")
637641

@@ -653,7 +657,11 @@ def upload_gt_annotations(
653657
raise
654658

655659
request_id = json.loads(response.data).get("rq_id")
656-
assert request_id, "CVAT server have not returned rq_id in the response."
660+
if not request_id:
661+
raise CVATException(
662+
"CVAT server has not returned rq_id in the response when "
663+
f"uploading GT annotations to the {job_id} job"
664+
)
657665

658666
while True:
659667
try:
@@ -695,23 +703,24 @@ def get_quality_control_settings(task_id: int) -> models.QualitySettings:
695703
with get_api_client() as api_client:
696704
try:
697705
paginated_data, _ = api_client.quality_api.list_settings(task_id=task_id)
698-
assert len(paginated_data["results"]) == 1, (
699-
f'CVAT returned {len(paginated_data["results"])}'
700-
"quality control settings associated with the task"
701-
)
706+
if (settings_count := paginated_data["results"]) != 1:
707+
raise CVATException(
708+
f"CVAT returned {settings_count}"
709+
f"quality control settings associated with the task({task_id})"
710+
)
702711
return paginated_data["results"][0]
703712

704-
except (exceptions.ApiException, AssertionError) as e:
705-
logger.exception(f"Exception when calling QualityApi.list_settings(): {e}\n")
713+
except exceptions.ApiException as ex:
714+
logger.exception(f"Exception when calling QualityApi.list_settings(): {ex}\n")
706715
raise
707716

708717

709718
def update_quality_control_settings(
710719
settings_id: int,
711720
*,
712721
target_metric_threshold: float,
713-
max_validations_per_job: int = Config.cvat_config.cvat_max_validation_checks,
714722
target_metric: str = "accuracy",
723+
max_validations_per_job: int = Config.cvat_config.cvat_max_validation_checks,
715724
iou_threshold: float = Config.cvat_config.cvat_iou_threshold,
716725
oks_sigma: float = Config.cvat_config.cvat_oks_sigma,
717726
) -> None:

packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,10 @@ def _setup_gt_job(self, task_id: int, dataset_path: Path, format_name: str) -> N
211211
cvat_api.upload_gt_annotations(gt_job.id, dataset_path, format_name=format_name)
212212
cvat_api.finish_gt_job(gt_job.id)
213213

214-
def _setup_quality_settings(self, task_id: int, *, quality_threshold: float) -> None:
214+
def _setup_quality_settings(self, task_id: int, **overrides) -> None:
215215
settings = cvat_api.get_quality_control_settings(task_id)
216216
cvat_api.update_quality_control_settings(
217-
settings.id, target_metric_threshold=quality_threshold
217+
settings.id, target_metric_threshold=self.manifest.validation.min_quality, **overrides
218218
)
219219

220220
@abstractmethod
@@ -292,7 +292,7 @@ def _get_gt_filenames(
292292

293293
return list(matched_gt_filenames)
294294

295-
def split_dataset_per_task(
295+
def _split_dataset_per_task(
296296
self,
297297
data_filenames: list[str],
298298
*,
@@ -324,7 +324,6 @@ def build(self):
324324
# Create task configuration
325325
gt_filenames = self._get_gt_filenames(gt_dataset, data_filenames, manifest=manifest)
326326
data_to_be_annotated = [f for f in data_filenames if f not in set(gt_filenames)]
327-
segment_size = manifest.annotation.job_size or Config.cvat_config.cvat_task_segment_size
328327
label_configuration = make_label_configuration(manifest)
329328

330329
self._upload_task_meta(gt_dataset)
@@ -346,7 +345,7 @@ def build(self):
346345
cvat_webhook = cvat_api.create_cvat_webhook(cvat_project.id)
347346

348347
with SessionLocal.begin() as session:
349-
segment_size = manifest.annotation.job_size or Config.cvat_config.cvat_task_segment_size
348+
segment_size = manifest.annotation.job_size
350349
total_jobs = math.ceil(len(data_to_be_annotated) / segment_size)
351350

352351
self.logger.info(
@@ -376,7 +375,7 @@ def build(self):
376375
db_service.get_project_by_id(session, project_id, for_update=True) # lock the row
377376
db_service.add_project_images(session, cvat_project.id, data_filenames)
378377

379-
for data_subset in self.split_dataset_per_task(
378+
for data_subset in self._split_dataset_per_task(
380379
data_to_be_annotated,
381380
subset_size=Config.cvat_config.cvat_max_jobs_per_task * segment_size,
382381
):
@@ -388,13 +387,14 @@ def build(self):
388387
)
389388
db_service.get_task_by_id(session, task_id, for_update=True) # lock the row
390389

391-
# Actual task creation in CVAT takes some time, so it's done in an async process.
392390
# The task is fully created once 'update:task' or 'update:job' webhook is received.
393391
cvat_api.put_task_data(
394392
cvat_task.id,
395393
cloud_storage.id,
396394
filenames=data_subset,
397395
sort_images=False,
396+
# use the same value for the chunk size as for the job size
397+
chunk_size=segment_size,
398398
validation_params={
399399
"gt_filenames": gt_filenames, # include whole GT dataset into each task
400400
"gt_frames_per_job_count": manifest.validation.val_size,
@@ -453,15 +453,12 @@ def _setup_gt_job_for_cvat_task(
453453
task_id=task_id, gt_dataset=gt_dataset, dm_export_format="datumaro"
454454
)
455455

456-
def _setup_quality_settings(self, task_id) -> None:
456+
def _setup_quality_settings(self, task_id: int, **overrides) -> None:
457457
assert self._mean_gt_bbox_radius_estimation is not _unset
458458

459-
settings = cvat_api.get_quality_control_settings(task_id)
460-
cvat_api.update_quality_control_settings(
461-
settings.id,
462-
target_metric_threshold=self.manifest.validation.min_quality,
463-
oks_sigma=self._mean_gt_bbox_radius_estimation,
464-
)
459+
values = { "oks_sigma": self._mean_gt_bbox_radius_estimation }
460+
values.update(overrides)
461+
super()._setup_quality_settings(task_id, **values)
465462

466463

467464
class BoxesFromPointsTaskBuilder(_TaskBuilderBase):
@@ -476,7 +473,7 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -
476473
self._input_gt_dataset: _MaybeUnset[dm.Dataset] = _unset
477474
self._gt_dataset: _MaybeUnset[dm.Dataset] = _unset
478475
self._gt_roi_dataset: _MaybeUnset[dm.Dataset] = _unset
479-
self._gt_filenames: _MaybeUnset[dm.Dataset] = _unset
476+
self._gt_filenames: _MaybeUnset[Sequence[str]] = _unset
480477
self._points_dataset: _MaybeUnset[dm.Dataset] = _unset
481478

482479
self._bbox_point_mapping: _MaybeUnset[boxes_from_points_task.BboxPointMapping] = _unset
@@ -1372,7 +1369,7 @@ def _roi_key(e):
13721369
roi_bytes,
13731370
)
13741371

1375-
def split_dataset_per_task(
1372+
def _split_dataset_per_task(
13761373
self,
13771374
data_filenames: list[str],
13781375
*,
@@ -1435,9 +1432,7 @@ def _create_on_cvat(self):
14351432
cvat_webhook = cvat_api.create_cvat_webhook(cvat_project.id)
14361433

14371434
with SessionLocal.begin() as session:
1438-
segment_size = (
1439-
self.manifest.annotation.job_size or Config.cvat_config.cvat_task_segment_size
1440-
)
1435+
segment_size = self.manifest.annotation.job_size
14411436
total_jobs = math.ceil(len(self._data_filenames_to_be_annotated) / segment_size)
14421437
self.logger.info(
14431438
"Task creation for escrow '%s': will create %s assignments",
@@ -1475,7 +1470,7 @@ def _create_on_cvat(self):
14751470
],
14761471
)
14771472

1478-
for data_subset in self.split_dataset_per_task(
1473+
for data_subset in self._split_dataset_per_task(
14791474
self._data_filenames_to_be_annotated,
14801475
subset_size=Config.cvat_config.cvat_max_jobs_per_task * segment_size,
14811476
):
@@ -1495,13 +1490,13 @@ def _create_on_cvat(self):
14951490
for fn in self._gt_filenames
14961491
]
14971492

1498-
# FUTURE-FIXME:
1499-
# Actual task creation in CVAT takes some time, so it's done in an async process.
15001493
cvat_api.put_task_data(
15011494
cvat_task.id,
15021495
cvat_cloud_storage.id,
15031496
filenames=filenames,
15041497
sort_images=False,
1498+
# use the same value for the chunk size as for the job size
1499+
chunk_size=segment_size,
15051500
validation_params={
15061501
"gt_filenames": gt_filenames,
15071502
"gt_frames_per_job_count": self.manifest.validation.val_size,
@@ -1568,6 +1563,10 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -
15681563
self._excluded_gt_info: _MaybeUnset[_ExcludedAnnotationsInfo] = _unset
15691564
self._excluded_boxes_info: _MaybeUnset[_ExcludedAnnotationsInfo] = _unset
15701565

1566+
# Configuration / constants
1567+
self.job_size_mult = skeletons_from_boxes_task.DEFAULT_ASSIGNMENT_SIZE_MULTIPLIER
1568+
"Job size multiplier"
1569+
15711570
# TODO: consider WebP if produced files are too big
15721571
self.roi_file_ext = ".png" # supposed to be lossless and reasonably compressing
15731572
"File extension for RoI images, with leading dot (.) included"
@@ -2218,6 +2217,13 @@ def _mangle_filenames(self):
22182217
roi_info.bbox_id: str(uuid.uuid4()) + self.roi_file_ext for roi_info in self._roi_infos
22192218
}
22202219

2220+
@property
2221+
def _task_segment_size(self):
2222+
# Unlike other task types, here we use a grid of RoIs,
2223+
# so the absolute job size numbers from manifest are multiplied by the job size multiplier.
2224+
# Then, we add a percent of job tiles for validation, keeping the requested ratio.
2225+
return self.manifest.annotation.job_size * self.job_size_mult
2226+
22212227
def _prepare_task_params(self):
22222228
assert self._roi_infos is not _unset
22232229
assert self._skeleton_bbox_mapping is not _unset
@@ -2233,9 +2239,8 @@ def _prepare_task_params(self):
22332239

22342240
roi_info_by_id = {roi_info.bbox_id: roi_info for roi_info in self._roi_infos}
22352241
self._roi_info_by_id = roi_info_by_id
2236-
segment_size = (
2237-
self.manifest.annotation.job_size or Config.cvat_config.cvat_task_segment_size
2238-
)
2242+
2243+
segment_size = self._task_segment_size
22392244

22402245
for label_id, _ in enumerate(self.manifest.annotation.labels):
22412246
label_gt_roi_ids = set(
@@ -2458,6 +2463,7 @@ def _task_params_label_key(ts):
24582463
label_specs_by_skeleton = {
24592464
skeleton_label_id: [
24602465
{
2466+
# why not just use skeleton node?
24612467
"name": self.point_labels[(skeleton_label.name, skeleton_point)],
24622468
"type": "points",
24632469
}
@@ -2474,9 +2480,7 @@ def _task_params_label_key(ts):
24742480
_params["bucket_host"] = "http://minio:9010"
24752481
cvat_cloud_storage = cvat_api.create_cloudstorage(**_params)
24762482

2477-
segment_size = (
2478-
self.manifest.annotation.job_size or Config.cvat_config.cvat_task_segment_size
2479-
)
2483+
segment_size = self._task_segment_size
24802484

24812485
total_jobs = sum(
24822486
len(self.manifest.annotation.labels[tp.label_id].nodes)
@@ -2620,19 +2624,19 @@ def _task_params_label_key(ts):
26202624
)
26212625
db_service.get_task_by_id(session, task_id, for_update=True) # lock the row
26222626

2623-
# FUTURE-FIXME: now we must wait for the task to be created to set up GT
2624-
# Actual task creation in CVAT takes some time,
2625-
# so it's done in an async process.
26262627
# The task is fully created once 'update:task' or 'update:job'
26272628
# webhook is received.
26282629
cvat_api.put_task_data(
26292630
cvat_task.id,
26302631
cvat_cloud_storage.id,
26312632
filenames=point_label_filenames + gt_point_label_filenames,
26322633
sort_images=False,
2634+
# use the same value for the chunk size as for the job size
2635+
chunk_size=segment_size,
26332636
validation_params={
26342637
"gt_filenames": gt_point_label_filenames,
2635-
"gt_frames_per_job_count": self.manifest.validation.val_size,
2638+
"gt_frames_per_job_count": self.manifest.validation.val_size
2639+
* self.job_size_mult,
26362640
},
26372641
)
26382642

packages/examples/cvat/exchange-oracle/src/handlers/job_export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def prepare_annotation_metafile(
5656
annotation_filename=job_annotations[job.cvat_id].filename,
5757
annotator_wallet_address=job.latest_assignment.user_wallet_address,
5858
assignment_id=job.latest_assignment.id,
59+
task_id=job.cvat_task_id,
5960
)
6061
for job in jobs
6162
]

packages/examples/cvat/recording-oracle/src/core/annotation_meta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class JobMeta(BaseModel):
1010
job_id: int
11+
task_id: int
1112
annotation_filename: Path
1213
annotator_wallet_address: str
1314
assignment_id: str

0 commit comments

Comments
 (0)