mid-work commit from train

This commit is contained in:
Laura Klünder 2024-09-26 13:19:29 +02:00
parent 55b8e6e78c
commit 174866c2fd
8 changed files with 174 additions and 127 deletions

View file

@ -0,0 +1,82 @@
from itertools import chain
from django.apps import apps
from c3nav.api.schema import BaseSchema
from c3nav.editor.operations import DatabaseOperationCollection, CreateObjectOperation, UpdateObjectOperation, \
DeleteObjectOperation, ClearManyToManyOperation, FieldValuesDict, ObjectReference, PreviousObjectCollection
from c3nav.mapdata.fields import I18nField
class ChangedManyToMany(BaseSchema):
cleared: bool = False
added: list[str] = []
removed: list[str] = []
class ChangedObject(BaseSchema):
obj: ObjectReference
titles: dict[str, str] | None
created: bool = False
deleted: bool = False
fields: FieldValuesDict = {}
m2m_changes: dict[str, ChangedManyToMany] = {}
class ChangedObjectCollection(BaseSchema):
"""
A collection of ChangedObject instances, sorted by model and id.
Also stores a PreviousObjectCollection for comparison with the current state.
Iterable as a list of ChangedObject instances.
"""
prev: PreviousObjectCollection = PreviousObjectCollection()
objects: dict[str, dict[int, ChangedObject]] = {}
def __iter__(self):
yield from chain(*(objects.keys() for model, objects in self.objects.items()))
def add_operations(self, operations: DatabaseOperationCollection):
"""
Add the given operations, creating/updating changed objects to represent the resulting state.
"""
# todo: merge prev
for operation in operations.operations:
changed_object = self.objects.setdefault(operation.obj.model, {}).get(operation.obj.id, None)
if changed_object is None:
changed_object = ChangedObject(obj=operation.obj,
titles=self.prev.get(operation.obj).titles)
self.objects[operation.obj.model][operation.obj.id] = changed_object
if isinstance(operation, CreateObjectOperation):
changed_object.created = True
changed_object.fields.update(operation.fields)
elif isinstance(operation, UpdateObjectOperation):
model = apps.get_model('mapdata', operation.obj.model)
for field_name, value in operation.fields.items():
field = model._meta.get_field(field_name)
if isinstance(field, I18nField) and field_name in changed_object.fields:
changed_object.fields[field_name] = {
lang: val for lang, val in {**changed_object.fields[field_name], **value}.items()
}
else:
changed_object.fields[field_name] = value
elif isinstance(operation, DeleteObjectOperation):
changed_object.deleted = False
else:
changed_m2m = changed_object.m2m_changes.get(operation.field, None)
if changed_m2m is None:
changed_m2m = ChangedManyToMany()
changed_object.m2m_changes[operation.field] = changed_m2m
if isinstance(operation, ClearManyToManyOperation):
changed_m2m.cleared = True
changed_m2m.added = []
changed_m2m.removed = []
else:
changed_m2m.added = sorted((set(changed_m2m.added) | operation.add_values)
- operation.remove_values)
changed_m2m.removed = sorted((set(changed_m2m.removed) - operation.add_values)
| operation.remove_values)
@property
def as_operations(self) -> DatabaseOperationCollection:
pass # todo: implement

View file

@ -1,6 +1,6 @@
# Generated by Django 5.0.8 on 2024-08-26 09:46 # Generated by Django 5.0.8 on 2024-09-26 10:06
import c3nav.editor.operations import c3nav.editor.changes
import django.core.serializers.json import django.core.serializers.json
import django_pydantic_field.fields import django_pydantic_field.fields
from django.db import migrations from django.db import migrations
@ -20,7 +20,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name='changeset', model_name='changeset',
name='changes', name='changes',
field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.operations.CollectedOperations, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.operations.CollectedOperations), field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.changes.ChangedObjectCollection, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.changes.ChangedObjectCollection),
), ),
migrations.DeleteModel( migrations.DeleteModel(
name='ChangedObject', name='ChangedObject',

View file

@ -12,7 +12,7 @@ from django.utils.translation import gettext_lazy as _
from django.utils.translation import ngettext_lazy from django.utils.translation import ngettext_lazy
from django_pydantic_field import SchemaField from django_pydantic_field import SchemaField
from c3nav.editor.operations import CollectedOperations from c3nav.editor.changes import ChangedObjectCollection
from c3nav.editor.tasks import send_changeset_proposed_notification from c3nav.editor.tasks import send_changeset_proposed_notification
from c3nav.mapdata.models import LocationSlug, MapUpdate from c3nav.mapdata.models import LocationSlug, MapUpdate
from c3nav.mapdata.models.locations import LocationRedirect from c3nav.mapdata.models.locations import LocationRedirect
@ -43,7 +43,7 @@ class ChangeSet(models.Model):
related_name='assigned_changesets', verbose_name=_('assigned to')) related_name='assigned_changesets', verbose_name=_('assigned to'))
map_update = models.OneToOneField(MapUpdate, null=True, related_name='changeset', map_update = models.OneToOneField(MapUpdate, null=True, related_name='changeset',
verbose_name=_('map update'), on_delete=models.PROTECT) verbose_name=_('map update'), on_delete=models.PROTECT)
changes: CollectedOperations = SchemaField(schema=CollectedOperations, default=CollectedOperations) changes: ChangedObjectCollection = SchemaField(schema=ChangedObjectCollection, default=ChangedObjectCollection)
class Meta: class Meta:
verbose_name = _('Change Set') verbose_name = _('Change Set')

View file

@ -1,15 +1,14 @@
import copy
import datetime import datetime
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeAlias, Any, Annotated, Literal, Union from typing import Annotated, Literal, Union, TypeAlias, Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from django.apps import apps from django.apps import apps
from django.core import serializers from django.core import serializers
from django.db.models import Model from django.db.models import Model
from django.utils import timezone from django.utils import timezone
from pydantic.config import ConfigDict from pydantic import ConfigDict
from pydantic.fields import Field from pydantic.fields import Field
from pydantic.types import Discriminator from pydantic.types import Discriminator
@ -17,10 +16,14 @@ from c3nav.api.schema import BaseSchema
from c3nav.mapdata.fields import I18nField from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug from c3nav.mapdata.models import LocationSlug
FieldValuesDict: TypeAlias = dict[str, Any] FieldValuesDict: TypeAlias = dict[str, Any]
class ObjectReference(BaseSchema): class ObjectReference(BaseSchema):
"""
Reference to an object based on model name and ID.
"""
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
model: str model: str
id: int id: int
@ -31,25 +34,42 @@ class ObjectReference(BaseSchema):
class PreviousObject(BaseSchema): class PreviousObject(BaseSchema):
"""
Represents the previous state of an objects, consisting of its values and its (multi-language) titles
"""
titles: dict[str, str] | None titles: dict[str, str] | None
values: FieldValuesDict values: FieldValuesDict
class PreviousObjects(BaseSchema): class PreviousObjectCollection(BaseSchema):
objects: dict[str, dict[int, PreviousObject]] = {} objects: dict[str, dict[int, PreviousObject]] = {}
def get(self, ref: ObjectReference) -> PreviousObject | None: def get(self, ref: ObjectReference) -> PreviousObject | None:
return self.objects.get(ref.model, {}).get(ref.id, None) return self.objects.get(ref.model, {}).get(ref.id, None)
def get_ids(self) -> dict[str, set[int]]:
"""
:return: all referenced IDs sorted by model
"""
return {model: set(objs.keys()) for model, objs in self.objects.items()}
def get_instances(self) -> dict[ObjectReference, Model]:
"""
:return: all reference objects as fetched from the databse right now
"""
instances: dict[ObjectReference, Model] = {}
for model_name, ids in self.get_ids().items():
model = apps.get_model("mapdata", model_name)
instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance)
for instance in model.objects.filter(pk__in=ids)))
return instances
def set(self, ref: ObjectReference, values: FieldValuesDict, titles: dict | None): def set(self, ref: ObjectReference, values: FieldValuesDict, titles: dict | None):
self.objects.setdefault(ref.model, {})[ref.id] = PreviousObject( self.objects.setdefault(ref.model, {})[ref.id] = PreviousObject(
values=values, values=values,
titles=titles, titles=titles,
) )
def get_ids(self) -> dict[str, set[int]]:
return {model: set(objs.keys()) for model, objs in self.objects.items()}
class BaseOperation(BaseSchema): class BaseOperation(BaseSchema):
obj: ObjectReference obj: ObjectReference
@ -155,11 +175,6 @@ class ClearManyToManyOperation(BaseOperation):
return instance return instance
class RevertOperation(BaseOperation):
type: Literal["revert"] = "revert"
reverts: UUID
DatabaseOperation = Annotated[ DatabaseOperation = Annotated[
Union[ Union[
CreateObjectOperation, CreateObjectOperation,
@ -172,94 +187,34 @@ DatabaseOperation = Annotated[
] ]
class ChangedManyToMany(BaseSchema): class DatabaseOperationCollection(BaseSchema):
cleared: bool = False """
added: list[str] = [] A collection of database operations, sorted by model and id.
removed: list[str] = [] Also stores a PreviousObjectCollection for comparison with the current state.
Iterable as a list of DatabaseOperation instances.
"""
class ChangedObject(BaseSchema): prev: PreviousObjectCollection = PreviousObjectCollection()
obj: ObjectReference
titles: dict[str, str] | None
created: bool = False
deleted: bool = False
fields: FieldValuesDict = {}
m2m_changes: dict[str, ChangedManyToMany] = {}
class CollectedOperations(BaseSchema):
uuid: UUID = Field(default_factory=uuid4)
prev: PreviousObjects = PreviousObjects()
operations: list[DatabaseOperation] = [] operations: list[DatabaseOperation] = []
def prefetch(self) -> "CollectedChangesPrefetch": def __iter__(self):
ids_to_query: dict[str, set[int]] = self.prev.get_ids() yield from self.operations
instances: dict[ObjectReference, Model] = {} def prefetch(self) -> "PrefetchedDatabaseOperationCollection":
for model_name, ids in ids_to_query.items(): return PrefetchedDatabaseOperationCollection(operations=self, instances=self.prev.get_instances())
model = apps.get_model("mapdata", model_name)
instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance)
for instance in model.objects.filter(pk__in=ids)))
return CollectedChangesPrefetch(changes=self, instances=instances)
@property
def changed_objects(self) -> list[ChangedObject]:
objects = {}
reverted_uuids = frozenset(operation.reverts for operation in self.operationsy
if isinstance(operation, RevertOperation))
for operation in self.operations:
if operation.uuid in reverted_uuids:
continue
changed_object = objects.get(operation.obj, None)
if changed_object is None:
changed_object = ChangedObject(obj=operation.obj,
titles=self.prev_titles[operation.obj.model][operation.obj.id])
objects[operation.obj] = changed_object
if isinstance(operation, CreateObjectOperation):
changed_object.created = True
changed_object.fields.update(operation.fields)
elif isinstance(operation, UpdateObjectOperation):
model = apps.get_model('mapdata', operation.obj.model)
for field_name, value in operation.fields.items():
field = model._meta.get_field(field_name)
if isinstance(field, I18nField) and field_name in changed_object.fields:
changed_object.fields[field_name] = {
lang: val for lang, val in {**changed_object.fields[field_name], **value}.items()
}
else:
changed_object.fields[field_name] = value
elif isinstance(operation, DeleteObjectOperation):
changed_object.deleted = False
else:
changed_m2m = changed_object.m2m_changes.get(operation.field, None)
if changed_m2m is None:
changed_m2m = ChangedManyToMany()
changed_object.m2m_changes[operation.field] = changed_m2m
if isinstance(operation, ClearManyToManyOperation):
changed_m2m.cleared = True
changed_m2m.added = []
changed_m2m.removed = []
else:
changed_m2m.added = sorted((set(changed_m2m.added) | operation.add_values)
- operation.remove_values)
changed_m2m.removed = sorted((set(changed_m2m.removed) - operation.add_values)
| operation.remove_values)
return list(objects.values())
@dataclass @dataclass
class CollectedChangesPrefetch: class PrefetchedDatabaseOperationCollection:
changes: CollectedOperations operations: DatabaseOperationCollection
instances: dict[ObjectReference, Model] instances: dict[ObjectReference, Model]
def apply(self): def apply(self):
# todo: what if unique constraint error occurs? # todo: what if unique constraint error occurs?
for operation in self.changes.operations: for operation in self.operations.operations:
if isinstance(operation, CreateObjectOperation): if isinstance(operation, CreateObjectOperation):
self.instances[operation.obj] = operation.apply_create() self.instances[operation.obj] = operation.apply_create()
else: else:
prev_obj = self.changes.prev.get(operation.obj) prev_obj = self.operations.prev.get(operation.obj)
if prev_obj is None: if prev_obj is None:
print('WARN WARN WARN') print('WARN WARN WARN')
values = prev_obj.values values = prev_obj.values
@ -271,5 +226,4 @@ class CollectedChangesPrefetch:
else: else:
instance = None instance = None
if instance is not None: if instance is not None:
operation.apply(values=values, instance=instance) operation.apply(values=values, instance=instance)

View file

@ -1,3 +1,4 @@
import copy
import json import json
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -8,8 +9,10 @@ from django.db import transaction
from django.db.models import Model from django.db.models import Model
from django.db.models.fields.related import ManyToManyField from django.db.models.fields.related import ManyToManyField
from c3nav.editor.operations import DatabaseOperation, ObjectReference, FieldValuesDict, CreateObjectOperation, \ from c3nav.editor.changes import ChangedObjectCollection
UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, CollectedOperations from c3nav.editor.operations import DatabaseOperation, CreateObjectOperation, \
UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, \
DatabaseOperationCollection, FieldValuesDict, ObjectReference, PreviousObjectCollection
from c3nav.mapdata.fields import I18nField from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug from c3nav.mapdata.models import LocationSlug
@ -28,21 +31,31 @@ class InterceptAbortTransaction(Exception):
@dataclass @dataclass
class DatabaseOverlayManager: class DatabaseOverlayManager:
changes: CollectedOperations """
new_operations: list[DatabaseOperation] = field(default_factory=list) This class handles the currently active database overlay and will apply and/or intercept changes.
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict) """
prev: PreviousObjectCollection = PreviousObjectCollection()
operations: list[DatabaseOperation] = field(default_factory=list)
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict, init=False, repr=False)
@classmethod @classmethod
@contextmanager @contextmanager
def enable(cls, changes: CollectedOperations | None, commit: bool): def enable(cls, operations: DatabaseOperationCollection | None = None, commit: bool = False):
if getattr(overlay_state, 'manager', None) is not None: """
raise TypeError Context manager to enable the database overlay, optionally pre-applying the given changes.
if changes is None: Only one overlay can be active at the same type, or else you get a TypeError.
changes = CollectedOperations()
:param operations: what operations to pre-apply
:param commit: whether to actually commit operations to the database or revert them at the end
"""
if getattr(overlay_state, "manager", None) is not None:
raise TypeError("Only one overlay can be active at the same time")
if operations is None:
operations = DatabaseOperationCollection()
try: try:
with transaction.atomic(): with transaction.atomic():
manager = DatabaseOverlayManager(changes) manager = DatabaseOverlayManager(prev=copy.deepcopy(operations.prev))
manager.changes.prefetch().apply() operations.prefetch().apply()
overlay_state.manager = manager overlay_state.manager = manager
yield manager yield manager
if not commit: if not commit:
@ -52,9 +65,6 @@ class DatabaseOverlayManager:
finally: finally:
overlay_state.manager = None overlay_state.manager = None
def save_new_operations(self):
self.changes.operations.extend(self.new_operations)
@staticmethod @staticmethod
def get_model_field_values(instance: Model) -> FieldValuesDict: def get_model_field_values(instance: Model) -> FieldValuesDict:
values = json.loads(serializers.serialize("json", [instance]))[0]["fields"] values = json.loads(serializers.serialize("json", [instance]))[0]["fields"]
@ -66,14 +76,14 @@ class DatabaseOverlayManager:
ref = ObjectReference.from_instance(instance) ref = ObjectReference.from_instance(instance)
pre_change_values = self.pre_change_values.pop(ref) pre_change_values = self.pre_change_values.pop(ref)
self.changes.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None)) self.operations.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None))
return ref, pre_change_values return ref, pre_change_values
def handle_pre_change_instance(self, instance: Model, **kwargs): def handle_pre_change_instance(self, instance: Model, **kwargs):
if instance.pk is None: if instance.pk is None:
return return
ref = ObjectReference.from_instance(instance) ref = ObjectReference.from_instance(instance)
if ref not in self.pre_change_values and self.changes.prev.get(ref) is None: if ref not in self.pre_change_values and self.operations.prev.get(ref) is None:
self.pre_change_values[ref] = self.get_model_field_values( self.pre_change_values[ref] = self.get_model_field_values(
instance._meta.model.objects.get(pk=instance.pk) instance._meta.model.objects.get(pk=instance.pk)
) )
@ -84,7 +94,7 @@ class DatabaseOverlayManager:
ref, pre_change_values = self.get_ref_and_pre_change_values(instance) ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
if created: if created:
self.new_operations.append(CreateObjectOperation(obj=ref, fields=field_values)) self.operations.append(CreateObjectOperation(obj=ref, fields=field_values))
return return
if update_fields: if update_fields:
@ -105,11 +115,11 @@ class DatabaseOverlayManager:
diff_val[lang] = after_val.get(lang, None) diff_val[lang] = after_val.get(lang, None)
field_values[field_name] = diff_val field_values[field_name] = diff_val
self.new_operations.append(UpdateObjectOperation(obj=ref, fields=field_values)) self.operations.append(UpdateObjectOperation(obj=ref, fields=field_values))
def handle_post_delete(self, instance: Model, **kwargs): def handle_post_delete(self, instance: Model, **kwargs):
ref, pre_change_values = self.get_ref_and_pre_change_values(instance) ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
self.new_operations.append(DeleteObjectOperation(obj=ref)) self.operations.append(DeleteObjectOperation(obj=ref))
def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model], def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model],
pk_set: set | None, reverse: bool, **kwargs): pk_set: set | None, reverse: bool, **kwargs):
@ -128,11 +138,11 @@ class DatabaseOverlayManager:
ref, pre_change_values = self.get_ref_and_pre_change_values(instance) ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
if action == "post_clear": if action == "post_clear":
self.new_operations.append(ClearManyToManyOperation(obj=ref, field=field.name)) self.operations.append(ClearManyToManyOperation(obj=ref, field=field.name))
return return
if self.new_operations: if self.operations:
last_change = self.new_operations[-1] last_change = self.operations[-1]
if isinstance(last_change, UpdateManyToManyOperation) and last_change == ref and last_change == field.name: if isinstance(last_change, UpdateManyToManyOperation) and last_change == ref and last_change == field.name:
if action == "post_add": if action == "post_add":
last_change.add_values.update(pk_set) last_change.add_values.update(pk_set)
@ -143,9 +153,9 @@ class DatabaseOverlayManager:
return return
if action == "post_add": if action == "post_add":
self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set))) self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set)))
else: else:
self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set))) self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set)))
def handle_pre_change_instance(sender: Type[Model], **kwargs): def handle_pre_change_instance(sender: Type[Model], **kwargs):

View file

@ -48,7 +48,7 @@ def accesses_mapdata(func):
if request.changeset.direct_editing: if request.changeset.direct_editing:
with (MapUpdate.lock() if writable_method else noctx()): with (MapUpdate.lock() if writable_method else noctx()):
changed_geometries.reset() changed_geometries.reset()
with DatabaseOverlayManager.enable(changes=None, commit=writable_method) as manager: with DatabaseOverlayManager.enable(operations=None, commit=writable_method) as manager:
result = func(request, *args, **kwargs) result = func(request, *args, **kwargs)
if manager.new_operations: if manager.new_operations:
if writable_method: if writable_method:
@ -57,10 +57,11 @@ def accesses_mapdata(func):
raise ValueError # todo: good error message, but this shouldn't happen raise ValueError # todo: good error message, but this shouldn't happen
else: else:
with maybe_lock_changeset_to_edit(request=request): with maybe_lock_changeset_to_edit(request=request):
with DatabaseOverlayManager.enable(changes=request.changeset.changes, commit=False) as manager: operations = request.changeset.changes.as_operations # todo: cache this
with DatabaseOverlayManager.enable(operations=operations, commit=False) as manager:
result = func(request, *args, **kwargs) result = func(request, *args, **kwargs)
if manager.new_operations: if manager.operations:
manager.save_new_operations() request.changeset.changes.add_operations(manager.operations)
request.changeset.save() request.changeset.save()
update = request.changeset.updates.create(user=request.user, objects_changed=True) update = request.changeset.updates.create(user=request.user, objects_changed=True)
request.changeset.last_update = update request.changeset.last_update = update

View file

@ -193,7 +193,7 @@ def changeset_detail(request, pk):
added_redirects = {} added_redirects = {}
removed_redirects = {} removed_redirects = {}
for changed_object in changeset.changes.changed_objects: for changed_object in changeset.changes:
if changed_object.obj.model == "locationredirect": if changed_object.obj.model == "locationredirect":
if changed_object.created and not changed_object.deleted: if changed_object.created and not changed_object.deleted:
added_redirects.setdefault(changed_object.fields["target"], set()).add(changed_object.fields["slug"]) added_redirects.setdefault(changed_object.fields["target"], set()).add(changed_object.fields["slug"])
@ -205,7 +205,7 @@ def changeset_detail(request, pk):
current_lang = get_language() current_lang = get_language()
for changed_object in changeset.changes.changed_objects: for changed_object in changeset.changes:
model = apps.get_model("mapdata", changed_object.obj.model) model = apps.get_model("mapdata", changed_object.obj.model)
if model == LocationRedirect: if model == LocationRedirect:
continue continue
@ -218,7 +218,7 @@ def changeset_detail(request, pk):
else: else:
title = next(iter(changed_object.titles.values())) title = next(iter(changed_object.titles.values()))
prev_values = changeset.changes.prev.get(changed_object.obj).values prev_values = changeset.operations.prev.get(changed_object.obj).values
edit_url = None edit_url = None
if not changed_object.deleted: if not changed_object.deleted:

View file

@ -128,7 +128,7 @@ def space_detail(request, level, pk):
def get_changeset_exceeded(request): def get_changeset_exceeded(request):
return request.user_permissions.max_changeset_changes <= len(request.changeset.changes.operations) return request.user_permissions.max_changeset_changes <= len(request.changeset.operations.operations)
@etag(editor_etag_func) @etag(editor_etag_func)