From 38baebb536c7138c1d7e918d202cb6cc7dc3f936 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Tue, 13 Jun 2017 22:07:36 +0200 Subject: [PATCH] don't break prefetch_related --- src/c3nav/editor/wrappers.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index 7a49f53b..a4bbf768 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -24,6 +24,8 @@ class BaseWrapper: assert isinstance(manager, Manager) if hasattr(manager, 'through'): return ManyRelatedManagerWrapper(self._changeset, manager, self._author) + if hasattr(manager, 'instance'): + return RelatedManagerWrapper(self._changeset, manager, self._author) return ManagerWrapper(self._changeset, manager, self._author) def _wrap_queryset(self, queryset): @@ -257,11 +259,28 @@ class ManagerWrapper(BaseQueryWrapper): return self._wrap_queryset(self._obj.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))) -class ManyRelatedManagerWrapper(ManagerWrapper): +class RelatedManagerWrapper(ManagerWrapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _get_cache_name(self): + return self._obj.field.related_query_name() + + def all(self): + result = self.instance._prefetched_objects_cache.get(self._get_cache_name(), None) + if result is not None: + return result + super().all() + + +class ManyRelatedManagerWrapper(RelatedManagerWrapper): def _check_through(self): if not self._obj.through._meta.auto_created: raise AttributeError('Cannot do this an a ManyToManyField which specifies an intermediary model.') + def _get_cache_name(self): + return self._obj.prefetch_cache_name + def set(self, objs, author=None): if author is None: author = self._author @@ -278,7 +297,7 @@ class ManyRelatedManagerWrapper(ManagerWrapper): for obj in objs: pk = (obj.pk if isinstance(obj, self._obj.model) else obj) - self._changeset.add_m2m_add(self._obj.instance, name=self.prefetch_cache_name, value=pk, author=author) + self._changeset.add_m2m_add(self._obj.instance, name=self._get_cache_name(), value=pk, author=author) def remove(self, *objs, author=None): if author is None: @@ -286,11 +305,7 @@ class ManyRelatedManagerWrapper(ManagerWrapper): for obj in objs: pk = (obj.pk if isinstance(obj, self._obj.model) else obj) - self._changeset.add_m2m_remove(self._obj.instance, name=self.prefetch_cache_name, value=pk, author=author) - - def get_prefetch_queryset(self, instances, queryset=None): - value = self._obj.get_prefetch_queryset(instances, queryset=queryset) - return value + self._changeset.add_m2m_remove(self._obj.instance, name=self._get_cache_name(), value=pk, author=author) class QuerySetWrapper(BaseQueryWrapper):