Skip to content

Commit 470bf0d

Browse files
authored
[CVAT] Allow manifest changes in recording oracle (#2850)
* Allow manifest changes in recording oracle * Fix linter * Fix test * Remove .env from test dockerfile * Ignore nonexistent .env
1 parent f745ece commit 470bf0d

File tree

7 files changed

+270
-18
lines changed

7 files changed

+270
-18
lines changed

packages/examples/cvat/exchange-oracle/dockerfiles/test.Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ RUN poetry config virtualenvs.create false \
1414

1515
COPY . .
1616

17+
RUN rm -f ./src/.env
18+
1719
CMD ["pytest", "-W", "ignore::DeprecationWarning", "-v"]

packages/examples/cvat/recording-oracle/dockerfiles/test.Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ RUN poetry config virtualenvs.create false \
1414

1515
COPY . .
1616

17+
RUN rm -f ./src/.env
18+
1719
CMD ["pytest", "-W", "ignore::DeprecationWarning", "-W", "ignore::RuntimeWarning", "-W", "ignore::UserWarning", "-v"]

packages/examples/cvat/recording-oracle/src/core/annotation_meta.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from collections.abc import Iterator
2-
from pathlib import Path
1+
from __future__ import annotations
2+
3+
from pathlib import Path # noqa: TCH003
4+
from typing import TYPE_CHECKING
35

46
from pydantic import BaseModel
57

8+
if TYPE_CHECKING:
9+
from collections.abc import Collection, Iterator
10+
611
ANNOTATION_RESULTS_METAFILE_NAME = "annotation_meta.json"
712
RESULTING_ANNOTATIONS_FILE = "resulting_annotations.zip"
813

@@ -24,5 +29,7 @@ def job_frame_range(self) -> Iterator[int]:
2429
class AnnotationMeta(BaseModel):
2530
jobs: list[JobMeta]
2631

27-
def skip_jobs(self, job_ids: list[int]):
28-
return AnnotationMeta(jobs=[job for job in self.jobs if job.job_id not in job_ids])
32+
def skip_assignments(self, assignment_ids: Collection[int]) -> AnnotationMeta:
33+
return AnnotationMeta(
34+
jobs=[job for job in self.jobs if job.assignment_id not in assignment_ids]
35+
)

packages/examples/cvat/recording-oracle/src/handlers/process_intermediate_results.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,13 +375,12 @@ def process_intermediate_results( # noqa: PLR0912
375375
), # should not happen, but waiting should not block processing
376376
)
377377
if task:
378-
# skip jobs that have already been accepted on the previous epochs
379-
accepted_cvat_job_ids = [
380-
validation_result.job.cvat_id
378+
# Skip assignments that were validated earlier
379+
validated_assignment_ids = {
380+
validation_result.assignment_id
381381
for validation_result in db_service.get_task_validation_results(session, task.id)
382-
if validation_result.annotation_quality >= manifest.validation.min_quality
383-
]
384-
unchecked_jobs_meta = meta.skip_jobs(accepted_cvat_job_ids)
382+
}
383+
unchecked_jobs_meta = meta.skip_assignments(validated_assignment_ids)
385384
else:
386385
# Recording Oracle task represents all CVAT tasks related with the escrow
387386
task_id = db_service.create_task(session, escrow_address=escrow_address, chain_id=chain_id)
@@ -566,6 +565,16 @@ def process_intermediate_results( # noqa: PLR0912
566565
else:
567566
assignment_validation_result_id = assignment_validation_result.id
568567

568+
# We consider only the last assignment as final even if there were assignments with higher
569+
# quality score. The reason for this is that during escrow annotation there are various
570+
# task changes possible, for instance:
571+
# - GT can be changed in the middle of the task annotation
572+
# - manifest can be updated with different quality parameters
573+
# etc. It can be considered more of a development or testing conditions so far,
574+
# according to the current system requirements, but it's likely to be
575+
# a normal requirement in the future.
576+
# Therefore, we use the logic: only the last job assignment can be considered
577+
# a final annotation result, regardless of the assignment quality.
569578
job_final_result_ids[job.id] = assignment_validation_result_id
570579

571580
task_jobs = task.jobs

packages/examples/cvat/recording-oracle/tests/conftest.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import pytest
44
from fastapi.testclient import TestClient
5+
from sqlalchemy.orm import Session
56

67
from src import app
7-
from src.db import Base, engine
8+
from src.db import Base, SessionLocal, engine
89

910

1011
@pytest.fixture(autouse=True)
@@ -17,3 +18,14 @@ def db():
1718
def client() -> Generator:
1819
with TestClient(app) as c:
1920
yield c
21+
22+
23+
@pytest.fixture
24+
def session() -> Generator[Session, None, None]:
25+
session = SessionLocal()
26+
27+
try:
28+
yield session
29+
finally:
30+
session.rollback()
31+
session.close()

packages/examples/cvat/recording-oracle/tests/integration/services/test_validation_service.py

Lines changed: 225 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
import io
12
import random
23
import unittest
34
import uuid
5+
from contextlib import ExitStack
6+
from logging import Logger
7+
from types import SimpleNamespace
8+
from unittest import mock
49

10+
from sqlalchemy.orm import Session
11+
12+
from src.core.annotation_meta import AnnotationMeta, JobMeta
13+
from src.core.manifest import TaskManifest, parse_manifest
514
from src.core.types import Networks
15+
from src.core.validation_results import ValidationFailure, ValidationSuccess
16+
from src.cvat import api_calls as cvat_api
617
from src.db import SessionLocal
18+
from src.handlers.process_intermediate_results import process_intermediate_results
719
from src.services.validation import (
820
create_job,
921
create_task,
@@ -16,19 +28,19 @@
1628
get_validation_result_by_assignment_id,
1729
)
1830

31+
from tests.utils.constants import ESCROW_ADDRESS, WALLET_ADDRESS1
32+
1933

2034
class ServiceIntegrationTest(unittest.TestCase):
2135
def setUp(self):
2236
random.seed(42)
2337
self.session = SessionLocal()
24-
self.escrow_address = "0x" + "".join([str(random.randint(0, 9)) for _ in range(40)])
38+
self.escrow_address = ESCROW_ADDRESS
2539
self.chain_id = Networks.localhost
2640
self.cvat_id = 0
27-
self.annotator_wallet_address = "0x" + "".join(
28-
[str(random.randint(0, 9)) for _ in range(40)]
29-
)
41+
self.annotator_wallet_address = WALLET_ADDRESS1
3042
self.annotation_quality = 0.9
31-
self.assigment_id = str(uuid.uuid4())
43+
self.assignment_id = str(uuid.uuid4())
3244

3345
def tearDown(self):
3446
self.session.close()
@@ -67,12 +79,218 @@ def test_create_and_get_validation_result(self):
6779
job_id,
6880
self.annotator_wallet_address,
6981
self.annotation_quality,
70-
self.assigment_id,
82+
self.assignment_id,
7183
)
7284

73-
vr = get_validation_result_by_assignment_id(self.session, self.assigment_id)
85+
vr = get_validation_result_by_assignment_id(self.session, self.assignment_id)
7486
assert vr.id == vr_id
7587

7688
vrs = get_task_validation_results(self.session, task_id)
7789
assert len(vrs) == 1
7890
assert vrs[0] == vr
91+
92+
93+
class TestManifestChange:
94+
@staticmethod
95+
def _generate_manifest(*, min_quality: float = 0.8) -> TaskManifest:
96+
data = {
97+
"data": {"data_url": "http://localhost:9010/datasets/sample"},
98+
"annotation": {
99+
"labels": [{"name": "person"}],
100+
"description": "",
101+
"user_guide": "",
102+
"type": "image_points",
103+
"job_size": 10,
104+
},
105+
"validation": {
106+
"min_quality": min_quality,
107+
"val_size": 2,
108+
"gt_url": "http://localhost:9010/datasets/sample/annotations/sample_gt.json",
109+
},
110+
"job_bounty": "0.0001",
111+
}
112+
113+
return parse_manifest(data)
114+
115+
def test_can_handle_lowered_quality_requirements_in_manifest(self, session: Session):
116+
escrow_address = ESCROW_ADDRESS
117+
chain_id = Networks.localhost
118+
119+
min_quality1 = 0.8
120+
min_quality2 = 0.5
121+
frame_count = 10
122+
123+
manifest = self._generate_manifest(min_quality=min_quality1)
124+
125+
cvat_task_id = 1
126+
cvat_job_id = 1
127+
annotator1 = WALLET_ADDRESS1
128+
129+
assignment1_id = f"0x{0:040d}"
130+
assignment1_quality = 0.7
131+
132+
assignment2_id = f"0x{1:040d}"
133+
assignment2_quality = 0.6
134+
135+
# create a validation input
136+
with ExitStack() as common_lock_es:
137+
logger = mock.Mock(Logger)
138+
139+
mock_make_cloud_client = common_lock_es.enter_context(
140+
mock.patch("src.handlers.process_intermediate_results.make_cloud_client")
141+
)
142+
mock_make_cloud_client.return_value.download_file = mock.Mock(return_value=b"")
143+
144+
common_lock_es.enter_context(
145+
mock.patch("src.handlers.process_intermediate_results.BucketAccessInfo.parse_obj")
146+
)
147+
148+
mock_get_task_validation_layout = common_lock_es.enter_context(
149+
mock.patch(
150+
"src.handlers.process_intermediate_results.cvat_api.get_task_validation_layout"
151+
)
152+
)
153+
mock_get_task_validation_layout.return_value = mock.Mock(
154+
cvat_api.models.ITaskValidationLayoutRead,
155+
honeypot_frames=[0, 1],
156+
honeypot_real_frames=[0, 1],
157+
)
158+
159+
mock_get_task_data_meta = common_lock_es.enter_context(
160+
mock.patch("src.handlers.process_intermediate_results.cvat_api.get_task_data_meta")
161+
)
162+
mock_get_task_data_meta.return_value = mock.Mock(
163+
cvat_api.models.IDataMetaRead,
164+
frames=[SimpleNamespace(name=f"frame_{i}.jpg") for i in range(frame_count)],
165+
)
166+
167+
common_lock_es.enter_context(
168+
mock.patch("src.handlers.process_intermediate_results.dm.Dataset.import_from")
169+
)
170+
common_lock_es.enter_context(
171+
mock.patch("src.handlers.process_intermediate_results.extract_zip_archive")
172+
)
173+
common_lock_es.enter_context(
174+
mock.patch("src.handlers.process_intermediate_results.write_dir_to_zip_archive")
175+
)
176+
177+
def patched_prepare_merged_dataset(self):
178+
self._updated_merged_dataset_archive = io.BytesIO()
179+
180+
common_lock_es.enter_context(
181+
mock.patch(
182+
"src.handlers.process_intermediate_results._TaskValidator._prepare_merged_dataset",
183+
patched_prepare_merged_dataset,
184+
)
185+
)
186+
187+
annotation_meta = AnnotationMeta(
188+
jobs=[
189+
JobMeta(
190+
job_id=cvat_job_id,
191+
task_id=cvat_task_id,
192+
annotation_filename="",
193+
annotator_wallet_address=annotator1,
194+
assignment_id=assignment1_id,
195+
start_frame=0,
196+
stop_frame=manifest.annotation.job_size + manifest.validation.val_size,
197+
)
198+
]
199+
)
200+
201+
with (
202+
mock.patch(
203+
"src.handlers.process_intermediate_results.cvat_api.get_task_quality_report"
204+
) as mock_get_task_quality_report,
205+
mock.patch(
206+
"src.handlers.process_intermediate_results.cvat_api.get_quality_report_data"
207+
) as mock_get_quality_report_data,
208+
mock.patch(
209+
"src.handlers.process_intermediate_results.cvat_api.get_jobs_quality_reports"
210+
) as mock_get_jobs_quality_reports,
211+
):
212+
mock_get_task_quality_report.return_value = mock.Mock(
213+
cvat_api.models.IQualityReport, id=1
214+
)
215+
mock_get_quality_report_data.return_value = mock.Mock(
216+
cvat_api.QualityReportData,
217+
frame_results={
218+
"0": mock.Mock(annotations=mock.Mock(accuracy=assignment1_quality)),
219+
"1": mock.Mock(annotations=mock.Mock(accuracy=assignment1_quality)),
220+
},
221+
)
222+
mock_get_jobs_quality_reports.return_value = [
223+
mock.Mock(
224+
cvat_api.models.IQualityReport,
225+
job_id=1,
226+
summary=mock.Mock(accuracy=assignment1_quality),
227+
),
228+
]
229+
230+
vr1 = process_intermediate_results(
231+
session,
232+
escrow_address=escrow_address,
233+
chain_id=chain_id,
234+
meta=annotation_meta,
235+
merged_annotations=io.BytesIO(),
236+
manifest=manifest,
237+
logger=logger,
238+
)
239+
240+
assert isinstance(vr1, ValidationFailure)
241+
assert len(vr1.rejected_jobs) == 1
242+
243+
manifest.validation.min_quality = min_quality2
244+
245+
annotation_meta.jobs[0].assignment_id = assignment2_id
246+
247+
with (
248+
mock.patch(
249+
"src.handlers.process_intermediate_results.cvat_api.get_task_quality_report"
250+
) as mock_get_task_quality_report,
251+
mock.patch(
252+
"src.handlers.process_intermediate_results.cvat_api.get_quality_report_data"
253+
) as mock_get_quality_report_data,
254+
mock.patch(
255+
"src.handlers.process_intermediate_results.cvat_api.get_jobs_quality_reports"
256+
) as mock_get_jobs_quality_reports,
257+
):
258+
mock_get_task_quality_report.return_value = mock.Mock(
259+
cvat_api.models.IQualityReport, id=2
260+
)
261+
mock_get_quality_report_data.return_value = mock.Mock(
262+
cvat_api.QualityReportData,
263+
frame_results={
264+
"0": mock.Mock(annotations=mock.Mock(accuracy=assignment2_quality)),
265+
"1": mock.Mock(annotations=mock.Mock(accuracy=assignment2_quality)),
266+
},
267+
)
268+
mock_get_jobs_quality_reports.return_value = [
269+
mock.Mock(
270+
cvat_api.models.IQualityReport,
271+
job_id=1,
272+
summary=mock.Mock(accuracy=assignment2_quality),
273+
),
274+
]
275+
276+
vr2 = process_intermediate_results(
277+
session,
278+
escrow_address=escrow_address,
279+
chain_id=chain_id,
280+
meta=annotation_meta,
281+
merged_annotations=io.BytesIO(),
282+
manifest=manifest,
283+
logger=logger,
284+
)
285+
286+
assert isinstance(vr2, ValidationSuccess)
287+
assert vr2.job_results[cvat_job_id] == assignment2_quality
288+
289+
assert len(vr2.validation_meta.jobs) == 1
290+
assert len(vr2.validation_meta.results) == 2
291+
assert (
292+
vr2.validation_meta.results[
293+
vr2.validation_meta.jobs[0].final_result_id
294+
].annotation_quality
295+
== assignment2_quality
296+
)

packages/examples/cvat/recording-oracle/tests/utils/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,5 @@
9898
-----END PGP PUBLIC KEY BLOCK-----
9999
"""
100100
)
101+
102+
WALLET_ADDRESS1 = "0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed"

0 commit comments

Comments
 (0)