diff --git a/src/c3nav/editor/operations.py b/src/c3nav/editor/operations.py index c8b34cfd..34485735 100644 --- a/src/c3nav/editor/operations.py +++ b/src/c3nav/editor/operations.py @@ -30,6 +30,27 @@ class ObjectReference(BaseSchema): return cls(model=instance._meta.model_name, id=instance.pk) +class PreviousObject(BaseSchema): + titles: dict[str, str] | None + values: FieldValuesDict + + +class PreviousObjects(BaseSchema): + objects: dict[str, dict[int, PreviousObject]] = {} + + def get(self, ref: ObjectReference) -> PreviousObject | None: + return self.objects.get(ref.model, {}).get(ref.id, None) + + def set(self, ref: ObjectReference, values: FieldValuesDict, titles: dict | None): + self.objects.setdefault(ref.model, {})[ref.id] = PreviousObject( + values=values, + titles=titles, + ) + + def get_ids(self) -> dict[str, set[int]]: + return {model: set(objs.keys()) for model, objs in self.objects.items()} + + class BaseOperation(BaseSchema): obj: ObjectReference uuid: UUID = Field(default_factory=uuid4) @@ -168,13 +189,11 @@ class ChangedObject(BaseSchema): class CollectedChanges(BaseSchema): uuid: UUID = Field(default_factory=uuid4) - prev_titles: dict[str, dict[int, dict[str, str] | None]] = {} - prev_values: dict[str, dict[int, FieldValuesDict]] = {} + prev: PreviousObjects = PreviousObjects() operations: list[DatabaseOperation] = [] def prefetch(self) -> "CollectedChangesPrefetch": - ids_to_query: dict[str, set[int]] = {model_name: set(val.keys()) - for model_name, val in self.prev_values.items()} + ids_to_query: dict[str, set[int]] = self.prev.get_ids() instances: dict[ObjectReference, Model] = {} for model_name, ids in ids_to_query.items(): @@ -236,19 +255,18 @@ class CollectedChangesPrefetch: def apply(self): # todo: what if unique constraint error occurs? - prev_values = copy.deepcopy(self.changes.prev_values) for operation in self.changes.operations: if isinstance(operation, CreateObjectOperation): self.instances[operation.obj] = operation.apply_create() else: - in_prev_values = operation.obj.id in prev_values.get(operation.obj.model, {}) - if not in_prev_values: + prev_obj = self.changes.prev.get(operation.obj) + if prev_obj is None: print('WARN WARN WARN') - values = prev_values.setdefault(operation.obj.model, {}).setdefault(operation.obj.id, {}) + values = prev_obj.values try: instance = self.instances[operation.obj] except KeyError: - if not in_prev_values: + if prev_obj is None: instance = apps.get_model("mapdata", operation.obj.model).filter(pk=operation.obj.id).first() else: instance = None diff --git a/src/c3nav/editor/overlay.py b/src/c3nav/editor/overlay.py index 8b4a9a56..a964ff13 100644 --- a/src/c3nav/editor/overlay.py +++ b/src/c3nav/editor/overlay.py @@ -65,18 +65,15 @@ class DatabaseOverlayManager: def get_ref_and_pre_change_values(self, instance: Model) -> tuple[ObjectReference, FieldValuesDict]: ref = ObjectReference.from_instance(instance) - pre_change_values = self.pre_change_values.pop(ref, None) - if pre_change_values: - self.changes.prev_values.setdefault(ref.model, {})[ref.id] = pre_change_values - self.changes.prev_titles.setdefault(ref.model, {})[ref.id] = getattr(instance, 'titles', None) - + pre_change_values = self.pre_change_values.pop(ref) + self.changes.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None)) return ref, pre_change_values def handle_pre_change_instance(self, instance: Model, **kwargs): if instance.pk is None: return ref = ObjectReference.from_instance(instance) - if ref not in self.pre_change_values and ref.id not in self.changes.prev_values.get(ref.model, {}): + if ref not in self.pre_change_values and self.changes.prev.get(ref) is None: self.pre_change_values[ref] = self.get_model_field_values( instance._meta.model.objects.get(pk=instance.pk) ) diff --git a/src/c3nav/editor/views/changes.py b/src/c3nav/editor/views/changes.py index 509f8d72..3281003e 100644 --- a/src/c3nav/editor/views/changes.py +++ b/src/c3nav/editor/views/changes.py @@ -198,7 +198,7 @@ def changeset_detail(request, pk): if changed_object.created and not changed_object.deleted: added_redirects.setdefault(changed_object.fields["target"], set()).add(changed_object.fields["slug"]) elif changed_object.deleted: - orig_values = changeset.changes.prev_values["locationredirect"][changed_object.obj.id] + orig_values = changeset.changes.prev.get(changed_object.obj).values removed_redirects.setdefault(orig_values["target"], set()).add(orig_values["slug"]) else: raise ValueError # dafuq? not possibile through the editor @@ -218,7 +218,7 @@ def changeset_detail(request, pk): else: title = next(iter(changed_object.titles.values())) - prev_values = changeset.changes.prev_values[changed_object.obj.model][changed_object.obj.id] + prev_values = changeset.changes.prev.get(changed_object.obj).values edit_url = None if not changed_object.deleted: