Skip to content

Commit bb717df

Browse files
committed
[Recording Oracle] Add initial polygon task type support
1 parent c846252 commit bb717df

File tree

3 files changed

+102
-3
lines changed

3 files changed

+102
-3
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import datetime
2+
3+
import uvicorn
4+
5+
from src import app
6+
from src.chain.kvstore import register_in_kvstore
7+
from src.core.config import Config
8+
9+
LOCAL_MANIFEST_FILES = set()
10+
11+
12+
def apply_local_development_patches():
13+
"""
14+
Applies local development patches to bypass direct source code modifications:
15+
- Overrides `EscrowUtils.get_escrow` to retrieve local escrow data for specific addresses,
16+
using mock data if the address corresponds to a local manifest.
17+
- Updates local manifest files from cloud storage.
18+
- Overrides `validate_address` to disable address validation.
19+
- Replaces `validate_oracle_webhook_signature` with a lenient version for oracle signature validation in development.
20+
"""
21+
from human_protocol_sdk.constants import ChainId
22+
from human_protocol_sdk.escrow import EscrowData, EscrowUtils
23+
24+
def get_local_escrow(chain_id: int, escrow_address: str) -> EscrowData:
25+
possible_manifest_name = escrow_address.split(":")[0]
26+
if possible_manifest_name in LOCAL_MANIFEST_FILES:
27+
return EscrowData(
28+
chain_id=ChainId(chain_id),
29+
id="test",
30+
address=escrow_address,
31+
amount_paid=10,
32+
balance=10,
33+
count=1,
34+
factory_address="",
35+
launcher="",
36+
status="Pending",
37+
token="HMT",
38+
total_funded_amount=10,
39+
created_at=datetime.datetime(2023, 1, 1),
40+
manifest_url=f"http://127.0.0.1:9010/manifests/{possible_manifest_name}",
41+
)
42+
return original_get_escrow(ChainId(chain_id), escrow_address)
43+
44+
original_get_escrow = EscrowUtils.get_escrow
45+
EscrowUtils.get_escrow = get_local_escrow
46+
47+
from src.services import cloud
48+
from src.services.cloud import BucketAccessInfo
49+
50+
manifests = cloud.make_client(BucketAccessInfo.parse_obj(Config.storage_config)).list_files(
51+
bucket="manifests"
52+
)
53+
LOCAL_MANIFEST_FILES.update(manifests)
54+
55+
import src.schemas.webhook
56+
from src.core.types import OracleWebhookTypes
57+
58+
src.schemas.webhook.validate_address = lambda x: x
59+
60+
async def lenient_validate_oracle_webhook_signature(request, signature, webhook):
61+
from src.validators.signature import validate_oracle_webhook_signature
62+
63+
try:
64+
return OracleWebhookTypes(signature.split(":")[0])
65+
except (ValueError, TypeError):
66+
return await validate_oracle_webhook_signature(request, signature, webhook)
67+
68+
import src.endpoints.webhook
69+
src.endpoints.webhook.validate_oracle_webhook_signature = (
70+
lenient_validate_oracle_webhook_signature
71+
)
72+
import logging
73+
74+
logging.warning("Local development patches applied.")
75+
76+
77+
if __name__ == "__main__":
78+
is_dev = Config.environment == "development"
79+
if is_dev:
80+
apply_local_development_patches()
81+
82+
Config.validate()
83+
register_in_kvstore()
84+
85+
uvicorn.run(
86+
app="src:app",
87+
host="0.0.0.0", # noqa: S104
88+
port=int(Config.port),
89+
workers=Config.workers_amount,
90+
)

packages/examples/cvat/recording-oracle/src/core/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class Networks(int, Enum):
1313
class TaskTypes(str, Enum, metaclass=BetterEnumMeta):
1414
image_label_binary = "IMAGE_LABEL_BINARY"
1515
image_points = "IMAGE_POINTS"
16+
image_polygons = "IMAGE_POLYGONS"
1617
image_boxes = "IMAGE_BOXES"
1718
image_boxes_from_points = "IMAGE_BOXES_FROM_POINTS"
1819
image_skeletons_from_boxes = "IMAGE_SKELETONS_FROM_BOXES"

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
DM_DATASET_FORMAT_MAPPING = {
4444
TaskTypes.image_label_binary: "cvat_images",
45+
TaskTypes.image_polygons: "coco_instances",
4546
TaskTypes.image_points: "coco_person_keypoints",
4647
TaskTypes.image_boxes: "coco_instances",
4748
TaskTypes.image_boxes_from_points: "coco_instances",
@@ -51,6 +52,7 @@
5152
DM_GT_DATASET_FORMAT_MAPPING = {
5253
TaskTypes.image_label_binary: "cvat_images",
5354
TaskTypes.image_points: "coco_instances", # we compare points against boxes
55+
TaskTypes.image_polygons: "coco_instances",
5456
TaskTypes.image_boxes: "coco_instances",
5557
TaskTypes.image_boxes_from_points: "coco_instances",
5658
TaskTypes.image_skeletons_from_boxes: "coco_person_keypoints",
@@ -60,6 +62,7 @@
6062
DATASET_COMPARATOR_TYPE_MAP: dict[TaskTypes, type[DatasetComparator]] = {
6163
# TaskType.image_label_binary: TagDatasetComparator, # TODO: implement if support is needed
6264
TaskTypes.image_boxes: BboxDatasetComparator,
65+
TaskTypes.image_polygons: BboxDatasetComparator,
6366
TaskTypes.image_points: PointsDatasetComparator,
6467
TaskTypes.image_boxes_from_points: BboxDatasetComparator,
6568
TaskTypes.image_skeletons_from_boxes: SkeletonDatasetComparator,
@@ -199,7 +202,7 @@ def _validate_jobs(self):
199202
try:
200203
job_mean_accuracy = comparator.compare(gt_dataset, job_dataset)
201204
except TooFewGtError as e:
202-
job_results[job_cvat_id] = self.self.UNKNOWN_QUALITY
205+
job_results[job_cvat_id] = self.UNKNOWN_QUALITY
203206
rejected_jobs[job_cvat_id] = e
204207
continue
205208

@@ -271,7 +274,7 @@ def _put_gt_into_merged_dataset(
271274
"""
272275

273276
match manifest.annotation.type:
274-
case TaskTypes.image_boxes.value:
277+
case TaskTypes.image_boxes.value | TaskTypes.image_polygons.value:
275278
merged_dataset.update(gt_dataset)
276279
case TaskTypes.image_points.value:
277280
merged_label_cat: dm.LabelCategories = merged_dataset.categories()[
@@ -929,7 +932,12 @@ def process_intermediate_results( # noqa: PLR0912
929932
# actually validate jobs
930933

931934
task_type = manifest.annotation.type
932-
if task_type in [TaskTypes.image_label_binary, TaskTypes.image_boxes, TaskTypes.image_points]:
935+
if task_type in [
936+
TaskTypes.image_label_binary,
937+
TaskTypes.image_boxes,
938+
TaskTypes.image_polygons,
939+
TaskTypes.image_points,
940+
]:
933941
validator_type = _TaskValidator
934942
elif task_type == TaskTypes.image_boxes_from_points:
935943
validator_type = _BoxesFromPointsValidator

0 commit comments

Comments
 (0)