always get queryset on manager (for filtering)
This commit is contained in:
parent
fb9479783f
commit
df880ca43b
1 changed files with 55 additions and 20 deletions
|
@ -1,7 +1,7 @@
|
|||
import operator
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
from functools import reduce, wraps
|
||||
from itertools import chain
|
||||
|
||||
from django.db import models
|
||||
|
@ -210,6 +210,24 @@ class ModelInstanceWrapper(BaseWrapper):
|
|||
self._changeset.add_delete(self, author=author)
|
||||
|
||||
|
||||
def get_queryset(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if hasattr(self, 'get_queryset'):
|
||||
return getattr(self.get_queryset(), func.__name__)(*args, **kwargs)
|
||||
return func(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def queryset_only(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if hasattr(self, 'get_queryset'):
|
||||
raise TypeError
|
||||
return func(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class BaseQueryWrapper(BaseWrapper):
|
||||
_allowed_callables = ('_add_hints', 'get_prefetch_queryset', '_apply_rel_filters')
|
||||
|
||||
|
@ -220,9 +238,6 @@ class BaseQueryWrapper(BaseWrapper):
|
|||
self._created_pks = created_pks
|
||||
self._extra = extra
|
||||
|
||||
def get_queryset(self):
|
||||
return self._obj
|
||||
|
||||
def _wrap_instance(self, instance):
|
||||
return super()._wrap_instance(instance)
|
||||
|
||||
|
@ -233,15 +248,19 @@ class BaseQueryWrapper(BaseWrapper):
|
|||
created_pks = None
|
||||
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra)
|
||||
|
||||
@get_queryset
|
||||
def all(self):
|
||||
return self._wrap_queryset(self.get_queryset().all())
|
||||
return self._wrap_queryset(self._obj.all())
|
||||
|
||||
@get_queryset
|
||||
def none(self):
|
||||
return self._wrap_queryset(self.get_queryset().none(), ())
|
||||
return self._wrap_queryset(self._obj.none(), ())
|
||||
|
||||
@get_queryset
|
||||
def select_related(self, *args, **kwargs):
|
||||
return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs))
|
||||
return self._wrap_queryset(self._obj.select_related(*args, **kwargs))
|
||||
|
||||
@get_queryset
|
||||
def prefetch_related(self, *lookups):
|
||||
lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups)
|
||||
max_depth = max(len(lookup) for lookup in lookups_splitted)
|
||||
|
@ -250,7 +269,7 @@ class BaseQueryWrapper(BaseWrapper):
|
|||
lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i)))
|
||||
|
||||
lookup_models = {(): self._obj.model}
|
||||
lookup_querysets = {(): self.get_queryset()}
|
||||
lookup_querysets = {(): self._obj}
|
||||
for depth_lookups in lookups_by_depth:
|
||||
for lookup in depth_lookups:
|
||||
model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model
|
||||
|
@ -266,10 +285,11 @@ class BaseQueryWrapper(BaseWrapper):
|
|||
return self._wrap_queryset(lookup_querysets[()])
|
||||
|
||||
def _clone(self, **kwargs):
|
||||
clone = self._wrap_queryset(self.get_queryset())
|
||||
clone = self._wrap_queryset(self._obj)
|
||||
clone._obj.__dict__.update(kwargs)
|
||||
return clone
|
||||
|
||||
@get_queryset
|
||||
def get(self, *args, **kwargs):
|
||||
results = tuple(self.filter(*args, **kwargs))
|
||||
if len(results) == 1:
|
||||
|
@ -278,8 +298,9 @@ class BaseQueryWrapper(BaseWrapper):
|
|||
raise self._obj.model.MultipleObjectsReturned
|
||||
raise self._obj.model.DoesNotExist
|
||||
|
||||
@get_queryset
|
||||
def order_by(self, *args):
|
||||
return self._wrap_queryset(self.get_queryset().order_by(*args))
|
||||
return self._wrap_queryset(self._obj.order_by(*args))
|
||||
|
||||
def _filter_values(self, q, field_name, check):
|
||||
other_values = self._changeset.get_changed_values(self._obj.model, field_name)
|
||||
|
@ -478,51 +499,61 @@ class BaseQueryWrapper(BaseWrapper):
|
|||
|
||||
created_pks = reduce(operator.and_, created_pks)
|
||||
if negate:
|
||||
filters = (~Q(*filters), )
|
||||
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks
|
||||
return self._wrap_queryset(self.get_queryset().filter(*filters), created_pks=(self._created_pks & created_pks))
|
||||
return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks))
|
||||
|
||||
@get_queryset
|
||||
def filter(self, *args, **kwargs):
|
||||
return self._filter_or_exclude(False, *args, **kwargs)
|
||||
|
||||
@get_queryset
|
||||
def exclude(self, *args, **kwargs):
|
||||
return self._filter_or_exclude(True, *args, **kwargs)
|
||||
|
||||
@get_queryset
|
||||
def count(self):
|
||||
return self.get_queryset().count()+len(tuple(self._get_created_objects()))
|
||||
return self._obj.count()+len(tuple(self._get_created_objects()))
|
||||
|
||||
@get_queryset
|
||||
def values_list(self, *args, flat=False):
|
||||
own_values = (tuple(getattr(obj, arg) for arg in args) for obj in self._get_created_objects())
|
||||
if flat:
|
||||
own_values = (v[0] for v in own_values)
|
||||
return chain(
|
||||
self.get_queryset().values_list(*args, flat=flat),
|
||||
self._obj.values_list(*args, flat=flat),
|
||||
own_values,
|
||||
)
|
||||
|
||||
@get_queryset
|
||||
def first(self):
|
||||
first = self.get_queryset().first()
|
||||
first = self._obj.first()
|
||||
if first is not None:
|
||||
first = self._wrap_instance(first)
|
||||
return first
|
||||
|
||||
@get_queryset
|
||||
def using(self, alias):
|
||||
return self._wrap_queryset(self.get_queryset().using(alias))
|
||||
return self._wrap_queryset(self._obj.using(alias))
|
||||
|
||||
@get_queryset
|
||||
def extra(self, select):
|
||||
for key in select.keys():
|
||||
if not key.startswith('_prefetch_related_val'):
|
||||
raise NotImplementedError('extra() calls are only supported for prefetch_related!')
|
||||
return self._wrap_queryset(self.get_queryset().extra(select), add_extra=tuple(select.keys()))
|
||||
return self._wrap_queryset(self._obj.extra(select), add_extra=tuple(select.keys()))
|
||||
|
||||
@get_queryset
|
||||
def _next_is_sticky(self):
|
||||
return self._wrap_queryset(self.get_queryset()._next_is_sticky())
|
||||
return self._wrap_queryset(self._obj._next_is_sticky())
|
||||
|
||||
def _get_created_objects(self):
|
||||
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
|
||||
for pk in sorted(self._created_pks))
|
||||
|
||||
@queryset_only
|
||||
def _get_cached_result(self):
|
||||
obj = self.get_queryset()
|
||||
obj = self._obj
|
||||
obj._prefetch_done = True
|
||||
obj._fetch_all()
|
||||
|
||||
|
@ -574,15 +605,18 @@ class BaseQueryWrapper(BaseWrapper):
|
|||
def _results_cache(self):
|
||||
return self._cached_result
|
||||
|
||||
@queryset_only
|
||||
def __iter__(self):
|
||||
return iter(self._cached_result)
|
||||
|
||||
@queryset_only
|
||||
def iterator(self):
|
||||
return iter(chain(
|
||||
(self._wrap_instance(instance) for instance in self.get_queryset().iterator()),
|
||||
(self._wrap_instance(instance) for instance in self._obj.iterator()),
|
||||
self._get_created_objects(),
|
||||
))
|
||||
|
||||
@queryset_only
|
||||
def __len__(self):
|
||||
return len(self._cached_result)
|
||||
|
||||
|
@ -594,7 +628,8 @@ class BaseQueryWrapper(BaseWrapper):
|
|||
|
||||
class ManagerWrapper(BaseQueryWrapper):
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
|
||||
qs = self._wrap_queryset(self._obj.model.objects.all())
|
||||
return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
|
||||
|
||||
|
||||
class RelatedManagerWrapper(ManagerWrapper):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue