2024-08-22 22:49:07 +02:00
|
|
|
import json
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from dataclasses import dataclass, field
|
2024-08-26 11:49:59 +02:00
|
|
|
from typing import Type
|
2024-08-22 22:49:07 +02:00
|
|
|
|
|
|
|
from django.core import serializers
|
|
|
|
from django.db import transaction
|
|
|
|
from django.db.models import Model
|
2024-08-26 11:49:59 +02:00
|
|
|
from django.db.models.fields.related import ManyToManyField
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-11-21 19:22:39 +01:00
|
|
|
from c3nav.editor.operations import CreateObjectOperation, \
|
2024-09-26 13:19:29 +02:00
|
|
|
UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, \
|
2024-11-21 19:22:39 +01:00
|
|
|
DatabaseOperationCollection, FieldValuesDict, ObjectReference
|
2024-08-26 16:45:16 +02:00
|
|
|
from c3nav.mapdata.fields import I18nField
|
2024-08-26 20:23:32 +02:00
|
|
|
from c3nav.mapdata.models import LocationSlug
|
2024-08-22 22:49:07 +02:00
|
|
|
|
|
|
|
try:
|
|
|
|
from asgiref.local import Local as LocalContext
|
|
|
|
except ImportError:
|
|
|
|
from threading import local as LocalContext
|
|
|
|
|
|
|
|
|
|
|
|
overlay_state = LocalContext()
|
|
|
|
|
|
|
|
|
|
|
|
class InterceptAbortTransaction(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2024-08-26 11:49:59 +02:00
|
|
|
class DatabaseOverlayManager:
|
2024-09-26 13:19:29 +02:00
|
|
|
"""
|
|
|
|
This class handles the currently active database overlay and will apply and/or intercept changes.
|
|
|
|
"""
|
2024-09-26 17:08:41 +02:00
|
|
|
operations: DatabaseOperationCollection = field(default_factory=DatabaseOperationCollection)
|
2024-09-26 13:19:29 +02:00
|
|
|
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict, init=False, repr=False)
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-08-26 11:49:59 +02:00
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
2024-09-26 13:19:29 +02:00
|
|
|
def enable(cls, operations: DatabaseOperationCollection | None = None, commit: bool = False):
|
|
|
|
"""
|
2024-10-28 17:17:45 +01:00
|
|
|
Context manager to enable the database overlay, optionally <pre-applying the given changes.
|
2024-09-26 13:19:29 +02:00
|
|
|
Only one overlay can be active at the same type, or else you get a TypeError.
|
|
|
|
|
|
|
|
:param operations: what operations to pre-apply
|
|
|
|
:param commit: whether to actually commit operations to the database or revert them at the end
|
|
|
|
"""
|
|
|
|
if getattr(overlay_state, "manager", None) is not None:
|
|
|
|
raise TypeError("Only one overlay can be active at the same time")
|
|
|
|
if operations is None:
|
|
|
|
operations = DatabaseOperationCollection()
|
2024-08-26 11:49:59 +02:00
|
|
|
try:
|
|
|
|
with transaction.atomic():
|
2024-09-26 17:08:41 +02:00
|
|
|
manager = DatabaseOverlayManager(operations=DatabaseOperationCollection(prev=operations.prev))
|
2024-09-26 13:19:29 +02:00
|
|
|
operations.prefetch().apply()
|
2024-08-26 11:49:59 +02:00
|
|
|
overlay_state.manager = manager
|
|
|
|
yield manager
|
|
|
|
if not commit:
|
|
|
|
raise InterceptAbortTransaction
|
|
|
|
except InterceptAbortTransaction:
|
|
|
|
pass
|
|
|
|
finally:
|
|
|
|
overlay_state.manager = None
|
|
|
|
|
2024-08-24 17:58:05 +02:00
|
|
|
@staticmethod
|
|
|
|
def get_model_field_values(instance: Model) -> FieldValuesDict:
|
2024-08-26 20:23:32 +02:00
|
|
|
values = json.loads(serializers.serialize("json", [instance]))[0]["fields"]
|
|
|
|
if issubclass(instance._meta.model, LocationSlug):
|
|
|
|
values["slug"] = instance.slug
|
|
|
|
return values
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-08-26 16:45:16 +02:00
|
|
|
def get_ref_and_pre_change_values(self, instance: Model) -> tuple[ObjectReference, FieldValuesDict]:
|
2024-08-24 17:58:05 +02:00
|
|
|
ref = ObjectReference.from_instance(instance)
|
|
|
|
|
2024-12-05 14:20:11 +01:00
|
|
|
prev = self.operations.prev.get(ref)
|
|
|
|
if prev is None:
|
|
|
|
pre_change_values = self.pre_change_values.pop(ref)
|
|
|
|
self.operations.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None))
|
|
|
|
else:
|
|
|
|
pre_change_values = prev.values
|
2024-08-26 16:45:16 +02:00
|
|
|
return ref, pre_change_values
|
2024-08-24 17:58:05 +02:00
|
|
|
|
|
|
|
def handle_pre_change_instance(self, instance: Model, **kwargs):
|
2024-08-22 22:49:07 +02:00
|
|
|
if instance.pk is None:
|
|
|
|
return
|
2024-08-24 13:20:39 +02:00
|
|
|
ref = ObjectReference.from_instance(instance)
|
2024-09-26 13:19:29 +02:00
|
|
|
if ref not in self.pre_change_values and self.operations.prev.get(ref) is None:
|
2024-08-22 22:49:07 +02:00
|
|
|
self.pre_change_values[ref] = self.get_model_field_values(
|
|
|
|
instance._meta.model.objects.get(pk=instance.pk)
|
|
|
|
)
|
|
|
|
|
2024-08-24 17:58:05 +02:00
|
|
|
def handle_post_save(self, instance: Model, created: bool, update_fields: set | None, **kwargs):
|
2024-08-22 22:49:07 +02:00
|
|
|
field_values = self.get_model_field_values(instance)
|
|
|
|
|
|
|
|
if created:
|
2024-12-05 13:13:35 +01:00
|
|
|
ref = ObjectReference.from_instance(instance)
|
2024-09-26 13:19:29 +02:00
|
|
|
self.operations.append(CreateObjectOperation(obj=ref, fields=field_values))
|
2024-08-22 22:49:07 +02:00
|
|
|
return
|
|
|
|
|
2024-12-05 13:13:35 +01:00
|
|
|
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
|
|
|
|
|
2024-08-22 22:49:07 +02:00
|
|
|
if update_fields:
|
|
|
|
field_values = {name: value for name, value in field_values.items() if name in update_fields}
|
|
|
|
|
2024-08-26 20:30:09 +02:00
|
|
|
if pre_change_values is not None:
|
|
|
|
field_values = {name: value for name, value in field_values.items() if value != pre_change_values[name]}
|
|
|
|
|
|
|
|
# special diffing within the i18n fields
|
|
|
|
for field_name in tuple(field_values):
|
|
|
|
if isinstance(instance._meta.get_field(field_name), I18nField):
|
|
|
|
before_val = pre_change_values[field_name]
|
|
|
|
after_val = field_values[field_name]
|
|
|
|
|
|
|
|
diff_val = {}
|
|
|
|
for lang in (set(before_val) | set(after_val)):
|
|
|
|
if before_val.get(lang, None) != after_val.get(lang, None):
|
|
|
|
diff_val[lang] = after_val.get(lang, None)
|
|
|
|
field_values[field_name] = diff_val
|
2024-08-26 16:45:16 +02:00
|
|
|
|
2024-09-26 13:19:29 +02:00
|
|
|
self.operations.append(UpdateObjectOperation(obj=ref, fields=field_values))
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-08-24 17:58:05 +02:00
|
|
|
def handle_post_delete(self, instance: Model, **kwargs):
|
2024-12-05 18:38:50 +01:00
|
|
|
# not isinstance() cause it would match submodels
|
|
|
|
if instance._meta.model is LocationSlug:
|
|
|
|
return
|
2024-08-26 16:45:16 +02:00
|
|
|
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
|
2024-09-26 13:19:29 +02:00
|
|
|
self.operations.append(DeleteObjectOperation(obj=ref))
|
2024-08-22 22:49:07 +02:00
|
|
|
|
2024-08-24 15:27:31 +02:00
|
|
|
def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model],
|
|
|
|
pk_set: set | None, reverse: bool, **kwargs):
|
|
|
|
if reverse:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
if action.startswith("pre_"):
|
|
|
|
return self.handle_pre_change_instance(sender=instance._meta.model, instance=instance)
|
|
|
|
|
|
|
|
for field in instance._meta.get_fields():
|
2024-08-26 11:01:42 +02:00
|
|
|
if isinstance(field, ManyToManyField) and field.remote_field.through == sender:
|
2024-08-24 15:27:31 +02:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise ValueError
|
|
|
|
|
2024-08-26 16:45:16 +02:00
|
|
|
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
|
2024-08-24 15:27:31 +02:00
|
|
|
|
2024-08-24 17:58:05 +02:00
|
|
|
if action == "post_clear":
|
2024-09-26 13:19:29 +02:00
|
|
|
self.operations.append(ClearManyToManyOperation(obj=ref, field=field.name))
|
2024-08-24 17:58:05 +02:00
|
|
|
return
|
|
|
|
|
2024-09-26 13:19:29 +02:00
|
|
|
if self.operations:
|
|
|
|
last_change = self.operations[-1]
|
2024-12-13 01:11:53 +00:00
|
|
|
if (isinstance(last_change, UpdateManyToManyOperation)
|
|
|
|
and last_change.obj == ref and last_change == field.name):
|
2024-08-24 17:58:05 +02:00
|
|
|
if action == "post_add":
|
|
|
|
last_change.add_values.update(pk_set)
|
|
|
|
last_change.remove_values.difference_update(pk_set)
|
|
|
|
else:
|
|
|
|
last_change.add_values.difference_update(pk_set)
|
|
|
|
last_change.remove_values.update(pk_set)
|
|
|
|
return
|
|
|
|
|
|
|
|
if action == "post_add":
|
2024-09-26 13:19:29 +02:00
|
|
|
self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set)))
|
2024-08-24 17:58:05 +02:00
|
|
|
else:
|
2024-09-26 13:19:29 +02:00
|
|
|
self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set)))
|
2024-08-22 22:49:07 +02:00
|
|
|
|
|
|
|
|
|
|
|
def handle_pre_change_instance(sender: Type[Model], **kwargs):
|
|
|
|
if sender._meta.app_label != 'mapdata':
|
|
|
|
return
|
2024-08-26 11:49:59 +02:00
|
|
|
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
|
2024-08-22 22:49:07 +02:00
|
|
|
if manager:
|
|
|
|
manager.handle_pre_change_instance(sender=sender, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def handle_post_save(sender: Type[Model], **kwargs):
|
|
|
|
if sender._meta.app_label != 'mapdata':
|
|
|
|
return
|
2024-08-26 11:49:59 +02:00
|
|
|
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
|
2024-08-22 22:49:07 +02:00
|
|
|
if manager:
|
|
|
|
manager.handle_post_save(sender=sender, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def handle_post_delete(sender: Type[Model], **kwargs):
|
|
|
|
if sender._meta.app_label != 'mapdata':
|
|
|
|
return
|
2024-08-26 11:49:59 +02:00
|
|
|
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
|
2024-08-22 22:49:07 +02:00
|
|
|
if manager:
|
|
|
|
manager.handle_post_delete(sender=sender, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def handle_m2m_changed(sender: Type[Model], **kwargs):
|
|
|
|
if sender._meta.app_label != 'mapdata':
|
|
|
|
return
|
2024-08-26 11:49:59 +02:00
|
|
|
manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None)
|
2024-08-22 22:49:07 +02:00
|
|
|
if manager:
|
|
|
|
manager.handle_m2m_changed(sender=sender, **kwargs)
|