Skip to content

Commit cd97e8b

Browse files
authored
[CVAT] Oracle fixes (#3224)
* Fix invalid project data urls in ExO db * Fix gt bbox shift in boxes_from_points tasks * Fix gt annotation replacement in the annotation results
1 parent c9ae4e0 commit cd97e8b

File tree

4 files changed

+39
-34
lines changed

4 files changed

+39
-34
lines changed

packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
import src.services.cvat as db_service
3131
from src.chain.escrow import get_escrow_manifest
3232
from src.core.config import Config
33-
from src.core.storage import compose_data_bucket_filename
33+
from src.core.storage import compose_data_bucket_filename, compose_data_bucket_prefix
3434
from src.core.types import CvatLabelTypes, TaskStatuses, TaskTypes
3535
from src.db import SessionLocal
3636
from src.log import ROOT_LOGGER_NAME
3737
from src.models.cvat import Project
3838
from src.services.cloud import CloudProviders, StorageClient
39-
from src.services.cloud.utils import BucketAccessInfo, compose_bucket_url
39+
from src.services.cloud.utils import BucketAccessInfo
4040
from src.utils.annotations import InstanceSegmentsToBbox, ProjectLabels, is_point_in_bbox
4141
from src.utils.assignments import parse_manifest
4242
from src.utils.logging import NullLogger, format_sequence, get_function_logger
@@ -396,11 +396,7 @@ def build(self):
396396
manifest.annotation.type,
397397
escrow_address,
398398
chain_id,
399-
compose_bucket_url(
400-
data_bucket.bucket_name,
401-
bucket_host=data_bucket.host_url,
402-
provider=data_bucket.provider,
403-
),
399+
data_bucket.to_url(),
404400
cvat_webhook_id=cvat_webhook.id,
405401
)
406402

@@ -1469,6 +1465,10 @@ def _prepare_gt_roi_dataset(self):
14691465
categories=self._gt_dataset.categories(), media_type=dm.Image
14701466
)
14711467

1468+
roi_info_by_point_id: dict[int, skeletons_from_boxes_task.RoiInfo] = {
1469+
roi_info.point_id: roi_info for roi_info in self._rois
1470+
}
1471+
14721472
for sample in self._gt_dataset:
14731473
for gt_bbox in sample.annotations:
14741474
assert isinstance(gt_bbox, dm.Bbox)
@@ -1478,10 +1478,15 @@ def _prepare_gt_roi_dataset(self):
14781478
self.escrow_address, self.chain_id, self._roi_filenames[point_id]
14791479
)
14801480

1481+
# update gt bbox coordinates to match RoI shift
1482+
roi_info = roi_info_by_point_id[point_id]
1483+
new_x = gt_bbox.points[0] - roi_info.roi_x
1484+
new_y = gt_bbox.points[1] - roi_info.roi_y
1485+
14811486
self._gt_roi_dataset.put(
14821487
sample.wrap(
14831488
id=os.path.splitext(gt_roi_filename)[0],
1484-
annotations=[gt_bbox],
1489+
annotations=[gt_bbox.wrap(x=new_x, y=new_y)],
14851490
media=dm.Image(path=gt_roi_filename, size=sample.media_as(dm.Image).size),
14861491
attributes=filter_dict(sample.attributes, exclude_keys=["id"]),
14871492
)
@@ -1495,7 +1500,6 @@ def _create_on_cvat(self):
14951500
assert self._label_configuration is not _unset
14961501
assert self._gt_roi_dataset is not _unset
14971502

1498-
input_data_bucket = BucketAccessInfo.parse_obj(self.manifest.data.data_url)
14991503
oracle_bucket = self.oracle_data_bucket
15001504

15011505
# Register cloud storage on CVAT to pass user dataset
@@ -1535,11 +1539,9 @@ def _create_on_cvat(self):
15351539
self.manifest.annotation.type,
15361540
self.escrow_address,
15371541
self.chain_id,
1538-
compose_bucket_url(
1539-
input_data_bucket.bucket_name,
1540-
bucket_host=input_data_bucket.host_url,
1541-
provider=input_data_bucket.provider,
1542-
),
1542+
oracle_bucket.to_url().rstrip("/")
1543+
+ "/"
1544+
+ compose_data_bucket_prefix(self.escrow_address, self.chain_id),
15431545
cvat_webhook_id=cvat_webhook.id,
15441546
)
15451547
db_service.get_project_by_id(session, project_id, for_update=True) # lock the row
@@ -2635,7 +2637,6 @@ def _task_params_label_key(ts):
26352637
for skeleton_label_id, skeleton_label in enumerate(self.manifest.annotation.labels)
26362638
}
26372639

2638-
input_data_bucket = BucketAccessInfo.parse_obj(self.manifest.data.data_url)
26392640
oracle_bucket = self.oracle_data_bucket
26402641

26412642
# Register cloud storage on CVAT to pass user dataset
@@ -2714,11 +2715,9 @@ def _task_params_label_key(ts):
27142715
self.manifest.annotation.type,
27152716
self.escrow_address,
27162717
self.chain_id,
2717-
compose_bucket_url(
2718-
input_data_bucket.bucket_name,
2719-
bucket_host=input_data_bucket.host_url,
2720-
provider=input_data_bucket.provider,
2721-
),
2718+
oracle_bucket.to_url().rstrip("/")
2719+
+ "/"
2720+
+ compose_data_bucket_prefix(self.escrow_address, self.chain_id),
27222721
cvat_webhook_id=cvat_webhook.id,
27232722
)
27242723
created_projects.append(project_id)

packages/examples/cvat/exchange-oracle/src/services/cloud/types.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from urllib.parse import urlparse
99

1010
import pydantic
11+
from httpx import URL
1112

1213
from src.core import manifest
1314
from src.core.config import Config, StorageConfig
@@ -113,7 +114,7 @@ def from_url(cls, url: str) -> BucketAccessInfo:
113114
)
114115
elif Config.features.enable_custom_cloud_host:
115116
# Check if netloc is an ip address
116-
# or localhost with port (or its /etc/hosts aliast, e.g. minio:9000)
117+
# or localhost with port (or its /etc/hosts alias, e.g. minio:9000)
117118
if is_ipv4(parsed_url.netloc) or re.fullmatch(r"\w+:\d{4}", parsed_url.netloc):
118119
host = parsed_url.netloc
119120
bucket_name, path = parsed_url.path.lstrip("/").split("/", maxsplit=1)
@@ -190,3 +191,18 @@ def parse_obj(
190191
return cls.from_storage_config(data)
191192

192193
raise TypeError(f"Unsupported data type ({type(data)}) was provided")
194+
195+
def to_url(self) -> str:
196+
url = URL(self.host_url)
197+
198+
if Config.features.enable_custom_cloud_host and (
199+
not url.host.endswith(DEFAULT_S3_HOST) and not url.host.endswith(DEFAULT_GCS_HOST)
200+
):
201+
url = url.copy_with(path="/".join(["", self.bucket_name, url.path.lstrip("/")]))
202+
else:
203+
url = url.copy_with(host=f"{self.bucket_name}.{url.host}")
204+
205+
if self.path:
206+
url = url.copy_with(path="/".join(["", url.path.lstrip("/"), self.path.lstrip("/")]))
207+
208+
return str(url)

packages/examples/cvat/exchange-oracle/src/services/cloud/utils.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,9 @@
11
from src.services.cloud.client import StorageClient
2-
from src.services.cloud.gcs import DEFAULT_GCS_HOST, GcsClient
3-
from src.services.cloud.s3 import DEFAULT_S3_HOST, S3Client
2+
from src.services.cloud.gcs import GcsClient
3+
from src.services.cloud.s3 import S3Client
44
from src.services.cloud.types import BucketAccessInfo, CloudProviders
55

66

7-
def compose_bucket_url(
8-
bucket_name: str, provider: CloudProviders, *, bucket_host: str | None = None
9-
) -> str:
10-
match provider:
11-
case CloudProviders.aws:
12-
return f"https://{bucket_name}.{bucket_host or DEFAULT_S3_HOST}/"
13-
case CloudProviders.gcs:
14-
return f"https://{bucket_name}.{bucket_host or DEFAULT_GCS_HOST}/"
15-
16-
177
def make_client(
188
bucket_info: BucketAccessInfo,
199
) -> StorageClient:

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def transform_item(self, item: dm.DatasetItem) -> dm.DatasetItem:
266266
item = item.wrap(id=item.id[len(self._prefix) :])
267267
return item
268268

269-
prefix = BucketAccessInfo.parse_obj(self.manifest.data.data_url).path.lstrip("/\\") + "/"
269+
prefix = BucketAccessInfo.parse_obj(self.manifest.data.data_url).path.strip("/\\") + "/"
270270

271271
# Remove prefixes if it can be done safely
272272
sample_ids = {sample.id for sample in merged_dataset}

0 commit comments

Comments
 (0)