always get queryset on manager (for filtering)

This commit is contained in:
Laura Klünder 2017-06-18 00:45:06 +02:00
parent fb9479783f
commit df880ca43b

View file

@ -1,7 +1,7 @@
import operator
import typing
from collections import OrderedDict
from functools import reduce
from functools import reduce, wraps
from itertools import chain
from django.db import models
@ -210,6 +210,24 @@ class ModelInstanceWrapper(BaseWrapper):
self._changeset.add_delete(self, author=author)
def get_queryset(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, 'get_queryset'):
return getattr(self.get_queryset(), func.__name__)(*args, **kwargs)
return func(self, *args, **kwargs)
return wrapper
def queryset_only(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, 'get_queryset'):
raise TypeError
return func(self, *args, **kwargs)
return wrapper
class BaseQueryWrapper(BaseWrapper):
_allowed_callables = ('_add_hints', 'get_prefetch_queryset', '_apply_rel_filters')
@ -220,9 +238,6 @@ class BaseQueryWrapper(BaseWrapper):
self._created_pks = created_pks
self._extra = extra
def get_queryset(self):
return self._obj
def _wrap_instance(self, instance):
return super()._wrap_instance(instance)
@ -233,15 +248,19 @@ class BaseQueryWrapper(BaseWrapper):
created_pks = None
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra)
@get_queryset
def all(self):
return self._wrap_queryset(self.get_queryset().all())
return self._wrap_queryset(self._obj.all())
@get_queryset
def none(self):
return self._wrap_queryset(self.get_queryset().none(), ())
return self._wrap_queryset(self._obj.none(), ())
@get_queryset
def select_related(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs))
return self._wrap_queryset(self._obj.select_related(*args, **kwargs))
@get_queryset
def prefetch_related(self, *lookups):
lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups)
max_depth = max(len(lookup) for lookup in lookups_splitted)
@ -250,7 +269,7 @@ class BaseQueryWrapper(BaseWrapper):
lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i)))
lookup_models = {(): self._obj.model}
lookup_querysets = {(): self.get_queryset()}
lookup_querysets = {(): self._obj}
for depth_lookups in lookups_by_depth:
for lookup in depth_lookups:
model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model
@ -266,10 +285,11 @@ class BaseQueryWrapper(BaseWrapper):
return self._wrap_queryset(lookup_querysets[()])
def _clone(self, **kwargs):
clone = self._wrap_queryset(self.get_queryset())
clone = self._wrap_queryset(self._obj)
clone._obj.__dict__.update(kwargs)
return clone
@get_queryset
def get(self, *args, **kwargs):
results = tuple(self.filter(*args, **kwargs))
if len(results) == 1:
@ -278,8 +298,9 @@ class BaseQueryWrapper(BaseWrapper):
raise self._obj.model.MultipleObjectsReturned
raise self._obj.model.DoesNotExist
@get_queryset
def order_by(self, *args):
return self._wrap_queryset(self.get_queryset().order_by(*args))
return self._wrap_queryset(self._obj.order_by(*args))
def _filter_values(self, q, field_name, check):
other_values = self._changeset.get_changed_values(self._obj.model, field_name)
@ -478,51 +499,61 @@ class BaseQueryWrapper(BaseWrapper):
created_pks = reduce(operator.and_, created_pks)
if negate:
filters = (~Q(*filters), )
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks
return self._wrap_queryset(self.get_queryset().filter(*filters), created_pks=(self._created_pks & created_pks))
return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks))
@get_queryset
def filter(self, *args, **kwargs):
return self._filter_or_exclude(False, *args, **kwargs)
@get_queryset
def exclude(self, *args, **kwargs):
return self._filter_or_exclude(True, *args, **kwargs)
@get_queryset
def count(self):
return self.get_queryset().count()+len(tuple(self._get_created_objects()))
return self._obj.count()+len(tuple(self._get_created_objects()))
@get_queryset
def values_list(self, *args, flat=False):
own_values = (tuple(getattr(obj, arg) for arg in args) for obj in self._get_created_objects())
if flat:
own_values = (v[0] for v in own_values)
return chain(
self.get_queryset().values_list(*args, flat=flat),
self._obj.values_list(*args, flat=flat),
own_values,
)
@get_queryset
def first(self):
first = self.get_queryset().first()
first = self._obj.first()
if first is not None:
first = self._wrap_instance(first)
return first
@get_queryset
def using(self, alias):
return self._wrap_queryset(self.get_queryset().using(alias))
return self._wrap_queryset(self._obj.using(alias))
@get_queryset
def extra(self, select):
for key in select.keys():
if not key.startswith('_prefetch_related_val'):
raise NotImplementedError('extra() calls are only supported for prefetch_related!')
return self._wrap_queryset(self.get_queryset().extra(select), add_extra=tuple(select.keys()))
return self._wrap_queryset(self._obj.extra(select), add_extra=tuple(select.keys()))
@get_queryset
def _next_is_sticky(self):
return self._wrap_queryset(self.get_queryset()._next_is_sticky())
return self._wrap_queryset(self._obj._next_is_sticky())
def _get_created_objects(self):
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
for pk in sorted(self._created_pks))
@queryset_only
def _get_cached_result(self):
obj = self.get_queryset()
obj = self._obj
obj._prefetch_done = True
obj._fetch_all()
@ -574,15 +605,18 @@ class BaseQueryWrapper(BaseWrapper):
def _results_cache(self):
return self._cached_result
@queryset_only
def __iter__(self):
return iter(self._cached_result)
@queryset_only
def iterator(self):
return iter(chain(
(self._wrap_instance(instance) for instance in self.get_queryset().iterator()),
(self._wrap_instance(instance) for instance in self._obj.iterator()),
self._get_created_objects(),
))
@queryset_only
def __len__(self):
return len(self._cached_result)
@ -594,7 +628,8 @@ class BaseQueryWrapper(BaseWrapper):
class ManagerWrapper(BaseQueryWrapper):
def get_queryset(self):
return super().get_queryset().exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
qs = self._wrap_queryset(self._obj.model.objects.all())
return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
class RelatedManagerWrapper(ManagerWrapper):