Skip to content

Commit 4fa8d23

Browse files
committed
Add wait timeout when importing GT annotations
1 parent eee13bc commit 4fa8d23

File tree

1 file changed

+66
-11
lines changed

1 file changed

+66
-11
lines changed

packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,24 @@
55
from collections.abc import Generator
66
from contextlib import contextmanager
77
from contextvars import ContextVar
8-
from datetime import timedelta
8+
from datetime import datetime, timedelta, timezone
99
from enum import Enum
1010
from http import HTTPStatus
1111
from io import BytesIO
12+
from pathlib import Path
1213
from time import sleep
13-
from typing import TYPE_CHECKING, Any
14+
from typing import Any
1415

1516
from cvat_sdk import Client, make_client
1617
from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
1718
from cvat_sdk.api_client.api_client import Endpoint
1819
from cvat_sdk.core.helpers import get_paginated_collection
20+
from cvat_sdk.core.uploading import AnnotationUploader
1921

2022
from src.core.config import Config
2123
from src.utils.enums import BetterEnumMeta
2224
from src.utils.time import utcnow
2325

24-
if TYPE_CHECKING:
25-
from cvat_sdk.core.proxies.jobs import Job
26-
2726
_NOTSET = object()
2827

2928

@@ -609,9 +608,9 @@ def clear_job_annotations(job_id: int) -> None:
609608
raise
610609

611610

612-
def setup_gt_job(task_id: int, filename: str, format_name: str) -> None:
611+
def setup_gt_job(task_id: int, dataset_path: Path, format_name: str) -> None:
613612
gt_job = get_gt_job(task_id)
614-
upload_gt_annotations(gt_job.id, filename, format_name)
613+
upload_gt_annotations(gt_job.id, dataset_path, format_name=format_name)
615614
finish_gt_job(gt_job.id)
616615
settings = get_quality_control_settings(task_id)
617616
update_quality_control_settings(settings.id)
@@ -634,12 +633,68 @@ def get_gt_job(task_id: int) -> models.JobRead:
634633

635634
def upload_gt_annotations(
636635
job_id: int,
637-
filename: str,
636+
dataset_path: Path,
637+
*,
638638
format_name: str,
639+
sleep_interval: int = 5,
640+
timeout: int | None = Config.features.default_import_timeout,
639641
) -> None:
642+
# FUTURE-TODO: use job.import_annotations when CVAT will support waiting timeout
643+
start_time = datetime.now(timezone.utc)
644+
logger = logging.getLogger("app")
645+
640646
with get_sdk_client() as client:
641-
job: Job = client.jobs.retrieve(job_id)
642-
job.import_annotations(format_name=format_name, filename=filename)
647+
uploader = AnnotationUploader(client)
648+
url = client.api_map.make_endpoint_url(
649+
client.api_client.jobs_api.create_annotations_endpoint.path, kwsub={"id": job_id}
650+
)
651+
652+
try:
653+
response = uploader.upload_file(
654+
url,
655+
dataset_path,
656+
query_params={"format": format_name, "filename": dataset_path.name},
657+
meta={"filename": dataset_path.name},
658+
)
659+
except Exception as ex:
660+
logger.exception(f"Exception occurred while importing GT annotations: {ex}\n")
661+
raise
662+
663+
request_id = json.loads(response.data).get("rq_id")
664+
assert request_id, "CVAT server have not returned rq_id in the response."
665+
666+
while True:
667+
try:
668+
(request_details, _) = client.api_client.requests_api.retrieve(request_id)
669+
except exceptions.ApiException as ex:
670+
logger.exception(f"Exception occurred while importing GT annotations: {ex}\n")
671+
raise
672+
673+
if (
674+
request_details.status.value
675+
== models.RequestStatus.allowed_values[("value",)]["FINISHED"]
676+
):
677+
break
678+
679+
if (
680+
request_details.status.value
681+
== models.RequestStatus.allowed_values[("value",)]["FAILED"]
682+
):
683+
raise Exception(
684+
"Annotations upload failed. "
685+
f"Previous status was: {request_details.status.value}."
686+
)
687+
688+
if timeout is not None and timedelta(seconds=timeout) < (utcnow() - start_time):
689+
raise Exception(
690+
"Failed to upload the GT annotations to CVAT within the timeout interval. "
691+
f"Previous status was: {request_details.status.value}. "
692+
f"Timeout: {timeout} seconds."
693+
)
694+
695+
sleep(sleep_interval)
696+
697+
logger.info(f"GT annotations for the job {job_id} have been uploaded to CVAT.")
643698

644699

645700
def get_quality_control_settings(task_id: int) -> models.QualitySettings:
@@ -662,7 +717,7 @@ def get_quality_control_settings(task_id: int) -> models.QualitySettings:
662717
def update_quality_control_settings(
663718
settings_id: int,
664719
*,
665-
max_validations_per_job: int,
720+
max_validations_per_job: int = Config.cvat_config.cvat_max_validation_checks,
666721
target_metric: str = "accuracy",
667722
target_metric_threshold: float = Config.cvat_config.cvat_target_metric_threshold,
668723
low_overlap_threshold: float = Config.cvat_config.cvat_low_overlap_threshold,

0 commit comments

Comments
 (0)