document wrappers

This commit is contained in:
Laura Klünder 2017-06-21 19:01:00 +02:00
parent 8306345856
commit da543f8ee1

View file

@ -15,6 +15,14 @@ from c3nav.editor.utils import is_created_pk
class BaseWrapper:
"""
Base Class for all wrappers.
Saves wrapped object along with the changeset and the author for new changes.
getattr, setattr and delattr will be forwarded to the object, exceptions are specified in _not_wrapped.
If the value of an attribute is a model, model instance, manager or queryset, it will be wrapped, to.
Callables will only be returned be getattr when they are inside _allowed_callables.
Callables in _wrapped_callables will be returned wrapped, so that their self if the wrapping instance.
"""
_not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_result_cache',
'_initial_values')
_allowed_callables = ()
@ -27,6 +35,9 @@ class BaseWrapper:
# noinspection PyUnresolvedReferences
def _wrap_model(self, model):
"""
Wrap a model, with same changeset and author as this wrapper.
"""
if isinstance(model, type) and issubclass(model, ModelInstanceWrapper):
model = model._parent
if isinstance(model, ModelWrapper):
@ -37,6 +48,9 @@ class BaseWrapper:
return ModelWrapper(self._changeset, model, self._author)
def _wrap_instance(self, instance):
"""
Wrap a model instance, with same changeset and author as this wrapper.
"""
if isinstance(instance, ModelInstanceWrapper):
if self._author == instance._author and self._changeset == instance._changeset:
return instance
@ -45,6 +59,10 @@ class BaseWrapper:
return self._wrap_model(type(instance)).create_wrapped_model_class()(self._changeset, instance, self._author)
def _wrap_manager(self, manager):
"""
Wrap a manager, with same changeset and author as this wrapper.
Detects RelatedManager or ManyRelatedmanager instances and chooses the Wrapper accordingly.
"""
assert isinstance(manager, Manager)
if hasattr(manager, 'through'):
return ManyRelatedManagerWrapper(self._changeset, manager, self._author)
@ -53,6 +71,9 @@ class BaseWrapper:
return ManagerWrapper(self._changeset, manager, self._author)
def _wrap_queryset(self, queryset):
"""
Wrap a queryset, with same changeset and author as this wrapper.
"""
return QuerySetWrapper(self._changeset, queryset, self._author)
def __getattr__(self, name):
@ -88,7 +109,12 @@ class BaseWrapper:
class ModelWrapper(BaseWrapper):
_allowed_callables = ('EditorForm',)
"""
Wraps a model class.
Can be compared to other wrapped or non-wrapped model classes.
Can be called (like a class) to get a wrapped model instance
that has the according ModelWrapper as its type / metaclass.
"""
_submodels_by_model = {}
def __eq__(self, other):
@ -96,12 +122,20 @@ class ModelWrapper(BaseWrapper):
return self._obj is other._obj
return self._obj is other
# noinspection PyPep8Naming
@cached_property
def EditorForm(self):
"""
Returns an editor form for this model.
"""
return create_editor_form(self._obj)
@classmethod
def get_submodels(cls, model):
def get_submodels(cls, model: models.Model):
"""
Get non-abstract submodels for a model including the model itself.
Result is cached.
"""
try:
return cls._submodels_by_model[model]
except KeyError:
@ -116,19 +150,31 @@ class ModelWrapper(BaseWrapper):
@cached_property
def _submodels(self):
"""
Get non-abstract submodels for this model including the model itself.
"""
return self.get_submodels(self._obj)
def create_wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']:
"""
Return a ModelInstanceWrapper that has a proxy to this instance as its type / metaclass. #voodoo
"""
# noinspection PyTypeChecker
return self.create_metaclass()(self._obj.__name__ + 'InstanceWrapper', (ModelInstanceWrapper,), {})
def __call__(self, **kwargs):
"""
Create a wrapped instance of this model. _wrap_instance will call create_wrapped_model_class().
"""
instance = self._wrap_instance(self._obj())
for name, value in kwargs.items():
setattr(instance, name, value)
return instance
def create_metaclass(self):
"""
Create the proxy metaclass for craeate_wrapped_model_class().
"""
parent = self
class ModelInstanceWrapperMeta(type):
@ -152,10 +198,20 @@ class ModelWrapper(BaseWrapper):
class ModelInstanceWrapper(BaseWrapper):
"""
Wraps a model instance. Don't use this directly, call a ModelWrapper instead / use ChangeSet.wrap().
Creates changes in changeset when save() is called.
Updates updated values on existing objects on init.
Can be compared to other wrapped or non-wrapped model instances.
"""
_allowed_callables = ('full_clean', '_perform_unique_checks', '_perform_date_checks')
_wrapped_callables = ('validate_unique', '_get_pk_val')
def __init__(self, *args, **kwargs):
"""
Get initial values of this instance, so we know what changed on save.
Updates values according to cangeset if this is an existing object.
"""
super().__init__(*args, **kwargs)
updates = self._changeset.updated_existing.get(type(self._obj), {}).get(self._obj.pk, {})
self._initial_values = {}
@ -199,6 +255,10 @@ class ModelInstanceWrapper(BaseWrapper):
return self.pk == other.pk
def __setattr__(self, name, value):
"""
We have to intercept here because RelatedFields won't accept
Wrapped model instances values, so we have to trick them.
"""
if name in self._not_wrapped:
return super().__setattr__(name, value)
class_value = getattr(type(self._obj), name, None)
@ -225,6 +285,9 @@ class ModelInstanceWrapper(BaseWrapper):
return [(self._wrap_model(model), unique) for model, unique in unique_checks], date_checks
def save(self, author=None):
"""
Create changes in changeset instead of saving.
"""
if author is None:
author = self._author
if self.pk is None:
@ -263,6 +326,10 @@ class ModelInstanceWrapper(BaseWrapper):
def get_queryset(func):
"""
Wraps methods of BaseQueryWrapper that manipulate a queryset.
If self is a Manager, not an object, preceed the method call with a filter call according to the manager.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, 'get_queryset'):
@ -272,6 +339,10 @@ def get_queryset(func):
def queryset_only(func):
"""
Wraps methods of BaseQueryWrapper that execute a queryset.
If self is a Manager, they throw an error, because you have to get a Queryset (e.g. using .all()) first.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if hasattr(self, 'get_queryset'):
@ -281,6 +352,13 @@ def queryset_only(func):
class BaseQueryWrapper(BaseWrapper):
"""
Base class for everything that wraps a QuerySet or manager.
Don't use this directly, but via WrappedModel.objects or WrappedInstance.groups or similar.
Intercepts all query methods to exclude ids / include ids for each filter according to changeset changes.
Keeps track of which created objects the current filtering still applies to.
When evaluated, just does everything as if the queryset was applied to the databse.
"""
_allowed_callables = ('_add_hints', 'get_prefetch_queryset', '_apply_rel_filters')
def __init__(self, changeset, obj, author=None, created_pks=None, extra=()):
@ -291,13 +369,18 @@ class BaseQueryWrapper(BaseWrapper):
self._extra = extra
def _get_initial_created_pks(self):
"""
Get all created pks for this query's model an submodels.
"""
self.model.get_submodels(self.model._obj)
return reduce(operator.or_, (self._changeset.get_created_pks(model) for model in self.model._submodels))
def _wrap_instance(self, instance):
return super()._wrap_instance(instance)
def _wrap_queryset(self, queryset, created_pks=None, add_extra=()):
"""
Wraps a queryset, usually after manipulating the current one.
:param created_pks: set of created pks to be still in the next queryset (the same ones as this one by default)
:param add_extra: extra() calls that have been added to the query
"""
if created_pks is None:
created_pks = self._created_pks
if created_pks is False:
@ -318,6 +401,11 @@ class BaseQueryWrapper(BaseWrapper):
@get_queryset
def prefetch_related(self, *lookups):
"""
We split up all prefetch related lookups into one-level prefetches
and convert them into Prefetch() objects with custom querysets.
This makes sure that the prefetch also happens on the virtually modified database.
"""
lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups)
max_depth = max(len(lookup) for lookup in lookups_splitted)
lookups_by_depth = []
@ -362,9 +450,19 @@ class BaseQueryWrapper(BaseWrapper):
@get_queryset
def order_by(self, *args):
"""
Order by is not yet supported on created instances because this is not needed so far.
"""
return self._wrap_queryset(self._obj.order_by(*args))
def _filter_values(self, q, field_name, check):
"""
Filter by value.
:param q: base Q object to give to the database and to modify
:param field_name: name of the field whose value should be compared
:param check: comparision function that only gets the new value
:return: new Q object and set of matched existing pks
"""
other_values = ()
submodels = [model for model in self.model._submodels]
for model in submodels:
@ -382,6 +480,11 @@ class BaseQueryWrapper(BaseWrapper):
return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks
def _filter_kwarg(self, filter_name, filter_value):
"""
filter by kwarg.
The core filtering happens here, as also Q objects are just a collection / combination of kwarg filters.
:return: new Q object and set of matched existing pks
"""
# print(filter_name, '=', filter_value, sep='')
segments = filter_name.split('__')
@ -391,19 +494,25 @@ class BaseQueryWrapper(BaseWrapper):
except AttributeError:
raise ValueError('%s has no attribute %s' % (self._obj.model, field_name))
# create a base q that we'll modify later
q = Q(**{filter_name: filter_value})
# check if the filter begins with pk or the name of the primary key
if field_name == 'pk' or field_name == self._obj.model._meta.pk.name:
if not segments:
# if the check is just 'pk' or the name or the name of the primary key, return the mathing object
if is_created_pk(filter_value):
return Q(pk__in=()), set([int(filter_value[1:])])
return q, set()
elif segments == ['in']:
# if the check is 'pk__in' it's nearly as easy
return (Q(pk__in=tuple(pk for pk in filter_value if not is_created_pk(pk))),
set(int(pk[1:]) for pk in filter_value if is_created_pk(pk)))
# check if we are filtering by a foreign key field
if isinstance(class_value, ForwardManyToOneDescriptor):
if not segments:
# turn 'foreign_obj' into 'foreign_obj__pk' for later
filter_name = field_name + '__pk'
filter_value = filter_value.pk
segments = ['pk']
@ -412,6 +521,7 @@ class BaseQueryWrapper(BaseWrapper):
filter_type = segments.pop(0)
if not segments and filter_type == 'in':
# turn 'foreign_obj__in' into 'foreign_obj__pk' for later
filter_name = field_name+'__pk__in'
filter_value = tuple(obj.pk for obj in filter_value)
filter_type = 'pk'
@ -419,29 +529,36 @@ class BaseQueryWrapper(BaseWrapper):
q = Q(**{filter_name: filter_value})
if filter_type == class_value.field.model._meta.pk.name:
# turn <name of the primary key field> into pk for later
filter_type = 'pk'
if filter_type == 'pk' and segments == ['in']:
# foreign_obj__pk__in
q = Q(**{field_name+'__pk__in': tuple(pk for pk in filter_value if not is_created_pk(pk))})
filter_value = tuple(str(pk) for pk in filter_value)
return self._filter_values(q, field_name, lambda val: str(val) in filter_value)
if segments:
# wo don't do multi-level lookups
raise NotImplementedError
if filter_type == 'pk':
# foreign_obj__pk
if is_created_pk(filter_value):
q = Q(pk__in=())
filter_value = str(filter_value)
return self._filter_values(q, field_name, lambda val: str(val) == filter_value)
if filter_type == 'isnull':
# foreign_obj__isnull
return self._filter_values(q, field_name, lambda val: (val is None) is filter_value)
raise NotImplementedError
# check if we are filtering by a many to many field
if isinstance(class_value, ManyToManyDescriptor):
if not segments:
# turn 'm2m' into 'm2m__pk' for later
filter_name = field_name + '__pk'
filter_value = filter_value.pk
segments = ['pk']
@ -450,6 +567,7 @@ class BaseQueryWrapper(BaseWrapper):
filter_type = segments.pop(0)
if not segments and filter_type == 'in':
# turn 'm2m__in' into 'm2m__pk__in' for later
filter_name = field_name+'__pk__in'
filter_value = tuple(obj.pk for obj in filter_value)
filter_type = 'pk'
@ -457,10 +575,13 @@ class BaseQueryWrapper(BaseWrapper):
q = Q(**{filter_name: filter_value})
if filter_type == class_value.field.model._meta.pk.name:
# turn <name of the primary key field> into pk for later
filter_type = 'pk'
if filter_type == 'pk' and segments == ['in']:
# m2m__pk__in
if not class_value.reverse:
# we don't do this in reverse
raise NotImplementedError
# so... e.g. we want to get all groups that belong to one of the given spaces.
@ -495,9 +616,11 @@ class BaseQueryWrapper(BaseWrapper):
return q, created_pks
if segments:
# we don't to multi-level lookups
raise NotImplementedError
if filter_type == 'pk':
# m2m__pk
if class_value.reverse:
model = class_value.field.model
@ -516,23 +639,29 @@ class BaseQueryWrapper(BaseWrapper):
Q(pk__in=(pk for pk in add_pks if not is_created_pk(pk)))),
set(int(pk[1:]) for pk in add_pks if is_created_pk(pk)))
# sorry, no reverse lookup
raise NotImplementedError
raise NotImplementedError
# check if field is a deffered attribute, e.g. a CharField
if isinstance(class_value, DeferredAttribute):
if not segments:
# field=
return self._filter_values(q, field_name, lambda val: val == filter_value)
filter_type = segments.pop(0)
if segments:
# we don't to field__whatever__whatever
raise NotImplementedError
if filter_type == 'in':
# field__in
return self._filter_values(q, field_name, lambda val: val in filter_value)
if filter_type == 'lt':
# field__lt
return self._filter_values(q, field_name, lambda val: val < filter_value)
raise NotImplementedError
@ -540,6 +669,11 @@ class BaseQueryWrapper(BaseWrapper):
raise NotImplementedError('cannot filter %s by %s (%s)' % (self._obj.model, filter_name, class_value))
def _filter_q(self, q):
"""
filter by Q object.
Split it up into recursive _filter_q and _filter_kwarg calls and combine them again.
:return: new Q object and set of matched existing pks
"""
filters, created_pks = zip(*((self._filter_q(c) if isinstance(c, Q) else self._filter_kwarg(*c))
for c in q.children))
result = Q(*filters)
@ -598,6 +732,9 @@ class BaseQueryWrapper(BaseWrapper):
@get_queryset
def extra(self, select):
"""
We only support the kind of extra() call that a many to many prefetch_related does.
"""
for key in select.keys():
if not key.startswith('_prefetch_related_val'):
raise NotImplementedError('extra() calls are only supported for prefetch_related!')
@ -605,14 +742,23 @@ class BaseQueryWrapper(BaseWrapper):
@get_queryset
def _next_is_sticky(self):
"""
Needed by prefetch_related.
"""
return self._wrap_queryset(self._obj._next_is_sticky())
def _get_created_objects(self, get_foreign_objects=True):
"""
Get ModelInstanceWrapper instance for all matched created objects.
"""
return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=get_foreign_objects)
for pk in sorted(self._created_pks))
@queryset_only
def _get_cached_result(self):
"""
Get results, make sure prefetch is prefetching and so on.
"""
obj = self._obj
obj._prefetch_done = True
obj._fetch_all()
@ -625,6 +771,7 @@ class BaseQueryWrapper(BaseWrapper):
result += list(self._get_created_objects())
for extra in self._extra:
# implementing the extra() call for prefetch_related
ex = extra[22:]
for f in self._obj.model._meta.get_fields():
if isinstance(f, ManyToManyRel) and f.through._meta.get_field(f.field.m2m_field_name()).attname == ex:
@ -667,6 +814,8 @@ class BaseQueryWrapper(BaseWrapper):
@_result_cache.setter
def _result_cache(self, value):
# prefetch_related will try to set this property
# it has to overwrite our final result because it already contains the created objects
self.__dict__['_cached_result'] = value
@queryset_only
@ -696,19 +845,38 @@ class BaseQueryWrapper(BaseWrapper):
class ManagerWrapper(BaseQueryWrapper):
"""
Wraps a manager.
This class itself is used to wrap Model.objects managers.
"""
def get_queryset(self):
"""
make sure that the database does not return objects that have been deleted in this changeset
"""
qs = self._wrap_queryset(self._obj.model.objects.all())
return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ()))
class RelatedManagerWrapper(ManagerWrapper):
"""
Wraps a related manager.
"""
def _get_cache_name(self):
"""
get cache name to fetch prefetch_related results
"""
return self._obj.field.related_query_name()
def get_queryset(self):
"""
filter queryset by related manager filters
"""
return super().get_queryset().filter(**self._obj.core_filters)
def all(self):
"""
get prefetched result if it exists
"""
try:
return self.instance._prefetched_objects_cache[self._get_cache_name()]
except(AttributeError, KeyError):
@ -723,6 +891,9 @@ class RelatedManagerWrapper(ManagerWrapper):
class ManyRelatedManagerWrapper(RelatedManagerWrapper):
"""
Wraps a many related manager (see RelatedManagerWrapper for details)
"""
def _check_through(self):
if not self._obj.through._meta.auto_created:
raise AttributeError('Cannot do this an a ManyToManyField which specifies an intermediary model.')
@ -768,6 +939,9 @@ class ManyRelatedManagerWrapper(RelatedManagerWrapper):
class QuerySetWrapper(BaseQueryWrapper):
"""
Wraps a queryset.
"""
@property
def _iterable_class(self):
return self._obj._iterable_class