From 52404c1cf13e95d0c6dde78a6870c66c03eacb0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Thu, 15 Jun 2017 17:36:35 +0200 Subject: [PATCH] query should only be evaluated at the end --- src/c3nav/editor/wrappers.py | 77 +++++++++++++++++++++++++++++------- 1 file changed, 63 insertions(+), 14 deletions(-) diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index c6bdb478..6226ab73 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -1,5 +1,6 @@ import typing from collections import deque +from functools import wraps from itertools import chain from django.db import models @@ -9,7 +10,7 @@ from django.db.models.query_utils import DeferredAttribute class BaseWrapper: - _not_wrapped = ('_changeset', '_author', '_obj', '_changes_qs', '_initial_values', '_wrap_instances') + _not_wrapped = ('_changeset', '_author', '_obj', '_commands', '_wrap_instances', '_initial_values') _allowed_callables = ('', ) def __init__(self, changeset, obj, author=None): @@ -206,47 +207,84 @@ class ModelInstanceWrapper(BaseWrapper): self._changeset.add_delete(self, author=author) -class ChangesQuerySet: - def __init__(self, changeset, model, author): - self._changeset = changeset - self._model = model - self._author = author +class ModifiesQueryDecorator: + def __init__(self, test_call=False): + self.test_call = test_call + + def __call__(self, f): + @wraps(f) + def wrapper(qs, *args, execute=False, **kwargs): + if execute: + return f(qs, *args, **kwargs) + if self.test_call: + f(qs, *args, test_call=True, **kwargs) + return qs._wrap_queryset(qs.get_queryset(), add_command=(f.__name__, args, kwargs)) + + return wrapper + + +modifies_query = ModifiesQueryDecorator + + +class ExecutesQueryDecorator: + def __init__(self, test_call=False): + self.test_call = test_call + + def __call__(self, f): + @wraps(f) + def wrapper(qs, *args, **kwargs): + qs = qs.execute_commands() + return f(qs, *args, **kwargs) + return wrapper + + +executes_query = ExecutesQueryDecorator class BaseQueryWrapper(BaseWrapper): _allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset') - def __init__(self, changeset, obj, author=None, changes_qs=None, wrap_instances=True): + def __init__(self, changeset, obj, author=None, commands=(), wrap_instances=True): super().__init__(changeset, obj, author) - if changes_qs is None: - changes_qs = ChangesQuerySet(changeset, obj.model, author) - self._changes_qs = changes_qs self._wrap_instances = wrap_instances + self._commands = commands def get_queryset(self): return self._obj + def execute_commands(self): + result = self + for name, args, kwargs in self._commands: + result = getattr(result, name)(*args, execute=True, **kwargs) + result._commands = () + return result + def _wrap_instance(self, instance): if self._wrap_instances: return super()._wrap_instance(instance) return instance - def _wrap_queryset(self, queryset, changes_qs=None, wrap_instances=None): - if changes_qs is None: - changes_qs = self._changes_qs + def _wrap_queryset(self, queryset, add_command=None, wrap_instances=None): if wrap_instances is None: wrap_instances = self._wrap_instances - return QuerySetWrapper(self._changeset, queryset, self._author, changes_qs, wrap_instances) + commands = self._commands + if add_command is not None: + commands += (add_command, ) + return QuerySetWrapper(self._changeset, queryset, self._author, commands, wrap_instances) + @modifies_query() def all(self): return self._wrap_queryset(self.get_queryset().all()) + @modifies_query() def none(self): return self._wrap_queryset(self.get_queryset().none()) + @modifies_query() def select_related(self, *args, **kwargs): return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs)) + @modifies_query() def prefetch_related(self, *lookups): new_lookups = deque() for lookup in lookups: @@ -261,6 +299,7 @@ class BaseQueryWrapper(BaseWrapper): new_lookups.append(Prefetch(lookup, qs)) return self._wrap_queryset(self.get_queryset().prefetch_related(*new_lookups)) + @executes_query() def get(self, *args, **kwargs): results = tuple(self.filter(*args, **kwargs)) if len(results) == 1: @@ -269,6 +308,7 @@ class BaseQueryWrapper(BaseWrapper): raise self._obj.model.DoesNotExist raise self._obj.model.MultipleObjectsReturned + @modifies_query() def order_by(self, *args): return self._wrap_queryset(self.get_queryset().order_by(*args)) @@ -384,33 +424,42 @@ class BaseQueryWrapper(BaseWrapper): tuple(self._filter_kwarg(name, value) for name, value in kwargs.items()) ) + @modifies_query() def filter(self, *args, **kwargs): return self._wrap_queryset(self.get_queryset().filter(*self._filter(*args, **kwargs))) + @modifies_query() def exclude(self, *args, **kwargs): return self._wrap_queryset(self.get_queryset().exclude(*self._filter(*args, **kwargs))) + @executes_query() def count(self): return self.get_queryset().count() + @executes_query() def values_list(self, *args, flat=False): return self.get_queryset().values_list(*args, flat=flat) + @executes_query() def first(self): first = self.get_queryset().first() if first is not None: first = self._wrap_instance(first) return first + @modifies_query() def using(self, alias): return self._wrap_queryset(self.get_queryset().using(alias)) + @executes_query() def __iter__(self): return iter((self._wrap_instance(instance) for instance in self.get_queryset())) + @executes_query() def iterator(self): return iter((self._wrap_instance(instance) for instance in self.get_queryset().iterator())) + @executes_query() def __len__(self): return len(self.get_queryset())