import json from contextlib import contextmanager from dataclasses import dataclass, field from typing import Type from django.core import serializers from django.db import transaction from django.db.models import Model from django.db.models.fields.related import ManyToManyField from c3nav.editor.operations import DatabaseOperation, ObjectReference, FieldValuesDict, CreateObjectOperation, \ UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, CollectedOperations from c3nav.mapdata.fields import I18nField from c3nav.mapdata.models import LocationSlug try: from asgiref.local import Local as LocalContext except ImportError: from threading import local as LocalContext overlay_state = LocalContext() class InterceptAbortTransaction(Exception): pass @dataclass class DatabaseOverlayManager: changes: CollectedOperations new_operations: list[DatabaseOperation] = field(default_factory=list) pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict) @classmethod @contextmanager def enable(cls, changes: CollectedOperations | None, commit: bool): if getattr(overlay_state, 'manager', None) is not None: raise TypeError if changes is None: changes = CollectedOperations() try: with transaction.atomic(): manager = DatabaseOverlayManager(changes) manager.changes.prefetch().apply() overlay_state.manager = manager yield manager if not commit: raise InterceptAbortTransaction except InterceptAbortTransaction: pass finally: overlay_state.manager = None def save_new_operations(self): self.changes.operations.extend(self.new_operations) @staticmethod def get_model_field_values(instance: Model) -> FieldValuesDict: values = json.loads(serializers.serialize("json", [instance]))[0]["fields"] if issubclass(instance._meta.model, LocationSlug): values["slug"] = instance.slug return values def get_ref_and_pre_change_values(self, instance: Model) -> tuple[ObjectReference, FieldValuesDict]: ref = ObjectReference.from_instance(instance) pre_change_values = self.pre_change_values.pop(ref) self.changes.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None)) return ref, pre_change_values def handle_pre_change_instance(self, instance: Model, **kwargs): if instance.pk is None: return ref = ObjectReference.from_instance(instance) if ref not in self.pre_change_values and self.changes.prev.get(ref) is None: self.pre_change_values[ref] = self.get_model_field_values( instance._meta.model.objects.get(pk=instance.pk) ) def handle_post_save(self, instance: Model, created: bool, update_fields: set | None, **kwargs): field_values = self.get_model_field_values(instance) ref, pre_change_values = self.get_ref_and_pre_change_values(instance) if created: self.new_operations.append(CreateObjectOperation(obj=ref, fields=field_values)) return if update_fields: field_values = {name: value for name, value in field_values.items() if name in update_fields} if pre_change_values is not None: field_values = {name: value for name, value in field_values.items() if value != pre_change_values[name]} # special diffing within the i18n fields for field_name in tuple(field_values): if isinstance(instance._meta.get_field(field_name), I18nField): before_val = pre_change_values[field_name] after_val = field_values[field_name] diff_val = {} for lang in (set(before_val) | set(after_val)): if before_val.get(lang, None) != after_val.get(lang, None): diff_val[lang] = after_val.get(lang, None) field_values[field_name] = diff_val self.new_operations.append(UpdateObjectOperation(obj=ref, fields=field_values)) def handle_post_delete(self, instance: Model, **kwargs): ref, pre_change_values = self.get_ref_and_pre_change_values(instance) self.new_operations.append(DeleteObjectOperation(obj=ref)) def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model], pk_set: set | None, reverse: bool, **kwargs): if reverse: raise NotImplementedError if action.startswith("pre_"): return self.handle_pre_change_instance(sender=instance._meta.model, instance=instance) for field in instance._meta.get_fields(): if isinstance(field, ManyToManyField) and field.remote_field.through == sender: break else: raise ValueError ref, pre_change_values = self.get_ref_and_pre_change_values(instance) if action == "post_clear": self.new_operations.append(ClearManyToManyOperation(obj=ref, field=field.name)) return if self.new_operations: last_change = self.new_operations[-1] if isinstance(last_change, UpdateManyToManyOperation) and last_change == ref and last_change == field.name: if action == "post_add": last_change.add_values.update(pk_set) last_change.remove_values.difference_update(pk_set) else: last_change.add_values.difference_update(pk_set) last_change.remove_values.update(pk_set) return if action == "post_add": self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set))) else: self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set))) def handle_pre_change_instance(sender: Type[Model], **kwargs): if sender._meta.app_label != 'mapdata': return manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None) if manager: manager.handle_pre_change_instance(sender=sender, **kwargs) def handle_post_save(sender: Type[Model], **kwargs): if sender._meta.app_label != 'mapdata': return manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None) if manager: manager.handle_post_save(sender=sender, **kwargs) def handle_post_delete(sender: Type[Model], **kwargs): if sender._meta.app_label != 'mapdata': return manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None) if manager: manager.handle_post_delete(sender=sender, **kwargs) def handle_m2m_changed(sender: Type[Model], **kwargs): if sender._meta.app_label != 'mapdata': return manager: DatabaseOverlayManager = getattr(overlay_state, 'manager', None) if manager: manager.handle_m2m_changed(sender=sender, **kwargs)