From 3609420d8e1fa2612d30378619ab12be9f285815 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Tue, 13 Jun 2017 17:03:16 +0200 Subject: [PATCH] implement BaseQueryWrapper.prefetch_related --- src/c3nav/editor/wrappers.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index ee6ed6bf..7e9ebe54 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -1,5 +1,7 @@ +from collections import deque + from django.db import models -from django.db.models import Manager +from django.db.models import Manager, Prefetch from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor @@ -163,6 +165,8 @@ class ChangesQuerySet: class BaseQueryWrapper(BaseWrapper): + _allowed_callables = ('_add_hints', '_next_is_sticky') + def __init__(self, changeset, obj, author=None, changes_qs=None): super().__init__(changeset, obj, author) if changes_qs is None: @@ -183,8 +187,17 @@ class BaseQueryWrapper(BaseWrapper): def select_related(self, *args, **kwargs): return self._wrap_queryset(self._obj.select_related(*args, **kwargs)) - def prefetch_related(self, *args, **kwargs): - return self._wrap_queryset(self._obj.prefetch_related(*args, **kwargs)) + def prefetch_related(self, *lookups): + new_lookups = deque() + for lookup in lookups: + if not isinstance(lookup, str): + new_lookups.append(lookup) + continue + model = self._obj.model + for name in lookup.split('__'): + model = model._meta.get_field(name).related_model + new_lookups.append(Prefetch(lookup, self._wrap_model(model).objects.all())) + return self._wrap_queryset(self._obj.prefetch_related(*new_lookups)) def get(self, **kwargs): return self._wrap_instance(self._obj.get(**kwargs)) @@ -212,6 +225,9 @@ class BaseQueryWrapper(BaseWrapper): first = self._wrap_instance(first) return first + def using(self, alias): + return self._wrap_queryset(self._obj.using(alias)) + def __iter__(self): return iter([instance for instance in self._obj]) @@ -260,4 +276,6 @@ class ManyRelatedManagerWrapper(ManagerWrapper): class QuerySetWrapper(BaseQueryWrapper): - pass + @property + def _iterable_class(self): + return self._obj._iterable_class