From 93c33ce60583a8cd7ef289e845f286b7314099fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Sun, 18 Jun 2017 04:40:37 +0200 Subject: [PATCH] filtering==, validate_unique and created model inheritance --- src/c3nav/editor/models.py | 7 ++- src/c3nav/editor/wrappers.py | 83 ++++++++++++++++++++++++++++++------ 2 files changed, 77 insertions(+), 13 deletions(-) diff --git a/src/c3nav/editor/models.py b/src/c3nav/editor/models.py index 4f03c636..9b327666 100644 --- a/src/c3nav/editor/models.py +++ b/src/c3nav/editor/models.py @@ -1,6 +1,7 @@ import json import typing from collections import OrderedDict +from itertools import chain from django.apps import apps from django.conf import settings @@ -106,7 +107,8 @@ class ChangeSet(models.Model): model = model._obj obj = model() obj.pk = 'c'+str(pk) - for name, value in self.created_objects[model][pk].items(): + for name, value in chain(*(self.created_objects.get(submodel, {}).get(pk, {}).items() + for submodel in ModelWrapper.get_submodels(model))): if name.startswith('title_'): obj.titles[name[6:]] = value continue @@ -126,6 +128,9 @@ class ChangeSet(models.Model): continue setattr(obj, name, model._meta.get_field(name).to_python(value)) + break + else: + raise model.DoesNotExist return self.wrap(obj, author=author) def get_created_pks(self, model): diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index 0b4e2cbf..831bca33 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -5,7 +5,7 @@ from functools import reduce, wraps from itertools import chain from django.db import models -from django.db.models import Manager, ManyToManyRel, Prefetch, Q +from django.db.models import Field, Manager, ManyToManyRel, Prefetch, Q from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor from django.db.models.query_utils import DeferredAttribute from django.utils.functional import cached_property @@ -17,7 +17,8 @@ def is_created_pk(pk): class BaseWrapper: _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_initial_values') - _allowed_callables = ('', ) + _allowed_callables = () + _wrapped_callables = () def __init__(self, changeset, obj, author=None): self._changeset = changeset @@ -25,6 +26,12 @@ class BaseWrapper: self._obj = obj def _wrap_model(self, model): + if isinstance(model, type) and issubclass(model, ModelInstanceWrapper): + model = model._parent + if isinstance(model, ModelWrapper): + if self._author == model._author and self._changeset == model._changeset: + return model + model = model._obj assert issubclass(model, models.Model) return ModelWrapper(self._changeset, model, self._author) @@ -58,6 +65,13 @@ class BaseWrapper: elif isinstance(value, type) and issubclass(value, Exception): pass elif callable(value) and name not in self._allowed_callables: + if name in self._wrapped_callables: + func = getattr(self._obj.__class__, name) + + @wraps(func) + def wrapper(*args, **kwargs): + return func(self, *args, **kwargs) + return wrapper if isinstance(self, ModelInstanceWrapper) and not hasattr(models.Model, name): return value raise TypeError('Can not call %s.%s wrapped!' % (type(self), name)) @@ -74,12 +88,31 @@ class BaseWrapper: class ModelWrapper(BaseWrapper): _allowed_callables = ('EditorForm',) + _submodels_by_model = {} def __eq__(self, other): if type(other) == ModelWrapper: return self._obj is other._obj return self._obj is other + @classmethod + def get_submodels(cls, model): + try: + return cls._submodels_by_model[cls._obj] + except: + pass + all_models = model.__subclasses__() + models = [] + if not model._meta.abstract: + models.append(model) + models.extend(chain(*(cls.get_submodels(model) for model in all_models))) + cls._submodels_by_model[model] = models + return models + + @cached_property + def _submodels(self): + return self.get_submodels(self._obj) + def create_wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']: # noinspection PyTypeChecker return self.create_metaclass()(self._obj.__name__ + 'InstanceWrapper', (ModelInstanceWrapper,), {}) @@ -94,6 +127,8 @@ class ModelWrapper(BaseWrapper): parent = self class ModelInstanceWrapperMeta(type): + _parent = parent + def __getattr__(self, name): return getattr(parent, name) @@ -107,15 +142,21 @@ class ModelWrapper(BaseWrapper): return ModelInstanceWrapperMeta + def __repr__(self): + return '' + class ModelInstanceWrapper(BaseWrapper): - _allowed_callables = ('full_clean', 'validate_unique') + _allowed_callables = ('full_clean', '_perform_unique_checks', '_perform_date_checks') + _wrapped_callables = ('validate_unique', '_get_pk_val') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) updates = self._changeset.updated_existing.get(type(self._obj), {}).get(self._obj.pk, {}) self._initial_values = {} for field in self._obj._meta.get_fields(): + if not isinstance(field, Field): + continue if field.related_model is None: if field.primary_key: continue @@ -174,6 +215,10 @@ class ModelInstanceWrapper(BaseWrapper): return '<%s #%s (created) from Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) return '<%s #%d (existing) with Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) + def _get_unique_checks(self, exclude=None): + unique_checks, date_checks = self._obj.__class__._get_unique_checks(self, exclude=exclude) + return [(self._wrap_model(model), unique) for model, unique in unique_checks], date_checks + def save(self, author=None): if author is None: author = self._author @@ -236,10 +281,14 @@ class BaseQueryWrapper(BaseWrapper): def __init__(self, changeset, obj, author=None, created_pks=None, extra=()): super().__init__(changeset, obj, author) if created_pks is None: - created_pks = self._changeset.get_created_pks(self._obj.model) + created_pks = self._get_initial_created_pks() self._created_pks = created_pks self._extra = extra + def _get_initial_created_pks(self): + self.model.get_submodels(self.model._obj) + return reduce(operator.or_, (self._changeset.get_created_pks(model) for model in self.model._submodels)) + def _wrap_instance(self, instance): return super()._wrap_instance(instance) @@ -300,18 +349,27 @@ class BaseQueryWrapper(BaseWrapper): raise self._obj.model.MultipleObjectsReturned raise self._obj.model.DoesNotExist + @get_queryset + def exists(self, *args, **kwargs): + if self._created_pks: + return True + return self._obj.exists() + @get_queryset def order_by(self, *args): return self._wrap_queryset(self._obj.order_by(*args)) def _filter_values(self, q, field_name, check): - other_values = self._changeset.get_changed_values(self._obj.model, field_name) + other_values = () + models = [model for model in self.model._submodels] + for model in models: + other_values += self._changeset.get_changed_values(model, field_name) add_pks = [] remove_pks = [] for pk, new_value in other_values: (add_pks if check(new_value) else remove_pks).append(pk) created_pks = set() - for pk, values in self._changeset.created_objects.get(self._obj.model, {}).items(): + for pk, values in chain(*(self._changeset.created_objects.get(model, {}).items() for model in models)): try: if check(values[field_name]): created_pks.add(pk) @@ -320,6 +378,7 @@ class BaseQueryWrapper(BaseWrapper): pass if check(getattr(self._changeset.get_created_object(self._obj.model, pk), field_name)): created_pks.add(pk) + return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks def _filter_kwarg(self, filter_name, filter_value): @@ -463,7 +522,7 @@ class BaseQueryWrapper(BaseWrapper): if isinstance(class_value, DeferredAttribute): if not segments: - raise NotImplementedError + return self._filter_values(q, field_name, lambda val: val == filter_value) filter_type = segments.pop(0) @@ -486,7 +545,7 @@ class BaseQueryWrapper(BaseWrapper): created_pks = reduce(operator.and_ if q.connector == 'AND' else operator.or_, created_pks) if q.negated: - created_pks = self._changeset.get_created_pks(self._obj.model)-created_pks + created_pks = self._get_initial_created_pks()-created_pks return result, created_pks def _filter_or_exclude(self, negate, *args, **kwargs): @@ -498,7 +557,7 @@ class BaseQueryWrapper(BaseWrapper): created_pks = reduce(operator.and_, created_pks) if negate: filters = (~Q(*filters), ) - created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks + created_pks = self._get_initial_created_pks()-created_pks return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks)) @get_queryset @@ -511,7 +570,7 @@ class BaseQueryWrapper(BaseWrapper): @get_queryset def count(self): - return self._obj.count()+len(tuple(self._get_created_objects())) + return self._obj.count()+len(tuple(self._get_created_objects(get_foreign_objects=False))) @get_queryset def values_list(self, *args, flat=False): @@ -545,8 +604,8 @@ class BaseQueryWrapper(BaseWrapper): def _next_is_sticky(self): return self._wrap_queryset(self._obj._next_is_sticky()) - def _get_created_objects(self): - return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True) + def _get_created_objects(self, get_foreign_objects=True): + return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=get_foreign_objects) for pk in sorted(self._created_pks)) @queryset_only