Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class AuthorizationData(BaseModel):
email: str


class AssignmentAuthorizationData(AuthorizationData):
qualifications: list[str]


AuthDataT = TypeVar("AuthDataT", bound=AuthorizationData)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from src.db import SessionLocal
from src.db import engine as db_engine
from src.endpoints.authentication import (
AssignmentAuthorizationData,
AuthorizationData,
AuthorizationParam,
JobListAuthorizationData,
Expand Down Expand Up @@ -395,17 +396,24 @@ def _page_serializer(
description="Start an assignment within the task for the annotator",
)
async def create_assignment(
data: AssignmentRequest, token: Annotated[AuthorizationData, AuthorizationParam]
data: AssignmentRequest,
token: Annotated[
AssignmentAuthorizationData, make_auth_dependency(AssignmentAuthorizationData)
],
) -> AssignmentResponse:
try:
assignment_id = oracle_service.create_assignment(
escrow_address=data.escrow_address,
chain_id=data.chain_id,
wallet_address=token.wallet_address,
qualifications=token.qualifications,
)
except oracle_service.UserHasUnfinishedAssignmentError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e

except oracle_service.UserQualificationError as e:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e)) from e

if not assignment_id:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def get_user_assignments_in_cvat_projects(

def has_active_user_assignments(
session: Session,
wallet_address: int,
wallet_address: str,
escrow_address: str,
chain_id: int,
) -> bool:
Expand Down
17 changes: 15 additions & 2 deletions packages/examples/cvat/exchange-oracle/src/services/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import src.cvat.api_calls as cvat_api
import src.services.cvat as cvat_service
from src.chain.escrow import get_escrow_manifest
from src.core.types import JobStatuses, Networks, ProjectStatuses, TaskTypes
from src.db import SessionLocal
from src.db.utils import ForUpdateParams
from src.models.cvat import Job
from src.utils.assignments import get_default_assignment_timeout
from src.utils.assignments import get_default_assignment_timeout, parse_manifest
from src.utils.requests import get_or_404
from src.utils.time import utcnow

Expand All @@ -20,14 +21,26 @@ def __str__(self) -> str:
)


def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: str) -> str | None:
class UserQualificationError(Exception):
def __str__(self) -> str:
return "User doesn't have required qualifications."


def create_assignment(
escrow_address: str, chain_id: Networks, wallet_address: str, qualifications: list[str]
) -> str | None:
with SessionLocal.begin() as session:
user = get_or_404(
cvat_service.get_user_by_id(session, wallet_address, for_update=True),
wallet_address,
object_type_name="user",
)

manifest = parse_manifest(get_escrow_manifest(chain_id, escrow_address))

if not all(q in qualifications for q in manifest.qualifications):
raise UserQualificationError

if cvat_service.has_active_user_assignments(
session,
wallet_address=wallet_address,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ def generate_jwt_token(
*,
wallet_address: str | None = None,
email: str = cvat_email,
qualifications: list[str] | None = None,
private_key: str = PRIVATE_KEY,
) -> str:
if qualifications is None:
qualifications = []
data = {
**({"wallet_address": wallet_address} if wallet_address else {"role": "human_app"}),
"email": email,
"qualifications": qualifications,
}

return jwt.encode(data, private_key, algorithm="ES256")
Expand Down Expand Up @@ -871,14 +875,16 @@ def test_can_create_assignment_200(client: TestClient, session: Session) -> None

with (
open("tests/utils/manifest.json") as data,
patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest,
patch("src.endpoints.serializers.get_escrow_manifest") as mock_serializer_get_manifest,
patch("src.services.exchange.get_escrow_manifest") as mock_exchange_get_manifest,
patch(
"src.endpoints.serializers.get_escrow_fund_token_symbol"
) as mock_get_escrow_fund_token_symbol,
patch("src.services.exchange.cvat_api") as cvat_api,
):
manifest = json.load(data)
mock_get_manifest.return_value = manifest
mock_serializer_get_manifest.return_value = manifest
mock_exchange_get_manifest.return_value = manifest
mock_get_escrow_fund_token_symbol.return_value = "HMT"

assert {cvat_project.updated_at, cvat_task.updated_at, cvat_job.updated_at} == {None}
Expand Down Expand Up @@ -972,17 +978,23 @@ def test_cannot_create_assignment_400_when_has_unfinished_assignments(

session.commit()

response = client.post(
"/assignment",
headers=get_auth_header(),
json={
"escrow_address": cvat_project.escrow_address,
"chain_id": cvat_project.chain_id,
},
)
with (
open("tests/utils/manifest.json") as data,
patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest,
):
manifest = json.load(data)
mock_get_manifest.return_value = manifest
response = client.post(
"/assignment",
headers=get_auth_header(),
json={
"escrow_address": cvat_project.escrow_address,
"chain_id": cvat_project.chain_id,
},
)

assert response.status_code == 400
assert "There are unfinished assignments in this escrow" in response.text
assert response.status_code == 400
assert "There are unfinished assignments in this escrow" in response.text


def test_can_list_assignments_200(client: TestClient, session: Session) -> None:
Expand Down Expand Up @@ -1501,14 +1513,16 @@ def test_can_list_jobs_200_check_updated_at(client: TestClient, session: Session

with (
open("tests/utils/manifest.json") as data,
patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest,
patch("src.endpoints.serializers.get_escrow_manifest") as mock_serializer_get_manifest,
patch("src.services.exchange.get_escrow_manifest") as mock_exchange_get_manifest,
patch(
"src.endpoints.serializers.get_escrow_fund_token_symbol"
) as mock_get_escrow_fund_token_symbol,
patch("src.services.exchange.cvat_api"),
):
manifest = json.load(data)
mock_get_manifest.return_value = manifest
mock_serializer_get_manifest.return_value = manifest
mock_exchange_get_manifest.return_value = manifest
mock_get_escrow_fund_token_symbol.return_value = "HMT"

# create assignment in each job
Expand Down
Loading