diff --git a/src/c3nav/editor/api/geometries.py b/src/c3nav/editor/api/geometries.py index 050cc48a..3bf35b4f 100644 --- a/src/c3nav/editor/api/geometries.py +++ b/src/c3nav/editor/api/geometries.py @@ -8,6 +8,9 @@ from shapely.ops import unary_union from c3nav.api.exceptions import API404, APIPermissionDenied from c3nav.editor.utils import LevelChildEditUtils, SpaceChildEditUtils +from c3nav.mapdata.models import Level, Space, GraphNode, Door, LocationGroup, Building, GraphEdge, DataOverlayFeature +from c3nav.mapdata.models.geometry.space import Column, Hole, AltitudeMarker, BeaconMeasurement, RangingBeacon, Area, \ + POI from c3nav.mapdata.utils.geometry import unwrap_geom @@ -58,10 +61,6 @@ def _get_geometries_for_one_level(level): return results -if TYPE_CHECKING: - from c3nav.mapdata.models import Level - - @dataclass(slots=True) class LevelsForLevel: levels: Sequence[int] # IDs of all levels to render for this level, in order, including the level itself @@ -69,9 +68,8 @@ class LevelsForLevel: levels_under: Sequence[int] # IDs of the level below this level plus levels on top of it (on_top_of field) @classmethod - def for_level(cls, request, level: "Level", special_if_on_top=False): # add typing + def for_level(cls, request, level: Level, special_if_on_top=False): # add typing # noinspection PyPep8Naming - Level = request.changeset.wrap_model('Level') levels_under = () levels_on_top = () lower_level = level.lower(Level).first() @@ -114,18 +112,6 @@ def conditional_geojson(obj, update_cache_key_match): # noinspection PyPep8Naming def get_level_geometries_result(request, level_id: int, update_cache_key: str, update_cache_key_match: True): - Level = request.changeset.wrap_model('Level') - Space = request.changeset.wrap_model('Space') - Column = request.changeset.wrap_model('Column') - Hole = request.changeset.wrap_model('Hole') - AltitudeMarker = request.changeset.wrap_model('AltitudeMarker') - Building = request.changeset.wrap_model('Building') - Door = request.changeset.wrap_model('Door') - LocationGroup = request.changeset.wrap_model('LocationGroup') - BeaconMeasurement = request.changeset.wrap_model('BeaconMeasurement') - RangingBeacon = request.changeset.wrap_model('RangingBeacon') - DataOverlayFeature = request.changeset.wrap_model('DataOverlayFeature') - try: level = Level.objects.filter(Level.q_for_request(request)).get(pk=level_id) except Level.DoesNotExist: @@ -138,7 +124,7 @@ def get_level_geometries_result(request, level_id: int, update_cache_key: str, u levels_for_level = LevelsForLevel.for_level(request, level) # don't prefetch groups for now as changesets do not yet work with m2m-prefetches levels = Level.objects.filter(pk__in=levels_for_level.levels).filter(Level.q_for_request(request)) - graphnodes_qs = request.changeset.wrap_model('GraphNode').objects.all() + graphnodes_qs = GraphNode.objects.all() levels = levels.prefetch_related( Prefetch('spaces', Space.objects.filter(Space.q_for_request(request)).only( 'geometry', 'level', 'outside' @@ -170,7 +156,7 @@ def get_level_geometries_result(request, level_id: int, update_cache_key: str, u for space in chain(*(level.spaces.all() for level in levels.values()))))) graphnodes_lookup = {node.pk: node for node in graphnodes} - graphedges = request.changeset.wrap_model('GraphEdge').objects.all() + graphedges = GraphEdge.objects.all() graphedges = graphedges.filter(Q(from_node__in=graphnodes) | Q(to_node__in=graphnodes)) graphedges = graphedges.select_related('waytype', 'from_node', 'to_node') @@ -202,12 +188,6 @@ def get_level_geometries_result(request, level_id: int, update_cache_key: str, u def get_space_geometries_result(request, space_id: int, update_cache_key: str, update_cache_key_match: bool): - Space = request.changeset.wrap_model('Space') - Area = request.changeset.wrap_model('Area') - POI = request.changeset.wrap_model('POI') - Door = request.changeset.wrap_model('Door') - LocationGroup = request.changeset.wrap_model('LocationGroup') - space_q_for_request = Space.q_for_request(request) qs = Space.objects.filter(space_q_for_request) @@ -271,12 +251,12 @@ def get_space_geometries_result(request, space_id: int, update_cache_key: str, u # todo: permissions if request.user_permissions.can_access_base_mapdata: - graph_nodes = request.changeset.wrap_model('GraphNode').objects.all() + graph_nodes = GraphNode.objects.all() graph_nodes = graph_nodes.filter((Q(space__in=all_other_spaces)) | Q(space__pk=space.pk)) space_graph_nodes = tuple(node for node in graph_nodes if node.space_id == space.pk) - graph_edges = request.changeset.wrap_model('GraphEdge').objects.all() + graph_edges = GraphEdge.objects.all() space_graphnodes_ids = tuple(node.pk for node in space_graph_nodes) graph_edges = graph_edges.filter(Q(from_node__pk__in=space_graphnodes_ids) | Q(to_node__pk__in=space_graphnodes_ids)) diff --git a/src/c3nav/editor/forms.py b/src/c3nav/editor/forms.py index f39ef208..1dfe5ddc 100644 --- a/src/c3nav/editor/forms.py +++ b/src/c3nav/editor/forms.py @@ -21,8 +21,11 @@ from shapely.geometry.geo import mapping from c3nav.editor.models import ChangeSet, ChangeSetUpdate from c3nav.mapdata.fields import GeometryField from c3nav.mapdata.forms import I18nModelFormMixin -from c3nav.mapdata.models import GraphEdge -from c3nav.mapdata.models.access import AccessPermission +from c3nav.mapdata.models import GraphEdge, LocationGroup, Source, LocationGroupCategory, GraphNode, Space, \ + LocationSlug, WayType +from c3nav.mapdata.models.access import AccessPermission, AccessRestrictionGroup, AccessRestriction +from c3nav.mapdata.models.geometry.space import ObstacleGroup +from c3nav.mapdata.models.theme import ThemeLocationGroupBackgroundColor, ThemeObstacleGroupBackgroundColor from c3nav.routing.schemas import LocateRequestWifiPeerSchema @@ -32,11 +35,6 @@ class EditorFormBase(I18nModelFormMixin, ModelForm): super().__init__(*args, **kwargs) creating = not self.instance.pk - LocationGroup = request.changeset.wrap_model('LocationGroup') - ThemeLocationGroupBackgroundColor = request.changeset.wrap_model('ThemeLocationGroupBackgroundColor') - ThemeObstacleGroupBackgroundColor = request.changeset.wrap_model('ThemeObstacleGroupBackgroundColor') - ObstacleGroup = request.changeset.wrap_model('ObstacleGroup') - if self._meta.model.__name__ == 'Theme': if creating: locationgroup_theme_colors = {} @@ -130,8 +128,6 @@ class EditorFormBase(I18nModelFormMixin, ModelForm): ) if self._meta.model.__name__ == 'Source' and self.request.user.is_superuser: - Source = self.request.changeset.wrap_model('Source') - sources = {s['name']: s for s in Source.objects.all().values('name', 'access_restriction_id', 'left', 'bottom', 'right', 'top')} used_names = set(sources.keys()) @@ -176,14 +172,10 @@ class EditorFormBase(I18nModelFormMixin, ModelForm): self.fields.move_to_end('name', last=False) if self._meta.model.__name__ == 'AccessRestriction': - AccessRestrictionGroup = self.request.changeset.wrap_model('AccessRestrictionGroup') - self.fields['groups'].label_from_instance = lambda obj: obj.title self.fields['groups'].queryset = AccessRestrictionGroup.qs_for_request(self.request) elif 'groups' in self.fields: - LocationGroupCategory = self.request.changeset.wrap_model('LocationGroupCategory') - kwargs = {'allow_'+self._meta.model._meta.default_related_name: True} categories = LocationGroupCategory.objects.filter(**kwargs).prefetch_related('groups') if self.instance.pk: @@ -228,8 +220,6 @@ class EditorFormBase(I18nModelFormMixin, ModelForm): self.fields['label_settings'].label_from_instance = attrgetter('title') if 'access_restriction' in self.fields: - AccessRestriction = self.request.changeset.wrap_model('AccessRestriction') - self.fields['access_restriction'].label_from_instance = lambda obj: obj.title self.fields['access_restriction'].queryset = AccessRestriction.qs_for_request(self.request).order_by( "titles__"+get_language(), "titles__en" @@ -240,11 +230,6 @@ class EditorFormBase(I18nModelFormMixin, ModelForm): self.fields['base_mapdata_accessible'].disabled = True if space_id and 'target_space' in self.fields: - Space = self.request.changeset.wrap_model('Space') - - GraphNode = self.request.changeset.wrap_model('GraphNode') - GraphEdge = self.request.changeset.wrap_model('GraphEdge') - cache_key = 'editor:neighbor_spaces:%s:%s%d' % ( self.request.changeset.raw_cache_key_by_changes, AccessPermission.cache_key_for_request(request, with_update=False), @@ -311,7 +296,6 @@ class EditorFormBase(I18nModelFormMixin, ModelForm): self.fields['slug'].run_validators(slug) model_slug_field.run_validators(slug) - LocationSlug = self.request.changeset.wrap_model('LocationSlug') qs = LocationSlug.objects.filter(slug__in=self.add_redirect_slugs) if 'slug' in self.cleaned_data and self.cleaned_data['slug'] in self.add_redirect_slugs: @@ -380,11 +364,6 @@ class EditorFormBase(I18nModelFormMixin, ModelForm): groups = tuple((int(val) if val.isdigit() else val) for val in groups) self.instance.groups.set(groups) - LocationGroup = self.request.changeset.wrap_model('LocationGroup') - ThemeLocationGroupBackgroundColor = self.request.changeset.wrap_model('ThemeLocationGroupBackgroundColor') - ThemeObstacleGroupBackgroundColor = self.request.changeset.wrap_model('ThemeObstacleGroupBackgroundColor') - ObstacleGroup = self.request.changeset.wrap_model('ObstacleGroup') - if self._meta.model.__name__ == 'Theme': locationgroup_colors = {theme_location_group.location_group_id: theme_location_group for theme_location_group in self.instance.location_groups.all()} @@ -454,6 +433,15 @@ def create_editor_form(editor_model): return EditorForm +editor_form_cache = {} +def get_editor_form(model): + form = editor_form_cache.get(model, None) + if form is None: + form = create_editor_form(model) + editor_form_cache[model] = form + return form + + class ChangeSetForm(ModelForm): class Meta: model = ChangeSet @@ -480,12 +468,10 @@ class GraphEdgeSettingsForm(ModelForm): self.request = request super().__init__(*args, **kwargs) - WayType = self.request.changeset.wrap_model('WayType') self.fields['waytype'].label_from_instance = lambda obj: obj.title self.fields['waytype'].queryset = WayType.objects.all() self.fields['waytype'].to_field_name = None - AccessRestriction = self.request.changeset.wrap_model('AccessRestriction') self.fields['access_restriction'].label_from_instance = lambda obj: obj.title self.fields['access_restriction'].queryset = AccessRestriction.qs_for_request(self.request) @@ -495,7 +481,6 @@ class GraphEditorActionForm(Form): self.request = request super().__init__(*args, **kwargs) - GraphNode = self.request.changeset.wrap_model('GraphNode') graph_node_qs = GraphNode.objects.all() self.fields['active_node'] = ModelChoiceField(graph_node_qs, widget=HiddenInput(), required=False) self.fields['clicked_node'] = ModelChoiceField(graph_node_qs, widget=HiddenInput(), required=False) @@ -503,7 +488,6 @@ class GraphEditorActionForm(Form): if allow_clicked_position: self.fields['clicked_position'] = JSONField(widget=HiddenInput(), required=False) - Space = self.request.changeset.wrap_model('Space') space_qs = Space.objects.all() self.fields['goto_space'] = ModelChoiceField(space_qs, widget=HiddenInput(), required=False) diff --git a/src/c3nav/editor/models/changedobject.py b/src/c3nav/editor/models/changedobject.py index bde6dfaa..13467a91 100644 --- a/src/c3nav/editor/models/changedobject.py +++ b/src/c3nav/editor/models/changedobject.py @@ -9,7 +9,7 @@ from django.db import models from django.db.models import CharField, DecimalField, Field, TextField from django.utils.translation import gettext_lazy as _ -from c3nav.editor.wrappers import ModelInstanceWrapper, is_created_pk +from c3nav.editor.wrappers import is_created_pk from c3nav.mapdata.fields import I18nField from c3nav.mapdata.models.locations import LocationRedirect @@ -52,402 +52,6 @@ class ChangedObject(models.Model): unique_together = ('changeset', 'content_type', 'existing_object_pk') ordering = ['created', 'pk'] - def __init__(self, *args, model_class=None, **kwargs): - super().__init__(*args, **kwargs) - self._set_object = None - self._m2m_added_cache = {name: set(values) for name, values in self.m2m_added.items()} - self._m2m_removed_cache = {name: set(values) for name, values in self.m2m_removed.items()} - if model_class is not None: - self.model_class = model_class - for field in self.model_class._meta.get_fields(): - if field.name in self.updated_fields and isinstance(field, DecimalField): - self.updated_fields[field.name] = Decimal(self.updated_fields[field.name]) - - @property - def model_class(self) -> typing.Optional[typing.Type[models.Model]]: - return self.content_type.model_class() - - @model_class.setter - def model_class(self, value: typing.Optional[typing.Type[models.Model]]): - self.content_type = ContentType.objects.get_for_model(value) - - @property - def obj_pk(self) -> typing.Union[int, str]: - if not self.is_created: - return self.existing_object_pk - return 'c'+str(self.pk) - - @property - def obj(self) -> ModelInstanceWrapper: - return self.get_obj(get_foreign_objects=True) - - @property - def is_created(self): - return self.existing_object_pk is None - - def get_obj(self, get_foreign_objects=False) -> ModelInstanceWrapper: - model = self.model_class - - if not self.is_created: - if self._set_object is None: - try: - obj = model.objects.get(pk=self.existing_object_pk) - except model.DoesNotExist: - obj = model(pk=self.existing_object_pk) - self._set_object = self.changeset.wrap_instance(obj) - - # noinspection PyTypeChecker - return self._set_object - - pk = self.obj_pk - - obj = model() - obj.pk = pk - if model._meta.pk.is_relation: - setattr(obj, model._meta.pk.related_model._meta.pk.attname, pk) - obj._state.adding = False - return self.changeset.wrap_instance(obj) - - def add_relevant_object_pks(self, object_pks, many=True): - object_pks.setdefault(self.model_class, set()).add(self.obj_pk) - for name, value in self.updated_fields.items(): - if '__i18n__' in name: - continue - field = self.model_class._meta.get_field(name) - if field.is_relation: - object_pks.setdefault(field.related_model, set()).add(value) - - if many: - for name, value in chain(self._m2m_added_cache.items(), self._m2m_removed_cache.items()): - field = self.model_class._meta.get_field(name) - object_pks.setdefault(field.related_model, set()).update(value) - - def update_changeset_cache(self): - if self.pk is None: - return - - model = self.model_class - pk = self.obj_pk - - self.changeset.changed_objects.setdefault(model, {})[pk] = self - - if self.is_created: - if not self.deleted: - self.changeset.created_objects.setdefault(model, {})[pk] = self.updated_fields - else: - if not self.deleted: - self.changeset.updated_existing.setdefault(model, {})[pk] = self.updated_fields - self.changeset.deleted_existing.setdefault(model, set()).discard(pk) - else: - self.changeset.updated_existing.setdefault(model, {}).pop(pk, None) - self.changeset.deleted_existing.setdefault(model, set()).add(pk) - - if not self.deleted: - self.changeset.m2m_added.setdefault(model, {})[pk] = self._m2m_added_cache - self.changeset.m2m_removed.setdefault(model, {})[pk] = self._m2m_removed_cache - else: - self.changeset.m2m_added.get(model, {}).pop(pk, None) - self.changeset.m2m_removed.get(model, {}).pop(pk, None) - - def apply_to_instance(self, instance: ModelInstanceWrapper, created_pks=None): - for name, value in self.updated_fields.items(): - if '__i18n__' in name: - name, i18n, lang = name.split('__') - field = instance._meta.get_field(name) - if not value: - getattr(instance, field.attname).pop(lang, None) - else: - getattr(instance, field.attname)[lang] = value - continue - - field = instance._meta.get_field(name) - if not field.is_relation: - setattr(instance, field.name, field.to_python(value)) - elif field.many_to_one or field.one_to_one: - if is_created_pk(value): - if created_pks is None: - try: - obj = self.changeset.get_created_object(field.related_model, value, allow_deleted=True) - except field.related_model.DoesNotExist: - pass - else: - setattr(instance, field.get_cache_name(), obj) - else: - try: - delattr(instance, field.get_cache_name()) - except AttributeError: - pass - try: - value = created_pks[field.related_model][value] - except KeyError: - raise ApplyToInstanceError - else: - try: - delattr(instance, field.get_cache_name()) - except AttributeError: - pass - setattr(instance, field.attname, value) - else: - raise NotImplementedError - - def clean_updated_fields(self, objects=None): - if self.is_created: - current_obj = self.model_class() - elif objects is not None: - current_obj = objects[self.model_class][self.existing_object_pk] - else: - current_obj = self.model_class.objects.get(pk=self.existing_object_pk) - - delete_fields = set() - for name, new_value in self.updated_fields.items(): - if '__i18n__' in name: - orig_name, i18n, lang = name.split('__') - field = self.model_class._meta.get_field(orig_name) - current_value = getattr(current_obj, field.attname).get(lang, '') - else: - field = self.model_class._meta.get_field(name) - - if not field.is_relation: - current_value = field.get_prep_value(getattr(current_obj, field.name)) - elif field.many_to_one or field.one_to_one: - current_value = getattr(current_obj, field.attname) - else: - raise NotImplementedError - - if current_value == new_value: - delete_fields.add(name) - - self.updated_fields = {name: value for name, value in self.updated_fields.items() if name not in delete_fields} - return delete_fields - - def handle_deleted_object_pks(self, deleted_object_pks): - if self.obj_pk in deleted_object_pks[self.model_class]: - self.delete() - return False - - for name, value in self.updated_fields.items(): - if '__i18n__' in name: - continue - field = self.model_class._meta.get_field(name) - if field.is_relation: - if value in deleted_object_pks[field.related_model]: - deleted_object_pks[self.model_class].add(self.obj_pk) - self.delete() - return False - - changed = False - for name, value in chain(self._m2m_added_cache.items(), self._m2m_removed_cache.items()): - field = self.model_class._meta.get_field(name) - if deleted_object_pks[field.related_model] & value: - value.difference_update(deleted_object_pks[field.related_model]) - changed = True - - return changed - - def save_instance(self, instance): - old_updated_fields = self.updated_fields - self.updated_fields = {} - - if instance.pk is None and self.model_class == LocationRedirect and not is_created_pk(instance.target_id): - obj = LocationRedirect.objects.filter(pk__in=self.changeset.deleted_existing.get(LocationRedirect, ()), - slug=instance.slug, target_id=instance.target_id).first() - if obj is not None: - self.changeset.get_changed_object(obj).restore() - return - - for field in self.model_class._meta.get_fields(): - if not isinstance(field, Field) or field.primary_key: - continue - - elif not field.is_relation: - value = getattr(instance, field.attname) - if isinstance(field, I18nField): - for lang, subvalue in value.items(): - self.updated_fields['%s__i18n__%s' % (field.name, lang)] = subvalue - elif isinstance(field, (CharField, TextField)): - self.updated_fields[field.name] = None if field.null and not value else field.get_prep_value(value) - else: - self.updated_fields[field.name] = field.get_prep_value(value) - elif field.many_to_one or field.one_to_one: - try: - value = getattr(instance, field.get_cache_name()) - except AttributeError: - value = getattr(instance, field.attname) - else: - value = None if value is None else value.pk - self.updated_fields[field.name] = value - - self.clean_updated_fields() - for name, value in self.updated_fields.items(): - if old_updated_fields.get(name, None) != value: - self.changeset._object_changed = True - break - self.save() - if instance.pk is None and self.pk is not None: - instance.pk = self.obj_pk - - def can_delete(self): - for field in self.model_class._meta.get_fields(): - if not field.one_to_many: - continue - related_model = field.related_model - if related_model._meta.app_label != 'mapdata': - continue - if related_model.__name__ in ('AccessPermission', 'Report'): - continue - kwargs = {field.field.name+'__pk': self.obj_pk} - if self.changeset.wrap_model(related_model).objects.filter(**kwargs).exists(): - return False - return True - - def get_unique_collisions(self, max_one=False): - result = set() - if not self.deleted: - return result - uniques = tuple(self.model_class._meta.unique_together) - uniques += tuple((field.name, ) - for field in self.model_class._meta.get_fields() - if field.related_model is None and field.unique and not field.primary_key) - for unique in uniques: - names = tuple((name if self.model_class._meta.get_field(name).related_model is None else name+'__pk') - for name in unique) - values = tuple(getattr(self.obj, self.model_class._meta.get_field(name).attname) for name in unique) - if None in values: - continue - if self.changeset.wrap_model(self.model_class).objects.filter(**dict(zip(names, values))).exists(): - result |= set(unique) - if result and max_one: - return result - return result - - def get_missing_dependencies(self, force_query=False, max_one=False): - result = set() - if not self.deleted: - return result - for field in self.model_class._meta.get_fields(): - if not field.many_to_one: - continue - if field.name not in self.updated_fields: - continue - related_model = field.related_model - if related_model._meta.app_label != 'mapdata': - continue - - pk = self.updated_fields[field.name] - - if force_query: - # query here to avoid a race condition - related_content_type = ContentType.objects.get_for_model(related_model) - qs = self.changeset.changed_objects_set.filter(content_type=related_content_type) - if is_created_pk(pk): - if not qs.filter(pk=int(pk[1:]), deleted=False).exists(): - result.add(field.name) - else: - if qs.filter(existing_object_pk=pk, deleted=True).exists(): - result.add(field.name) - else: - if is_created_pk(pk): - if pk not in self.changeset.created_objects.get(related_model, ()): - result.add(field.name) - else: - if pk in self.changeset.deleted_existing.get(related_model, ()): - result.add(field.name) - - if result and max_one: - return result - - return result - - def mark_deleted(self): - if not self.can_delete(): - return False - self.changeset._object_changed = True - self.deleted = True - self.save() - return True - - def clean_m2m(self, objects): - current_obj = objects[self.model_class][self.obj_pk] - changed = False - for name in set(self._m2m_added_cache.keys()) | set(self._m2m_removed_cache.keys()): - changed = changed or self.m2m_set(name, obj=self.changeset.wrap_instance(current_obj)) - return changed - - def m2m_set(self, name, set_pks=None, obj=None): - if obj is not None: - pks = set(related_obj.pk for related_obj in getattr(obj, name).all()) - elif not self.is_created: - field = self.model_class._meta.get_field(name) - rel_name = field.remote_field.related_name - pks = set(field.related_model.objects.filter(**{rel_name+'__pk': self.obj_pk}).values_list('pk', flat=True)) - else: - pks = set() - - m2m_added_before = self._m2m_added_cache.get(name, set()) - m2m_removed_before = self._m2m_removed_cache.get(name, set()) - - if set_pks is None: - self._m2m_added_cache.get(name, set()).difference_update(pks) - self._m2m_removed_cache.get(name, set()).intersection_update(pks) - else: - self._m2m_added_cache[name] = set_pks - pks - self._m2m_removed_cache[name] = pks - set_pks - - if not self._m2m_added_cache.get(name, set()): - self._m2m_added_cache.pop(name, None) - if not self._m2m_removed_cache.get(name, set()): - self._m2m_removed_cache.pop(name, None) - - if (m2m_added_before != self._m2m_added_cache.get(name, set()) or - m2m_removed_before != self._m2m_removed_cache.get(name, set())): - self.changeset._object_changed = True - self.save() - return True - return False - - def m2m_add(self, name, pks: set): - self._m2m_added_cache.setdefault(name, set()).update(pks) - self._m2m_removed_cache.setdefault(name, set()).difference_update(pks) - self.m2m_set(name) - - def m2m_remove(self, name, pks: set): - self._m2m_removed_cache.setdefault(name, set()).update(pks) - self._m2m_added_cache.setdefault(name, set()).difference_update(pks) - self.m2m_set(name) - - def restore(self): - if self.deleted is False: - return - if self.get_missing_dependencies(force_query=True, max_one=True) or self.get_unique_collisions(max_one=True): - raise PermissionError - self.deleted = False - self.save(standalone=True) - - @property - def does_something(self): - return (self.updated_fields or self._m2m_added_cache or self._m2m_removed_cache or self.is_created or - (not self.is_created and self.deleted)) - - def save(self, *args, standalone=False, **kwargs): - self.m2m_added = {name: tuple(values) for name, values in self._m2m_added_cache.items()} - self.m2m_removed = {name: tuple(values) for name, values in self._m2m_removed_cache.items()} - if not self.does_something: - if self.pk: - self.delete() - else: - self.changeset._object_changed = True - if not standalone and self.changeset.pk is None: - self.changeset.save() - self.changeset = self.changeset - if self.does_something: - super().save(*args, **kwargs) - if not standalone and not self.changeset.fill_changes_cache(): - self.update_changeset_cache() - - def delete(self, **kwargs): - self.changeset._object_changed = True - super().delete(**kwargs) - def __repr__(self): return '' % (str(self.pk), str(self.changeset_id)) diff --git a/src/c3nav/editor/models/changeset.py b/src/c3nav/editor/models/changeset.py index 0857ca43..fc87f9c6 100644 --- a/src/c3nav/editor/models/changeset.py +++ b/src/c3nav/editor/models/changeset.py @@ -19,7 +19,7 @@ from django.utils.translation import ngettext_lazy from c3nav.editor.models.changedobject import ApplyToInstanceError, ChangedObject, NoopChangedObject from c3nav.editor.tasks import send_changeset_proposed_notification -from c3nav.editor.wrappers import ModelInstanceWrapper, ModelWrapper, is_created_pk +from c3nav.editor.wrappers import is_created_pk from c3nav.mapdata.models import LocationSlug, MapUpdate from c3nav.mapdata.models.locations import LocationRedirect from c3nav.mapdata.utils.cache.changes import changed_geometries @@ -139,27 +139,6 @@ class ChangeSet(models.Model): """ Wrap Objects """ - def wrap_model(self, model): - if isinstance(model, str): - model = apps.get_model('mapdata', model) - assert isinstance(model, type) and issubclass(model, models.Model) - if self.direct_editing: - model.EditorForm = ModelWrapper(self, model).EditorForm - return model - return self._get_wrapped_model(model) - - def _get_wrapped_model(self, model): - wrapped = self._wrapped_model_cache.get(model, None) - if wrapped is None: - wrapped = ModelWrapper(self, model) - self._wrapped_model_cache[model] = wrapped - return wrapped - - def wrap_instance(self, instance): - assert isinstance(instance, models.Model) - if self.direct_editing: - return instance - return self.wrap_model(instance.__class__).wrapped_model_class(self, instance) def relevant_changed_objects(self) -> typing.Iterable[ChangedObject]: return self.changed_objects_set.exclude(existing_object_pk__isnull=True, deleted=True) @@ -298,133 +277,8 @@ class ChangeSet(models.Model): Analyse Changes """ def get_objects(self, many=True, changed_objects=None, prefetch_related=()): - if changed_objects is None: - if self.changed_objects is None: - raise TypeError - changed_objects = self.iter_changed_objects() - - # collect pks of relevant objects - object_pks = {} - for change in changed_objects: - change.add_relevant_object_pks(object_pks, many=many) - - # create dummy objects for deleted ones - objects = {} - for model, pks in object_pks.items(): - objects[model] = {pk: model(pk=pk) for pk in pks} - - slug_submodels = tuple(model for model in object_pks.keys() - if model is not LocationSlug and issubclass(model, LocationSlug)) - if slug_submodels: - object_pks[LocationSlug] = reduce(operator.or_, (object_pks[model] for model in slug_submodels)) - for model in slug_submodels: - object_pks.pop(model) - - # retrieve relevant objects - for model, pks in object_pks.items(): - if not pks: - continue - created_pks = set(pk for pk in pks if is_created_pk(pk)) - existing_pks = pks - created_pks - model_objects = {} - if existing_pks: - qs = model.objects - if model is LocationSlug: - qs = qs.select_related_target() - qs = qs.filter(pk__in=existing_pks) - for prefetch in prefetch_related: - try: - model._meta.get_field(prefetch) - except FieldDoesNotExist: - pass - else: - qs = qs.prefetch_related(prefetch) - for obj in qs: - if model == LocationSlug: - obj = obj.get_child() - model_objects[obj.pk] = obj - if created_pks: - for pk in created_pks: - model_objects[pk] = self.get_created_object(model, pk, allow_deleted=True)._obj - objects[model] = model_objects - - # add LocationSlug objects as their correct model - for pk, obj in objects.get(LocationSlug, {}).items(): - objects.setdefault(obj.__class__, {})[pk] = obj - - for pk, obj in objects.get(LocationRedirect, {}).items(): - try: - target = obj.target.get_child(obj.target) - except FieldDoesNotExist: - # todo: fix this - continue - # todo: why is it sometimes wrapped and sometimes not? - objects.setdefault(LocationSlug, {})[target.pk] = getattr(target, '_obj', target) - objects.setdefault(target.__class__, {})[target.pk] = getattr(target, '_obj', target) - - return objects - - def get_changed_values(self, model: models.Model, name: str) -> tuple: - """ - Get all changes values for a specific field on existing models - :param model: model class - :param name: field name - :return: returns a dictionary with primary keys as keys and new values as values - """ - r = tuple((pk, values[name]) for pk, values in self.updated_existing.get(model, {}).items() if name in values) - return r - - def get_changed_object(self, obj, allow_noop=False) -> typing.Union[ChangedObject, typing.Type[NoopChangedObject]]: - if isinstance(obj, ModelInstanceWrapper): - obj = obj._obj - model = obj.__class__ - pk = obj.pk - if pk is None: - return ChangedObject(changeset=self, model_class=model) - - self.fill_changes_cache() - - objects = tuple(obj for obj in ((submodel, self.changed_objects.get(submodel, {}).get(pk, None)) - for submodel in get_submodels(model)) if obj[1] is not None) - if len(objects) > 1: - raise model.MultipleObjectsReturned - if objects: - return objects[0][1] - - if is_created_pk(pk): - raise model.DoesNotExist - - if allow_noop: - return NoopChangedObject - - return ChangedObject(changeset=self, model_class=model, existing_object_pk=pk) - - def get_created_object(self, model, pk, get_foreign_objects=False, allow_deleted=False): - """ - Gets a created model instance. - :param model: model class - :param pk: primary key - :param get_foreign_objects: whether to fetch foreign objects and not just set their id to field.attname - :param allow_deleted: return created objects that have already been deleted (needs get_history=True) - :return: a wrapped model instance - """ - self.fill_changes_cache() - if issubclass(model, ModelWrapper): - model = model._obj - - obj = self.get_changed_object(model(pk=pk)) - if obj.deleted and not allow_deleted: - raise model.DoesNotExist - return obj.get_obj(get_foreign_objects=get_foreign_objects) - - def get_created_pks(self, model) -> set: - """ - Returns a set with the primary keys of created objects from this model - """ - self.fill_changes_cache() - if issubclass(model, ModelWrapper): - model = model._obj - return set(self.created_objects.get(model, {}).keys()) + # todo: reimplement, maybe + pass """ Permissions @@ -677,85 +531,7 @@ class ChangeSet(models.Model): def apply(self, user): with MapUpdate.lock(): - changed_geometries.reset() - - self._clean_changes() - changed_objects = self.relevant_changed_objects() - created_objects = [] - existing_objects = [] - for changed_object in changed_objects: - (created_objects if changed_object.is_created else existing_objects).append(changed_object) - - objects = self.get_objects(changed_objects=changed_objects) - - # remove slugs on all changed existing objects - slugs_updated = set(changed_object.obj_pk for changed_object in existing_objects - if (issubclass(changed_object.model_class, LocationSlug) and - 'slug' in changed_object.updated_fields)) - LocationSlug.objects.filter(pk__in=slugs_updated).update(slug=None) - - redirects_deleted = set(changed_object.obj_pk for changed_object in existing_objects - if (issubclass(changed_object.model_class, LocationRedirect) and - changed_object.deleted)) - LocationRedirect.objects.filter(pk__in=redirects_deleted).delete() - - # create created objects - created_pks = {} - objects_to_create = set(created_objects) - while objects_to_create: - created_in_last_run = set() - for created_object in objects_to_create: - model = created_object.model_class - pk = created_object.obj_pk - - # lets try to create this object - obj = model() - try: - created_object.apply_to_instance(obj, created_pks=created_pks) - except ApplyToInstanceError: - continue - - obj.save() - created_in_last_run.add(created_object) - created_pks.setdefault(model, {})[pk] = obj.pk - objects.setdefault(model, {})[pk] = obj - if issubclass(model, LocationSlug): - # todo: make this generic - created_pks.setdefault(LocationSlug, {})[pk] = obj.pk - objects.setdefault(LocationSlug, {})[pk] = obj - - objects_to_create -= created_in_last_run - - # update existing objects - for existing_object in existing_objects: - if existing_object.deleted: - continue - model = existing_object.model_class - pk = existing_object.obj_pk - - obj = objects[model][pk] - existing_object.apply_to_instance(obj, created_pks=created_pks) - obj.save() - - # delete existing objects - for existing_object in existing_objects: - if not existing_object.deleted and not issubclass(existing_object.model_class, LocationRedirect): - continue - model = existing_object.model_class - pk = existing_object.obj_pk - - obj = objects[model][pk] - obj.delete() - - # update m2m - for changed_object in changed_objects: - obj = objects[changed_object.model_class][changed_object.obj_pk] - for mode, updates in (('remove', changed_object.m2m_removed), ('add', changed_object.m2m_added)): - for name, pks in updates.items(): - field = changed_object.model_class._meta.get_field(name) - pks = tuple(objects[field.related_model][pk].pk for pk in pks) - getattr(getattr(obj, name), mode)(*pks) - + # todo: reimplement update = self.updates.create(user=user, state='applied') map_update = MapUpdate.objects.create(user=user, type='changeset') self.state = 'applied' diff --git a/src/c3nav/editor/views/base.py b/src/c3nav/editor/views/base.py index 0ba12001..8943e7ee 100644 --- a/src/c3nav/editor/views/base.py +++ b/src/c3nav/editor/views/base.py @@ -16,7 +16,6 @@ from django.utils.translation import get_language from django.utils.translation import gettext_lazy as _ from c3nav.editor.models import ChangeSet -from c3nav.editor.wrappers import QuerySetWrapper from c3nav.mapdata.models.access import AccessPermission from c3nav.mapdata.models.base import SerializableMixin from c3nav.mapdata.utils.user import can_access_editor @@ -233,7 +232,7 @@ class APIHybridTemplateContextResponse(APIHybridResponse): def _maybe_serialize_value(self, value): if isinstance(value, SerializableMixin): value = value.serialize(geometry=False, detailed=False) - elif isinstance(value, (QuerySet, QuerySetWrapper)) and issubclass(value.model, SerializableMixin): + elif isinstance(value, QuerySet) and issubclass(value.model, SerializableMixin): value = [item.serialize(geometry=False, detailed=False) for item in value] return value diff --git a/src/c3nav/editor/views/changes.py b/src/c3nav/editor/views/changes.py index 72cf0f18..11378eef 100644 --- a/src/c3nav/editor/views/changes.py +++ b/src/c3nav/editor/views/changes.py @@ -12,7 +12,7 @@ from django.utils.text import format_lazy from django.utils.translation import get_language_info from django.utils.translation import gettext_lazy as _ -from c3nav.editor.forms import ChangeSetForm, RejectForm +from c3nav.editor.forms import ChangeSetForm, RejectForm, get_editor_form from c3nav.editor.models import ChangeSet from c3nav.editor.views.base import sidebar_view from c3nav.editor.wrappers import is_created_pk @@ -265,7 +265,7 @@ def changeset_detail(request, pk): } changed_objects_data.append(changed_object_data) - form_fields = changeset.wrap_model(type(obj)).EditorForm._meta.fields + form_fields = get_editor_form(model)._meta.fields if changed_object.is_created: changes.append({ diff --git a/src/c3nav/editor/views/edit.py b/src/c3nav/editor/views/edit.py index 2f8b1baf..4cd9ca3c 100644 --- a/src/c3nav/editor/views/edit.py +++ b/src/c3nav/editor/views/edit.py @@ -2,6 +2,7 @@ import mimetypes import typing from contextlib import suppress +from django.apps import apps from django.conf import settings from django.contrib import messages from django.core.cache import cache @@ -14,17 +15,19 @@ from django.urls import reverse from django.utils.translation import gettext_lazy as _ from django.views.decorators.http import etag -from c3nav.editor.forms import GraphEdgeSettingsForm, GraphEditorActionForm +from c3nav.editor.forms import GraphEdgeSettingsForm, GraphEditorActionForm, get_editor_form from c3nav.editor.utils import DefaultEditUtils, LevelChildEditUtils, SpaceChildEditUtils from c3nav.editor.views.base import (APIHybridError, APIHybridFormTemplateResponse, APIHybridLoginRequiredResponse, APIHybridMessageRedirectResponse, APIHybridTemplateContextResponse, editor_etag_func, sidebar_view) +from c3nav.mapdata.models import Level, Space, LocationGroupCategory, GraphNode, GraphEdge from c3nav.mapdata.models.access import AccessPermission from c3nav.mapdata.utils.user import can_access_editor def child_model(request, model: typing.Union[str, models.Model], kwargs=None, parent=None): - model = request.changeset.wrap_model(model) + if isinstance(model, str): + model = apps.get_model(app_label="mapdata", model_name=model) related_name = model._meta.default_related_name if parent is not None: qs = getattr(parent, related_name) @@ -43,7 +46,6 @@ def child_model(request, model: typing.Union[str, models.Model], kwargs=None, pa @etag(editor_etag_func) @sidebar_view(api_hybrid=True) def main_index(request): - Level = request.changeset.wrap_model('Level') return APIHybridTemplateContextResponse('editor/index.html', { 'levels': Level.objects.filter(Level.q_for_request(request), on_top_of__isnull=True), 'can_create_level': (request.user_permissions.can_access_base_mapdata and @@ -68,7 +70,6 @@ def main_index(request): @etag(editor_etag_func) @sidebar_view(api_hybrid=True) def level_detail(request, pk): - Level = request.changeset.wrap_model('Level') qs = Level.objects.filter(Level.q_for_request(request)) level = get_object_or_404(qs.select_related('on_top_of').prefetch_related('levels_on_top'), pk=pk) @@ -97,9 +98,6 @@ def level_detail(request, pk): @etag(editor_etag_func) @sidebar_view(api_hybrid=True) def space_detail(request, level, pk): - Level = request.changeset.wrap_model('Level') - Space = request.changeset.wrap_model('Space') - # todo: HOW TO GET DATA qs = Space.objects.filter(Space.q_for_request(request)) space = get_object_or_404(qs.select_related('level'), level__pk=level, pk=pk) @@ -133,17 +131,16 @@ def get_changeset_exceeded(request): @etag(editor_etag_func) @sidebar_view(api_hybrid=True) def edit(request, pk=None, model=None, level=None, space=None, on_top_of=None, explicit_edit=False): + if isinstance(model, str): + model = apps.get_model(app_label="mapdata", model_name=model) + changeset_exceeded = get_changeset_exceeded(request) model_changes = {} if changeset_exceeded: model_changes = request.changeset.get_changed_objects_by_model(model) - model = request.changeset.wrap_model(model) related_name = model._meta.default_related_name - Level = request.changeset.wrap_model('Level') - Space = request.changeset.wrap_model('Space') - can_edit_changeset = request.changeset.can_edit(request) obj = None @@ -343,9 +340,9 @@ def edit(request, pk=None, model=None, level=None, space=None, on_top_of=None, e json_body = getattr(request, 'json_body', None) data = json_body if json_body is not None else request.POST - form = model.EditorForm(instance=model() if new else obj, data=data, is_json=json_body is not None, - request=request, space_id=space_id, - geometry_editable=edit_utils.can_access_child_base_mapdata) + form = get_editor_form(model)(instance=model() if new else obj, data=data, is_json=json_body is not None, + request=request, space_id=space_id, + geometry_editable=edit_utils.can_access_child_base_mapdata) if form.is_valid(): # Update/create objects obj = form.save(commit=False) @@ -383,8 +380,8 @@ def edit(request, pk=None, model=None, level=None, space=None, on_top_of=None, e error = APIHybridError(status_code=403, message=_('You can not edit changes on this changeset.')) else: - form = model.EditorForm(instance=obj, request=request, space_id=space_id, - geometry_editable=edit_utils.can_access_child_base_mapdata) + form = get_editor_form(model)(instance=obj, request=request, space_id=space_id, + geometry_editable=edit_utils.can_access_child_base_mapdata) ctx.update({ 'form': form, @@ -400,7 +397,6 @@ def get_visible_spaces(request): ) visible_spaces = cache.get(cache_key, None) if visible_spaces is None: - Space = request.changeset.wrap_model('Space') visible_spaces = tuple(Space.qs_for_request(request).values_list('pk', flat=True)) cache.set(cache_key, visible_spaces, 900) return visible_spaces @@ -419,15 +415,13 @@ def get_visible_spaces_kwargs(model, request): @etag(editor_etag_func) @sidebar_view(api_hybrid=True) def list_objects(request, model=None, level=None, space=None, explicit_edit=False): + if isinstance(model, str): + model = apps.get_model(app_label="mapdata", model_name=model) + resolver_match = getattr(request, 'sub_resolver_match', request.resolver_match) if not resolver_match.url_name.endswith('.list'): raise ValueError('url_name does not end with .list') - model = request.changeset.wrap_model(model) - - Level = request.changeset.wrap_model('Level') - Space = request.changeset.wrap_model('Space') - can_edit = request.changeset.can_edit(request) ctx = { @@ -524,7 +518,6 @@ def list_objects(request, model=None, level=None, space=None, explicit_edit=Fals reverse_kwargs.pop('pk', None) if model.__name__ == 'LocationGroup': - LocationGroupCategory = request.changeset.wrap_model('LocationGroupCategory') grouped_objects = tuple( { 'title': category.title_plural, @@ -593,11 +586,6 @@ def graph_edit(request, level=None, space=None): if not request.user_permissions.can_access_base_mapdata: raise PermissionDenied - Level = request.changeset.wrap_model('Level') - Space = request.changeset.wrap_model('Space') - GraphNode = request.changeset.wrap_model('GraphNode') - GraphEdge = request.changeset.wrap_model('GraphEdge') - can_edit = request.changeset.can_edit(request) ctx = { diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index d83dc905..50be21bb 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -1,1011 +1,3 @@ -import base64 -import operator -import typing -from collections import OrderedDict -from functools import reduce, wraps -from itertools import chain - -from django.core.cache import cache -from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist -from django.db import models -from django.db.models import Manager, ManyToManyRel, Prefetch, Q -from django.db.models.fields.related_descriptors import ReverseOneToOneDescriptor -from django.utils.functional import cached_property - -from c3nav.mapdata.utils.models import get_submodels - - -class BaseWrapper: - """ - Base Class for all wrappers. Saves wrapped objects along with the changeset. - getattr, setattr and delattr will be forwarded to the object, exceptions are specified in _not_wrapped. - If the value of an attribute is a model, model instance, manager or queryset, it will be wrapped, to. - Callables will only be passed through for models. Implement other passthroughs yourself. - """ - _not_wrapped = frozenset(('_changeset', '_obj', '_created_pks', '_result', '_extra', '_result_cache', - '_affected_by_changeset')) - - def __init__(self, changeset, obj): - self._changeset = changeset - self._obj = obj - - # noinspection PyUnresolvedReferences - def _wrap_model(self, model): - """ - Wrap a model with same changeset as this wrapper. - """ - if isinstance(model, type) and issubclass(model, ModelInstanceWrapper): - model = model._parent - if isinstance(model, ModelWrapper): - if self._changeset == model._changeset: - return model - model = model._obj - assert issubclass(model, models.Model) - return self._changeset._get_wrapped_model(model) - - def _wrap_instance(self, instance): - """ - Wrap a model instance with same changeset as this wrapper. - """ - if isinstance(instance, ModelInstanceWrapper): - if self._changeset == instance._changeset: - return instance - instance = instance._obj - assert isinstance(instance, models.Model) - return self._wrap_model(instance.__class__).wrapped_model_class(self._changeset, instance) - - def _wrap_manager(self, manager): - """ - Wrap a manager with same changeset as this wrapper. - Detects RelatedManager or ManyRelatedmanager instances and chooses the Wrapper accordingly. - """ - assert isinstance(manager, Manager) - if hasattr(manager, 'through'): - return ManyRelatedManagerWrapper(self._changeset, manager) - if hasattr(manager, 'instance'): - return RelatedManagerWrapper(self._changeset, manager) - return ManagerWrapper(self._changeset, manager) - - def _wrap_queryset(self, queryset): - """ - Wrap a queryset with same changeset as this wrapper. - """ - return QuerySetWrapper(self._changeset, queryset) - - def __getattr__(self, name): - value = getattr(self._obj, name) - if isinstance(value, Manager): - value = self._wrap_manager(value) - elif isinstance(value, type) and issubclass(value, models.Model) and value._meta.app_label == 'mapdata': - value = self._wrap_model(value) - elif isinstance(value, models.Model) and value._meta.app_label == 'mapdata': - value = self._wrap_instance(value) - elif isinstance(value, type) and issubclass(value, Exception): - pass - elif callable(value): - if isinstance(self, (ModelInstanceWrapper, ModelWrapper)) and not hasattr(models.Model, name): - return value - raise TypeError('Can not call %s.%s wrapped!' % (type(self), name)) - return value - - def __setattr__(self, name, value): - if name in self._not_wrapped: - return super().__setattr__(name, value) - return setattr(self._obj, name, value) - - def __delattr__(self, name): - return delattr(self._obj, name) - - -class ModelWrapper(BaseWrapper): - """ - Wraps a model class. - Can be compared to other wrapped or non-wrapped model classes. - Can be called (like a class) to get a wrapped model instance - that has the according ModelWrapper as its type / metaclass. - """ - def __eq__(self, other): - if type(other) == ModelWrapper: - return self._obj is other._obj - return self._obj is other - - # noinspection PyPep8Naming - @cached_property - def EditorForm(self): - """ - Returns an editor form for this model. - """ - from c3nav.editor.forms import create_editor_form - return create_editor_form(self._obj) - - @cached_property - def _submodels(self): - """ - Get non-abstract submodels for this model including the model itself. - """ - return get_submodels(self._obj) - - @cached_property - def wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']: - return self.create_wrapped_model_class() - - def create_wrapped_model_class(self) -> typing.Type['ModelInstanceWrapper']: - """ - Return a ModelInstanceWrapper that has a proxy to this instance as its type / metaclass. #voodoo - """ - # noinspection PyTypeChecker - return self.metaclass(self._obj.__name__ + 'InstanceWrapper', (ModelInstanceWrapper,), {}) - - def __call__(self, **kwargs): - """ - Create a wrapped instance of this model. _wrap_instance will call create_wrapped_model_class(). - """ - instance = self._wrap_instance(self._obj()) - for name, value in kwargs.items(): - setattr(instance, name, value) - return instance - - @cached_property - def metaclass(self): - return self.create_metaclass() - - def create_metaclass(self): - """ - Create the proxy metaclass for craeate_wrapped_model_class(). - """ - parent = self - - class ModelInstanceWrapperMeta(type): - _parent = parent - - def __getattr__(self, name): - return getattr(parent, name) - - def __setattr__(self, name, value): - setattr(parent, name, value) - - def __delattr__(self, name): - delattr(parent, name) - - ModelInstanceWrapperMeta.__name__ = self._obj.__name__+'InstanceWrapperMeta' - - return ModelInstanceWrapperMeta - - def __repr__(self): - return '' - - def is_created_pk(pk): return isinstance(pk, str) and pk.startswith('c') and pk[1:].isnumeric() - -class ModelInstanceWrapper(BaseWrapper): - """ - Wraps a model instance. Don't use this directly, call a ModelWrapper instead / use ChangeSet.wrap(). - Creates changes in changeset when save() is called. - Updates updated values on existing objects on init. - Can be compared to other wrapped or non-wrapped model instances. - """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._affected_by_changeset = False - if self._obj.pk is not None: - changed_object = self._changeset.get_changed_object(self._obj, allow_noop=True) - self._affected_by_changeset = changed_object.pk is not None - changed_object.apply_to_instance(self) - - def __eq__(self, other): - if isinstance(other, BaseWrapper): - if type(self._obj) is not type(other._obj): # noqa - return False - elif type(self._obj) is not type(other): - return False - return self.pk == other.pk - - def __getattr__(self, name): - descriptor = getattr(self._obj.__class__, name, None) - if isinstance(descriptor, ReverseOneToOneDescriptor): - try: - rel_obj = descriptor.related.get_cached_value(self._obj) - except KeyError: - related_pk = self._obj._get_pk_val() - if related_pk is None: - rel_obj = None - else: - related_model = self._wrap_model(descriptor.related.related_model) - filter_args = descriptor.related.field.get_forward_related_filter(self._obj) - try: - rel_obj = related_model.objects.get(**filter_args) - except related_model.DoesNotExist: - rel_obj = None - else: - descriptor.related.field.set_cached_value(rel_obj, self._obj) - descriptor.related.set_cached_value(self._obj, rel_obj) - return super().__getattr__(name) - - def __setattr__(self, name, value): - """ - We have to intercept here because RelatedFields won't accept - wrapped model instances values, so we have to trick them. - """ - if name in self._not_wrapped: - return super().__setattr__(name, value) - try: - field = self._obj._meta.get_field(name) - except FieldDoesNotExist: - pass - else: - if field.many_to_one and name != field.attname and value is not None: - if isinstance(value, models.Model): - value = self._wrap_instance(value) - if not isinstance(value, ModelInstanceWrapper): - raise ValueError('value has to be None or ModelInstanceWrapper') - setattr(self._obj, name, value._obj) - field.set_cached_value(self._obj, value) - return - super().__setattr__(name, value) - - def __repr__(self): - cls_name = self._obj.__class__.__name__ - if self.pk is None: - return '<%s (unsaved) with Changeset #%s>' % (cls_name, self._changeset.pk) - elif is_created_pk(self.pk): - return '<%s #%s (created) from Changeset #%s>' % (cls_name, self.pk, self._changeset.pk) - return '<%s #%d (existing) with Changeset #%s>' % (cls_name, self.pk, self._changeset.pk) - - def _get_unique_checks(self, exclude=None): - unique_checks, date_checks = self._obj.__class__._get_unique_checks(self, exclude=exclude) - return [(self._wrap_model(model), unique) for model, unique in unique_checks], date_checks - - def full_clean(self, *args, **kwargs): - return self._obj.full_clean(*args, **kwargs) - - def _perform_unique_checks(self, *args, **kwargs): - return self._obj._perform_unique_checks(*args, **kwargs) - - def _perform_date_checks(self, *args, **kwargs): - return self._obj._perform_date_checks(*args, **kwargs) - - def validate_unique(self, *args, **kwargs): - return self._obj.__class__.validate_unique(self, *args, **kwargs) - - def _get_pk_val(self, *args, **kwargs): - return self._obj.__class__._get_pk_val(self, *args, **kwargs) - - def save(self): - """ - Create changes in changeset instead of saving. - """ - self._changeset.get_changed_object(self._obj).save_instance(self) - - def delete(self): - self._changeset.get_changed_object(self._obj).mark_deleted() - - -def get_queryset(func): - """ - Wraps methods of BaseQueryWrapper that manipulate a queryset. - If self is a Manager, not an object, preceed the method call with a filter call according to the manager. - """ - @wraps(func) - def wrapper(self, *args, **kwargs): - if hasattr(self, 'get_queryset'): - return getattr(self.get_queryset(), func.__name__)(*args, **kwargs) - return func(self, *args, **kwargs) - return wrapper - - -class BaseQueryWrapper(BaseWrapper): - """ - Base class for everything that wraps a QuerySet or manager. - Don't use this directly, but via WrappedModel.objects or WrappedInstance.groups or similar. - Intercepts all query methods to exclude ids / include ids for each filter according to changeset changes. - Keeps track of which created objects the current filtering still applies to. - When evaluated, just does everything as if the queryset was applied to the databse. - """ - def __init__(self, changeset, obj, created_pks=None, extra=()): - super().__init__(changeset, obj) - if created_pks is None: - created_pks = self._get_initial_created_pks() - self._created_pks = created_pks - self._extra = extra - - @property - def model(self): - return self._wrap_model(self._obj.model) - - def _get_initial_created_pks(self): - """ - Get all created pks for this query's model an submodels. - """ - return reduce(operator.or_, (self._changeset.get_created_pks(model) for model in self.model._submodels)) - - def _wrap_queryset(self, queryset, created_pks=None, add_extra=()): - """ - Wraps a queryset, usually after manipulating the current one. - :param created_pks: set of created pks to be still in the next queryset (the same ones as this one by default) - :param add_extra: extra() calls that have been added to the query - """ - if created_pks is None: - created_pks = self._created_pks - if created_pks is False: - created_pks = None - return QuerySetWrapper(self._changeset, queryset, created_pks, self._extra+add_extra) - - @get_queryset - def all(self): - return self._wrap_queryset(self._obj.all()) - - @get_queryset - def none(self): - return self._wrap_queryset(self._obj.none(), ()) - - @get_queryset - def select_related(self, *args, **kwargs): - return self._wrap_queryset(self._obj.select_related(*args, **kwargs)) - - @get_queryset - def prefetch_related(self, *lookups): - """ - We split up all prefetch related lookups into one-level prefetches - and convert them into Prefetch() objects with custom querysets. - This makes sure that the prefetch also happens on the virtually modified database. - """ - lookups_qs = {tuple(lookup.prefetch_through.split('__')): lookup.queryset for lookup in lookups - if isinstance(lookup, Prefetch) and lookup.queryset is not None} - for qs in lookups_qs.values(): - if not isinstance(qs, QuerySetWrapper): - raise TypeError('Prefetch object queryset needs to be wrapped!') - lookups = tuple((lookup.prefetch_through if isinstance(lookup, Prefetch) else lookup) for lookup in lookups) - - lookups_splitted = tuple(tuple(lookup.split('__')) for lookup in lookups) - max_depth = max(len(lookup) for lookup in lookups_splitted) - lookups_by_depth = [] - for i in range(max_depth): - lookups_by_depth.append(set(tuple(lookup[:i+1] for lookup in lookups_splitted if len(lookup) > i))) - - lookup_models = {(): self._obj.model} - lookup_querysets = {(): self.all()} - for depth_lookups in lookups_by_depth: - for lookup in depth_lookups: - model = lookup_models[lookup[:-1]]._meta.get_field(lookup[-1]).related_model - lookup_models[lookup] = model - lookup_querysets[lookup] = lookups_qs.get(lookup, self._wrap_model(model).objects.all()) - - for depth_lookups in reversed(lookups_by_depth): - for lookup in depth_lookups: - qs = lookup_querysets[lookup] - prefetch = Prefetch(lookup[-1], qs) - lower_qs = lookup_querysets[lookup[:-1]] - lower_qs._obj = lower_qs._obj.prefetch_related(prefetch) - - return lookup_querysets[()] - - def _chain(self, **kwargs): - """ - Return a copy of the current QuerySet that's ready for another - operation. - """ - obj = self._clone() - obj._obj.__dict__.update(kwargs) - return obj - - def _clone(self, **kwargs): - clone = self._wrap_queryset(self._obj) - clone._obj.__dict__.update(kwargs) - return clone - - @get_queryset - def get(self, *args, **kwargs): - results = tuple(self.filter(*args, **kwargs)) - if len(results) == 1: - return self._wrap_instance(results[0]) - if results: - raise self._obj.model.MultipleObjectsReturned - raise self._obj.model.DoesNotExist - - @get_queryset - def exists(self, *args, **kwargs): - if self._created_pks: - return True - return self._obj.exists() - - def only(self, *fields): - return self._wrap_queryset(self._obj.only(*fields)) - - def defer(self, *fields): - return self._wrap_queryset(self._obj.defer(*fields)) - - @get_queryset - def order_by(self, *args): - """ - Order only supported for numeric fields for now - """ - return self._wrap_queryset(self._obj.order_by(*args)) - - def _filter_values(self, q, field_name, check): - """ - Filter by value. - :param q: base Q object to give to the database and to modify - :param field_name: name of the field whose value should be compared - :param check: comparision function that only gets the new value - :return: new Q object and set of matched existing pks - """ - other_values = () - submodels = [model for model in self.model._submodels] - for model in submodels: - other_values += self._changeset.get_changed_values(model, field_name) - add_pks = [] - remove_pks = [] - for pk, new_value in other_values: - (add_pks if check(new_value) else remove_pks).append(pk) - created_pks = set() - for model in submodels: - for pk, values in self._changeset.created_objects.get(model, {}).items(): - field_name = getattr(model._meta.get_field(field_name), 'attname', field_name) - try: - if check(getattr(self._changeset.get_created_object(self._obj.model, pk), field_name)): - created_pks.add(pk) - except ObjectDoesNotExist: - pass - - return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks - - def _filter_kwarg(self, filter_name, filter_value): - """ - filter by kwarg. - The core filtering happens here, as also Q objects are just a collection / combination of kwarg filters. - :return: new Q object and set of matched existing pks - """ - # print(filter_name, '=', filter_value, sep='') - - segments = filter_name.split('__') - field_name = segments.pop(0) - model = self._obj.model - if field_name == 'pk': - field = model._meta.pk - else: - field = model._meta.get_field(field_name) - - # create a base q that we'll modify later - q = Q(**{filter_name: filter_value}) - - # check if the filter begins with pk or the name of the primary key - if field_name == 'pk' or field_name == model._meta.pk.name: - if not segments: - # if the check is just 'pk' or the name or the name of the primary key, return the mathing object - if is_created_pk(filter_value): - return Q(pk__in=()), {filter_value} - if filter_value is None or int(filter_value) in self._changeset.deleted_existing.get(model, ()): - return Q(pk__in=()), set() - return q, set() - elif segments == ['in']: - # if the check is 'pk__in' it's nearly as easy - return (Q(pk__in=tuple(pk for pk in filter_value if not is_created_pk(pk))), - set(pk for pk in filter_value if is_created_pk(pk))) - - # check if we are filtering by a foreign key field - if field.many_to_one or field.one_to_one: - rel_model = field.related_model - - if field_name == field.attname: - # turn 'foreign_obj_id' into 'foreign_obj__pk' for later - segments.insert(0, 'pk') - filter_name = field.name + '__' + '__'.join(segments) - q = Q(**{filter_name: filter_value}) - - if not segments: - # turn 'foreign_obj' into 'foreign_obj__pk' for later - filter_name = field_name + '__pk' - filter_value = filter_value.pk - segments = ['pk'] - q = Q(**{filter_name: filter_value}) - - filter_type = segments.pop(0) - - if not segments and filter_type == 'in': - # turn 'foreign_obj__in' into 'foreign_obj__pk' for later - filter_name = field_name+'__pk__in' - filter_value = tuple(obj.pk for obj in filter_value) - filter_type = 'pk' - segments = ['in'] - q = Q(**{filter_name: filter_value}) - - if filter_type == field.related_model._meta.pk.name: - # turn into pk for later - filter_type = 'pk' - - if filter_type == 'pk' and segments == ['in']: - # foreign_obj__pk__in - filter_value = (pk for pk in filter_value if pk is not None) - filter_value = tuple(pk for pk in filter_value - if is_created_pk(pk) or - int(pk) not in self._changeset.deleted_existing.get(rel_model, ())) - existing_pks = tuple(pk for pk in filter_value if not is_created_pk(pk)) - q = Q(**{field_name+'__pk__in': existing_pks}) - filter_value = tuple(str(pk) for pk in filter_value) - return self._filter_values(q, field_name, lambda val: str(val) in filter_value) - - if not segments: - if filter_type == 'pk': - # foreign_obj__pk - if is_created_pk(filter_value): - q = Q(pk__in=()) - else: - deleted_existing = self._changeset.deleted_existing.get(rel_model, ()) - if filter_value is None or int(filter_value) in deleted_existing: - return Q(pk__in=()), set() - filter_value = str(filter_value) - return self._filter_values(q, field_name, lambda val: str(val) == filter_value) - - if filter_type == 'isnull': - # foreign_obj__isnull - return self._filter_values(q, field_name, lambda val: (val is None) is filter_value) - - # so… is this a multi-level-lookup? - try: - rel_model._meta.get_field(filter_type) - except Exception: - raise NotImplementedError('Unsupported lookup or %s has no field "%s".' % (rel_model, filter_type)) - - # multi-level-lookup - subkwargs = {'__'.join([filter_type] + segments): filter_value} - cache_key = '%s:multilevellookup:%s:%s:%s' % ( - self._changeset.cache_key_by_changes, - rel_model.__name__, - next(iter(subkwargs.keys())), - base64.b64encode(repr(filter_value).encode()).decode() - ) - pk_values = cache.get(cache_key, None) - if pk_values is None: - pk_values = self._changeset.wrap_model(rel_model).objects.filter( - **subkwargs - ).values_list('pk', flat=True) - cache.set(cache_key, pk_values, 300) - q = Q(**{field_name + '__pk__in': tuple(pk for pk in pk_values if not is_created_pk(pk))}) - pk_values = set(str(pk) for pk in pk_values) - return self._filter_values(q, field_name, lambda val: str(val) in pk_values) - - # check if we are filtering by a many to many field - if field.many_to_many: - if not segments: - # turn 'm2m' into 'm2m__pk' for later - filter_name = field_name + '__pk' - filter_value = None if filter_value is None else filter_value.pk - segments = ['pk'] - q = Q(**{filter_name: filter_value}) - - filter_type = segments.pop(0) - - if not segments and filter_type == 'in': - # turn 'm2m__in' into 'm2m__pk__in' for later - filter_name = field_name+'__pk__in' - filter_value = tuple(obj.pk for obj in filter_value) - filter_type = 'pk' - segments = ['in'] - q = Q(**{filter_name: filter_value}) - - if filter_type == field.related_model._meta.pk.name: - # turn into pk for later - filter_type = 'pk' - - if filter_type == 'pk' and segments == ['in']: - # m2m__pk__in - if field.concrete: - # we don't do this in reverse - raise NotImplementedError - - # so... e.g. we want to get all groups that belong to one of the given spaces. - # field_name would be "spaces" - rel_model = field.related_model # space - rel_name = field.field.name # groups - filter_value = set(filter_value) # space pks - filter_value_existing = set(pk for pk in filter_value if not is_created_pk(pk)) - - # lets removeall spaces that have been deleted - filter_value = (pk for pk in filter_value if pk is not None) - filter_value = tuple(pk for pk in filter_value - if is_created_pk(pk) or - int(pk) not in self._changeset.deleted_existing.get(rel_model, ())) - - # get spaces that we are interested about that had groups added or removed - m2m_added = {pk: val[rel_name] for pk, val in self._changeset.m2m_added.get(rel_model, {}).items() - if pk in filter_value and rel_name in val} - m2m_removed = {pk: val[rel_name] for pk, val in self._changeset.m2m_removed.get(rel_model, {}).items() - if pk in filter_value and rel_name in val} # can only be existing spaces - - # directly lookup groups for spaces that had no groups removed - q = Q(**{field_name+'__pk__in': filter_value_existing - set(m2m_removed.keys())}) - - # lookup groups for spaces that had groups removed - for pk, values in m2m_removed.items(): - q |= Q(Q(**{field_name+'__pk': pk}) & ~Q(pk__in=values)) - - # get pk of groups that were added to any of the spaces - r_added_pks = reduce(operator.or_, m2m_added.values(), set()) - - # lookup existing groups that were added to any of the spaces - q |= Q(pk__in=tuple(pk for pk in r_added_pks if not is_created_pk(pk))) - - # get created groups that were added to any of the spaces - created_pks = set(pk for pk in r_added_pks if is_created_pk(pk)) - - return q, created_pks - - if segments: - # we don't to multi-level lookups - raise NotImplementedError - - if filter_type == 'pk': - # m2m__pk - if not field.concrete: - rel_model = field.related_model - - def get_changeset_m2m(items): - return items.get(rel_model, {}).get(filter_value, {}).get(field.field.name, ()) - - remove_pks = get_changeset_m2m(self._changeset.m2m_removed) - add_pks = get_changeset_m2m(self._changeset.m2m_added) - - if is_created_pk(filter_value): - pks = add_pks - return (Q(pk__in=(pk for pk in pks if not is_created_pk(pk))), - set(pk for pk in pks if is_created_pk(pk))) - - if filter_value is None or int(filter_value) in self._changeset.deleted_existing.get(rel_model, ()): - return Q(pk__in=()), set() - - return (((q & ~Q(pk__in=(pk for pk in remove_pks if not is_created_pk(pk)))) | - Q(pk__in=(pk for pk in add_pks if not is_created_pk(pk)))), - set(pk for pk in add_pks if is_created_pk(pk))) - - # sorry, no reverse lookup - raise NotImplementedError - - raise NotImplementedError - - # check if field is a deffered attribute, e.g. a CharField - if not field.is_relation: - if not segments: - # field= - return self._filter_values(q, field_name, lambda val: val == filter_value) - - filter_type = segments.pop(0) - - if not filter_type: - raise ValueError('Invalid filter: '+filter_name) - - if segments: - # we don't to field__whatever__whatever - raise NotImplementedError - - if filter_type == 'in': - # field__in - return self._filter_values(q, field_name, lambda val: val in filter_value) - - if filter_type == 'lt': - # field__lt - return self._filter_values(q, field_name, lambda val: val < filter_value) - - if filter_type == 'isnull': - # field__isnull - return self._filter_values(q, field_name, lambda val: (val is None) is filter_value) - - raise NotImplementedError - - raise NotImplementedError('cannot filter %s by %s (%s)' % (model, filter_name, field)) - - def _filter_q(self, q): - """ - filter by Q object. - Split it up into recursive _filter_q and _filter_kwarg calls and combine them again. - :return: new Q object and set of matched existing pks - """ - if not q.children: - return q, self._get_initial_created_pks() - filters, created_pks = zip(*((self._filter_q(c) if isinstance(c, Q) else self._filter_kwarg(*c)) - for c in q.children)) - result = Q(*filters) - result.connector = q.connector - result.negated = q.negated - - created_pks = reduce(operator.and_ if q.connector == 'AND' else operator.or_, created_pks) - if q.negated: - created_pks = self._get_initial_created_pks()-created_pks - return result, created_pks - - def _filter_or_exclude(self, negate, *args, **kwargs): - if not args and not kwargs: - return self._wrap_queryset(self._obj.filter()) - filters, created_pks = zip(*tuple(chain( - tuple(self._filter_q(q) for q in args), - tuple(self._filter_kwarg(name, value) for name, value in kwargs.items()) - ))) - - created_pks = reduce(operator.and_, created_pks) - if negate: - filters = (~Q(*filters), ) - created_pks = self._get_initial_created_pks()-created_pks - return self._wrap_queryset(self._obj.filter(*filters), created_pks=(self._created_pks & created_pks)) - - @get_queryset - def filter(self, *args, **kwargs): - return self._filter_or_exclude(False, *args, **kwargs) - - @get_queryset - def exclude(self, *args, **kwargs): - return self._filter_or_exclude(True, *args, **kwargs) - - @get_queryset - def count(self): - return self._obj.count()+len(tuple(self._get_created_objects(get_foreign_objects=False))) - - @get_queryset - def values_list(self, *args, flat=False): - own_values = (tuple(getattr(obj, arg) for arg in args) for obj in self._get_created_objects()) - if flat: - own_values = (v[0] for v in own_values) - return tuple(chain( - self._obj.values_list(*args, flat=flat), - own_values, - )) - - @get_queryset - def first(self): - if self._created_pks: - return next(self._get_created_objects()) - first = self._obj.first() - if first is not None: - first = self._wrap_instance(first) - return first - - @get_queryset - def using(self, alias): - return self._wrap_queryset(self._obj.using(alias)) - - @get_queryset - def extra(self, select): - """ - We only support the kind of extra() call that a many to many prefetch_related does. - """ - for key in select.keys(): - if not key.startswith('_prefetch_related_val'): - raise NotImplementedError('extra() calls are only supported for prefetch_related!') - return self._wrap_queryset(self._obj.extra(select), add_extra=tuple(select.keys())) - - @get_queryset - def _next_is_sticky(self): - """ - Needed by prefetch_related. - """ - return self._wrap_queryset(self._obj._next_is_sticky()) - - def _add_hints(self, *args, **kwargs): - return self._obj._add_hints(*args, **kwargs) - - def get_prefetch_queryset(self, *args, **kwargs): - return self._obj.get_prefetch_queryset(*args, **kwargs) - - def get_prefetch_querysets(self, *args, **kwargs): - return self._obj.get_prefetch_querysets(*args, **kwargs) - - def _apply_rel_filters(self, *args, **kwargs): - return self._obj._apply_rel_filters(*args, **kwargs) - - def create(self, *args, **kwargs): - obj = self.model(*args, **kwargs) - obj.save() - return obj - - -class ManagerWrapper(BaseQueryWrapper): - """ - Wraps a manager. - This class itself is used to wrap Model.objects managers. - """ - def get_queryset(self): - """ - make sure that the database does not return objects that have been deleted in this changeset - """ - qs = self._wrap_queryset(self._obj.model.objects.all()) - return qs.exclude(pk__in=tuple(chain(*(self._changeset.deleted_existing.get(submodel, ()) - for submodel in get_submodels(self._obj.model))))) - - def delete(self): - self.get_queryset().delete() - - -class RelatedManagerWrapper(ManagerWrapper): - """ - Wraps a related manager. - """ - def _get_cache_name(self): - """ - get cache name to fetch prefetch_related results - """ - return self._obj.field.related_query_name() - - def get_queryset(self): - """ - filter queryset by related manager filters - """ - return super().get_queryset().filter(**self._obj.core_filters) - - def all(self): - """ - get prefetched result if it exists - """ - try: - return self.instance._prefetched_objects_cache[self._get_cache_name()] - except (AttributeError, KeyError): - pass - return super().all() - - def create(self, *args, **kwargs): - if self.instance.pk is None: - raise TypeError - kwargs[self._obj.field.name] = self.instance - super().create(*args, **kwargs) - - -class ManyRelatedManagerWrapper(RelatedManagerWrapper): - """ - Wraps a many related manager (see RelatedManagerWrapper for details) - """ - def _check_through(self): - if not self._obj.through._meta.auto_created: - raise AttributeError('Cannot do this an a ManyToManyField which specifies an intermediary model.') - - def _get_cache_name(self): - return self._obj.prefetch_cache_name - - def set(self, objs): - if self._obj.reverse: - raise NotImplementedError - pks = set((obj.pk if isinstance(obj, models.Model) else obj) for obj in objs) - self._changeset.get_changed_object(self._obj.instance).m2m_set(self._get_cache_name(), pks) - - def add(self, *objs): - if self._obj.reverse: - raise NotImplementedError - pks = set((obj.pk if isinstance(obj, self._obj.model) else obj) for obj in objs) - self._changeset.get_changed_object(self._obj.instance).m2m_add(self._get_cache_name(), pks) - - def remove(self, *objs): - if self._obj.reverse: - raise NotImplementedError - pks = set((obj.pk if isinstance(obj, self._obj.model) else obj) for obj in objs) - self._changeset.get_changed_object(self._obj.instance).m2m_remove(self._get_cache_name(), pks) - - def all(self): - try: - return self.instance._prefetched_objects_cache[self._get_cache_name()] - except (AttributeError, KeyError): - pass - return super().all() - - def create(self, *args, **kwargs): - raise NotImplementedError - - -class QuerySetWrapper(BaseQueryWrapper): - """ - Wraps a queryset. - """ - def _get_created_objects(self, get_foreign_objects=True): - """ - Get ModelInstanceWrapper instance for all matched created objects. - """ - return (self._changeset.get_created_object(self._obj.model, pk, get_foreign_objects=get_foreign_objects) - for pk in sorted(self._created_pks)) - - def _ordering_key_func(self, ordering): - def key_func(obj): - result = [] - for field in ordering: - fact = -1 if field[0] == '-' else 1 - field = field.lstrip('-') - - field_split = field.split('__') - field = field_split.pop() - final_obj = obj - for subfield in field_split: - final_obj = getattr(final_obj, subfield) - - val = getattr(obj, field) - if field in ('id', 'pk'): - if isinstance(val, int): - result.extend((1*fact, val*fact)) - else: - result.extend((2*fact, int(val[1:])*fact)) - else: - result.append(val * fact) - return tuple(result) - return key_func - - def _get_cached_result(self): - """ - Get results, make sure prefetch is prefetching and so on. - """ - obj = self._obj - obj._prefetch_done = True - obj._fetch_all() - - result = [self._wrap_instance(instance) for instance in obj._result_cache] - obj._result_cache = result - obj._prefetch_done = False - obj._fetch_all() - - result += list(self._get_created_objects()) - - ordering = self._obj.query.order_by - if ordering: - result = sorted(result, key=self._ordering_key_func(ordering)) - - for extra in self._extra: - # implementing the extra() call for prefetch_related - ex = extra[22:] - for f in self._obj.model._meta.get_fields(): - if isinstance(f, ManyToManyRel) and f.through._meta.get_field(f.field.m2m_field_name()).attname == ex: - objs_by_pk = OrderedDict() - for instance in result: - objs_by_pk.setdefault(instance.pk, OrderedDict())[getattr(instance, extra, None)] = instance - - m2m_added = self._changeset.m2m_added.get(f.field.model, {}) - m2m_removed = self._changeset.m2m_removed.get(f.field.model, {}) - for related_pk, changes in m2m_added.items(): - for pk in changes.get(f.field.name, ()): - if pk in objs_by_pk and related_pk not in objs_by_pk[pk]: - new = self._wrap_instance(next(iter(objs_by_pk[pk].values()))._obj) - new.__dict__[extra] = related_pk - objs_by_pk[pk][related_pk] = new - - for related_pk, changes in m2m_removed.items(): - for pk in changes.get(f.field.name, ()): - if pk in objs_by_pk and related_pk in objs_by_pk[pk]: - objs_by_pk[pk].pop(related_pk) - - for pk, instances in objs_by_pk.items(): - instances.pop(None, None) - - result = list(chain(*(instances.values() for instances in objs_by_pk.values()))) - break - else: - raise NotImplementedError('Cannot do extra() for ' + extra) - - obj._result_cache = result - return result - - @cached_property - def _cached_result(self): - return self._get_cached_result() - - @property - def _result_cache(self): - return self._cached_result - - @_result_cache.setter - def _result_cache(self, value): - # prefetch_related will try to set this property - # it has to overwrite our final result because it already contains the created objects - self.__dict__['_cached_result'] = value - - def __iter__(self): - return iter(self._cached_result) - - def iterator(self): - return iter(chain( - (self._wrap_instance(instance) for instance in self._obj.iterator()), - self._get_created_objects(), - )) - - def __len__(self): - return len(self._cached_result) - - def delete(self): - for obj in self: - obj.delete() - - @property - def _iterable_class(self): - return self._obj._iterable_class