Skip to content

Commit a6fc693

Browse files
ferrantsrpkilby
authored andcommitted
Support to_field_name for related filters (philipn#294)
1 parent 604327f commit a6fc693

File tree

5 files changed

+103
-34
lines changed

5 files changed

+103
-34
lines changed

rest_framework_filters/filterset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import copy
22
from collections import OrderedDict
33

4-
from django.db.models import Subquery
54
from django.db.models.constants import LOOKUP_SEP
65
from django_filters import filterset, rest_framework
76
from django_filters.utils import get_model_field
@@ -300,9 +299,13 @@ def filter_related_filtersets(self, queryset):
300299
if not any(value.startswith(prefix) for value in self.data):
301300
continue
302301

302+
field = self.filters[related_name].field
303+
to_field_name = getattr(field, 'to_field_name', 'pk') or 'pk'
304+
303305
field_name = self.filters[related_name].field_name
304306
lookup_expr = LOOKUP_SEP.join([field_name, 'in'])
305-
subquery = Subquery(related_filterset.qs.values('pk'))
307+
308+
subquery = related_filterset.qs.values(to_field_name)
306309
queryset = queryset.filter(**{lookup_expr: subquery})
307310

308311
return queryset

tests/test_filtering.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from rest_framework_filters import FilterSet, filters
55

66
from .testapp.filters import (
7-
CFilter, CoverFilter, NoteFilter, NoteFilterWithAlias, NoteFilterWithRelatedAlias,
8-
PageFilter, PersonFilter, PostFilter, UserFilter,
7+
AccountFilter, CFilter, CoverFilter, CustomerFilter, NoteFilter, NoteFilterWithAlias,
8+
NoteFilterWithRelatedAlias, PageFilter, PersonFilter, PostFilter, UserFilter,
9+
)
10+
from .testapp.models import (
11+
A, Account, B, C, Cover, Customer, Note, Page, Person, Post, Tag, User,
912
)
10-
from .testapp.models import A, B, C, Cover, Note, Page, Person, Post, Tag, User
1113

1214

1315
class LocalTagFilter(FilterSet):
@@ -81,54 +83,47 @@ class RelatedFilterTests(TestCase):
8183

8284
@classmethod
8385
def setUpTestData(cls):
84-
#######################
85-
# Create users
86-
#######################
86+
########################################################################
87+
# Create users #########################################################
8788
user1 = User.objects.create(username="user1", email="[email protected]")
8889
user2 = User.objects.create(username="user2", email="[email protected]")
8990

90-
#######################
91-
# Create notes
92-
#######################
91+
########################################################################
92+
# Create notes #########################################################
9393
note1 = Note.objects.create(title="Test 1", content="Test content 1", author=user1)
9494
note2 = Note.objects.create(title="Test 2", content="Test content 2", author=user1)
9595
Note.objects.create(title="Hello Test 3", content="Test content 3", author=user1)
9696
note4 = Note.objects.create(title="Hello Test 4", content="Test content 4", author=user2)
9797

98-
#######################
99-
# Create posts
100-
#######################
98+
########################################################################
99+
# Create posts #########################################################
101100
post1 = Post.objects.create(note=note1, content="Test content in post 1")
102101
Post.objects.create(note=note2, content="Test content in post 2")
103102
post3 = Post.objects.create(note=note4, content="Test content in post 3")
104103

105-
#######################
106-
# Create covers
107-
#######################
104+
########################################################################
105+
# Create covers ########################################################
108106
Cover.objects.create(post=post1, comment="Cover 1")
109107
Cover.objects.create(post=post3, comment="Cover 2")
110108

111-
#######################
112-
# Create pages
113-
#######################
109+
########################################################################
110+
# Create pages #########################################################
114111
Page.objects.create(title="First page", content="First first.")
115112
Page.objects.create(title="Second page", content="Second second.", previous_page_id=1)
116113
Page.objects.create(title="Third page", content="Third third.", previous_page_id=2)
117114
Page.objects.create(title="Fourth page", content="Fourth fourth.", previous_page_id=3)
118115

119-
################################
120-
# ManyToMany
121-
################################
116+
########################################################################
117+
# ManyToMany ###########################################################
122118
t1 = Tag.objects.create(name="park")
123119
Tag.objects.create(name="lake")
124120
t3 = Tag.objects.create(name="house")
125121

126122
post1.tags.set([t1, t3])
127123
post3.tags.set([t3])
128124

129-
################################
130-
# Recursive relations
131-
################################
125+
########################################################################
126+
# Recursive relations ##################################################
132127
a = A.objects.create(title="A1")
133128
b = B.objects.create(name="B1")
134129
c = C.objects.create(title="C1")
@@ -146,6 +141,17 @@ def setUpTestData(cls):
146141
john = Person.objects.create(name="John")
147142
Person.objects.create(name="Mark", best_friend=john)
148143

144+
########################################################################
145+
# to_field relations ###################################################
146+
c1 = Customer.objects.create(name='Bob Jones', ssn='111111111', dob='1990-01-01')
147+
c2 = Customer.objects.create(name='Sue Jones', ssn='222222222', dob='1990-01-01')
148+
149+
Account.objects.create(customer=c1, type='c', name='Bank 1 checking')
150+
Account.objects.create(customer=c1, type='s', name='Bank 1 savings')
151+
Account.objects.create(customer=c2, type='c', name='Bank 1 checking 1')
152+
Account.objects.create(customer=c2, type='c', name='Bank 1 checking 2')
153+
Account.objects.create(customer=c2, type='s', name='Bank 2 savings')
154+
149155
def test_relatedfilter(self):
150156
# Test that the default exact filter works
151157
GET = {'author': User.objects.get(username='user2').pk}
@@ -399,6 +405,17 @@ def test_empty_param_name(self):
399405
f = NoteFilter(GET, queryset=Note.objects.all())
400406
self.assertEqual(len(list(f.qs)), 1)
401407

408+
def test_to_field_forwards_relation(self):
409+
GET = {'customer__name': 'Bob Jones'}
410+
f = AccountFilter(GET)
411+
self.assertEqual(len(list(f.qs)), 2)
412+
413+
def test_to_field_reverse_relation(self):
414+
# Note: pending #99, this query should ideally return 2 distinct results
415+
GET = {'accounts__type': 'c'}
416+
f = CustomerFilter(GET)
417+
self.assertEqual(len(list(f.qs)), 3)
418+
402419

403420
class AnnotationTests(TestCase):
404421
# TODO: these tests should somehow assert that the annotation method is

tests/testapp/filters.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from rest_framework_filters.filters import AutoFilter, RelatedFilter
66
from rest_framework_filters.filterset import FilterSet
77

8-
from .models import A, B, Blog, C, Cover, Note, Page, Person, Post, Tag, User
8+
from .models import (
9+
A, Account, B, Blog, C, Cover, Customer, Note, Page, Person, Post, Tag, User,
10+
)
911

1012

1113
class DFUserFilter(django_filters.FilterSet):
@@ -101,9 +103,8 @@ class Meta:
101103
fields = []
102104

103105

104-
#############################################################
105-
# Aliased parameter names
106-
#############################################################
106+
################################################################################
107+
# Aliased parameter names ######################################################
107108
class UserFilterWithAlias(FilterSet):
108109
name = filters.CharFilter(field_name='username')
109110

@@ -129,9 +130,8 @@ class Meta:
129130
fields = []
130131

131132

132-
#############################################################
133-
# Recursive filtersets
134-
#############################################################
133+
################################################################################
134+
# Recursive filtersets #########################################################
135135
class AFilter(FilterSet):
136136
title = filters.CharFilter(field_name='title')
137137
b = RelatedFilter('BFilter', field_name='b', queryset=B.objects.all())
@@ -170,3 +170,21 @@ class PersonFilter(FilterSet):
170170
class Meta:
171171
model = Person
172172
fields = []
173+
174+
175+
################################################################################
176+
# `to_field` filtersets ########################################################
177+
class CustomerFilter(FilterSet):
178+
accounts = RelatedFilter('AccountFilter', field_name='account', queryset=Account.objects.all())
179+
180+
class Meta:
181+
model = Customer
182+
fields = ['name', 'ssn', 'dob', 'accounts']
183+
184+
185+
class AccountFilter(FilterSet):
186+
customer = RelatedFilter('CustomerFilter', to_field_name='ssn', queryset=Customer.objects.all())
187+
188+
class Meta:
189+
model = Account
190+
fields = ['customer', 'type', 'name']

tests/testapp/migrations/0001_initial.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Generated by Django 2.2.2 on 2019-06-05 21:58
1+
# Generated by Django 2.2.2 on 2019-08-14 20:04
22

33
from django.conf import settings
44
from django.db import migrations, models
@@ -28,6 +28,15 @@ class Migration(migrations.Migration):
2828
('name', models.CharField(max_length=100)),
2929
],
3030
),
31+
migrations.CreateModel(
32+
name='Customer',
33+
fields=[
34+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
35+
('name', models.CharField(max_length=80)),
36+
('ssn', models.CharField(max_length=9, unique=True)),
37+
('dob', models.DateField()),
38+
],
39+
),
3140
migrations.CreateModel(
3241
name='Note',
3342
fields=[
@@ -101,6 +110,15 @@ class Migration(migrations.Migration):
101110
('c', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='testapp.C')),
102111
],
103112
),
113+
migrations.CreateModel(
114+
name='Account',
115+
fields=[
116+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
117+
('type', models.CharField(choices=[('c', 'Checking'), ('s', 'Savings')], max_length=1)),
118+
('name', models.CharField(max_length=80)),
119+
('customer', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='testapp.Customer', to_field='ssn')),
120+
],
121+
),
104122
migrations.AddField(
105123
model_name='a',
106124
name='b',

tests/testapp/models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,16 @@ class Person(models.Model):
7373
date_joined = models.DateField(auto_now_add=True)
7474
time_joined = models.TimeField(auto_now_add=True)
7575
datetime_joined = models.DateTimeField(auto_now_add=True)
76+
77+
78+
# Models using `to_field`
79+
class Customer(models.Model):
80+
name = models.CharField(max_length=80)
81+
ssn = models.CharField(max_length=9, unique=True)
82+
dob = models.DateField()
83+
84+
85+
class Account(models.Model):
86+
customer = models.ForeignKey(Customer, to_field='ssn', on_delete=models.CASCADE)
87+
type = models.CharField(max_length=1, choices=[('c', 'Checking'), ('s', 'Savings')])
88+
name = models.CharField(max_length=80)

0 commit comments

Comments
 (0)