filtering==, validate_unique and created model inheritance

This commit is contained in:
Laura Klünder 2017-06-18 04:40:37 +02:00
parent 2611e20284
commit 93c33ce605
2 changed files with 77 additions and 13 deletions

View file

@ -1,6 +1,7 @@
import json
import typing
from collections import OrderedDict
from itertools import chain
from django.apps import apps
from django.conf import settings
@ -106,7 +107,8 @@ class ChangeSet(models.Model):
model = model._obj
obj = model()
obj.pk = 'c'+str(pk)
for name, value in self.created_objects[model][pk].items():
for name, value in chain(*(self.created_objects.get(submodel, {}).get(pk, {}).items()
for submodel in ModelWrapper.get_submodels(model))):
if name.startswith('title_'):
obj.titles[name[6:]] = value
continue
@ -126,6 +128,9 @@ class ChangeSet(models.Model):
continue
setattr(obj, name, model._meta.get_field(name).to_python(value))
break
else:
raise model.DoesNotExist
return self.wrap(obj, author=author)
def get_created_pks(self, model):

View file

@ -5,7 +5,7 @@ from functools import reduce, wraps
from itertools import chain
from django.db import models
from django.db.models import Manager, ManyToManyRel, Prefetch, Q
from django.db.models import Field, Manager, ManyToManyRel, Prefetch, Q
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor
from django.db.models.query_utils import DeferredAttribute
from django.utils.functional import cached_property
@ -17,7 +17,8 @@ def is_created_pk(pk):
class BaseWrapper:
_not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_initial_values')
_allowed_callables = ('', )
_allowed_callables = ()
_wrapped_callables = ()
def __init__(self, changeset, obj, author=None):
self._changeset = changeset
@ -25,6 +26,12 @@ class BaseWrapper:
self._obj = obj
def _wrap_model(self, model):
if isinstance(model, type) and issubclass(model, ModelInstanceWrapper):
model = model._parent
if isinstance(model, ModelWrapper):
if self._author == model._author and self._changeset == model._changeset:
return model
model = model._obj
assert issubclass(model, models.Model)
return ModelWrapper(self._changeset, model, self._author)
@ -58,6 +65,13 @@ class BaseWrapper:
elif isinstance(value, type) and issubclass(value, Exception):
pass
elif callable(value) and name not in self._allowed_callables:
if name in self._wrapped_callables:
func = getattr(self._obj.__class__, name)
@wraps(func)
def wrapper(*args, **kwargs):
return func(self, *args, **kwargs)
return wrapper
if isinstance(self, ModelInstanceWrapper) and not hasattr(models.Model, name):
return value
raise TypeError('Can not call %s.%s wrapped!' % (type(self), name))
@ -74,12 +88,31 @@ class BaseWrapper:
class ModelWrapper(BaseWrapper):
_allowed_callables = ('EditorForm',)
_submodels_by_model = {}
def __eq__(self, other):
if type(other) == ModelWrapper:
return self._obj is other._obj
return self._obj is other
@classmethod
def get_submodels(cls, model):
try:
return cls._submodels_by_model[cls._obj]
except:
pass
all_models = model.__subclasses__()
models = []
if not model._meta.abstract:
models.append(model)
models.extend(chain(*(cls.get_submodels(model) for model in all_models)))
cls._submodels_by_model[model] = models
return models
@cached_property
def _submodels(self):
return self.get_submodels(self._obj)
def create_wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']:
# noinspection PyTypeChecker
return self.create_metaclass()(self._obj.__name__ + 'InstanceWrapper', (ModelInstanceWrapper,), {})
@ -94,6 +127,8 @@ class ModelWrapper(BaseWrapper):
parent = self
class ModelInstanceWrapperMeta(type):
_parent = parent
def __getattr__(self, name):
return getattr(parent, name)
@ -107,15 +142,21 @@ class ModelWrapper(BaseWrapper):
return ModelInstanceWrapperMeta
def __repr__(self):
return '<ModelWrapper '+repr(self._obj.__name__)+'>'
class ModelInstanceWrapper(BaseWrapper):
_allowed_callables = ('full_clean', 'validate_unique')
_allowed_callables = ('full_clean', '_perform_unique_checks', '_perform_date_checks')
_wrapped_callables = ('validate_unique', '_get_pk_val')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
updates = self._changeset.updated_existing.get(type(self._obj), {}).get(self._obj.pk, {})
self._initial_values = {}
for field in self._obj._meta.get_fields():
if not isinstance(field, Field):
continue
if field.related_model is None:
if field.primary_key:
continue
@ -174,6 +215,10 @@ class ModelInstanceWrapper(BaseWrapper):
return '<%s #%s (created) from Changeset #%d>' % (cls_name, self.pk, self._changeset.pk)
return '<%s #%d (existing) with Changeset #%d>' % (cls_name, self.pk, self._changeset.pk)
def _get_unique_checks(self, exclude=None):
unique_checks, date_checks = self._obj.__class__._get_unique_checks(self, exclude=exclude)
return [(self._wrap_model(model), unique) for model, unique in unique_checks], date_checks
def save(self, author=None):
if author is None:
author = self._author
@ -236,10 +281,14 @@ class BaseQueryWrapper(BaseWrapper):
def __init__(self, changeset, obj, author=None, created_pks=None, extra=()):
super().__init__(changeset, obj, author)
if created_pks is None:
created_pks = self._changeset.get_created_pks(self._obj.model)
created_pks = self._get_initial_created_pks()
self._created_pks = created_pks
self._extra = extra
def _get_initial_created_pks(self):
self.model.get_submodels(self.model._obj)
return reduce(operator.or_, (self._changeset.get_created_pks(model) for model in self.model._submodels))
def _wrap_instance(self, instance):
return super()._wrap_instance(instance)
@ -300,18 +349,27 @@ class BaseQueryWrapper(BaseWrapper):
raise self._obj.model.MultipleObjectsReturned
raise self._obj.model.DoesNotExist
@get_queryset
def exists(self, *args, **kwargs):
if self._created_pks:
return True
return self._obj.exists()
@get_queryset
def order_by(self, *args):
return self._wrap_queryset(self._obj.order_by(*args))
def _filter_values(self, q, field_name, check):
other_values = self._changeset.get_changed_values(self._obj.model, field_name)
other_values = ()
models = [model for model in self.model._submodels]
for model in models:
other_values += self._changeset.get_changed_values(model, field_name)
add_pks = []
remove_pks = []
for pk, new_value in other_values:
(add_pks if check(new_value) else remove_pks).append(pk)
created_pks = set()
for pk, values in self._changeset.created_objects.get(self._obj.model, {}).items():
for pk, values in chain(*(self._changeset.created_objects.get(model, {}).items() for model in models)):
try:
if check(values[field_name]):
created_pks.add(pk)
@ -320,6 +378,7 @@ class BaseQueryWrapper(BaseWrapper):
pass
if check(getattr(self._changeset.get_created_object(self._obj.model, pk), field_name)):
created_pks.add(pk)
return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks
def _filter_kwarg(self, filter_name, filter_value):
@ -463,7 +522,7 @@ class BaseQueryWrapper(BaseWrapper):
if isinstance(class_value, DeferredAttribute):
if not segments:
raise NotImplementedError
return self._filter_values(q, field_name, lambda val: val == filter_value)
filter_type = segments.pop(0)
@ -486,7 +545,7 @@ class BaseQueryWrapper(BaseWrapper):
created_pks = reduce(operator.and_ if q.connector == 'AND' else operator.or_, created_pks)
if q.negated:
created_pks = self._changeset.get_created_pks(self._obj.model)-created_pks
created_pks = self._get_initial_created_pks()-created_pks
return result, created_pks
def _filter_or_exclude(self, negate, *args, **kwargs):
@ -498,7 +557,7 @@ class BaseQueryWrapper(BaseWrapper):
created_pks = reduce(operator.and_, created_pks)
if negate:
filters = (~Q(*filters), )
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks
created_pks = self._get_initial_created_pks()-created_pks
return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks))
@get_queryset
@ -511,7 +570,7 @@ class BaseQueryWrapper(BaseWrapper):
@get_queryset
def count(self):
return self._obj.count()+len(tuple(self._get_created_objects()))
return self._obj.count()+len(tuple(self._get_created_objects(get_foreign_objects=False)))
@get_queryset
def values_list(self, *args, flat=False):
@ -545,8 +604,8 @@ class BaseQueryWrapper(BaseWrapper):
def _next_is_sticky(self):
return self._wrap_queryset(self._obj._next_is_sticky())
def _get_created_objects(self):
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True)
def _get_created_objects(self, get_foreign_objects=True):
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=get_foreign_objects)
for pk in sorted(self._created_pks))
@queryset_only