Skip to content

Commit cd9289f

Browse files
authored
[CVAT] Add qualification validation for POST /assignment (#3565)
1 parent 2dbb5da commit cd9289f

File tree

7 files changed

+200
-35
lines changed

7 files changed

+200
-35
lines changed

packages/examples/cvat/exchange-oracle/src/endpoints/authentication.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ class AuthorizationData(BaseModel):
3737
email: str
3838

3939

40+
class AssignmentAuthorizationData(AuthorizationData):
41+
qualifications: list[str]
42+
43+
4044
AuthDataT = TypeVar("AuthDataT", bound=AuthorizationData)
4145

4246

packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from src.db import SessionLocal
1919
from src.db import engine as db_engine
2020
from src.endpoints.authentication import (
21+
AssignmentAuthorizationData,
2122
AuthorizationData,
2223
AuthorizationParam,
2324
JobListAuthorizationData,
@@ -395,17 +396,24 @@ def _page_serializer(
395396
description="Start an assignment within the task for the annotator",
396397
)
397398
async def create_assignment(
398-
data: AssignmentRequest, token: Annotated[AuthorizationData, AuthorizationParam]
399+
data: AssignmentRequest,
400+
token: Annotated[
401+
AssignmentAuthorizationData, make_auth_dependency(AssignmentAuthorizationData)
402+
],
399403
) -> AssignmentResponse:
400404
try:
401405
assignment_id = oracle_service.create_assignment(
402406
escrow_address=data.escrow_address,
403407
chain_id=data.chain_id,
404408
wallet_address=token.wallet_address,
409+
qualifications=token.qualifications,
405410
)
406411
except oracle_service.UserHasUnfinishedAssignmentError as e:
407412
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
408413

414+
except oracle_service.UserQualificationError as e:
415+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e)) from e
416+
409417
if not assignment_id:
410418
raise HTTPException(
411419
status_code=HTTPStatus.BAD_REQUEST,

packages/examples/cvat/exchange-oracle/src/services/cvat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ def get_user_assignments_in_cvat_projects(
973973

974974
def has_active_user_assignments(
975975
session: Session,
976-
wallet_address: int,
976+
wallet_address: str,
977977
escrow_address: str,
978978
chain_id: int,
979979
) -> bool:

packages/examples/cvat/exchange-oracle/src/services/exchange.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
import src.cvat.api_calls as cvat_api
55
import src.services.cvat as cvat_service
6+
from src.chain.escrow import get_escrow_manifest
67
from src.core.types import JobStatuses, Networks, ProjectStatuses, TaskTypes
78
from src.db import SessionLocal
89
from src.db.utils import ForUpdateParams
910
from src.models.cvat import Job
10-
from src.utils.assignments import get_default_assignment_timeout
11+
from src.utils.assignments import get_default_assignment_timeout, parse_manifest
1112
from src.utils.requests import get_or_404
1213
from src.utils.time import utcnow
1314

@@ -20,14 +21,26 @@ def __str__(self) -> str:
2021
)
2122

2223

23-
def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: str) -> str | None:
24+
class UserQualificationError(Exception):
25+
def __str__(self) -> str:
26+
return "User doesn't have required qualifications."
27+
28+
29+
def create_assignment(
30+
escrow_address: str, chain_id: Networks, wallet_address: str, qualifications: list[str]
31+
) -> str | None:
2432
with SessionLocal.begin() as session:
2533
user = get_or_404(
2634
cvat_service.get_user_by_id(session, wallet_address, for_update=True),
2735
wallet_address,
2836
object_type_name="user",
2937
)
3038

39+
manifest = parse_manifest(get_escrow_manifest(chain_id, escrow_address))
40+
41+
if not all(q in qualifications for q in manifest.qualifications):
42+
raise UserQualificationError
43+
3144
if cvat_service.has_active_user_assignments(
3245
session,
3346
wallet_address=wallet_address,

packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,15 @@ def generate_jwt_token(
5757
*,
5858
wallet_address: str | None = None,
5959
email: str = cvat_email,
60+
qualifications: list[str] | None = None,
6061
private_key: str = PRIVATE_KEY,
6162
) -> str:
63+
if qualifications is None:
64+
qualifications = []
6265
data = {
6366
**({"wallet_address": wallet_address} if wallet_address else {"role": "human_app"}),
6467
"email": email,
68+
"qualifications": qualifications,
6569
}
6670

6771
return jwt.encode(data, private_key, algorithm="ES256")
@@ -871,14 +875,16 @@ def test_can_create_assignment_200(client: TestClient, session: Session) -> None
871875

872876
with (
873877
open("tests/utils/manifest.json") as data,
874-
patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest,
878+
patch("src.endpoints.serializers.get_escrow_manifest") as mock_serializer_get_manifest,
879+
patch("src.services.exchange.get_escrow_manifest") as mock_exchange_get_manifest,
875880
patch(
876881
"src.endpoints.serializers.get_escrow_fund_token_symbol"
877882
) as mock_get_escrow_fund_token_symbol,
878883
patch("src.services.exchange.cvat_api") as cvat_api,
879884
):
880885
manifest = json.load(data)
881-
mock_get_manifest.return_value = manifest
886+
mock_serializer_get_manifest.return_value = manifest
887+
mock_exchange_get_manifest.return_value = manifest
882888
mock_get_escrow_fund_token_symbol.return_value = "HMT"
883889

884890
assert {cvat_project.updated_at, cvat_task.updated_at, cvat_job.updated_at} == {None}
@@ -972,17 +978,23 @@ def test_cannot_create_assignment_400_when_has_unfinished_assignments(
972978

973979
session.commit()
974980

975-
response = client.post(
976-
"/assignment",
977-
headers=get_auth_header(),
978-
json={
979-
"escrow_address": cvat_project.escrow_address,
980-
"chain_id": cvat_project.chain_id,
981-
},
982-
)
981+
with (
982+
open("tests/utils/manifest.json") as data,
983+
patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest,
984+
):
985+
manifest = json.load(data)
986+
mock_get_manifest.return_value = manifest
987+
response = client.post(
988+
"/assignment",
989+
headers=get_auth_header(),
990+
json={
991+
"escrow_address": cvat_project.escrow_address,
992+
"chain_id": cvat_project.chain_id,
993+
},
994+
)
983995

984-
assert response.status_code == 400
985-
assert "There are unfinished assignments in this escrow" in response.text
996+
assert response.status_code == 400
997+
assert "There are unfinished assignments in this escrow" in response.text
986998

987999

9881000
def test_can_list_assignments_200(client: TestClient, session: Session) -> None:
@@ -1501,14 +1513,16 @@ def test_can_list_jobs_200_check_updated_at(client: TestClient, session: Session
15011513

15021514
with (
15031515
open("tests/utils/manifest.json") as data,
1504-
patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest,
1516+
patch("src.endpoints.serializers.get_escrow_manifest") as mock_serializer_get_manifest,
1517+
patch("src.services.exchange.get_escrow_manifest") as mock_exchange_get_manifest,
15051518
patch(
15061519
"src.endpoints.serializers.get_escrow_fund_token_symbol"
15071520
) as mock_get_escrow_fund_token_symbol,
15081521
patch("src.services.exchange.cvat_api"),
15091522
):
15101523
manifest = json.load(data)
1511-
mock_get_manifest.return_value = manifest
1524+
mock_serializer_get_manifest.return_value = manifest
1525+
mock_exchange_get_manifest.return_value = manifest
15121526
mock_get_escrow_fund_token_symbol.return_value = "HMT"
15131527

15141528
# create assignment in each job

0 commit comments

Comments
 (0)