reimplement wrapped prefetch_related

This commit is contained in:
Laura Klünder 2017-06-16 18:19:52 +02:00
parent 19856dfd8a
commit 8fef2a81a9
5 changed files with 35 additions and 20 deletions

View file

@ -88,6 +88,9 @@ class EditorViewSet(ViewSet):
self._get_level_geometries(level),
*(self._get_level_geometries(s) for s in levels_on_top)
)
results = tuple(results)
for result in results:
print(type(result).__name__)
return Response([obj.to_geojson() for obj in results])
elif space is not None:

View file

@ -121,7 +121,7 @@ class ChangeSet(models.Model):
setattr(obj, field.attname, value)
if isinstance(pk, str):
setattr(obj, class_value.cache_name, self.get_created_object(field.model, value))
elif get_foreign_objects:
elif get_foreign_objects or True:
setattr(obj, class_value.cache_name, self.wrap(field.related_model.objects.get(pk=value)))
continue

View file

@ -1,6 +1,5 @@
import operator
import typing
from collections import deque
from functools import reduce
from itertools import chain
@ -231,6 +230,8 @@ class BaseQueryWrapper(BaseWrapper):
def _wrap_queryset(self, queryset, created_pks=None, wrap_instances=None):
if created_pks is None:
created_pks = self._created_pks
if created_pks is False:
created_pks = None
if wrap_instances is None:
wrap_instances = self._wrap_instances
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, wrap_instances)
@ -239,24 +240,38 @@ class BaseQueryWrapper(BaseWrapper):
return self._wrap_queryset(self.get_queryset().all())
def none(self):
return self._wrap_queryset(self.get_queryset().none(), created_pks=set())
return self._wrap_queryset(self.get_queryset().none())
def select_related(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs))
def prefetch_related(self, *lookups):
new_lookups = deque()
for lookup in lookups:
if not isinstance(lookup, str):
new_lookups.append(lookup)
continue
model = self._obj.model
for name in lookup.split('__'):
model = model._meta.get_field(name).related_model
qs = self._wrap_model(model).objects.all()
qs._wrap_instances = False
new_lookups.append(Prefetch(lookup, qs))
return self._wrap_queryset(self.get_queryset().prefetch_related(*new_lookups))
lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups)
max_depth = max(len(lookup) for lookup in lookups_splitted)
lookups_by_depth = []
for i in range(max_depth):
lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i)))
lookup_models = {(): self._obj.model}
lookup_querysets = {(): self.get_queryset()}
for depth_lookups in lookups_by_depth:
for lookup in depth_lookups:
model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model
lookup_models[lookup] = model
lookup_querysets[lookup] = self._wrap_model(model).objects.all()._obj
for depth_lookups in reversed(lookups_by_depth):
for lookup in depth_lookups:
qs = self._wrap_queryset(lookup_querysets[lookup], wrap_instances=True, created_pks=False)
prefetch = Prefetch(lookup[-1], qs)
lookup_querysets[lookup[:-1]] = lookup_querysets[lookup[:-1]].prefetch_related(prefetch)
return self._wrap_queryset(lookup_querysets[()])
def _clone(self, **kwargs):
clone = self._wrap_queryset(self.get_queryset())
clone._obj.__dict__.update(kwargs)
return clone
def get(self, *args, **kwargs):
results = tuple(self.filter(*args, **kwargs))
@ -367,7 +382,6 @@ class BaseQueryWrapper(BaseWrapper):
if filter_type == 'pk':
if class_value.reverse:
# todo: implement this for created models
model = class_value.field.model
if self.is_created_pk(filter_value):
@ -450,8 +464,6 @@ class BaseQueryWrapper(BaseWrapper):
return self._wrap_queryset(self.get_queryset().using(alias))
def _get_created_objects(self):
if not self._wrap_instances:
return ()
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
for pk in sorted(self._created_pks))

View file

@ -26,7 +26,7 @@ class SerializableMixin(models.Model):
result = OrderedDict()
if include_type:
result['type'] = self.__class__.__name__.lower()
result['id'] = self.id
result['id'] = self.pk
return result

View file

@ -20,7 +20,7 @@ class GeometryMixin(EditorFormMixin):
def get_geojson_properties(self) -> dict:
result = OrderedDict((
('type', self.__class__.__name__.lower()),
('id', self.id),
('id', self.pk),
))
if getattr(self, 'bounds', False):
result['bounds'] = True