Skip to content

Commit 1cb7e44

Browse files
committed
Allow manifest changes in recording oracle
1 parent ab36747 commit 1cb7e44

File tree

5 files changed

+262
-18
lines changed

5 files changed

+262
-18
lines changed

packages/examples/cvat/recording-oracle/src/core/annotation_meta.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from collections.abc import Iterator
2-
from pathlib import Path
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
34

5+
from pathlib import Path
46
from pydantic import BaseModel
57

8+
if TYPE_CHECKING:
9+
from collections.abc import Collection, Iterator
10+
611
ANNOTATION_RESULTS_METAFILE_NAME = "annotation_meta.json"
712
RESULTING_ANNOTATIONS_FILE = "resulting_annotations.zip"
813

@@ -24,5 +29,7 @@ def job_frame_range(self) -> Iterator[int]:
2429
class AnnotationMeta(BaseModel):
2530
jobs: list[JobMeta]
2631

27-
def skip_jobs(self, job_ids: list[int]):
28-
return AnnotationMeta(jobs=[job for job in self.jobs if job.job_id not in job_ids])
32+
def skip_assignments(self, assignment_ids: Collection[int]) -> AnnotationMeta:
33+
return AnnotationMeta(
34+
jobs=[job for job in self.jobs if job.assignment_id not in assignment_ids]
35+
)

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,13 +375,12 @@ def process_intermediate_results( # noqa: PLR0912
375375
), # should not happen, but waiting should not block processing
376376
)
377377
if task:
378-
# skip jobs that have already been accepted on the previous epochs
379-
accepted_cvat_job_ids = [
380-
validation_result.job.cvat_id
378+
# Skip assignments that were validated earlier
379+
validated_assignment_ids = {
380+
validation_result.assignment_id
381381
for validation_result in db_service.get_task_validation_results(session, task.id)
382-
if validation_result.annotation_quality >= manifest.validation.min_quality
383-
]
384-
unchecked_jobs_meta = meta.skip_jobs(accepted_cvat_job_ids)
382+
}
383+
unchecked_jobs_meta = meta.skip_assignments(validated_assignment_ids)
385384
else:
386385
# Recording Oracle task represents all CVAT tasks related with the escrow
387386
task_id = db_service.create_task(session, escrow_address=escrow_address, chain_id=chain_id)
@@ -566,6 +565,16 @@ def process_intermediate_results( # noqa: PLR0912
566565
else:
567566
assignment_validation_result_id = assignment_validation_result.id
568567

568+
# We consider only the last assignment as final even if there were assignments with higher
569+
# quality score. The reason for this is that during escrow annotation there are various
570+
# task changes possible, for instance:
571+
# - GT can be changed in the middle of the task annotation
572+
# - manifest can be updated with different quality parameters
573+
# etc. It can be considered more of a development or testing conditions so far,
574+
# according to the current system requirements, but it's likely to be
575+
# a normal requirement in the future.
576+
# Therefore, we use the logic: only the last job assignment can be considered
577+
# a final annotation result, regardless of the assignment quality.
569578
job_final_result_ids[job.id] = assignment_validation_result_id
570579

571580
task_jobs = task.jobs

packages/examples/cvat/recording-oracle/tests/conftest.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import pytest
44
from fastapi.testclient import TestClient
5+
from sqlalchemy.orm import Session
56

67
from src import app
7-
from src.db import Base, engine
8+
from src.db import Base, SessionLocal, engine
89

910

1011
@pytest.fixture(autouse=True)
@@ -17,3 +18,14 @@ def db():
1718
def client() -> Generator:
1819
with TestClient(app) as c:
1920
yield c
21+
22+
23+
@pytest.fixture
24+
def session() -> Generator[Session, None, None]:
25+
session = SessionLocal()
26+
27+
try:
28+
yield session
29+
finally:
30+
session.rollback()
31+
session.close()

packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py

Lines changed: 221 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
import io
12
import random
23
import unittest
34
import uuid
5+
from contextlib import ExitStack
6+
from logging import Logger
7+
from types import SimpleNamespace
8+
from unittest import mock
49

10+
from sqlalchemy.orm import Session
11+
12+
from src.core.annotation_meta import AnnotationMeta, JobMeta
13+
from src.core.manifest import TaskManifest, parse_manifest
514
from src.core.types import Networks
15+
from src.core.validation_results import ValidationFailure, ValidationSuccess
16+
from src.cvat import api_calls as cvat_api
617
from src.db import SessionLocal
18+
from src.handlers.process_intermediate_results import process_intermediate_results
719
from src.services.validation import (
820
create_job,
921
create_task,
@@ -16,19 +28,19 @@
1628
get_validation_result_by_assignment_id,
1729
)
1830

31+
from tests.utils.constants import ESCROW_ADDRESS, WALLET_ADDRESS1
32+
1933

2034
class ServiceIntegrationTest(unittest.TestCase):
2135
def setUp(self):
2236
random.seed(42)
2337
self.session = SessionLocal()
24-
self.escrow_address = "0x" + "".join([str(random.randint(0, 9)) for _ in range(40)])
38+
self.escrow_address = ESCROW_ADDRESS
2539
self.chain_id = Networks.localhost
2640
self.cvat_id = 0
27-
self.annotator_wallet_address = "0x" + "".join(
28-
[str(random.randint(0, 9)) for _ in range(40)]
29-
)
41+
self.annotator_wallet_address = WALLET_ADDRESS1
3042
self.annotation_quality = 0.9
31-
self.assigment_id = str(uuid.uuid4())
43+
self.assignment_id = str(uuid.uuid4())
3244

3345
def tearDown(self):
3446
self.session.close()
@@ -67,12 +79,214 @@ def test_create_and_get_validation_result(self):
6779
job_id,
6880
self.annotator_wallet_address,
6981
self.annotation_quality,
70-
self.assigment_id,
82+
self.assignment_id,
7183
)
7284

73-
vr = get_validation_result_by_assignment_id(self.session, self.assigment_id)
85+
vr = get_validation_result_by_assignment_id(self.session, self.assignment_id)
7486
assert vr.id == vr_id
7587

7688
vrs = get_task_validation_results(self.session, task_id)
7789
assert len(vrs) == 1
7890
assert vrs[0] == vr
91+
92+
93+
class TestManifestChange:
94+
@staticmethod
95+
def _generate_manifest(*, min_quality: float = 0.8) -> TaskManifest:
96+
data = {
97+
"data": {"data_url": "http://localhost:9010/datasets/sample"},
98+
"annotation": {
99+
"labels": [{"name": "person"}],
100+
"description": "",
101+
"user_guide": "",
102+
"type": "image_points",
103+
"job_size": 10,
104+
},
105+
"validation": {
106+
"min_quality": min_quality,
107+
"val_size": 2,
108+
"gt_url": "http://localhost:9010/datasets/sample/annotations/sample_gt.json",
109+
},
110+
"job_bounty": "0.0001",
111+
}
112+
113+
return parse_manifest(data)
114+
115+
def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Session):
116+
escrow_address = ESCROW_ADDRESS
117+
chain_id = Networks.localhost
118+
119+
min_quality1 = 0.8
120+
min_quality2 = 0.5
121+
frame_count = 10
122+
123+
manifest = self._generate_manifest(min_quality=min_quality1)
124+
125+
cvat_task_id = 1
126+
cvat_job_id = 1
127+
annotator1 = WALLET_ADDRESS1
128+
129+
assignment1_id = f"0x{0:040d}"
130+
assignment1_quality = 0.7
131+
132+
assignment2_id = f"0x{1:040d}"
133+
assignment2_quality = 0.6
134+
135+
# create a validation input
136+
with ExitStack() as common_lock_es:
137+
logger = mock.Mock(Logger)
138+
139+
mock_make_cloud_client = common_lock_es.enter_context(
140+
mock.patch("src.handlers.process_intermediate_results.make_cloud_client")
141+
)
142+
mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"")
143+
144+
mock_get_task_validation_layout = common_lock_es.enter_context(
145+
mock.patch(
146+
"src.handlers.process_intermediate_results.cvat_api.get_task_validation_layout"
147+
)
148+
)
149+
mock_get_task_validation_layout.return_value = mock.Mock(
150+
cvat_api.models.ITaskValidationLayoutRead,
151+
honeypot_frames=[0, 1],
152+
honeypot_real_frames=[0, 1],
153+
)
154+
155+
mock_get_task_data_meta = common_lock_es.enter_context(
156+
mock.patch("src.handlers.process_intermediate_results.cvat_api.get_task_data_meta")
157+
)
158+
mock_get_task_data_meta.return_value = mock.Mock(
159+
cvat_api.models.IDataMetaRead,
160+
frames=[SimpleNamespace(name=f"frame_{i}.jpg") for i in range(frame_count)],
161+
)
162+
163+
common_lock_es.enter_context(
164+
mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from")
165+
)
166+
common_lock_es.enter_context(
167+
mock.patch("src.handlers.process_intermediate_results.extract_zip_archive")
168+
)
169+
common_lock_es.enter_context(
170+
mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive")
171+
)
172+
173+
def patched_prepare_merged_dataset(self):
174+
self._updated_merged_dataset_archive = io.BytesIO()
175+
176+
common_lock_es.enter_context(
177+
mock.patch(
178+
"src.handlers.process_intermediate_results._TaskValidator._prepare_merged_dataset",
179+
patched_prepare_merged_dataset,
180+
)
181+
)
182+
183+
annotation_meta = AnnotationMeta(
184+
jobs=[
185+
JobMeta(
186+
job_id=cvat_job_id,
187+
task_id=cvat_task_id,
188+
annotation_filename="",
189+
annotator_wallet_address=annotator1,
190+
assignment_id=assignment1_id,
191+
start_frame=0,
192+
stop_frame=manifest.annotation.job_size + manifest.validation.val_size,
193+
)
194+
]
195+
)
196+
197+
with (
198+
mock.patch(
199+
"src.handlers.process_intermediate_results.cvat_api.get_task_quality_report"
200+
) as mock_get_task_quality_report,
201+
mock.patch(
202+
"src.handlers.process_intermediate_results.cvat_api.get_quality_report_data"
203+
) as mock_get_quality_report_data,
204+
mock.patch(
205+
"src.handlers.process_intermediate_results.cvat_api.get_jobs_quality_reports"
206+
) as mock_get_jobs_quality_reports,
207+
):
208+
mock_get_task_quality_report.return_value = mock.Mock(
209+
cvat_api.models.IQualityReport, id=1
210+
)
211+
mock_get_quality_report_data.return_value = mock.Mock(
212+
cvat_api.QualityReportData,
213+
frame_results={
214+
"0": mock.Mock(annotations=mock.Mock(accuracy=assignment1_quality)),
215+
"1": mock.Mock(annotations=mock.Mock(accuracy=assignment1_quality)),
216+
},
217+
)
218+
mock_get_jobs_quality_reports.return_value = [
219+
mock.Mock(
220+
cvat_api.models.IQualityReport,
221+
job_id=1,
222+
summary=mock.Mock(accuracy=assignment1_quality),
223+
),
224+
]
225+
226+
vr1 = process_intermediate_results(
227+
session,
228+
escrow_address=escrow_address,
229+
chain_id=chain_id,
230+
meta=annotation_meta,
231+
merged_annotations=io.BytesIO(),
232+
manifest=manifest,
233+
logger=logger,
234+
)
235+
236+
assert isinstance(vr1, ValidationFailure)
237+
assert len(vr1.rejected_jobs) == 1
238+
239+
manifest.validation.min_quality = min_quality2
240+
241+
annotation_meta.jobs[0].assignment_id = assignment2_id
242+
243+
with (
244+
mock.patch(
245+
"src.handlers.process_intermediate_results.cvat_api.get_task_quality_report"
246+
) as mock_get_task_quality_report,
247+
mock.patch(
248+
"src.handlers.process_intermediate_results.cvat_api.get_quality_report_data"
249+
) as mock_get_quality_report_data,
250+
mock.patch(
251+
"src.handlers.process_intermediate_results.cvat_api.get_jobs_quality_reports"
252+
) as mock_get_jobs_quality_reports,
253+
):
254+
mock_get_task_quality_report.return_value = mock.Mock(
255+
cvat_api.models.IQualityReport, id=2
256+
)
257+
mock_get_quality_report_data.return_value = mock.Mock(
258+
cvat_api.QualityReportData,
259+
frame_results={
260+
"0": mock.Mock(annotations=mock.Mock(accuracy=assignment2_quality)),
261+
"1": mock.Mock(annotations=mock.Mock(accuracy=assignment2_quality)),
262+
},
263+
)
264+
mock_get_jobs_quality_reports.return_value = [
265+
mock.Mock(
266+
cvat_api.models.IQualityReport,
267+
job_id=1,
268+
summary=mock.Mock(accuracy=assignment2_quality),
269+
),
270+
]
271+
272+
vr2 = process_intermediate_results(
273+
session,
274+
escrow_address=escrow_address,
275+
chain_id=chain_id,
276+
meta=annotation_meta,
277+
merged_annotations=io.BytesIO(),
278+
manifest=manifest,
279+
logger=logger,
280+
)
281+
282+
assert isinstance(vr2, ValidationSuccess)
283+
assert vr2.job_results[cvat_job_id] == assignment2_quality
284+
285+
assert len(vr2.validation_meta.jobs) == 1
286+
assert len(vr2.validation_meta.results) == 2
287+
assert (
288+
vr2.validation_meta.results[
289+
vr2.validation_meta.jobs[0].final_result_id
290+
].annotation_quality
291+
== assignment2_quality
292+
)

packages/examples/cvat/recording-oracle/tests/utils/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,5 @@
9898
-----END PGP PUBLIC KEY BLOCK-----
9999
"""
100100
)
101+
102+
WALLET_ADDRESS1 = "0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed"

0 commit comments

Comments
 (0)