diff --git a/apps/commons/views.py b/apps/commons/views.py index 647ecb42..00cd5f14 100644 --- a/apps/commons/views.py +++ b/apps/commons/views.py @@ -1,9 +1,18 @@ +from django.db.models import QuerySet from django.shortcuts import get_object_or_404 from rest_framework import mixins, viewsets +from rest_framework.permissions import IsAuthenticated, IsAuthenticatedOrReadOnly from rest_framework.response import Response from rest_framework.settings import api_settings +from apps.accounts.models import ProjectUser +from apps.accounts.permissions import HasBasePermission +from apps.commons.permissions import IsOwner, ReadOnly, WillBeOwner +from apps.commons.utils import map_action_to_permission from apps.organizations.models import Organization +from apps.organizations.permissions import HasOrganizationPermission +from apps.projects.models import Project +from apps.projects.permissions import HasProjectPermission, ProjectIsNotLocked from .mixins import HasMultipleIDs @@ -143,10 +152,252 @@ def get_paginated_list(self, queryset): return Response(serializer.data) -class NestedOrganizationViewMixins: - def initial(self, request, *args, **kwargs): - self.organization = get_object_or_404( - Organization, code=kwargs["organization_code"] +class OrganizationRelatedViewset(viewsets.GenericViewSet): + organization_code_url_kwarg: str = "organization_code" + queryset_organization_field: str = "organization" + + read_only_permissions: bool = True + permissions_app_label: str = "" + permissions_base_codename: str = "" + + def get_permissions(self): + if self.permissions_base_codename and self.permissions_app_label: + codename = map_action_to_permission( + self.action, self.permissions_base_codename + ) + if self.read_only_permissions: + return [ + IsAuthenticatedOrReadOnly, + ReadOnly + | HasBasePermission(codename, self.permissions_app_label) + | HasOrganizationPermission(codename), + ] + return [ + IsAuthenticated, + HasBasePermission(codename, self.permissions_app_label) + | HasOrganizationPermission(codename), + ] + return super().get_permissions() + + def organization_filter_queryset(self, queryset: "QuerySet") -> "QuerySet": + """ + Filter the given queryset by the organization specified in the URL. + """ + return queryset.filter(**{self.queryset_organization_field: self.organization}) + + def get_queryset(self): + """ + Return the queryset for this viewset, filtered by the organization specified + in the URL. + """ + return self.organization_filter_queryset(super().get_queryset()) + + def get_serializer_context(self): + return { + **super().get_serializer_context(), + "organization": self.organization, + } + + @property + def organization(self) -> Organization: + if not hasattr(self, "_organization"): + if self.organization_code_url_kwarg not in self.kwargs: + raise ValueError( + f"URL kwarg '{self.organization_code_url_kwarg}' is required for a" + f" viewset based on OrganizationRelatedViewset." + ) + self._organization = get_object_or_404( + Organization, code=self.kwargs[self.organization_code_url_kwarg] + ) + return self._organization + + +class ProjectRelatedViewset(MultipleIDViewsetMixin, OrganizationRelatedViewset): + """ + A viewset for models relared to a project. + + This viewset should only be accessed through a URL containing the `project_id` and + `organization_code` kwargs. + e.g. `/v1/organizations/{organization_code}/projects/{project_id}/my_model/` + + The viewset automatically handles filtering using the request user's permissions, + and it provides the project in the serializer context. + + Attributes : + ------------ + organization_code_url_kwarg: str (default: "organization_code") + The name of the URL kwarg containing the organization code. + project_id_url_kwarg: str (default: "project_id") + The name of the URL kwarg containing the project id. + queryset_organization_field: str (default: "project__organizations") + The name of the field to use for filtering the queryset by organization. + queryset_project_field: str (default: "project") + The name of the field to use for filtering the queryset by project. + read_only_permissions: bool (default: True) + Whether the viewset should use read-only permissions. This is useful when the + read permissions are handled at the instance level. + block_if_project_is_locked: bool (default: True) + Whether to block all actions if the project is locked. + permissions_app_label: str (default: "") + The app label to use in the default permissions check + permissions_base_codename: str (default: "") + The base codename to use for generating the permissions to check. If not set, + the `permissions_codename` attribute will be used as the codename for all actions. + permissions_codename: str (default: "change_project") + The codename to use for the default permissions check if`permissions_base_codename` + is not set. This can be used if the same permission is used for all actions. + multiple_lookup_fields: list of tuple[HasMultipleIDs, str] (default: []) + Inherited from MultipleIDViewsetMixin. A list of tuples containing a model that + inherits from HasMultipleIDs and the name of the URL kwarg containing the id to + transform into the main id. + """ + + project_id_url_kwarg: str = "project_id" + queryset_organization_field: str = "project__organizations" + queryset_project_field: str = "project" + + read_only_permissions: bool = True + block_if_project_is_locked: bool = True + permissions_app_label: str = "projects" + permissions_base_codename: str = "" + permissions_codename: str = "change_project" + + multiple_lookup_fields = [ + (Project, "project_id"), + ] + + def get_permissions(self): + if self.permissions_base_codename: + codename = map_action_to_permission( + self.action, self.permissions_base_codename + ) + else: + codename = self.permissions_codename + if codename and self.permissions_app_label: + if self.read_only_permissions: + permissions = [ + IsAuthenticatedOrReadOnly, + ReadOnly + | HasBasePermission(codename, self.permissions_app_label) + | HasOrganizationPermission(codename) + | HasProjectPermission(codename), + ] + else: + permissions = [ + IsAuthenticated, + HasBasePermission(codename, self.permissions_app_label) + | HasOrganizationPermission(codename) + | HasProjectPermission(codename), + ] + if self.block_if_project_is_locked: + permissions.insert(1, ProjectIsNotLocked) + return permissions + return super().get_permissions() + + def project_filter_queryset(self, queryset: "QuerySet") -> "QuerySet": + """ + Filter the given queryset by the project specified in the URL. + """ + return self.request.user.get_project_related_queryset( + queryset.filter(**{self.queryset_project_field: self.project}), + self.queryset_project_field, ) - super().initial(request, *args, **kwargs) + def get_queryset(self): + """ + Return the queryset for this viewset, filtered by the project and the + organization specified in the URL. + """ + return self.project_filter_queryset(super().get_queryset()) + + def get_serializer_context(self): + return { + **super().get_serializer_context(), + "project": self.project, + } + + @property + def project(self) -> Project: + if not hasattr(self, "_project"): + if self.project_id_url_kwarg not in self.kwargs: + raise ValueError( + f"URL kwarg '{self.project_id_url_kwarg}' is required for a" + f" viewset based on ProjectRelatedViewset." + ) + self._project = get_object_or_404( + Project, id=self.kwargs[self.project_id_url_kwarg] + ) + return self._project + + +class UserRelatedViewset(OrganizationRelatedViewset): + user_id_url_kwarg: str = "user_id" + queryset_organization_field: str = "user__groups__organizations" + queryset_user_field: str = "user" + + read_only_permissions: bool = True + permissions_app_label: str = "accounts" + permissions_base_codename: str = "" + permissions_codename: str = "change_projectuser" + + def get_permissions(self): + if self.permissions_base_codename: + codename = map_action_to_permission( + self.action, self.permissions_base_codename + ) + else: + codename = self.permissions_codename + if codename and self.permissions_app_label: + if self.read_only_permissions: + return [ + IsAuthenticatedOrReadOnly, + ReadOnly + | IsOwner + | WillBeOwner + | HasBasePermission(codename, self.permissions_app_label) + | HasOrganizationPermission(codename), + ] + return [ + IsAuthenticated, + IsOwner + | WillBeOwner + | HasBasePermission(codename, self.permissions_app_label) + | HasOrganizationPermission(codename), + ] + return super().get_permissions() + + def user_filter_queryset(self, queryset: "QuerySet") -> "QuerySet": + """ + Filter the given queryset by the user specified in the URL and by the read + permimssions given to the request user. + """ + return self.request.user.get_user_related_queryset( + queryset.filter(**{self.queryset_user_field: self.user}), + self.queryset_user_field, + ) + + def get_queryset(self): + """ + Return the queryset for this viewset, filtered by the user specified in the URL + and by the read permimssions given to the request user. + """ + return self.user_filter_queryset(super().get_queryset()) + + def get_serializer_context(self): + return { + **super().get_serializer_context(), + "user": self.user, + } + + @property + def user(self) -> ProjectUser: + if not hasattr(self, "_user"): + if self.user_id_url_kwarg not in self.kwargs: + raise ValueError( + f"URL kwarg '{self.user_id_url_kwarg}' is required for a" + f" viewset based on UserRelatedViewset." + ) + self._user = get_object_or_404( + ProjectUser, id=self.kwargs[self.user_id_url_kwarg] + ) + return self._user diff --git a/apps/projects/permissions.py b/apps/projects/permissions.py index af778b15..2e7dd01e 100644 --- a/apps/projects/permissions.py +++ b/apps/projects/permissions.py @@ -76,7 +76,7 @@ def has_object_permission( if not project: project = self.get_related_project(request, view) if project and app: - request.user.has_perm(f"{app}.{codename}", project) + return request.user.has_perm(f"{app}.{codename}", project) if project: return request.user.has_perm(codename, project) return False diff --git a/apps/projects/tests/views/test_project.py b/apps/projects/tests/views/test_project.py index b9b0eb36..5c248ec6 100644 --- a/apps/projects/tests/views/test_project.py +++ b/apps/projects/tests/views/test_project.py @@ -91,7 +91,9 @@ def test_create_project(self, role, expected_code): "owner_groups": [pg.id for pg in self.owner_groups], }, } - response = self.client.post(reverse("Project-list"), data=payload) + response = self.client.post( + reverse("Project-list", args=(self.organization.code,)), data=payload + ) self.assertEqual(response.status_code, expected_code) if expected_code == status.HTTP_201_CREATED: content = response.json() @@ -195,7 +197,8 @@ def test_update_project(self, role, expected_code): "template_id": self.template.id, } response = self.client.patch( - reverse("Project-detail", args=(self.project.id,)), data=payload + reverse("Project-detail", args=(self.organization.code, self.project.id)), + data=payload, ) self.assertEqual(response.status_code, expected_code) if expected_code == status.HTTP_200_OK: @@ -249,7 +252,8 @@ def test_update_project_only_reviewer_can_update(self, role, expected_code): "publication_status": Project.PublicationStatus.PUBLIC, } response = self.client.patch( - reverse("Project-detail", args=(project.id,)), data=payload + reverse("Project-detail", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, expected_code) content = response.json() @@ -294,7 +298,9 @@ def test_delete_project(self, role, expected_code): project = ProjectFactory(organizations=[self.organization]) user = self.get_parameterized_test_user(role, instances=[project]) self.client.force_authenticate(user) - response = self.client.delete(reverse("Project-detail", args=(project.id,))) + response = self.client.delete( + reverse("Project-detail", args=(self.organization.code, project.id)) + ) self.assertEqual(response.status_code, expected_code) if expected_code == status.HTTP_204_NO_CONTENT: project.refresh_from_db() @@ -349,7 +355,7 @@ def test_add_project_member(self, role, expected_code): "reviewer_groups": [pg.id for pg in self.reviewer_groups], } response = self.client.post( - reverse("Project-add-member", args=(project.id,)), + reverse("Project-add-member", args=(self.organization.code, project.id)), data=payload, ) self.assertEqual(response.status_code, expected_code) @@ -402,7 +408,7 @@ def test_remove_project_member(self, role, expected_code): ], } response = self.client.post( - reverse("Project-remove-member", args=(project.id,)), + reverse("Project-remove-member", args=(self.organization.code, project.id)), data=payload, ) self.assertEqual(response.status_code, expected_code) @@ -427,7 +433,7 @@ def test_remove_project_member_self(self): project.members.add(to_delete) self.client.force_authenticate(to_delete) response = self.client.delete( - reverse("Project-remove-self", args=(project.id,)) + reverse("Project-remove-self", args=(self.organization.code, project.id)) ) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertNotIn(to_delete, project.members.all()) @@ -594,12 +600,14 @@ def test_duplicate_project(self, role, expected_code): user = self.get_parameterized_test_user(role, instances=[self.project]) self.client.force_authenticate(user) response = self.client.post( - reverse("Project-duplicate", args=(self.project.id,)) + reverse("Project-duplicate", args=(self.organization.code, self.project.id)) ) self.assertEqual(response.status_code, expected_code) if expected_code == status.HTTP_201_CREATED: initial_response = self.client.get( - reverse("Project-detail", args=(self.project.id,)) + reverse( + "Project-detail", args=(self.organization.code, self.project.id) + ) ) self.assertEqual(initial_response.status_code, status.HTTP_200_OK) duplicated_project = response.json() @@ -633,7 +641,9 @@ def test_lock_project(self, role, expected_code): project = ProjectFactory(organizations=[self.organization], is_locked=False) user = self.get_parameterized_test_user(role, instances=[project]) self.client.force_authenticate(user) - response = self.client.post(reverse("Project-lock", args=(project.id,))) + response = self.client.post( + reverse("Project-lock", args=(self.organization.code, project.id)) + ) self.assertEqual(response.status_code, expected_code) if expected_code == status.HTTP_200_OK: project.refresh_from_db() @@ -656,7 +666,9 @@ def test_unlock_project(self, role, expected_code): project = ProjectFactory(organizations=[self.organization], is_locked=True) user = self.get_parameterized_test_user(role, instances=[project]) self.client.force_authenticate(user) - response = self.client.post(reverse("Project-unlock", args=(project.id,))) + response = self.client.post( + reverse("Project-unlock", args=(self.organization.code, project.id)) + ) self.assertEqual(response.status_code, expected_code) if expected_code == status.HTTP_200_OK: project.refresh_from_db() @@ -667,12 +679,10 @@ class FilterSearchOrderProjectTestCase(JwtAPITestCase): @classmethod def setUpTestData(cls): super().setUpTestData() - cls.organization_1 = OrganizationFactory() - cls.organization_2 = OrganizationFactory() - cls.organization_3 = OrganizationFactory(parent=cls.organization_1) - cls.category_1 = ProjectCategoryFactory(organization=cls.organization_1) - cls.category_2 = ProjectCategoryFactory(organization=cls.organization_2) - cls.category_3 = ProjectCategoryFactory(organization=cls.organization_3) + cls.organization = OrganizationFactory() + cls.category_1 = ProjectCategoryFactory(organization=cls.organization) + cls.category_2 = ProjectCategoryFactory(organization=cls.organization) + cls.category_3 = ProjectCategoryFactory(organization=cls.organization) cls.tag_1 = TagFactory() cls.tag_2 = TagFactory() cls.tag_3 = TagFactory() @@ -681,21 +691,21 @@ def setUpTestData(cls): cls.date_3 = make_aware(datetime.datetime(2022, 1, 1)) cls.project_1 = ProjectFactory( - organizations=[cls.organization_1], + organizations=[cls.organization], categories=[cls.category_1], language="fr", sdgs=[1, 2], life_status=Project.LifeStatus.TO_REVIEW, ) cls.project_2 = ProjectFactory( - organizations=[cls.organization_2], + organizations=[cls.organization], categories=[cls.category_2], language="en", sdgs=[2, 3], life_status=Project.LifeStatus.RUNNING, ) cls.project_3 = ProjectFactory( - organizations=[cls.organization_3], + organizations=[cls.organization], categories=[cls.category_3], language="en", sdgs=[3, 4], @@ -713,9 +723,9 @@ def setUpTestData(cls): cls.user_3 = UserFactory( groups=[cls.project_2.get_owners(), cls.project_3.get_owners()] ) - cls.people_group_1 = PeopleGroupFactory(organization=cls.organization_1) - cls.people_group_2 = PeopleGroupFactory(organization=cls.organization_2) - cls.people_group_3 = PeopleGroupFactory(organization=cls.organization_3) + cls.people_group_1 = PeopleGroupFactory(organization=cls.organization) + cls.people_group_2 = PeopleGroupFactory(organization=cls.organization) + cls.people_group_3 = PeopleGroupFactory(organization=cls.organization) cls.project_1.owner_groups.add(cls.people_group_1) cls.project_2.reviewer_groups.add(cls.people_group_1) cls.project_2.member_groups.add(cls.people_group_2) @@ -742,7 +752,7 @@ def setUp(self) -> None: def test_filter_by_category(self): response = self.client.get( - reverse("Project-list") + reverse("Project-list", args=(self.organization.code,)) + f"?categories={self.category_1.id},{self.category_2.id}" ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) @@ -753,23 +763,13 @@ def test_filter_by_category(self): {self.project_1.id, self.project_2.id}, ) - def test_filter_by_organization_code(self): + def test_filter_by_language(self): response = self.client.get( - reverse("Project-list") + f"?organizations={self.organization_1.code}" + reverse("Project-list", args=(self.organization.code,)) + "?languages=en" ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() self.assertEqual(content["count"], 2) - self.assertEqual( - {p["id"] for p in content["results"]}, - {self.project_1.id, self.project_3.id}, - ) - - def test_filter_by_language(self): - response = self.client.get(reverse("Project-list") + "?languages=en") - self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) - content = response.json() - self.assertEqual(content["count"], 2) self.assertEqual( {p["id"] for p in content["results"]}, {self.project_2.id, self.project_3.id}, @@ -777,7 +777,8 @@ def test_filter_by_language(self): def test_filter_by_members(self): response = self.client.get( - reverse("Project-list") + f"?members={self.user_2.id},{self.user_3.id}" + reverse("Project-list", args=(self.organization.code,)) + + f"?members={self.user_2.id},{self.user_3.id}" ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() @@ -789,7 +790,7 @@ def test_filter_by_members(self): def test_filter_by_group_members(self): response = self.client.get( - reverse("Project-list") + reverse("Project-list", args=(self.organization.code,)) + f"?group_members={self.people_group_2.id},{self.people_group_3.id}" ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) @@ -801,7 +802,9 @@ def test_filter_by_group_members(self): ) def test_filter_by_sdgs(self): - response = self.client.get(reverse("Project-list") + "?sdgs=1,4,7") + response = self.client.get( + reverse("Project-list", args=(self.organization.code,)) + "?sdgs=1,4,7" + ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() self.assertEqual(content["count"], 2) @@ -812,7 +815,8 @@ def test_filter_by_sdgs(self): def test_filter_by_tags(self): response = self.client.get( - reverse("Project-list") + f"?tags={self.tag_1.id},{self.tag_2.id}" + reverse("Project-list", args=(self.organization.code,)) + + f"?tags={self.tag_1.id},{self.tag_2.id}" ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) content = response.json() @@ -824,7 +828,7 @@ def test_filter_by_tags(self): def test_filter_by_member_role(self): response = self.client.get( - reverse("Project-list") + reverse("Project-list", args=(self.organization.code,)) + f"?members={self.user_1.id},{self.user_2.id}" + f"&member_role={GroupData.Role.OWNERS},{GroupData.Role.MEMBERS}" ) @@ -838,7 +842,7 @@ def test_filter_by_member_role(self): def test_filter_by_group_role(self): response = self.client.get( - reverse("Project-list") + reverse("Project-list", args=(self.organization.code,)) + f"?group_members={self.people_group_1.id},{self.people_group_2.id}" + f"&group_role={GroupData.Role.OWNER_GROUPS},{GroupData.Role.MEMBER_GROUPS}" ) @@ -852,7 +856,7 @@ def test_filter_by_group_role(self): def test_filter_by_life_status(self): response = self.client.get( - reverse("Project-list") + reverse("Project-list", args=(self.organization.code,)) + f"?life_status={Project.LifeStatus.RUNNING},{Project.LifeStatus.COMPLETED}" ) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -864,7 +868,10 @@ def test_filter_by_life_status(self): ) def test_filter_by_creation_year(self): - response = self.client.get(reverse("Project-list") + "?creation_year=2020,2021") + response = self.client.get( + reverse("Project-list", args=(self.organization.code,)) + + "?creation_year=2020,2021" + ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() self.assertEqual(content["count"], 2) @@ -875,7 +882,8 @@ def test_filter_by_creation_year(self): def test_filter_by_ids_and_slugs(self): response = self.client.get( - reverse("Project-list") + f"?ids={self.project_1.id},{self.project_2.slug}" + reverse("Project-list", args=(self.organization.code,)) + + f"?ids={self.project_1.id},{self.project_2.slug}" ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() @@ -886,7 +894,10 @@ def test_filter_by_ids_and_slugs(self): ) def test_order_by_created_date(self): - response = self.client.get(reverse("Project-list") + "?ordering=created_at") + response = self.client.get( + reverse("Project-list", args=(self.organization.code,)) + + "?ordering=created_at" + ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() self.assertEqual(content["count"], 3) @@ -898,7 +909,10 @@ def test_order_by_created_date(self): ) def test_order_by_created_date_reverse(self): - response = self.client.get(reverse("Project-list") + "?ordering=-created_at") + response = self.client.get( + reverse("Project-list", args=(self.organization.code,)) + + "?ordering=-created_at" + ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() self.assertEqual(content["count"], 3) @@ -910,7 +924,10 @@ def test_order_by_created_date_reverse(self): ) def test_order_by_updated_date(self): - response = self.client.get(reverse("Project-list") + "?ordering=updated_at") + response = self.client.get( + reverse("Project-list", args=(self.organization.code,)) + + "?ordering=updated_at" + ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() self.assertEqual(content["count"], 3) @@ -922,7 +939,10 @@ def test_order_by_updated_date(self): ) def test_order_by_updated_date_reverse(self): - response = self.client.get(reverse("Project-list") + "?ordering=-updated_at") + response = self.client.get( + reverse("Project-list", args=(self.organization.code,)) + + "?ordering=-updated_at" + ) self.assertEqual(response.status_code, status.HTTP_200_OK, response.content) content = response.json() self.assertEqual(content["count"], 3) @@ -949,7 +969,8 @@ def test_update_without_organization(self): project = ProjectFactory(organizations=[self.organization]) payload = {"organizations_codes": []} response = self.client.patch( - reverse("Project-detail", args=(project.id,)), data=payload + reverse("Project-detail", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertApiValidationError( @@ -969,7 +990,8 @@ def test_remove_last_member(self): "users": [owner.id], } response = self.client.post( - reverse("Project-remove-member", args=(project.id,)), data=payload + reverse("Project-remove-member", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual( response.status_code, status.HTTP_400_BAD_REQUEST, response.content @@ -986,7 +1008,9 @@ def test_create_project_in_organization_with_no_rights(self): "organizations_codes": [self.organization.code, self.organization_2.code], "project_categories_ids": [self.category.id, self.category_2.id], } - response = self.client.post(reverse("Project-list"), data=payload) + response = self.client.post( + reverse("Project-list", args=(self.organization.code,)), data=payload + ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertApiPermissionError( response, "You do not have the rights to add a project in this organization" @@ -1000,7 +1024,8 @@ def test_add_project_to_organization_with_no_rights(self): "organizations_codes": [self.organization.code, self.organization_2.code], } response = self.client.patch( - reverse("Project-detail", args=(project.id,)), data=payload + reverse("Project-detail", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertApiPermissionError( @@ -1015,7 +1040,8 @@ def test_update_project_with_two_organizations(self): "title": faker.sentence(), } response = self.client.patch( - reverse("Project-detail", args=(project.id,)), data=payload + reverse("Project-detail", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_200_OK) content = response.json() @@ -1065,7 +1091,8 @@ def test_outdated_slug(self): # Check that the slug is updated and the old one is stored in outdated_slugs payload = {"title": title_b} response = self.client.patch( - reverse("Project-detail", args=(project.id,)), data=payload + reverse("Project-detail", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_200_OK) project.refresh_from_db() @@ -1075,7 +1102,8 @@ def test_outdated_slug(self): # Check that multiple_slug is correctly updated payload = {"title": title_c} response = self.client.patch( - reverse("Project-detail", args=(project.id,)), data=payload + reverse("Project-detail", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_200_OK) project.refresh_from_db() @@ -1085,7 +1113,8 @@ def test_outdated_slug(self): # Check that outdated_slugs are reused if relevant payload = {"title": title_b} response = self.client.patch( - reverse("Project-detail", args=(project.id,)), data=payload + reverse("Project-detail", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_200_OK) project.refresh_from_db() @@ -1100,7 +1129,9 @@ def test_outdated_slug(self): "title": title_a, "purpose": faker.sentence(), } - response = self.client.post(reverse("Project-list"), data=payload) + response = self.client.post( + reverse("Project-list", args=(self.organization.code,)), data=payload + ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) content = response.json() self.assertEqual(content["slug"], "title-a-1") @@ -1113,7 +1144,9 @@ def test_outdated_slug(self): "title": title_b, "purpose": faker.sentence(), } - response = self.client.post(reverse("Project-list"), data=payload) + response = self.client.post( + reverse("Project-list", args=(self.organization.code,)), data=payload + ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) content = response.json() self.assertEqual(content["slug"], "title-b-1") @@ -1135,7 +1168,8 @@ def test_change_member_role(self): GroupData.Role.OWNERS: [user.id], } response = self.client.post( - reverse("Project-add-member", args=(project.id,)), data=payload + reverse("Project-add-member", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertIn(user, project.owners.all()) @@ -1146,12 +1180,16 @@ def test_is_followed_get(self): user = self.superadmin self.client.force_authenticate(user) - response = self.client.get(reverse("Project-detail", args=(project.id,))) + response = self.client.get( + reverse("Project-detail", args=(self.organization.code, project.id)) + ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertFalse(response.json()["is_followed"]["is_followed"]) follow = FollowFactory(follower=user, project=project) - response = self.client.get(reverse("Project-detail", args=(project.id,))) + response = self.client.get( + reverse("Project-detail", args=(self.organization.code, project.id)) + ) self.assertEqual(response.status_code, status.HTTP_200_OK) content = response.json() self.assertTrue(content["is_followed"]["is_followed"]) @@ -1164,7 +1202,9 @@ def test_is_followed_list(self): follow_2 = FollowFactory(follower=user, project=projects[1]) self.client.force_authenticate(user) - response = self.client.get(reverse("Project-list")) + response = self.client.get( + reverse("Project-list", args=(self.organization.code,)) + ) self.assertEqual(response.status_code, status.HTTP_200_OK) content = response.json()["results"] self.assertSetEqual( @@ -1197,7 +1237,8 @@ def test_add_reviewer_to_public_project(self): GroupData.Role.REVIEWERS: [reviewer.id], } response = self.client.post( - reverse("Project-add-member", args=(project.id,)), data=payload + reverse("Project-add-member", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) project.refresh_from_db() @@ -1215,7 +1256,8 @@ def test_add_reviewer_to_reviewed_public_project(self): GroupData.Role.REVIEWERS: [reviewer.id], } response = self.client.post( - reverse("Project-add-member", args=(project.id,)), data=payload + reverse("Project-add-member", args=(self.organization.code, project.id)), + data=payload, ) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) project.refresh_from_db() diff --git a/apps/projects/tests/views/test_project_header.py b/apps/projects/tests/views/test_project_header.py index 94a8c6c2..13242992 100644 --- a/apps/projects/tests/views/test_project_header.py +++ b/apps/projects/tests/views/test_project_header.py @@ -48,7 +48,13 @@ def test_create_project_header(self, role, expected_code): "natural_ratio": faker.pyfloat(min_value=1.0, max_value=2.0), } response = self.client.post( - reverse("Project-header-list", args=(self.project.id,)), + reverse( + "Project-header-list", + args=( + self.organization.code, + self.project.id, + ), + ), data=payload, format="multipart", ) @@ -103,7 +109,11 @@ def test_update_project_header(self, role, expected_code): response = self.client.patch( reverse( "Project-header-detail", - args=(self.project.id, self.project.header_image.id), + args=( + self.organization.code, + self.project.id, + self.project.header_image.id, + ), ), data=payload, format="multipart", @@ -152,7 +162,7 @@ def test_delete_project_header(self, role, expected_code): response = self.client.delete( reverse( "Project-header-detail", - args=(project.id, project.header_image.id), + args=(self.organization.code, project.id, project.header_image.id), ), ) self.assertEqual(response.status_code, expected_code) diff --git a/apps/projects/urls.py b/apps/projects/urls.py index 889fcc31..0f1c0a5b 100644 --- a/apps/projects/urls.py +++ b/apps/projects/urls.py @@ -1,7 +1,11 @@ from rest_framework.routers import DefaultRouter from apps.announcements.views import AnnouncementViewSet -from apps.commons.urls import project_router_register +from apps.commons.urls import ( + organization_project_router_register, + organization_router_register, + project_router_register, +) from apps.feedbacks.views import ( CommentImagesView, CommentViewSet, @@ -30,7 +34,8 @@ router = DefaultRouter() router.register(r"location", ReadLocationViewSet, basename="Read-location") -router.register(r"project", ProjectViewSet, basename="Project") + +organization_router_register(router, r"project", ProjectViewSet, basename="Project") project_router_register( router, @@ -57,7 +62,9 @@ router, r"announcement", AnnouncementViewSet, basename="Announcement" ) project_router_register(router, r"image", ProjectImagesView, basename="Project-images") -project_router_register(router, r"header", ProjectHeaderView, basename="Project-header") +organization_project_router_register( + router, r"header", ProjectHeaderView, basename="Project-header" +) project_router_register( router, r"project-message", ProjectMessageViewSet, basename="ProjectMessage" ) diff --git a/apps/projects/views.py b/apps/projects/views.py index 5e41ada8..d5e60b5e 100644 --- a/apps/projects/views.py +++ b/apps/projects/views.py @@ -22,7 +22,11 @@ from apps.commons.cache import clear_cache_with_key, redis_cache_view from apps.commons.permissions import IsOwner, ReadOnly from apps.commons.utils import map_action_to_permission -from apps.commons.views import MultipleIDViewsetMixin +from apps.commons.views import ( + MultipleIDViewsetMixin, + OrganizationRelatedViewset, + ProjectRelatedViewset, +) from apps.files.models import Image from apps.files.views import ImageStorageView from apps.notifications.tasks import ( @@ -74,18 +78,23 @@ ) -class ProjectViewSet(MultipleIDViewsetMixin, viewsets.ModelViewSet): +class ProjectViewSet( + MultipleIDViewsetMixin, OrganizationRelatedViewset, viewsets.ModelViewSet +): """Main endpoints for projects.""" class InfoDetails(enum.Enum): SUMMARY = "summary" + queryset = Project.objects.all() serializer_class = ProjectSerializer filter_backends = [DjangoFilterBackend, OrderingFilter] filterset_class = ProjectFilter ordering_fields = ["created_at", "updated_at"] lookup_field = "id" lookup_value_regex = "[^/]+" + + queryset_organization_field = "organizations" multiple_lookup_fields = [ (Project, "id"), ] @@ -105,7 +114,7 @@ def get_permissions(self): def get_queryset(self) -> QuerySet: return ( - self.request.user.get_project_queryset() + self.organization_filter_queryset(self.request.user.get_project_queryset()) .select_related("header_image") .prefetch_related( "categories", @@ -130,10 +139,6 @@ def get_serializer_class(self): return ProjectLightSerializer return self.serializer_class - def get_serializer_context(self): - """Adds request to the serializer's context.""" - return {"request": self.request} - def perform_create(self, serializer: ProjectSerializer): project = serializer.save() project.setup_permissions(self.request.user) @@ -397,24 +402,10 @@ def similar(self, request, *args, **kwargs): return Response(ProjectLightSerializer(queryset, many=True).data) -class ProjectHeaderView(MultipleIDViewsetMixin, ImageStorageView): - permission_classes = [ - IsAuthenticatedOrReadOnly, - ProjectIsNotLocked, - ReadOnly - | IsOwner - | HasBasePermission("change_project", "projects") - | HasOrganizationPermission("change_project") - | HasProjectPermission("change_project"), - ] - multiple_lookup_fields = [ - (Project, "project_id"), - ] - - def get_queryset(self): - if "project_id" in self.kwargs: - return Image.objects.filter(project_header__id=self.kwargs["project_id"]) - return Image.objects.none() +class ProjectHeaderView(ProjectRelatedViewset, ImageStorageView): + queryset = Image.objects.all() + queryset_organization_field: str = "project_header__organizations" + queryset_project_field: str = "project_header" @staticmethod def upload_to(instance, filename) -> str: diff --git a/services/crisalid/views.py b/services/crisalid/views.py index 95debbed..e7a34ada 100644 --- a/services/crisalid/views.py +++ b/services/crisalid/views.py @@ -15,7 +15,7 @@ from rest_framework import viewsets from rest_framework.decorators import action -from apps.commons.views import NestedOrganizationViewMixins +from apps.commons.views import OrganizationRelatedViewset from services.crisalid import relators from services.crisalid.models import ( Document, @@ -82,7 +82,7 @@ ), ) class AbstractDocumentViewSet( - NestedOrganizationViewMixins, + OrganizationRelatedViewset, NestedResearcherViewMixins, viewsets.ReadOnlyModelViewSet, ): @@ -292,7 +292,7 @@ class ConferenceViewSet(AbstractDocumentViewSet): ], ), ) -class ResearcherViewSet(NestedOrganizationViewMixins, viewsets.ReadOnlyModelViewSet): +class ResearcherViewSet(OrganizationRelatedViewset, viewsets.ReadOnlyModelViewSet): serializer_class = ResearcherSerializer filter_backends = (DjangoFilterBackend,) filterset_fields = ("user_id", "id")