diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 100e0ae..cd55816 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,21 +1,43 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 - hooks: - - id: check-ast - - id: check-added-large-files - - id: check-case-conflict - - id: check-json - - id: check-merge-conflict - - id: check-symlinks - - id: check-toml - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.5 - hooks: - # Run the linter. - - id: ruff - args: [ --fix ] - # Run the formatter. - - id: ruff-format +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + exclude: tests/fixtures/ + - id: end-of-file-fixer + exclude: tests/fixtures/ + - id: check-yaml + exclude: tests/fixtures/ + - id: debug-statements + exclude: tests/fixtures/ + - id: requirements-txt-fixer + exclude: tests/fixtures/ +- repo: https://github.com/asottile/reorder_python_imports + rev: v3.14.0 + hooks: + - id: reorder-python-imports + args: [--py37-plus] + exclude: tests/fixtures/ +- repo: https://github.com/psf/black + rev: 23.9.1 + hooks: + - id: black + exclude: tests/fixtures/ + args: + - --line-length=120 +- repo: https://github.com/PyCQA/autoflake + rev: v2.3.1 + hooks: + - id: autoflake + args: [--remove-all-unused-imports, --in-place] + exclude: tests/fixtures/ +- repo: https://github.com/PyCQA/flake8 + rev: 7.1.1 + hooks: + - id: flake8 + exclude: tests/fixtures/ + args: + - --max-line-length=120 + - --extend-ignore=W503,E203 + additional_dependencies: + - flake8-string-format diff --git a/README.md b/README.md index 40bbb6e..fe0db29 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # Requirements -* Python 3.8+ +* Python 3.9+ * Django 4.1+ We **highly recommend** and only officially support the latest patch release of @@ -128,7 +128,7 @@ serializers.py ```python from adrf.serializers import Serializer -from rest_framework import serializers +from adrf import serializers class AsyncSerializer(Serializer): username = serializers.CharField() diff --git a/adrf/decorators.py b/adrf/decorators.py index d59e2b5..5d4cdf4 100644 --- a/adrf/decorators.py +++ b/adrf/decorators.py @@ -22,20 +22,15 @@ def decorator(func): # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this # api_view applied without (method_names) - assert not ( - isinstance(http_method_names, types.FunctionType) - ), "@api_view missing list of allowed HTTP methods" + assert not (isinstance(http_method_names, types.FunctionType)), "@api_view missing list of allowed HTTP methods" # api_view applied with eg. string instead of list of strings assert isinstance(http_method_names, (list, tuple)), ( - "@api_view expected a list of strings, received %s" - % type(http_method_names).__name__ + "@api_view expected a list of strings, received %s" % type(http_method_names).__name__ ) allowed_methods = set(http_method_names) | {"options"} - WrappedAPIView.http_method_names = [ - method.lower() for method in allowed_methods - ] + WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] view_is_async = asyncio.iscoroutinefunction(func) @@ -55,25 +50,15 @@ def handler(self, *args, **kwargs): WrappedAPIView.__name__ = func.__name__ WrappedAPIView.__module__ = func.__module__ - WrappedAPIView.renderer_classes = getattr( - func, "renderer_classes", APIView.renderer_classes - ) + WrappedAPIView.renderer_classes = getattr(func, "renderer_classes", APIView.renderer_classes) - WrappedAPIView.parser_classes = getattr( - func, "parser_classes", APIView.parser_classes - ) + WrappedAPIView.parser_classes = getattr(func, "parser_classes", APIView.parser_classes) - WrappedAPIView.authentication_classes = getattr( - func, "authentication_classes", APIView.authentication_classes - ) + WrappedAPIView.authentication_classes = getattr(func, "authentication_classes", APIView.authentication_classes) - WrappedAPIView.throttle_classes = getattr( - func, "throttle_classes", APIView.throttle_classes - ) + WrappedAPIView.throttle_classes = getattr(func, "throttle_classes", APIView.throttle_classes) - WrappedAPIView.permission_classes = getattr( - func, "permission_classes", APIView.permission_classes - ) + WrappedAPIView.permission_classes = getattr(func, "permission_classes", APIView.permission_classes) WrappedAPIView.schema = getattr(func, "schema", APIView.schema) diff --git a/adrf/fields.py b/adrf/fields.py index 2893ff9..f41f9ae 100644 --- a/adrf/fields.py +++ b/adrf/fields.py @@ -1,7 +1,462 @@ -from rest_framework.serializers import SerializerMethodField as DRFSerializerMethodField +import inspect +from typing import Mapping +from typing import Union +from asgiref.sync import sync_to_async +from django.core.exceptions import ValidationError as DjangoValidationError +from rest_framework.exceptions import ValidationError +from rest_framework.fields import _UnvalidatedField +from rest_framework.fields import BooleanField as DRFBooleanField +from rest_framework.fields import BuiltinSignatureError +from rest_framework.fields import CharField as DRFCharField +from rest_framework.fields import ChoiceField as DRFChoiceField +from rest_framework.fields import DateField as DRFDateField +from rest_framework.fields import DateTimeField as DRFDateTimeField +from rest_framework.fields import DecimalField as DRFDecimalField +from rest_framework.fields import DictField as DRFDictField +from rest_framework.fields import DurationField as DRFDurationField +from rest_framework.fields import EmailField as DRFEmailField +from rest_framework.fields import empty +from rest_framework.fields import Field +from rest_framework.fields import FileField as DRFFileField +from rest_framework.fields import FilePathField as DRFFilePathField +from rest_framework.fields import FloatField as DRFFloatField +from rest_framework.fields import get_error_detail +from rest_framework.fields import HiddenField as DRFHiddenField +from rest_framework.fields import HStoreField as DRFHStoreField +from rest_framework.fields import ImageField as DRFImageField +from rest_framework.fields import IntegerField as DRFIntegerField +from rest_framework.fields import IPAddressField as DRFIPAddressForeign +from rest_framework.fields import JSONField as DRFJSONField +from rest_framework.fields import ListField as DRFListField +from rest_framework.fields import ModelField as DRFModelField +from rest_framework.fields import MultipleChoiceField as DRFMultipleChoiceField +from rest_framework.fields import ReadOnlyField as DRFReadOnlyField +from rest_framework.fields import RegexField as DRFRegexField +from rest_framework.fields import SerializerMethodField as DRFSerializerMethodField +from rest_framework.fields import SkipField +from rest_framework.fields import SlugField as DRFSlugField +from rest_framework.fields import TimeField as DRFTimeField +from rest_framework.fields import URLField as DRFURLField +from rest_framework.fields import UUIDField as DRFUUIDField +from rest_framework.utils import html -class SerializerMethodField(DRFSerializerMethodField): +from adrf.utils import aget_attribute + + +class AsyncFieldMixin: + async def avalidate_empty_values(self, data): + """ + Асинхронная версия validate_empty_values. + """ + if self.read_only: + return (True, await self.aget_default()) + + if data is empty: + if getattr(self.root, "partial", False): + raise SkipField() + if self.required: + self.fail("required") + return (True, await self.aget_default()) + + if data is None: + if not self.allow_null: + self.fail("null") + elif self.source == "*": + return (False, None) + return (True, None) + + return (False, data) + + async def arun_validation(self, data=empty): + """ + Асинхронная версия run_validation. + """ + is_empty_value, data = await self.avalidate_empty_values(data) + if is_empty_value: + return data + value = await self.ato_internal_value(data) + await self.arun_validators(value) + return value + + async def arun_validators(self, value): + """ + Асинхронная версия run_validators. + """ + errors = [] + for validator in self.validators: + try: + if getattr(validator, "requires_context", False): + if inspect.iscoroutinefunction(validator): + await validator(value, self) + else: + await sync_to_async(validator)(value, self) + else: + if inspect.iscoroutinefunction(validator): + await validator(value) + else: + await sync_to_async(validator)(value) + except ValidationError as exc: + if isinstance(exc.detail, dict): + raise + errors.extend(exc.detail) + except DjangoValidationError as exc: + errors.extend(get_error_detail(exc)) + if errors: + raise ValidationError(errors) + + async def ato_internal_value(self, data): + """ + Асинхронная версия to_internal_value. + Должна быть переопределена в дочерних классах. + """ + raise NotImplementedError( + "{cls}.ato_internal_value() must be implemented for field " + "{field_name}.".format( + cls=self.__class__.__name__, + field_name=self.field_name, + ) + ) + + async def ato_representation(self, value): + """ + Асинхронная версия to_representation. + Должна быть переопределена в дочерних классах. + """ + raise NotImplementedError( + "{cls}.ato_representation() must be implemented for field {field_name}.".format( + cls=self.__class__.__name__, + field_name=self.field_name, + ) + ) + + async def aget_default(self): + """ + Асинхронная версия get_default. + """ + if self.default is empty or getattr(self.root, "partial", False): + raise SkipField() + if callable(self.default): + if getattr(self.default, "requires_context", False): + if inspect.iscoroutinefunction(self.default): + return await self.default(self) + return self.default(self) + else: + if inspect.isfunction(self.default): + return await self.default(self) + return self.default() + return self.default + + +class AsyncField(Field, AsyncFieldMixin): + """ + Базовый класс для всех асинхронных полей. + Наследует синхронные методы из Field и асинхронные из AsyncFieldMixin. + """ + + async def aget_attribute(self, instance): + """ + Given the *outgoing* object instance, return the primitive value + that should be used for this field. + """ + try: + return await aget_attribute(instance, self.source_attrs) + except BuiltinSignatureError as exc: + msg = ( + "Field source for `{serializer}.{field}` maps to a built-in " + "function type and is invalid. Define a property or method on " + "the `{instance}` instance that wraps the call to the built-in " + "function.".format( + serializer=self.parent.__class__.__name__, + field=self.field_name, + instance=instance.__class__.__name__, + ) + ) + raise type(exc)(msg) + except (KeyError, AttributeError) as exc: + if self.default is not empty: + return self.get_default() + if self.allow_null: + return None + if not self.required: + raise SkipField() + msg = ( + "Got {exc_type} when attempting to get a value for field " + "`{field}` on serializer `{serializer}`.\nThe serializer " + "field might be named incorrectly and not match " + "any attribute or key on the `{instance}` instance.\n" + "Original exception text was: {exc}.".format( + exc_type=type(exc).__name__, + field=self.field_name, + serializer=self.parent.__class__.__name__, + instance=instance.__class__.__name__, + exc=exc, + ) + ) + raise type(exc)(msg) + + +class BooleanField(DRFBooleanField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class CharField(DRFCharField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class EmailField(DRFEmailField, CharField): + pass + + +class RegexField(DRFRegexField, CharField): + pass + + +class SlugField(DRFSlugField, CharField): + pass + + +class URLField(DRFURLField, CharField): + pass + + +class UUIDField(DRFUUIDField, AsyncField): + pass + + +class IPAddressField(DRFIPAddressForeign, AsyncField): + pass + + +class IntegerField(DRFIntegerField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class FloatField(DRFFloatField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class DecimalField(DRFDecimalField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class DateTimeField(DRFDateTimeField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class DateField(DRFDateField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class TimeField(DRFTimeField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class DurationField(DRFDurationField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class ChoiceField(DRFChoiceField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class MultipleChoiceField(DRFMultipleChoiceField, ChoiceField): + pass + + +class FilePathField(DRFFilePathField, ChoiceField): + pass + + +class FileField(DRFFileField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class ImageField(DRFImageField, FileField): + pass + + +class ListField(DRFListField, AsyncField): + child: Union[Field, AsyncField] = _UnvalidatedField() + + async def arun_validation(self, data=empty): + if html.is_html_input(data): + data = html.parse_html_list(data, default=[]) + if isinstance(data, (str, Mapping)) or not hasattr(data, "__iter__"): + self.fail("not_a_list", input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + self.fail("empty") + return await self.arun_child_validation(data) + + async def arun_child_validation(self, data): + result = [] + errors = {} + + for idx, item in enumerate(data): + try: + if hasattr(self.child, "arun_validation"): + validated_item = await self.child.arun_validation(item) + else: + validated_item = self.child.run_validation(item) + result.append(validated_item) + except ValidationError as e: + errors[idx] = e.detail + except DjangoValidationError as e: + errors[idx] = get_error_detail(e) + + if not errors: + return result + raise ValidationError(errors) + + async def ato_representation(self, value): + result = [] + if hasattr(value, "__aiter__"): + async for item in value: + if item is not None: + if hasattr(self.child, "ato_representation"): + result.append(await self.child.ato_representation(item)) + else: + result.append(self.child.to_representation(item)) + else: + result.append(None) + else: + for item in value: + if item is not None: + if hasattr(self.child, "ato_representation"): + result.append(await self.child.ato_representation(item)) + else: + result.append(self.child.to_representation(item)) + else: + result.append(None) + return result + + +class DictField(DRFDictField, AsyncField): + child: Union[Field, AsyncField] = _UnvalidatedField() + + async def ato_internal_value(self, data): + """ + Dicts of native values <- Dicts of primitive datatypes. + """ + if html.is_html_input(data): + data = html.parse_html_dict(data) + if not isinstance(data, dict): + self.fail("not_a_dict", input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + self.fail("empty") + + return await self.arun_child_validation(data) + + async def ato_representation(self, value: dict): + result = {} + for key, val in value.items(): + if hasattr(self.child, "ato_representation"): + result[key] = await self.child.ato_representation(val) + else: + result[key] = self.child.to_representation(val) + + return {str(key): self.child.ato_representation(val) if val is not None else None for key, val in value.items()} + + async def arun_child_validation(self, data): + result = {} + errors = {} + + for key, value in data.items(): + key = str(key) + + try: + if hasattr(self.child, "arun_validation"): + result[key] = await self.child.arun_validation(value) + else: + result[key] = self.child.run_validation(value) + except ValidationError as e: + errors[key] = e.detail + + if not errors: + return result + raise ValidationError(errors) + + +class HStoreField(DRFHStoreField, DictField): + pass + + +class JSONField(DRFJSONField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + +class ReadOnlyField(DRFReadOnlyField, AsyncField): + async def ato_representation(self, value): + return self.to_representation(value) + + +class HiddenField(DRFHiddenField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + +class SerializerMethodField(DRFSerializerMethodField, AsyncField): async def ato_representation(self, attribute): method = getattr(self.parent, self.method_name) - return await method(attribute) + if inspect.iscoroutinefunction(method): + return await method(attribute) + return await sync_to_async(method)(attribute) + + +class ModelField(DRFModelField, AsyncField): + async def ato_internal_value(self, data): + return self.to_internal_value(data) + + async def ato_representation(self, value): + return self.to_representation(value) + + async def aget_attribute(self, instance): + return instance diff --git a/adrf/generics.py b/adrf/generics.py index d57b703..4b6a3a2 100644 --- a/adrf/generics.py +++ b/adrf/generics.py @@ -1,11 +1,13 @@ import asyncio from asgiref.sync import async_to_sync +from asgiref.sync import sync_to_async from django.http import Http404 from rest_framework.exceptions import ValidationError from rest_framework.generics import GenericAPIView as DRFGenericAPIView -from adrf import mixins, views +from adrf import mixins +from adrf import views from adrf.shortcuts import aget_object_or_404 as _aget_object_or_404 @@ -23,6 +25,12 @@ def aget_object_or_404(queryset, *filter_args, **filter_kwargs): class GenericAPIView(views.APIView, DRFGenericAPIView): """This generic API view supports async pagination.""" + async def aget_queryset(self): + return await sync_to_async(self.get_queryset)() + + async def afilter_queryset(self, queryset): + return await sync_to_async(self.filter_queryset)(queryset) + async def aget_object(self): """ Returns the object the view is displaying. @@ -31,7 +39,7 @@ async def aget_object(self): queryset lookups. Eg if objects are referenced using multiple keyword arguments in the url conf. """ - queryset = self.filter_queryset(self.get_queryset()) + queryset = await self.afilter_queryset(await self.aget_queryset()) # Perform the lookup filtering. lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field @@ -39,8 +47,7 @@ async def aget_object(self): assert lookup_url_kwarg in self.kwargs, ( "Expected view %s to be called with a URL keyword argument " 'named "%s". Fix your URL conf, or set the `.lookup_field` ' - "attribute on the view correctly." - % (self.__class__.__name__, lookup_url_kwarg) + "attribute on the view correctly." % (self.__class__.__name__, lookup_url_kwarg) ) filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} @@ -58,9 +65,7 @@ def paginate_queryset(self, queryset): if self.paginator is None: return None if asyncio.iscoroutinefunction(self.paginator.paginate_queryset): - return async_to_sync(self.paginator.paginate_queryset)( - queryset, self.request, view=self - ) + return async_to_sync(self.paginator.paginate_queryset)(queryset, self.request, view=self) return self.paginator.paginate_queryset(queryset, self.request, view=self) def get_paginated_response(self, data): @@ -79,9 +84,7 @@ async def apaginate_queryset(self, queryset): if self.paginator is None: return None if asyncio.iscoroutinefunction(self.paginator.paginate_queryset): - return await self.paginator.paginate_queryset( - queryset, self.request, view=self - ) + return await self.paginator.paginate_queryset(queryset, self.request, view=self) return self.paginator.paginate_queryset(queryset, self.request, view=self) async def get_apaginated_response(self, data): @@ -158,9 +161,7 @@ async def post(self, request, *args, **kwargs): return await self.acreate(request, *args, **kwargs) -class RetrieveUpdateAPIView( - mixins.RetrieveModelMixin, mixins.UpdateModelMixin, GenericAPIView -): +class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, GenericAPIView): """ Concrete view for retrieving, updating a model instance. """ @@ -175,9 +176,7 @@ async def patch(self, request, *args, **kwargs): return await self.partial_aupdate(request, *args, **kwargs) -class RetrieveDestroyAPIView( - mixins.RetrieveModelMixin, mixins.DestroyModelMixin, GenericAPIView -): +class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, GenericAPIView): """ Concrete view for retrieving or deleting a model instance. """ diff --git a/adrf/mixins.py b/adrf/mixins.py index 5f4f3a1..c938370 100644 --- a/adrf/mixins.py +++ b/adrf/mixins.py @@ -1,5 +1,5 @@ -from asgiref.sync import sync_to_async -from rest_framework import mixins, status +from rest_framework import mixins +from rest_framework import status from rest_framework.response import Response @@ -15,7 +15,7 @@ class CreateModelMixin(mixins.CreateModelMixin): async def acreate(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) - await sync_to_async(serializer.is_valid)(raise_exception=True) + await serializer.ais_valid(raise_exception=True) await self.perform_acreate(serializer) data = await get_data(serializer) headers = self.get_success_headers(data) @@ -31,7 +31,8 @@ class ListModelMixin(mixins.ListModelMixin): """ async def alist(self, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) + initial_queryset = await self.aget_queryset() + queryset = await self.afilter_queryset(initial_queryset) page = await self.apaginate_queryset(queryset) if page is not None: @@ -65,7 +66,7 @@ async def aupdate(self, request, *args, **kwargs): partial = kwargs.pop("partial", False) instance = await self.aget_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) - await sync_to_async(serializer.is_valid)(raise_exception=True) + await serializer.ais_valid(raise_exception=True) await self.perform_aupdate(serializer) if getattr(instance, "_prefetched_objects_cache", None): diff --git a/adrf/relations.py b/adrf/relations.py new file mode 100644 index 0000000..e489b11 --- /dev/null +++ b/adrf/relations.py @@ -0,0 +1,304 @@ +import contextlib +import sys +from urllib import parse + +from async_property import async_property +from django.core.exceptions import ObjectDoesNotExist +from django.urls import get_script_prefix +from django.urls import resolve +from django.urls import Resolver404 +from django.utils.encoding import smart_str +from django.utils.encoding import uri_to_iri +from rest_framework.fields import SkipField +from rest_framework.relations import empty +from rest_framework.relations import HyperlinkedIdentityField as DRFHyperlinkedIdentityField +from rest_framework.relations import HyperlinkedRelatedField as DRFHyperlinkedRelatedField +from rest_framework.relations import is_simple_callable +from rest_framework.relations import iter_options +from rest_framework.relations import MANY_RELATION_KWARGS +from rest_framework.relations import ManyRelatedField as DRFManyRelatedField +from rest_framework.relations import ObjectTypeError +from rest_framework.relations import ObjectValueError +from rest_framework.relations import PKOnlyObject +from rest_framework.relations import PrimaryKeyRelatedField as DRFPrimaryKeyRelatedField +from rest_framework.relations import RelatedField as DRFRelatedField +from rest_framework.relations import SlugRelatedField as DRFSlugRelatedField +from rest_framework.relations import StringRelatedField as DRFStringRelatedField + +from adrf.fields import AsyncField +from adrf.utils import aget_attribute +from adrf.utils import async_attrgetter +from adrf.utils import is_async_callable + + +class RelatedField(DRFRelatedField, AsyncField): + queryset = None + html_cutoff = None + html_cutoff_text = None + + @classmethod + def many_init(cls, *args, **kwargs): + """ + This method handles creating a parent `ManyRelatedField` instance + when the `many=True` keyword argument is passed. + + Typically you won't need to override this method. + + Note that we're over-cautious in passing most arguments to both parent + and child classes in order to try to cover the general case. If you're + overriding this method you'll probably want something much simpler, eg: + + @classmethod + def many_init(cls, *args, **kwargs): + kwargs['child'] = cls() + return CustomManyRelatedField(*args, **kwargs) + """ + list_kwargs = {"child_relation": cls(*args, **kwargs)} + for key in kwargs: + if key in MANY_RELATION_KWARGS: + list_kwargs[key] = kwargs[key] + return ManyRelatedField(**list_kwargs) + + async def aget_attribute(self, instance): + if self.use_pk_only_optimization() and self.source_attrs: + # Optimized case, return a mock object only containing the pk attribute. + with contextlib.suppress(AttributeError): + attribute_instance = await aget_attribute(instance, self.source_attrs[:-1]) + value = attribute_instance.serializable_value(self.source_attrs[-1]) + if is_async_callable(value): + # Handle edge case where the relationship `source` argument + # points to a `get_relationship()` method on the model. + value = await value() + if is_simple_callable(value): + # Handle edge case where the relationship `source` argument + # points to a `get_relationship()` method on the model. + value = value() + + # Handle edge case where relationship `source` argument points + # to an instance instead of a pk (e.g., a `@property`). + value = getattr(value, "pk", value) + + return PKOnlyObject(pk=value) + # Standard case, return the object instance. + return await super().aget_attribute(instance) + + async def arun_validation(self, data=empty): + # We force empty strings to None values for relational fields. + if data == "": + data = None + return await super().arun_validation(data) + + async def aget_choices(self, cutoff=None): + queryset = self.get_queryset() + if queryset is None: + # Ensure that field.choices returns something sensible + # even when accessed with a read-only field. + return {} + + if cutoff is not None: + queryset = queryset[:cutoff] + + return {self.ato_representation(item): self.display_value(item) async for item in queryset} + + @async_property + async def choices(self): + return await self.aget_choices() + + @async_property + async def grouped_choices(self): + return await self.choices + + async def aiter_options(self): + choices = await self.aget_choices(cutoff=self.html_cutoff) + return iter_options(choices, cutoff=self.html_cutoff, cutoff_text=self.html_cutoff_text) + + +class StringRelatedField(DRFStringRelatedField, RelatedField): + async def ato_representation(self, value): + return self.to_representation(value) + + +class PrimaryKeyRelatedField(DRFPrimaryKeyRelatedField, RelatedField): + async def ato_internal_value(self, data): + if self.pk_field is not None: + if hasattr(self.pk_field, "ato_internal_value"): + data = await self.pk_field.ato_internal_value(data) + else: + data = self.pk_field.to_internal_value(data) + queryset = self.get_queryset() + try: + if isinstance(data, bool): + raise TypeError + return await queryset.aget(pk=data) + except ObjectDoesNotExist: + self.fail("does_not_exist", pk_value=data) + except (TypeError, ValueError): + self.fail("incorrect_type", data_type=type(data).__name__) + + async def ato_representation(self, value): + if self.pk_field is not None: + if hasattr(self.pk_field, "ato_representation"): + return await self.pk_field.ato_representation(value.pk) + return self.pk_field.to_representation(value.pk) + return value.pk + + +class HyperlinkedRelatedField(DRFHyperlinkedRelatedField, RelatedField): + async def aget_object(self, view_name, view_args, view_kwargs): + """ + Return the object corresponding to a matched URL. + + Takes the matched URL conf arguments, and should return an + object instance, or raise an `ObjectDoesNotExist` exception. + """ + lookup_value = view_kwargs[self.lookup_url_kwarg] + lookup_kwargs = {self.lookup_field: lookup_value} + queryset = self.get_queryset() + + try: + return await queryset.aget(**lookup_kwargs) + except ValueError: + exc = ObjectValueError(str(sys.exc_info()[1])) + raise exc.with_traceback(sys.exc_info()[2]) + except TypeError: + exc = ObjectTypeError(str(sys.exc_info()[1])) + raise exc.with_traceback(sys.exc_info()[2]) + + async def ato_internal_value(self, data): + request = self.context.get("request") + try: + http_prefix = data.startswith(("http:", "https:")) + except AttributeError: + self.fail("incorrect_type", data_type=type(data).__name__) + + if http_prefix: + # If needed convert absolute URLs to relative path + data = parse.urlparse(data).path + prefix = get_script_prefix() + if data.startswith(prefix): + data = "/" + data[len(prefix) :] + + data = uri_to_iri(parse.unquote(data)) + + try: + match = resolve(data) + except Resolver404: + self.fail("no_match") + + try: + expected_viewname = request.versioning_scheme.get_versioned_viewname(self.view_name, request) + except AttributeError: + expected_viewname = self.view_name + + if match.view_name != expected_viewname: + self.fail("incorrect_match") + + try: + return await self.aget_object(match.view_name, match.args, match.kwargs) + except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError): + self.fail("does_not_exist") + + async def ato_representation(self, value): + return self.to_representation(value) + + +class HyperlinkedIdentityField(DRFHyperlinkedIdentityField, HyperlinkedRelatedField): + pass + + +class SlugRelatedField(DRFSlugRelatedField, RelatedField): + async def ato_internal_value(self, data): + queryset = self.get_queryset() + try: + return await queryset.aget(**{self.slug_field: data}) + except ObjectDoesNotExist: + self.fail("does_not_exist", slug_name=self.slug_field, value=smart_str(data)) + except (TypeError, ValueError): + self.fail("invalid") + + async def ato_representation(self, obj): + slug = self.slug_field + if "__" in slug: + # handling nested relationship if defined + slug = slug.replace("__", ".") + return await async_attrgetter(slug)(obj) + + +class ManyRelatedField(DRFManyRelatedField, AsyncField): + async def ato_internal_value(self, data): + if isinstance(data, str) or not hasattr(data, "__iter__"): + self.fail("not_a_list", input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + self.fail("empty") + + result = [] + for item in data: + if hasattr(self.child_relation, "ato_internal_value"): + result.append(await self.child_relation.ato_internal_value(item)) + else: + result.append(self.child_relation.to_internal_value(item)) + return result + + async def aget_attribute(self, instance): + # Can't have any relationships if not created + if hasattr(instance, "pk") and instance.pk is None: + return [] + + try: + relationship = await aget_attribute(instance, self.source_attrs) + except (KeyError, AttributeError) as exc: + if self.default is not empty: + return self.get_default() + if self.allow_null: + return None + if not self.required: + raise SkipField() + msg = ( + "Got {exc_type} when attempting to get a value for field " + "`{field}` on serializer `{serializer}`.\nThe serializer " + "field might be named incorrectly and not match " + "any attribute or key on the `{instance}` instance.\n" + "Original exception text was: {exc}.".format( + exc_type=type(exc).__name__, + field=self.field_name, + serializer=self.parent.__class__.__name__, + instance=instance.__class__.__name__, + exc=exc, + ) + ) + raise type(exc)(msg) + + return relationship.all() if hasattr(relationship, "all") else relationship + + async def ato_representation(self, iterable): + result = [] + if hasattr(iterable, "__aiter__"): + async for value in iterable: + if hasattr(self.child_relation, "ato_representation"): + result.append(await self.child_relation.ato_representation(value)) + else: + result.append(self.child_relation.to_representation(value)) + else: + for value in iterable: + if hasattr(self.child_relation, "ato_representation"): + result.append(await self.child_relation.ato_representation(value)) + else: + result.append(self.child_relation.to_representation(value)) + return result + + async def aget_choices(self, cutoff=None): + if hasattr(self.child_relation, "aget_choices"): + return await self.child_relation.aget_choices(cutoff) + return self.child_relation.get_choices(cutoff) + + @async_property + async def choices(self): + return await self.aget_choices() + + @async_property + async def grouped_choices(self): + return await self.choices + + async def aiter_options(self): + choices = await self.get_choices(cutoff=self.html_cutoff) + return iter_options(choices, cutoff=self.html_cutoff, cutoff_text=self.html_cutoff_text) diff --git a/adrf/requests.py b/adrf/requests.py index 13f641e..292de6e 100644 --- a/adrf/requests.py +++ b/adrf/requests.py @@ -2,7 +2,8 @@ from asgiref.sync import async_to_sync from rest_framework import exceptions -from rest_framework.request import Request, wrap_attributeerrors +from rest_framework.request import Request +from rest_framework.request import wrap_attributeerrors class AsyncRequest(Request): diff --git a/adrf/routers.py b/adrf/routers.py index 388fad8..8631068 100644 --- a/adrf/routers.py +++ b/adrf/routers.py @@ -29,8 +29,7 @@ def get_method_map(self, viewset, method_map): bound_methods = {} if getattr(viewset, "view_is_async", False): method_map = { - method: self.sync_to_async_action_map.get(action, action) - for method, action in method_map.items() + method: self.sync_to_async_action_map.get(action, action) for method, action in method_map.items() } for method, action in method_map.items(): if hasattr(viewset, action): diff --git a/adrf/serializers.py b/adrf/serializers.py index 407deeb..5b416f0 100644 --- a/adrf/serializers.py +++ b/adrf/serializers.py @@ -4,20 +4,53 @@ from async_property import async_property from django.db import models +from rest_framework.compat import postgres_fields +from rest_framework.exceptions import ValidationError +from rest_framework.fields import html from rest_framework.fields import SkipField -from rest_framework.serializers import ( - LIST_SERIALIZER_KWARGS, - model_meta, - raise_errors_on_nested_writes, -) +from rest_framework.serializers import api_settings from rest_framework.serializers import BaseSerializer as DRFBaseSerializer +from rest_framework.serializers import DjangoValidationError +from rest_framework.serializers import get_error_detail +from rest_framework.serializers import LIST_SERIALIZER_KWARGS from rest_framework.serializers import ListSerializer as DRFListSerializer +from rest_framework.serializers import Mapping +from rest_framework.serializers import model_meta from rest_framework.serializers import ModelSerializer as DRFModelSerializer +from rest_framework.serializers import raise_errors_on_nested_writes from rest_framework.serializers import Serializer as DRFSerializer -from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList - - -class BaseSerializer(DRFBaseSerializer): +from rest_framework.utils.serializer_helpers import ReturnDict +from rest_framework.utils.serializer_helpers import ReturnList + +from adrf.fields import AsyncField +from adrf.fields import BooleanField +from adrf.fields import CharField +from adrf.fields import ChoiceField +from adrf.fields import DateField +from adrf.fields import DateTimeField +from adrf.fields import DecimalField +from adrf.fields import DurationField +from adrf.fields import EmailField +from adrf.fields import FileField +from adrf.fields import FilePathField +from adrf.fields import FloatField +from adrf.fields import HStoreField +from adrf.fields import ImageField +from adrf.fields import IntegerField +from adrf.fields import IPAddressField +from adrf.fields import JSONField +from adrf.fields import ListField +from adrf.fields import ModelField +from adrf.fields import SlugField +from adrf.fields import TimeField +from adrf.fields import URLField +from adrf.fields import UUIDField +from adrf.relations import HyperlinkedIdentityField +from adrf.relations import PrimaryKeyRelatedField +from adrf.relations import SlugRelatedField + + +class BaseSerializer(DRFBaseSerializer, AsyncField): """ Base serializer class. """ @@ -37,13 +70,7 @@ def many_init(cls, *args, **kwargs): list_kwargs["max_length"] = max_length if min_length is not None: list_kwargs["min_length"] = min_length - list_kwargs.update( - { - key: value - for key, value in kwargs.items() - if key in LIST_SERIALIZER_KWARGS - } - ) + list_kwargs.update({key: value for key, value in kwargs.items() if key in LIST_SERIALIZER_KWARGS}) meta = getattr(cls, "Meta", None) list_serializer_class = getattr(meta, "list_serializer_class", ListSerializer) return list_serializer_class(*args, **list_kwargs) @@ -63,9 +90,7 @@ async def adata(self): if not hasattr(self, "_data"): if self.instance is not None and not getattr(self, "_errors", None): self._data = await self.ato_representation(self.instance) - elif hasattr(self, "_validated_data") and not getattr( - self, "_errors", None - ): + elif hasattr(self, "_validated_data") and not getattr(self, "_errors", None): self._data = await self.ato_representation(self.validated_data) else: self._data = self.get_initial() @@ -82,13 +107,9 @@ async def acreate(self, validated_data): raise NotImplementedError("`acreate()` must be implemented.") async def asave(self, **kwargs): - assert hasattr( - self, "_errors" - ), "You must call `.is_valid()` before calling `.asave()`." + assert hasattr(self, "_errors"), "You must call `.is_valid()` before calling `.asave()`." - assert ( - not self.errors - ), "You cannot call `.asave()` on a serializer with invalid data." + assert not self.errors, "You cannot call `.asave()` on a serializer with invalid data." # Guard against incorrect use of `serializer.asave(commit=False)` assert "commit" not in kwargs, ( @@ -109,17 +130,33 @@ async def asave(self, **kwargs): if self.instance is not None: self.instance = await self.aupdate(self.instance, validated_data) - assert ( - self.instance is not None - ), "`aupdate()` did not return an object instance." + assert self.instance is not None, "`aupdate()` did not return an object instance." else: self.instance = await self.acreate(validated_data) - assert ( - self.instance is not None - ), "`acreate()` did not return an object instance." + assert self.instance is not None, "`acreate()` did not return an object instance." return self.instance + async def ais_valid(self, *, raise_exception=False): + assert hasattr(self, "initial_data"), ( + "Cannot call `.is_valid()` as no `data=` keyword argument was " + "passed when instantiating the serializer instance." + ) + + if not hasattr(self, "_validated_data"): + try: + self._validated_data = await self.arun_validation(self.initial_data) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.detail + else: + self._errors = {} + + if self._errors and raise_exception: + raise ValidationError(self.errors) + + return not bool(self._errors) + class Serializer(BaseSerializer, DRFSerializer): @async_property @@ -132,6 +169,39 @@ async def adata(self): return ReturnDict(ret, serializer=self) + async def ato_internal_value(self, data): + """ + Dict of native values <- Dict of primitive datatypes. + """ + if not isinstance(data, Mapping): + message = self.error_messages["invalid"].format(datatype=type(data).__name__) + raise ValidationError({api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="invalid") + + ret = {} + errors = {} + fields = self._writable_fields + + for field in fields: + validate_method = getattr(self, "validate_" + field.field_name, None) + primitive_value = field.get_value(data) + try: + validated_value = await field.arun_validation(primitive_value) + if validate_method is not None: + validated_value = validate_method(validated_value) + except ValidationError as exc: + errors[field.field_name] = exc.detail + except DjangoValidationError as exc: + errors[field.field_name] = get_error_detail(exc) + except SkipField: + pass + else: + self.set_value(ret, field.source_attrs, validated_value) + + if errors: + raise ValidationError(errors) + + return ret + async def ato_representation(self, instance): """ Object instance -> Dict of primitive datatypes. @@ -142,19 +212,15 @@ async def ato_representation(self, instance): for field in fields: try: - attribute = field.get_attribute(instance) + attribute = await field.aget_attribute(instance) except SkipField: continue - check_for_none = ( - attribute.pk if isinstance(attribute, models.Model) else attribute - ) + check_for_none = attribute.pk if isinstance(attribute, models.Model) else attribute if check_for_none is None: ret[field.field_name] = None else: - if asyncio.iscoroutinefunction( - getattr(field, "ato_representation", None) - ): + if asyncio.iscoroutinefunction(getattr(field, "ato_representation", None)): repr = await field.ato_representation(attribute) else: repr = field.to_representation(attribute) @@ -198,14 +264,10 @@ async def asave(self, **kwargs): if self.instance is not None: self.instance = await self.aupdate(self.instance, validated_data) - assert ( - self.instance is not None - ), "`aupdate()` did not return an object instance." + assert self.instance is not None, "`aupdate()` did not return an object instance." else: self.instance = await self.acreate(validated_data) - assert ( - self.instance is not None - ), "`acreate()` did not return an object instance." + assert self.instance is not None, "`acreate()` did not return an object instance." return self.instance @@ -226,8 +288,121 @@ async def adata(self): async def acreate(self, validated_data): return [await self.child.acreate(attrs) for attrs in validated_data] + async def ais_valid(self, *, raise_exception=False): + # This implementation is the same as the default, + # except that we use lists, rather than dicts, as the empty case. + assert hasattr(self, "initial_data"), ( + "Cannot call `.ais_valid()` as no `data=` keyword argument was " + "passed when instantiating the serializer instance." + ) + + if not hasattr(self, "_validated_data"): + try: + self._validated_data = await self.arun_validation(self.initial_data) + except ValidationError as exc: + self._validated_data = [] + self._errors = exc.detail + else: + self._errors = [] + + if self._errors and raise_exception: + raise ValidationError(self.errors) + + return not bool(self._errors) + + async def ato_internal_value(self, data): + """ + List of dicts of native values <- List of dicts of primitive datatypes. + """ + if html.is_html_input(data): + data = html.parse_html_list(data, default=[]) + + if not isinstance(data, list): + message = self.error_messages["not_a_list"].format(input_type=type(data).__name__) + raise ValidationError({api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="not_a_list") + + if not self.allow_empty and len(data) == 0: + message = self.error_messages["empty"] + raise ValidationError({api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="empty") + + if self.max_length is not None and len(data) > self.max_length: + message = self.error_messages["max_length"].format(max_length=self.max_length) + raise ValidationError({api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="max_length") + + if self.min_length is not None and len(data) < self.min_length: + message = self.error_messages["min_length"].format(min_length=self.min_length) + raise ValidationError({api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="min_length") + + ret = [] + errors = [] + + for item in data: + try: + validated = await self.arun_child_validation(item) + except ValidationError as exc: + errors.append(exc.detail) + else: + ret.append(validated) + errors.append({}) + + if any(errors): + raise ValidationError(errors) + + return ret + + async def arun_child_validation(self, data): + """ + Run validation on child serializer. + You may need to override this method to support multiple updates. For example: + + self.child.instance = self.instance.get(pk=data['id']) + self.child.initial_data = data + return super().run_child_validation(data) + """ + return await self.child.arun_validation(data) + + +class ModelSerializer(DRFModelSerializer, Serializer): + serializer_field_mapping = { + models.AutoField: IntegerField, + models.BigIntegerField: IntegerField, + models.BooleanField: BooleanField, + models.CharField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.DateField: DateField, + models.DateTimeField: DateTimeField, + models.DecimalField: DecimalField, + models.DurationField: DurationField, + models.EmailField: EmailField, + models.Field: ModelField, + models.FileField: FileField, + models.FloatField: FloatField, + models.ImageField: ImageField, + models.IntegerField: IntegerField, + models.NullBooleanField: BooleanField, + models.PositiveIntegerField: IntegerField, + models.PositiveSmallIntegerField: IntegerField, + models.SlugField: SlugField, + models.SmallIntegerField: IntegerField, + models.TextField: CharField, + models.TimeField: TimeField, + models.URLField: URLField, + models.UUIDField: UUIDField, + models.GenericIPAddressField: IPAddressField, + models.FilePathField: FilePathField, + } + + if hasattr(models, "JSONField"): + serializer_field_mapping[models.JSONField] = JSONField + if postgres_fields: + serializer_field_mapping[postgres_fields.HStoreField] = HStoreField + serializer_field_mapping[postgres_fields.ArrayField] = ListField + serializer_field_mapping[postgres_fields.JSONField] = JSONField + serializer_related_field = PrimaryKeyRelatedField + serializer_related_to_field = SlugRelatedField + serializer_url_field = HyperlinkedIdentityField + serializer_choice_field = ChoiceField -class ModelSerializer(Serializer, DRFModelSerializer): async def acreate(self, validated_data): """ Create and return a new `Snippet` instance, given the validated data. diff --git a/adrf/shortcuts.py b/adrf/shortcuts.py index 7c27cab..35fb0f4 100644 --- a/adrf/shortcuts.py +++ b/adrf/shortcuts.py @@ -10,16 +10,11 @@ async def aget_object_or_404(klass, *args, **kwargs): """See get_object_or_404().""" queryset = _get_queryset(klass) if not hasattr(queryset, "aget"): - klass__name = ( - klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ - ) + klass__name = klass.__name__ if isinstance(klass, type) else klass.__class__.__name__ raise ValueError( - "First argument to aget_object_or_404() must be a Model, Manager, or " - f"QuerySet, not '{klass__name}'." + "First argument to aget_object_or_404() must be a Model, Manager, or " f"QuerySet, not '{klass__name}'." ) try: return await queryset.aget(*args, **kwargs) except queryset.model.DoesNotExist: - raise Http404( - f"No {queryset.model._meta.object_name} matches the given query." - ) + raise Http404(f"No {queryset.model._meta.object_name} matches the given query.") diff --git a/adrf/test.py b/adrf/test.py index a8e93ad..73e7a6d 100644 --- a/adrf/test.py +++ b/adrf/test.py @@ -45,9 +45,7 @@ def _encode_data(self, data, format=None, content_type=None): if data is None: return ("", content_type) - assert ( - format is None or content_type is None - ), "You may not set both `format` and `content_type`." + assert format is None or content_type is None, "You may not set both `format` and `content_type`." if content_type: # Content type specified explicitly, treat data as a raw bytestring @@ -55,14 +53,11 @@ def _encode_data(self, data, format=None, content_type=None): else: format = format or self.default_format - + avalible = ", ".join(["'" + fmt + "'" for fmt in self.renderer_classes]) assert format in self.renderer_classes, ( - "Invalid format '{}'. Available formats are {}. " + f"Invalid format '{format}'. Available formats are {avalible}. " "Set TEST_REQUEST_RENDERER_CLASSES to enable " - "extra request formats.".format( - format, - ", ".join(["'" + fmt + "'" for fmt in self.renderer_classes]), - ) + "extra request formats." ) # Use format and render the data into a bytestring @@ -72,7 +67,7 @@ def _encode_data(self, data, format=None, content_type=None): # Determine the content-type header from the renderer content_type = renderer.media_type if renderer.charset: - content_type = "{}; charset={}".format(content_type, renderer.charset) + content_type = f"{content_type}; charset={renderer.charset}" # Coerce text to bytes if required. if isinstance(ret, str): @@ -173,33 +168,23 @@ def get(self, path, data=None, **extra): return response def post(self, path, data=None, format=None, content_type=None, **extra): - response = super().post( - path, data=data, format=format, content_type=content_type, **extra - ) + response = super().post(path, data=data, format=format, content_type=content_type, **extra) return response def put(self, path, data=None, format=None, content_type=None, **extra): - response = super().put( - path, data=data, format=format, content_type=content_type, **extra - ) + response = super().put(path, data=data, format=format, content_type=content_type, **extra) return response def patch(self, path, data=None, format=None, content_type=None, **extra): - response = super().patch( - path, data=data, format=format, content_type=content_type, **extra - ) + response = super().patch(path, data=data, format=format, content_type=content_type, **extra) return response def delete(self, path, data=None, format=None, content_type=None, **extra): - response = super().delete( - path, data=data, format=format, content_type=content_type, **extra - ) + response = super().delete(path, data=data, format=format, content_type=content_type, **extra) return response def options(self, path, data=None, format=None, content_type=None, **extra): - response = super().options( - path, data=data, format=format, content_type=content_type, **extra - ) + response = super().options(path, data=data, format=format, content_type=content_type, **extra) return response def logout(self): diff --git a/adrf/utils.py b/adrf/utils.py index 97c0447..cbc0ac8 100644 --- a/adrf/utils.py +++ b/adrf/utils.py @@ -1,4 +1,11 @@ import inspect +from collections.abc import Mapping + +from asgiref.sync import sync_to_async +from django.core.exceptions import ObjectDoesNotExist +from django.db.models import ForeignKey +from django.db.models import Model +from rest_framework.fields import is_simple_callable # NOTE This function was taken from the python library and modified @@ -15,10 +22,7 @@ def getmembers(object, predicate, exclude_names=[]): try: for base in object.__bases__: for k, v in base.__dict__.items(): - if ( - isinstance(v, inspect.types.DynamicClassAttribute) - and k not in exclude_names - ): + if isinstance(v, inspect.types.DynamicClassAttribute) and k not in exclude_names: names.append(k) except AttributeError: pass @@ -47,3 +51,109 @@ def getmembers(object, predicate, exclude_names=[]): processed.add(key) results.sort(key=lambda pair: pair[0]) return results + + +class async_attrgetter: + """ + Return a callable object that fetches the given attribute(s) from its operand asynchronously. + After f = async_attrgetter('name'), the call await f(r) returns r.name. + After g = async_attrgetter('name', 'date'), the call await g(r) returns (r.name, r.date). + After h = async_attrgetter('name.first', 'name.last'), the call await h(r) returns + (r.name.first, r.name.last). + """ + + __slots__ = ("_attrs", "_call") + + def __init__(self, attr, *attrs): + if not attrs: + if not isinstance(attr, str): + raise TypeError("attribute name must be a string") + self._attrs = (attr,) + names = attr.split(".") + + async def func(obj): + for name in names: + # Проверяем, является ли текущий атрибут внешним ключом + if ( + isinstance(obj, Model) + and name in [f.name for f in obj.__class__._meta.fields] + and isinstance(obj.__class__._meta.get_field(name), ForeignKey) + ): + # Если это внешний ключ, выполняем асинхронный запрос + obj = await obj.__class__.objects.select_related(name).aget(pk=obj.pk) + obj = getattr(obj, name) + return obj + + self._call = func + else: + self._attrs = (attr,) + attrs + getters = tuple(map(async_attrgetter, self._attrs)) + + async def func(obj): + return tuple(await getter(obj) for getter in getters) + + self._call = func + + async def __call__(self, obj): + return await self._call(obj) + + def __repr__(self): + return "%s.%s(%s)" % (self.__class__.__module__, self.__class__.__qualname__, ", ".join(map(repr, self._attrs))) + + def __reduce__(self): + return self.__class__, self._attrs + + +async def aget_attribute(instance, attrs): + """ + Similar to Python's built in `getattr(instance, attr)`, + but takes a list of nested attributes, instead of a single attribute. + + Also accepts either attribute lookup on objects or dictionary lookups. + This version is asynchronous and supports Django ORM's async methods. + """ + for attr in attrs: + try: + if isinstance(instance, Mapping): + instance = instance[attr] + else: + # Проверяем, является ли instance моделью Django + if isinstance(instance, Model): + # Получаем поле модели + field = instance.__class__._meta.get_field(attr) + # Если это внешний ключ, выполняем асинхронный запрос + if isinstance(field, ForeignKey): + instance = await instance.__class__.objects.select_related(attr).aget(pk=instance.pk) + # Получаем атрибут + instance = getattr(instance, attr) + except ObjectDoesNotExist: + return None + + # Если атрибут является callable-объектом, вызываем его + + try: + if is_async_callable(instance): + instance = await instance() + elif is_simple_callable(instance): + instance = await sync_to_async(instance)() + except (AttributeError, KeyError) as exc: + # Если вызов callable вызвал исключение, поднимаем ValueError + raise ValueError(f'Exception raised in callable attribute "{attr}"; original exception was: {exc}') + + return instance + + +def is_async_callable(obj): + """ + Проверяет, является ли объект асинхронно вызываемым. + Использует модуль inspect для более надёжной проверки. + """ + if not callable(obj): + return False + # Проверяем, является ли объект асинхронной функцией или методом + if inspect.iscoroutinefunction(obj): + return True + # Проверяем, является ли объект асинхронным callable объектом (например, объект с методом __call__) + if inspect.isawaitable(obj) and hasattr(obj, "__call__"): + return True + return False diff --git a/adrf/views.py b/adrf/views.py index 27fe8af..e9cf483 100755 --- a/adrf/views.py +++ b/adrf/views.py @@ -1,7 +1,9 @@ import asyncio -from typing import List, Optional +from typing import List +from typing import Optional -from asgiref.sync import async_to_sync, sync_to_async +from asgiref.sync import async_to_sync +from asgiref.sync import sync_to_async from rest_framework.permissions import BasePermission from rest_framework.request import Request from rest_framework.throttling import BaseThrottle @@ -27,9 +29,7 @@ def sync_dispatch(self, request, *args, **kwargs): # Get the appropriate handler method if request.method.lower() in self.http_method_names: - handler = getattr( - self, request.method.lower(), self.http_method_not_allowed - ) + handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed @@ -58,9 +58,7 @@ async def async_dispatch(self, request, *args, **kwargs): # Get the appropriate handler method if request.method.lower() in self.http_method_names: - handler = getattr( - self, request.method.lower(), self.http_method_not_allowed - ) + handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed @@ -119,9 +117,7 @@ def check_permissions(self, request: Request) -> None: if sync_permissions: self.check_sync_permissions(request, sync_permissions) - async def check_async_permissions( - self, request: AsyncRequest, permissions: List[BasePermission] - ) -> None: + async def check_async_permissions(self, request: AsyncRequest, permissions: List[BasePermission]) -> None: """ Check if the request should be permitted asynchronously. Raises an appropriate exception if the request is not permitted. @@ -142,9 +138,7 @@ async def check_async_permissions( code=getattr(has_permission, "code", None), ) - def check_sync_permissions( - self, request: Request, permissions: List[BasePermission] - ) -> None: + def check_sync_permissions(self, request: Request, permissions: List[BasePermission]) -> None: """ Check if the request should be permitted synchronously. Raises an appropriate exception if the request is not permitted. @@ -173,9 +167,7 @@ def check_object_permissions(self, request: Request, obj) -> None: sync_permissions.append(permission) if async_permissions: - async_to_sync(self.check_async_object_permissions)( - request, async_permissions, obj - ) + async_to_sync(self.check_async_object_permissions)(request, async_permissions, obj) if sync_permissions: self.check_sync_object_permissions(request, sync_permissions, obj) @@ -189,10 +181,7 @@ async def check_async_object_permissions( """ has_object_permissions = await asyncio.gather( - *[ - permission.has_object_permission(request, self, obj) - for permission in permissions - ], + *[permission.has_object_permission(request, self, obj) for permission in permissions], return_exceptions=True, ) @@ -206,9 +195,7 @@ async def check_async_object_permissions( code=getattr(has_object_permission, "code", None), ) - def check_sync_object_permissions( - self, request: Request, permissions: List[BasePermission], obj - ) -> None: + def check_sync_object_permissions(self, request: Request, permissions: List[BasePermission], obj) -> None: """ Check if the request should be permitted synchronously. Raises an appropriate exception if the request is not permitted. @@ -244,16 +231,12 @@ def check_throttles(self, request: Request) -> None: throttle_durations.extend(self.check_sync_throttles(request, sync_throttles)) - throttle_durations.extend( - async_to_sync(self.check_async_throttles)(request, async_throttles) - ) + throttle_durations.extend(async_to_sync(self.check_async_throttles)(request, async_throttles)) if throttle_durations: # Filter out `None` values which may happen in case of config / rate # changes, see #1438 - durations = [ - duration for duration in throttle_durations if duration is not None - ] + durations = [duration for duration in throttle_durations if duration is not None] duration = max(durations, default=None) self.throttled(request, duration) @@ -274,9 +257,7 @@ async def check_async_throttles( return throttle_durations - def check_sync_throttles( - self, request: Request, throttles: List[BaseThrottle] - ) -> List[Optional[float]]: + def check_sync_throttles(self, request: Request, throttles: List[BaseThrottle]) -> List[Optional[float]]: """ Check if the request should be throttled synchronously. Raises an appropriate exception if the request is throttled. diff --git a/adrf/viewsets.py b/adrf/viewsets.py index 3f5b8db..86fbfe1 100644 --- a/adrf/viewsets.py +++ b/adrf/viewsets.py @@ -64,15 +64,12 @@ def as_view(cls, actions=None, **initkwargs): "keyword argument to %s(). Don't do that." % (key, cls.__name__) ) if not hasattr(cls, key): - raise TypeError( - "%s() received an invalid keyword %r" % (cls.__name__, key) - ) + raise TypeError("%s() received an invalid keyword %r" % (cls.__name__, key)) # name and suffix are mutually exclusive if "name" in initkwargs and "suffix" in initkwargs: raise TypeError( - "%s() received both `name` and `suffix`, which are " - "mutually exclusive arguments." % (cls.__name__) + "%s() received both `name` and `suffix`, which are " "mutually exclusive arguments." % (cls.__name__) ) def view(request, *args, **kwargs): @@ -157,9 +154,7 @@ def view_is_async(cls): """ return any( asyncio.iscoroutinefunction(function) - for name, function in getmembers( - cls, inspect.iscoroutinefunction, exclude_names=["view_is_async"] - ) + for name, function in getmembers(cls, inspect.iscoroutinefunction, exclude_names=["view_is_async"]) if not name.startswith("__") and name not in cls._ASYNC_NON_DISPATCH_METHODS ) @@ -167,20 +162,18 @@ def view_is_async(cls): class GenericViewSet(ViewSet, GenericAPIView): _ASYNC_NON_DISPATCH_METHODS = ViewSet._ASYNC_NON_DISPATCH_METHODS + [ "aget_object", + "aget_queryset", + "afilter_queryset", "apaginate_queryset", "get_apaginated_response", ] -class ReadOnlyModelViewSet( - mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet -): +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet): """ A viewset that provides default asynchronous `list()` and `retrieve()` actions. """ - pass - class ModelViewSet( mixins.CreateModelMixin, @@ -194,5 +187,3 @@ class ModelViewSet( A viewset that provides default asynchronous `create()`, `retrieve()`, `update()`, `partial_update()`, `destroy()` and `list()` actions. """ - - pass diff --git a/poetry.lock b/poetry.lock index faa9da0..38fbcad 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "asgiref" @@ -28,34 +28,6 @@ files = [ {file = "async_property-0.2.2.tar.gz", hash = "sha256:17d9bd6ca67e27915a75d92549df64b5c7174e9dc806b30a3934dc4ff0506380"}, ] -[[package]] -name = "backports-zoneinfo" -version = "0.2.1" -description = "Backport of the standard library zoneinfo module" -optional = false -python-versions = ">=3.6" -files = [ - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"}, - {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"}, -] - -[package.extras] -tzdata = ["tzdata"] - [[package]] name = "cachetools" version = "5.4.0" @@ -67,6 +39,17 @@ files = [ {file = "cachetools-5.4.0.tar.gz", hash = "sha256:b8adc2e7c07f105ced7bc56dbb6dfbe7c4a00acce20e2227b3f355be89bc6827"}, ] +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "chardet" version = "5.2.0" @@ -180,7 +163,6 @@ files = [ [package.dependencies] asgiref = ">=3.6.0,<4" -"backports.zoneinfo" = {version = "*", markers = "python_version < \"3.9\""} sqlparse = ">=0.3.1" tzdata = {version = "*", markers = "sys_platform == \"win32\""} @@ -200,7 +182,6 @@ files = [ ] [package.dependencies] -"backports.zoneinfo" = {version = "*", markers = "python_version < \"3.9\""} django = ">=4.2" [[package]] @@ -247,6 +228,20 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "identify" +version = "2.6.7" +description = "File identification library for Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "identify-2.6.7-py2.py3-none-any.whl", hash = "sha256:155931cb617a401807b09ecec6635d6c692d180090a1cedca8ef7d58ba5b6aa0"}, + {file = "identify-2.6.7.tar.gz", hash = "sha256:3fa266b42eba321ee0b2bb0936a6a6b9e36a1351cbb69055b3082f4193035684"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -258,6 +253,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "packaging" version = "24.1" @@ -300,6 +306,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "4.1.0" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b"}, + {file = "pre_commit-4.1.0.tar.gz", hash = "sha256:ae3f018575a588e30dfddfab9a05448bfbd6b73d78709617b5a2b853549716d4"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "pyproject-api" version = "1.7.1" @@ -391,6 +415,68 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pyyaml" +version = "6.0.2" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, +] + [[package]] name = "ruff" version = "0.5.5" @@ -526,5 +612,5 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" -python-versions = ">=3.8" -content-hash = "dc9cc7e34b5f2fa502cb78551e5bc68afa150028e1182e428d5f170a702df497" +python-versions = ">=3.9" +content-hash = "7530d8d8f5ea490ce08760e4c767d69264e4b6d8a4eb8dd8858d420749372b2c" diff --git a/pyproject.toml b/pyproject.toml index 044d9fe..e5177b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "adrf" -version = "0.1.9" +version = "0.1.17" description = "Async support for Django REST framework" authors = ["Enrico Massa "] keywords = ["async", "django", "rest", "framework", "drf"] @@ -11,7 +11,7 @@ repository = "https://github.com/em1208/adrf" include = ["LICENSE"] [tool.poetry.dependencies] -python = ">=3.8" +python = ">=3.9" django = ">=4.1" djangorestframework = ">=3.14.0" async-property = ">=0.2.2" @@ -24,6 +24,9 @@ tox = "^4.16.0" faker = "^26.1.0" ruff = "^0.5.5" +[tool.poetry.group.dev.dependencies] +pre-commit = "^4.1.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 73ad047..4074777 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,10 +1,13 @@ import faker from django.contrib.auth.models import User from django.http import HttpResponse -from django.test import TestCase, override_settings -from rest_framework import permissions, status +from django.test import override_settings +from django.test import TestCase +from rest_framework import permissions +from rest_framework import status from rest_framework.authentication import BaseAuthentication -from rest_framework.decorators import authentication_classes, permission_classes +from rest_framework.decorators import authentication_classes +from rest_framework.decorators import permission_classes from rest_framework.exceptions import AuthenticationFailed from rest_framework.test import APIRequestFactory diff --git a/tests/test_generics.py b/tests/test_generics.py index 8e1ed83..23a996c 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -3,9 +3,10 @@ from rest_framework import status from rest_framework.test import APIRequestFactory -from adrf import generics, serializers - -from .models import Order, User +from .models import Order +from .models import User +from adrf import generics +from adrf import serializers factory = APIRequestFactory() diff --git a/tests/test_object_permissions.py b/tests/test_object_permissions.py index 8d049d0..b7e3089 100755 --- a/tests/test_object_permissions.py +++ b/tests/test_object_permissions.py @@ -1,6 +1,7 @@ from asgiref.sync import sync_to_async from django.http import HttpResponse -from django.test import TestCase, override_settings +from django.test import override_settings +from django.test import TestCase from rest_framework.permissions import BasePermission from rest_framework.test import APIRequestFactory @@ -59,17 +60,13 @@ class TestSyncObjectPermission(TestCase): async def test_sync_object_permission(self): request = factory.get("/sync/allow") - response = await ObjectPermissionTestView.as_view( - permission_classes=(SyncObjectPermission,) - )(request) + response = await ObjectPermissionTestView.as_view(permission_classes=(SyncObjectPermission,))(request) self.assertEqual(response.status_code, 200) async def test_sync_object_permission_reject(self): request = factory.get("/sync/reject") - response = await ObjectPermissionTestView.as_view( - permission_classes=(SyncObjectPermission,) - )(request) + response = await ObjectPermissionTestView.as_view(permission_classes=(SyncObjectPermission,))(request) self.assertEqual(response.status_code, 403) diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 51ebfe7..80f0806 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -1,5 +1,6 @@ from django.http import HttpResponse -from django.test import TestCase, override_settings +from django.test import override_settings +from django.test import TestCase from rest_framework.permissions import BasePermission from rest_framework.test import APIRequestFactory diff --git a/tests/test_routers.py b/tests/test_routers.py index 6038587..6be7d97 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -1,4 +1,6 @@ -from django.test import Client, TestCase, override_settings +from django.test import Client +from django.test import override_settings +from django.test import TestCase from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet as DRFModelViewSet @@ -68,9 +70,7 @@ def test_list(self): assert resp.data == {"method": "GET", "async": self.use_async} def test_create(self): - resp = self.client.post( - self.url, {"foo": "bar"}, content_type="application/json" - ) + resp = self.client.post(self.url, {"foo": "bar"}, content_type="application/json") assert resp.status_code == 200 assert resp.data == { "method": "POST", @@ -88,9 +88,7 @@ def test_retrieve(self): } def test_update(self): - resp = self.client.put( - self.detail_url, {"foo": "bar"}, content_type="application/json" - ) + resp = self.client.put(self.detail_url, {"foo": "bar"}, content_type="application/json") assert resp.status_code == 200 assert resp.data == { "method": "PUT", @@ -99,9 +97,7 @@ def test_update(self): } def test_partial_update(self): - resp = self.client.patch( - self.detail_url, {"foo": "bar"}, content_type="application/json" - ) + resp = self.client.patch(self.detail_url, {"foo": "bar"}, content_type="application/json") assert resp.status_code == 200 assert resp.data == { "method": "PATCH", diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 2f0ffc3..55a7a18 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -2,13 +2,14 @@ from asgiref.sync import sync_to_async from django.test import TestCase -from rest_framework import serializers from rest_framework.test import APIRequestFactory +from .models import Order +from .models import User +from adrf import serializers from adrf.fields import SerializerMethodField -from adrf.serializers import ModelSerializer, Serializer - -from .models import Order, User +from adrf.serializers import ModelSerializer +from adrf.serializers import Serializer factory = APIRequestFactory() @@ -101,9 +102,7 @@ async def test_invalid_datatype(self): assert serializer.validated_data == {} assert await serializer.adata == {} - assert serializer.errors == { - "non_field_errors": ["Invalid data. Expected a dictionary, but got list."] - } + assert serializer.errors == {"non_field_errors": ["Invalid data. Expected a dictionary, but got list."]} async def test_partial_validation(self): data = { @@ -164,9 +163,7 @@ async def test_crud_serializer_update(self): assert serializer.is_valid() # Update the object - updated_object = await serializer.aupdate( - default_object, serializer.validated_data - ) + updated_object = await serializer.aupdate(default_object, serializer.validated_data) # Verify the object has been updated successfully assert isinstance(updated_object, MockObject) diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py index 02701ff..45d1d55 100644 --- a/tests/test_shortcuts.py +++ b/tests/test_shortcuts.py @@ -1,9 +1,8 @@ from django.http import Http404 from django.test import TestCase -from adrf.shortcuts import aget_object_or_404 - from .models import User +from adrf.shortcuts import aget_object_or_404 class TestAGetObject(TestCase): diff --git a/tests/test_testmodule.py b/tests/test_testmodule.py index ee9539d..3da2a0b 100644 --- a/tests/test_testmodule.py +++ b/tests/test_testmodule.py @@ -1,11 +1,14 @@ from django.core.handlers.asgi import ASGIRequest -from django.test import TestCase, override_settings -from django.urls import path, reverse +from django.test import override_settings +from django.test import TestCase +from django.urls import path +from django.urls import reverse from rest_framework import status from rest_framework.response import Response from adrf.decorators import api_view -from adrf.test import AsyncAPIClient, AsyncAPIRequestFactory +from adrf.test import AsyncAPIClient +from adrf.test import AsyncAPIRequestFactory @api_view(["GET", "POST", "PUT", "PATCH"]) @@ -45,9 +48,7 @@ def setUp(self): def test_is_it_asgi(self): request = factory.request(path="/", method="GET") - assert isinstance( - request, ASGIRequest - ), f'Type of request is "{type(request).__name__}"' + assert isinstance(request, ASGIRequest), f'Type of request is "{type(request).__name__}"' def test_get_succeeds(self): request = factory.get("/") diff --git a/tests/test_throttling.py b/tests/test_throttling.py index 63aab9b..29d46a3 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -1,5 +1,6 @@ from django.http import HttpResponse -from django.test import TestCase, override_settings +from django.test import override_settings +from django.test import TestCase from rest_framework.test import APIRequestFactory from rest_framework.throttling import BaseThrottle diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py index 90c260b..c2aec0a 100644 --- a/tests/test_viewsets.py +++ b/tests/test_viewsets.py @@ -6,8 +6,10 @@ from rest_framework.test import APIRequestFactory from adrf.serializers import ModelSerializer -from adrf.viewsets import ModelViewSet, ViewSet -from tests.test_views import JSON_ERROR, sanitise_json_error +from adrf.viewsets import ModelViewSet +from adrf.viewsets import ViewSet +from tests.test_views import JSON_ERROR +from tests.test_views import sanitise_json_error factory = APIRequestFactory() @@ -124,9 +126,7 @@ class UserViewSet(ModelViewSet): class ModelViewSetIntegrationTests(TestCase): def setUp(self): self.list_create = UserViewSet.as_view({"get": "alist", "post": "acreate"}) - self.retrieve_update = UserViewSet.as_view( - {"get": "aretrieve", "put": "aupdate"} - ) + self.retrieve_update = UserViewSet.as_view({"get": "aretrieve", "put": "aupdate"}) self.destroy = UserViewSet.as_view({"delete": "adestroy"}) def test_list_succeeds(self):