team-3/src/c3nav/editor/operations.py

258 lines
9.5 KiB
Python
Raw Normal View History

import copy
import datetime
import json
from dataclasses import dataclass
from typing import TypeAlias, Any, Annotated, Literal, Union
from uuid import UUID, uuid4
from django.apps import apps
from django.core import serializers
from django.db.models import Model
from django.utils import timezone
from pydantic.config import ConfigDict
from pydantic.fields import Field
from pydantic.types import Discriminator
from c3nav.api.schema import BaseSchema
from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug
FieldValuesDict: TypeAlias = dict[str, Any]
class ObjectReference(BaseSchema):
model_config = ConfigDict(frozen=True)
model: str
id: int
@classmethod
def from_instance(cls, instance: Model):
return cls(model=instance._meta.model_name, id=instance.pk)
class BaseOperation(BaseSchema):
obj: ObjectReference
uuid: UUID = Field(default_factory=uuid4)
datetime: Annotated[datetime.datetime, Field(default_factory=timezone.now)]
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
raise NotImplementedError
class CreateObjectOperation(BaseOperation):
type: Literal["create"] = "create"
fields: FieldValuesDict
def apply_create(self) -> Model:
model = apps.get_model('mapdata', self.obj.model)
data = []
if issubclass(model, LocationSlug):
data.append({
"model": f"mapdata.locationslug",
"pk": self.obj.id,
"fields": {
"slug": self.fields.get("slug", None)
},
})
values = {key: val for key, val in self.fields.items() if key != "slug"}
else:
values = self.fields
data.append({
"model": f"mapdata.{self.obj.model}",
"pk": self.obj.id,
"fields": values,
})
instances = list(serializers.deserialize("json", json.dumps(data)))
for instance in instances:
instance.save(save_m2m=False)
return instances[-1].object
class UpdateObjectOperation(BaseOperation):
type: Literal["update"] = "update"
fields: FieldValuesDict
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
model = apps.get_model('mapdata', self.obj.model)
for field_name, value in self.fields.items():
field = model._meta.get_field(field_name)
if isinstance(field, I18nField) and field_name in self.fields:
2024-08-26 19:50:03 +02:00
values[field_name] = {lang: val for lang, val in {**values[field_name], **value}.items()
if val is not None}
else:
values[field_name] = value
data = []
if issubclass(model, LocationSlug) and "slug" in values:
data.append({
"model": f"mapdata.locationslug",
"pk": self.obj.id,
"fields": {
"slug": values["slug"],
},
})
values = {key: val for key, val in values.items() if key != "slug"}
data.append({
"model": f"mapdata.{self.obj.model}",
"pk": self.obj.id,
"fields": values,
})
instances = list(serializers.deserialize("json", json.dumps(data)))
for instance in instances:
instance.save(save_m2m=False)
return instances[-1].object
class DeleteObjectOperation(BaseOperation):
type: Literal["delete"] = "delete"
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
instance.delete()
return instance
class UpdateManyToManyOperation(BaseOperation):
type: Literal["m2m_add"] = "m2m_update"
field: str
add_values: set[int] = set()
remove_values: set[int] = set()
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
values[self.field] = sorted((set(values[self.field]) | self.add_values) - self.remove_values)
field_manager = getattr(instance, self.field)
field_manager.add(*self.add_values)
field_manager.remove(*self.remove_values)
return instance
class ClearManyToManyOperation(BaseOperation):
type: Literal["m2m_clear"] = "m2m_clear"
field: str
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
values[self.field] = []
getattr(instance, self.field).clear()
return instance
class RevertOperation(BaseOperation):
type: Literal["revert"] = "revert"
reverts: UUID
DatabaseOperation = Annotated[
Union[
CreateObjectOperation,
UpdateObjectOperation,
DeleteObjectOperation,
UpdateManyToManyOperation,
ClearManyToManyOperation,
],
Discriminator("type"),
]
2024-08-26 14:58:37 +02:00
class ChangedManyToMany(BaseSchema):
cleared: bool = False
added: list[str] = []
removed: list[str] = []
class ChangedObject(BaseSchema):
obj: ObjectReference
2024-08-26 20:29:56 +02:00
titles: dict[str, str] | None
2024-08-26 14:58:37 +02:00
created: bool = False
deleted: bool = False
fields: FieldValuesDict = {}
m2m_changes: dict[str, ChangedManyToMany] = {}
class CollectedChanges(BaseSchema):
uuid: UUID = Field(default_factory=uuid4)
2024-08-26 20:29:56 +02:00
prev_titles: dict[str, dict[int, dict[str, str] | None]] = {}
prev_values: dict[str, dict[int, FieldValuesDict]] = {}
operations: list[DatabaseOperation] = []
def prefetch(self) -> "CollectedChangesPrefetch":
ids_to_query: dict[str, set[int]] = {model_name: set(val.keys())
for model_name, val in self.prev_values.items()}
instances: dict[ObjectReference, Model] = {}
for model_name, ids in ids_to_query.items():
model = apps.get_model("mapdata", model_name)
instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance)
for instance in model.objects.filter(pk__in=ids)))
return CollectedChangesPrefetch(changes=self, instances=instances)
2024-08-26 14:58:37 +02:00
@property
def changed_objects(self) -> list[ChangedObject]:
objects = {}
reverted_uuids = frozenset(operation.reverts for operation in self.operationsy
if isinstance(operation, RevertOperation))
2024-08-26 14:58:37 +02:00
for operation in self.operations:
if operation.uuid in reverted_uuids:
continue
2024-08-26 14:58:37 +02:00
changed_object = objects.get(operation.obj, None)
if changed_object is None:
changed_object = ChangedObject(obj=operation.obj,
titles=self.prev_titles[operation.obj.model][operation.obj.id])
2024-08-26 14:58:37 +02:00
objects[operation.obj] = changed_object
if isinstance(operation, CreateObjectOperation):
changed_object.created = True
changed_object.fields.update(operation.fields)
elif isinstance(operation, UpdateObjectOperation):
model = apps.get_model('mapdata', operation.obj.model)
for field_name, value in operation.fields.items():
field = model._meta.get_field(field_name)
if isinstance(field, I18nField) and field_name in changed_object.fields:
2024-08-26 19:50:03 +02:00
changed_object.fields[field_name] = {
lang: val for lang, val in {**changed_object.fields[field_name], **value}.items()
}
else:
changed_object.fields[field_name] = value
2024-08-26 14:58:37 +02:00
elif isinstance(operation, DeleteObjectOperation):
changed_object.deleted = False
else:
changed_m2m = changed_object.m2m_changes.get(operation.field, None)
if changed_m2m is None:
changed_m2m = ChangedManyToMany()
changed_object.m2m_changes[operation.field] = changed_m2m
if isinstance(operation, ClearManyToManyOperation):
changed_m2m.cleared = True
changed_m2m.added = []
changed_m2m.removed = []
else:
changed_m2m.added = sorted((set(changed_m2m.added) | operation.add_values)
- operation.remove_values)
changed_m2m.removed = sorted((set(changed_m2m.removed) - operation.add_values)
| operation.remove_values)
return list(objects.values())
@dataclass
class CollectedChangesPrefetch:
changes: CollectedChanges
instances: dict[ObjectReference, Model]
def apply(self):
2024-08-26 15:01:40 +02:00
# todo: what if unique constraint error occurs?
prev_values = copy.deepcopy(self.changes.prev_values)
for operation in self.changes.operations:
if isinstance(operation, CreateObjectOperation):
self.instances[operation.obj] = operation.apply_create()
else:
in_prev_values = operation.obj.id in prev_values.get(operation.obj.model, {})
if not in_prev_values:
print('WARN WARN WARN')
values = prev_values.setdefault(operation.obj.model, {}).setdefault(operation.obj.id, {})
try:
instance = self.instances[operation.obj]
except KeyError:
if not in_prev_values:
instance = apps.get_model("mapdata", operation.obj.model).filter(pk=operation.obj.id).first()
else:
instance = None
if instance is not None:
operation.apply(values=values, instance=instance)