From 19856dfd8ad6e638bb3bf7a241925328718a820a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Fri, 16 Jun 2017 16:03:51 +0200 Subject: [PATCH] add created object to querysets (not yet fully working with select_related) --- src/c3nav/editor/models.py | 29 +++++--- src/c3nav/editor/urls.py | 10 +-- src/c3nav/editor/views.py | 6 ++ src/c3nav/editor/wrappers.py | 134 +++++++++++++++++++++++++---------- 4 files changed, 128 insertions(+), 51 deletions(-) diff --git a/src/c3nav/editor/models.py b/src/c3nav/editor/models.py index 8755a15e..4c122b44 100644 --- a/src/c3nav/editor/models.py +++ b/src/c3nav/editor/models.py @@ -74,26 +74,31 @@ class ChangeSet(models.Model): self.created_objects[model][change.created_object_id][name].remove(value) return + pk = change.existing_object_pk if change.action == 'update': - self.updated_existing.setdefault(model, {}).setdefault(change.obj_pk, {})[name] = value + self.updated_existing.setdefault(model, {}).setdefault(pk, {})[name] = value elif change.action == 'm2m_add': - m2m_remove_existing = self.m2m_remove_existing.get(model, {}).get(change.obj_pk, ()) + m2m_remove_existing = self.m2m_remove_existing.get(model, {}).get(pk, {}).get(name, ()) if value in m2m_remove_existing: m2m_remove_existing.remove(value) else: - self.m2m_add_existing.setdefault(model, {}).setdefault(change.obj_pk, set()).add(value) + self.m2m_add_existing.setdefault(model, {}).setdefault(pk, {}).setdefault(name, set()).add(value) elif change.action == 'm2m_remove': - m2m_add_existing = self.m2m_add_existing.get(model, {}).get(change.obj_pk, ()) + m2m_add_existing = self.m2m_add_existing.get(model, {}).get(pk, {}).get(name, ()) if value in m2m_add_existing: m2m_add_existing.remove(value) else: - self.m2m_remove_existing.setdefault(model, {}).setdefault(change.obj_pk, set()).add(value) + self.m2m_remove_existing.setdefault(model, {}).setdefault(pk, {}).setdefault(name, set()).add(value) def get_changed_values(self, model, name): r = tuple((pk, values[name]) for pk, values in self.updated_existing.get(model, {}).items() if name in values) return r - def get_created_object(self, model, pk, author=None): + def get_created_values(self, model, name): + r = tuple((pk, values[name]) for pk, values in self.updated_existing.get(model, {}).items() if name in values) + return r + + def get_created_object(self, model, pk, author=None, get_foreign_objects=False): if isinstance(pk, str): pk = int(pk[1:]) self.parse_changes() @@ -112,14 +117,22 @@ class ChangeSet(models.Model): continue if isinstance(class_value, ForwardManyToOneDescriptor): - setattr(obj, class_value.field.attname, pk) + field = class_value.field + setattr(obj, field.attname, value) if isinstance(pk, str): - setattr(obj, class_value.cache_name, self.get_created_object(class_value.field.model, pk)) + setattr(obj, class_value.cache_name, self.get_created_object(field.model, value)) + elif get_foreign_objects: + setattr(obj, class_value.cache_name, self.wrap(field.related_model.objects.get(pk=value))) continue setattr(obj, name, model._meta.get_field(name).to_python(value)) return self.wrap(obj, author=author) + def get_created_pks(self, model): + if issubclass(model, ModelWrapper): + model = model._obj + return set(self.created_objects.get(model, {}).keys()) + @property def cache_key(self): return str(self.pk)+'-'+str(self._last_change_pk) diff --git a/src/c3nav/editor/urls.py b/src/c3nav/editor/urls.py index 0020af83..6a6bc06e 100644 --- a/src/c3nav/editor/urls.py +++ b/src/c3nav/editor/urls.py @@ -10,7 +10,7 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit if parent_model_name: parent_model = apps.get_model('mapdata', parent_model_name) parent_model_name_plural = parent_model._meta.default_related_name - prefix = (parent_model_name_plural+r'/(?P<'+parent_model_name.lower()+'>[0-9]+)/')+model_name_plural + prefix = (parent_model_name_plural+r'/(?P<'+parent_model_name.lower()+'>c?[0-9]+)/')+model_name_plural else: prefix = model_name_plural @@ -22,7 +22,7 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit if with_list: result.append(url(r'^'+prefix+r'/$', list_objects, name=name_prefix+'list', kwargs=kwargs)) result.extend([ - url(r'^'+prefix+r'/(?P\d+)/'+explicit_edit+'$', edit, name=name_prefix+'edit', kwargs=kwargs), + url(r'^'+prefix+r'/(?Pc?\d+)/'+explicit_edit+'$', edit, name=name_prefix+'edit', kwargs=kwargs), url(r'^'+prefix+r'/create$', edit, name=name_prefix+'create', kwargs=kwargs), ]) return result @@ -30,9 +30,9 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit urlpatterns = [ url(r'^$', main_index, name='editor.index'), - url(r'^levels/(?P[0-9]+)/$', level_detail, name='editor.levels.detail'), - url(r'^levels/(?P[0-9]+)/spaces/(?P[0-9]+)/$', space_detail, name='editor.spaces.detail'), - url(r'^levels/(?P[0-9]+)/levels_on_top/create$', edit, name='editor.levels_on_top.create', + url(r'^levels/(?Pc?[0-9]+)/$', level_detail, name='editor.levels.detail'), + url(r'^levels/(?Pc?[0-9]+)/spaces/(?Pc?[0-9]+)/$', space_detail, name='editor.spaces.detail'), + url(r'^levels/(?Pc?[0-9]+)/levels_on_top/create$', edit, name='editor.levels_on_top.create', kwargs={'model': 'Level'}), url(r'^changesets/(?P[0-9]+)/$', changeset_detail, name='editor.changesets.detail'), ] diff --git a/src/c3nav/editor/views.py b/src/c3nav/editor/views.py index df9da090..4694dde8 100644 --- a/src/c3nav/editor/views.py +++ b/src/c3nav/editor/views.py @@ -72,6 +72,10 @@ def level_detail(request, pk): def space_detail(request, level, pk): Space = request.changeset.wrap('Space') space = get_object_or_404(Space.objects.select_related('level'), level__pk=level, pk=pk) + print('also!') + print(Space.objects.select_related('level').get(level__pk=level, pk=pk)) + print(space) + print('aha') return render(request, 'editor/space.html', { 'level': space.level, @@ -293,6 +297,8 @@ def list_objects(request, model=None, level=None, space=None, explicit_edit=Fals edit_url_name = request.resolver_match.url_name[:-4]+('detail' if explicit_edit else 'edit') for obj in queryset: reverse_kwargs['pk'] = obj.pk + print(reverse_kwargs) + print(reverse(edit_url_name, kwargs=reverse_kwargs)) obj.edit_url = reverse(edit_url_name, kwargs=reverse_kwargs) reverse_kwargs.pop('pk', None) diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index c9c7fbe2..40f1fc96 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -1,15 +1,18 @@ +import operator import typing from collections import deque +from functools import reduce from itertools import chain from django.db import models from django.db.models import Manager, Prefetch, Q from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor from django.db.models.query_utils import DeferredAttribute +from django.utils.functional import cached_property class BaseWrapper: - _not_wrapped = ('_changeset', '_author', '_obj', '_changes_qs', '_initial_values', '_wrap_instances') + _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_initial_values', '_wrap_instances') _allowed_callables = ('', ) def __init__(self, changeset, obj, author=None): @@ -206,22 +209,16 @@ class ModelInstanceWrapper(BaseWrapper): self._changeset.add_delete(self, author=author) -class ChangesQuerySet: - def __init__(self, changeset, model, author): - self._changeset = changeset - self._model = model - self._author = author - - class BaseQueryWrapper(BaseWrapper): _allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset') - def __init__(self, changeset, obj, author=None, changes_qs=None, wrap_instances=True): + def __init__(self, changeset, obj, author=None, created_pks=None, wrap_instances=True): super().__init__(changeset, obj, author) - if changes_qs is None: - changes_qs = ChangesQuerySet(changeset, obj.model, author) - self._changes_qs = changes_qs + if created_pks is None: + created_pks = self._changeset.get_created_pks(self._obj.model) + self._created_pks = created_pks self._wrap_instances = wrap_instances + self._result = None def get_queryset(self): return self._obj @@ -231,18 +228,18 @@ class BaseQueryWrapper(BaseWrapper): return super()._wrap_instance(instance) return instance - def _wrap_queryset(self, queryset, changes_qs=None, wrap_instances=None): - if changes_qs is None: - changes_qs = self._changes_qs + def _wrap_queryset(self, queryset, created_pks=None, wrap_instances=None): + if created_pks is None: + created_pks = self._created_pks if wrap_instances is None: wrap_instances = self._wrap_instances - return QuerySetWrapper(self._changeset, queryset, self._author, changes_qs, wrap_instances) + return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, wrap_instances) def all(self): return self._wrap_queryset(self.get_queryset().all()) def none(self): - return self._wrap_queryset(self.get_queryset().none()) + return self._wrap_queryset(self.get_queryset().none(), created_pks=set()) def select_related(self, *args, **kwargs): return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs)) @@ -266,8 +263,8 @@ class BaseQueryWrapper(BaseWrapper): if len(results) == 1: return self._wrap_instance(results[0]) if results: - raise self._obj.model.DoesNotExist - raise self._obj.model.MultipleObjectsReturned + raise self._obj.model.MultipleObjectsReturned + raise self._obj.model.DoesNotExist def order_by(self, *args): return self._wrap_queryset(self.get_queryset().order_by(*args)) @@ -278,7 +275,21 @@ class BaseQueryWrapper(BaseWrapper): remove_pks = [] for pk, new_value in other_values: (add_pks if check(new_value) else remove_pks).append(pk) - return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks) + created_pks = set() + for pk, values in self._changeset.created_objects.get(self._obj.model, {}).items(): + try: + if check(values[field_name]): + created_pks.add(pk) + continue + except AttributeError: + pass + if check(getattr(self._changeset.get_created_object(self._obj.model, pk), field_name)): + created_pks.add(pk) + return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks + + @staticmethod + def is_created_pk(pk): + return isinstance(pk, str) and pk.startswith('c') and pk[1:].isnumeric() def _filter_kwarg(self, filter_name, filter_value): print(filter_name, '=', filter_value, sep='') @@ -294,9 +305,12 @@ class BaseQueryWrapper(BaseWrapper): if field_name == 'pk' or field_name == self._obj.model._meta.pk.name: if not segments: - return q - else: - return q + if self.is_created_pk(filter_value): + return Q(pk__in=()), set([int(filter_value[1:])]) + return q, set() + elif segments == ['in']: + return (Q(pk__in=tuple(pk for pk in filter_value if not self.is_created_pk(pk))), + set(int(pk[1:]) for pk in filter_value if self.is_created_pk(pk))) if isinstance(class_value, ForwardManyToOneDescriptor): if not segments: @@ -318,13 +332,18 @@ class BaseQueryWrapper(BaseWrapper): filter_type = 'pk' if filter_type == 'pk' and segments == ['in']: - return self._filter_values(q, field_name, lambda val: val in filter_value) + q = Q(**{field_name+'__pk__in': tuple(pk for pk in filter_value if not self.is_created_pk(pk))}) + filter_value = tuple(str(pk) for pk in filter_value) + return self._filter_values(q, field_name, lambda val: str(val) in filter_value) if segments: raise NotImplementedError if filter_type == 'pk': - return self._filter_values(q, field_name, lambda val: val == filter_value) + if self.is_created_pk(filter_value): + q = Q(pk__in=()) + filter_value = str(filter_value) + return self._filter_values(q, field_name, lambda val: str(val) == filter_value) if filter_type == 'isnull': return self._filter_values(q, field_name, lambda val: (val is None) is filter_value) @@ -348,9 +367,23 @@ class BaseQueryWrapper(BaseWrapper): if filter_type == 'pk': if class_value.reverse: + # todo: implement this for created models model = class_value.field.model - return ((q & ~Q(pk__in=self._changeset.m2m_remove_existing.get(model, {}).get(filter_value, ()))) | - Q(pk__in=self._changeset.m2m_add_existing.get(model, {}).get(filter_value, ()))) + + if self.is_created_pk(filter_value): + filter_value = int(filter_value[1:]) + pks = tuple(self._changeset.created_objects[model][filter_value].get(field_name, ())) + return (Q(pk__in=(pk for pk in pks if not self.is_created_pk(pk))), + set(int(pk[1:]) for pk in pks if self.is_created_pk(pk))) + + def get_changeset_m2m(items): + return items.get(model, {}).get(filter_value, {}).get(field_name, ()) + + remove_pks = get_changeset_m2m(self._changeset.m2m_remove_existing) + add_pks = get_changeset_m2m(self._changeset.m2m_add_existing) + return (((q & ~Q(pk__in=(pk for pk in remove_pks if not self.is_created_pk(pk)))) | + Q(pk__in=(pk for pk in remove_pks if not self.is_created_pk(pk)))), + set(int(pk[1:]) for pk in add_pks if self.is_created_pk(pk))) raise NotImplementedError @@ -373,25 +406,36 @@ class BaseQueryWrapper(BaseWrapper): raise NotImplementedError('cannot filter %s by %s (%s)' % (self._obj.model, filter_name, class_value)) def _filter_q(self, q): - result = Q(*((self._filter_q(c) if isinstance(c, Q) else self._filter_kwarg(*c)) for c in q.children)) + filters, created_pks = zip(*((self._filter_q(c) if isinstance(c, Q) else self._filter_kwarg(*c)) + for c in q.children)) + result = Q(*filters) result.connector = q.connector result.negated = q.negated - return result - def _filter(self, *args, **kwargs): - return chain( + created_pks = reduce(operator.and_ if q.connector == 'AND' else operator.or_, created_pks) + if q.negated: + created_pks = self._changeset.get_created_pks(self._obj.model)-created_pks + return result, created_pks + + def _filter_or_exclude(self, negate, *args, **kwargs): + filters, created_pks = zip(*tuple(chain( tuple(self._filter_q(q) for q in args), tuple(self._filter_kwarg(name, value) for name, value in kwargs.items()) - ) + ))) + + created_pks = reduce(operator.and_, created_pks) + if negate: + created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks + return self._wrap_queryset(self.get_queryset().filter(*filters), created_pks=(self._created_pks & created_pks)) def filter(self, *args, **kwargs): - return self._wrap_queryset(self.get_queryset().filter(*self._filter(*args, **kwargs))) + return self._filter_or_exclude(False, *args, **kwargs) def exclude(self, *args, **kwargs): - return self._wrap_queryset(self.get_queryset().exclude(*self._filter(*args, **kwargs))) + return self._filter_or_exclude(True, *args, **kwargs) def count(self): - return self.get_queryset().count() + return self.get_queryset().count()+len(tuple(self._get_created_objects())) def values_list(self, *args, flat=False): return self.get_queryset().values_list(*args, flat=flat) @@ -405,14 +449,28 @@ class BaseQueryWrapper(BaseWrapper): def using(self, alias): return self._wrap_queryset(self.get_queryset().using(alias)) + def _get_created_objects(self): + if not self._wrap_instances: + return () + return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True) + for pk in sorted(self._created_pks)) + + @cached_property + def _cached_result(self): + return (tuple(self._wrap_instance(instance) for instance in self.get_queryset()) + + tuple(self._get_created_objects())) + def __iter__(self): - return iter((self._wrap_instance(instance) for instance in self.get_queryset())) + return iter(self._cached_result) def iterator(self): - return iter((self._wrap_instance(instance) for instance in self.get_queryset().iterator())) + return iter(chain( + (self._wrap_instance(instance) for instance in self.get_queryset().iterator()), + self._get_created_objects(), + )) def __len__(self): - return len(self.get_queryset()) + return len(self._cached_result) def create(self, *args, **kwargs): obj = self.model(*args, **kwargs)