From f0d4d122da78289a8ad99e4222818e96ac52bf14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Sat, 17 Jun 2017 21:50:15 +0200 Subject: [PATCH] fix changeset parsing on prefetch_related with manytomany --- src/c3nav/editor/api.py | 2 +- src/c3nav/editor/wrappers.py | 69 ++++++++++++++++++++++++++++++------ 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/src/c3nav/editor/api.py b/src/c3nav/editor/api.py index 6abf30d8..61d36ba0 100644 --- a/src/c3nav/editor/api.py +++ b/src/c3nav/editor/api.py @@ -76,7 +76,7 @@ class EditorViewSet(ViewSet): levels, levels_on_top, levels_under = self._get_levels_pk(request, level) # don't prefetch groups for now as changesets do not yet work with m2m-prefetches levels = Level.objects.filter(pk__in=levels).prefetch_related('buildings', 'spaces', 'doors', - 'spaces__holes', # 'spaces__groups', + 'spaces__holes', 'spaces__groups', 'spaces__columns') levels = {s.pk: s for s in levels} diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index fe2e635b..f5e75736 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -1,17 +1,18 @@ import operator import typing +from collections import OrderedDict from functools import reduce from itertools import chain from django.db import models -from django.db.models import Manager, Prefetch, Q +from django.db.models import Manager, ManyToManyRel, Prefetch, Q from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor from django.db.models.query_utils import DeferredAttribute from django.utils.functional import cached_property class BaseWrapper: - _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_initial_values') + _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_initial_values') _allowed_callables = ('', ) def __init__(self, changeset, obj, author=None): @@ -210,14 +211,14 @@ class ModelInstanceWrapper(BaseWrapper): class BaseQueryWrapper(BaseWrapper): - _allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset') + _allowed_callables = ('_add_hints', ) - def __init__(self, changeset, obj, author=None, created_pks=None): + def __init__(self, changeset, obj, author=None, created_pks=None, extra=()): super().__init__(changeset, obj, author) if created_pks is None: created_pks = self._changeset.get_created_pks(self._obj.model) self._created_pks = created_pks - self._result = None + self._extra = extra def get_queryset(self): return self._obj @@ -225,12 +226,12 @@ class BaseQueryWrapper(BaseWrapper): def _wrap_instance(self, instance): return super()._wrap_instance(instance) - def _wrap_queryset(self, queryset, created_pks=None): + def _wrap_queryset(self, queryset, created_pks=None, add_extra=()): if created_pks is None: created_pks = self._created_pks if created_pks is False: created_pks = None - return QuerySetWrapper(self._changeset, queryset, self._author, created_pks) + return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra) def all(self): return self._wrap_queryset(self.get_queryset().all()) @@ -395,7 +396,7 @@ class BaseQueryWrapper(BaseWrapper): m2m_added = {pk: val[rel_name] for pk, val in self._changeset.m2m_added.get(model, {}).items() if pk in filter_value and rel_name in val} m2m_removed = {pk: val[rel_name] for pk, val in self._changeset.m2m_removed.get(model, {}).items() - if pk in filter_value and rel_name in val} # can only be existing spaces + if pk in filter_value and rel_name in val} # can only be existing spaces # directly lookup groups for spaces that had no groups removed q = Q(**{field_name+'__pk__in': filter_value_existing - set(m2m_removed.keys())}) @@ -501,14 +502,62 @@ class BaseQueryWrapper(BaseWrapper): def using(self, alias): return self._wrap_queryset(self.get_queryset().using(alias)) + def extra(self, select): + for key in select.keys(): + if not key.startswith('_prefetch_related_val'): + raise NotImplementedError('extra() calls are only supported for prefetch_related!') + return self._wrap_queryset(self.get_queryset().extra(select), add_extra=tuple(select.keys())) + + def _next_is_sticky(self): + return self._wrap_queryset(self.get_queryset()._next_is_sticky()) + def _get_created_objects(self): return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True) for pk in sorted(self._created_pks)) @cached_property def _cached_result(self): - return (tuple(self._wrap_instance(instance) for instance in self.get_queryset()) + - tuple(self._get_created_objects())) + obj = self.get_queryset() + obj._prefetch_done = True + obj._fetch_all() + + result = [self._wrap_instance(instance) for instance in obj._result_cache] + list(self._get_created_objects()) + + for extra in self._extra: + ex = extra[22:] + for f in self._obj.model._meta.get_fields(): + if isinstance(f, ManyToManyRel) and f.through._meta.get_field(f.field.m2m_field_name()).attname == ex: + objs_by_pk = OrderedDict() + for instance in result: + objs_by_pk.setdefault(instance.pk, OrderedDict())[getattr(instance, extra)] = instance + + m2m_added = self._changeset.m2m_added.get(f.field.model, {}) + m2m_removed = self._changeset.m2m_removed.get(f.field.model, {}) + for related_pk, changes in m2m_added.items(): + for pk in changes.get(f.field.name, ()): + if related_pk not in objs_by_pk[pk]: + print('added', pk, 'to', related_pk) + new = self._wrap_instance(next(iter(objs_by_pk[pk].values()))._obj) + new.__dict__[extra] = related_pk + objs_by_pk[pk][related_pk] = new + + for related_pk, changes in m2m_removed.items(): + for pk in changes.get(f.field.name, ()): + if related_pk in objs_by_pk[pk]: + print('removed', pk, 'from', related_pk) + objs_by_pk[pk].pop(related_pk) + + result = list(chain(*(instances.values() for instances in objs_by_pk.values()))) + break + + obj._result_cache = result + obj._prefetch_done = False + obj._fetch_all() + return [self._wrap_instance(instance) for instance in obj._result_cache] + list(self._get_created_objects()) + + @property + def _results_cache(self): + return self._cached_result def __iter__(self): return iter(self._cached_result)