|
| 1 | +import json |
| 2 | +import logging |
| 3 | +from datetime import timedelta |
| 4 | +from http import HTTPStatus |
| 5 | +from time import sleep |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models |
| 9 | +from cvat_sdk.core.helpers import get_paginated_collection |
| 10 | + |
| 11 | +from src.core.config import Config |
| 12 | +from src.utils.time import utcnow |
| 13 | + |
| 14 | + |
| 15 | +def get_api_client() -> ApiClient: |
| 16 | + configuration = Configuration( |
| 17 | + host=Config.cvat_config.cvat_url, |
| 18 | + username=Config.cvat_config.cvat_admin, |
| 19 | + password=Config.cvat_config.cvat_admin_pass, |
| 20 | + ) |
| 21 | + |
| 22 | + api_client = ApiClient(configuration=configuration) |
| 23 | + api_client.set_default_header("X-organization", Config.cvat_config.cvat_org_slug) |
| 24 | + |
| 25 | + return api_client |
| 26 | + |
| 27 | + |
| 28 | +def get_last_task_quality_report(task_id: int) -> models.QualityReport | None: |
| 29 | + with get_api_client() as api_client: |
| 30 | + paginated_result, _ = api_client.quality_api.list_reports( |
| 31 | + task_id=task_id, |
| 32 | + page_size=1, |
| 33 | + target="task", |
| 34 | + sort="-created_date", |
| 35 | + ) |
| 36 | + assert len(paginated_result.results) <= 1 |
| 37 | + return paginated_result.results[0] if paginated_result.results else None |
| 38 | + |
| 39 | + |
| 40 | +def compute_task_quality_report( |
| 41 | + task_id: int, |
| 42 | + *, |
| 43 | + max_waiting_time: int = 10 * 60, |
| 44 | + sleep_interval: float = 0.5, |
| 45 | +) -> models.QualityReport: |
| 46 | + logger = logging.getLogger("app") |
| 47 | + start_time = utcnow() |
| 48 | + |
| 49 | + with get_api_client() as api_client: |
| 50 | + _, response = api_client.quality_api.create_report( |
| 51 | + quality_report_create_request={"task_id": task_id}, _parse_response=False |
| 52 | + ) |
| 53 | + rq_id = json.loads(response.data).get("rq_id") |
| 54 | + assert rq_id, ( |
| 55 | + "CVAT server hasn't returned rq_id in the response " |
| 56 | + f"when creating a task({task_id}) quality report" |
| 57 | + ) |
| 58 | + |
| 59 | + while utcnow() - start_time < timedelta(seconds=max_waiting_time): |
| 60 | + _, response = api_client.quality_api.create_report( |
| 61 | + rq_id=rq_id, _check_status=False, _parse_response=False |
| 62 | + ) |
| 63 | + match response.status: |
| 64 | + case HTTPStatus.CREATED: |
| 65 | + report = models.QualityReport._from_openapi_data(**json.loads(response.data)) |
| 66 | + if logger.isEnabledFor(logging.DEBUG): |
| 67 | + logger.debug(f"Created quality report: {report.id}") |
| 68 | + |
| 69 | + return report |
| 70 | + case HTTPStatus.ACCEPTED: |
| 71 | + sleep(sleep_interval) |
| 72 | + continue |
| 73 | + case _: |
| 74 | + raise Exception(f"Unexpected response status: {response.status}") |
| 75 | + |
| 76 | + raise Exception(f"Task({task_id}) quality report has not been created in time") |
| 77 | + |
| 78 | + |
| 79 | +def get_task_quality_report( |
| 80 | + task_id: int, |
| 81 | + *, |
| 82 | + max_waiting_time: int = 10 * 60, |
| 83 | + sleep_interval: float = 0.5, |
| 84 | +) -> models.QualityReport: |
| 85 | + logger = logging.getLogger("app") |
| 86 | + report = get_last_task_quality_report(task_id) |
| 87 | + if report and report.created_date > report.target_last_updated: |
| 88 | + if logger.isEnabledFor(logging.DEBUG): |
| 89 | + logger.debug(f"The latest task({task_id}) quality report({report.id}) is actual") |
| 90 | + return report |
| 91 | + |
| 92 | + return compute_task_quality_report( |
| 93 | + task_id, max_waiting_time=max_waiting_time, sleep_interval=sleep_interval |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +def get_quality_report_data(report_id: int) -> dict[str, Any]: |
| 98 | + logger = logging.getLogger("app") |
| 99 | + with get_api_client() as api_client: |
| 100 | + try: |
| 101 | + _, response = api_client.quality_api.retrieve_report_data( |
| 102 | + report_id, _parse_response=False |
| 103 | + ) |
| 104 | + report_data = json.loads(response.data) |
| 105 | + assert report_data |
| 106 | + return report_data |
| 107 | + |
| 108 | + except exceptions.ApiException as e: |
| 109 | + logger.exception(f"Exception when calling QualityApi.retrieve_report_data: {e}\n") |
| 110 | + raise |
| 111 | + |
| 112 | + |
| 113 | +def get_job_validation_layout(job_id: int) -> models.JobValidationLayoutRead: |
| 114 | + logger = logging.getLogger("app") |
| 115 | + with get_api_client() as api_client: |
| 116 | + try: |
| 117 | + layout, _ = api_client.jobs_api.retrieve_validation_layout(job_id) |
| 118 | + return layout |
| 119 | + |
| 120 | + except exceptions.ApiException as e: |
| 121 | + logger.exception(f"Exception when calling JobApi.retrieve_validation_layout: {e}\n") |
| 122 | + raise |
| 123 | + |
| 124 | + |
| 125 | +def get_task_validation_layout(task_id: int) -> models.TaskValidationLayoutRead: |
| 126 | + logger = logging.getLogger("app") |
| 127 | + with get_api_client() as api_client: |
| 128 | + try: |
| 129 | + layout, _ = api_client.tasks_api.retrieve_validation_layout(task_id) |
| 130 | + return layout |
| 131 | + |
| 132 | + except exceptions.ApiException as e: |
| 133 | + logger.exception(f"Exception when calling TaskApi.retrieve_validation_layout: {e}\n") |
| 134 | + raise |
| 135 | + |
| 136 | + |
| 137 | +def get_jobs_quality_reports(parent_id: int) -> dict[int, models.QualityReport]: |
| 138 | + logger = logging.getLogger("app") |
| 139 | + with get_api_client() as api_client: |
| 140 | + try: |
| 141 | + reports: list[models.QualityReport] = get_paginated_collection( |
| 142 | + api_client.quality_api.list_reports_endpoint, parent_id=parent_id, target="job" |
| 143 | + ) |
| 144 | + return {report.job_id: report for report in reports} |
| 145 | + |
| 146 | + except exceptions.ApiException as e: |
| 147 | + logger.exception(f"Exception when calling QualityApi.list_reports: {e}\n") |
| 148 | + raise |
| 149 | + |
| 150 | + |
| 151 | +def get_quality_report_data(report_id: int) -> dict[str, Any]: |
| 152 | + with get_api_client() as api_client: |
| 153 | + _, response = api_client.quality_api.retrieve_report_data(report_id, _parse_response=False) |
| 154 | + assert response.status == HTTPStatus.OK |
| 155 | + return response.json() |
| 156 | + |
| 157 | + |
| 158 | +def shuffle_honeypots_in_jobs(job_ids: list[int] | int) -> None: |
| 159 | + logger = logging.getLogger("app") |
| 160 | + |
| 161 | + if isinstance(job_ids, int): |
| 162 | + job_ids = [job_ids] |
| 163 | + |
| 164 | + with get_api_client() as api_client: |
| 165 | + for job_id in job_ids: |
| 166 | + updated_validation_layout, _ = api_client.jobs_api.partial_update_validation_layout( |
| 167 | + job_id, |
| 168 | + patched_job_validation_layout_write_request=models.PatchedJobValidationLayoutWriteRequest( |
| 169 | + frame_selection_method="random_uniform", |
| 170 | + ), |
| 171 | + ) |
| 172 | + if logger.isEnabledFor(logging.DEBUG): |
| 173 | + logger.debug( |
| 174 | + f"Updated validation layout for the job {job_id}: {updated_validation_layout!s}" |
| 175 | + ) |
| 176 | + |
| 177 | + |
| 178 | +def disable_validation_frames(task_id: int, *, frames_to_disable: list[int]) -> None: |
| 179 | + logger = logging.getLogger("app") |
| 180 | + with get_api_client() as api_client: |
| 181 | + task_validation_layout, _ = api_client.tasks_api.retrieve_validation_layout(task_id) |
| 182 | + disabled_frames = task_validation_layout.disabled_frames |
| 183 | + |
| 184 | + # nothing to update |
| 185 | + if not (set(frames_to_disable) - set(disabled_frames)): |
| 186 | + if logger.isEnabledFor(logging.DEBUG): |
| 187 | + logger.debug( |
| 188 | + f"Validation frames {frames_to_disable!r} are already " |
| 189 | + f"disabled for the CVAT task {task_id}" |
| 190 | + ) |
| 191 | + return |
| 192 | + |
| 193 | + api_client.tasks_api.partial_update_validation_layout( |
| 194 | + task_id, |
| 195 | + patched_task_validation_layout_write_request=models.PatchedTaskValidationLayoutWriteRequest( |
| 196 | + disabled_frames=sorted(set(disabled_frames + frames_to_disable)) |
| 197 | + ), |
| 198 | + ) |
| 199 | + logger.info( |
| 200 | + f"Validation frames {frames_to_disable!r} have been disabled " |
| 201 | + f"for the CVAT task {task_id}" |
| 202 | + ) |
0 commit comments