Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 126 additions & 2 deletions adrf/fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,131 @@
from rest_framework.serializers import SerializerMethodField as DRFSerializerMethodField
from rest_framework import serializers as drf_serializers
import asyncio


class SerializerMethodField(DRFSerializerMethodField):
class SerializerMethodField(drf_serializers.SerializerMethodField):
async def ato_representation(self, attribute):
method = getattr(self.parent, self.method_name)
return await method(attribute)


class AsyncFieldMixin:
async def ato_representation(self, value):
if asyncio.iscoroutine(value):
value = await value
return super().to_representation(value)


class IntegerField(AsyncFieldMixin, drf_serializers.IntegerField):
pass


class BooleanField(AsyncFieldMixin, drf_serializers.BooleanField):
pass


class CharField(AsyncFieldMixin, drf_serializers.CharField):
pass


class ChoiceField(AsyncFieldMixin, drf_serializers.ChoiceField):
pass


class DateField(AsyncFieldMixin, drf_serializers.DateField):
pass


class DateTimeField(AsyncFieldMixin, drf_serializers.DateTimeField):
pass


class DecimalField(AsyncFieldMixin, drf_serializers.DecimalField):
pass


class DictField(AsyncFieldMixin, drf_serializers.DictField):
pass


class DurationField(AsyncFieldMixin, drf_serializers.DurationField):
pass


class EmailField(AsyncFieldMixin, drf_serializers.EmailField):
pass


class Field(AsyncFieldMixin, drf_serializers.Field):
pass


class FileField(AsyncFieldMixin, drf_serializers.FileField):
pass


class FilePathField(AsyncFieldMixin, drf_serializers.FilePathField):
pass


class FloatField(AsyncFieldMixin, drf_serializers.FloatField):
pass


class HiddenField(AsyncFieldMixin, drf_serializers.HiddenField):
pass


class HStoreField(AsyncFieldMixin, drf_serializers.HStoreField):
pass


class IPAddressField(AsyncFieldMixin, drf_serializers.IPAddressField):
pass


class ImageField(AsyncFieldMixin, drf_serializers.ImageField):
pass


class IntegerField(AsyncFieldMixin, drf_serializers.IntegerField):
pass


class JSONField(AsyncFieldMixin, drf_serializers.JSONField):
pass


class ListField(AsyncFieldMixin, drf_serializers.ListField):
pass


class ModelField(AsyncFieldMixin, drf_serializers.ModelField):
pass


class MultipleChoiceField(AsyncFieldMixin, drf_serializers.MultipleChoiceField):
pass


class ReadOnlyField(AsyncFieldMixin, drf_serializers.ReadOnlyField):
pass


class RegexField(AsyncFieldMixin, drf_serializers.RegexField):
pass


class SlugField(AsyncFieldMixin, drf_serializers.SlugField):
pass


class TimeField(AsyncFieldMixin, drf_serializers.TimeField):
pass


class URLField(AsyncFieldMixin, drf_serializers.URLField):
pass


class UUIDField(AsyncFieldMixin, drf_serializers.UUIDField):
pass
66 changes: 65 additions & 1 deletion adrf/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import traceback
from collections import OrderedDict
import uuid

from asgiref.sync import sync_to_async
from async_property import async_property
Expand All @@ -17,6 +18,37 @@
from rest_framework.serializers import ModelSerializer as DRFModelSerializer
from rest_framework.serializers import Serializer as DRFSerializer
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
from adrf.fields import ( # NOQA # isort:skip
BooleanField,
CharField,
ChoiceField,
DateField,
DateTimeField,
DecimalField,
DictField,
DurationField,
EmailField,
Field,
FileField,
FilePathField,
FloatField,
HiddenField,
HStoreField,
IPAddressField,
ImageField,
IntegerField,
JSONField,
ListField,
ModelField,
MultipleChoiceField,
ReadOnlyField,
RegexField,
SerializerMethodField,
SlugField,
TimeField,
URLField,
UUIDField,
)


class BaseSerializer(DRFBaseSerializer):
Expand Down Expand Up @@ -122,6 +154,23 @@ async def asave(self, **kwargs):

return self.instance

async def aget_attribute(self, instance):
resolved_attrs = []
source_attrs_copy = self.source_attrs.copy()
try:
for attr in self.source_attrs:
if asyncio.iscoroutine(getattr(instance, attr, None)):
awaited_attr_name = f"_{attr}_{uuid.uuid4()}" # We use uuid to not hit existing model field
setattr(instance, awaited_attr_name, await getattr(instance, attr))
resolved_attrs.append(awaited_attr_name)
else:
resolved_attrs.append(attr)
self.source_attrs = resolved_attrs
attribute = self.get_attribute(instance)
finally:
self.source_attrs = source_attrs_copy
return attribute


class Serializer(BaseSerializer, DRFSerializer):
@async_property
Expand All @@ -144,7 +193,12 @@ async def ato_representation(self, instance):

for field in fields:
try:
attribute = await sync_to_async(field.get_attribute)(instance)
if asyncio.iscoroutinefunction(
getattr(field, "aget_attribute", None)
):
attribute = await field.aget_attribute(instance)
else:
attribute = await sync_to_async(field.get_attribute)(instance)
except SkipField:
continue

Expand Down Expand Up @@ -299,3 +353,13 @@ async def aupdate(self, instance, validated_data):
await field.aset(value)

return instance


def build_property_field(self, field_name, model_class):
"""
Handle async properties without defined Field.
"""
_, field_kwargs = super().build_property_field(field_name, model_class)
field_class = ReadOnlyField

return field_class, field_kwargs
26 changes: 26 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,29 @@ class ModelA(models.Model):

class ModelB(models.Model):
fielda = models.ForeignKey(ModelA, on_delete=models.CASCADE)


class Parent(models.Model):
name = models.CharField(default="foo")
description = models.CharField(default="bar")

@property
async def custom_name(self):
return self.name

@property
async def custom_description(self):
return self.description


class Child(models.Model):
name = models.CharField(default="foo")
parent = models.ForeignKey(Parent, on_delete=models.CASCADE)

@property
async def custom_name(self):
return self.name

@property
async def custom_parent(self):
return self.parent
42 changes: 41 additions & 1 deletion tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from adrf.fields import SerializerMethodField
from adrf.serializers import ModelSerializer, Serializer

from .models import ModelA, ModelB, Order, User
from .models import ModelA, ModelB, Order, User, Parent, Child

factory = APIRequestFactory()

Expand Down Expand Up @@ -380,3 +380,43 @@ async def test_order_serializer_for_depth_gt_0(self):
assert await order_serializer.adata == data
print("HELLO", await order_serializer.adata, await user_serializer.adata)
assert (await order_serializer.adata)["user"] == await user_serializer.adata


class TestModelAsyncPropertySerializer(TestCase):
def setUp(self):
class ParentSerializer(aserializers.ModelSerializer):
custom_name = aserializers.CharField()

class Meta:
model = Parent
fields = ["name", "custom_name", "custom_description"]


class ChildSerializer(aserializers.ModelSerializer):
custom_parent = TestSerializer()

class Meta:
model = Child
fields = ["custom_name", "custom_parent"]

self.parent_serializer = ParentSerializer
self.child_serializer = ChildSerializer

async def test_default_field_returns_value(self):
parent = Parent()
serializer = self.parent_serializer(instance=parent)
data = await serializer.adata
assert data["custom_description"] == await parent.custom_description

async def test_provided_field_returns_value(self):
parent = Parent()
serializer = self.parent_serializer(instance=parent)
data = await serializer.adata
assert data["custom_name"] == await parent.custom_name

async def test_nested_serializer_returns_value(self):
parent = Parent()
child = Child(parent=parent)
serializer = self.child_serializer(instance=child)
data = await serializer.adata
assert data["custom_parent"]["custom_name"] == await parent.custom_name