fix m2m prefetch_related with newly created objects

This commit is contained in:
Laura Klünder 2017-06-17 22:11:20 +02:00
parent f0d4d122da
commit ec18984875

View file

@ -515,8 +515,7 @@ class BaseQueryWrapper(BaseWrapper):
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
for pk in sorted(self._created_pks))
@cached_property
def _cached_result(self):
def _get_cached_result(self):
obj = self.get_queryset()
obj._prefetch_done = True
obj._fetch_all()
@ -529,13 +528,13 @@ class BaseQueryWrapper(BaseWrapper):
if isinstance(f, ManyToManyRel) and f.through._meta.get_field(f.field.m2m_field_name()).attname == ex:
objs_by_pk = OrderedDict()
for instance in result:
objs_by_pk.setdefault(instance.pk, OrderedDict())[getattr(instance, extra)] = instance
objs_by_pk.setdefault(instance.pk, OrderedDict())[getattr(instance, extra, None)] = instance
m2m_added = self._changeset.m2m_added.get(f.field.model, {})
m2m_removed = self._changeset.m2m_removed.get(f.field.model, {})
for related_pk, changes in m2m_added.items():
for pk in changes.get(f.field.name, ()):
if related_pk not in objs_by_pk[pk]:
if pk in objs_by_pk and related_pk not in objs_by_pk[pk]:
print('added', pk, 'to', related_pk)
new = self._wrap_instance(next(iter(objs_by_pk[pk].values()))._obj)
new.__dict__[extra] = related_pk
@ -543,23 +542,31 @@ class BaseQueryWrapper(BaseWrapper):
for related_pk, changes in m2m_removed.items():
for pk in changes.get(f.field.name, ()):
if related_pk in objs_by_pk[pk]:
if pk in objs_by_pk and related_pk in objs_by_pk[pk]:
print('removed', pk, 'from', related_pk)
objs_by_pk[pk].pop(related_pk)
for pk, instances in objs_by_pk.items():
instances.pop(None, None)
result = list(chain(*(instances.values() for instances in objs_by_pk.values())))
break
obj._result_cache = result
obj._prefetch_done = False
obj._fetch_all()
return [self._wrap_instance(instance) for instance in obj._result_cache] + list(self._get_created_objects())
return result
@cached_property
def _cached_result(self):
return self._get_cached_result()
@property
def _results_cache(self):
return self._cached_result
def __iter__(self):
self._get_cached_result()
return iter(self._cached_result)
def iterator(self):