always get queryset on manager (for filtering)

This commit is contained in:
Laura Klünder 2017-06-18 00:45:06 +02:00
parent fb9479783f
commit df880ca43b

View file

@ -1,7 +1,7 @@
import operator import operator
import typing import typing
from collections import OrderedDict from collections import OrderedDict
from functools import reduce from functools import reduce, wraps
from itertools import chain from itertools import chain
from django.db import models from django.db import models
@ -210,6 +210,24 @@ class ModelInstanceWrapper(BaseWrapper):
self._changeset.add_delete(self, author=author) self._changeset.add_delete(self, author=author)
def get_queryset(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, 'get_queryset'):
return getattr(self.get_queryset(), func.__name__)(*args, **kwargs)
return func(self, *args, **kwargs)
return wrapper
def queryset_only(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, 'get_queryset'):
raise TypeError
return func(self, *args, **kwargs)
return wrapper
class BaseQueryWrapper(BaseWrapper): class BaseQueryWrapper(BaseWrapper):
_allowed_callables = ('_add_hints', 'get_prefetch_queryset', '_apply_rel_filters') _allowed_callables = ('_add_hints', 'get_prefetch_queryset', '_apply_rel_filters')
@ -220,9 +238,6 @@ class BaseQueryWrapper(BaseWrapper):
self._created_pks = created_pks self._created_pks = created_pks
self._extra = extra self._extra = extra
def get_queryset(self):
return self._obj
def _wrap_instance(self, instance): def _wrap_instance(self, instance):
return super()._wrap_instance(instance) return super()._wrap_instance(instance)
@ -233,15 +248,19 @@ class BaseQueryWrapper(BaseWrapper):
created_pks = None created_pks = None
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra) return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra)
@get_queryset
def all(self): def all(self):
return self._wrap_queryset(self.get_queryset().all()) return self._wrap_queryset(self._obj.all())
@get_queryset
def none(self): def none(self):
return self._wrap_queryset(self.get_queryset().none(), ()) return self._wrap_queryset(self._obj.none(), ())
@get_queryset
def select_related(self, *args, **kwargs): def select_related(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs)) return self._wrap_queryset(self._obj.select_related(*args, **kwargs))
@get_queryset
def prefetch_related(self, *lookups): def prefetch_related(self, *lookups):
lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups) lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups)
max_depth = max(len(lookup) for lookup in lookups_splitted) max_depth = max(len(lookup) for lookup in lookups_splitted)
@ -250,7 +269,7 @@ class BaseQueryWrapper(BaseWrapper):
lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i))) lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i)))
lookup_models = {(): self._obj.model} lookup_models = {(): self._obj.model}
lookup_querysets = {(): self.get_queryset()} lookup_querysets = {(): self._obj}
for depth_lookups in lookups_by_depth: for depth_lookups in lookups_by_depth:
for lookup in depth_lookups: for lookup in depth_lookups:
model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model
@ -266,10 +285,11 @@ class BaseQueryWrapper(BaseWrapper):
return self._wrap_queryset(lookup_querysets[()]) return self._wrap_queryset(lookup_querysets[()])
def _clone(self, **kwargs): def _clone(self, **kwargs):
clone = self._wrap_queryset(self.get_queryset()) clone = self._wrap_queryset(self._obj)
clone._obj.__dict__.update(kwargs) clone._obj.__dict__.update(kwargs)
return clone return clone
@get_queryset
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
results = tuple(self.filter(*args, **kwargs)) results = tuple(self.filter(*args, **kwargs))
if len(results) == 1: if len(results) == 1:
@ -278,8 +298,9 @@ class BaseQueryWrapper(BaseWrapper):
raise self._obj.model.MultipleObjectsReturned raise self._obj.model.MultipleObjectsReturned
raise self._obj.model.DoesNotExist raise self._obj.model.DoesNotExist
@get_queryset
def order_by(self, *args): def order_by(self, *args):
return self._wrap_queryset(self.get_queryset().order_by(*args)) return self._wrap_queryset(self._obj.order_by(*args))
def _filter_values(self, q, field_name, check): def _filter_values(self, q, field_name, check):
other_values = self._changeset.get_changed_values(self._obj.model, field_name) other_values = self._changeset.get_changed_values(self._obj.model, field_name)
@ -478,51 +499,61 @@ class BaseQueryWrapper(BaseWrapper):
created_pks = reduce(operator.and_, created_pks) created_pks = reduce(operator.and_, created_pks)
if negate: if negate:
filters = (~Q(*filters), )
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks
return self._wrap_queryset(self.get_queryset().filter(*filters), created_pks=(self._created_pks & created_pks)) return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks))
@get_queryset
def filter(self, *args, **kwargs): def filter(self, *args, **kwargs):
return self._filter_or_exclude(False, *args, **kwargs) return self._filter_or_exclude(False, *args, **kwargs)
@get_queryset
def exclude(self, *args, **kwargs): def exclude(self, *args, **kwargs):
return self._filter_or_exclude(True, *args, **kwargs) return self._filter_or_exclude(True, *args, **kwargs)
@get_queryset
def count(self): def count(self):
return self.get_queryset().count()+len(tuple(self._get_created_objects())) return self._obj.count()+len(tuple(self._get_created_objects()))
@get_queryset
def values_list(self, *args, flat=False): def values_list(self, *args, flat=False):
own_values = (tuple(getattr(obj, arg) for arg in args) for obj in self._get_created_objects()) own_values = (tuple(getattr(obj, arg) for arg in args) for obj in self._get_created_objects())
if flat: if flat:
own_values = (v[0] for v in own_values) own_values = (v[0] for v in own_values)
return chain( return chain(
self.get_queryset().values_list(*args, flat=flat), self._obj.values_list(*args, flat=flat),
own_values, own_values,
) )
@get_queryset
def first(self): def first(self):
first = self.get_queryset().first() first = self._obj.first()
if first is not None: if first is not None:
first = self._wrap_instance(first) first = self._wrap_instance(first)
return first return first
@get_queryset
def using(self, alias): def using(self, alias):
return self._wrap_queryset(self.get_queryset().using(alias)) return self._wrap_queryset(self._obj.using(alias))
@get_queryset
def extra(self, select): def extra(self, select):
for key in select.keys(): for key in select.keys():
if not key.startswith('_prefetch_related_val'): if not key.startswith('_prefetch_related_val'):
raise NotImplementedError('extra() calls are only supported for prefetch_related!') raise NotImplementedError('extra() calls are only supported for prefetch_related!')
return self._wrap_queryset(self.get_queryset().extra(select), add_extra=tuple(select.keys())) return self._wrap_queryset(self._obj.extra(select), add_extra=tuple(select.keys()))
@get_queryset
def _next_is_sticky(self): def _next_is_sticky(self):
return self._wrap_queryset(self.get_queryset()._next_is_sticky()) return self._wrap_queryset(self._obj._next_is_sticky())
def _get_created_objects(self): def _get_created_objects(self):
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True) return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
for pk in sorted(self._created_pks)) for pk in sorted(self._created_pks))
@queryset_only
def _get_cached_result(self): def _get_cached_result(self):
obj = self.get_queryset() obj = self._obj
obj._prefetch_done = True obj._prefetch_done = True
obj._fetch_all() obj._fetch_all()
@ -574,15 +605,18 @@ class BaseQueryWrapper(BaseWrapper):
def _results_cache(self): def _results_cache(self):
return self._cached_result return self._cached_result
@queryset_only
def __iter__(self): def __iter__(self):
return iter(self._cached_result) return iter(self._cached_result)
@queryset_only
def iterator(self): def iterator(self):
return iter(chain( return iter(chain(
(self._wrap_instance(instance) for instance in self.get_queryset().iterator()), (self._wrap_instance(instance) for instance in self._obj.iterator()),
self._get_created_objects(), self._get_created_objects(),
)) ))
@queryset_only
def __len__(self): def __len__(self):
return len(self._cached_result) return len(self._cached_result)
@ -594,7 +628,8 @@ class BaseQueryWrapper(BaseWrapper):
class ManagerWrapper(BaseQueryWrapper): class ManagerWrapper(BaseQueryWrapper):
def get_queryset(self): def get_queryset(self):
return super().get_queryset().exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ())) qs = self._wrap_queryset(self._obj.model.objects.all())
return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
class RelatedManagerWrapper(ManagerWrapper): class RelatedManagerWrapper(ManagerWrapper):