Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

django-filters support #64

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ from drf_multiple_model.views import ObjectMultipleModelAPIView
* pagination
* Filtering -- either per queryset or on all querysets
* custom model labeling

* django-filters support
*
For full configuration options, filtering tools, and more, see [the documentation](https://django-rest-multiple-models.readthedocs.org/en/latest/).

# Basic Usage
Expand All @@ -68,11 +69,21 @@ class PlaySerializer(serializers.ModelSerializer):
class Meta:
model = Play
fields = ('genre','title','pages')


class PlayFilter(django_filters.FilterSet):
class Meta:
model = Play
fields = ['genre', 'pages']

class PoemSerializer(serializers.ModelSerializer):
class Meta:
model = Poem
fields = ('title','stanzas')

class PoemFilter(django_filters.FilterSet):
class Meta:
model = Poem
fields = ['style', 'lines', 'stanzas']
```

Then you might use the `ObjectMultipleModelAPIView` as follows:
Expand All @@ -83,8 +94,8 @@ from drf_multiple_model.views import ObjectMultipleModelAPIView

class TextAPIView(ObjectMultipleModelAPIView):
querylist = [
{'queryset': Play.objects.all(), 'serializer_class': PlaySerializer},
{'queryset': Poem.objects.filter(style='Sonnet'), 'serializer_class': PoemSerializer},
{'queryset': Play.objects.all(), 'serializer_class': PlaySerializer, 'filterset_class':PlayFilter},
{'queryset': Poem.objects.filter(style='Sonnet'), 'serializer_class': PoemSerializer,'filterset_class':PoemFilter},
....
]
```
Expand Down
114 changes: 71 additions & 43 deletions drf_multiple_model/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from copy import deepcopy

from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet
Expand All @@ -9,20 +10,20 @@ class BaseMultipleModelMixin(object):
"""
Base class that holds functions need for all MultipleModelMixins/Views
"""

querylist = None

# Keys required for every item in a querylist
required_keys = ['queryset', 'serializer_class']
required_keys = ["queryset", "serializer_class"]

# default pagination state. Gets overridden if pagination is active
is_paginated = False
default_filterset_class = None

def get_querylist(self):
assert self.querylist is not None, (
'{} should either include a `querylist` attribute, '
'or override the `get_querylist()` method.'.format(
self.__class__.__name__
)
"{} should either include a `querylist` attribute, "
"or override the `get_querylist()` method.".format(self.__class__.__name__)
)

return self.querylist
Expand All @@ -36,8 +37,8 @@ def check_query_data(self, query_data):
for key in self.required_keys:
if key not in query_data:
raise ValidationError(
'All items in the {} querylist attribute should contain a '
'`{}` key'.format(self.__class__.__name__, key)
"All items in the {} querylist attribute should contain a "
"`{}` key".format(self.__class__.__name__, key)
)

def load_queryset(self, query_data, request, *args, **kwargs):
Expand All @@ -46,17 +47,17 @@ def load_queryset(self, query_data, request, *args, **kwargs):
built-in rest_framework filters and custom filters passed into
the querylist
"""
queryset = query_data.get('queryset', [])
queryset = query_data.get("queryset", [])
filterset_class = query_data.get("filterset_class", None)

if isinstance(queryset, QuerySet):
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()

# run rest_framework filters
queryset = self.filter_queryset(queryset)
queryset = self.filter_queryset_custom(queryset, filterset_class)

# run custom filters
filter_fn = query_data.get('filter_fn', None)
filter_fn = query_data.get("filter_fn", None)
if filter_fn is not None:
queryset = filter_fn(queryset, request, *args, **kwargs)

Expand All @@ -65,15 +66,32 @@ def load_queryset(self, query_data, request, *args, **kwargs):

return page if page is not None else queryset

def filter_queryset_custom(self, queryset, filterset_class=None):

old_filterset_class = getattr(self, "filterset_class", None)
for backend in list(self.filter_backends):

try:
from django_filters.rest_framework import DjangoFilterBackend

if issubclass(backend, DjangoFilterBackend):
self.filterset_class = filterset_class
except ImportError:
pass

queryset = backend().filter_queryset(self.request, queryset, self)
self.filterset_class = old_filterset_class
return queryset

def get_empty_results(self):
"""
Because the base result type is different depending on the return structure
(e.g. list for flat, dict for object), `get_result_type` initials the
`results` variable to the proper type
"""
assert self.result_type is not None, (
'{} must specify a `result_type` value or overwrite the '
'`get_empty_result` method.'.format(self.__class__.__name__)
"{} must specify a `result_type` value or overwrite the "
"`get_empty_result` method.".format(self.__class__.__name__)
)

return self.result_type()
Expand All @@ -84,10 +102,8 @@ def add_to_results(self, data, label, results):
data from this queryset/serializer combo
"""
raise NotImplementedError(
'{} must specify how to add data to the running results tally '
'by overriding the `add_to_results` method.'.format(
self.__class__.__name__
)
"{} must specify how to add data to the running results tally "
"by overriding the `add_to_results` method.".format(self.__class__.__name__)
)

def format_results(self, results, request):
Expand All @@ -109,7 +125,9 @@ def list(self, request, *args, **kwargs):

# Run the paired serializer
context = self.get_serializer_context()
data = query_data['serializer_class'](queryset, many=True, context=context).data
data = query_data["serializer_class"](
queryset, many=True, context=context
).data

label = self.get_label(queryset, query_data)

Expand Down Expand Up @@ -156,6 +174,7 @@ class FlatMultipleModelMixin(BaseMultipleModelMixin):
...
]
"""

# Optional keyword to sort flat lasts by given attribute
# note that the attribute must by shared by ALL models
sorting_field = None
Expand All @@ -166,14 +185,16 @@ class FlatMultipleModelMixin(BaseMultipleModelMixin):
# Django-like model lookups are supported via '__', but you have to be sure that all querysets will return results
# with corresponding structure.
sorting_fields_map = {}
sorting_parameter_name = 'o'
sorting_parameter_name = "o"

# Flag to append the particular django model being used to the data
add_model_type = True

result_type = list

_list_attribute_error = 'Invalid sorting field. Corresponding data item is a list: {}'
_list_attribute_error = (
"Invalid sorting field. Corresponding data item is a list: {}"
)

def initial(self, request, *args, **kwargs):
"""
Expand All @@ -182,13 +203,15 @@ def initial(self, request, *args, **kwargs):
after original `initial` has been ran in order to make sure that view has all its properties set up.
"""
super(FlatMultipleModelMixin, self).initial(request, *args, **kwargs)
assert not (self.sorting_field and self.sorting_fields), \
'{} should either define ``sorting_field`` or ``sorting_fields`` property, not both.' \
.format(self.__class__.__name__)
assert not (
self.sorting_field and self.sorting_fields
), "{} should either define ``sorting_field`` or ``sorting_fields`` property, not both.".format(
self.__class__.__name__
)
if self.sorting_field:
warnings.warn(
'``sorting_field`` property is pending its deprecation. Use ``sorting_fields`` instead.',
DeprecationWarning
"``sorting_field`` property is pending its deprecation. Use ``sorting_fields`` instead.",
DeprecationWarning,
)
self.sorting_fields = [self.sorting_field]
self._sorting_fields = self.sorting_fields
Expand All @@ -198,13 +221,13 @@ def get_label(self, queryset, query_data):
Gets option label for each datum. Can be used for type identification
of individual serialized objects
"""
if query_data.get('label', False):
return query_data['label']
if query_data.get("label", False):
return query_data["label"]
elif self.add_model_type:
try:
return queryset.model.__name__
except AttributeError:
return query_data['queryset'].model.__name__
return query_data["queryset"].model.__name__

def add_to_results(self, data, label, results):
"""
Expand All @@ -213,7 +236,7 @@ def add_to_results(self, data, label, results):
"""
for datum in data:
if label is not None:
datum.update({'type': label})
datum.update({"type": label})

results.append(datum)

Expand All @@ -227,9 +250,9 @@ def format_results(self, results, request):
if self._sorting_fields:
results = self.sort_results(results)

if request.accepted_renderer.format == 'html':
if request.accepted_renderer.format == "html":
# Makes the the results available to the template context by transforming to a dict
results = {'data': results}
results = {"data": results}

return results

Expand All @@ -240,8 +263,8 @@ def _sort_by(self, datum, param, path=None):
if not path:
path = []
try:
if '__' in param:
root, new_param = param.split('__')
if "__" in param:
root, new_param = param.split("__")
path.append(root)
return self._sort_by(datum[root], param=new_param, path=path)
else:
Expand All @@ -252,9 +275,9 @@ def _sort_by(self, datum, param, path=None):
raise ValidationError(self._list_attribute_error.format(param))
return data
except TypeError:
raise ValidationError(self._list_attribute_error.format('.'.join(path)))
raise ValidationError(self._list_attribute_error.format(".".join(path)))
except KeyError:
raise ValidationError('Invalid sorting field: {}'.format('.'.join(path)))
raise ValidationError("Invalid sorting field: {}".format(".".join(path)))

def prepare_sorting_fields(self):
"""
Expand All @@ -264,22 +287,26 @@ def prepare_sorting_fields(self):
if self.sorting_parameter_name in self.request.query_params:
# Extract sorting parameter from query string
self._sorting_fields = [
_.strip() for _ in self.request.query_params.get(self.sorting_parameter_name).split(',')
_.strip()
for _ in self.request.query_params.get(
self.sorting_parameter_name
).split(",")
]

if self._sorting_fields:
# Create a list of sorting parameters. Each parameter is a tuple: (field:str, descending:bool)
self._sorting_fields = [
(self.sorting_fields_map.get(field.lstrip('-'), field.lstrip('-')), field[0] == '-')
(
self.sorting_fields_map.get(field.lstrip("-"), field.lstrip("-")),
field[0] == "-",
)
for field in self._sorting_fields
]

def sort_results(self, results):
for field, descending in reversed(self._sorting_fields):
results = sorted(
results,
reverse=descending,
key=lambda x: self._sort_by(x, field)
results, reverse=descending, key=lambda x: self._sort_by(x, field)
)
return results

Expand Down Expand Up @@ -314,6 +341,7 @@ class ObjectMultipleModelMixin(BaseMultipleModelMixin):
...
}
"""

result_type = dict

def add_to_results(self, data, label, results):
Expand All @@ -326,10 +354,10 @@ def get_label(self, queryset, query_data):
Gets option label for each datum. Can be used for type identification
of individual serialized objects
"""
if query_data.get('label', False):
return query_data['label']
if query_data.get("label", False):
return query_data["label"]

try:
return queryset.model.__name__
except AttributeError:
return query_data['queryset'].model.__name__
return query_data["queryset"].model.__name__