always get queryset on manager (for filtering)
This commit is contained in:
parent
fb9479783f
commit
df880ca43b
1 changed files with 55 additions and 20 deletions
|
@ -1,7 +1,7 @@
|
||||||
import operator
|
import operator
|
||||||
import typing
|
import typing
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import reduce
|
from functools import reduce, wraps
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
@ -210,6 +210,24 @@ class ModelInstanceWrapper(BaseWrapper):
|
||||||
self._changeset.add_delete(self, author=author)
|
self._changeset.add_delete(self, author=author)
|
||||||
|
|
||||||
|
|
||||||
|
def get_queryset(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
if hasattr(self, 'get_queryset'):
|
||||||
|
return getattr(self.get_queryset(), func.__name__)(*args, **kwargs)
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def queryset_only(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
if hasattr(self, 'get_queryset'):
|
||||||
|
raise TypeError
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class BaseQueryWrapper(BaseWrapper):
|
class BaseQueryWrapper(BaseWrapper):
|
||||||
_allowed_callables = ('_add_hints', 'get_prefetch_queryset', '_apply_rel_filters')
|
_allowed_callables = ('_add_hints', 'get_prefetch_queryset', '_apply_rel_filters')
|
||||||
|
|
||||||
|
@ -220,9 +238,6 @@ class BaseQueryWrapper(BaseWrapper):
|
||||||
self._created_pks = created_pks
|
self._created_pks = created_pks
|
||||||
self._extra = extra
|
self._extra = extra
|
||||||
|
|
||||||
def get_queryset(self):
|
|
||||||
return self._obj
|
|
||||||
|
|
||||||
def _wrap_instance(self, instance):
|
def _wrap_instance(self, instance):
|
||||||
return super()._wrap_instance(instance)
|
return super()._wrap_instance(instance)
|
||||||
|
|
||||||
|
@ -233,15 +248,19 @@ class BaseQueryWrapper(BaseWrapper):
|
||||||
created_pks = None
|
created_pks = None
|
||||||
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra)
|
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra)
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def all(self):
|
def all(self):
|
||||||
return self._wrap_queryset(self.get_queryset().all())
|
return self._wrap_queryset(self._obj.all())
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def none(self):
|
def none(self):
|
||||||
return self._wrap_queryset(self.get_queryset().none(), ())
|
return self._wrap_queryset(self._obj.none(), ())
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def select_related(self, *args, **kwargs):
|
def select_related(self, *args, **kwargs):
|
||||||
return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs))
|
return self._wrap_queryset(self._obj.select_related(*args, **kwargs))
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def prefetch_related(self, *lookups):
|
def prefetch_related(self, *lookups):
|
||||||
lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups)
|
lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups)
|
||||||
max_depth = max(len(lookup) for lookup in lookups_splitted)
|
max_depth = max(len(lookup) for lookup in lookups_splitted)
|
||||||
|
@ -250,7 +269,7 @@ class BaseQueryWrapper(BaseWrapper):
|
||||||
lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i)))
|
lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i)))
|
||||||
|
|
||||||
lookup_models = {(): self._obj.model}
|
lookup_models = {(): self._obj.model}
|
||||||
lookup_querysets = {(): self.get_queryset()}
|
lookup_querysets = {(): self._obj}
|
||||||
for depth_lookups in lookups_by_depth:
|
for depth_lookups in lookups_by_depth:
|
||||||
for lookup in depth_lookups:
|
for lookup in depth_lookups:
|
||||||
model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model
|
model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model
|
||||||
|
@ -266,10 +285,11 @@ class BaseQueryWrapper(BaseWrapper):
|
||||||
return self._wrap_queryset(lookup_querysets[()])
|
return self._wrap_queryset(lookup_querysets[()])
|
||||||
|
|
||||||
def _clone(self, **kwargs):
|
def _clone(self, **kwargs):
|
||||||
clone = self._wrap_queryset(self.get_queryset())
|
clone = self._wrap_queryset(self._obj)
|
||||||
clone._obj.__dict__.update(kwargs)
|
clone._obj.__dict__.update(kwargs)
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def get(self, *args, **kwargs):
|
def get(self, *args, **kwargs):
|
||||||
results = tuple(self.filter(*args, **kwargs))
|
results = tuple(self.filter(*args, **kwargs))
|
||||||
if len(results) == 1:
|
if len(results) == 1:
|
||||||
|
@ -278,8 +298,9 @@ class BaseQueryWrapper(BaseWrapper):
|
||||||
raise self._obj.model.MultipleObjectsReturned
|
raise self._obj.model.MultipleObjectsReturned
|
||||||
raise self._obj.model.DoesNotExist
|
raise self._obj.model.DoesNotExist
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def order_by(self, *args):
|
def order_by(self, *args):
|
||||||
return self._wrap_queryset(self.get_queryset().order_by(*args))
|
return self._wrap_queryset(self._obj.order_by(*args))
|
||||||
|
|
||||||
def _filter_values(self, q, field_name, check):
|
def _filter_values(self, q, field_name, check):
|
||||||
other_values = self._changeset.get_changed_values(self._obj.model, field_name)
|
other_values = self._changeset.get_changed_values(self._obj.model, field_name)
|
||||||
|
@ -478,51 +499,61 @@ class BaseQueryWrapper(BaseWrapper):
|
||||||
|
|
||||||
created_pks = reduce(operator.and_, created_pks)
|
created_pks = reduce(operator.and_, created_pks)
|
||||||
if negate:
|
if negate:
|
||||||
|
filters = (~Q(*filters), )
|
||||||
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks
|
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks
|
||||||
return self._wrap_queryset(self.get_queryset().filter(*filters), created_pks=(self._created_pks & created_pks))
|
return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks))
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def filter(self, *args, **kwargs):
|
def filter(self, *args, **kwargs):
|
||||||
return self._filter_or_exclude(False, *args, **kwargs)
|
return self._filter_or_exclude(False, *args, **kwargs)
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def exclude(self, *args, **kwargs):
|
def exclude(self, *args, **kwargs):
|
||||||
return self._filter_or_exclude(True, *args, **kwargs)
|
return self._filter_or_exclude(True, *args, **kwargs)
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def count(self):
|
def count(self):
|
||||||
return self.get_queryset().count()+len(tuple(self._get_created_objects()))
|
return self._obj.count()+len(tuple(self._get_created_objects()))
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def values_list(self, *args, flat=False):
|
def values_list(self, *args, flat=False):
|
||||||
own_values = (tuple(getattr(obj, arg) for arg in args) for obj in self._get_created_objects())
|
own_values = (tuple(getattr(obj, arg) for arg in args) for obj in self._get_created_objects())
|
||||||
if flat:
|
if flat:
|
||||||
own_values = (v[0] for v in own_values)
|
own_values = (v[0] for v in own_values)
|
||||||
return chain(
|
return chain(
|
||||||
self.get_queryset().values_list(*args, flat=flat),
|
self._obj.values_list(*args, flat=flat),
|
||||||
own_values,
|
own_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def first(self):
|
def first(self):
|
||||||
first = self.get_queryset().first()
|
first = self._obj.first()
|
||||||
if first is not None:
|
if first is not None:
|
||||||
first = self._wrap_instance(first)
|
first = self._wrap_instance(first)
|
||||||
return first
|
return first
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def using(self, alias):
|
def using(self, alias):
|
||||||
return self._wrap_queryset(self.get_queryset().using(alias))
|
return self._wrap_queryset(self._obj.using(alias))
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def extra(self, select):
|
def extra(self, select):
|
||||||
for key in select.keys():
|
for key in select.keys():
|
||||||
if not key.startswith('_prefetch_related_val'):
|
if not key.startswith('_prefetch_related_val'):
|
||||||
raise NotImplementedError('extra() calls are only supported for prefetch_related!')
|
raise NotImplementedError('extra() calls are only supported for prefetch_related!')
|
||||||
return self._wrap_queryset(self.get_queryset().extra(select), add_extra=tuple(select.keys()))
|
return self._wrap_queryset(self._obj.extra(select), add_extra=tuple(select.keys()))
|
||||||
|
|
||||||
|
@get_queryset
|
||||||
def _next_is_sticky(self):
|
def _next_is_sticky(self):
|
||||||
return self._wrap_queryset(self.get_queryset()._next_is_sticky())
|
return self._wrap_queryset(self._obj._next_is_sticky())
|
||||||
|
|
||||||
def _get_created_objects(self):
|
def _get_created_objects(self):
|
||||||
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
|
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
|
||||||
for pk in sorted(self._created_pks))
|
for pk in sorted(self._created_pks))
|
||||||
|
|
||||||
|
@queryset_only
|
||||||
def _get_cached_result(self):
|
def _get_cached_result(self):
|
||||||
obj = self.get_queryset()
|
obj = self._obj
|
||||||
obj._prefetch_done = True
|
obj._prefetch_done = True
|
||||||
obj._fetch_all()
|
obj._fetch_all()
|
||||||
|
|
||||||
|
@ -574,15 +605,18 @@ class BaseQueryWrapper(BaseWrapper):
|
||||||
def _results_cache(self):
|
def _results_cache(self):
|
||||||
return self._cached_result
|
return self._cached_result
|
||||||
|
|
||||||
|
@queryset_only
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self._cached_result)
|
return iter(self._cached_result)
|
||||||
|
|
||||||
|
@queryset_only
|
||||||
def iterator(self):
|
def iterator(self):
|
||||||
return iter(chain(
|
return iter(chain(
|
||||||
(self._wrap_instance(instance) for instance in self.get_queryset().iterator()),
|
(self._wrap_instance(instance) for instance in self._obj.iterator()),
|
||||||
self._get_created_objects(),
|
self._get_created_objects(),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@queryset_only
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._cached_result)
|
return len(self._cached_result)
|
||||||
|
|
||||||
|
@ -594,7 +628,8 @@ class BaseQueryWrapper(BaseWrapper):
|
||||||
|
|
||||||
class ManagerWrapper(BaseQueryWrapper):
|
class ManagerWrapper(BaseQueryWrapper):
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
return super().get_queryset().exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
|
qs = self._wrap_queryset(self._obj.model.objects.all())
|
||||||
|
return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
|
||||||
|
|
||||||
|
|
||||||
class RelatedManagerWrapper(ManagerWrapper):
|
class RelatedManagerWrapper(ManagerWrapper):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue