diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/authentication.py b/packages/examples/cvat/exchange-oracle/src/endpoints/authentication.py index 38ebcca10b..54d8039f4d 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/authentication.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/authentication.py @@ -37,6 +37,10 @@ class AuthorizationData(BaseModel): email: str +class AssignmentAuthorizationData(AuthorizationData): + qualifications: list[str] + + AuthDataT = TypeVar("AuthDataT", bound=AuthorizationData) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py b/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py index 60c876c20f..e75b14ddaf 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py @@ -18,6 +18,7 @@ from src.db import SessionLocal from src.db import engine as db_engine from src.endpoints.authentication import ( + AssignmentAuthorizationData, AuthorizationData, AuthorizationParam, JobListAuthorizationData, @@ -395,17 +396,24 @@ def _page_serializer( description="Start an assignment within the task for the annotator", ) async def create_assignment( - data: AssignmentRequest, token: Annotated[AuthorizationData, AuthorizationParam] + data: AssignmentRequest, + token: Annotated[ + AssignmentAuthorizationData, make_auth_dependency(AssignmentAuthorizationData) + ], ) -> AssignmentResponse: try: assignment_id = oracle_service.create_assignment( escrow_address=data.escrow_address, chain_id=data.chain_id, wallet_address=token.wallet_address, + qualifications=token.qualifications, ) except oracle_service.UserHasUnfinishedAssignmentError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except oracle_service.UserQualificationError as e: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e)) from e + if not assignment_id: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index 02abe39bbd..083b1e8efc 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -973,7 +973,7 @@ def get_user_assignments_in_cvat_projects( def has_active_user_assignments( session: Session, - wallet_address: int, + wallet_address: str, escrow_address: str, chain_id: int, ) -> bool: diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index 7d514597a0..5d6d1b3dda 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -3,11 +3,12 @@ import src.cvat.api_calls as cvat_api import src.services.cvat as cvat_service +from src.chain.escrow import get_escrow_manifest from src.core.types import JobStatuses, Networks, ProjectStatuses, TaskTypes from src.db import SessionLocal from src.db.utils import ForUpdateParams from src.models.cvat import Job -from src.utils.assignments import get_default_assignment_timeout +from src.utils.assignments import get_default_assignment_timeout, parse_manifest from src.utils.requests import get_or_404 from src.utils.time import utcnow @@ -20,7 +21,14 @@ def __str__(self) -> str: ) -def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: str) -> str | None: +class UserQualificationError(Exception): + def __str__(self) -> str: + return "User doesn't have required qualifications." + + +def create_assignment( + escrow_address: str, chain_id: Networks, wallet_address: str, qualifications: list[str] +) -> str | None: with SessionLocal.begin() as session: user = get_or_404( cvat_service.get_user_by_id(session, wallet_address, for_update=True), @@ -28,6 +36,11 @@ def create_assignment(escrow_address: str, chain_id: Networks, wallet_address: s object_type_name="user", ) + manifest = parse_manifest(get_escrow_manifest(chain_id, escrow_address)) + + if not all(q in qualifications for q in manifest.qualifications): + raise UserQualificationError + if cvat_service.has_active_user_assignments( session, wallet_address=wallet_address, diff --git a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py index 72b3300a13..71c812bc52 100644 --- a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py +++ b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py @@ -57,11 +57,15 @@ def generate_jwt_token( *, wallet_address: str | None = None, email: str = cvat_email, + qualifications: list[str] | None = None, private_key: str = PRIVATE_KEY, ) -> str: + if qualifications is None: + qualifications = [] data = { **({"wallet_address": wallet_address} if wallet_address else {"role": "human_app"}), "email": email, + "qualifications": qualifications, } return jwt.encode(data, private_key, algorithm="ES256") @@ -871,14 +875,16 @@ def test_can_create_assignment_200(client: TestClient, session: Session) -> None with ( open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, + patch("src.endpoints.serializers.get_escrow_manifest") as mock_serializer_get_manifest, + patch("src.services.exchange.get_escrow_manifest") as mock_exchange_get_manifest, patch( "src.endpoints.serializers.get_escrow_fund_token_symbol" ) as mock_get_escrow_fund_token_symbol, patch("src.services.exchange.cvat_api") as cvat_api, ): manifest = json.load(data) - mock_get_manifest.return_value = manifest + mock_serializer_get_manifest.return_value = manifest + mock_exchange_get_manifest.return_value = manifest mock_get_escrow_fund_token_symbol.return_value = "HMT" assert {cvat_project.updated_at, cvat_task.updated_at, cvat_job.updated_at} == {None} @@ -972,17 +978,23 @@ def test_cannot_create_assignment_400_when_has_unfinished_assignments( session.commit() - response = client.post( - "/assignment", - headers=get_auth_header(), - json={ - "escrow_address": cvat_project.escrow_address, - "chain_id": cvat_project.chain_id, - }, - ) + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + response = client.post( + "/assignment", + headers=get_auth_header(), + json={ + "escrow_address": cvat_project.escrow_address, + "chain_id": cvat_project.chain_id, + }, + ) - assert response.status_code == 400 - assert "There are unfinished assignments in this escrow" in response.text + assert response.status_code == 400 + assert "There are unfinished assignments in this escrow" in response.text def test_can_list_assignments_200(client: TestClient, session: Session) -> None: @@ -1501,14 +1513,16 @@ def test_can_list_jobs_200_check_updated_at(client: TestClient, session: Session with ( open("tests/utils/manifest.json") as data, - patch("src.endpoints.serializers.get_escrow_manifest") as mock_get_manifest, + patch("src.endpoints.serializers.get_escrow_manifest") as mock_serializer_get_manifest, + patch("src.services.exchange.get_escrow_manifest") as mock_exchange_get_manifest, patch( "src.endpoints.serializers.get_escrow_fund_token_symbol" ) as mock_get_escrow_fund_token_symbol, patch("src.services.exchange.cvat_api"), ): manifest = json.load(data) - mock_get_manifest.return_value = manifest + mock_serializer_get_manifest.return_value = manifest + mock_exchange_get_manifest.return_value = manifest mock_get_escrow_fund_token_symbol.return_value = "HMT" # create assignment in each job diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py index 590ff6db5e..e41b5cca89 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py @@ -97,9 +97,15 @@ def test_create_assignment(self): self.session.commit() - with patch("src.services.exchange.cvat_api"): + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest assignment_id = create_assignment( - cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address + cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address, [] ) assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() @@ -146,9 +152,15 @@ def test_create_assignment_many_jobs_1_completed(self): self.session.commit() - with patch("src.services.exchange.cvat_api"): + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest assignment_id = create_assignment( - cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address + cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address, [] ) assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() @@ -166,6 +178,7 @@ def test_create_assignment_invalid_user_address(self): cvat_project_1.escrow_address, Networks(cvat_project_1.chain_id), "invalid_address", + [], ) def test_create_assignment_invalid_project(self): @@ -178,8 +191,82 @@ def test_create_assignment_invalid_project(self): self.session.add(user) self.session.commit() - with pytest.raises(HTTPException, match="Can't find job"): - create_assignment("1", Networks.localhost, user_address) + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + with pytest.raises(HTTPException, match="Can't find job"): + create_assignment("1", Networks.localhost, user_address, []) + + def test_create_assignment_no_required_qualifications(self): + user_address = WALLET_ADDRESS1 + user = User( + wallet_address=user_address, + cvat_email="test@hmt.ai", + cvat_id=1, + ) + self.session.add(user) + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + ): + manifest = json.load(data) + manifest["qualifications"] = ["random_qualification"] + mock_get_manifest.return_value = manifest + with pytest.raises(Exception, match="User doesn't have required qualifications."): + create_assignment(ESCROW_ADDRESS, Networks.localhost, user_address, []) + + def test_create_assignment_with_required_qualifications(self): + cvat_project, cvat_task, cvat_job = create_project_task_and_job( + self.session, ESCROW_ADDRESS, 1 + ) + initial_job_updated_at = cvat_job.updated_at + initial_task_updated_at = cvat_task.updated_at + initial_project_updated_at = cvat_project.updated_at + + user_address = WALLET_ADDRESS1 + user = User( + wallet_address=user_address, + cvat_email="test@hmt.ai", + cvat_id=1, + ) + self.session.add(user) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + manifest["qualifications"] = ["test", "test2"] + mock_get_manifest.return_value = manifest + assignment_id = create_assignment( + cvat_project.escrow_address, + Networks(cvat_project.chain_id), + user_address, + ["test", "test2", "test3"], + ) + + assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() + + assert assignment.cvat_job_id == cvat_job.cvat_id + assert assignment.user_wallet_address == user_address + assert assignment.status == AssignmentStatuses.created + + self.session.refresh(cvat_job) + assert cvat_job.updated_at != initial_job_updated_at + + self.session.refresh(cvat_task) + assert cvat_task.updated_at != initial_task_updated_at + + self.session.refresh(cvat_project) + assert cvat_project.updated_at != initial_project_updated_at def test_create_assignment_unfinished_assignment(self): _, _, cvat_job = create_project_task_and_job(self.session, ESCROW_ADDRESS, 1) @@ -205,10 +292,14 @@ def test_create_assignment_unfinished_assignment(self): self.session.commit() with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, patch("src.services.exchange.cvat_api"), - pytest.raises(Exception, match="unfinished assignment"), ): - create_assignment(ESCROW_ADDRESS, Networks.localhost, user_address) + manifest = json.load(data) + mock_get_manifest.return_value = manifest + with pytest.raises(Exception, match="unfinished assignment"): + create_assignment(ESCROW_ADDRESS, Networks.localhost, user_address, []) def test_create_assignment_has_expired_assignment_and_available_jobs(self): escrow_address = ESCROW_ADDRESS @@ -238,8 +329,16 @@ def test_create_assignment_has_expired_assignment_and_available_jobs(self): self.session.commit() - with patch("src.services.exchange.cvat_api"): - new_assignment_id = create_assignment(escrow_address, Networks.localhost, user_address) + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + new_assignment_id = create_assignment( + escrow_address, Networks.localhost, user_address, [] + ) new_assignment = self.session.query(Assignment).filter_by(id=new_assignment_id).first() assert new_assignment.cvat_job_id == cvat_job2.cvat_id # job1 was attempted already @@ -280,9 +379,15 @@ def test_create_assignment_no_available_jobs_completed_assignment(self): self.session.commit() - with patch("src.services.exchange.cvat_api"): + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest assignment_id = create_assignment( - cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address2 + cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address2, [] ) assert assignment_id == None @@ -319,9 +424,15 @@ def test_create_assignment_no_available_jobs_active_foreign_assignment(self): self.session.commit() - with patch("src.services.exchange.cvat_api"): + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest assignment_id = create_assignment( - cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address2 + cvat_project.escrow_address, Networks(cvat_project.chain_id), user_address2, [] ) assert assignment_id == None @@ -351,11 +462,18 @@ def test_create_assignment_wont_reassign_job_to_previous_user(self): self.session.commit() - with patch("src.services.exchange.cvat_api"): + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest assignment_id = create_assignment( cvat_project_1.escrow_address, Networks(cvat_project_1.chain_id), user.wallet_address, + [], ) assert assignment_id is None @@ -391,11 +509,18 @@ def test_create_assignment_can_assign_job_to_new_user(self): self.session.commit() - with patch("src.services.exchange.cvat_api"): + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest assignment_id = create_assignment( cvat_project_1.escrow_address, Networks(cvat_project_1.chain_id), new_user.wallet_address, + [], ) assignment = self.session.get(Assignment, assignment_id) diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/manifest.json b/packages/examples/cvat/exchange-oracle/tests/utils/manifest.json index 2577f20c38..950b15cd3f 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/manifest.json +++ b/packages/examples/cvat/exchange-oracle/tests/utils/manifest.json @@ -15,5 +15,6 @@ "val_size": 2, "gt_url": "https://test.storage.googleapis.com" }, - "job_bounty": "5.001123929619726" + "job_bounty": "5.001123929619726", + "qualifications": [] }