diff --git a/src/c3nav/editor/changes.py b/src/c3nav/editor/changes.py index 0f137d95..611e2853 100644 --- a/src/c3nav/editor/changes.py +++ b/src/c3nav/editor/changes.py @@ -55,16 +55,11 @@ class DeleteObjectChange(BaseChange): type: Literal["delete"] = "delete" -class AddManyToManyChange(BaseSchema): - type: Literal["m2m_add"] = "m2m_add" +class UpdateManyToManyChange(BaseSchema): + type: Literal["m2m_add"] = "m2m_update" field: str - values: list[int] - - -class RemoveManyToManyChange(BaseSchema): - type: Literal["m2m_remove"] = "m2m_remove" - field: str - values: list[int] + add_values: set[int] = set() + remove_values: set[int] = set() class ClearManyToManyChange(BaseSchema): @@ -77,8 +72,7 @@ ChangeSetChange = Annotated[ CreateObjectChange, UpdateObjectChange, DeleteObjectChange, - AddManyToManyChange, - RemoveManyToManyChange, + UpdateManyToManyChange, ClearManyToManyChange, ], Discriminator("type"), @@ -116,13 +110,24 @@ def enable_changeset_overlay(changeset): @dataclass class ChangesetOverlayManager: changes: ChangeSetChanges - new_changes: bool = False + new_changes: list[ChangeSetChange] = field(default_factory=list) pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict) - def get_model_field_values(self, instance: Model) -> FieldValuesDict: + @staticmethod + def get_model_field_values(instance: Model) -> FieldValuesDict: return json.loads(serializers.serialize("json", [instance]))[0]["fields"] - def handle_pre_change_instance(self, sender: Type[Model], instance: Model, **kwargs): + def get_ref_and_pre_change_values(self, instance: Model) -> ObjectReference: + ref = ObjectReference.from_instance(instance) + + pre_change_values = self.pre_change_values.pop(ref, None) + if pre_change_values: + self.changes.prev_values[ref] = pre_change_values + self.changes.prev_reprs[ref] = str(instance) + + return ref + + def handle_pre_change_instance(self, instance: Model, **kwargs): if instance.pk is None: return ref = ObjectReference.from_instance(instance) @@ -131,40 +136,23 @@ class ChangesetOverlayManager: instance._meta.model.objects.get(pk=instance.pk) ) - def handle_post_save(self, sender: Type[Model], instance: Model, created: bool, - update_fields: set | None, **kwargs): + def handle_post_save(self, instance: Model, created: bool, update_fields: set | None, **kwargs): field_values = self.get_model_field_values(instance) - ref = ObjectReference.from_instance(instance) + ref = self.get_ref_and_pre_change_values(instance) if created: - self.changes.changes.append(CreateObjectChange(obj=ref, fields=field_values)) - from pprint import pprint - pprint(self.changes) + self.new_changes.append(CreateObjectChange(obj=ref, fields=field_values)) return if update_fields: field_values = {name: value for name, value in field_values.items() if name in update_fields} - pre_change_values = self.pre_change_values.pop(ref, None) - if pre_change_values: - self.changes.prev_values[ref] = pre_change_values - self.changes.prev_reprs[ref] = str(instance) - self.changes.changes.append(UpdateObjectChange(obj=ref, fields=field_values)) - from pprint import pprint - pprint(self.changes) + self.new_changes.append(UpdateObjectChange(obj=ref, fields=field_values)) - def handle_post_delete(self, sender: Type[Model], instance: Model, **kwargs): - ref = ObjectReference.from_instance(instance) - pre_change_values = self.pre_change_values.pop(ref, None) - if pre_change_values: - self.changes.prev_values[ref] = pre_change_values - self.changes.prev_reprs[ref] = str(instance) - self.changes.changes.append(DeleteObjectChange( - obj=ref, - )) - from pprint import pprint - pprint(self.changes) + def handle_post_delete(self, instance: Model, **kwargs): + ref = self.get_ref_and_pre_change_values(instance) + self.new_changes.append(DeleteObjectChange(obj=ref)) def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model], pk_set: set | None, reverse: bool, **kwargs): @@ -176,38 +164,33 @@ class ChangesetOverlayManager: for field in instance._meta.get_fields(): if isinstance(field, ManyToManyField): + # todo: actually identify field!! + raise NotImplementedError break else: raise ValueError - ref = ObjectReference.from_instance(instance) - pre_change_values = self.pre_change_values.pop(ref, None) - if pre_change_values: - self.changes.prev_values[ref] = pre_change_values + ref = self.get_ref_and_pre_change_values(instance) - match(action): - case "post_add": - self.changes.changes.append(AddManyToManyChange( - obj=ref, - field=field.name, - values=list(pk_set), - )) + if action == "post_clear": + self.new_changes.append(ClearManyToManyChange(obj=ref, field=field.name)) + return - case "post_remove": - self.changes.changes.append(RemoveManyToManyChange( - obj=ref, - field=field.name, - values=list(pk_set), - )) + if self.new_changes: + last_change = self.new_changes[-1] + if isinstance(last_change, UpdateManyToManyChange) and last_change == ref and last_change == field.name: + if action == "post_add": + last_change.add_values.update(pk_set) + last_change.remove_values.difference_update(pk_set) + else: + last_change.add_values.difference_update(pk_set) + last_change.remove_values.update(pk_set) + return - case "post_clear": - self.changes.changes.append(ClearManyToManyChange( - obj=ref, - field=field.name, - )) - - from pprint import pprint - pprint(self.changes) + if action == "post_add": + self.new_changes.append(UpdateManyToManyChange(obj=ref, field=field.name, add_values=list(pk_set))) + else: + self.new_changes.append(UpdateManyToManyChange(obj=ref, field=field.name, remove_values=list(pk_set))) def handle_pre_change_instance(sender: Type[Model], **kwargs):