2024-08-22 22:49:07 +02:00
|
|
|
import json
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from dataclasses import dataclass, field
|
2024-08-26 11:49:59 +02:00
|
|
|
from typing import Type
|
2024-08-22 22:49:07 +02:00
|
|
|
|
|
|
|
from django.core import serializers
|
|
|
|
from django.db import transaction
|
|
|
|
from django.db.models import Model
|
2024-08-26 11:49:59 +02:00
|
|
|
from django.db.models.fields.related import ManyToManyField
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-08-26 11:49:59 +02:00
|
|
|
from c3nav.editor.operations import DatabaseOperation, ObjectReference, FieldValuesDict, CreateObjectOperation, \
|
|
|
|
UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, CollectedChanges
|
2024-08-26 16:45:16 +02:00
|
|
|
from c3nav.mapdata.fields import I18nField
|
2024-08-26 20:23:32 +02:00
|
|
|
from c3nav.mapdata.models import LocationSlug
|
2024-08-22 22:49:07 +02:00
|
|
|
|
|
|
|
try:
|
|
|
|
from asgiref.local import Local as LocalContext
|
|
|
|
except ImportError:
|
|
|
|
from threading import local as LocalContext
|
|
|
|
|
|
|
|
|
|
|
|
overlay_state = LocalContext()
|
|
|
|
|
|
|
|
|
|
|
|
class InterceptAbortTransaction(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2024-08-26 11:49:59 +02:00
|
|
|
class DatabaseOverlayManager:
|
|
|
|
changes: CollectedChanges
|
|
|
|
new_operations: list[DatabaseOperation] = field(default_factory=list)
|
2024-08-22 22:49:07 +02:00
|
|
|
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict)
|
|
|
|
|
2024-08-26 11:49:59 +02:00
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def enable(cls, changes: CollectedChanges | None, commit: bool):
|
|
|
|
if getattr(overlay_state, 'manager', None) is not None:
|
|
|
|
raise TypeError
|
|
|
|
if changes is None:
|
|
|
|
changes = CollectedChanges()
|
|
|
|
try:
|
|
|
|
with transaction.atomic():
|
|
|
|
manager = DatabaseOverlayManager(changes)
|
2024-08-26 14:27:25 +02:00
|
|
|
manager.changes.prefetch().apply()
|
2024-08-26 11:49:59 +02:00
|
|
|
overlay_state.manager = manager
|
|
|
|
yield manager
|
|
|
|
if not commit:
|
|
|
|
raise InterceptAbortTransaction
|
|
|
|
except InterceptAbortTransaction:
|
|
|
|
pass
|
|
|
|
finally:
|
|
|
|
overlay_state.manager = None
|
|
|
|
|
2024-08-26 14:27:25 +02:00
|
|
|
def save_new_operations(self):
|
|
|
|
self.changes.operations.extend(self.new_operations)
|
|
|
|
|
2024-08-24 17:58:05 +02:00
|
|
|
@staticmethod
|
|
|
|
def get_model_field_values(instance: Model) -> FieldValuesDict:
|
2024-08-26 20:23:32 +02:00
|
|
|
values = json.loads(serializers.serialize("json", [instance]))[0]["fields"]
|
|
|
|
if issubclass(instance._meta.model, LocationSlug):
|
|
|
|
values["slug"] = instance.slug
|
|
|
|
return values
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-08-26 16:45:16 +02:00
|
|
|
def get_ref_and_pre_change_values(self, instance: Model) -> tuple[ObjectReference, FieldValuesDict]:
|
2024-08-24 17:58:05 +02:00
|
|
|
ref = ObjectReference.from_instance(instance)
|
|
|
|
|
|
|
|
pre_change_values = self.pre_change_values.pop(ref, None)
|
|
|
|
if pre_change_values:
|
2024-08-26 14:27:25 +02:00
|
|
|
self.changes.prev_values.setdefault(ref.model, {})[ref.id] = pre_change_values
|
2024-08-26 16:55:22 +02:00
|
|
|
self.changes.prev_titles.setdefault(ref.model, {})[ref.id] = getattr(instance, 'titles', None)
|
2024-08-24 17:58:05 +02:00
|
|
|
|
2024-08-26 16:45:16 +02:00
|
|
|
return ref, pre_change_values
|
2024-08-24 17:58:05 +02:00
|
|
|
|
|
|
|
def handle_pre_change_instance(self, instance: Model, **kwargs):
|
2024-08-22 22:49:07 +02:00
|
|
|
if instance.pk is None:
|
|
|
|
return
|
2024-08-24 13:20:39 +02:00
|
|
|
ref = ObjectReference.from_instance(instance)
|
2024-08-26 14:27:25 +02:00
|
|
|
if ref not in self.pre_change_values and ref.id not in self.changes.prev_values.get(ref.model, {}):
|
2024-08-22 22:49:07 +02:00
|
|
|
self.pre_change_values[ref] = self.get_model_field_values(
|
|
|
|
instance._meta.model.objects.get(pk=instance.pk)
|
|
|
|
)
|
|
|
|
|
2024-08-24 17:58:05 +02:00
|
|
|
def handle_post_save(self, instance: Model, created: bool, update_fields: set | None, **kwargs):
|
2024-08-22 22:49:07 +02:00
|
|
|
field_values = self.get_model_field_values(instance)
|
|
|
|
|
2024-08-26 16:45:16 +02:00
|
|
|
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
|
2024-08-24 13:20:39 +02:00
|
|
|
|
2024-08-22 22:49:07 +02:00
|
|
|
if created:
|
2024-08-26 11:49:59 +02:00
|
|
|
self.new_operations.append(CreateObjectOperation(obj=ref, fields=field_values))
|
2024-08-22 22:49:07 +02:00
|
|
|
return
|
|
|
|
|
|
|
|
if update_fields:
|
|
|
|
field_values = {name: value for name, value in field_values.items() if name in update_fields}
|
|
|
|
|
2024-08-26 20:30:09 +02:00
|
|
|
if pre_change_values is not None:
|
|
|
|
field_values = {name: value for name, value in field_values.items() if value != pre_change_values[name]}
|
|
|
|
|
|
|
|
# special diffing within the i18n fields
|
|
|
|
for field_name in tuple(field_values):
|
|
|
|
if isinstance(instance._meta.get_field(field_name), I18nField):
|
|
|
|
before_val = pre_change_values[field_name]
|
|
|
|
after_val = field_values[field_name]
|
|
|
|
|
|
|
|
diff_val = {}
|
|
|
|
for lang in (set(before_val) | set(after_val)):
|
|
|
|
if before_val.get(lang, None) != after_val.get(lang, None):
|
|
|
|
diff_val[lang] = after_val.get(lang, None)
|
|
|
|
field_values[field_name] = diff_val
|
2024-08-26 16:45:16 +02:00
|
|
|
|
2024-08-26 11:49:59 +02:00
|
|
|
self.new_operations.append(UpdateObjectOperation(obj=ref, fields=field_values))
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-08-24 17:58:05 +02:00
|
|
|
def handle_post_delete(self, instance: Model, **kwargs):
|
2024-08-26 16:45:16 +02:00
|
|
|
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
|
2024-08-26 11:49:59 +02:00
|
|
|
self.new_operations.append(DeleteObjectOperation(obj=ref))
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-08-24 15:27:31 +02:00
|
|
|
def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model],
|
|
|
|
pk_set: set | None, reverse: bool, **kwargs):
|
|
|
|
if reverse:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
if action.startswith("pre_"):
|
|
|
|
return self.handle_pre_change_instance(sender=instance._meta.model, instance=instance)
|
|
|
|
|
|
|
|
for field in instance._meta.get_fields():
|
2024-08-26 11:01:42 +02:00
|
|
|
if isinstance(field, ManyToManyField) and field.remote_field.through == sender:
|
2024-08-24 15:27:31 +02:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise ValueError
|
|
|
|
|
2024-08-26 16:45:16 +02:00
|
|
|
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
|
2024-08-24 15:27:31 +02:00
|
|
|
|
2024-08-24 17:58:05 +02:00
|
|
|
if action == "post_clear":
|
2024-08-26 11:49:59 +02:00
|
|
|
self.new_operations.append(ClearManyToManyOperation(obj=ref, field=field.name))
|
2024-08-24 17:58:05 +02:00
|
|
|
return
|
|
|
|
|
2024-08-26 11:49:59 +02:00
|
|
|
if self.new_operations:
|
|
|
|
last_change = self.new_operations[-1]
|
|
|
|
if isinstance(last_change, UpdateManyToManyOperation) and last_change == ref and last_change == field.name:
|
2024-08-24 17:58:05 +02:00
|
|
|
if action == "post_add":
|
|
|
|
last_change.add_values.update(pk_set)
|
|
|
|
last_change.remove_values.difference_update(pk_set)
|
|
|
|
else:
|
|
|
|
last_change.add_values.difference_update(pk_set)
|
|
|
|
last_change.remove_values.update(pk_set)
|
|
|
|
return
|
|
|
|
|
|
|
|
if action == "post_add":
|
2024-08-26 11:49:59 +02:00
|
|
|
self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set)))
|
2024-08-24 17:58:05 +02:00
|
|
|
else:
|
2024-08-26 11:49:59 +02:00
|
|
|
self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set)))
|
2024-08-22 22:49:07 +02:00
|
|
|
|
|
|
|
|
|
|
|
def handle_pre_change_instance(sender: Type[Model], **kwargs):
|
|
|
|
if sender._meta.app_label != 'mapdata':
|
|
|
|
return
|
2024-08-26 11:49:59 +02:00
|
|
|
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
|
2024-08-22 22:49:07 +02:00
|
|
|
if manager:
|
|
|
|
manager.handle_pre_change_instance(sender=sender, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def handle_post_save(sender: Type[Model], **kwargs):
|
|
|
|
if sender._meta.app_label != 'mapdata':
|
|
|
|
return
|
2024-08-26 11:49:59 +02:00
|
|
|
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
|
2024-08-22 22:49:07 +02:00
|
|
|
if manager:
|
|
|
|
manager.handle_post_save(sender=sender, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def handle_post_delete(sender: Type[Model], **kwargs):
|
|
|
|
if sender._meta.app_label != 'mapdata':
|
|
|
|
return
|
2024-08-26 11:49:59 +02:00
|
|
|
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
|
2024-08-22 22:49:07 +02:00
|
|
|
if manager:
|
|
|
|
manager.handle_post_delete(sender=sender, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def handle_m2m_changed(sender: Type[Model], **kwargs):
|
|
|
|
if sender._meta.app_label != 'mapdata':
|
|
|
|
return
|
2024-08-26 11:49:59 +02:00
|
|
|
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
|
2024-08-22 22:49:07 +02:00
|
|
|
if manager:
|
|
|
|
manager.handle_m2m_changed(sender=sender, **kwargs)
|