filtering==, validate_unique and created model inheritance

This commit is contained in:
Laura Klünder 2017-06-18 04:40:37 +02:00
parent 2611e20284
commit 93c33ce605
2 changed files with 77 additions and 13 deletions

View file

@ -1,6 +1,7 @@
import json import json
import typing import typing
from collections import OrderedDict from collections import OrderedDict
from itertools import chain
from django.apps import apps from django.apps import apps
from django.conf import settings from django.conf import settings
@ -106,7 +107,8 @@ class ChangeSet(models.Model):
model = model._obj model = model._obj
obj = model() obj = model()
obj.pk = 'c'+str(pk) obj.pk = 'c'+str(pk)
for name, value in self.created_objects[model][pk].items(): for name, value in chain(*(self.created_objects.get(submodel, {}).get(pk, {}).items()
for submodel in ModelWrapper.get_submodels(model))):
if name.startswith('title_'): if name.startswith('title_'):
obj.titles[name[6:]] = value obj.titles[name[6:]] = value
continue continue
@ -126,6 +128,9 @@ class ChangeSet(models.Model):
continue continue
setattr(obj, name, model._meta.get_field(name).to_python(value)) setattr(obj, name, model._meta.get_field(name).to_python(value))
break
else:
raise model.DoesNotExist
return self.wrap(obj, author=author) return self.wrap(obj, author=author)
def get_created_pks(self, model): def get_created_pks(self, model):

View file

@ -5,7 +5,7 @@ from functools import reduce, wraps
from itertools import chain from itertools import chain
from django.db import models from django.db import models
from django.db.models import Manager, ManyToManyRel, Prefetch, Q from django.db.models import Field, Manager, ManyToManyRel, Prefetch, Q
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor
from django.db.models.query_utils import DeferredAttribute from django.db.models.query_utils import DeferredAttribute
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -17,7 +17,8 @@ def is_created_pk(pk):
class BaseWrapper: class BaseWrapper:
_not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_initial_values') _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_initial_values')
_allowed_callables = ('', ) _allowed_callables = ()
_wrapped_callables = ()
def __init__(self, changeset, obj, author=None): def __init__(self, changeset, obj, author=None):
self._changeset = changeset self._changeset = changeset
@ -25,6 +26,12 @@ class BaseWrapper:
self._obj = obj self._obj = obj
def _wrap_model(self, model): def _wrap_model(self, model):
if isinstance(model, type) and issubclass(model, ModelInstanceWrapper):
model = model._parent
if isinstance(model, ModelWrapper):
if self._author == model._author and self._changeset == model._changeset:
return model
model = model._obj
assert issubclass(model, models.Model) assert issubclass(model, models.Model)
return ModelWrapper(self._changeset, model, self._author) return ModelWrapper(self._changeset, model, self._author)
@ -58,6 +65,13 @@ class BaseWrapper:
elif isinstance(value, type) and issubclass(value, Exception): elif isinstance(value, type) and issubclass(value, Exception):
pass pass
elif callable(value) and name not in self._allowed_callables: elif callable(value) and name not in self._allowed_callables:
if name in self._wrapped_callables:
func = getattr(self._obj.__class__, name)
@wraps(func)
def wrapper(*args, **kwargs):
return func(self, *args, **kwargs)
return wrapper
if isinstance(self, ModelInstanceWrapper) and not hasattr(models.Model, name): if isinstance(self, ModelInstanceWrapper) and not hasattr(models.Model, name):
return value return value
raise TypeError('Can not call %s.%s wrapped!' % (type(self), name)) raise TypeError('Can not call %s.%s wrapped!' % (type(self), name))
@ -74,12 +88,31 @@ class BaseWrapper:
class ModelWrapper(BaseWrapper): class ModelWrapper(BaseWrapper):
_allowed_callables = ('EditorForm',) _allowed_callables = ('EditorForm',)
_submodels_by_model = {}
def __eq__(self, other): def __eq__(self, other):
if type(other) == ModelWrapper: if type(other) == ModelWrapper:
return self._obj is other._obj return self._obj is other._obj
return self._obj is other return self._obj is other
@classmethod
def get_submodels(cls, model):
try:
return cls._submodels_by_model[cls._obj]
except:
pass
all_models = model.__subclasses__()
models = []
if not model._meta.abstract:
models.append(model)
models.extend(chain(*(cls.get_submodels(model) for model in all_models)))
cls._submodels_by_model[model] = models
return models
@cached_property
def _submodels(self):
return self.get_submodels(self._obj)
def create_wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']: def create_wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']:
# noinspection PyTypeChecker # noinspection PyTypeChecker
return self.create_metaclass()(self._obj.__name__ + 'InstanceWrapper', (ModelInstanceWrapper,), {}) return self.create_metaclass()(self._obj.__name__ + 'InstanceWrapper', (ModelInstanceWrapper,), {})
@ -94,6 +127,8 @@ class ModelWrapper(BaseWrapper):
parent = self parent = self
class ModelInstanceWrapperMeta(type): class ModelInstanceWrapperMeta(type):
_parent = parent
def __getattr__(self, name): def __getattr__(self, name):
return getattr(parent, name) return getattr(parent, name)
@ -107,15 +142,21 @@ class ModelWrapper(BaseWrapper):
return ModelInstanceWrapperMeta return ModelInstanceWrapperMeta
def __repr__(self):
return '<ModelWrapper '+repr(self._obj.__name__)+'>'
class ModelInstanceWrapper(BaseWrapper): class ModelInstanceWrapper(BaseWrapper):
_allowed_callables = ('full_clean', 'validate_unique') _allowed_callables = ('full_clean', '_perform_unique_checks', '_perform_date_checks')
_wrapped_callables = ('validate_unique', '_get_pk_val')
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
updates = self._changeset.updated_existing.get(type(self._obj), {}).get(self._obj.pk, {}) updates = self._changeset.updated_existing.get(type(self._obj), {}).get(self._obj.pk, {})
self._initial_values = {} self._initial_values = {}
for field in self._obj._meta.get_fields(): for field in self._obj._meta.get_fields():
if not isinstance(field, Field):
continue
if field.related_model is None: if field.related_model is None:
if field.primary_key: if field.primary_key:
continue continue
@ -174,6 +215,10 @@ class ModelInstanceWrapper(BaseWrapper):
return '<%s #%s (created) from Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) return '<%s #%s (created) from Changeset #%d>' % (cls_name, self.pk, self._changeset.pk)
return '<%s #%d (existing) with Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) return '<%s #%d (existing) with Changeset #%d>' % (cls_name, self.pk, self._changeset.pk)
def _get_unique_checks(self, exclude=None):
unique_checks, date_checks = self._obj.__class__._get_unique_checks(self, exclude=exclude)
return [(self._wrap_model(model), unique) for model, unique in unique_checks], date_checks
def save(self, author=None): def save(self, author=None):
if author is None: if author is None:
author = self._author author = self._author
@ -236,10 +281,14 @@ class BaseQueryWrapper(BaseWrapper):
def __init__(self, changeset, obj, author=None, created_pks=None, extra=()): def __init__(self, changeset, obj, author=None, created_pks=None, extra=()):
super().__init__(changeset, obj, author) super().__init__(changeset, obj, author)
if created_pks is None: if created_pks is None:
created_pks = self._changeset.get_created_pks(self._obj.model) created_pks = self._get_initial_created_pks()
self._created_pks = created_pks self._created_pks = created_pks
self._extra = extra self._extra = extra
def _get_initial_created_pks(self):
self.model.get_submodels(self.model._obj)
return reduce(operator.or_, (self._changeset.get_created_pks(model) for model in self.model._submodels))
def _wrap_instance(self, instance): def _wrap_instance(self, instance):
return super()._wrap_instance(instance) return super()._wrap_instance(instance)
@ -300,18 +349,27 @@ class BaseQueryWrapper(BaseWrapper):
raise self._obj.model.MultipleObjectsReturned raise self._obj.model.MultipleObjectsReturned
raise self._obj.model.DoesNotExist raise self._obj.model.DoesNotExist
@get_queryset
def exists(self, *args, **kwargs):
if self._created_pks:
return True
return self._obj.exists()
@get_queryset @get_queryset
def order_by(self, *args): def order_by(self, *args):
return self._wrap_queryset(self._obj.order_by(*args)) return self._wrap_queryset(self._obj.order_by(*args))
def _filter_values(self, q, field_name, check): def _filter_values(self, q, field_name, check):
other_values = self._changeset.get_changed_values(self._obj.model, field_name) other_values = ()
models = [model for model in self.model._submodels]
for model in models:
other_values += self._changeset.get_changed_values(model, field_name)
add_pks = [] add_pks = []
remove_pks = [] remove_pks = []
for pk, new_value in other_values: for pk, new_value in other_values:
(add_pks if check(new_value) else remove_pks).append(pk) (add_pks if check(new_value) else remove_pks).append(pk)
created_pks = set() created_pks = set()
for pk, values in self._changeset.created_objects.get(self._obj.model, {}).items(): for pk, values in chain(*(self._changeset.created_objects.get(model, {}).items() for model in models)):
try: try:
if check(values[field_name]): if check(values[field_name]):
created_pks.add(pk) created_pks.add(pk)
@ -320,6 +378,7 @@ class BaseQueryWrapper(BaseWrapper):
pass pass
if check(getattr(self._changeset.get_created_object(self._obj.model, pk), field_name)): if check(getattr(self._changeset.get_created_object(self._obj.model, pk), field_name)):
created_pks.add(pk) created_pks.add(pk)
return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks
def _filter_kwarg(self, filter_name, filter_value): def _filter_kwarg(self, filter_name, filter_value):
@ -463,7 +522,7 @@ class BaseQueryWrapper(BaseWrapper):
if isinstance(class_value, DeferredAttribute): if isinstance(class_value, DeferredAttribute):
if not segments: if not segments:
raise NotImplementedError return self._filter_values(q, field_name, lambda val: val == filter_value)
filter_type = segments.pop(0) filter_type = segments.pop(0)
@ -486,7 +545,7 @@ class BaseQueryWrapper(BaseWrapper):
created_pks = reduce(operator.and_ if q.connector == 'AND' else operator.or_, created_pks) created_pks = reduce(operator.and_ if q.connector == 'AND' else operator.or_, created_pks)
if q.negated: if q.negated:
created_pks = self._changeset.get_created_pks(self._obj.model)-created_pks created_pks = self._get_initial_created_pks()-created_pks
return result, created_pks return result, created_pks
def _filter_or_exclude(self, negate, *args, **kwargs): def _filter_or_exclude(self, negate, *args, **kwargs):
@ -498,7 +557,7 @@ class BaseQueryWrapper(BaseWrapper):
created_pks = reduce(operator.and_, created_pks) created_pks = reduce(operator.and_, created_pks)
if negate: if negate:
filters = (~Q(*filters), ) filters = (~Q(*filters), )
created_pks = self._changeset.get_created_pks(self._obj.model) - created_pks created_pks = self._get_initial_created_pks()-created_pks
return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks)) return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks))
@get_queryset @get_queryset
@ -511,7 +570,7 @@ class BaseQueryWrapper(BaseWrapper):
@get_queryset @get_queryset
def count(self): def count(self):
return self._obj.count()+len(tuple(self._get_created_objects())) return self._obj.count()+len(tuple(self._get_created_objects(get_foreign_objects=False)))
@get_queryset @get_queryset
def values_list(self, *args, flat=False): def values_list(self, *args, flat=False):
@ -545,8 +604,8 @@ class BaseQueryWrapper(BaseWrapper):
def _next_is_sticky(self): def _next_is_sticky(self):
return self._wrap_queryset(self._obj._next_is_sticky()) return self._wrap_queryset(self._obj._next_is_sticky())
def _get_created_objects(self): def _get_created_objects(self, get_foreign_objects=True):
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=True) return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=get_foreign_objects)
for pk in sorted(self._created_pks)) for pk in sorted(self._created_pks))
@queryset_only @queryset_only