move querset_only methods to QuerySetWrapper

This commit is contained in:
Laura Klünder 2017-06-21 19:06:36 +02:00
parent da543f8ee1
commit 9159e8f6b9

View file

@ -338,19 +338,6 @@ def get_queryset(func):
return wrapper
def queryset_only(func):
"""
Wraps methods of BaseQueryWrapper that execute a queryset.
If self is a Manager, they throw an error, because you have to get a Queryset (e.g. using .all()) first.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, 'get_queryset'):
raise TypeError
return func(self, *args, **kwargs)
return wrapper
class BaseQueryWrapper(BaseWrapper):
"""
Base class for everything that wraps a QuerySet or manager.
@ -747,102 +734,11 @@ class BaseQueryWrapper(BaseWrapper):
"""
return self._wrap_queryset(self._obj._next_is_sticky())
def _get_created_objects(self, get_foreign_objects=True):
"""
Get ModelInstanceWrapper instance for all matched created objects.
"""
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=get_foreign_objects)
for pk in sorted(self._created_pks))
@queryset_only
def _get_cached_result(self):
"""
Get results, make sure prefetch is prefetching and so on.
"""
obj = self._obj
obj._prefetch_done = True
obj._fetch_all()
result = [self._wrap_instance(instance) for instance in obj._result_cache]
obj._result_cache = result
obj._prefetch_done = False
obj._fetch_all()
result += list(self._get_created_objects())
for extra in self._extra:
# implementing the extra() call for prefetch_related
ex = extra[22:]
for f in self._obj.model._meta.get_fields():
if isinstance(f, ManyToManyRel) and f.through._meta.get_field(f.field.m2m_field_name()).attname == ex:
objs_by_pk = OrderedDict()
for instance in result:
objs_by_pk.setdefault(instance.pk, OrderedDict())[getattr(instance, extra, None)] = instance
m2m_added = self._changeset.m2m_added.get(f.field.model, {})
m2m_removed = self._changeset.m2m_removed.get(f.field.model, {})
for related_pk, changes in m2m_added.items():
for pk in changes.get(f.field.name, ()):
if pk in objs_by_pk and related_pk not in objs_by_pk[pk]:
new = self._wrap_instance(next(iter(objs_by_pk[pk].values()))._obj)
new.__dict__[extra] = related_pk
objs_by_pk[pk][related_pk] = new
for related_pk, changes in m2m_removed.items():
for pk in changes.get(f.field.name, ()):
if pk in objs_by_pk and related_pk in objs_by_pk[pk]:
objs_by_pk[pk].pop(related_pk)
for pk, instances in objs_by_pk.items():
instances.pop(None, None)
result = list(chain(*(instances.values() for instances in objs_by_pk.values())))
break
else:
raise NotImplementedError('Cannot do extra() for '+extra)
obj._result_cache = result
return result
@cached_property
def _cached_result(self):
return self._get_cached_result()
@property
def _result_cache(self):
return self._cached_result
@_result_cache.setter
def _result_cache(self, value):
# prefetch_related will try to set this property
# it has to overwrite our final result because it already contains the created objects
self.__dict__['_cached_result'] = value
@queryset_only
def __iter__(self):
return iter(self._cached_result)
@queryset_only
def iterator(self):
return iter(chain(
(self._wrap_instance(instance) for instance in self._obj.iterator()),
self._get_created_objects(),
))
@queryset_only
def __len__(self):
return len(self._cached_result)
def create(self, *args, **kwargs):
obj = self.model(*args, **kwargs)
obj.save()
return obj
@get_queryset
def delete(self):
for obj in self:
obj.delete()
class ManagerWrapper(BaseQueryWrapper):
"""
@ -856,6 +752,9 @@ class ManagerWrapper(BaseQueryWrapper):
qs = self._wrap_queryset(self._obj.model.objects.all())
return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
def delete(self):
self.get_queryset().delete()
class RelatedManagerWrapper(ManagerWrapper):
"""
@ -942,6 +841,92 @@ class QuerySetWrapper(BaseQueryWrapper):
"""
Wraps a queryset.
"""
def _get_created_objects(self, get_foreign_objects=True):
"""
Get ModelInstanceWrapper instance for all matched created objects.
"""
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=get_foreign_objects)
for pk in sorted(self._created_pks))
def _get_cached_result(self):
"""
Get results, make sure prefetch is prefetching and so on.
"""
obj = self._obj
obj._prefetch_done = True
obj._fetch_all()
result = [self._wrap_instance(instance) for instance in obj._result_cache]
obj._result_cache = result
obj._prefetch_done = False
obj._fetch_all()
result += list(self._get_created_objects())
for extra in self._extra:
# implementing the extra() call for prefetch_related
ex = extra[22:]
for f in self._obj.model._meta.get_fields():
if isinstance(f, ManyToManyRel) and f.through._meta.get_field(f.field.m2m_field_name()).attname == ex:
objs_by_pk = OrderedDict()
for instance in result:
objs_by_pk.setdefault(instance.pk, OrderedDict())[getattr(instance, extra, None)] = instance
m2m_added = self._changeset.m2m_added.get(f.field.model, {})
m2m_removed = self._changeset.m2m_removed.get(f.field.model, {})
for related_pk, changes in m2m_added.items():
for pk in changes.get(f.field.name, ()):
if pk in objs_by_pk and related_pk not in objs_by_pk[pk]:
new = self._wrap_instance(next(iter(objs_by_pk[pk].values()))._obj)
new.__dict__[extra] = related_pk
objs_by_pk[pk][related_pk] = new
for related_pk, changes in m2m_removed.items():
for pk in changes.get(f.field.name, ()):
if pk in objs_by_pk and related_pk in objs_by_pk[pk]:
objs_by_pk[pk].pop(related_pk)
for pk, instances in objs_by_pk.items():
instances.pop(None, None)
result = list(chain(*(instances.values() for instances in objs_by_pk.values())))
break
else:
raise NotImplementedError('Cannot do extra() for ' + extra)
obj._result_cache = result
return result
@cached_property
def _cached_result(self):
return self._get_cached_result()
@property
def _result_cache(self):
return self._cached_result
@_result_cache.setter
def _result_cache(self, value):
# prefetch_related will try to set this property
# it has to overwrite our final result because it already contains the created objects
self.__dict__['_cached_result'] = value
def __iter__(self):
return iter(self._cached_result)
def iterator(self):
return iter(chain(
(self._wrap_instance(instance) for instance in self._obj.iterator()),
self._get_created_objects(),
))
def __len__(self):
return len(self._cached_result)
def delete(self):
for obj in self:
obj.delete()
@property
def _iterable_class(self):
return self._obj._iterable_class