mid-work commit from train

This commit is contained in:
Laura Klünder 2024-09-26 13:19:29 +02:00
parent 55b8e6e78c
commit 174866c2fd
8 changed files with 174 additions and 127 deletions

View file

@ -1,15 +1,14 @@
import copy
import datetime
import json
from dataclasses import dataclass
from typing import TypeAlias, Any, Annotated, Literal, Union
from typing import Annotated, Literal, Union, TypeAlias, Any
from uuid import UUID, uuid4
from django.apps import apps
from django.core import serializers
from django.db.models import Model
from django.utils import timezone
from pydantic.config import ConfigDict
from pydantic import ConfigDict
from pydantic.fields import Field
from pydantic.types import Discriminator
@ -17,10 +16,14 @@ from c3nav.api.schema import BaseSchema
from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug
FieldValuesDict: TypeAlias = dict[str, Any]
class ObjectReference(BaseSchema):
"""
Reference to an object based on model name and ID.
"""
model_config = ConfigDict(frozen=True)
model: str
id: int
@ -31,25 +34,42 @@ class ObjectReference(BaseSchema):
class PreviousObject(BaseSchema):
"""
Represents the previous state of an objects, consisting of its values and its (multi-language) titles
"""
titles: dict[str, str] | None
values: FieldValuesDict
class PreviousObjects(BaseSchema):
class PreviousObjectCollection(BaseSchema):
objects: dict[str, dict[int, PreviousObject]] = {}
def get(self, ref: ObjectReference) -> PreviousObject | None:
return self.objects.get(ref.model, {}).get(ref.id, None)
def get_ids(self) -> dict[str, set[int]]:
"""
:return: all referenced IDs sorted by model
"""
return {model: set(objs.keys()) for model, objs in self.objects.items()}
def get_instances(self) -> dict[ObjectReference, Model]:
"""
:return: all reference objects as fetched from the databse right now
"""
instances: dict[ObjectReference, Model] = {}
for model_name, ids in self.get_ids().items():
model = apps.get_model("mapdata", model_name)
instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance)
for instance in model.objects.filter(pk__in=ids)))
return instances
def set(self, ref: ObjectReference, values: FieldValuesDict, titles: dict | None):
self.objects.setdefault(ref.model, {})[ref.id] = PreviousObject(
values=values,
titles=titles,
)
def get_ids(self) -> dict[str, set[int]]:
return {model: set(objs.keys()) for model, objs in self.objects.items()}
class BaseOperation(BaseSchema):
obj: ObjectReference
@ -155,11 +175,6 @@ class ClearManyToManyOperation(BaseOperation):
return instance
class RevertOperation(BaseOperation):
type: Literal["revert"] = "revert"
reverts: UUID
DatabaseOperation = Annotated[
Union[
CreateObjectOperation,
@ -172,94 +187,34 @@ DatabaseOperation = Annotated[
]
class ChangedManyToMany(BaseSchema):
cleared: bool = False
added: list[str] = []
removed: list[str] = []
class ChangedObject(BaseSchema):
obj: ObjectReference
titles: dict[str, str] | None
created: bool = False
deleted: bool = False
fields: FieldValuesDict = {}
m2m_changes: dict[str, ChangedManyToMany] = {}
class CollectedOperations(BaseSchema):
uuid: UUID = Field(default_factory=uuid4)
prev: PreviousObjects = PreviousObjects()
class DatabaseOperationCollection(BaseSchema):
"""
A collection of database operations, sorted by model and id.
Also stores a PreviousObjectCollection for comparison with the current state.
Iterable as a list of DatabaseOperation instances.
"""
prev: PreviousObjectCollection = PreviousObjectCollection()
operations: list[DatabaseOperation] = []
def prefetch(self) -> "CollectedChangesPrefetch":
ids_to_query: dict[str, set[int]] = self.prev.get_ids()
def __iter__(self):
yield from self.operations
instances: dict[ObjectReference, Model] = {}
for model_name, ids in ids_to_query.items():
model = apps.get_model("mapdata", model_name)
instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance)
for instance in model.objects.filter(pk__in=ids)))
return CollectedChangesPrefetch(changes=self, instances=instances)
@property
def changed_objects(self) -> list[ChangedObject]:
objects = {}
reverted_uuids = frozenset(operation.reverts for operation in self.operationsy
if isinstance(operation, RevertOperation))
for operation in self.operations:
if operation.uuid in reverted_uuids:
continue
changed_object = objects.get(operation.obj, None)
if changed_object is None:
changed_object = ChangedObject(obj=operation.obj,
titles=self.prev_titles[operation.obj.model][operation.obj.id])
objects[operation.obj] = changed_object
if isinstance(operation, CreateObjectOperation):
changed_object.created = True
changed_object.fields.update(operation.fields)
elif isinstance(operation, UpdateObjectOperation):
model = apps.get_model('mapdata', operation.obj.model)
for field_name, value in operation.fields.items():
field = model._meta.get_field(field_name)
if isinstance(field, I18nField) and field_name in changed_object.fields:
changed_object.fields[field_name] = {
lang: val for lang, val in {**changed_object.fields[field_name], **value}.items()
}
else:
changed_object.fields[field_name] = value
elif isinstance(operation, DeleteObjectOperation):
changed_object.deleted = False
else:
changed_m2m = changed_object.m2m_changes.get(operation.field, None)
if changed_m2m is None:
changed_m2m = ChangedManyToMany()
changed_object.m2m_changes[operation.field] = changed_m2m
if isinstance(operation, ClearManyToManyOperation):
changed_m2m.cleared = True
changed_m2m.added = []
changed_m2m.removed = []
else:
changed_m2m.added = sorted((set(changed_m2m.added) | operation.add_values)
- operation.remove_values)
changed_m2m.removed = sorted((set(changed_m2m.removed) - operation.add_values)
| operation.remove_values)
return list(objects.values())
def prefetch(self) -> "PrefetchedDatabaseOperationCollection":
return PrefetchedDatabaseOperationCollection(operations=self, instances=self.prev.get_instances())
@dataclass
class CollectedChangesPrefetch:
changes: CollectedOperations
class PrefetchedDatabaseOperationCollection:
operations: DatabaseOperationCollection
instances: dict[ObjectReference, Model]
def apply(self):
# todo: what if unique constraint error occurs?
for operation in self.changes.operations:
for operation in self.operations.operations:
if isinstance(operation, CreateObjectOperation):
self.instances[operation.obj] = operation.apply_create()
else:
prev_obj = self.changes.prev.get(operation.obj)
prev_obj = self.operations.prev.get(operation.obj)
if prev_obj is None:
print('WARN WARN WARN')
values = prev_obj.values
@ -271,5 +226,4 @@ class CollectedChangesPrefetch:
else:
instance = None
if instance is not None:
operation.apply(values=values, instance=instance)
operation.apply(values=values, instance=instance)