Skip to content

Commit eee13bc

Browse files
committed
Update exchange oracle
1 parent ce35f40 commit eee13bc

File tree

4 files changed

+467
-123
lines changed

4 files changed

+467
-123
lines changed

packages/examples/cvat/exchange-oracle/src/core/config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,24 @@ class CronConfig:
167167

168168

169169
class CvatConfig:
170+
# TODO: it looks odd to use cvat_ prefix in class attributes inside CvatConfig
170171
cvat_url = os.environ.get("CVAT_URL", "http://localhost:8080")
171172
cvat_admin = os.environ.get("CVAT_ADMIN", "admin")
172173
cvat_admin_pass = os.environ.get("CVAT_ADMIN_PASS", "admin")
173174
cvat_org_slug = os.environ.get("CVAT_ORG_SLUG", "")
174175

175176
cvat_job_overlap = int(os.environ.get("CVAT_JOB_OVERLAP", 0))
176-
cvat_job_segment_size = int(os.environ.get("CVAT_JOB_SEGMENT_SIZE", 150))
177+
cvat_task_segment_size = int(os.environ.get("CVAT_TASK_SEGMENT_SIZE", 150))
177178
cvat_default_image_quality = int(os.environ.get("CVAT_DEFAULT_IMAGE_QUALITY", 70))
179+
cvat_max_jobs_per_task = int(os.environ.get("CVAT_MAX_JOBS_PER_TASK", 10 * 1000))
180+
181+
# quality control settings
182+
cvat_val_frames_per_job_count = int(os.environ.get("CVAT_VAL_FRAMES_PER_JOB_COUNT", 2))
183+
cvat_max_validation_checks = int(os.environ.get("CVAT_MAX_VALIDATION_CHECKS", 3))
184+
cvat_iou_threshold = float(os.environ.get("CVAT_IOU_THRESHOLD", 0.5))
185+
cvat_low_overlap_threshold = float(os.environ.get("CVAT_LOW_OVERLAP_THRESHOLD", 0.8))
186+
cvat_target_metric_threshold = cvat_low_overlap_threshold
187+
cvat_oks_sigma = float(os.environ.get("CVAT_OKS_SIGMA", 0.1))
178188

179189
cvat_incoming_webhooks_url = os.environ.get("CVAT_INCOMING_WEBHOOKS_URL")
180190
cvat_webhook_secret = os.environ.get("CVAT_WEBHOOK_SECRET", "thisisasamplesecret")
@@ -223,6 +233,9 @@ class FeaturesConfig:
223233
default_export_timeout = int(os.environ.get("DEFAULT_EXPORT_TIMEOUT", 60))
224234
"Timeout, in seconds, for annotations or dataset export waiting"
225235

236+
default_import_timeout = int(os.environ.get("DEFAULT_IMPORT_TIMEOUT", 60))
237+
"Timeout, in seconds, for waiting on GT annotations import"
238+
226239
request_logging_enabled = to_bool(os.getenv("REQUEST_LOGGING_ENABLED", "0"))
227240
"Allow to log request details for each request"
228241

packages/examples/cvat/exchange-oracle/src/core/tasks/skeletons_from_boxes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
@frozen(kw_only=True)
1919
class RoiInfo:
2020
original_image_key: int
21+
original_image_id: str
2122
bbox_id: int
2223
bbox_x: int
2324
bbox_y: int

packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from http import HTTPStatus
1111
from io import BytesIO
1212
from time import sleep
13-
from typing import Any
13+
from typing import TYPE_CHECKING, Any
1414

15+
from cvat_sdk import Client, make_client
1516
from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
1617
from cvat_sdk.api_client.api_client import Endpoint
1718
from cvat_sdk.core.helpers import get_paginated_collection
@@ -20,6 +21,9 @@
2021
from src.utils.enums import BetterEnumMeta
2122
from src.utils.time import utcnow
2223

24+
if TYPE_CHECKING:
25+
from cvat_sdk.core.proxies.jobs import Job
26+
2327
_NOTSET = object()
2428

2529

@@ -122,6 +126,16 @@ def get_api_client() -> ApiClient:
122126
return api_client
123127

124128

129+
def get_sdk_client() -> Client:
130+
client = make_client(
131+
host=Config.cvat_config.cvat_url,
132+
credentials=(Config.cvat_config.cvat_admin, Config.cvat_config.cvat_admin_pass),
133+
)
134+
client.organization_slug = Config.cvat_config.cvat_org_slug
135+
136+
return client
137+
138+
125139
def create_cloudstorage(
126140
provider: str,
127141
bucket_name: str,
@@ -297,14 +311,19 @@ def create_cvat_webhook(project_id: int) -> models.WebhookRead:
297311
raise
298312

299313

300-
def create_task(project_id: int, name: str) -> models.TaskRead:
314+
def create_task(
315+
project_id: int,
316+
name: str,
317+
*,
318+
segment_size: int = Config.cvat_config.cvat_task_segment_size,
319+
) -> models.TaskRead:
301320
logger = logging.getLogger("app")
302321
with get_api_client() as api_client:
303322
task_write_request = models.TaskWriteRequest(
304323
name=name,
305324
project_id=project_id,
306325
overlap=0,
307-
segment_size=Config.cvat_config.cvat_job_segment_size,
326+
segment_size=segment_size,
308327
)
309328
try:
310329
(task_info, response) = api_client.tasks_api.create(task_write_request)
@@ -335,8 +354,14 @@ def put_task_data(
335354
*,
336355
filenames: list[str] | None = None,
337356
sort_images: bool = True,
357+
validation_params: dict[str, str | float | list[str]] | None = None,
338358
) -> None:
339359
logger = logging.getLogger("app")
360+
sorting_method = (
361+
models.SortingMethod("lexicographical")
362+
if sort_images
363+
else models.SortingMethod("predefined")
364+
)
340365

341366
with get_api_client() as api_client:
342367
kwargs = {}
@@ -345,21 +370,42 @@ def put_task_data(
345370
else:
346371
kwargs["filename_pattern"] = "*"
347372

373+
if validation_params:
374+
logger.info(
375+
f"The {sorting_method} is ignored."
376+
'Only "random" sorting can be used when validation parameters passed.'
377+
)
378+
sorting_method = models.SortingMethod("random")
379+
380+
gt_filenames = validation_params["gt_filenames"]
381+
if missed_filenames := set(gt_filenames) - set(filenames):
382+
filenames.extend(missed_filenames)
383+
384+
kwargs["validation_params"] = models.DataRequestValidationParams(
385+
mode=models.ValidationMode("gt_pool"),
386+
frames=gt_filenames,
387+
frame_selection_method=models.FrameSelectionMethod("manual"),
388+
frames_per_job_count=validation_params.get(
389+
"gt_frames_per_job_count",
390+
Config.cvat_config.cvat_val_frames_per_job_count,
391+
),
392+
)
393+
348394
data_request = models.DataRequest(
349-
chunk_size=Config.cvat_config.cvat_job_segment_size,
395+
chunk_size=Config.cvat_config.cvat_task_segment_size,
350396
cloud_storage_id=cloudstorage_id,
351397
image_quality=Config.cvat_config.cvat_default_image_quality,
352398
use_cache=True,
353399
use_zip_chunks=True,
354-
sorting_method="lexicographical" if sort_images else "predefined",
400+
sorting_method=sorting_method,
355401
**kwargs,
356402
)
357403
try:
358404
(_, response) = api_client.tasks_api.create_data(task_id, data_request=data_request)
359405
return
360406

361407
except exceptions.ApiException as e:
362-
logger.exception(f"Exception when calling ProjectsApi.put_task_data: {e}\n")
408+
logger.exception(f"Exception when calling tasks_api.create_data: {e}\n")
363409
raise
364410

365411

@@ -563,36 +609,138 @@ def clear_job_annotations(job_id: int) -> None:
563609
raise
564610

565611

566-
def update_job_assignee(id: str, assignee_id: int | None):
612+
def setup_gt_job(task_id: int, filename: str, format_name: str) -> None:
613+
gt_job = get_gt_job(task_id)
614+
upload_gt_annotations(gt_job.id, filename, format_name)
615+
finish_gt_job(gt_job.id)
616+
settings = get_quality_control_settings(task_id)
617+
update_quality_control_settings(settings.id)
618+
619+
620+
def get_gt_job(task_id: int) -> models.JobRead:
621+
logger = logging.getLogger("app")
622+
623+
with get_api_client() as api_client:
624+
try:
625+
(paginated_jobs, _) = api_client.jobs_api.list(task_id=task_id, type="ground_truth")
626+
assert (
627+
len(paginated_jobs["results"]) == 1
628+
), f'CVAT returned {len(paginated_jobs["results"])} GT jobs'
629+
return paginated_jobs["results"][0]
630+
except (exceptions.ApiException, AssertionError) as ex:
631+
logger.exception(f"Exception when calling JobsApi.list(): {ex}\n")
632+
raise
633+
634+
635+
def upload_gt_annotations(
636+
job_id: int,
637+
filename: str,
638+
format_name: str,
639+
) -> None:
640+
with get_sdk_client() as client:
641+
job: Job = client.jobs.retrieve(job_id)
642+
job.import_annotations(format_name=format_name, filename=filename)
643+
644+
645+
def get_quality_control_settings(task_id: int) -> models.QualitySettings:
567646
logger = logging.getLogger("app")
568647

569648
with get_api_client() as api_client:
570649
try:
571-
api_client.jobs_api.partial_update(
572-
id=id,
573-
patched_job_write_request=models.PatchedJobWriteRequest(assignee=assignee_id),
650+
paginated_data, _ = api_client.quality_api.list_settings(task_id=task_id)
651+
assert len(paginated_data["results"]) == 1, (
652+
f'CVAT returned {len(paginated_data["results"])}'
653+
"quality control settings associated with the task"
654+
)
655+
return paginated_data["results"][0]
656+
657+
except (exceptions.ApiException, AssertionError) as e:
658+
logger.exception(f"Exception when calling QualityApi.list_settings(): {e}\n")
659+
raise
660+
661+
662+
def update_quality_control_settings(
663+
settings_id: int,
664+
*,
665+
max_validations_per_job: int,
666+
target_metric: str = "accuracy",
667+
target_metric_threshold: float = Config.cvat_config.cvat_target_metric_threshold,
668+
low_overlap_threshold: float = Config.cvat_config.cvat_low_overlap_threshold,
669+
iou_threshold: float = Config.cvat_config.cvat_iou_threshold,
670+
oks_sigma: float = Config.cvat_config.cvat_oks_sigma,
671+
) -> None:
672+
logger = logging.getLogger("app")
673+
674+
with get_api_client() as api_client:
675+
try:
676+
api_client.quality_api.partial_update_settings(
677+
settings_id,
678+
patched_quality_settings_request=models.PatchedQualitySettingsRequest(
679+
max_validations_per_job=max_validations_per_job,
680+
target_metric=target_metric,
681+
target_metric_threshold=target_metric_threshold,
682+
iou_threshold=iou_threshold,
683+
low_overlap_threshold=low_overlap_threshold,
684+
oks_sigma=oks_sigma,
685+
),
574686
)
575687
except exceptions.ApiException as e:
576-
logger.exception(f"Exception when calling JobsApi.partial_update(): {e}\n")
688+
logger.exception(f"Exception when calling QualityApi.partial_update_settings(): {e}\n")
577689
raise
578690

579691

580-
def restart_job(id: str, *, assignee_id: int | None = None):
692+
def _update_job(
693+
job_id: int,
694+
*,
695+
assignee_id: int | None | object = _NOTSET,
696+
stage: models.JobStage | None = None,
697+
state: models.OperationStatus | None = None,
698+
) -> None:
699+
to_update = {
700+
attr: value
701+
for attr, value in {
702+
"stage": stage,
703+
"state": state,
704+
}.items()
705+
if value
706+
}
707+
708+
if assignee_id is not _NOTSET:
709+
to_update["assignee"] = assignee_id
710+
711+
assert to_update
712+
581713
logger = logging.getLogger("app")
582714

583715
with get_api_client() as api_client:
584716
try:
585717
api_client.jobs_api.partial_update(
586-
id=id,
587-
patched_job_write_request=models.PatchedJobWriteRequest(
588-
stage="annotation", state="new", assignee=assignee_id
589-
),
718+
job_id, patched_job_write_request=models.PatchedJobWriteRequest(**to_update)
590719
)
591720
except exceptions.ApiException as e:
592721
logger.exception(f"Exception when calling JobsApi.partial_update(): {e}\n")
593722
raise
594723

595724

725+
def update_job_assignee(id: int, assignee_id: int | None):
726+
_update_job(id, assignee_id=assignee_id)
727+
728+
729+
def restart_job(id: str, *, assignee_id: int | None = None):
730+
_update_job(
731+
id,
732+
stage=models.JobStage("annotation"),
733+
state=models.OperationStatus("new"),
734+
assignee_id=assignee_id,
735+
)
736+
737+
738+
def finish_gt_job(job_id: int) -> None:
739+
_update_job(
740+
job_id, stage=models.JobStage("acceptance"), state=models.OperationStatus("completed")
741+
)
742+
743+
596744
def get_user_id(user_email: str) -> int:
597745
logger = logging.getLogger("app")
598746

0 commit comments

Comments
 (0)