mid-work commit from train

This commit is contained in:
Laura Klünder 2024-09-26 13:19:29 +02:00
parent 55b8e6e78c
commit 174866c2fd
8 changed files with 174 additions and 127 deletions

View file

@ -0,0 +1,82 @@
from itertools import chain
from django.apps import apps
from c3nav.api.schema import BaseSchema
from c3nav.editor.operations import DatabaseOperationCollection, CreateObjectOperation, UpdateObjectOperation, \
DeleteObjectOperation, ClearManyToManyOperation, FieldValuesDict, ObjectReference, PreviousObjectCollection
from c3nav.mapdata.fields import I18nField
class ChangedManyToMany(BaseSchema):
cleared: bool = False
added: list[str] = []
removed: list[str] = []
class ChangedObject(BaseSchema):
obj: ObjectReference
titles: dict[str, str] | None
created: bool = False
deleted: bool = False
fields: FieldValuesDict = {}
m2m_changes: dict[str, ChangedManyToMany] = {}
class ChangedObjectCollection(BaseSchema):
"""
A collection of ChangedObject instances, sorted by model and id.
Also stores a PreviousObjectCollection for comparison with the current state.
Iterable as a list of ChangedObject instances.
"""
prev: PreviousObjectCollection = PreviousObjectCollection()
objects: dict[str, dict[int, ChangedObject]] = {}
def __iter__(self):
yield from chain(*(objects.keys() for model, objects in self.objects.items()))
def add_operations(self, operations: DatabaseOperationCollection):
"""
Add the given operations, creating/updating changed objects to represent the resulting state.
"""
# todo: merge prev
for operation in operations.operations:
changed_object = self.objects.setdefault(operation.obj.model, {}).get(operation.obj.id, None)
if changed_object is None:
changed_object = ChangedObject(obj=operation.obj,
titles=self.prev.get(operation.obj).titles)
self.objects[operation.obj.model][operation.obj.id] = changed_object
if isinstance(operation, CreateObjectOperation):
changed_object.created = True
changed_object.fields.update(operation.fields)
elif isinstance(operation, UpdateObjectOperation):
model = apps.get_model('mapdata', operation.obj.model)
for field_name, value in operation.fields.items():
field = model._meta.get_field(field_name)
if isinstance(field, I18nField) and field_name in changed_object.fields:
changed_object.fields[field_name] = {
lang: val for lang, val in {**changed_object.fields[field_name], **value}.items()
}
else:
changed_object.fields[field_name] = value
elif isinstance(operation, DeleteObjectOperation):
changed_object.deleted = False
else:
changed_m2m = changed_object.m2m_changes.get(operation.field, None)
if changed_m2m is None:
changed_m2m = ChangedManyToMany()
changed_object.m2m_changes[operation.field] = changed_m2m
if isinstance(operation, ClearManyToManyOperation):
changed_m2m.cleared = True
changed_m2m.added = []
changed_m2m.removed = []
else:
changed_m2m.added = sorted((set(changed_m2m.added) | operation.add_values)
- operation.remove_values)
changed_m2m.removed = sorted((set(changed_m2m.removed) - operation.add_values)
| operation.remove_values)
@property
def as_operations(self) -> DatabaseOperationCollection:
pass # todo: implement

View file

@ -1,6 +1,6 @@
# Generated by Django 5.0.8 on 2024-08-26 09:46
# Generated by Django 5.0.8 on 2024-09-26 10:06
import c3nav.editor.operations
import c3nav.editor.changes
import django.core.serializers.json
import django_pydantic_field.fields
from django.db import migrations
@ -20,7 +20,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name='changeset',
name='changes',
field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.operations.CollectedOperations, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.operations.CollectedOperations),
field=django_pydantic_field.fields.PydanticSchemaField(config=None, default=c3nav.editor.changes.ChangedObjectCollection, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=c3nav.editor.changes.ChangedObjectCollection),
),
migrations.DeleteModel(
name='ChangedObject',

View file

@ -12,7 +12,7 @@ from django.utils.translation import gettext_lazy as _
from django.utils.translation import ngettext_lazy
from django_pydantic_field import SchemaField
from c3nav.editor.operations import CollectedOperations
from c3nav.editor.changes import ChangedObjectCollection
from c3nav.editor.tasks import send_changeset_proposed_notification
from c3nav.mapdata.models import LocationSlug, MapUpdate
from c3nav.mapdata.models.locations import LocationRedirect
@ -43,7 +43,7 @@ class ChangeSet(models.Model):
related_name='assigned_changesets', verbose_name=_('assigned to'))
map_update = models.OneToOneField(MapUpdate, null=True, related_name='changeset',
verbose_name=_('map update'), on_delete=models.PROTECT)
changes: CollectedOperations = SchemaField(schema=CollectedOperations, default=CollectedOperations)
changes: ChangedObjectCollection = SchemaField(schema=ChangedObjectCollection, default=ChangedObjectCollection)
class Meta:
verbose_name = _('Change Set')

View file

@ -1,15 +1,14 @@
import copy
import datetime
import json
from dataclasses import dataclass
from typing import TypeAlias, Any, Annotated, Literal, Union
from typing import Annotated, Literal, Union, TypeAlias, Any
from uuid import UUID, uuid4
from django.apps import apps
from django.core import serializers
from django.db.models import Model
from django.utils import timezone
from pydantic.config import ConfigDict
from pydantic import ConfigDict
from pydantic.fields import Field
from pydantic.types import Discriminator
@ -17,10 +16,14 @@ from c3nav.api.schema import BaseSchema
from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug
FieldValuesDict: TypeAlias = dict[str, Any]
class ObjectReference(BaseSchema):
"""
Reference to an object based on model name and ID.
"""
model_config = ConfigDict(frozen=True)
model: str
id: int
@ -31,25 +34,42 @@ class ObjectReference(BaseSchema):
class PreviousObject(BaseSchema):
"""
Represents the previous state of an objects, consisting of its values and its (multi-language) titles
"""
titles: dict[str, str] | None
values: FieldValuesDict
class PreviousObjects(BaseSchema):
class PreviousObjectCollection(BaseSchema):
objects: dict[str, dict[int, PreviousObject]] = {}
def get(self, ref: ObjectReference) -> PreviousObject | None:
return self.objects.get(ref.model, {}).get(ref.id, None)
def get_ids(self) -> dict[str, set[int]]:
"""
:return: all referenced IDs sorted by model
"""
return {model: set(objs.keys()) for model, objs in self.objects.items()}
def get_instances(self) -> dict[ObjectReference, Model]:
"""
:return: all reference objects as fetched from the databse right now
"""
instances: dict[ObjectReference, Model] = {}
for model_name, ids in self.get_ids().items():
model = apps.get_model("mapdata", model_name)
instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance)
for instance in model.objects.filter(pk__in=ids)))
return instances
def set(self, ref: ObjectReference, values: FieldValuesDict, titles: dict | None):
self.objects.setdefault(ref.model, {})[ref.id] = PreviousObject(
values=values,
titles=titles,
)
def get_ids(self) -> dict[str, set[int]]:
return {model: set(objs.keys()) for model, objs in self.objects.items()}
class BaseOperation(BaseSchema):
obj: ObjectReference
@ -155,11 +175,6 @@ class ClearManyToManyOperation(BaseOperation):
return instance
class RevertOperation(BaseOperation):
type: Literal["revert"] = "revert"
reverts: UUID
DatabaseOperation = Annotated[
Union[
CreateObjectOperation,
@ -172,94 +187,34 @@ DatabaseOperation = Annotated[
]
class ChangedManyToMany(BaseSchema):
cleared: bool = False
added: list[str] = []
removed: list[str] = []
class ChangedObject(BaseSchema):
obj: ObjectReference
titles: dict[str, str] | None
created: bool = False
deleted: bool = False
fields: FieldValuesDict = {}
m2m_changes: dict[str, ChangedManyToMany] = {}
class CollectedOperations(BaseSchema):
uuid: UUID = Field(default_factory=uuid4)
prev: PreviousObjects = PreviousObjects()
class DatabaseOperationCollection(BaseSchema):
"""
A collection of database operations, sorted by model and id.
Also stores a PreviousObjectCollection for comparison with the current state.
Iterable as a list of DatabaseOperation instances.
"""
prev: PreviousObjectCollection = PreviousObjectCollection()
operations: list[DatabaseOperation] = []
def prefetch(self) -> "CollectedChangesPrefetch":
ids_to_query: dict[str, set[int]] = self.prev.get_ids()
def __iter__(self):
yield from self.operations
instances: dict[ObjectReference, Model] = {}
for model_name, ids in ids_to_query.items():
model = apps.get_model("mapdata", model_name)
instances.update(dict((ObjectReference(model=model_name, id=instance.pk), instance)
for instance in model.objects.filter(pk__in=ids)))
return CollectedChangesPrefetch(changes=self, instances=instances)
@property
def changed_objects(self) -> list[ChangedObject]:
objects = {}
reverted_uuids = frozenset(operation.reverts for operation in self.operationsy
if isinstance(operation, RevertOperation))
for operation in self.operations:
if operation.uuid in reverted_uuids:
continue
changed_object = objects.get(operation.obj, None)
if changed_object is None:
changed_object = ChangedObject(obj=operation.obj,
titles=self.prev_titles[operation.obj.model][operation.obj.id])
objects[operation.obj] = changed_object
if isinstance(operation, CreateObjectOperation):
changed_object.created = True
changed_object.fields.update(operation.fields)
elif isinstance(operation, UpdateObjectOperation):
model = apps.get_model('mapdata', operation.obj.model)
for field_name, value in operation.fields.items():
field = model._meta.get_field(field_name)
if isinstance(field, I18nField) and field_name in changed_object.fields:
changed_object.fields[field_name] = {
lang: val for lang, val in {**changed_object.fields[field_name], **value}.items()
}
else:
changed_object.fields[field_name] = value
elif isinstance(operation, DeleteObjectOperation):
changed_object.deleted = False
else:
changed_m2m = changed_object.m2m_changes.get(operation.field, None)
if changed_m2m is None:
changed_m2m = ChangedManyToMany()
changed_object.m2m_changes[operation.field] = changed_m2m
if isinstance(operation, ClearManyToManyOperation):
changed_m2m.cleared = True
changed_m2m.added = []
changed_m2m.removed = []
else:
changed_m2m.added = sorted((set(changed_m2m.added) | operation.add_values)
- operation.remove_values)
changed_m2m.removed = sorted((set(changed_m2m.removed) - operation.add_values)
| operation.remove_values)
return list(objects.values())
def prefetch(self) -> "PrefetchedDatabaseOperationCollection":
return PrefetchedDatabaseOperationCollection(operations=self, instances=self.prev.get_instances())
@dataclass
class CollectedChangesPrefetch:
changes: CollectedOperations
class PrefetchedDatabaseOperationCollection:
operations: DatabaseOperationCollection
instances: dict[ObjectReference, Model]
def apply(self):
# todo: what if unique constraint error occurs?
for operation in self.changes.operations:
for operation in self.operations.operations:
if isinstance(operation, CreateObjectOperation):
self.instances[operation.obj] = operation.apply_create()
else:
prev_obj = self.changes.prev.get(operation.obj)
prev_obj = self.operations.prev.get(operation.obj)
if prev_obj is None:
print('WARN WARN WARN')
values = prev_obj.values
@ -271,5 +226,4 @@ class CollectedChangesPrefetch:
else:
instance = None
if instance is not None:
operation.apply(values=values, instance=instance)
operation.apply(values=values, instance=instance)

View file

@ -1,3 +1,4 @@
import copy
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
@ -8,8 +9,10 @@ from django.db import transaction
from django.db.models import Model
from django.db.models.fields.related import ManyToManyField
from c3nav.editor.operations import DatabaseOperation, ObjectReference, FieldValuesDict, CreateObjectOperation, \
UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, CollectedOperations
from c3nav.editor.changes import ChangedObjectCollection
from c3nav.editor.operations import DatabaseOperation, CreateObjectOperation, \
UpdateObjectOperation, DeleteObjectOperation, ClearManyToManyOperation, UpdateManyToManyOperation, \
DatabaseOperationCollection, FieldValuesDict, ObjectReference, PreviousObjectCollection
from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug
@ -28,21 +31,31 @@ class InterceptAbortTransaction(Exception):
@dataclass
class DatabaseOverlayManager:
changes: CollectedOperations
new_operations: list[DatabaseOperation] = field(default_factory=list)
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict)
"""
This class handles the currently active database overlay and will apply and/or intercept changes.
"""
prev: PreviousObjectCollection = PreviousObjectCollection()
operations: list[DatabaseOperation] = field(default_factory=list)
pre_change_values: dict[ObjectReference, FieldValuesDict] = field(default_factory=dict, init=False, repr=False)
@classmethod
@contextmanager
def enable(cls, changes: CollectedOperations | None, commit: bool):
if getattr(overlay_state, 'manager', None) is not None:
raise TypeError
if changes is None:
changes = CollectedOperations()
def enable(cls, operations: DatabaseOperationCollection | None = None, commit: bool = False):
"""
Context manager to enable the database overlay, optionally pre-applying the given changes.
Only one overlay can be active at the same type, or else you get a TypeError.
:param operations: what operations to pre-apply
:param commit: whether to actually commit operations to the database or revert them at the end
"""
if getattr(overlay_state, "manager", None) is not None:
raise TypeError("Only one overlay can be active at the same time")
if operations is None:
operations = DatabaseOperationCollection()
try:
with transaction.atomic():
manager = DatabaseOverlayManager(changes)
manager.changes.prefetch().apply()
manager = DatabaseOverlayManager(prev=copy.deepcopy(operations.prev))
operations.prefetch().apply()
overlay_state.manager = manager
yield manager
if not commit:
@ -52,9 +65,6 @@ class DatabaseOverlayManager:
finally:
overlay_state.manager = None
def save_new_operations(self):
self.changes.operations.extend(self.new_operations)
@staticmethod
def get_model_field_values(instance: Model) -> FieldValuesDict:
values = json.loads(serializers.serialize("json", [instance]))[0]["fields"]
@ -66,14 +76,14 @@ class DatabaseOverlayManager:
ref = ObjectReference.from_instance(instance)
pre_change_values = self.pre_change_values.pop(ref)
self.changes.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None))
self.operations.prev.set(ref, values=pre_change_values, titles=getattr(instance, 'titles', None))
return ref, pre_change_values
def handle_pre_change_instance(self, instance: Model, **kwargs):
if instance.pk is None:
return
ref = ObjectReference.from_instance(instance)
if ref not in self.pre_change_values and self.changes.prev.get(ref) is None:
if ref not in self.pre_change_values and self.operations.prev.get(ref) is None:
self.pre_change_values[ref] = self.get_model_field_values(
instance._meta.model.objects.get(pk=instance.pk)
)
@ -84,7 +94,7 @@ class DatabaseOverlayManager:
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
if created:
self.new_operations.append(CreateObjectOperation(obj=ref, fields=field_values))
self.operations.append(CreateObjectOperation(obj=ref, fields=field_values))
return
if update_fields:
@ -105,11 +115,11 @@ class DatabaseOverlayManager:
diff_val[lang] = after_val.get(lang, None)
field_values[field_name] = diff_val
self.new_operations.append(UpdateObjectOperation(obj=ref, fields=field_values))
self.operations.append(UpdateObjectOperation(obj=ref, fields=field_values))
def handle_post_delete(self, instance: Model, **kwargs):
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
self.new_operations.append(DeleteObjectOperation(obj=ref))
self.operations.append(DeleteObjectOperation(obj=ref))
def handle_m2m_changed(self, sender: Type[Model], instance: Model, action: str, model: Type[Model],
pk_set: set | None, reverse: bool, **kwargs):
@ -128,11 +138,11 @@ class DatabaseOverlayManager:
ref, pre_change_values = self.get_ref_and_pre_change_values(instance)
if action == "post_clear":
self.new_operations.append(ClearManyToManyOperation(obj=ref, field=field.name))
self.operations.append(ClearManyToManyOperation(obj=ref, field=field.name))
return
if self.new_operations:
last_change = self.new_operations[-1]
if self.operations:
last_change = self.operations[-1]
if isinstance(last_change, UpdateManyToManyOperation) and last_change == ref and last_change == field.name:
if action == "post_add":
last_change.add_values.update(pk_set)
@ -143,9 +153,9 @@ class DatabaseOverlayManager:
return
if action == "post_add":
self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set)))
self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, add_values=list(pk_set)))
else:
self.new_operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set)))
self.operations.append(UpdateManyToManyOperation(obj=ref, field=field.name, remove_values=list(pk_set)))
def handle_pre_change_instance(sender: Type[Model], **kwargs):

View file

@ -48,7 +48,7 @@ def accesses_mapdata(func):
if request.changeset.direct_editing:
with (MapUpdate.lock() if writable_method else noctx()):
changed_geometries.reset()
with DatabaseOverlayManager.enable(changes=None, commit=writable_method) as manager:
with DatabaseOverlayManager.enable(operations=None, commit=writable_method) as manager:
result = func(request, *args, **kwargs)
if manager.new_operations:
if writable_method:
@ -57,10 +57,11 @@ def accesses_mapdata(func):
raise ValueError # todo: good error message, but this shouldn't happen
else:
with maybe_lock_changeset_to_edit(request=request):
with DatabaseOverlayManager.enable(changes=request.changeset.changes, commit=False) as manager:
operations = request.changeset.changes.as_operations # todo: cache this
with DatabaseOverlayManager.enable(operations=operations, commit=False) as manager:
result = func(request, *args, **kwargs)
if manager.new_operations:
manager.save_new_operations()
if manager.operations:
request.changeset.changes.add_operations(manager.operations)
request.changeset.save()
update = request.changeset.updates.create(user=request.user, objects_changed=True)
request.changeset.last_update = update

View file

@ -193,7 +193,7 @@ def changeset_detail(request, pk):
added_redirects = {}
removed_redirects = {}
for changed_object in changeset.changes.changed_objects:
for changed_object in changeset.changes:
if changed_object.obj.model == "locationredirect":
if changed_object.created and not changed_object.deleted:
added_redirects.setdefault(changed_object.fields["target"], set()).add(changed_object.fields["slug"])
@ -205,7 +205,7 @@ def changeset_detail(request, pk):
current_lang = get_language()
for changed_object in changeset.changes.changed_objects:
for changed_object in changeset.changes:
model = apps.get_model("mapdata", changed_object.obj.model)
if model == LocationRedirect:
continue
@ -218,7 +218,7 @@ def changeset_detail(request, pk):
else:
title = next(iter(changed_object.titles.values()))
prev_values = changeset.changes.prev.get(changed_object.obj).values
prev_values = changeset.operations.prev.get(changed_object.obj).values
edit_url = None
if not changed_object.deleted:

View file

@ -128,7 +128,7 @@ def space_detail(request, level, pk):
def get_changeset_exceeded(request):
return request.user_permissions.max_changeset_changes <= len(request.changeset.changes.operations)
return request.user_permissions.max_changeset_changes <= len(request.changeset.operations.operations)
@etag(editor_etag_func)