fix changeset parsing on prefetch_related with manytomany

This commit is contained in:
Laura Klünder 2017-06-17 21:50:15 +02:00
parent cb4867614c
commit f0d4d122da
2 changed files with 60 additions and 11 deletions

View file

@ -76,7 +76,7 @@ class EditorViewSet(ViewSet):
levels, levels_on_top, levels_under = self._get_levels_pk(request, level)
# don't prefetch groups for now as changesets do not yet work with m2m-prefetches
levels = Level.objects.filter(pk__in=levels).prefetch_related('buildings', 'spaces', 'doors',
'spaces__holes', # 'spaces__groups',
'spaces__holes', 'spaces__groups',
'spaces__columns')
levels = {s.pk: s for s in levels}

View file

@ -1,17 +1,18 @@
import operator
import typing
from collections import OrderedDict
from functools import reduce
from itertools import chain
from django.db import models
from django.db.models import Manager, Prefetch, Q
from django.db.models import Manager, ManyToManyRel, Prefetch, Q
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor
from django.db.models.query_utils import DeferredAttribute
from django.utils.functional import cached_property
class BaseWrapper:
_not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_initial_values')
_not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_initial_values')
_allowed_callables = ('', )
def __init__(self, changeset, obj, author=None):
@ -210,14 +211,14 @@ class ModelInstanceWrapper(BaseWrapper):
class BaseQueryWrapper(BaseWrapper):
_allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset')
_allowed_callables = ('_add_hints', )
def __init__(self, changeset, obj, author=None, created_pks=None):
def __init__(self, changeset, obj, author=None, created_pks=None, extra=()):
super().__init__(changeset, obj, author)
if created_pks is None:
created_pks = self._changeset.get_created_pks(self._obj.model)
self._created_pks = created_pks
self._result = None
self._extra = extra
def get_queryset(self):
return self._obj
@ -225,12 +226,12 @@ class BaseQueryWrapper(BaseWrapper):
def _wrap_instance(self, instance):
return super()._wrap_instance(instance)
def _wrap_queryset(self, queryset, created_pks=None):
def _wrap_queryset(self, queryset, created_pks=None, add_extra=()):
if created_pks is None:
created_pks = self._created_pks
if created_pks is False:
created_pks = None
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks)
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra)
def all(self):
return self._wrap_queryset(self.get_queryset().all())
@ -395,7 +396,7 @@ class BaseQueryWrapper(BaseWrapper):
m2m_added = {pk: val[rel_name] for pk, val in self._changeset.m2m_added.get(model, {}).items()
if pk in filter_value and rel_name in val}
m2m_removed = {pk: val[rel_name] for pk, val in self._changeset.m2m_removed.get(model, {}).items()
if pk in filter_value and rel_name in val} # can only be existing spaces
if pk in filter_value and rel_name in val} # can only be existing spaces
# directly lookup groups for spaces that had no groups removed
q = Q(**{field_name+'__pk__in': filter_value_existing - set(m2m_removed.keys())})
@ -501,14 +502,62 @@ class BaseQueryWrapper(BaseWrapper):
def using(self, alias):
return self._wrap_queryset(self.get_queryset().using(alias))
def extra(self, select):
for key in select.keys():
if not key.startswith('_prefetch_related_val'):
raise NotImplementedError('extra() calls are only supported for prefetch_related!')
return self._wrap_queryset(self.get_queryset().extra(select), add_extra=tuple(select.keys()))
def _next_is_sticky(self):
return self._wrap_queryset(self.get_queryset()._next_is_sticky())
def _get_created_objects(self):
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
for pk in sorted(self._created_pks))
@cached_property
def _cached_result(self):
return (tuple(self._wrap_instance(instance) for instance in self.get_queryset()) +
tuple(self._get_created_objects()))
obj = self.get_queryset()
obj._prefetch_done = True
obj._fetch_all()
result = [self._wrap_instance(instance) for instance in obj._result_cache] + list(self._get_created_objects())
for extra in self._extra:
ex = extra[22:]
for f in self._obj.model._meta.get_fields():
if isinstance(f, ManyToManyRel) and f.through._meta.get_field(f.field.m2m_field_name()).attname == ex:
objs_by_pk = OrderedDict()
for instance in result:
objs_by_pk.setdefault(instance.pk, OrderedDict())[getattr(instance, extra)] = instance
m2m_added = self._changeset.m2m_added.get(f.field.model, {})
m2m_removed = self._changeset.m2m_removed.get(f.field.model, {})
for related_pk, changes in m2m_added.items():
for pk in changes.get(f.field.name, ()):
if related_pk not in objs_by_pk[pk]:
print('added', pk, 'to', related_pk)
new = self._wrap_instance(next(iter(objs_by_pk[pk].values()))._obj)
new.__dict__[extra] = related_pk
objs_by_pk[pk][related_pk] = new
for related_pk, changes in m2m_removed.items():
for pk in changes.get(f.field.name, ()):
if related_pk in objs_by_pk[pk]:
print('removed', pk, 'from', related_pk)
objs_by_pk[pk].pop(related_pk)
result = list(chain(*(instances.values() for instances in objs_by_pk.values())))
break
obj._result_cache = result
obj._prefetch_done = False
obj._fetch_all()
return [self._wrap_instance(instance) for instance in obj._result_cache] + list(self._get_created_objects())
@property
def _results_cache(self):
return self._cached_result
def __iter__(self):
return iter(self._cached_result)