query should only be evaluated at the end

This commit is contained in:
Laura Klünder 2017-06-15 17:36:35 +02:00
parent 13bd8f24af
commit 52404c1cf1

View file

@ -1,5 +1,6 @@
import typing import typing
from collections import deque from collections import deque
from functools import wraps
from itertools import chain from itertools import chain
from django.db import models from django.db import models
@ -9,7 +10,7 @@ from django.db.models.query_utils import DeferredAttribute
class BaseWrapper: class BaseWrapper:
_not_wrapped = ('_changeset', '_author', '_obj', '_changes_qs', '_initial_values', '_wrap_instances') _not_wrapped = ('_changeset', '_author', '_obj', '_commands', '_wrap_instances', '_initial_values')
_allowed_callables = ('', ) _allowed_callables = ('', )
def __init__(self, changeset, obj, author=None): def __init__(self, changeset, obj, author=None):
@ -206,47 +207,84 @@ class ModelInstanceWrapper(BaseWrapper):
self._changeset.add_delete(self, author=author) self._changeset.add_delete(self, author=author)
class ChangesQuerySet: class ModifiesQueryDecorator:
def __init__(self, changeset, model, author): def __init__(self, test_call=False):
self._changeset = changeset self.test_call = test_call
self._model = model
self._author = author def __call__(self, f):
@wraps(f)
def wrapper(qs, *args, execute=False, **kwargs):
if execute:
return f(qs, *args, **kwargs)
if self.test_call:
f(qs, *args, test_call=True, **kwargs)
return qs._wrap_queryset(qs.get_queryset(), add_command=(f.__name__, args, kwargs))
return wrapper
modifies_query = ModifiesQueryDecorator
class ExecutesQueryDecorator:
def __init__(self, test_call=False):
self.test_call = test_call
def __call__(self, f):
@wraps(f)
def wrapper(qs, *args, **kwargs):
qs = qs.execute_commands()
return f(qs, *args, **kwargs)
return wrapper
executes_query = ExecutesQueryDecorator
class BaseQueryWrapper(BaseWrapper): class BaseQueryWrapper(BaseWrapper):
_allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset') _allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset')
def __init__(self, changeset, obj, author=None, changes_qs=None, wrap_instances=True): def __init__(self, changeset, obj, author=None, commands=(), wrap_instances=True):
super().__init__(changeset, obj, author) super().__init__(changeset, obj, author)
if changes_qs is None:
changes_qs = ChangesQuerySet(changeset, obj.model, author)
self._changes_qs = changes_qs
self._wrap_instances = wrap_instances self._wrap_instances = wrap_instances
self._commands = commands
def get_queryset(self): def get_queryset(self):
return self._obj return self._obj
def execute_commands(self):
result = self
for name, args, kwargs in self._commands:
result = getattr(result, name)(*args, execute=True, **kwargs)
result._commands = ()
return result
def _wrap_instance(self, instance): def _wrap_instance(self, instance):
if self._wrap_instances: if self._wrap_instances:
return super()._wrap_instance(instance) return super()._wrap_instance(instance)
return instance return instance
def _wrap_queryset(self, queryset, changes_qs=None, wrap_instances=None): def _wrap_queryset(self, queryset, add_command=None, wrap_instances=None):
if changes_qs is None:
changes_qs = self._changes_qs
if wrap_instances is None: if wrap_instances is None:
wrap_instances = self._wrap_instances wrap_instances = self._wrap_instances
return QuerySetWrapper(self._changeset, queryset, self._author, changes_qs, wrap_instances) commands = self._commands
if add_command is not None:
commands += (add_command, )
return QuerySetWrapper(self._changeset, queryset, self._author, commands, wrap_instances)
@modifies_query()
def all(self): def all(self):
return self._wrap_queryset(self.get_queryset().all()) return self._wrap_queryset(self.get_queryset().all())
@modifies_query()
def none(self): def none(self):
return self._wrap_queryset(self.get_queryset().none()) return self._wrap_queryset(self.get_queryset().none())
@modifies_query()
def select_related(self, *args, **kwargs): def select_related(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs)) return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs))
@modifies_query()
def prefetch_related(self, *lookups): def prefetch_related(self, *lookups):
new_lookups = deque() new_lookups = deque()
for lookup in lookups: for lookup in lookups:
@ -261,6 +299,7 @@ class BaseQueryWrapper(BaseWrapper):
new_lookups.append(Prefetch(lookup, qs)) new_lookups.append(Prefetch(lookup, qs))
return self._wrap_queryset(self.get_queryset().prefetch_related(*new_lookups)) return self._wrap_queryset(self.get_queryset().prefetch_related(*new_lookups))
@executes_query()
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
results = tuple(self.filter(*args, **kwargs)) results = tuple(self.filter(*args, **kwargs))
if len(results) == 1: if len(results) == 1:
@ -269,6 +308,7 @@ class BaseQueryWrapper(BaseWrapper):
raise self._obj.model.DoesNotExist raise self._obj.model.DoesNotExist
raise self._obj.model.MultipleObjectsReturned raise self._obj.model.MultipleObjectsReturned
@modifies_query()
def order_by(self, *args): def order_by(self, *args):
return self._wrap_queryset(self.get_queryset().order_by(*args)) return self._wrap_queryset(self.get_queryset().order_by(*args))
@ -384,33 +424,42 @@ class BaseQueryWrapper(BaseWrapper):
tuple(self._filter_kwarg(name, value) for name, value in kwargs.items()) tuple(self._filter_kwarg(name, value) for name, value in kwargs.items())
) )
@modifies_query()
def filter(self, *args, **kwargs): def filter(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().filter(*self._filter(*args, **kwargs))) return self._wrap_queryset(self.get_queryset().filter(*self._filter(*args, **kwargs)))
@modifies_query()
def exclude(self, *args, **kwargs): def exclude(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().exclude(*self._filter(*args, **kwargs))) return self._wrap_queryset(self.get_queryset().exclude(*self._filter(*args, **kwargs)))
@executes_query()
def count(self): def count(self):
return self.get_queryset().count() return self.get_queryset().count()
@executes_query()
def values_list(self, *args, flat=False): def values_list(self, *args, flat=False):
return self.get_queryset().values_list(*args, flat=flat) return self.get_queryset().values_list(*args, flat=flat)
@executes_query()
def first(self): def first(self):
first = self.get_queryset().first() first = self.get_queryset().first()
if first is not None: if first is not None:
first = self._wrap_instance(first) first = self._wrap_instance(first)
return first return first
@modifies_query()
def using(self, alias): def using(self, alias):
return self._wrap_queryset(self.get_queryset().using(alias)) return self._wrap_queryset(self.get_queryset().using(alias))
@executes_query()
def __iter__(self): def __iter__(self):
return iter((self._wrap_instance(instance) for instance in self.get_queryset())) return iter((self._wrap_instance(instance) for instance in self.get_queryset()))
@executes_query()
def iterator(self): def iterator(self):
return iter((self._wrap_instance(instance) for instance in self.get_queryset().iterator())) return iter((self._wrap_instance(instance) for instance in self.get_queryset().iterator()))
@executes_query()
def __len__(self): def __len__(self):
return len(self.get_queryset()) return len(self.get_queryset())