team-3/src/c3nav/editor/overlay.py

184 lines
7.2 KiB
Python
Raw Normal View History

2024-08-22 22:49:07 +02:00
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Type
2024-08-22 22:49:07 +02:00
from django.core import serializers
from django.db import transaction
from django.db.models import Model
from django.db.models.fields.related import ManyToManyField
2024-08-22 22:49:07 +02:00
from c3nav.editor.operations import DatabaseOperation, ObjectReference, FieldValuesDict, CreateObjectOperation, \
UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, CollectedChanges
from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug
2024-08-22 22:49:07 +02:00
try:
from asgiref.local import Local as LocalContext
except ImportError:
from threading import local as LocalContext
overlay_state = LocalContext()
class InterceptAbortTransaction(Exception):
pass
@dataclass
class DatabaseOverlayManager:
changes: CollectedChanges
new_operations: list[DatabaseOperation] = field(default_factory=list)
2024-08-22 22:49:07 +02:00
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict)
@classmethod
@contextmanager
def enable(cls, changes: CollectedChanges | None, commit: bool):
if getattr(overlay_state, 'manager', None) is not None:
raise TypeError
if changes is None:
changes = CollectedChanges()
try:
with transaction.atomic():
manager = DatabaseOverlayManager(changes)
manager.changes.prefetch().apply()
overlay_state.manager = manager
yield manager
if not commit:
raise InterceptAbortTransaction
except InterceptAbortTransaction:
pass
finally:
overlay_state.manager = None
def save_new_operations(self):
self.changes.operations.extend(self.new_operations)
@staticmethod
def get_model_field_values(instance: Model) -> FieldValuesDict:
values = json.loads(serializers.serialize("json", [instance]))[0]["fields"]
if issubclass(instance._meta.model, LocationSlug):
values["slug"] = instance.slug
return values
2024-08-22 22:49:07 +02:00
def get_ref_and_pre_change_values(self, instance: Model) -> tuple[ObjectReference, FieldValuesDict]:
ref = ObjectReference.from_instance(instance)
pre_change_values = self.pre_change_values.pop(ref, None)
if pre_change_values:
self.changes.prev_values.setdefault(ref.model, {})[ref.id] = pre_change_values
self.changes.prev_titles.setdefault(ref.model, {})[ref.id] = getattr(instance, 'titles', None)
return ref, pre_change_values
def handle_pre_change_instance(self, instance: Model, **kwargs):
2024-08-22 22:49:07 +02:00
if instance.pk is None:
return
ref = ObjectReference.from_instance(instance)
if ref not in self.pre_change_values and ref.id not in self.changes.prev_values.get(ref.model, {}):
2024-08-22 22:49:07 +02:00
self.pre_change_values[ref] = self.get_model_field_values(
instance._meta.model.objects.get(pk=instance.pk)
)
def handle_post_save(self, instance: Model, created: bool, update_fields: set | None, **kwargs):
2024-08-22 22:49:07 +02:00
field_values = self.get_model_field_values(instance)
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
2024-08-22 22:49:07 +02:00
if created:
self.new_operations.append(CreateObjectOperation(obj=ref, fields=field_values))
2024-08-22 22:49:07 +02:00
return
if update_fields:
field_values = {name: value for name, value in field_values.items() if name in update_fields}
if pre_change_values is not None:
field_values = {name: value for name, value in field_values.items() if value != pre_change_values[name]}
# special diffing within the i18n fields
for field_name in tuple(field_values):
if isinstance(instance._meta.get_field(field_name), I18nField):
before_val = pre_change_values[field_name]
after_val = field_values[field_name]
diff_val = {}
for lang in (set(before_val) | set(after_val)):
if before_val.get(lang, None) != after_val.get(lang, None):
diff_val[lang] = after_val.get(lang, None)
field_values[field_name] = diff_val
self.new_operations.append(UpdateObjectOperation(obj=ref, fields=field_values))
2024-08-22 22:49:07 +02:00
def handle_post_delete(self, instance: Model, **kwargs):
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
self.new_operations.append(DeleteObjectOperation(obj=ref))
2024-08-22 22:49:07 +02:00
2024-08-24 15:27:31 +02:00
def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model],
pk_set: set | None, reverse: bool, **kwargs):
if reverse:
raise NotImplementedError
if action.startswith("pre_"):
return self.handle_pre_change_instance(sender=instance._meta.model, instance=instance)
for field in instance._meta.get_fields():
2024-08-26 11:01:42 +02:00
if isinstance(field, ManyToManyField) and field.remote_field.through == sender:
2024-08-24 15:27:31 +02:00
break
else:
raise ValueError
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
2024-08-24 15:27:31 +02:00
if action == "post_clear":
self.new_operations.append(ClearManyToManyOperation(obj=ref, field=field.name))
return
if self.new_operations:
last_change = self.new_operations[-1]
if isinstance(last_change, UpdateManyToManyOperation) and last_change == ref and last_change == field.name:
if action == "post_add":
last_change.add_values.update(pk_set)
last_change.remove_values.difference_update(pk_set)
else:
last_change.add_values.difference_update(pk_set)
last_change.remove_values.update(pk_set)
return
if action == "post_add":
self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set)))
else:
self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set)))
2024-08-22 22:49:07 +02:00
def handle_pre_change_instance(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
2024-08-22 22:49:07 +02:00
if manager:
manager.handle_pre_change_instance(sender=sender, **kwargs)
def handle_post_save(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
2024-08-22 22:49:07 +02:00
if manager:
manager.handle_post_save(sender=sender, **kwargs)
def handle_post_delete(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
2024-08-22 22:49:07 +02:00
if manager:
manager.handle_post_delete(sender=sender, **kwargs)
def handle_m2m_changed(sender: Type[Model], **kwargs):
if sender._meta.app_label != 'mapdata':
return
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
2024-08-22 22:49:07 +02:00
if manager:
manager.handle_m2m_changed(sender=sender, **kwargs)