don't break prefetch_related

This commit is contained in:
Laura Klünder 2017-06-13 22:07:36 +02:00
parent 3e36f5b7a3
commit 38baebb536

View file

@ -24,6 +24,8 @@ class BaseWrapper:
assert isinstance(manager, Manager)
if hasattr(manager, 'through'):
return ManyRelatedManagerWrapper(self._changeset, manager, self._author)
if hasattr(manager, 'instance'):
return RelatedManagerWrapper(self._changeset, manager, self._author)
return ManagerWrapper(self._changeset, manager, self._author)
def _wrap_queryset(self, queryset):
@ -257,11 +259,28 @@ class ManagerWrapper(BaseQueryWrapper):
return self._wrap_queryset(self._obj.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ())))
class ManyRelatedManagerWrapper(ManagerWrapper):
class RelatedManagerWrapper(ManagerWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _get_cache_name(self):
return self._obj.field.related_query_name()
def all(self):
result = self.instance._prefetched_objects_cache.get(self._get_cache_name(), None)
if result is not None:
return result
super().all()
class ManyRelatedManagerWrapper(RelatedManagerWrapper):
def _check_through(self):
if not self._obj.through._meta.auto_created:
raise AttributeError('Cannot do this an a ManyToManyField which specifies an intermediary model.')
def _get_cache_name(self):
return self._obj.prefetch_cache_name
def set(self, objs, author=None):
if author is None:
author = self._author
@ -278,7 +297,7 @@ class ManyRelatedManagerWrapper(ManagerWrapper):
for obj in objs:
pk = (obj.pk if isinstance(obj, self._obj.model) else obj)
self._changeset.add_m2m_add(self._obj.instance, name=self.prefetch_cache_name, value=pk, author=author)
self._changeset.add_m2m_add(self._obj.instance, name=self._get_cache_name(), value=pk, author=author)
def remove(self, *objs, author=None):
if author is None:
@ -286,11 +305,7 @@ class ManyRelatedManagerWrapper(ManagerWrapper):
for obj in objs:
pk = (obj.pk if isinstance(obj, self._obj.model) else obj)
self._changeset.add_m2m_remove(self._obj.instance, name=self.prefetch_cache_name, value=pk, author=author)
def get_prefetch_queryset(self, instances, queryset=None):
value = self._obj.get_prefetch_queryset(instances, queryset=queryset)
return value
self._changeset.add_m2m_remove(self._obj.instance, name=self._get_cache_name(), value=pk, author=author)
class QuerySetWrapper(BaseQueryWrapper):