add created object to querysets (not yet fully working with select_related)

This commit is contained in:
Laura Klünder 2017-06-16 16:03:51 +02:00
parent 1d9564568b
commit 19856dfd8a
4 changed files with 128 additions and 51 deletions

View file

@ -74,26 +74,31 @@ class ChangeSet(models.Model):
self.created_objects[model][change.created_object_id][name].remove(value)
return
pk = change.existing_object_pk
if change.action == 'update':
self.updated_existing.setdefault(model, {}).setdefault(change.obj_pk, {})[name] = value
self.updated_existing.setdefault(model, {}).setdefault(pk, {})[name] = value
elif change.action == 'm2m_add':
m2m_remove_existing = self.m2m_remove_existing.get(model, {}).get(change.obj_pk, ())
m2m_remove_existing = self.m2m_remove_existing.get(model, {}).get(pk, {}).get(name, ())
if value in m2m_remove_existing:
m2m_remove_existing.remove(value)
else:
self.m2m_add_existing.setdefault(model, {}).setdefault(change.obj_pk, set()).add(value)
self.m2m_add_existing.setdefault(model, {}).setdefault(pk, {}).setdefault(name, set()).add(value)
elif change.action == 'm2m_remove':
m2m_add_existing = self.m2m_add_existing.get(model, {}).get(change.obj_pk, ())
m2m_add_existing = self.m2m_add_existing.get(model, {}).get(pk, {}).get(name, ())
if value in m2m_add_existing:
m2m_add_existing.remove(value)
else:
self.m2m_remove_existing.setdefault(model, {}).setdefault(change.obj_pk, set()).add(value)
self.m2m_remove_existing.setdefault(model, {}).setdefault(pk, {}).setdefault(name, set()).add(value)
def get_changed_values(self, model, name):
r = tuple((pk, values[name]) for pk, values in self.updated_existing.get(model, {}).items() if name in values)
return r
def get_created_object(self, model, pk, author=None):
def get_created_values(self, model, name):
r = tuple((pk, values[name]) for pk, values in self.updated_existing.get(model, {}).items() if name in values)
return r
def get_created_object(self, model, pk, author=None, get_foreign_objects=False):
if isinstance(pk, str):
pk = int(pk[1:])
self.parse_changes()
@ -112,14 +117,22 @@ class ChangeSet(models.Model):
continue
if isinstance(class_value, ForwardManyToOneDescriptor):
setattr(obj, class_value.field.attname, pk)
field = class_value.field
setattr(obj, field.attname, value)
if isinstance(pk, str):
setattr(obj, class_value.cache_name, self.get_created_object(class_value.field.model, pk))
setattr(obj, class_value.cache_name, self.get_created_object(field.model, value))
elif get_foreign_objects:
setattr(obj, class_value.cache_name, self.wrap(field.related_model.objects.get(pk=value)))
continue
setattr(obj, name, model._meta.get_field(name).to_python(value))
return self.wrap(obj, author=author)
def get_created_pks(self, model):
if issubclass(model, ModelWrapper):
model = model._obj
return set(self.created_objects.get(model, {}).keys())
@property
def cache_key(self):
return str(self.pk)+'-'+str(self._last_change_pk)

View file

@ -10,7 +10,7 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit
if parent_model_name:
parent_model = apps.get_model('mapdata', parent_model_name)
parent_model_name_plural = parent_model._meta.default_related_name
prefix = (parent_model_name_plural+r'/(?P<'+parent_model_name.lower()+'>[0-9]+)/')+model_name_plural
prefix = (parent_model_name_plural+r'/(?P<'+parent_model_name.lower()+'>c?[0-9]+)/')+model_name_plural
else:
prefix = model_name_plural
@ -22,7 +22,7 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit
if with_list:
result.append(url(r'^'+prefix+r'/$', list_objects, name=name_prefix+'list', kwargs=kwargs))
result.extend([
url(r'^'+prefix+r'/(?P<pk>\d+)/'+explicit_edit+'$', edit, name=name_prefix+'edit', kwargs=kwargs),
url(r'^'+prefix+r'/(?P<pk>c?\d+)/'+explicit_edit+'$', edit, name=name_prefix+'edit', kwargs=kwargs),
url(r'^'+prefix+r'/create$', edit, name=name_prefix+'create', kwargs=kwargs),
])
return result
@ -30,9 +30,9 @@ def add_editor_urls(model_name, parent_model_name=None, with_list=True, explicit
urlpatterns = [
url(r'^$', main_index, name='editor.index'),
url(r'^levels/(?P<pk>[0-9]+)/$', level_detail, name='editor.levels.detail'),
url(r'^levels/(?P<level>[0-9]+)/spaces/(?P<pk>[0-9]+)/$', space_detail, name='editor.spaces.detail'),
url(r'^levels/(?P<on_top_of>[0-9]+)/levels_on_top/create$', edit, name='editor.levels_on_top.create',
url(r'^levels/(?P<pk>c?[0-9]+)/$', level_detail, name='editor.levels.detail'),
url(r'^levels/(?P<level>c?[0-9]+)/spaces/(?P<pk>c?[0-9]+)/$', space_detail, name='editor.spaces.detail'),
url(r'^levels/(?P<on_top_of>c?[0-9]+)/levels_on_top/create$', edit, name='editor.levels_on_top.create',
kwargs={'model': 'Level'}),
url(r'^changesets/(?P<pk>[0-9]+)/$', changeset_detail, name='editor.changesets.detail'),
]

View file

@ -72,6 +72,10 @@ def level_detail(request, pk):
def space_detail(request, level, pk):
Space = request.changeset.wrap('Space')
space = get_object_or_404(Space.objects.select_related('level'), level__pk=level, pk=pk)
print('also!')
print(Space.objects.select_related('level').get(level__pk=level, pk=pk))
print(space)
print('aha')
return render(request, 'editor/space.html', {
'level': space.level,
@ -293,6 +297,8 @@ def list_objects(request, model=None, level=None, space=None, explicit_edit=Fals
edit_url_name = request.resolver_match.url_name[:-4]+('detail' if explicit_edit else 'edit')
for obj in queryset:
reverse_kwargs['pk'] = obj.pk
print(reverse_kwargs)
print(reverse(edit_url_name, kwargs=reverse_kwargs))
obj.edit_url = reverse(edit_url_name, kwargs=reverse_kwargs)
reverse_kwargs.pop('pk', None)

View file

@ -1,15 +1,18 @@
import operator
import typing
from collections import deque
from functools import reduce
from itertools import chain
from django.db import models
from django.db.models import Manager, Prefetch, Q
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor
from django.db.models.query_utils import DeferredAttribute
from django.utils.functional import cached_property
class BaseWrapper:
_not_wrapped = ('_changeset', '_author', '_obj', '_changes_qs', '_initial_values', '_wrap_instances')
_not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_initial_values', '_wrap_instances')
_allowed_callables = ('', )
def __init__(self, changeset, obj, author=None):
@ -206,22 +209,16 @@ class ModelInstanceWrapper(BaseWrapper):
self._changeset.add_delete(self, author=author)
class ChangesQuerySet:
def __init__(self, changeset, model, author):
self._changeset = changeset
self._model = model
self._author = author
class BaseQueryWrapper(BaseWrapper):
_allowed_callables = ('_add_hints', '_next_is_sticky', 'get_prefetch_queryset')
def __init__(self, changeset, obj, author=None, changes_qs=None, wrap_instances=True):
def __init__(self, changeset, obj, author=None, created_pks=None, wrap_instances=True):
super().__init__(changeset, obj, author)
if changes_qs is None:
changes_qs = ChangesQuerySet(changeset, obj.model, author)
self._changes_qs = changes_qs
if created_pks is None:
created_pks = self._changeset.get_created_pks(self._obj.model)
self._created_pks = created_pks
self._wrap_instances = wrap_instances
self._result = None
def get_queryset(self):
return self._obj
@ -231,18 +228,18 @@ class BaseQueryWrapper(BaseWrapper):
return super()._wrap_instance(instance)
return instance
def _wrap_queryset(self, queryset, changes_qs=None, wrap_instances=None):
if changes_qs is None:
changes_qs = self._changes_qs
def _wrap_queryset(self, queryset, created_pks=None, wrap_instances=None):
if created_pks is None:
created_pks = self._created_pks
if wrap_instances is None:
wrap_instances = self._wrap_instances
return QuerySetWrapper(self._changeset, queryset, self._author, changes_qs, wrap_instances)
return QuerySetWrapper(self._changeset, queryset, self._author, created_pks, wrap_instances)
def all(self):
return self._wrap_queryset(self.get_queryset().all())
def none(self):
return self._wrap_queryset(self.get_queryset().none())
return self._wrap_queryset(self.get_queryset().none(), created_pks=set())
def select_related(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().select_related(*args, **kwargs))
@ -266,8 +263,8 @@ class BaseQueryWrapper(BaseWrapper):
if len(results) == 1:
return self._wrap_instance(results[0])
if results:
raise self._obj.model.DoesNotExist
raise self._obj.model.MultipleObjectsReturned
raise self._obj.model.MultipleObjectsReturned
raise self._obj.model.DoesNotExist
def order_by(self, *args):
return self._wrap_queryset(self.get_queryset().order_by(*args))
@ -278,7 +275,21 @@ class BaseQueryWrapper(BaseWrapper):
remove_pks = []
for pk, new_value in other_values:
(add_pks if check(new_value) else remove_pks).append(pk)
return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks)
created_pks = set()
for pk, values in self._changeset.created_objects.get(self._obj.model, {}).items():
try:
if check(values[field_name]):
created_pks.add(pk)
continue
except AttributeError:
pass
if check(getattr(self._changeset.get_created_object(self._obj.model, pk), field_name)):
created_pks.add(pk)
return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks
@staticmethod
def is_created_pk(pk):
return isinstance(pk, str) and pk.startswith('c') and pk[1:].isnumeric()
def _filter_kwarg(self, filter_name, filter_value):
print(filter_name, '=', filter_value, sep='')
@ -294,9 +305,12 @@ class BaseQueryWrapper(BaseWrapper):
if field_name == 'pk' or field_name == self._obj.model._meta.pk.name:
if not segments:
return q
else:
return q
if self.is_created_pk(filter_value):
return Q(pk__in=()), set([int(filter_value[1:])])
return q, set()
elif segments == ['in']:
return (Q(pk__in=tuple(pk for pk in filter_value if not self.is_created_pk(pk))),
set(int(pk[1:]) for pk in filter_value if self.is_created_pk(pk)))
if isinstance(class_value, ForwardManyToOneDescriptor):
if not segments:
@ -318,13 +332,18 @@ class BaseQueryWrapper(BaseWrapper):
filter_type = 'pk'
if filter_type == 'pk' and segments == ['in']:
return self._filter_values(q, field_name, lambda val: val in filter_value)
q = Q(**{field_name+'__pk__in': tuple(pk for pk in filter_value if not self.is_created_pk(pk))})
filter_value = tuple(str(pk) for pk in filter_value)
return self._filter_values(q, field_name, lambda val: str(val) in filter_value)
if segments:
raise NotImplementedError
if filter_type == 'pk':
return self._filter_values(q, field_name, lambda val: val == filter_value)
if self.is_created_pk(filter_value):
q = Q(pk__in=())
filter_value = str(filter_value)
return self._filter_values(q, field_name, lambda val: str(val) == filter_value)
if filter_type == 'isnull':
return self._filter_values(q, field_name, lambda val: (val is None) is filter_value)
@ -348,9 +367,23 @@ class BaseQueryWrapper(BaseWrapper):
if filter_type == 'pk':
if class_value.reverse:
# todo: implement this for created models
model = class_value.field.model
return ((q & ~Q(pk__in=self._changeset.m2m_remove_existing.get(model, {}).get(filter_value, ()))) |
Q(pk__in=self._changeset.m2m_add_existing.get(model, {}).get(filter_value, ())))
if self.is_created_pk(filter_value):
filter_value = int(filter_value[1:])
pks = tuple(self._changeset.created_objects[model][filter_value].get(field_name, ()))
return (Q(pk__in=(pk for pk in pks if not self.is_created_pk(pk))),
set(int(pk[1:]) for pk in pks if self.is_created_pk(pk)))
def get_changeset_m2m(items):
return items.get(model, {}).get(filter_value, {}).get(field_name, ())
remove_pks = get_changeset_m2m(self._changeset.m2m_remove_existing)
add_pks = get_changeset_m2m(self._changeset.m2m_add_existing)
return (((q & ~Q(pk__in=(pk for pk in remove_pks if not self.is_created_pk(pk)))) |
Q(pk__in=(pk for pk in remove_pks if not self.is_created_pk(pk)))),
set(int(pk[1:]) for pk in add_pks if self.is_created_pk(pk)))
raise NotImplementedError
@ -373,25 +406,36 @@ class BaseQueryWrapper(BaseWrapper):
raise NotImplementedError('cannot filter %s by %s (%s)' % (self._obj.model, filter_name, class_value))
def _filter_q(self, q):
result = Q(*((self._filter_q(c) if isinstance(c, Q) else self._filter_kwarg(*c)) for c in q.children))
filters, created_pks = zip(*((self._filter_q(c) if isinstance(c, Q) else self._filter_kwarg(*c))
for c in q.children))
result = Q(*filters)
result.connector = q.connector
result.negated = q.negated
return result
def _filter(self, *args, **kwargs):
return chain(
created_pks = reduce(operator.and_ if q.connector == 'AND' else operator.or_, created_pks)
if q.negated:
created_pks = self._changeset.get_created_pks(self._obj.model)-created_pks
return result, created_pks
def _filter_or_exclude(self, negate, *args, **kwargs):
filters, created_pks = zip(*tuple(chain(
tuple(self._filter_q(q) for q in args),
tuple(self._filter_kwarg(name, value) for name, value in kwargs.items())
)
)))
created_pks = reduce(operator.and_, created_pks)
if negate:
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks
return self._wrap_queryset(self.get_queryset().filter(*filters), created_pks=(self._created_pks & created_pks))
def filter(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().filter(*self._filter(*args, **kwargs)))
return self._filter_or_exclude(False, *args, **kwargs)
def exclude(self, *args, **kwargs):
return self._wrap_queryset(self.get_queryset().exclude(*self._filter(*args, **kwargs)))
return self._filter_or_exclude(True, *args, **kwargs)
def count(self):
return self.get_queryset().count()
return self.get_queryset().count()+len(tuple(self._get_created_objects()))
def values_list(self, *args, flat=False):
return self.get_queryset().values_list(*args, flat=flat)
@ -405,14 +449,28 @@ class BaseQueryWrapper(BaseWrapper):
def using(self, alias):
return self._wrap_queryset(self.get_queryset().using(alias))
def _get_created_objects(self):
if not self._wrap_instances:
return ()
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
for pk in sorted(self._created_pks))
@cached_property
def _cached_result(self):
return (tuple(self._wrap_instance(instance) for instance in self.get_queryset()) +
tuple(self._get_created_objects()))
def __iter__(self):
return iter((self._wrap_instance(instance) for instance in self.get_queryset()))
return iter(self._cached_result)
def iterator(self):
return iter((self._wrap_instance(instance) for instance in self.get_queryset().iterator()))
return iter(chain(
(self._wrap_instance(instance) for instance in self.get_queryset().iterator()),
self._get_created_objects(),
))
def __len__(self):
return len(self.get_queryset())
return len(self._cached_result)
def create(self, *args, **kwargs):
obj = self.model(*args, **kwargs)