team-3/src/c3nav/editor/changes.py

204 lines
6.1 KiB
Python
Raw Normal View History

2024-08-22 22:49:07 +02:00
import datetime
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TypeAlias, Literal, Annotated, Union, Type, Any
from django.core import serializers
from django.db import transaction
from django.db.models import Model
from django.db.models.fields.related import OneToOneField, ForeignKey
from django.utils import timezone
from pydantic import ConfigDict, Discriminator
from pydantic.fields import Field
from c3nav.api.schema import BaseSchema
try:
from asgiref.local import Local as LocalContext
except ImportError:
from threading import local as LocalContext
FieldValuesDict: TypeAlias = dict[str, Any]
class ObjectReference(BaseSchema):
model_config = ConfigDict(frozen=True)
model: str
id: int
2024-08-22 22:49:07 +02:00
@classmethod
def from_instance(cls, instance: Model):
2024-08-22 22:49:07 +02:00
"""
This method will not convert the ID yet!
"""
return cls(model=instance._meta.model_name, id=instance.pk)
class BaseChange(BaseSchema):
obj: ObjectReference
datetime: Annotated[datetime.datetime, Field(default_factory=timezone.now)]
class CreateObjectChange(BaseChange):
type: Literal["create"] = "create"
fields: FieldValuesDict
class UpdateObjectChange(BaseChange):
type: Literal["update"] = "update"
fields: FieldValuesDict
class DeleteObjectChange(BaseChange):
type: Literal["delete"] = "delete"
class AddManyToManyChange(BaseSchema):
type: Literal["m2m_add"] = "m2m_add"
field: str
values: list[int]
class RemoveManyToManyChange(BaseSchema):
type: Literal["m2m_remove"] = "m2m_remove"
field: str
values: list[int]
class ClearManyToManyChange(BaseSchema):
type: Literal["m2m_clear"] = "m2m_clear"
field: str
ChangeSetChange = Annotated[
Union[
CreateObjectChange,
UpdateObjectChange,
DeleteObjectChange,
AddManyToManyChange,
RemoveManyToManyChange,
ClearManyToManyChange,
],
Discriminator("type"),
]
class ChangeSetChanges(BaseSchema):
prev_reprs: dict[ObjectReference, str] = {}
prev_values: dict[ObjectReference, FieldValuesDict] = {}
prev_m2m: dict[ObjectReference, dict[str, list[int]]] = {}
changes: list[ChangeSetChange] = []
overlay_state = LocalContext()
class InterceptAbortTransaction(Exception):
pass
@contextmanager
def enable_changeset_overlay(changeset):
try:
with transaction.atomic():
manager = ChangesetOverlayManager(changeset.changes)
overlay_state.manager = manager
# todo: apply changes so far
yield
raise InterceptAbortTransaction
except InterceptAbortTransaction:
pass
finally:
overlay_state.manager = None
@dataclass
class ChangesetOverlayManager:
changes: ChangeSetChanges
new_changes: bool = False
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict)
def get_model_field_values(self, instance: Model) -> FieldValuesDict:
return json.loads(serializers.serialize("json", [instance]))[0]["fields"]
2024-08-22 22:49:07 +02:00
def handle_pre_change_instance(self, sender: Type[Model], instance: Model, **kwargs):
if instance.pk is None:
return
ref = ObjectReference.from_instance(instance)
2024-08-22 22:49:07 +02:00
if ref not in self.pre_change_values and ref not in self.changes.prev_values:
self.pre_change_values[ref] = self.get_model_field_values(
instance._meta.model.objects.get(pk=instance.pk)
)
def handle_post_save(self, sender: Type[Model], instance: Model, created: bool,
update_fields: set | None, **kwargs):
field_values = self.get_model_field_values(instance)
ref = ObjectReference.from_instance(instance)
2024-08-22 22:49:07 +02:00
if created:
self.changes.changes.append(CreateObjectChange(obj=ref, fields=field_values))
2024-08-22 22:49:07 +02:00
from pprint import pprint
pprint(self.changes)
return
if update_fields:
field_values = {name: value for name, value in field_values.items() if name in update_fields}
pre_change_values = self.pre_change_values.pop(ref, None)
if pre_change_values:
self.changes.prev_values[ref] = pre_change_values
self.changes.prev_reprs[ref] = str(instance)
self.changes.changes.append(UpdateObjectChange(obj=ref, fields=field_values))
2024-08-22 22:49:07 +02:00
from pprint import pprint
pprint(self.changes)
def handle_post_delete(self, sender: Type[Model], instance: Model, **kwargs):
ref = ObjectReference.from_instance(instance)
2024-08-22 22:49:07 +02:00
pre_change_values = self.pre_change_values.pop(ref, None)
if pre_change_values:
self.changes.prev_values[ref] = pre_change_values
self.changes.prev_reprs[ref] = str(instance)
self.changes.changes.append(DeleteObjectChange(
obj=ref,
))
from pprint import pprint
pprint(self.changes)
def handle_m2m_changed(self, sender: Type[Model], instance: Model, **kwargs):
pass
def handle_pre_change_instance(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_pre_change_instance(sender=sender, **kwargs)
def handle_post_save(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_post_save(sender=sender, **kwargs)
def handle_post_delete(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_post_delete(sender=sender, **kwargs)
def handle_m2m_changed(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_m2m_changed(sender=sender, **kwargs)