Skip to content

Commit 916cb0d

Browse files
committed
Update recording oracle
1 parent 95a80c1 commit 916cb0d

File tree

8 files changed

+206
-88
lines changed

8 files changed

+206
-88
lines changed

packages/examples/cvat/recording-oracle/alembic/versions/9d4367899f90_recreate_gt_stats.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,19 @@ def upgrade() -> None:
2121
# ### commands auto generated by Alembic - please adjust! ###
2222
op.drop_index("ix_gt_stats_gt_key", table_name="gt_stats")
2323
op.drop_table("gt_stats")
24-
# ### end Alembic commands ###
2524

2625
op.create_table(
2726
"gt_stats",
28-
sa.Column("task_id", sa.VARCHAR(), autoincrement=False, nullable=False),
29-
sa.Column("cvat_task_id", sa.INTEGER(), autoincrement=False, nullable=False),
30-
sa.Column("gt_frame_id", sa.INTEGER(), autoincrement=False, nullable=False),
31-
sa.Column("failed_attempts", sa.INTEGER(), autoincrement=False, nullable=False),
32-
sa.ForeignKeyConstraint(
33-
["task_id"], ["tasks.id"], name="gt_stats_task_id_fkey", ondelete="CASCADE"
34-
),
35-
sa.PrimaryKeyConstraint("task_id", "gt_frame_id", name="gt_stats_pkey"),
27+
sa.Column("task_id", sa.String(), nullable=False),
28+
sa.Column("cvat_task_id", sa.Integer(), nullable=False),
29+
sa.Column("gt_frame_id", sa.Integer(), nullable=False),
30+
sa.Column("failed_attempts", sa.Integer(), nullable=False),
31+
sa.Column("accepted_attempts", sa.Integer(), nullable=False),
32+
sa.Column("accumulated_quality", sa.Float(), nullable=False),
33+
sa.ForeignKeyConstraint(["task_id"], ["tasks.id"], ondelete="CASCADE"),
34+
sa.PrimaryKeyConstraint("task_id", "gt_frame_id"),
3635
)
36+
# ### end Alembic commands ###
3737

3838

3939
def downgrade() -> None:

packages/examples/cvat/recording-oracle/src/.env.template

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,7 @@ PGP_PRIVATE_KEY=
7878
PGP_PASSPHRASE=
7979
PGP_PUBLIC_KEY_URL=
8080

81+
CVAT_URL=
8182
CVAT_ADMIN=
82-
CVAT_ADMIN_PASS=
83+
CVAT_ADMIN_PASS=
84+
CVAT_ORG_SLUG=

packages/examples/cvat/recording-oracle/src/core/annotation_meta.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterator
12
from pathlib import Path
23

34
from pydantic import BaseModel
@@ -12,7 +13,19 @@ class JobMeta(BaseModel):
1213
annotation_filename: Path
1314
annotator_wallet_address: str
1415
assignment_id: str
16+
start_frame: int
17+
stop_frame: int
18+
19+
@property
20+
def job_frame_range(self) -> Iterator[int]:
21+
return range(self.start_frame, self.stop_frame + 1)
1522

1623

1724
class AnnotationMeta(BaseModel):
1825
jobs: list[JobMeta]
26+
27+
def skip_jobs(self, job_ids: list[int]):
28+
# self.jobs = [
29+
# job for job in self.jobs if job.job_id not in job_ids
30+
# ]
31+
return AnnotationMeta(jobs=[job for job in self.jobs if job.job_id not in job_ids])
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from attrs import define
2+
3+
4+
@define(kw_only=True)
5+
class ValidationFrameStats:
6+
accumulated_quality: float = 0.0
7+
failed_attempts: int = 0
8+
accepted_attempts: int = 0
9+
10+
@property
11+
def average_quality(self) -> float:
12+
return self.accumulated_quality / ((self.failed_attempts + self.accepted_attempts) or 1)
13+
14+
15+
_TaskIdValFrameIdPair = tuple[int, int]
16+
17+
GtStats = dict[_TaskIdValFrameIdPair, ValidationFrameStats]

packages/examples/cvat/recording-oracle/src/cvat/api_calls.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,13 @@ def get_task_validation_layout(task_id: int) -> models.TaskValidationLayoutRead:
127127
raise
128128

129129

130-
def get_jobs_quality_reports(parent_id: int) -> dict[int, models.QualityReport]:
130+
def get_jobs_quality_reports(parent_id: int) -> list[models.QualityReport]:
131131
logger = logging.getLogger("app")
132132
with get_api_client() as api_client:
133133
try:
134-
reports: list[models.QualityReport] = get_paginated_collection(
134+
return get_paginated_collection(
135135
api_client.quality_api.list_reports_endpoint, parent_id=parent_id, target="job"
136136
)
137-
return {report.job_id: report for report in reports}
138137

139138
except exceptions.ApiException as e:
140139
logger.exception(f"Exception when calling QualityApi.list_reports: {e}\n")
@@ -145,21 +144,18 @@ def update_task_validation_layout(
145144
task_id: int,
146145
*,
147146
disabled_frames: list[int],
148-
shuffle_honeypots: bool = True,
147+
honeypot_real_frames: list[int],
149148
) -> None:
150149
logger = logging.getLogger("app")
151-
params = {
152-
"disabled_frames": disabled_frames,
153-
}
154-
if shuffle_honeypots:
155-
params["frame_selection_method"] = models.FrameSelectionMethod("random_uniform")
156150

157151
with get_api_client() as api_client:
158152
try:
159153
validation_layout, _ = api_client.tasks_api.partial_update_validation_layout(
160154
task_id,
161155
patched_task_validation_layout_write_request=models.PatchedTaskValidationLayoutWriteRequest(
162-
**params
156+
disabled_frames=disabled_frames,
157+
honeypot_real_frames=honeypot_real_frames,
158+
frame_selection_method=models.FrameSelectionMethod("manual"),
163159
),
164160
)
165161
except exceptions.ApiException as ex:

0 commit comments

Comments
 (0)