BaseQueryWrapper.prefetch_related(): support Prefetch objects

This commit is contained in:
Laura Klünder 2017-07-13 21:45:57 +02:00
parent 9fff9e5dcb
commit c208e97ff1

View file

@ -307,6 +307,13 @@ class BaseQueryWrapper(BaseWrapper):
and convert them into Prefetch() objects with custom querysets. and convert them into Prefetch() objects with custom querysets.
This makes sure that the prefetch also happens on the virtually modified database. This makes sure that the prefetch also happens on the virtually modified database.
""" """
lookups_qs = {tuple(lookup.prefetch_through.split('__')): lookup.queryset for lookup in lookups
if isinstance(lookup, Prefetch) and lookup.queryset}
for qs in lookups_qs.values():
if not isinstance(qs, QuerySetWrapper):
raise TypeError('Prefetch object queryset needs to be wrapped!')
lookups = tuple((lookup.prefetch_through if isinstance(lookup, Prefetch) else lookup) for lookup in lookups)
lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups) lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups)
max_depth = max(len(lookup) for lookup in lookups_splitted) max_depth = max(len(lookup) for lookup in lookups_splitted)
lookups_by_depth = [] lookups_by_depth = []
@ -314,20 +321,21 @@ class BaseQueryWrapper(BaseWrapper):
lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i))) lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i)))
lookup_models = {(): self._obj.model} lookup_models = {(): self._obj.model}
lookup_querysets = {(): self._obj} lookup_querysets = {(): self.all()}
for depth_lookups in lookups_by_depth: for depth_lookups in lookups_by_depth:
for lookup in depth_lookups: for lookup in depth_lookups:
model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model
lookup_models[lookup] = model lookup_models[lookup] = model
lookup_querysets[lookup] = self._wrap_model(model).objects.all()._obj lookup_querysets[lookup] = lookups_qs.get(lookup, self._wrap_model(model).objects.all())
for depth_lookups in reversed(lookups_by_depth): for depth_lookups in reversed(lookups_by_depth):
for lookup in depth_lookups: for lookup in depth_lookups:
qs = self._wrap_queryset(lookup_querysets[lookup], created_pks=False) qs = lookup_querysets[lookup]
prefetch = Prefetch(lookup[-1], qs) prefetch = Prefetch(lookup[-1], qs)
lookup_querysets[lookup[:-1]] = lookup_querysets[lookup[:-1]].prefetch_related(prefetch) lower_qs = lookup_querysets[lookup[:-1]]
lower_qs._obj = lower_qs._obj.prefetch_related(prefetch)
return self._wrap_queryset(lookup_querysets[()]) return lookup_querysets[()]
def _clone(self, **kwargs): def _clone(self, **kwargs):
clone = self._wrap_queryset(self._obj) clone = self._wrap_queryset(self._obj)