diff --git a/src/c3nav/editor/changes.py b/src/c3nav/editor/changes.py new file mode 100644 index 00000000..3ae4628e --- /dev/null +++ b/src/c3nav/editor/changes.py @@ -0,0 +1,82 @@ +from itertools import chain + +from django.apps import apps + +from c3nav.api.schema import BaseSchema +from c3nav.editor.operations import DatabaseOperationCollection, CreateObjectOperation, UpdateObjectOperation, \ + DeleteObjectOperation, ClearManyToManyOperation, FieldValuesDict, ObjectReference, PreviousObjectCollection +from c3nav.mapdata.fields import I18nField + + +class ChangedManyToMany(BaseSchema): + cleared: bool = False + added: list[str] = [] + removed: list[str] = [] + + +class ChangedObject(BaseSchema): + obj: ObjectReference + titles: dict[str, str] | None + created: bool = False + deleted: bool = False + fields: FieldValuesDict = {} + m2m_changes: dict[str, ChangedManyToMany] = {} + + +class ChangedObjectCollection(BaseSchema): + """ + A collection of ChangedObject instances, sorted by model and id. + Also stores a PreviousObjectCollection for comparison with the current state. + Iterable as a list of ChangedObject instances. + """ + prev: PreviousObjectCollection = PreviousObjectCollection() + objects: dict[str, dict[int, ChangedObject]] = {} + + def __iter__(self): + yield from chain(*(objects.keys() for model, objects in self.objects.items())) + + def add_operations(self, operations: DatabaseOperationCollection): + """ + Add the given operations, creating/updating changed objects to represent the resulting state. + """ + # todo: merge prev + for operation in operations.operations: + changed_object = self.objects.setdefault(operation.obj.model, {}).get(operation.obj.id, None) + if changed_object is None: + changed_object = ChangedObject(obj=operation.obj, + titles=self.prev.get(operation.obj).titles) + self.objects[operation.obj.model][operation.obj.id] = changed_object + if isinstance(operation, CreateObjectOperation): + changed_object.created = True + changed_object.fields.update(operation.fields) + elif isinstance(operation, UpdateObjectOperation): + model = apps.get_model('mapdata', operation.obj.model) + for field_name, value in operation.fields.items(): + field = model._meta.get_field(field_name) + if isinstance(field, I18nField) and field_name in changed_object.fields: + changed_object.fields[field_name] = { + lang: val for lang, val in {**changed_object.fields[field_name], **value}.items() + } + else: + changed_object.fields[field_name] = value + elif isinstance(operation, DeleteObjectOperation): + changed_object.deleted = False + else: + changed_m2m = changed_object.m2m_changes.get(operation.field, None) + if changed_m2m is None: + changed_m2m = ChangedManyToMany() + changed_object.m2m_changes[operation.field] = changed_m2m + if isinstance(operation, ClearManyToManyOperation): + changed_m2m.cleared = True + changed_m2m.added = [] + changed_m2m.removed = [] + else: + changed_m2m.added = sorted((set(changed_m2m.added) | operation.add_values) + - operation.remove_values) + changed_m2m.removed = sorted((set(changed_m2m.removed) - operation.add_values) + | operation.remove_values) + + @property + def as_operations(self) -> DatabaseOperationCollection: + + pass # todo: implement \ No newline at end of file diff --git a/src/c3nav/editor/migrations/0004_changeset_rewrite_2024.py b/src/c3nav/editor/migrations/0004_changeset_rewrite_2024.py index 2e3ab351..7d217349 100644 --- a/src/c3nav/editor/migrations/0004_changeset_rewrite_2024.py +++ b/src/c3nav/editor/migrations/0004_changeset_rewrite_2024.py @@ -1,6 +1,6 @@ -# Generated by Django 5.0.8 on 2024-08-26 09:46 +# Generated by Django 5.0.8 on 2024-09-26 10:06 -import c3nav.editor.operations +import c3nav.editor.changes import django.core.serializers.json import django_pydantic_field.fields from django.db import migrations @@ -20,7 +20,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name='changeset', name='changes', - field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.operations.CollectedOperations, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.operations.CollectedOperations), + field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.changes.ChangedObjectCollection, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.changes.ChangedObjectCollection), ), migrations.DeleteModel( name='ChangedObject', diff --git a/src/c3nav/editor/models/changeset.py b/src/c3nav/editor/models/changeset.py index 9f6146a8..bf0f9bc7 100644 --- a/src/c3nav/editor/models/changeset.py +++ b/src/c3nav/editor/models/changeset.py @@ -12,7 +12,7 @@ from django.utils.translation import gettext_lazy as _ from django.utils.translation import ngettext_lazy from django_pydantic_field import SchemaField -from c3nav.editor.operations import CollectedOperations +from c3nav.editor.changes import ChangedObjectCollection from c3nav.editor.tasks import send_changeset_proposed_notification from c3nav.mapdata.models import LocationSlug, MapUpdate from c3nav.mapdata.models.locations import LocationRedirect @@ -43,7 +43,7 @@ class ChangeSet(models.Model): related_name='assigned_changesets', verbose_name=_('assigned to')) map_update = models.OneToOneField(MapUpdate, null=True, related_name='changeset', verbose_name=_('map update'), on_delete=models.PROTECT) - changes: CollectedOperations = SchemaField(schema=CollectedOperations, default=CollectedOperations) + changes: ChangedObjectCollection = SchemaField(schema=ChangedObjectCollection, default=ChangedObjectCollection) class Meta: verbose_name = _('Change Set') diff --git a/src/c3nav/editor/operations.py b/src/c3nav/editor/operations.py index 3a7ad05c..5a86692d 100644 --- a/src/c3nav/editor/operations.py +++ b/src/c3nav/editor/operations.py @@ -1,15 +1,14 @@ -import copy import datetime import json from dataclasses import dataclass -from typing import TypeAlias, Any, Annotated, Literal, Union +from typing import Annotated, Literal, Union, TypeAlias, Any from uuid import UUID, uuid4 from django.apps import apps from django.core import serializers from django.db.models import Model from django.utils import timezone -from pydantic.config import ConfigDict +from pydantic import ConfigDict from pydantic.fields import Field from pydantic.types import Discriminator @@ -17,10 +16,14 @@ from c3nav.api.schema import BaseSchema from c3nav.mapdata.fields import I18nField from c3nav.mapdata.models import LocationSlug + FieldValuesDict: TypeAlias = dict[str, Any] class ObjectReference(BaseSchema): + """ + Reference to an object based on model name and ID. + """ model_config = ConfigDict(frozen=True) model: str id: int @@ -31,25 +34,42 @@ class ObjectReference(BaseSchema): class PreviousObject(BaseSchema): + """ + Represents the previous state of an objects, consisting of its values and its (multi-language) titles + """ titles: dict[str, str] | None values: FieldValuesDict -class PreviousObjects(BaseSchema): +class PreviousObjectCollection(BaseSchema): objects: dict[str, dict[int, PreviousObject]] = {} def get(self, ref: ObjectReference) -> PreviousObject | None: return self.objects.get(ref.model, {}).get(ref.id, None) + def get_ids(self) -> dict[str, set[int]]: + """ + :return: all referenced IDs sorted by model + """ + return {model: set(objs.keys()) for model, objs in self.objects.items()} + + def get_instances(self) -> dict[ObjectReference, Model]: + """ + :return: all reference objects as fetched from the databse right now + """ + instances: dict[ObjectReference, Model] = {} + for model_name, ids in self.get_ids().items(): + model = apps.get_model("mapdata", model_name) + instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance) + for instance in model.objects.filter(pk__in=ids))) + return instances + def set(self, ref: ObjectReference, values: FieldValuesDict, titles: dict | None): self.objects.setdefault(ref.model, {})[ref.id] = PreviousObject( values=values, titles=titles, ) - def get_ids(self) -> dict[str, set[int]]: - return {model: set(objs.keys()) for model, objs in self.objects.items()} - class BaseOperation(BaseSchema): obj: ObjectReference @@ -155,11 +175,6 @@ class ClearManyToManyOperation(BaseOperation): return instance -class RevertOperation(BaseOperation): - type: Literal["revert"] = "revert" - reverts: UUID - - DatabaseOperation = Annotated[ Union[ CreateObjectOperation, @@ -172,94 +187,34 @@ DatabaseOperation = Annotated[ ] -class ChangedManyToMany(BaseSchema): - cleared: bool = False - added: list[str] = [] - removed: list[str] = [] - - -class ChangedObject(BaseSchema): - obj: ObjectReference - titles: dict[str, str] | None - created: bool = False - deleted: bool = False - fields: FieldValuesDict = {} - m2m_changes: dict[str, ChangedManyToMany] = {} - - -class CollectedOperations(BaseSchema): - uuid: UUID = Field(default_factory=uuid4) - prev: PreviousObjects = PreviousObjects() +class DatabaseOperationCollection(BaseSchema): + """ + A collection of database operations, sorted by model and id. + Also stores a PreviousObjectCollection for comparison with the current state. + Iterable as a list of DatabaseOperation instances. + """ + prev: PreviousObjectCollection = PreviousObjectCollection() operations: list[DatabaseOperation] = [] - def prefetch(self) -> "CollectedChangesPrefetch": - ids_to_query: dict[str, set[int]] = self.prev.get_ids() + def __iter__(self): + yield from self.operations - instances: dict[ObjectReference, Model] = {} - for model_name, ids in ids_to_query.items(): - model = apps.get_model("mapdata", model_name) - instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance) - for instance in model.objects.filter(pk__in=ids))) - - return CollectedChangesPrefetch(changes=self, instances=instances) - - @property - def changed_objects(self) -> list[ChangedObject]: - objects = {} - reverted_uuids = frozenset(operation.reverts for operation in self.operationsy - if isinstance(operation, RevertOperation)) - for operation in self.operations: - if operation.uuid in reverted_uuids: - continue - changed_object = objects.get(operation.obj, None) - if changed_object is None: - changed_object = ChangedObject(obj=operation.obj, - titles=self.prev_titles[operation.obj.model][operation.obj.id]) - objects[operation.obj] = changed_object - if isinstance(operation, CreateObjectOperation): - changed_object.created = True - changed_object.fields.update(operation.fields) - elif isinstance(operation, UpdateObjectOperation): - model = apps.get_model('mapdata', operation.obj.model) - for field_name, value in operation.fields.items(): - field = model._meta.get_field(field_name) - if isinstance(field, I18nField) and field_name in changed_object.fields: - changed_object.fields[field_name] = { - lang: val for lang, val in {**changed_object.fields[field_name], **value}.items() - } - else: - changed_object.fields[field_name] = value - elif isinstance(operation, DeleteObjectOperation): - changed_object.deleted = False - else: - changed_m2m = changed_object.m2m_changes.get(operation.field, None) - if changed_m2m is None: - changed_m2m = ChangedManyToMany() - changed_object.m2m_changes[operation.field] = changed_m2m - if isinstance(operation, ClearManyToManyOperation): - changed_m2m.cleared = True - changed_m2m.added = [] - changed_m2m.removed = [] - else: - changed_m2m.added = sorted((set(changed_m2m.added) | operation.add_values) - - operation.remove_values) - changed_m2m.removed = sorted((set(changed_m2m.removed) - operation.add_values) - | operation.remove_values) - return list(objects.values()) + def prefetch(self) -> "PrefetchedDatabaseOperationCollection": + return PrefetchedDatabaseOperationCollection(operations=self, instances=self.prev.get_instances()) @dataclass -class CollectedChangesPrefetch: - changes: CollectedOperations +class PrefetchedDatabaseOperationCollection: + operations: DatabaseOperationCollection instances: dict[ObjectReference, Model] def apply(self): # todo: what if unique constraint error occurs? - for operation in self.changes.operations: + for operation in self.operations.operations: if isinstance(operation, CreateObjectOperation): self.instances[operation.obj] = operation.apply_create() else: - prev_obj = self.changes.prev.get(operation.obj) + prev_obj = self.operations.prev.get(operation.obj) if prev_obj is None: print('WARN WARN WARN') values = prev_obj.values @@ -271,5 +226,4 @@ class CollectedChangesPrefetch: else: instance = None if instance is not None: - operation.apply(values=values, instance=instance) - + operation.apply(values=values, instance=instance) \ No newline at end of file diff --git a/src/c3nav/editor/overlay.py b/src/c3nav/editor/overlay.py index ee846d85..881adc71 100644 --- a/src/c3nav/editor/overlay.py +++ b/src/c3nav/editor/overlay.py @@ -1,3 +1,4 @@ +import copy import json from contextlib import contextmanager from dataclasses import dataclass, field @@ -8,8 +9,10 @@ from django.db import transaction from django.db.models import Model from django.db.models.fields.related import ManyToManyField -from c3nav.editor.operations import DatabaseOperation, ObjectReference, FieldValuesDict, CreateObjectOperation, \ - UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, CollectedOperations +from c3nav.editor.changes import ChangedObjectCollection +from c3nav.editor.operations import DatabaseOperation, CreateObjectOperation, \ + UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, \ + DatabaseOperationCollection, FieldValuesDict, ObjectReference, PreviousObjectCollection from c3nav.mapdata.fields import I18nField from c3nav.mapdata.models import LocationSlug @@ -28,21 +31,31 @@ class InterceptAbortTransaction(Exception): @dataclass class DatabaseOverlayManager: - changes: CollectedOperations - new_operations: list[DatabaseOperation] = field(default_factory=list) - pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict) + """ + This class handles the currently active database overlay and will apply and/or intercept changes. + """ + prev: PreviousObjectCollection = PreviousObjectCollection() + operations: list[DatabaseOperation] = field(default_factory=list) + pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict, init=False, repr=False) @classmethod @contextmanager - def enable(cls, changes: CollectedOperations | None, commit: bool): - if getattr(overlay_state, 'manager', None) is not None: - raise TypeError - if changes is None: - changes = CollectedOperations() + def enable(cls, operations: DatabaseOperationCollection | None = None, commit: bool = False): + """ + Context manager to enable the database overlay, optionally pre-applying the given changes. + Only one overlay can be active at the same type, or else you get a TypeError. + + :param operations: what operations to pre-apply + :param commit: whether to actually commit operations to the database or revert them at the end + """ + if getattr(overlay_state, "manager", None) is not None: + raise TypeError("Only one overlay can be active at the same time") + if operations is None: + operations = DatabaseOperationCollection() try: with transaction.atomic(): - manager = DatabaseOverlayManager(changes) - manager.changes.prefetch().apply() + manager = DatabaseOverlayManager(prev=copy.deepcopy(operations.prev)) + operations.prefetch().apply() overlay_state.manager = manager yield manager if not commit: @@ -52,9 +65,6 @@ class DatabaseOverlayManager: finally: overlay_state.manager = None - def save_new_operations(self): - self.changes.operations.extend(self.new_operations) - @staticmethod def get_model_field_values(instance: Model) -> FieldValuesDict: values = json.loads(serializers.serialize("json", [instance]))[0]["fields"] @@ -66,14 +76,14 @@ class DatabaseOverlayManager: ref = ObjectReference.from_instance(instance) pre_change_values = self.pre_change_values.pop(ref) - self.changes.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None)) + self.operations.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None)) return ref, pre_change_values def handle_pre_change_instance(self, instance: Model, **kwargs): if instance.pk is None: return ref = ObjectReference.from_instance(instance) - if ref not in self.pre_change_values and self.changes.prev.get(ref) is None: + if ref not in self.pre_change_values and self.operations.prev.get(ref) is None: self.pre_change_values[ref] = self.get_model_field_values( instance._meta.model.objects.get(pk=instance.pk) ) @@ -84,7 +94,7 @@ class DatabaseOverlayManager: ref, pre_change_values = self.get_ref_and_pre_change_values(instance) if created: - self.new_operations.append(CreateObjectOperation(obj=ref, fields=field_values)) + self.operations.append(CreateObjectOperation(obj=ref, fields=field_values)) return if update_fields: @@ -105,11 +115,11 @@ class DatabaseOverlayManager: diff_val[lang] = after_val.get(lang, None) field_values[field_name] = diff_val - self.new_operations.append(UpdateObjectOperation(obj=ref, fields=field_values)) + self.operations.append(UpdateObjectOperation(obj=ref, fields=field_values)) def handle_post_delete(self, instance: Model, **kwargs): ref, pre_change_values = self.get_ref_and_pre_change_values(instance) - self.new_operations.append(DeleteObjectOperation(obj=ref)) + self.operations.append(DeleteObjectOperation(obj=ref)) def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model], pk_set: set | None, reverse: bool, **kwargs): @@ -128,11 +138,11 @@ class DatabaseOverlayManager: ref, pre_change_values = self.get_ref_and_pre_change_values(instance) if action == "post_clear": - self.new_operations.append(ClearManyToManyOperation(obj=ref, field=field.name)) + self.operations.append(ClearManyToManyOperation(obj=ref, field=field.name)) return - if self.new_operations: - last_change = self.new_operations[-1] + if self.operations: + last_change = self.operations[-1] if isinstance(last_change, UpdateManyToManyOperation) and last_change == ref and last_change == field.name: if action == "post_add": last_change.add_values.update(pk_set) @@ -143,9 +153,9 @@ class DatabaseOverlayManager: return if action == "post_add": - self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set))) + self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set))) else: - self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set))) + self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set))) def handle_pre_change_instance(sender: Type[Model], **kwargs): diff --git a/src/c3nav/editor/views/base.py b/src/c3nav/editor/views/base.py index c6509bda..6151e2bf 100644 --- a/src/c3nav/editor/views/base.py +++ b/src/c3nav/editor/views/base.py @@ -48,7 +48,7 @@ def accesses_mapdata(func): if request.changeset.direct_editing: with (MapUpdate.lock() if writable_method else noctx()): changed_geometries.reset() - with DatabaseOverlayManager.enable(changes=None, commit=writable_method) as manager: + with DatabaseOverlayManager.enable(operations=None, commit=writable_method) as manager: result = func(request, *args, **kwargs) if manager.new_operations: if writable_method: @@ -57,10 +57,11 @@ def accesses_mapdata(func): raise ValueError # todo: good error message, but this shouldn't happen else: with maybe_lock_changeset_to_edit(request=request): - with DatabaseOverlayManager.enable(changes=request.changeset.changes, commit=False) as manager: + operations = request.changeset.changes.as_operations # todo: cache this + with DatabaseOverlayManager.enable(operations=operations, commit=False) as manager: result = func(request, *args, **kwargs) - if manager.new_operations: - manager.save_new_operations() + if manager.operations: + request.changeset.changes.add_operations(manager.operations) request.changeset.save() update = request.changeset.updates.create(user=request.user, objects_changed=True) request.changeset.last_update = update diff --git a/src/c3nav/editor/views/changes.py b/src/c3nav/editor/views/changes.py index 3281003e..48918f35 100644 --- a/src/c3nav/editor/views/changes.py +++ b/src/c3nav/editor/views/changes.py @@ -193,7 +193,7 @@ def changeset_detail(request, pk): added_redirects = {} removed_redirects = {} - for changed_object in changeset.changes.changed_objects: + for changed_object in changeset.changes: if changed_object.obj.model == "locationredirect": if changed_object.created and not changed_object.deleted: added_redirects.setdefault(changed_object.fields["target"], set()).add(changed_object.fields["slug"]) @@ -205,7 +205,7 @@ def changeset_detail(request, pk): current_lang = get_language() - for changed_object in changeset.changes.changed_objects: + for changed_object in changeset.changes: model = apps.get_model("mapdata", changed_object.obj.model) if model == LocationRedirect: continue @@ -218,7 +218,7 @@ def changeset_detail(request, pk): else: title = next(iter(changed_object.titles.values())) - prev_values = changeset.changes.prev.get(changed_object.obj).values + prev_values = changeset.operations.prev.get(changed_object.obj).values edit_url = None if not changed_object.deleted: diff --git a/src/c3nav/editor/views/edit.py b/src/c3nav/editor/views/edit.py index 0bba20f7..4fb017f9 100644 --- a/src/c3nav/editor/views/edit.py +++ b/src/c3nav/editor/views/edit.py @@ -128,7 +128,7 @@ def space_detail(request, level, pk): def get_changeset_exceeded(request): - return request.user_permissions.max_changeset_changes <= len(request.changeset.changes.operations) + return request.user_permissions.max_changeset_changes <= len(request.changeset.operations.operations) @etag(editor_etag_func)