55from collections .abc import Generator
66from contextlib import contextmanager
77from contextvars import ContextVar
8- from datetime import timedelta
8+ from datetime import datetime , timedelta , timezone
99from enum import Enum
1010from http import HTTPStatus
1111from io import BytesIO
12+ from pathlib import Path
1213from time import sleep
13- from typing import TYPE_CHECKING , Any
14+ from typing import Any
1415
1516from cvat_sdk import Client , make_client
1617from cvat_sdk .api_client import ApiClient , Configuration , exceptions , models
1718from cvat_sdk .api_client .api_client import Endpoint
1819from cvat_sdk .core .helpers import get_paginated_collection
20+ from cvat_sdk .core .uploading import AnnotationUploader
1921
2022from src .core .config import Config
2123from src .utils .enums import BetterEnumMeta
2224from src .utils .time import utcnow
2325
24- if TYPE_CHECKING :
25- from cvat_sdk .core .proxies .jobs import Job
26-
2726_NOTSET = object ()
2827
2928
@@ -609,9 +608,9 @@ def clear_job_annotations(job_id: int) -> None:
609608 raise
610609
611610
612- def setup_gt_job (task_id : int , filename : str , format_name : str ) -> None :
611+ def setup_gt_job (task_id : int , dataset_path : Path , format_name : str ) -> None :
613612 gt_job = get_gt_job (task_id )
614- upload_gt_annotations (gt_job .id , filename , format_name )
613+ upload_gt_annotations (gt_job .id , dataset_path , format_name = format_name )
615614 finish_gt_job (gt_job .id )
616615 settings = get_quality_control_settings (task_id )
617616 update_quality_control_settings (settings .id )
@@ -634,12 +633,68 @@ def get_gt_job(task_id: int) -> models.JobRead:
634633
635634def upload_gt_annotations (
636635 job_id : int ,
637- filename : str ,
636+ dataset_path : Path ,
637+ * ,
638638 format_name : str ,
639+ sleep_interval : int = 5 ,
640+ timeout : int | None = Config .features .default_import_timeout ,
639641) -> None :
642+ # FUTURE-TODO: use job.import_annotations when CVAT will support waiting timeout
643+ start_time = datetime .now (timezone .utc )
644+ logger = logging .getLogger ("app" )
645+
640646 with get_sdk_client () as client :
641- job : Job = client .jobs .retrieve (job_id )
642- job .import_annotations (format_name = format_name , filename = filename )
647+ uploader = AnnotationUploader (client )
648+ url = client .api_map .make_endpoint_url (
649+ client .api_client .jobs_api .create_annotations_endpoint .path , kwsub = {"id" : job_id }
650+ )
651+
652+ try :
653+ response = uploader .upload_file (
654+ url ,
655+ dataset_path ,
656+ query_params = {"format" : format_name , "filename" : dataset_path .name },
657+ meta = {"filename" : dataset_path .name },
658+ )
659+ except Exception as ex :
660+ logger .exception (f"Exception occurred while importing GT annotations: { ex } \n " )
661+ raise
662+
663+ request_id = json .loads (response .data ).get ("rq_id" )
664+ assert request_id , "CVAT server have not returned rq_id in the response."
665+
666+ while True :
667+ try :
668+ (request_details , _ ) = client .api_client .requests_api .retrieve (request_id )
669+ except exceptions .ApiException as ex :
670+ logger .exception (f"Exception occurred while importing GT annotations: { ex } \n " )
671+ raise
672+
673+ if (
674+ request_details .status .value
675+ == models .RequestStatus .allowed_values [("value" ,)]["FINISHED" ]
676+ ):
677+ break
678+
679+ if (
680+ request_details .status .value
681+ == models .RequestStatus .allowed_values [("value" ,)]["FAILED" ]
682+ ):
683+ raise Exception (
684+ "Annotations upload failed. "
685+ f"Previous status was: { request_details .status .value } ."
686+ )
687+
688+ if timeout is not None and timedelta (seconds = timeout ) < (utcnow () - start_time ):
689+ raise Exception (
690+ "Failed to upload the GT annotations to CVAT within the timeout interval. "
691+ f"Previous status was: { request_details .status .value } . "
692+ f"Timeout: { timeout } seconds."
693+ )
694+
695+ sleep (sleep_interval )
696+
697+ logger .info (f"GT annotations for the job { job_id } have been uploaded to CVAT." )
643698
644699
645700def get_quality_control_settings (task_id : int ) -> models .QualitySettings :
@@ -662,7 +717,7 @@ def get_quality_control_settings(task_id: int) -> models.QualitySettings:
662717def update_quality_control_settings (
663718 settings_id : int ,
664719 * ,
665- max_validations_per_job : int ,
720+ max_validations_per_job : int = Config . cvat_config . cvat_max_validation_checks ,
666721 target_metric : str = "accuracy" ,
667722 target_metric_threshold : float = Config .cvat_config .cvat_target_metric_threshold ,
668723 low_overlap_threshold : float = Config .cvat_config .cvat_low_overlap_threshold ,
0 commit comments