intercept save and delete

This commit is contained in:
Laura Klünder 2024-08-22 22:49:07 +02:00
parent 431372c4e4
commit 71789e7f41
6 changed files with 254 additions and 120 deletions

View file

@ -1,15 +1,18 @@
from django.apps import AppConfig
from django.contrib.auth import user_logged_in
from django.db.models.signals import m2m_changed, post_delete, post_save
from django.db.models.signals import m2m_changed, post_delete, post_save, pre_save, pre_delete
from c3nav.editor import changes
class EditorConfig(AppConfig):
name = 'c3nav.editor'
def ready(self):
from c3nav.editor.models import ChangeSet
from c3nav.editor.signals import set_changeset_author_on_login
post_save.connect(ChangeSet.object_changed_handler)
post_delete.connect(ChangeSet.object_changed_handler)
m2m_changed.connect(ChangeSet.object_changed_handler)
pre_save.connect(changes.handle_pre_change_instance)
pre_delete.connect(changes.handle_pre_change_instance)
post_save.connect(changes.handle_post_save)
post_delete.connect(changes.handle_post_delete)
m2m_changed.connect(changes.handle_m2m_changed)
user_logged_in.connect(set_changeset_author_on_login)

240
src/c3nav/editor/changes.py Normal file
View file

@ -0,0 +1,240 @@
import datetime
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TypeAlias, Literal, Annotated, Union, Type, Any
from django.core import serializers
from django.db import transaction
from django.db.models import Model
from django.db.models.fields.related import OneToOneField, ForeignKey
from django.utils import timezone
from pydantic import ConfigDict, Discriminator
from pydantic.fields import Field
from c3nav.api.schema import BaseSchema
try:
from asgiref.local import Local as LocalContext
except ImportError:
from threading import local as LocalContext
FieldValuesDict: TypeAlias = dict[str, Any]
ExistingOrCreatedID: TypeAlias = int # negative = temporary ID of created object
class ObjectReference(BaseSchema):
model_config = ConfigDict(frozen=True)
model: str
id: ExistingOrCreatedID
@classmethod
def simple_from_instance(cls, instance: Model):
"""
This method will not convert the ID yet!
"""
return cls(model=instance._meta.model_name, id=instance.pk)
class BaseChange(BaseSchema):
obj: ObjectReference
datetime: Annotated[datetime.datetime, Field(default_factory=timezone.now)]
class CreateObjectChange(BaseChange):
type: Literal["create"] = "create"
fields: FieldValuesDict
class UpdateObjectChange(BaseChange):
type: Literal["update"] = "update"
fields: FieldValuesDict
class DeleteObjectChange(BaseChange):
type: Literal["delete"] = "delete"
class AddManyToManyChange(BaseSchema):
type: Literal["m2m_add"] = "m2m_add"
field: str
values: list[int]
class RemoveManyToManyChange(BaseSchema):
type: Literal["m2m_remove"] = "m2m_remove"
field: str
values: list[int]
class ClearManyToManyChange(BaseSchema):
type: Literal["m2m_clear"] = "m2m_clear"
field: str
ChangeSetChange = Annotated[
Union[
CreateObjectChange,
UpdateObjectChange,
DeleteObjectChange,
AddManyToManyChange,
RemoveManyToManyChange,
ClearManyToManyChange,
],
Discriminator("type"),
]
class ChangeSetChanges(BaseSchema):
prev_reprs: dict[ObjectReference, str] = {}
prev_values: dict[ObjectReference, FieldValuesDict] = {}
prev_m2m: dict[ObjectReference, dict[str, list[int]]] = {}
changes: list[ChangeSetChange] = []
overlay_state = LocalContext()
class InterceptAbortTransaction(Exception):
pass
@contextmanager
def enable_changeset_overlay(changeset):
try:
with transaction.atomic():
manager = ChangesetOverlayManager(changeset.changes)
overlay_state.manager = manager
# todo: apply changes so far
yield
raise InterceptAbortTransaction
except InterceptAbortTransaction:
pass
finally:
overlay_state.manager = None
@dataclass
class ChangesetOverlayManager:
changes: ChangeSetChanges
new_changes: bool = False
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict)
# maps negative IDs of created objects to the ID during the current transaction
mapped_ids: dict[ObjectReference, int] = field(default_factory=dict)
# maps IDs as used during the current transaction to the negative IDs
reverse_mapped_ids: dict[ObjectReference, int] = field(default_factory=dict)
def ref_lookup(self, ref: ObjectReference):
local_value = self.mapped_ids.get(ref, None)
return ref if local_value is None else ObjectReference(model=ref.model, id=local_value)
def reverse_ref_lookup(self, ref: ObjectReference):
created_value = self.reverse_mapped_ids.get(ref, None)
return ref if created_value is None else ObjectReference(model=ref.model, id=created_value)
def get_model_field_values(self, instance: Model) -> FieldValuesDict:
values = json.loads(serializers.serialize("json", [instance]))[0]["fields"]
for field in instance._meta.get_fields():
if field.name not in values:
continue
if isinstance(field, (OneToOneField, ForeignKey)):
value = values[field.name]
if value is not None:
values[field.name] = self.reverse_ref_lookup(
ObjectReference(model=field.model._meta.model_name, id=value)
).id
return values
def handle_pre_change_instance(self, sender: Type[Model], instance: Model, **kwargs):
if instance.pk is None:
return
ref = ObjectReference.simple_from_instance(instance)
if ref in self.reverse_mapped_ids:
return
if ref not in self.pre_change_values and ref not in self.changes.prev_values:
self.pre_change_values[ref] = self.get_model_field_values(
instance._meta.model.objects.get(pk=instance.pk)
)
def handle_post_save(self, sender: Type[Model], instance: Model, created: bool,
update_fields: set | None, **kwargs):
field_values = self.get_model_field_values(instance)
if created:
created_id = min([change.obj.id for change in self.changes
if isinstance(change, CreateObjectChange)], default=0)-1
model_name = instance._meta.model_name
self.mapped_ids[ObjectReference(model=model_name, id=created_id)] = instance.pk
self.reverse_mapped_ids[ObjectReference(model=model_name, id=instance.pk)] = created_id
self.changes.changes.append(CreateObjectChange(
obj=ObjectReference(model=instance._meta.model_name, id=created_id),
fields=field_values
))
from pprint import pprint
pprint(self.changes)
return
if update_fields:
field_values = {name: value for name, value in field_values.items() if name in update_fields}
ref = self.reverse_ref_lookup(ObjectReference.simple_from_instance(instance))
pre_change_values = self.pre_change_values.pop(ref, None)
if pre_change_values:
self.changes.prev_values[ref] = pre_change_values
self.changes.prev_reprs[ref] = str(instance)
self.changes.changes.append(UpdateObjectChange(
obj=ref,
fields=field_values
))
from pprint import pprint
pprint(self.changes)
def handle_post_delete(self, sender: Type[Model], instance: Model, **kwargs):
ref = self.reverse_ref_lookup(ObjectReference.simple_from_instance(instance))
pre_change_values = self.pre_change_values.pop(ref, None)
if pre_change_values:
self.changes.prev_values[ref] = pre_change_values
self.changes.prev_reprs[ref] = str(instance)
self.changes.changes.append(DeleteObjectChange(
obj=ref,
))
from pprint import pprint
pprint(self.changes)
def handle_m2m_changed(self, sender: Type[Model], instance: Model, **kwargs):
pass
def handle_pre_change_instance(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_pre_change_instance(sender=sender, **kwargs)
def handle_post_save(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_post_save(sender=sender, **kwargs)
def handle_post_delete(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_post_delete(sender=sender, **kwargs)
def handle_m2m_changed(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_m2m_changed(sender=sender, **kwargs)

View file

@ -1,29 +0,0 @@
from contextlib import contextmanager
from functools import wraps
from django.db import transaction
from c3nav.editor.models import ChangeSet
try:
from asgiref.local import Local as LocalContext
except ImportError:
from threading import local as LocalContext
intercept = LocalContext()
class InterceptAbortTransaction(Exception):
pass
@contextmanager
def enable_changeset_overlay(changeset):
try:
with transaction.atomic():
# todo: apply changes so far
yield
raise InterceptAbortTransaction
except InterceptAbortTransaction:
pass

View file

@ -1,10 +1,12 @@
# Generated by Django 5.0.8 on 2024-08-22 17:03
import c3nav.editor.changes
import c3nav.editor.models.changeset
import django.core.serializers.json
import django_pydantic_field.fields
from django.db import migrations
import c3nav.editor.changes
class Migration(migrations.Migration):
@ -20,7 +22,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name='changeset',
name='changes',
field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.models.changeset.ChangeSetChanges, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.models.changeset.ChangeSetChanges),
field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.changes.ChangeSetChanges, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.changes.ChangeSetChanges),
),
migrations.DeleteModel(
name='ChangedObject',

View file

@ -1,8 +1,5 @@
import datetime
from collections import OrderedDict
from contextlib import contextmanager
from enum import StrEnum
from typing import Literal, TypeAlias, Union, Annotated
from django.apps import apps
from django.conf import settings
@ -15,92 +12,13 @@ from django.utils.timezone import make_naive
from django.utils.translation import gettext_lazy as _
from django.utils.translation import ngettext_lazy
from django_pydantic_field import SchemaField
from pydantic.config import ConfigDict
from pydantic.fields import Field
from pydantic.types import Discriminator
from c3nav.api.schema import BaseSchema
from c3nav.editor.changes import ChangeSetChanges
from c3nav.editor.tasks import send_changeset_proposed_notification
from c3nav.mapdata.models import LocationSlug, MapUpdate
from c3nav.mapdata.models.locations import LocationRedirect
from c3nav.mapdata.utils.cache.changes import changed_geometries
FieldValuesDict: TypeAlias = dict[int, str]
ExistingOrCreatedID: TypeAlias = int # negative = temporary ID of created object
class ObjectReferenceType(StrEnum):
EXISTING = "existing"
CREATED = "created"
class ObjectReference(BaseSchema):
model_config = ConfigDict(frozen=True)
model: str
id: ExistingOrCreatedID
class BaseChange(BaseSchema):
obj: ObjectReference
datetime: datetime.datetime
class CreateObjectChange(BaseChange):
type: Literal["create"]
fields: FieldValuesDict
class UpdateObjectChange(BaseChange):
type: Literal["update"]
fields: FieldValuesDict
class DeleteObjectChange(BaseChange):
type: Literal["delete"]
class AddManyToManyChange(BaseSchema):
type: Literal["m2m_add"]
field: str
values: list[int]
class RemoveManyToManyChange(BaseSchema):
type: Literal["m2m_remove"]
field: str
values: list[int]
class ClearManyToManyChange(BaseSchema):
type: Literal["m2m_clear"]
field: str
ChangeSetChange = Annotated[
Union[
CreateObjectChange,
UpdateObjectChange,
DeleteObjectChange,
AddManyToManyChange,
RemoveManyToManyChange,
ClearManyToManyChange,
],
Discriminator("type"),
]
class ChangeSetChanges(BaseSchema):
prev_reprs: dict[ObjectReference, str] = {}
prev_values: dict[ObjectReference, FieldValuesDict] = {}
prev_m2m: dict[ObjectReference, dict[str, list[int]]] = {}
changes: list[ChangeSetChange] = []
# maps negative IDs of created objects to the ID during the current transaction
mapped_ids: Annotated[dict[ObjectReference, int], Field(exclude=True)] = {}
# maps IDs as used during the current transaction to the negative IDs
reverse_mapped_ids: Annotated[dict[ObjectReference, int], Field(exclude=True)] = {}
class ChangeSet(models.Model):
STATES = (

View file

@ -15,7 +15,7 @@ from django.utils.cache import patch_vary_headers
from django.utils.translation import get_language
from django.utils.translation import gettext_lazy as _
from c3nav.editor.intercept import enable_changeset_overlay
from c3nav.editor.changes import enable_changeset_overlay
from c3nav.editor.models import ChangeSet
from c3nav.mapdata.models.access import AccessPermission
from c3nav.mapdata.models.base import SerializableMixin