1010from http import HTTPStatus
1111from io import BytesIO
1212from time import sleep
13- from typing import Any
13+ from typing import TYPE_CHECKING , Any
1414
15+ from cvat_sdk import Client , make_client
1516from cvat_sdk .api_client import ApiClient , Configuration , exceptions , models
1617from cvat_sdk .api_client .api_client import Endpoint
1718from cvat_sdk .core .helpers import get_paginated_collection
2021from src .utils .enums import BetterEnumMeta
2122from src .utils .time import utcnow
2223
24+ if TYPE_CHECKING :
25+ from cvat_sdk .core .proxies .jobs import Job
26+
2327_NOTSET = object ()
2428
2529
@@ -122,6 +126,16 @@ def get_api_client() -> ApiClient:
122126 return api_client
123127
124128
129+ def get_sdk_client () -> Client :
130+ client = make_client (
131+ host = Config .cvat_config .cvat_url ,
132+ credentials = (Config .cvat_config .cvat_admin , Config .cvat_config .cvat_admin_pass ),
133+ )
134+ client .organization_slug = Config .cvat_config .cvat_org_slug
135+
136+ return client
137+
138+
125139def create_cloudstorage (
126140 provider : str ,
127141 bucket_name : str ,
@@ -297,14 +311,19 @@ def create_cvat_webhook(project_id: int) -> models.WebhookRead:
297311 raise
298312
299313
300- def create_task (project_id : int , name : str ) -> models .TaskRead :
314+ def create_task (
315+ project_id : int ,
316+ name : str ,
317+ * ,
318+ segment_size : int = Config .cvat_config .cvat_task_segment_size ,
319+ ) -> models .TaskRead :
301320 logger = logging .getLogger ("app" )
302321 with get_api_client () as api_client :
303322 task_write_request = models .TaskWriteRequest (
304323 name = name ,
305324 project_id = project_id ,
306325 overlap = 0 ,
307- segment_size = Config . cvat_config . cvat_job_segment_size ,
326+ segment_size = segment_size ,
308327 )
309328 try :
310329 (task_info , response ) = api_client .tasks_api .create (task_write_request )
@@ -335,8 +354,14 @@ def put_task_data(
335354 * ,
336355 filenames : list [str ] | None = None ,
337356 sort_images : bool = True ,
357+ validation_params : dict [str , str | float | list [str ]] | None = None ,
338358) -> None :
339359 logger = logging .getLogger ("app" )
360+ sorting_method = (
361+ models .SortingMethod ("lexicographical" )
362+ if sort_images
363+ else models .SortingMethod ("predefined" )
364+ )
340365
341366 with get_api_client () as api_client :
342367 kwargs = {}
@@ -345,21 +370,42 @@ def put_task_data(
345370 else :
346371 kwargs ["filename_pattern" ] = "*"
347372
373+ if validation_params :
374+ logger .info (
375+ f"The { sorting_method } is ignored."
376+ 'Only "random" sorting can be used when validation parameters passed.'
377+ )
378+ sorting_method = models .SortingMethod ("random" )
379+
380+ gt_filenames = validation_params ["gt_filenames" ]
381+ if missed_filenames := set (gt_filenames ) - set (filenames ):
382+ filenames .extend (missed_filenames )
383+
384+ kwargs ["validation_params" ] = models .DataRequestValidationParams (
385+ mode = models .ValidationMode ("gt_pool" ),
386+ frames = gt_filenames ,
387+ frame_selection_method = models .FrameSelectionMethod ("manual" ),
388+ frames_per_job_count = validation_params .get (
389+ "gt_frames_per_job_count" ,
390+ Config .cvat_config .cvat_val_frames_per_job_count ,
391+ ),
392+ )
393+
348394 data_request = models .DataRequest (
349- chunk_size = Config .cvat_config .cvat_job_segment_size ,
395+ chunk_size = Config .cvat_config .cvat_task_segment_size ,
350396 cloud_storage_id = cloudstorage_id ,
351397 image_quality = Config .cvat_config .cvat_default_image_quality ,
352398 use_cache = True ,
353399 use_zip_chunks = True ,
354- sorting_method = "lexicographical" if sort_images else "predefined" ,
400+ sorting_method = sorting_method ,
355401 ** kwargs ,
356402 )
357403 try :
358404 (_ , response ) = api_client .tasks_api .create_data (task_id , data_request = data_request )
359405 return
360406
361407 except exceptions .ApiException as e :
362- logger .exception (f"Exception when calling ProjectsApi.put_task_data : { e } \n " )
408+ logger .exception (f"Exception when calling tasks_api.create_data : { e } \n " )
363409 raise
364410
365411
@@ -563,36 +609,138 @@ def clear_job_annotations(job_id: int) -> None:
563609 raise
564610
565611
566- def update_job_assignee (id : str , assignee_id : int | None ):
612+ def setup_gt_job (task_id : int , filename : str , format_name : str ) -> None :
613+ gt_job = get_gt_job (task_id )
614+ upload_gt_annotations (gt_job .id , filename , format_name )
615+ finish_gt_job (gt_job .id )
616+ settings = get_quality_control_settings (task_id )
617+ update_quality_control_settings (settings .id )
618+
619+
620+ def get_gt_job (task_id : int ) -> models .JobRead :
621+ logger = logging .getLogger ("app" )
622+
623+ with get_api_client () as api_client :
624+ try :
625+ (paginated_jobs , _ ) = api_client .jobs_api .list (task_id = task_id , type = "ground_truth" )
626+ assert (
627+ len (paginated_jobs ["results" ]) == 1
628+ ), f'CVAT returned { len (paginated_jobs ["results" ])} GT jobs'
629+ return paginated_jobs ["results" ][0 ]
630+ except (exceptions .ApiException , AssertionError ) as ex :
631+ logger .exception (f"Exception when calling JobsApi.list(): { ex } \n " )
632+ raise
633+
634+
635+ def upload_gt_annotations (
636+ job_id : int ,
637+ filename : str ,
638+ format_name : str ,
639+ ) -> None :
640+ with get_sdk_client () as client :
641+ job : Job = client .jobs .retrieve (job_id )
642+ job .import_annotations (format_name = format_name , filename = filename )
643+
644+
645+ def get_quality_control_settings (task_id : int ) -> models .QualitySettings :
567646 logger = logging .getLogger ("app" )
568647
569648 with get_api_client () as api_client :
570649 try :
571- api_client .jobs_api .partial_update (
572- id = id ,
573- patched_job_write_request = models .PatchedJobWriteRequest (assignee = assignee_id ),
650+ paginated_data , _ = api_client .quality_api .list_settings (task_id = task_id )
651+ assert len (paginated_data ["results" ]) == 1 , (
652+ f'CVAT returned { len (paginated_data ["results" ])} '
653+ "quality control settings associated with the task"
654+ )
655+ return paginated_data ["results" ][0 ]
656+
657+ except (exceptions .ApiException , AssertionError ) as e :
658+ logger .exception (f"Exception when calling QualityApi.list_settings(): { e } \n " )
659+ raise
660+
661+
662+ def update_quality_control_settings (
663+ settings_id : int ,
664+ * ,
665+ max_validations_per_job : int ,
666+ target_metric : str = "accuracy" ,
667+ target_metric_threshold : float = Config .cvat_config .cvat_target_metric_threshold ,
668+ low_overlap_threshold : float = Config .cvat_config .cvat_low_overlap_threshold ,
669+ iou_threshold : float = Config .cvat_config .cvat_iou_threshold ,
670+ oks_sigma : float = Config .cvat_config .cvat_oks_sigma ,
671+ ) -> None :
672+ logger = logging .getLogger ("app" )
673+
674+ with get_api_client () as api_client :
675+ try :
676+ api_client .quality_api .partial_update_settings (
677+ settings_id ,
678+ patched_quality_settings_request = models .PatchedQualitySettingsRequest (
679+ max_validations_per_job = max_validations_per_job ,
680+ target_metric = target_metric ,
681+ target_metric_threshold = target_metric_threshold ,
682+ iou_threshold = iou_threshold ,
683+ low_overlap_threshold = low_overlap_threshold ,
684+ oks_sigma = oks_sigma ,
685+ ),
574686 )
575687 except exceptions .ApiException as e :
576- logger .exception (f"Exception when calling JobsApi.partial_update (): { e } \n " )
688+ logger .exception (f"Exception when calling QualityApi.partial_update_settings (): { e } \n " )
577689 raise
578690
579691
580- def restart_job (id : str , * , assignee_id : int | None = None ):
692+ def _update_job (
693+ job_id : int ,
694+ * ,
695+ assignee_id : int | None | object = _NOTSET ,
696+ stage : models .JobStage | None = None ,
697+ state : models .OperationStatus | None = None ,
698+ ) -> None :
699+ to_update = {
700+ attr : value
701+ for attr , value in {
702+ "stage" : stage ,
703+ "state" : state ,
704+ }.items ()
705+ if value
706+ }
707+
708+ if assignee_id is not _NOTSET :
709+ to_update ["assignee" ] = assignee_id
710+
711+ assert to_update
712+
581713 logger = logging .getLogger ("app" )
582714
583715 with get_api_client () as api_client :
584716 try :
585717 api_client .jobs_api .partial_update (
586- id = id ,
587- patched_job_write_request = models .PatchedJobWriteRequest (
588- stage = "annotation" , state = "new" , assignee = assignee_id
589- ),
718+ job_id , patched_job_write_request = models .PatchedJobWriteRequest (** to_update )
590719 )
591720 except exceptions .ApiException as e :
592721 logger .exception (f"Exception when calling JobsApi.partial_update(): { e } \n " )
593722 raise
594723
595724
725+ def update_job_assignee (id : int , assignee_id : int | None ):
726+ _update_job (id , assignee_id = assignee_id )
727+
728+
729+ def restart_job (id : str , * , assignee_id : int | None = None ):
730+ _update_job (
731+ id ,
732+ stage = models .JobStage ("annotation" ),
733+ state = models .OperationStatus ("new" ),
734+ assignee_id = assignee_id ,
735+ )
736+
737+
738+ def finish_gt_job (job_id : int ) -> None :
739+ _update_job (
740+ job_id , stage = models .JobStage ("acceptance" ), state = models .OperationStatus ("completed" )
741+ )
742+
743+
596744def get_user_id (user_email : str ) -> int :
597745 logger = logging .getLogger ("app" )
598746
0 commit comments