team-3/src/c3nav/editor/changes.py

241 lines
8 KiB
Python
Raw Normal View History

2024-08-22 22:49:07 +02:00
import datetime
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TypeAlias, Literal, Annotated, Union, Type, Any
from django.core import serializers
from django.db import transaction
from django.db.models import Model
from django.db.models.fields.related import OneToOneField, ForeignKey
from django.utils import timezone
from pydantic import ConfigDict, Discriminator
from pydantic.fields import Field
from c3nav.api.schema import BaseSchema
try:
from asgiref.local import Local as LocalContext
except ImportError:
from threading import local as LocalContext
FieldValuesDict: TypeAlias = dict[str, Any]
ExistingOrCreatedID: TypeAlias = int # negative = temporary ID of created object
class ObjectReference(BaseSchema):
model_config = ConfigDict(frozen=True)
model: str
id: ExistingOrCreatedID
@classmethod
def simple_from_instance(cls, instance: Model):
"""
This method will not convert the ID yet!
"""
return cls(model=instance._meta.model_name, id=instance.pk)
class BaseChange(BaseSchema):
obj: ObjectReference
datetime: Annotated[datetime.datetime, Field(default_factory=timezone.now)]
class CreateObjectChange(BaseChange):
type: Literal["create"] = "create"
fields: FieldValuesDict
class UpdateObjectChange(BaseChange):
type: Literal["update"] = "update"
fields: FieldValuesDict
class DeleteObjectChange(BaseChange):
type: Literal["delete"] = "delete"
class AddManyToManyChange(BaseSchema):
type: Literal["m2m_add"] = "m2m_add"
field: str
values: list[int]
class RemoveManyToManyChange(BaseSchema):
type: Literal["m2m_remove"] = "m2m_remove"
field: str
values: list[int]
class ClearManyToManyChange(BaseSchema):
type: Literal["m2m_clear"] = "m2m_clear"
field: str
ChangeSetChange = Annotated[
Union[
CreateObjectChange,
UpdateObjectChange,
DeleteObjectChange,
AddManyToManyChange,
RemoveManyToManyChange,
ClearManyToManyChange,
],
Discriminator("type"),
]
class ChangeSetChanges(BaseSchema):
prev_reprs: dict[ObjectReference, str] = {}
prev_values: dict[ObjectReference, FieldValuesDict] = {}
prev_m2m: dict[ObjectReference, dict[str, list[int]]] = {}
changes: list[ChangeSetChange] = []
overlay_state = LocalContext()
class InterceptAbortTransaction(Exception):
pass
@contextmanager
def enable_changeset_overlay(changeset):
try:
with transaction.atomic():
manager = ChangesetOverlayManager(changeset.changes)
overlay_state.manager = manager
# todo: apply changes so far
yield
raise InterceptAbortTransaction
except InterceptAbortTransaction:
pass
finally:
overlay_state.manager = None
@dataclass
class ChangesetOverlayManager:
changes: ChangeSetChanges
new_changes: bool = False
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict)
# maps negative IDs of created objects to the ID during the current transaction
mapped_ids: dict[ObjectReference, int] = field(default_factory=dict)
# maps IDs as used during the current transaction to the negative IDs
reverse_mapped_ids: dict[ObjectReference, int] = field(default_factory=dict)
def ref_lookup(self, ref: ObjectReference):
local_value = self.mapped_ids.get(ref, None)
return ref if local_value is None else ObjectReference(model=ref.model, id=local_value)
def reverse_ref_lookup(self, ref: ObjectReference):
created_value = self.reverse_mapped_ids.get(ref, None)
return ref if created_value is None else ObjectReference(model=ref.model, id=created_value)
def get_model_field_values(self, instance: Model) -> FieldValuesDict:
values = json.loads(serializers.serialize("json", [instance]))[0]["fields"]
for field in instance._meta.get_fields():
if field.name not in values:
continue
if isinstance(field, (OneToOneField, ForeignKey)):
value = values[field.name]
if value is not None:
values[field.name] = self.reverse_ref_lookup(
ObjectReference(model=field.model._meta.model_name, id=value)
).id
return values
def handle_pre_change_instance(self, sender: Type[Model], instance: Model, **kwargs):
if instance.pk is None:
return
ref = ObjectReference.simple_from_instance(instance)
if ref in self.reverse_mapped_ids:
return
if ref not in self.pre_change_values and ref not in self.changes.prev_values:
self.pre_change_values[ref] = self.get_model_field_values(
instance._meta.model.objects.get(pk=instance.pk)
)
def handle_post_save(self, sender: Type[Model], instance: Model, created: bool,
update_fields: set | None, **kwargs):
field_values = self.get_model_field_values(instance)
if created:
created_id = min([change.obj.id for change in self.changes
if isinstance(change, CreateObjectChange)], default=0)-1
model_name = instance._meta.model_name
self.mapped_ids[ObjectReference(model=model_name, id=created_id)] = instance.pk
self.reverse_mapped_ids[ObjectReference(model=model_name, id=instance.pk)] = created_id
self.changes.changes.append(CreateObjectChange(
obj=ObjectReference(model=instance._meta.model_name, id=created_id),
fields=field_values
))
from pprint import pprint
pprint(self.changes)
return
if update_fields:
field_values = {name: value for name, value in field_values.items() if name in update_fields}
ref = self.reverse_ref_lookup(ObjectReference.simple_from_instance(instance))
pre_change_values = self.pre_change_values.pop(ref, None)
if pre_change_values:
self.changes.prev_values[ref] = pre_change_values
self.changes.prev_reprs[ref] = str(instance)
self.changes.changes.append(UpdateObjectChange(
obj=ref,
fields=field_values
))
from pprint import pprint
pprint(self.changes)
def handle_post_delete(self, sender: Type[Model], instance: Model, **kwargs):
ref = self.reverse_ref_lookup(ObjectReference.simple_from_instance(instance))
pre_change_values = self.pre_change_values.pop(ref, None)
if pre_change_values:
self.changes.prev_values[ref] = pre_change_values
self.changes.prev_reprs[ref] = str(instance)
self.changes.changes.append(DeleteObjectChange(
obj=ref,
))
from pprint import pprint
pprint(self.changes)
def handle_m2m_changed(self, sender: Type[Model], instance: Model, **kwargs):
pass
def handle_pre_change_instance(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_pre_change_instance(sender=sender, **kwargs)
def handle_post_save(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_post_save(sender=sender, **kwargs)
def handle_post_delete(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_post_delete(sender=sender, **kwargs)
def handle_m2m_changed(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: ChangesetOverlayManager = getattr(overlay_state, 'manager', None)
if manager:
manager.handle_m2m_changed(sender=sender, **kwargs)