diff --git a/src/c3nav/editor/changes.py b/src/c3nav/editor/changes.py index 5f6e2a26..0f137d95 100644 --- a/src/c3nav/editor/changes.py +++ b/src/c3nav/editor/changes.py @@ -7,7 +7,7 @@ from typing import TypeAlias, Literal, Annotated, Union, Type, Any from django.core import serializers from django.db import transaction from django.db.models import Model -from django.db.models.fields.related import OneToOneField, ForeignKey +from django.db.models.fields.related import OneToOneField, ForeignKey, ManyToManyField from django.utils import timezone from pydantic import ConfigDict, Discriminator from pydantic.fields import Field @@ -88,7 +88,6 @@ ChangeSetChange = Annotated[ class ChangeSetChanges(BaseSchema): prev_reprs: dict[ObjectReference, str] = {} prev_values: dict[ObjectReference, FieldValuesDict] = {} - prev_m2m: dict[ObjectReference, dict[str, list[int]]] = {} changes: list[ChangeSetChange] = [] @@ -167,8 +166,48 @@ class ChangesetOverlayManager: from pprint import pprint pprint(self.changes) - def handle_m2m_changed(self, sender: Type[Model], instance: Model, **kwargs): - pass + def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model], + pk_set: set | None, reverse: bool, **kwargs): + if reverse: + raise NotImplementedError + + if action.startswith("pre_"): + return self.handle_pre_change_instance(sender=instance._meta.model, instance=instance) + + for field in instance._meta.get_fields(): + if isinstance(field, ManyToManyField): + break + else: + raise ValueError + + ref = ObjectReference.from_instance(instance) + pre_change_values = self.pre_change_values.pop(ref, None) + if pre_change_values: + self.changes.prev_values[ref] = pre_change_values + + match(action): + case "post_add": + self.changes.changes.append(AddManyToManyChange( + obj=ref, + field=field.name, + values=list(pk_set), + )) + + case "post_remove": + self.changes.changes.append(RemoveManyToManyChange( + obj=ref, + field=field.name, + values=list(pk_set), + )) + + case "post_clear": + self.changes.changes.append(ClearManyToManyChange( + obj=ref, + field=field.name, + )) + + from pprint import pprint + pprint(self.changes) def handle_pre_change_instance(sender: Type[Model], **kwargs):