team-3/src/c3nav/editor/operations.py
2024-12-05 19:04:29 +01:00

266 lines
No EOL
9 KiB
Python

import json
from dataclasses import dataclass
from typing import Annotated, Literal, Union, TypeAlias, Any, Self, Iterator
from django.apps import apps
from django.core import serializers
from django.db.models import Model
from pydantic import ConfigDict
from pydantic.types import Discriminator
from c3nav.api.schema import BaseSchema
from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug
ModelName: TypeAlias = str
ObjectID: TypeAlias = int
FieldName: TypeAlias = str
FieldValuesDict: TypeAlias = dict[FieldName, Any]
class ObjectReference(BaseSchema):
"""
Reference to an object based on model name and ID.
"""
model_config = ConfigDict(frozen=True)
model: ModelName
id: ObjectID
@classmethod
def from_instance(cls, instance: Model):
return cls(model=instance._meta.model_name, id=instance.pk)
class PreviousObject(BaseSchema):
"""
Represents the previous state of an objects, consisting of its values and its (multi-language) titles
"""
titles: dict[str, str] | None
values: FieldValuesDict
class PreviousObjectCollection(BaseSchema):
objects: dict[ModelName, dict[ObjectID, PreviousObject]] = {}
def get(self, ref: ObjectReference) -> PreviousObject | None:
return self.objects.get(ref.model, {}).get(ref.id, None)
def get_ids(self) -> dict[ModelName, set[ObjectID]]:
"""
:return: all referenced IDs sorted by model
"""
return {model: set(objs.keys()) for model, objs in self.objects.items()}
def get_instances(self) -> dict[ObjectReference, Model]:
"""
:return: all reference objects as fetched from the databse right now
"""
instances: dict[ObjectReference, Model] = {}
for model_name, ids in self.get_ids().items():
model = apps.get_model("mapdata", model_name)
instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance)
for instance in model.objects.filter(pk__in=ids)))
return instances
def set(self, ref: ObjectReference, values: FieldValuesDict, titles: dict | None):
self.objects.setdefault(ref.model, {})[ref.id] = PreviousObject(
values=values,
titles=titles,
)
def add_other(self, other: Self):
for key in set(self.objects.keys()) | set(other.objects.keys()):
self.objects[key] = {**other.objects.get(key, {}), **self.objects.get(key, {})}
class BaseOperation(BaseSchema):
obj: ObjectReference
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
raise NotImplementedError
class CreateObjectOperation(BaseOperation):
type: Literal["create"] = "create"
fields: FieldValuesDict
def get_data(self):
model = apps.get_model('mapdata', self.obj.model)
data = []
if issubclass(model, LocationSlug):
data.append({
"model": f"mapdata.locationslug",
"pk": self.obj.id,
"fields": {
"slug": self.fields.get("slug", None)
},
})
values = {key: val for key, val in self.fields.items() if key != "slug"}
else:
values = self.fields
data.append({
"model": f"mapdata.{self.obj.model}",
"pk": self.obj.id,
"fields": values,
})
return data
def apply_create(self) -> dict[ObjectReference, Model]:
data = self.get_data()
instances = list(serializers.deserialize("json", json.dumps(data)))
for instance in instances:
# .object. to make sure our own .save() function is called!
instance.object.save()
return {self.obj: instances[-1].object}
class CreateMultipleObjectsOperation(BaseSchema):
type: Literal["create_multiple"] = "create_multiple"
objects: list[CreateObjectOperation] = []
def apply_create(self) -> dict[ObjectReference, Model]:
indexes = {}
data = []
for obj in self.objects:
data.extend(obj.get_data())
indexes[obj.obj] = len(data)-1
instances = list(serializers.deserialize("json", json.dumps(data)))
# todo: actually do a create_multiple!, let's not forget about register_changed_geometries etc
for instance in instances:
# .object. to make sure our own .save() function is called!
instance.object.save()
return {ref: instances[i] for ref, i in indexes.items()}
# todo: delete multiple objects
class UpdateObjectOperation(BaseOperation):
type: Literal["update"] = "update"
fields: FieldValuesDict
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
model = apps.get_model('mapdata', self.obj.model)
for field_name, value in self.fields.items():
field = model._meta.get_field(field_name)
if isinstance(field, I18nField) and field_name in self.fields:
values[field_name] = {lang: val for lang, val in {**values[field_name], **value}.items()
if val is not None}
else:
values[field_name] = value
data = []
if issubclass(model, LocationSlug) and "slug" in values:
data.append({
"model": f"mapdata.locationslug",
"pk": self.obj.id,
"fields": {
"slug": values["slug"],
},
})
values = {key: val for key, val in values.items() if key != "slug"}
data.append({
"model": f"mapdata.{self.obj.model}",
"pk": self.obj.id,
"fields": values,
})
instances = list(serializers.deserialize("json", json.dumps(data)))
for instance in instances:
instance.object.save()
return instances[-1].object
class DeleteObjectOperation(BaseOperation):
type: Literal["delete"] = "delete"
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
instance.delete()
return instance
class UpdateManyToManyOperation(BaseOperation):
type: Literal["m2m_add"] = "m2m_update"
field: FieldName
add_values: set[ObjectID] = set()
remove_values: set[ObjectID] = set()
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
values[self.field] = sorted((set(values[self.field]) | self.add_values) - self.remove_values)
field_manager = getattr(instance, self.field)
field_manager.add(*self.add_values)
field_manager.remove(*self.remove_values)
return instance
class ClearManyToManyOperation(BaseOperation):
type: Literal["m2m_clear"] = "m2m_clear"
field: FieldName
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
values[self.field] = []
getattr(instance, self.field).clear()
return instance
DatabaseOperation = Annotated[
Union[
CreateObjectOperation,
CreateMultipleObjectsOperation,
UpdateObjectOperation,
DeleteObjectOperation,
UpdateManyToManyOperation,
ClearManyToManyOperation,
],
Discriminator("type"),
]
class DatabaseOperationCollection(BaseSchema):
"""
A collection of database operations, sorted by model and id.
Also stores a PreviousObjectCollection for comparison with the current state.
Iterable as a list of DatabaseOperation instances.
"""
prev: PreviousObjectCollection = PreviousObjectCollection()
operations: list[DatabaseOperation] = []
def __iter__(self) -> Iterator[DatabaseOperation]:
yield from self.operations
def __len__(self):
return len(self.operations)
def extend(self, items: list[DatabaseOperation]):
self.operations.extend(items)
def append(self, item: DatabaseOperation):
self.operations.append(item)
def prefetch(self) -> "PrefetchedDatabaseOperationCollection":
return PrefetchedDatabaseOperationCollection(operations=self, instances=self.prev.get_instances())
@dataclass
class PrefetchedDatabaseOperationCollection:
operations: DatabaseOperationCollection
instances: dict[ObjectReference, Model]
def apply(self):
# todo: what if unique constraint error occurs?
for operation in self.operations:
if isinstance(operation, (CreateObjectOperation, CreateMultipleObjectsOperation)):
self.instances.update(operation.apply_create())
else:
prev_obj = self.operations.prev.get(operation.obj)
if prev_obj is None:
print('WARN WARN WARN')
values = prev_obj.values
try:
instance = self.instances[operation.obj]
except KeyError:
if prev_obj is None:
instance = apps.get_model("mapdata", operation.obj.model).filter(pk=operation.obj.id).first()
else:
instance = None
if instance is not None:
operation.apply(values=values, instance=instance)