fix changeset parsing on prefetch_related with manytomany

This commit is contained in:
Laura Klünder 2017-06-17 21:50:15 +02:00
parent cb4867614c
commit f0d4d122da
2 changed files with 60 additions and 11 deletions

View file

@ -76,7 +76,7 @@ class EditorViewSet(ViewSet):
levels, levels_on_top, levels_under = self._get_levels_pk(request, level) levels, levels_on_top, levels_under = self._get_levels_pk(request, level)
# don't prefetch groups for now as changesets do not yet work with m2m-prefetches # don't prefetch groups for now as changesets do not yet work with m2m-prefetches
levels = Level.objects.filter(pk__in=levels).prefetch_related('buildings', 'spaces', 'doors', levels = Level.objects.filter(pk__in=levels).prefetch_related('buildings', 'spaces', 'doors',
'spaces__holes', # 'spaces__groups', 'spaces__holes', 'spaces__groups',
'spaces__columns') 'spaces__columns')
levels = {s.pk: s for s in levels} levels = {s.pk: s for s in levels}

View file

@ -1,17 +1,18 @@
import operator import operator
import typing import typing
from collections import OrderedDict
from functools import reduce from functools import reduce
from itertools import chain from itertools import chain
from django.db import models from django.db import models
from django.db.models import Manager, Prefetch, Q from django.db.models import Manager, ManyToManyRel, Prefetch, Q
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor
from django.db.models.query_utils import DeferredAttribute from django.db.models.query_utils import DeferredAttribute
from django.utils.functional import cached_property from django.utils.functional import cached_property
class BaseWrapper: class BaseWrapper:
_not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_initial_values') _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_initial_values')
_allowed_callables = ('', ) _allowed_callables = ('', )
def __init__(self, changeset, obj, author=None): def __init__(self, changeset, obj, author=None):
@ -210,14 +211,14 @@ class ModelInstanceWrapper(BaseWrapper):
class BaseQueryWrapper(BaseWrapper): class BaseQueryWrapper(BaseWrapper):
_allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset') _allowed_callables = ('_add_hints', )
def __init__(self, changeset, obj, author=None, created_pks=None): def __init__(self, changeset, obj, author=None, created_pks=None, extra=()):
super().__init__(changeset, obj, author) super().__init__(changeset, obj, author)
if created_pks is None: if created_pks is None:
created_pks = self._changeset.get_created_pks(self._obj.model) created_pks = self._changeset.get_created_pks(self._obj.model)
self._created_pks = created_pks self._created_pks = created_pks
self._result = None self._extra = extra
def get_queryset(self): def get_queryset(self):
return self._obj return self._obj
@ -225,12 +226,12 @@ class BaseQueryWrapper(BaseWrapper):
def _wrap_instance(self, instance): def _wrap_instance(self, instance):
return super()._wrap_instance(instance) return super()._wrap_instance(instance)
def _wrap_queryset(self, queryset, created_pks=None): def _wrap_queryset(self, queryset, created_pks=None, add_extra=()):
if created_pks is None: if created_pks is None:
created_pks = self._created_pks created_pks = self._created_pks
if created_pks is False: if created_pks is False:
created_pks = None created_pks = None
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks) return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, self._extra+add_extra)
def all(self): def all(self):
return self._wrap_queryset(self.get_queryset().all()) return self._wrap_queryset(self.get_queryset().all())
@ -501,14 +502,62 @@ class BaseQueryWrapper(BaseWrapper):
def using(self, alias): def using(self, alias):
return self._wrap_queryset(self.get_queryset().using(alias)) return self._wrap_queryset(self.get_queryset().using(alias))
def extra(self, select):
for key in select.keys():
if not key.startswith('_prefetch_related_val'):
raise NotImplementedError('extra() calls are only supported for prefetch_related!')
return self._wrap_queryset(self.get_queryset().extra(select), add_extra=tuple(select.keys()))
def _next_is_sticky(self):
return self._wrap_queryset(self.get_queryset()._next_is_sticky())
def _get_created_objects(self): def _get_created_objects(self):
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True) return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
for pk in sorted(self._created_pks)) for pk in sorted(self._created_pks))
@cached_property @cached_property
def _cached_result(self): def _cached_result(self):
return (tuple(self._wrap_instance(instance) for instance in self.get_queryset()) + obj = self.get_queryset()
tuple(self._get_created_objects())) obj._prefetch_done = True
obj._fetch_all()
result = [self._wrap_instance(instance) for instance in obj._result_cache] + list(self._get_created_objects())
for extra in self._extra:
ex = extra[22:]
for f in self._obj.model._meta.get_fields():
if isinstance(f, ManyToManyRel) and f.through._meta.get_field(f.field.m2m_field_name()).attname == ex:
objs_by_pk = OrderedDict()
for instance in result:
objs_by_pk.setdefault(instance.pk, OrderedDict())[getattr(instance, extra)] = instance
m2m_added = self._changeset.m2m_added.get(f.field.model, {})
m2m_removed = self._changeset.m2m_removed.get(f.field.model, {})
for related_pk, changes in m2m_added.items():
for pk in changes.get(f.field.name, ()):
if related_pk not in objs_by_pk[pk]:
print('added', pk, 'to', related_pk)
new = self._wrap_instance(next(iter(objs_by_pk[pk].values()))._obj)
new.__dict__[extra] = related_pk
objs_by_pk[pk][related_pk] = new
for related_pk, changes in m2m_removed.items():
for pk in changes.get(f.field.name, ()):
if related_pk in objs_by_pk[pk]:
print('removed', pk, 'from', related_pk)
objs_by_pk[pk].pop(related_pk)
result = list(chain(*(instances.values() for instances in objs_by_pk.values())))
break
obj._result_cache = result
obj._prefetch_done = False
obj._fetch_all()
return [self._wrap_instance(instance) for instance in obj._result_cache] + list(self._get_created_objects())
@property
def _results_cache(self):
return self._cached_result
def __iter__(self): def __iter__(self):
return iter(self._cached_result) return iter(self._cached_result)