diff --git a/src/c3nav/editor/models/change.py b/src/c3nav/editor/models/change.py index 5766406b..cfdcbc8d 100644 --- a/src/c3nav/editor/models/change.py +++ b/src/c3nav/editor/models/change.py @@ -9,7 +9,8 @@ from django.db.models import Q from django.utils.translation import ugettext_lazy as _ from c3nav.editor.models import ChangeSet -from c3nav.editor.wrappers import ModelInstanceWrapper, is_created_pk +from c3nav.editor.utils import is_created_pk +from c3nav.editor.wrappers import ModelInstanceWrapper class Change(models.Model): diff --git a/src/c3nav/editor/models/changeset.py b/src/c3nav/editor/models/changeset.py index 95224a4a..baa0a606 100644 --- a/src/c3nav/editor/models/changeset.py +++ b/src/c3nav/editor/models/changeset.py @@ -12,7 +12,8 @@ from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ungettext_lazy from c3nav.editor.models.change import Change -from c3nav.editor.wrappers import ModelWrapper, is_created_pk +from c3nav.editor.utils import is_created_pk +from c3nav.editor.wrappers import ModelWrapper from c3nav.mapdata.models import LocationSlug from c3nav.mapdata.models.locations import LocationRedirect diff --git a/src/c3nav/editor/utils.py b/src/c3nav/editor/utils.py new file mode 100644 index 00000000..515099a6 --- /dev/null +++ b/src/c3nav/editor/utils.py @@ -0,0 +1,2 @@ +def is_created_pk(pk): + return isinstance(pk, str) and pk.startswith('c') and pk[1:].isnumeric() diff --git a/src/c3nav/editor/views/changes.py b/src/c3nav/editor/views/changes.py index 308f9de7..177e2514 100644 --- a/src/c3nav/editor/views/changes.py +++ b/src/c3nav/editor/views/changes.py @@ -8,8 +8,8 @@ from django.utils.formats import date_format from django.utils.translation import ugettext_lazy as _ from c3nav.editor.models import ChangeSet +from c3nav.editor.utils import is_created_pk from c3nav.editor.views.base import sidebar_view -from c3nav.editor.wrappers import is_created_pk from c3nav.mapdata.models.locations import LocationRedirect, LocationSlug diff --git a/src/c3nav/editor/wrappers/__init__.py b/src/c3nav/editor/wrappers/__init__.py new file mode 100644 index 00000000..e359b631 --- /dev/null +++ b/src/c3nav/editor/wrappers/__init__.py @@ -0,0 +1,2 @@ +from c3nav.editor.wrappers.instance import ModelInstanceWrapper # noqa +from c3nav.editor.wrappers.model import ModelWrapper # noqa diff --git a/src/c3nav/editor/wrappers/base.py b/src/c3nav/editor/wrappers/base.py new file mode 100644 index 00000000..f4025d04 --- /dev/null +++ b/src/c3nav/editor/wrappers/base.py @@ -0,0 +1,81 @@ +from functools import wraps + +from django.db import models +from django.db.models import Manager + +from c3nav.editor.wrappers import ModelInstanceWrapper, ModelWrapper +from c3nav.editor.wrappers.manager import ManagerWrapper, ManyRelatedManagerWrapper, RelatedManagerWrapper +from c3nav.editor.wrappers.query import QuerySetWrapper + + +class BaseWrapper: + _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_result_cache', + '_initial_values') + _allowed_callables = () + _wrapped_callables = () + + def __init__(self, changeset, obj, author=None): + self._changeset = changeset + self._author = author + self._obj = obj + + # noinspection PyUnresolvedReferences + def _wrap_model(self, model): + if isinstance(model, type) and issubclass(model, ModelInstanceWrapper): + model = model._parent + if isinstance(model, ModelWrapper): + if self._author == model._author and self._changeset == model._changeset: + return model + model = model._obj + assert issubclass(model, models.Model) + return ModelWrapper(self._changeset, model, self._author) + + def _wrap_instance(self, instance): + if isinstance(instance, ModelInstanceWrapper): + if self._author == instance._author and self._changeset == instance._changeset: + return instance + instance = instance._obj + assert isinstance(instance, models.Model) + return self._wrap_model(type(instance)).create_wrapped_model_class()(self._changeset, instance, self._author) + + def _wrap_manager(self, manager): + assert isinstance(manager, Manager) + if hasattr(manager, 'through'): + return ManyRelatedManagerWrapper(self._changeset, manager, self._author) + if hasattr(manager, 'instance'): + return RelatedManagerWrapper(self._changeset, manager, self._author) + return ManagerWrapper(self._changeset, manager, self._author) + + def _wrap_queryset(self, queryset): + return QuerySetWrapper(self._changeset, queryset, self._author) + + def __getattr__(self, name): + value = getattr(self._obj, name) + if isinstance(value, Manager): + value = self._wrap_manager(value) + elif isinstance(value, type) and issubclass(value, models.Model) and value._meta.app_label == 'mapdata': + value = self._wrap_model(value) + elif isinstance(value, models.Model) and value._meta.app_label == 'mapdata': + value = self._wrap_instance(value) + elif isinstance(value, type) and issubclass(value, Exception): + pass + elif callable(value) and name not in self._allowed_callables: + if name in self._wrapped_callables: + func = getattr(self._obj.__class__, name) + + @wraps(func) + def wrapper(*args, **kwargs): + return func(self, *args, **kwargs) + return wrapper + if isinstance(self, ModelInstanceWrapper) and not hasattr(models.Model, name): + return value + raise TypeError('Can not call %s.%s wrapped!' % (type(self), name)) + return value + + def __setattr__(self, name, value): + if name in self._not_wrapped: + return super().__setattr__(name, value) + return setattr(self._obj, name, value) + + def __delattr__(self, name): + return delattr(self._obj, name) diff --git a/src/c3nav/editor/wrappers/instance.py b/src/c3nav/editor/wrappers/instance.py new file mode 100644 index 00000000..42e12318 --- /dev/null +++ b/src/c3nav/editor/wrappers/instance.py @@ -0,0 +1,117 @@ +from django.db import models +from django.db.models import Field +from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor + +from c3nav.editor.utils import is_created_pk +from c3nav.editor.wrappers import BaseWrapper + + +class ModelInstanceWrapper(BaseWrapper): + _allowed_callables = ('full_clean', '_perform_unique_checks', '_perform_date_checks') + _wrapped_callables = ('validate_unique', '_get_pk_val') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + updates = self._changeset.updated_existing.get(type(self._obj), {}).get(self._obj.pk, {}) + self._initial_values = {} + for field in self._obj._meta.get_fields(): + if not isinstance(field, Field): + continue + if field.related_model is None: + if field.primary_key: + continue + + if field.name == 'titles': + for name, value in updates.items(): + if not name.startswith('title_'): + continue + if not value: + self._obj.titles.pop(name[6:], None) + else: + self._obj.titles[name[6:]] = value + elif field.name in updates: + setattr(self._obj, field.name, field.to_python(updates[field.name])) + self._initial_values[field] = getattr(self._obj, field.name) + elif (field.many_to_one or field.one_to_one) and not field.primary_key: + if field.name in updates: + value_pk = updates[field.name] + class_value = getattr(type(self._obj), field.name, None) + if is_created_pk(value_pk): + obj = self._wrap_model(field.model).get(pk=value_pk) + setattr(self._obj, class_value.cache_name, obj) + setattr(self._obj, field.attname, obj.pk) + else: + delattr(self._obj, class_value.cache_name) + setattr(self._obj, field.attname, value_pk) + self._initial_values[field] = getattr(self._obj, field.attname) + + def __eq__(self, other): + if isinstance(other, BaseWrapper): + if type(self._obj) is not type(other._obj): # noqa + return False + elif type(self._obj) is not type(other): + return False + return self.pk == other.pk + + def __setattr__(self, name, value): + if name in self._not_wrapped: + return super().__setattr__(name, value) + class_value = getattr(type(self._obj), name, None) + if isinstance(class_value, ForwardManyToOneDescriptor) and value is not None: + if isinstance(value, models.Model): + value = self._wrap_instance(value) + if not isinstance(value, ModelInstanceWrapper): + raise ValueError('value has to be None or ModelInstanceWrapper') + setattr(self._obj, name, value._obj) + setattr(self._obj, class_value.cache_name, value) + return + super().__setattr__(name, value) + + def __repr__(self): + cls_name = self._obj.__class__.__name__ + if self.pk is None: + return '<%s (unsaved) with Changeset #%d>' % (cls_name, self._changeset.pk) + elif is_created_pk(self.pk): + return '<%s #%s (created) from Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) + return '<%s #%d (existing) with Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) + + def _get_unique_checks(self, exclude=None): + unique_checks, date_checks = self._obj.__class__._get_unique_checks(self, exclude=exclude) + return [(self._wrap_model(model), unique) for model, unique in unique_checks], date_checks + + def save(self, author=None): + if author is None: + author = self._author + if self.pk is None: + self._changeset.add_create(self, author=author) + for field, initial_value in self._initial_values.items(): + class_value = getattr(type(self._obj), field.name, None) + if isinstance(class_value, ForwardManyToOneDescriptor): + try: + new_value = getattr(self._obj, class_value.cache_name) + except AttributeError: + new_value = getattr(self._obj, field.attname) + else: + new_value = None if new_value is None else new_value.pk + + if new_value != initial_value: + self._changeset.add_update(self, name=field.name, value=new_value, author=author) + continue + + new_value = getattr(self._obj, field.name) + if new_value == initial_value: + continue + + if field.name == 'titles': + for lang in (set(initial_value.keys()) | set(new_value.keys())): + new_title = new_value.get(lang, '') + if new_title != initial_value.get(lang, ''): + self._changeset.add_update(self, name='title_'+lang, value=new_title, author=author) + continue + + self._changeset.add_update(self, name=field.name, value=field.get_prep_value(new_value), author=author) + + def delete(self, author=None): + if author is None: + author = self._author + self._changeset.add_delete(self, author=author) diff --git a/src/c3nav/editor/wrappers/manager.py b/src/c3nav/editor/wrappers/manager.py new file mode 100644 index 00000000..b263e0d1 --- /dev/null +++ b/src/c3nav/editor/wrappers/manager.py @@ -0,0 +1,73 @@ +from c3nav.editor.wrappers import BaseQueryWrapper + + +class ManagerWrapper(BaseQueryWrapper): + def get_queryset(self): + qs = self._wrap_queryset(self._obj.model.objects.all()) + return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ())) + + +class RelatedManagerWrapper(ManagerWrapper): + def _get_cache_name(self): + return self._obj.field.related_query_name() + + def get_queryset(self): + return super().get_queryset().filter(**self._obj.core_filters) + + def all(self): + try: + return self.instance._prefetched_objects_cache[self._get_cache_name()] + except(AttributeError, KeyError): + pass + return super().all() + + def create(self, *args, **kwargs): + if self.instance.pk is None: + raise TypeError + kwargs[self._obj.field.name] = self.instance + super().create(*args, **kwargs) + + +class ManyRelatedManagerWrapper(RelatedManagerWrapper): + def _check_through(self): + if not self._obj.through._meta.auto_created: + raise AttributeError('Cannot do this an a ManyToManyField which specifies an intermediary model.') + + def _get_cache_name(self): + return self._obj.prefetch_cache_name + + def set(self, objs, author=None): + if author is None: + author = self._author + + old_ids = set(self.values_list('pk', flat=True)) + new_ids = set(obj.pk for obj in objs) + + self.remove(*(old_ids - new_ids), author=author) + self.add(*(new_ids - old_ids), author=author) + + def add(self, *objs, author=None): + if author is None: + author = self._author + + for obj in objs: + pk = (obj.pk if isinstance(obj, self._obj.model) else obj) + self._changeset.add_m2m_add(self._obj.instance, name=self._get_cache_name(), value=pk, author=author) + + def remove(self, *objs, author=None): + if author is None: + author = self._author + + for obj in objs: + pk = (obj.pk if isinstance(obj, self._obj.model) else obj) + self._changeset.add_m2m_remove(self._obj.instance, name=self._get_cache_name(), value=pk, author=author) + + def all(self): + try: + return self.instance._prefetched_objects_cache[self._get_cache_name()] + except(AttributeError, KeyError): + pass + return super().all() + + def create(self, *args, **kwargs): + raise NotImplementedError diff --git a/src/c3nav/editor/wrappers/model.py b/src/c3nav/editor/wrappers/model.py new file mode 100644 index 00000000..e78fe0b5 --- /dev/null +++ b/src/c3nav/editor/wrappers/model.py @@ -0,0 +1,71 @@ +import typing +from itertools import chain + +from django.utils.functional import cached_property + +from c3nav.editor.forms import create_editor_form +from c3nav.editor.wrappers import BaseWrapper, ModelInstanceWrapper + + +class ModelWrapper(BaseWrapper): + _allowed_callables = ('EditorForm',) + _submodels_by_model = {} + + def __eq__(self, other): + if type(other) == ModelWrapper: + return self._obj is other._obj + return self._obj is other + + @cached_property + def EditorForm(self): + return create_editor_form(self._obj) + + @classmethod + def get_submodels(cls, model): + try: + return cls._submodels_by_model[model] + except KeyError: + pass + all_models = model.__subclasses__() + result = [] + if not model._meta.abstract: + result.append(model) + result.extend(chain(*(cls.get_submodels(model) for model in all_models))) + cls._submodels_by_model[model] = result + return result + + @cached_property + def _submodels(self): + return self.get_submodels(self._obj) + + def create_wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']: + # noinspection PyTypeChecker + return self.create_metaclass()(self._obj.__name__ + 'InstanceWrapper', (ModelInstanceWrapper,), {}) + + def __call__(self, **kwargs): + instance = self._wrap_instance(self._obj()) + for name, value in kwargs.items(): + setattr(instance, name, value) + return instance + + def create_metaclass(self): + parent = self + + class ModelInstanceWrapperMeta(type): + _parent = parent + + def __getattr__(self, name): + return getattr(parent, name) + + def __setattr__(self, name, value): + setattr(parent, name, value) + + def __delattr__(self, name): + delattr(parent, name) + + ModelInstanceWrapperMeta.__name__ = self._obj.__name__+'InstanceWrapperMeta' + + return ModelInstanceWrapperMeta + + def __repr__(self): + return '' diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers/query.py similarity index 57% rename from src/c3nav/editor/wrappers.py rename to src/c3nav/editor/wrappers/query.py index 94046bbc..414838ec 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers/query.py @@ -1,268 +1,14 @@ import operator -import typing from collections import OrderedDict from functools import reduce, wraps from itertools import chain -from django.db import models -from django.db.models import Field, Manager, ManyToManyRel, Prefetch, Q +from django.db.models import DeferredAttribute, ManyToManyRel, Prefetch, Q from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor -from django.db.models.query_utils import DeferredAttribute from django.utils.functional import cached_property -from c3nav.editor.forms import create_editor_form - - -def is_created_pk(pk): - return isinstance(pk, str) and pk.startswith('c') and pk[1:].isnumeric() - - -class BaseWrapper: - _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_result_cache', - '_initial_values') - _allowed_callables = () - _wrapped_callables = () - - def __init__(self, changeset, obj, author=None): - self._changeset = changeset - self._author = author - self._obj = obj - - # noinspection PyUnresolvedReferences - def _wrap_model(self, model): - if isinstance(model, type) and issubclass(model, ModelInstanceWrapper): - model = model._parent - if isinstance(model, ModelWrapper): - if self._author == model._author and self._changeset == model._changeset: - return model - model = model._obj - assert issubclass(model, models.Model) - return ModelWrapper(self._changeset, model, self._author) - - def _wrap_instance(self, instance): - if isinstance(instance, ModelInstanceWrapper): - if self._author == instance._author and self._changeset == instance._changeset: - return instance - instance = instance._obj - assert isinstance(instance, models.Model) - return self._wrap_model(type(instance)).create_wrapped_model_class()(self._changeset, instance, self._author) - - def _wrap_manager(self, manager): - assert isinstance(manager, Manager) - if hasattr(manager, 'through'): - return ManyRelatedManagerWrapper(self._changeset, manager, self._author) - if hasattr(manager, 'instance'): - return RelatedManagerWrapper(self._changeset, manager, self._author) - return ManagerWrapper(self._changeset, manager, self._author) - - def _wrap_queryset(self, queryset): - return QuerySetWrapper(self._changeset, queryset, self._author) - - def __getattr__(self, name): - value = getattr(self._obj, name) - if isinstance(value, Manager): - value = self._wrap_manager(value) - elif isinstance(value, type) and issubclass(value, models.Model) and value._meta.app_label == 'mapdata': - value = self._wrap_model(value) - elif isinstance(value, models.Model) and value._meta.app_label == 'mapdata': - value = self._wrap_instance(value) - elif isinstance(value, type) and issubclass(value, Exception): - pass - elif callable(value) and name not in self._allowed_callables: - if name in self._wrapped_callables: - func = getattr(self._obj.__class__, name) - - @wraps(func) - def wrapper(*args, **kwargs): - return func(self, *args, **kwargs) - return wrapper - if isinstance(self, ModelInstanceWrapper) and not hasattr(models.Model, name): - return value - raise TypeError('Can not call %s.%s wrapped!' % (type(self), name)) - return value - - def __setattr__(self, name, value): - if name in self._not_wrapped: - return super().__setattr__(name, value) - return setattr(self._obj, name, value) - - def __delattr__(self, name): - return delattr(self._obj, name) - - -class ModelWrapper(BaseWrapper): - _allowed_callables = ('EditorForm',) - _submodels_by_model = {} - - def __eq__(self, other): - if type(other) == ModelWrapper: - return self._obj is other._obj - return self._obj is other - - @cached_property - def EditorForm(self): - return create_editor_form(self._obj) - - @classmethod - def get_submodels(cls, model): - try: - return cls._submodels_by_model[model] - except KeyError: - pass - all_models = model.__subclasses__() - result = [] - if not model._meta.abstract: - result.append(model) - result.extend(chain(*(cls.get_submodels(model) for model in all_models))) - cls._submodels_by_model[model] = result - return result - - @cached_property - def _submodels(self): - return self.get_submodels(self._obj) - - def create_wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']: - # noinspection PyTypeChecker - return self.create_metaclass()(self._obj.__name__ + 'InstanceWrapper', (ModelInstanceWrapper,), {}) - - def __call__(self, **kwargs): - instance = self._wrap_instance(self._obj()) - for name, value in kwargs.items(): - setattr(instance, name, value) - return instance - - def create_metaclass(self): - parent = self - - class ModelInstanceWrapperMeta(type): - _parent = parent - - def __getattr__(self, name): - return getattr(parent, name) - - def __setattr__(self, name, value): - setattr(parent, name, value) - - def __delattr__(self, name): - delattr(parent, name) - - ModelInstanceWrapperMeta.__name__ = self._obj.__name__+'InstanceWrapperMeta' - - return ModelInstanceWrapperMeta - - def __repr__(self): - return '' - - -class ModelInstanceWrapper(BaseWrapper): - _allowed_callables = ('full_clean', '_perform_unique_checks', '_perform_date_checks') - _wrapped_callables = ('validate_unique', '_get_pk_val') - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - updates = self._changeset.updated_existing.get(type(self._obj), {}).get(self._obj.pk, {}) - self._initial_values = {} - for field in self._obj._meta.get_fields(): - if not isinstance(field, Field): - continue - if field.related_model is None: - if field.primary_key: - continue - - if field.name == 'titles': - for name, value in updates.items(): - if not name.startswith('title_'): - continue - if not value: - self._obj.titles.pop(name[6:], None) - else: - self._obj.titles[name[6:]] = value - elif field.name in updates: - setattr(self._obj, field.name, field.to_python(updates[field.name])) - self._initial_values[field] = getattr(self._obj, field.name) - elif (field.many_to_one or field.one_to_one) and not field.primary_key: - if field.name in updates: - value_pk = updates[field.name] - class_value = getattr(type(self._obj), field.name, None) - if is_created_pk(value_pk): - obj = self._wrap_model(field.model).get(pk=value_pk) - setattr(self._obj, class_value.cache_name, obj) - setattr(self._obj, field.attname, obj.pk) - else: - delattr(self._obj, class_value.cache_name) - setattr(self._obj, field.attname, value_pk) - self._initial_values[field] = getattr(self._obj, field.attname) - - def __eq__(self, other): - if isinstance(other, BaseWrapper): - if type(self._obj) is not type(other._obj): # noqa - return False - elif type(self._obj) is not type(other): - return False - return self.pk == other.pk - - def __setattr__(self, name, value): - if name in self._not_wrapped: - return super().__setattr__(name, value) - class_value = getattr(type(self._obj), name, None) - if isinstance(class_value, ForwardManyToOneDescriptor) and value is not None: - if isinstance(value, models.Model): - value = self._wrap_instance(value) - if not isinstance(value, ModelInstanceWrapper): - raise ValueError('value has to be None or ModelInstanceWrapper') - setattr(self._obj, name, value._obj) - setattr(self._obj, class_value.cache_name, value) - return - super().__setattr__(name, value) - - def __repr__(self): - cls_name = self._obj.__class__.__name__ - if self.pk is None: - return '<%s (unsaved) with Changeset #%d>' % (cls_name, self._changeset.pk) - elif is_created_pk(self.pk): - return '<%s #%s (created) from Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) - return '<%s #%d (existing) with Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) - - def _get_unique_checks(self, exclude=None): - unique_checks, date_checks = self._obj.__class__._get_unique_checks(self, exclude=exclude) - return [(self._wrap_model(model), unique) for model, unique in unique_checks], date_checks - - def save(self, author=None): - if author is None: - author = self._author - if self.pk is None: - self._changeset.add_create(self, author=author) - for field, initial_value in self._initial_values.items(): - class_value = getattr(type(self._obj), field.name, None) - if isinstance(class_value, ForwardManyToOneDescriptor): - try: - new_value = getattr(self._obj, class_value.cache_name) - except AttributeError: - new_value = getattr(self._obj, field.attname) - else: - new_value = None if new_value is None else new_value.pk - - if new_value != initial_value: - self._changeset.add_update(self, name=field.name, value=new_value, author=author) - continue - - new_value = getattr(self._obj, field.name) - if new_value == initial_value: - continue - - if field.name == 'titles': - for lang in (set(initial_value.keys()) | set(new_value.keys())): - new_title = new_value.get(lang, '') - if new_title != initial_value.get(lang, ''): - self._changeset.add_update(self, name='title_'+lang, value=new_title, author=author) - continue - - self._changeset.add_update(self, name=field.name, value=field.get_prep_value(new_value), author=author) - - def delete(self, author=None): - if author is None: - author = self._author - self._changeset.add_delete(self, author=author) +from c3nav.editor.utils import is_created_pk +from c3nav.editor.wrappers import BaseWrapper def get_queryset(func): @@ -698,78 +444,6 @@ class BaseQueryWrapper(BaseWrapper): obj.delete() -class ManagerWrapper(BaseQueryWrapper): - def get_queryset(self): - qs = self._wrap_queryset(self._obj.model.objects.all()) - return qs.exclude(pk__in=self._changeset.deleted_existing.get(self._obj.model, ())) - - -class RelatedManagerWrapper(ManagerWrapper): - def _get_cache_name(self): - return self._obj.field.related_query_name() - - def get_queryset(self): - return super().get_queryset().filter(**self._obj.core_filters) - - def all(self): - try: - return self.instance._prefetched_objects_cache[self._get_cache_name()] - except(AttributeError, KeyError): - pass - return super().all() - - def create(self, *args, **kwargs): - if self.instance.pk is None: - raise TypeError - kwargs[self._obj.field.name] = self.instance - super().create(*args, **kwargs) - - -class ManyRelatedManagerWrapper(RelatedManagerWrapper): - def _check_through(self): - if not self._obj.through._meta.auto_created: - raise AttributeError('Cannot do this an a ManyToManyField which specifies an intermediary model.') - - def _get_cache_name(self): - return self._obj.prefetch_cache_name - - def set(self, objs, author=None): - if author is None: - author = self._author - - old_ids = set(self.values_list('pk', flat=True)) - new_ids = set(obj.pk for obj in objs) - - self.remove(*(old_ids - new_ids), author=author) - self.add(*(new_ids - old_ids), author=author) - - def add(self, *objs, author=None): - if author is None: - author = self._author - - for obj in objs: - pk = (obj.pk if isinstance(obj, self._obj.model) else obj) - self._changeset.add_m2m_add(self._obj.instance, name=self._get_cache_name(), value=pk, author=author) - - def remove(self, *objs, author=None): - if author is None: - author = self._author - - for obj in objs: - pk = (obj.pk if isinstance(obj, self._obj.model) else obj) - self._changeset.add_m2m_remove(self._obj.instance, name=self._get_cache_name(), value=pk, author=author) - - def all(self): - try: - return self.instance._prefetched_objects_cache[self._get_cache_name()] - except(AttributeError, KeyError): - pass - return super().all() - - def create(self, *args, **kwargs): - raise NotImplementedError - - class QuerySetWrapper(BaseQueryWrapper): @property def _iterable_class(self):