team-3/src/c3nav/editor/overlay.py

197 lines
7.9 KiB
Python
Raw Normal View History

2024-08-22 22:49:07 +02:00
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Type
2024-08-22 22:49:07 +02:00
from django.core import serializers
from django.db import transaction
from django.db.models import Model
from django.db.models.fields.related import ManyToManyField
2024-08-22 22:49:07 +02:00
from c3nav.editor.operations import CreateObjectOperation, \
2024-09-26 13:19:29 +02:00
UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, \
DatabaseOperationCollection, FieldValuesDict, ObjectReference
from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug
2024-08-22 22:49:07 +02:00
try:
from asgiref.local import Local as LocalContext
except ImportError:
from threading import local as LocalContext
overlay_state = LocalContext()
class InterceptAbortTransaction(Exception):
pass
@dataclass
class DatabaseOverlayManager:
2024-09-26 13:19:29 +02:00
"""
This class handles the currently active database overlay and will apply and/or intercept changes.
"""
operations: DatabaseOperationCollection = field(default_factory=DatabaseOperationCollection)
2024-09-26 13:19:29 +02:00
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict, init=False, repr=False)
2024-08-22 22:49:07 +02:00
@classmethod
@contextmanager
2024-09-26 13:19:29 +02:00
def enable(cls, operations: DatabaseOperationCollection | None = None, commit: bool = False):
"""
2024-10-28 17:17:45 +01:00
Context manager to enable the database overlay, optionally <pre-applying the given changes.
2024-09-26 13:19:29 +02:00
Only one overlay can be active at the same type, or else you get a TypeError.
:param operations: what operations to pre-apply
:param commit: whether to actually commit operations to the database or revert them at the end
"""
if getattr(overlay_state, "manager", None) is not None:
raise TypeError("Only one overlay can be active at the same time")
if operations is None:
operations = DatabaseOperationCollection()
try:
with transaction.atomic():
manager = DatabaseOverlayManager(operations=DatabaseOperationCollection(prev=operations.prev))
2024-09-26 13:19:29 +02:00
operations.prefetch().apply()
overlay_state.manager = manager
yield manager
if not commit:
raise InterceptAbortTransaction
except InterceptAbortTransaction:
pass
finally:
overlay_state.manager = None
@staticmethod
def get_model_field_values(instance: Model) -> FieldValuesDict:
values = json.loads(serializers.serialize("json", [instance]))[0]["fields"]
if issubclass(instance._meta.model, LocationSlug):
values["slug"] = instance.slug
return values
2024-08-22 22:49:07 +02:00
def get_ref_and_pre_change_values(self, instance: Model) -> tuple[ObjectReference, FieldValuesDict]:
ref = ObjectReference.from_instance(instance)
prev = self.operations.prev.get(ref)
if prev is None:
pre_change_values = self.pre_change_values.pop(ref)
self.operations.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None))
else:
pre_change_values = prev.values
return ref, pre_change_values
def handle_pre_change_instance(self, instance: Model, **kwargs):
2024-08-22 22:49:07 +02:00
if instance.pk is None:
return
ref = ObjectReference.from_instance(instance)
2024-09-26 13:19:29 +02:00
if ref not in self.pre_change_values and self.operations.prev.get(ref) is None:
2024-08-22 22:49:07 +02:00
self.pre_change_values[ref] = self.get_model_field_values(
instance._meta.model.objects.get(pk=instance.pk)
)
def handle_post_save(self, instance: Model, created: bool, update_fields: set | None, **kwargs):
2024-08-22 22:49:07 +02:00
field_values = self.get_model_field_values(instance)
if created:
ref = ObjectReference.from_instance(instance)
2024-09-26 13:19:29 +02:00
self.operations.append(CreateObjectOperation(obj=ref, fields=field_values))
2024-08-22 22:49:07 +02:00
return
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
2024-08-22 22:49:07 +02:00
if update_fields:
field_values = {name: value for name, value in field_values.items() if name in update_fields}
if pre_change_values is not None:
field_values = {name: value for name, value in field_values.items() if value != pre_change_values[name]}
# special diffing within the i18n fields
for field_name in tuple(field_values):
if isinstance(instance._meta.get_field(field_name), I18nField):
before_val = pre_change_values[field_name]
after_val = field_values[field_name]
diff_val = {}
for lang in (set(before_val) | set(after_val)):
if before_val.get(lang, None) != after_val.get(lang, None):
diff_val[lang] = after_val.get(lang, None)
field_values[field_name] = diff_val
2024-09-26 13:19:29 +02:00
self.operations.append(UpdateObjectOperation(obj=ref, fields=field_values))
2024-08-22 22:49:07 +02:00
def handle_post_delete(self, instance: Model, **kwargs):
# not isinstance() cause it would match submodels
if instance._meta.model is LocationSlug:
return
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
2024-09-26 13:19:29 +02:00
self.operations.append(DeleteObjectOperation(obj=ref))
2024-08-22 22:49:07 +02:00
2024-08-24 15:27:31 +02:00
def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model],
pk_set: set | None, reverse: bool, **kwargs):
if reverse:
raise NotImplementedError
if action.startswith("pre_"):
return self.handle_pre_change_instance(sender=instance._meta.model, instance=instance)
for field in instance._meta.get_fields():
2024-08-26 11:01:42 +02:00
if isinstance(field, ManyToManyField) and field.remote_field.through == sender:
2024-08-24 15:27:31 +02:00
break
else:
raise ValueError
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
2024-08-24 15:27:31 +02:00
if action == "post_clear":
2024-09-26 13:19:29 +02:00
self.operations.append(ClearManyToManyOperation(obj=ref, field=field.name))
return
2024-09-26 13:19:29 +02:00
if self.operations:
last_change = self.operations[-1]
2024-12-13 01:11:53 +00:00
if (isinstance(last_change, UpdateManyToManyOperation)
and last_change.obj == ref and last_change == field.name):
if action == "post_add":
last_change.add_values.update(pk_set)
last_change.remove_values.difference_update(pk_set)
else:
last_change.add_values.difference_update(pk_set)
last_change.remove_values.update(pk_set)
return
if action == "post_add":
2024-09-26 13:19:29 +02:00
self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set)))
else:
2024-09-26 13:19:29 +02:00
self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set)))
2024-08-22 22:49:07 +02:00
def handle_pre_change_instance(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
2024-08-22 22:49:07 +02:00
if manager:
manager.handle_pre_change_instance(sender=sender, **kwargs)
def handle_post_save(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
2024-08-22 22:49:07 +02:00
if manager:
manager.handle_post_save(sender=sender, **kwargs)
def handle_post_delete(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
2024-08-22 22:49:07 +02:00
if manager:
manager.handle_post_delete(sender=sender, **kwargs)
def handle_m2m_changed(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
2024-08-22 22:49:07 +02:00
if manager:
manager.handle_m2m_changed(sender=sender, **kwargs)