add created object to querysets (not yet fully working with select_related)

This commit is contained in:
Laura Klünder 2017-06-16 16:03:51 +02:00
parent 1d9564568b
commit 19856dfd8a
4 changed files with 128 additions and 51 deletions

View file

@ -74,26 +74,31 @@ class ChangeSet(models.Model):
self.created_objects[model][change.created_object_id][name].remove(value) self.created_objects[model][change.created_object_id][name].remove(value)
return return
pk = change.existing_object_pk
if change.action == 'update': if change.action == 'update':
self.updated_existing.setdefault(model, {}).setdefault(change.obj_pk, {})[name] = value self.updated_existing.setdefault(model, {}).setdefault(pk, {})[name] = value
elif change.action == 'm2m_add': elif change.action == 'm2m_add':
m2m_remove_existing = self.m2m_remove_existing.get(model, {}).get(change.obj_pk, ()) m2m_remove_existing = self.m2m_remove_existing.get(model, {}).get(pk, {}).get(name, ())
if value in m2m_remove_existing: if value in m2m_remove_existing:
m2m_remove_existing.remove(value) m2m_remove_existing.remove(value)
else: else:
self.m2m_add_existing.setdefault(model, {}).setdefault(change.obj_pk, set()).add(value) self.m2m_add_existing.setdefault(model, {}).setdefault(pk, {}).setdefault(name, set()).add(value)
elif change.action == 'm2m_remove': elif change.action == 'm2m_remove':
m2m_add_existing = self.m2m_add_existing.get(model, {}).get(change.obj_pk, ()) m2m_add_existing = self.m2m_add_existing.get(model, {}).get(pk, {}).get(name, ())
if value in m2m_add_existing: if value in m2m_add_existing:
m2m_add_existing.remove(value) m2m_add_existing.remove(value)
else: else:
self.m2m_remove_existing.setdefault(model, {}).setdefault(change.obj_pk, set()).add(value) self.m2m_remove_existing.setdefault(model, {}).setdefault(pk, {}).setdefault(name, set()).add(value)
def get_changed_values(self, model, name): def get_changed_values(self, model, name):
r = tuple((pk, values[name]) for pk, values in self.updated_existing.get(model, {}).items() if name in values) r = tuple((pk, values[name]) for pk, values in self.updated_existing.get(model, {}).items() if name in values)
return r return r
def get_created_object(self, model, pk, author=None): def get_created_values(self, model, name):
r = tuple((pk, values[name]) for pk, values in self.updated_existing.get(model, {}).items() if name in values)
return r
def get_created_object(self, model, pk, author=None, get_foreign_objects=False):
if isinstance(pk, str): if isinstance(pk, str):
pk = int(pk[1:]) pk = int(pk[1:])
self.parse_changes() self.parse_changes()
@ -112,14 +117,22 @@ class ChangeSet(models.Model):
continue continue
if isinstance(class_value, ForwardManyToOneDescriptor): if isinstance(class_value, ForwardManyToOneDescriptor):
setattr(obj, class_value.field.attname, pk) field = class_value.field
setattr(obj, field.attname, value)
if isinstance(pk, str): if isinstance(pk, str):
setattr(obj, class_value.cache_name, self.get_created_object(class_value.field.model, pk)) setattr(obj, class_value.cache_name, self.get_created_object(field.model, value))
elif get_foreign_objects:
setattr(obj, class_value.cache_name, self.wrap(field.related_model.objects.get(pk=value)))
continue continue
setattr(obj, name, model._meta.get_field(name).to_python(value)) setattr(obj, name, model._meta.get_field(name).to_python(value))
return self.wrap(obj, author=author) return self.wrap(obj, author=author)
def get_created_pks(self, model):
if issubclass(model, ModelWrapper):
model = model._obj
return set(self.created_objects.get(model, {}).keys())
@property @property
def cache_key(self): def cache_key(self):
return str(self.pk)+'-'+str(self._last_change_pk) return str(self.pk)+'-'+str(self._last_change_pk)

View file

@ -10,7 +10,7 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit
if parent_model_name: if parent_model_name:
parent_model = apps.get_model('mapdata', parent_model_name) parent_model = apps.get_model('mapdata', parent_model_name)
parent_model_name_plural = parent_model._meta.default_related_name parent_model_name_plural = parent_model._meta.default_related_name
prefix = (parent_model_name_plural+r'/(?P<'+parent_model_name.lower()+'>[0-9]+)/')+model_name_plural prefix = (parent_model_name_plural+r'/(?P<'+parent_model_name.lower()+'>c?[0-9]+)/')+model_name_plural
else: else:
prefix = model_name_plural prefix = model_name_plural
@ -22,7 +22,7 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit
if with_list: if with_list:
result.append(url(r'^'+prefix+r'/$', list_objects, name=name_prefix+'list', kwargs=kwargs)) result.append(url(r'^'+prefix+r'/$', list_objects, name=name_prefix+'list', kwargs=kwargs))
result.extend([ result.extend([
url(r'^'+prefix+r'/(?P<pk>\d+)/'+explicit_edit+'$', edit, name=name_prefix+'edit', kwargs=kwargs), url(r'^'+prefix+r'/(?P<pk>c?\d+)/'+explicit_edit+'$', edit, name=name_prefix+'edit', kwargs=kwargs),
url(r'^'+prefix+r'/create$', edit, name=name_prefix+'create', kwargs=kwargs), url(r'^'+prefix+r'/create$', edit, name=name_prefix+'create', kwargs=kwargs),
]) ])
return result return result
@ -30,9 +30,9 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit
urlpatterns = [ urlpatterns = [
url(r'^$', main_index, name='editor.index'), url(r'^$', main_index, name='editor.index'),
url(r'^levels/(?P<pk>[0-9]+)/$', level_detail, name='editor.levels.detail'), url(r'^levels/(?P<pk>c?[0-9]+)/$', level_detail, name='editor.levels.detail'),
url(r'^levels/(?P<level>[0-9]+)/spaces/(?P<pk>[0-9]+)/$', space_detail, name='editor.spaces.detail'), url(r'^levels/(?P<level>c?[0-9]+)/spaces/(?P<pk>c?[0-9]+)/$', space_detail, name='editor.spaces.detail'),
url(r'^levels/(?P<on_top_of>[0-9]+)/levels_on_top/create$', edit, name='editor.levels_on_top.create', url(r'^levels/(?P<on_top_of>c?[0-9]+)/levels_on_top/create$', edit, name='editor.levels_on_top.create',
kwargs={'model': 'Level'}), kwargs={'model': 'Level'}),
url(r'^changesets/(?P<pk>[0-9]+)/$', changeset_detail, name='editor.changesets.detail'), url(r'^changesets/(?P<pk>[0-9]+)/$', changeset_detail, name='editor.changesets.detail'),
] ]

View file

@ -72,6 +72,10 @@ def level_detail(request, pk):
def space_detail(request, level, pk): def space_detail(request, level, pk):
Space = request.changeset.wrap('Space') Space = request.changeset.wrap('Space')
space = get_object_or_404(Space.objects.select_related('level'), level__pk=level, pk=pk) space = get_object_or_404(Space.objects.select_related('level'), level__pk=level, pk=pk)
print('also!')
print(Space.objects.select_related('level').get(level__pk=level, pk=pk))
print(space)
print('aha')
return render(request, 'editor/space.html', { return render(request, 'editor/space.html', {
'level': space.level, 'level': space.level,
@ -293,6 +297,8 @@ def list_objects(request, model=None, level=None, space=None, explicit_edit=Fals
edit_url_name = request.resolver_match.url_name[:-4]+('detail' if explicit_edit else 'edit') edit_url_name = request.resolver_match.url_name[:-4]+('detail' if explicit_edit else 'edit')
for obj in queryset: for obj in queryset:
reverse_kwargs['pk'] = obj.pk reverse_kwargs['pk'] = obj.pk
print(reverse_kwargs)
print(reverse(edit_url_name, kwargs=reverse_kwargs))
obj.edit_url = reverse(edit_url_name, kwargs=reverse_kwargs) obj.edit_url = reverse(edit_url_name, kwargs=reverse_kwargs)
reverse_kwargs.pop('pk', None) reverse_kwargs.pop('pk', None)

View file

@ -1,15 +1,18 @@
import operator
import typing import typing
from collections import deque from collections import deque
from functools import reduce
from itertools import chain from itertools import chain
from django.db import models from django.db import models
from django.db.models import Manager, Prefetch, Q from django.db.models import Manager, Prefetch, Q
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor
from django.db.models.query_utils import DeferredAttribute from django.db.models.query_utils import DeferredAttribute
from django.utils.functional import cached_property
class BaseWrapper: class BaseWrapper:
_not_wrapped = ('_changeset', '_author', '_obj', '_changes_qs', '_initial_values', '_wrap_instances') _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_initial_values', '_wrap_instances')
_allowed_callables = ('', ) _allowed_callables = ('', )
def __init__(self, changeset, obj, author=None): def __init__(self, changeset, obj, author=None):
@ -206,22 +209,16 @@ class ModelInstanceWrapper(BaseWrapper):
self._changeset.add_delete(self, author=author) self._changeset.add_delete(self, author=author)
class ChangesQuerySet:
def __init__(self, changeset, model, author):
self._changeset = changeset
self._model = model
self._author = author
class BaseQueryWrapper(BaseWrapper): class BaseQueryWrapper(BaseWrapper):
_allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset') _allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset')
def __init__(self, changeset, obj, author=None, changes_qs=None, wrap_instances=True): def __init__(self, changeset, obj, author=None, created_pks=None, wrap_instances=True):
super().__init__(changeset, obj, author) super().__init__(changeset, obj, author)
if changes_qs is None: if created_pks is None:
changes_qs = ChangesQuerySet(changeset, obj.model, author) created_pks = self._changeset.get_created_pks(self._obj.model)
self._changes_qs = changes_qs self._created_pks = created_pks
self._wrap_instances = wrap_instances self._wrap_instances = wrap_instances
self._result = None
def get_queryset(self): def get_queryset(self):
return self._obj return self._obj
@ -231,18 +228,18 @@ class BaseQueryWrapper(BaseWrapper):
return super()._wrap_instance(instance) return super()._wrap_instance(instance)
return instance return instance
def _wrap_queryset(self, queryset, changes_qs=None, wrap_instances=None): def _wrap_queryset(self, queryset, created_pks=None, wrap_instances=None):
if changes_qs is None: if created_pks is None:
changes_qs = self._changes_qs created_pks = self._created_pks
if wrap_instances is None: if wrap_instances is None:
wrap_instances = self._wrap_instances wrap_instances = self._wrap_instances
return QuerySetWrapper(self._changeset, queryset, self._author, changes_qs, wrap_instances) return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, wrap_instances)
def all(self): def all(self):
return self._wrap_queryset(self.get_queryset().all()) return self._wrap_queryset(self.get_queryset().all())
def none(self): def none(self):
return self._wrap_queryset(self.get_queryset().none()) return self._wrap_queryset(self.get_queryset().none(), created_pks=set())
def select_related(self, *args, **kwargs): def select_related(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs)) return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs))
@ -266,8 +263,8 @@ class BaseQueryWrapper(BaseWrapper):
if len(results) == 1: if len(results) == 1:
return self._wrap_instance(results[0]) return self._wrap_instance(results[0])
if results: if results:
raise self._obj.model.DoesNotExist raise self._obj.model.MultipleObjectsReturned
raise self._obj.model.MultipleObjectsReturned raise self._obj.model.DoesNotExist
def order_by(self, *args): def order_by(self, *args):
return self._wrap_queryset(self.get_queryset().order_by(*args)) return self._wrap_queryset(self.get_queryset().order_by(*args))
@ -278,7 +275,21 @@ class BaseQueryWrapper(BaseWrapper):
remove_pks = [] remove_pks = []
for pk, new_value in other_values: for pk, new_value in other_values:
(add_pks if check(new_value) else remove_pks).append(pk) (add_pks if check(new_value) else remove_pks).append(pk)
return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks) created_pks = set()
for pk, values in self._changeset.created_objects.get(self._obj.model, {}).items():
try:
if check(values[field_name]):
created_pks.add(pk)
continue
except AttributeError:
pass
if check(getattr(self._changeset.get_created_object(self._obj.model, pk), field_name)):
created_pks.add(pk)
return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks
@staticmethod
def is_created_pk(pk):
return isinstance(pk, str) and pk.startswith('c') and pk[1:].isnumeric()
def _filter_kwarg(self, filter_name, filter_value): def _filter_kwarg(self, filter_name, filter_value):
print(filter_name, '=', filter_value, sep='') print(filter_name, '=', filter_value, sep='')
@ -294,9 +305,12 @@ class BaseQueryWrapper(BaseWrapper):
if field_name == 'pk' or field_name == self._obj.model._meta.pk.name: if field_name == 'pk' or field_name == self._obj.model._meta.pk.name:
if not segments: if not segments:
return q if self.is_created_pk(filter_value):
else: return Q(pk__in=()), set([int(filter_value[1:])])
return q return q, set()
elif segments == ['in']:
return (Q(pk__in=tuple(pk for pk in filter_value if not self.is_created_pk(pk))),
set(int(pk[1:]) for pk in filter_value if self.is_created_pk(pk)))
if isinstance(class_value, ForwardManyToOneDescriptor): if isinstance(class_value, ForwardManyToOneDescriptor):
if not segments: if not segments:
@ -318,13 +332,18 @@ class BaseQueryWrapper(BaseWrapper):
filter_type = 'pk' filter_type = 'pk'
if filter_type == 'pk' and segments == ['in']: if filter_type == 'pk' and segments == ['in']:
return self._filter_values(q, field_name, lambda val: val in filter_value) q = Q(**{field_name+'__pk__in': tuple(pk for pk in filter_value if not self.is_created_pk(pk))})
filter_value = tuple(str(pk) for pk in filter_value)
return self._filter_values(q, field_name, lambda val: str(val) in filter_value)
if segments: if segments:
raise NotImplementedError raise NotImplementedError
if filter_type == 'pk': if filter_type == 'pk':
return self._filter_values(q, field_name, lambda val: val == filter_value) if self.is_created_pk(filter_value):
q = Q(pk__in=())
filter_value = str(filter_value)
return self._filter_values(q, field_name, lambda val: str(val) == filter_value)
if filter_type == 'isnull': if filter_type == 'isnull':
return self._filter_values(q, field_name, lambda val: (val is None) is filter_value) return self._filter_values(q, field_name, lambda val: (val is None) is filter_value)
@ -348,9 +367,23 @@ class BaseQueryWrapper(BaseWrapper):
if filter_type == 'pk': if filter_type == 'pk':
if class_value.reverse: if class_value.reverse:
# todo: implement this for created models
model = class_value.field.model model = class_value.field.model
return ((q & ~Q(pk__in=self._changeset.m2m_remove_existing.get(model, {}).get(filter_value, ()))) |
Q(pk__in=self._changeset.m2m_add_existing.get(model, {}).get(filter_value, ()))) if self.is_created_pk(filter_value):
filter_value = int(filter_value[1:])
pks = tuple(self._changeset.created_objects[model][filter_value].get(field_name, ()))
return (Q(pk__in=(pk for pk in pks if not self.is_created_pk(pk))),
set(int(pk[1:]) for pk in pks if self.is_created_pk(pk)))
def get_changeset_m2m(items):
return items.get(model, {}).get(filter_value, {}).get(field_name, ())
remove_pks = get_changeset_m2m(self._changeset.m2m_remove_existing)
add_pks = get_changeset_m2m(self._changeset.m2m_add_existing)
return (((q & ~Q(pk__in=(pk for pk in remove_pks if not self.is_created_pk(pk)))) |
Q(pk__in=(pk for pk in remove_pks if not self.is_created_pk(pk)))),
set(int(pk[1:]) for pk in add_pks if self.is_created_pk(pk)))
raise NotImplementedError raise NotImplementedError
@ -373,25 +406,36 @@ class BaseQueryWrapper(BaseWrapper):
raise NotImplementedError('cannot filter %s by %s (%s)' % (self._obj.model, filter_name, class_value)) raise NotImplementedError('cannot filter %s by %s (%s)' % (self._obj.model, filter_name, class_value))
def _filter_q(self, q): def _filter_q(self, q):
result = Q(*((self._filter_q(c) if isinstance(c, Q) else self._filter_kwarg(*c)) for c in q.children)) filters, created_pks = zip(*((self._filter_q(c) if isinstance(c, Q) else self._filter_kwarg(*c))
for c in q.children))
result = Q(*filters)
result.connector = q.connector result.connector = q.connector
result.negated = q.negated result.negated = q.negated
return result
def _filter(self, *args, **kwargs): created_pks = reduce(operator.and_ if q.connector == 'AND' else operator.or_, created_pks)
return chain( if q.negated:
created_pks = self._changeset.get_created_pks(self._obj.model)-created_pks
return result, created_pks
def _filter_or_exclude(self, negate, *args, **kwargs):
filters, created_pks = zip(*tuple(chain(
tuple(self._filter_q(q) for q in args), tuple(self._filter_q(q) for q in args),
tuple(self._filter_kwarg(name, value) for name, value in kwargs.items()) tuple(self._filter_kwarg(name, value) for name, value in kwargs.items())
) )))
created_pks = reduce(operator.and_, created_pks)
if negate:
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks
return self._wrap_queryset(self.get_queryset().filter(*filters), created_pks=(self._created_pks & created_pks))
def filter(self, *args, **kwargs): def filter(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().filter(*self._filter(*args, **kwargs))) return self._filter_or_exclude(False, *args, **kwargs)
def exclude(self, *args, **kwargs): def exclude(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().exclude(*self._filter(*args, **kwargs))) return self._filter_or_exclude(True, *args, **kwargs)
def count(self): def count(self):
return self.get_queryset().count() return self.get_queryset().count()+len(tuple(self._get_created_objects()))
def values_list(self, *args, flat=False): def values_list(self, *args, flat=False):
return self.get_queryset().values_list(*args, flat=flat) return self.get_queryset().values_list(*args, flat=flat)
@ -405,14 +449,28 @@ class BaseQueryWrapper(BaseWrapper):
def using(self, alias): def using(self, alias):
return self._wrap_queryset(self.get_queryset().using(alias)) return self._wrap_queryset(self.get_queryset().using(alias))
def _get_created_objects(self):
if not self._wrap_instances:
return ()
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
for pk in sorted(self._created_pks))
@cached_property
def _cached_result(self):
return (tuple(self._wrap_instance(instance) for instance in self.get_queryset()) +
tuple(self._get_created_objects()))
def __iter__(self): def __iter__(self):
return iter((self._wrap_instance(instance) for instance in self.get_queryset())) return iter(self._cached_result)
def iterator(self): def iterator(self):
return iter((self._wrap_instance(instance) for instance in self.get_queryset().iterator())) return iter(chain(
(self._wrap_instance(instance) for instance in self.get_queryset().iterator()),
self._get_created_objects(),
))
def __len__(self): def __len__(self):
return len(self.get_queryset()) return len(self._cached_result)
def create(self, *args, **kwargs): def create(self, *args, **kwargs):
obj = self.model(*args, **kwargs) obj = self.model(*args, **kwargs)