diff --git a/src/c3nav/editor/apps.py b/src/c3nav/editor/apps.py index a7c1e22b..b89ae379 100644 --- a/src/c3nav/editor/apps.py +++ b/src/c3nav/editor/apps.py @@ -1,15 +1,18 @@ from django.apps import AppConfig from django.contrib.auth import user_logged_in -from django.db.models.signals import m2m_changed, post_delete, post_save +from django.db.models.signals import m2m_changed, post_delete, post_save, pre_save, pre_delete + +from c3nav.editor import changes class EditorConfig(AppConfig): name = 'c3nav.editor' def ready(self): - from c3nav.editor.models import ChangeSet from c3nav.editor.signals import set_changeset_author_on_login - post_save.connect(ChangeSet.object_changed_handler) - post_delete.connect(ChangeSet.object_changed_handler) - m2m_changed.connect(ChangeSet.object_changed_handler) + pre_save.connect(changes.handle_pre_change_instance) + pre_delete.connect(changes.handle_pre_change_instance) + post_save.connect(changes.handle_post_save) + post_delete.connect(changes.handle_post_delete) + m2m_changed.connect(changes.handle_m2m_changed) user_logged_in.connect(set_changeset_author_on_login) diff --git a/src/c3nav/editor/changes.py b/src/c3nav/editor/changes.py new file mode 100644 index 00000000..b70af8c8 --- /dev/null +++ b/src/c3nav/editor/changes.py @@ -0,0 +1,240 @@ +import datetime +import json +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import TypeAlias, Literal, Annotated, Union, Type, Any + +from django.core import serializers +from django.db import transaction +from django.db.models import Model +from django.db.models.fields.related import OneToOneField, ForeignKey +from django.utils import timezone +from pydantic import ConfigDict, Discriminator +from pydantic.fields import Field + +from c3nav.api.schema import BaseSchema + +try: + from asgiref.local import Local as LocalContext +except ImportError: + from threading import local as LocalContext + + +FieldValuesDict: TypeAlias = dict[str, Any] +ExistingOrCreatedID: TypeAlias = int # negative = temporary ID of created object + + +class ObjectReference(BaseSchema): + model_config = ConfigDict(frozen=True) + model: str + id: ExistingOrCreatedID + + @classmethod + def simple_from_instance(cls, instance: Model): + """ + This method will not convert the ID yet! + """ + return cls(model=instance._meta.model_name, id=instance.pk) + + +class BaseChange(BaseSchema): + obj: ObjectReference + datetime: Annotated[datetime.datetime, Field(default_factory=timezone.now)] + + +class CreateObjectChange(BaseChange): + type: Literal["create"] = "create" + fields: FieldValuesDict + + +class UpdateObjectChange(BaseChange): + type: Literal["update"] = "update" + fields: FieldValuesDict + + +class DeleteObjectChange(BaseChange): + type: Literal["delete"] = "delete" + + +class AddManyToManyChange(BaseSchema): + type: Literal["m2m_add"] = "m2m_add" + field: str + values: list[int] + + +class RemoveManyToManyChange(BaseSchema): + type: Literal["m2m_remove"] = "m2m_remove" + field: str + values: list[int] + + +class ClearManyToManyChange(BaseSchema): + type: Literal["m2m_clear"] = "m2m_clear" + field: str + + +ChangeSetChange = Annotated[ + Union[ + CreateObjectChange, + UpdateObjectChange, + DeleteObjectChange, + AddManyToManyChange, + RemoveManyToManyChange, + ClearManyToManyChange, + ], + Discriminator("type"), +] + + +class ChangeSetChanges(BaseSchema): + prev_reprs: dict[ObjectReference, str] = {} + prev_values: dict[ObjectReference, FieldValuesDict] = {} + prev_m2m: dict[ObjectReference, dict[str, list[int]]] = {} + changes: list[ChangeSetChange] = [] + + +overlay_state = LocalContext() + + +class InterceptAbortTransaction(Exception): + pass + + +@contextmanager +def enable_changeset_overlay(changeset): + try: + with transaction.atomic(): + manager = ChangesetOverlayManager(changeset.changes) + overlay_state.manager = manager + # todo: apply changes so far + yield + raise InterceptAbortTransaction + except InterceptAbortTransaction: + pass + finally: + overlay_state.manager = None + + +@dataclass +class ChangesetOverlayManager: + changes: ChangeSetChanges + new_changes: bool = False + pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict) + + # maps negative IDs of created objects to the ID during the current transaction + mapped_ids: dict[ObjectReference, int] = field(default_factory=dict) + + # maps IDs as used during the current transaction to the negative IDs + reverse_mapped_ids: dict[ObjectReference, int] = field(default_factory=dict) + + def ref_lookup(self, ref: ObjectReference): + local_value = self.mapped_ids.get(ref, None) + return ref if local_value is None else ObjectReference(model=ref.model, id=local_value) + + def reverse_ref_lookup(self, ref: ObjectReference): + created_value = self.reverse_mapped_ids.get(ref, None) + return ref if created_value is None else ObjectReference(model=ref.model, id=created_value) + + def get_model_field_values(self, instance: Model) -> FieldValuesDict: + values = json.loads(serializers.serialize("json", [instance]))[0]["fields"] + for field in instance._meta.get_fields(): + if field.name not in values: + continue + if isinstance(field, (OneToOneField, ForeignKey)): + value = values[field.name] + if value is not None: + values[field.name] = self.reverse_ref_lookup( + ObjectReference(model=field.model._meta.model_name, id=value) + ).id + return values + + def handle_pre_change_instance(self, sender: Type[Model], instance: Model, **kwargs): + if instance.pk is None: + return + ref = ObjectReference.simple_from_instance(instance) + if ref in self.reverse_mapped_ids: + return + if ref not in self.pre_change_values and ref not in self.changes.prev_values: + self.pre_change_values[ref] = self.get_model_field_values( + instance._meta.model.objects.get(pk=instance.pk) + ) + + def handle_post_save(self, sender: Type[Model], instance: Model, created: bool, + update_fields: set | None, **kwargs): + field_values = self.get_model_field_values(instance) + + if created: + created_id = min([change.obj.id for change in self.changes + if isinstance(change, CreateObjectChange)], default=0)-1 + model_name = instance._meta.model_name + self.mapped_ids[ObjectReference(model=model_name, id=created_id)] = instance.pk + self.reverse_mapped_ids[ObjectReference(model=model_name, id=instance.pk)] = created_id + self.changes.changes.append(CreateObjectChange( + obj=ObjectReference(model=instance._meta.model_name, id=created_id), + fields=field_values + )) + from pprint import pprint + pprint(self.changes) + return + + if update_fields: + field_values = {name: value for name, value in field_values.items() if name in update_fields} + + ref = self.reverse_ref_lookup(ObjectReference.simple_from_instance(instance)) + pre_change_values = self.pre_change_values.pop(ref, None) + if pre_change_values: + self.changes.prev_values[ref] = pre_change_values + self.changes.prev_reprs[ref] = str(instance) + self.changes.changes.append(UpdateObjectChange( + obj=ref, + fields=field_values + )) + from pprint import pprint + pprint(self.changes) + + def handle_post_delete(self, sender: Type[Model], instance: Model, **kwargs): + ref = self.reverse_ref_lookup(ObjectReference.simple_from_instance(instance)) + pre_change_values = self.pre_change_values.pop(ref, None) + if pre_change_values: + self.changes.prev_values[ref] = pre_change_values + self.changes.prev_reprs[ref] = str(instance) + self.changes.changes.append(DeleteObjectChange( + obj=ref, + )) + from pprint import pprint + pprint(self.changes) + + def handle_m2m_changed(self, sender: Type[Model], instance: Model, **kwargs): + pass + + +def handle_pre_change_instance(sender: Type[Model], **kwargs): + if sender._meta.app_label != 'mapdata': + return + manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None) + if manager: + manager.handle_pre_change_instance(sender=sender, **kwargs) + + +def handle_post_save(sender: Type[Model], **kwargs): + if sender._meta.app_label != 'mapdata': + return + manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None) + if manager: + manager.handle_post_save(sender=sender, **kwargs) + + +def handle_post_delete(sender: Type[Model], **kwargs): + if sender._meta.app_label != 'mapdata': + return + manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None) + if manager: + manager.handle_post_delete(sender=sender, **kwargs) + + +def handle_m2m_changed(sender: Type[Model], **kwargs): + if sender._meta.app_label != 'mapdata': + return + manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None) + if manager: + manager.handle_m2m_changed(sender=sender, **kwargs) diff --git a/src/c3nav/editor/intercept.py b/src/c3nav/editor/intercept.py deleted file mode 100644 index acf88ac1..00000000 --- a/src/c3nav/editor/intercept.py +++ /dev/null @@ -1,29 +0,0 @@ -from contextlib import contextmanager -from functools import wraps - -from django.db import transaction - -from c3nav.editor.models import ChangeSet - -try: - from asgiref.local import Local as LocalContext -except ImportError: - from threading import local as LocalContext - - -intercept = LocalContext() - - -class InterceptAbortTransaction(Exception): - pass - - -@contextmanager -def enable_changeset_overlay(changeset): - try: - with transaction.atomic(): - # todo: apply changes so far - yield - raise InterceptAbortTransaction - except InterceptAbortTransaction: - pass diff --git a/src/c3nav/editor/migrations/0004_changeset_rewrite_2024.py b/src/c3nav/editor/migrations/0004_changeset_rewrite_2024.py index 32c96a62..31cbbe85 100644 --- a/src/c3nav/editor/migrations/0004_changeset_rewrite_2024.py +++ b/src/c3nav/editor/migrations/0004_changeset_rewrite_2024.py @@ -1,10 +1,12 @@ # Generated by Django 5.0.8 on 2024-08-22 17:03 - +import c3nav.editor.changes import c3nav.editor.models.changeset import django.core.serializers.json import django_pydantic_field.fields from django.db import migrations +import c3nav.editor.changes + class Migration(migrations.Migration): @@ -20,7 +22,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name='changeset', name='changes', - field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.models.changeset.ChangeSetChanges, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.models.changeset.ChangeSetChanges), + field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.changes.ChangeSetChanges, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.changes.ChangeSetChanges), ), migrations.DeleteModel( name='ChangedObject', diff --git a/src/c3nav/editor/models/changeset.py b/src/c3nav/editor/models/changeset.py index b7f530be..06bd68b4 100644 --- a/src/c3nav/editor/models/changeset.py +++ b/src/c3nav/editor/models/changeset.py @@ -1,8 +1,5 @@ -import datetime from collections import OrderedDict from contextlib import contextmanager -from enum import StrEnum -from typing import Literal, TypeAlias, Union, Annotated from django.apps import apps from django.conf import settings @@ -15,92 +12,13 @@ from django.utils.timezone import make_naive from django.utils.translation import gettext_lazy as _ from django.utils.translation import ngettext_lazy from django_pydantic_field import SchemaField -from pydantic.config import ConfigDict -from pydantic.fields import Field -from pydantic.types import Discriminator -from c3nav.api.schema import BaseSchema +from c3nav.editor.changes import ChangeSetChanges from c3nav.editor.tasks import send_changeset_proposed_notification from c3nav.mapdata.models import LocationSlug, MapUpdate from c3nav.mapdata.models.locations import LocationRedirect from c3nav.mapdata.utils.cache.changes import changed_geometries -FieldValuesDict: TypeAlias = dict[int, str] -ExistingOrCreatedID: TypeAlias = int # negative = temporary ID of created object - - -class ObjectReferenceType(StrEnum): - EXISTING = "existing" - CREATED = "created" - - -class ObjectReference(BaseSchema): - model_config = ConfigDict(frozen=True) - model: str - id: ExistingOrCreatedID - - -class BaseChange(BaseSchema): - obj: ObjectReference - datetime: datetime.datetime - - -class CreateObjectChange(BaseChange): - type: Literal["create"] - fields: FieldValuesDict - - -class UpdateObjectChange(BaseChange): - type: Literal["update"] - fields: FieldValuesDict - - -class DeleteObjectChange(BaseChange): - type: Literal["delete"] - - -class AddManyToManyChange(BaseSchema): - type: Literal["m2m_add"] - field: str - values: list[int] - - -class RemoveManyToManyChange(BaseSchema): - type: Literal["m2m_remove"] - field: str - values: list[int] - - -class ClearManyToManyChange(BaseSchema): - type: Literal["m2m_clear"] - field: str - - -ChangeSetChange = Annotated[ - Union[ - CreateObjectChange, - UpdateObjectChange, - DeleteObjectChange, - AddManyToManyChange, - RemoveManyToManyChange, - ClearManyToManyChange, - ], - Discriminator("type"), -] - - -class ChangeSetChanges(BaseSchema): - prev_reprs: dict[ObjectReference, str] = {} - prev_values: dict[ObjectReference, FieldValuesDict] = {} - prev_m2m: dict[ObjectReference, dict[str, list[int]]] = {} - changes: list[ChangeSetChange] = [] - - # maps negative IDs of created objects to the ID during the current transaction - mapped_ids: Annotated[dict[ObjectReference, int], Field(exclude=True)] = {} - - # maps IDs as used during the current transaction to the negative IDs - reverse_mapped_ids: Annotated[dict[ObjectReference, int], Field(exclude=True)] = {} - class ChangeSet(models.Model): STATES = ( diff --git a/src/c3nav/editor/views/base.py b/src/c3nav/editor/views/base.py index 24f18e02..50d67f41 100644 --- a/src/c3nav/editor/views/base.py +++ b/src/c3nav/editor/views/base.py @@ -15,7 +15,7 @@ from django.utils.cache import patch_vary_headers from django.utils.translation import get_language from django.utils.translation import gettext_lazy as _ -from c3nav.editor.intercept import enable_changeset_overlay +from c3nav.editor.changes import enable_changeset_overlay from c3nav.editor.models import ChangeSet from c3nav.mapdata.models.access import AccessPermission from c3nav.mapdata.models.base import SerializableMixin