From df880ca43b59c945960e1220809d354e5f268e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Sun, 18 Jun 2017 00:45:06 +0200 Subject: [PATCH] always get queryset on manager (for filtering) --- src/c3nav/editor/wrappers.py | 75 ++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index 49fcda8a..4b84b58f 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -1,7 +1,7 @@ import operator import typing from collections import OrderedDict -from functools import reduce +from functools import reduce, wraps from itertools import chain from django.db import models @@ -210,6 +210,24 @@ class ModelInstanceWrapper(BaseWrapper): self._changeset.add_delete(self, author=author) +def get_queryset(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if hasattr(self, 'get_queryset'): + return getattr(self.get_queryset(), func.__name__)(*args, **kwargs) + return func(self, *args, **kwargs) + return wrapper + + +def queryset_only(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if hasattr(self, 'get_queryset'): + raise TypeError + return func(self, *args, **kwargs) + return wrapper + + class BaseQueryWrapper(BaseWrapper): _allowed_callables = ('_add_hints', 'get_prefetch_queryset', '_apply_rel_filters') @@ -220,9 +238,6 @@ class BaseQueryWrapper(BaseWrapper): self._created_pks = created_pks self._extra = extra - def get_queryset(self): - return self._obj - def _wrap_instance(self, instance): return super()._wrap_instance(instance) @@ -233,15 +248,19 @@ class BaseQueryWrapper(BaseWrapper): created_pks = None return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra) + @get_queryset def all(self): - return self._wrap_queryset(self.get_queryset().all()) + return self._wrap_queryset(self._obj.all()) + @get_queryset def none(self): - return self._wrap_queryset(self.get_queryset().none(), ()) + return self._wrap_queryset(self._obj.none(), ()) + @get_queryset def select_related(self, *args, **kwargs): - return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs)) + return self._wrap_queryset(self._obj.select_related(*args, **kwargs)) + @get_queryset def prefetch_related(self, *lookups): lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups) max_depth = max(len(lookup) for lookup in lookups_splitted) @@ -250,7 +269,7 @@ class BaseQueryWrapper(BaseWrapper): lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i))) lookup_models = {(): self._obj.model} - lookup_querysets = {(): self.get_queryset()} + lookup_querysets = {(): self._obj} for depth_lookups in lookups_by_depth: for lookup in depth_lookups: model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model @@ -266,10 +285,11 @@ class BaseQueryWrapper(BaseWrapper): return self._wrap_queryset(lookup_querysets[()]) def _clone(self, **kwargs): - clone = self._wrap_queryset(self.get_queryset()) + clone = self._wrap_queryset(self._obj) clone._obj.__dict__.update(kwargs) return clone + @get_queryset def get(self, *args, **kwargs): results = tuple(self.filter(*args, **kwargs)) if len(results) == 1: @@ -278,8 +298,9 @@ class BaseQueryWrapper(BaseWrapper): raise self._obj.model.MultipleObjectsReturned raise self._obj.model.DoesNotExist + @get_queryset def order_by(self, *args): - return self._wrap_queryset(self.get_queryset().order_by(*args)) + return self._wrap_queryset(self._obj.order_by(*args)) def _filter_values(self, q, field_name, check): other_values = self._changeset.get_changed_values(self._obj.model, field_name) @@ -478,51 +499,61 @@ class BaseQueryWrapper(BaseWrapper): created_pks = reduce(operator.and_, created_pks) if negate: + filters = (~Q(*filters), ) created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks - return self._wrap_queryset(self.get_queryset().filter(*filters), created_pks=(self._created_pks & created_pks)) + return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks)) + @get_queryset def filter(self, *args, **kwargs): return self._filter_or_exclude(False, *args, **kwargs) + @get_queryset def exclude(self, *args, **kwargs): return self._filter_or_exclude(True, *args, **kwargs) + @get_queryset def count(self): - return self.get_queryset().count()+len(tuple(self._get_created_objects())) + return self._obj.count()+len(tuple(self._get_created_objects())) + @get_queryset def values_list(self, *args, flat=False): own_values = (tuple(getattr(obj, arg) for arg in args) for obj in self._get_created_objects()) if flat: own_values = (v[0] for v in own_values) return chain( - self.get_queryset().values_list(*args, flat=flat), + self._obj.values_list(*args, flat=flat), own_values, ) + @get_queryset def first(self): - first = self.get_queryset().first() + first = self._obj.first() if first is not None: first = self._wrap_instance(first) return first + @get_queryset def using(self, alias): - return self._wrap_queryset(self.get_queryset().using(alias)) + return self._wrap_queryset(self._obj.using(alias)) + @get_queryset def extra(self, select): for key in select.keys(): if not key.startswith('_prefetch_related_val'): raise NotImplementedError('extra() calls are only supported for prefetch_related!') - return self._wrap_queryset(self.get_queryset().extra(select), add_extra=tuple(select.keys())) + return self._wrap_queryset(self._obj.extra(select), add_extra=tuple(select.keys())) + @get_queryset def _next_is_sticky(self): - return self._wrap_queryset(self.get_queryset()._next_is_sticky()) + return self._wrap_queryset(self._obj._next_is_sticky()) def _get_created_objects(self): return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True) for pk in sorted(self._created_pks)) + @queryset_only def _get_cached_result(self): - obj = self.get_queryset() + obj = self._obj obj._prefetch_done = True obj._fetch_all() @@ -574,15 +605,18 @@ class BaseQueryWrapper(BaseWrapper): def _results_cache(self): return self._cached_result + @queryset_only def __iter__(self): return iter(self._cached_result) + @queryset_only def iterator(self): return iter(chain( - (self._wrap_instance(instance) for instance in self.get_queryset().iterator()), + (self._wrap_instance(instance) for instance in self._obj.iterator()), self._get_created_objects(), )) + @queryset_only def __len__(self): return len(self._cached_result) @@ -594,7 +628,8 @@ class BaseQueryWrapper(BaseWrapper): class ManagerWrapper(BaseQueryWrapper): def get_queryset(self): - return super().get_queryset().exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ())) + qs = self._wrap_queryset(self._obj.model.objects.all()) + return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ())) class RelatedManagerWrapper(ManagerWrapper):