diff --git a/src/c3nav/editor/changes.py b/src/c3nav/editor/changes.py index 8f684767..bb22ecd4 100644 --- a/src/c3nav/editor/changes.py +++ b/src/c3nav/editor/changes.py @@ -1,23 +1,24 @@ import operator from functools import reduce from itertools import chain -from typing import Type, Any, Optional, Annotated, Union +from typing import Type, Any, Union from django.apps import apps -from django.db.models import Model, OneToOneField, ForeignKey +from django.db.models import Model, Q +from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel from pydantic.config import ConfigDict from c3nav.api.schema import BaseSchema from c3nav.editor.operations import DatabaseOperationCollection, CreateObjectOperation, UpdateObjectOperation, \ DeleteObjectOperation, ClearManyToManyOperation, FieldValuesDict, ObjectReference, PreviousObjectCollection, \ - DatabaseOperation + DatabaseOperation, ObjectID, FieldName, ModelName from c3nav.mapdata.fields import I18nField class ChangedManyToMany(BaseSchema): cleared: bool = False - added: list[int] = [] - removed: list[int] = [] + added: list[ObjectID] = [] + removed: list[ObjectID] = [] class ChangedObject(BaseSchema): @@ -26,7 +27,7 @@ class ChangedObject(BaseSchema): created: bool = False deleted: bool = False fields: FieldValuesDict = {} - m2m_changes: dict[str, ChangedManyToMany] = {} + m2m_changes: dict[FieldName, ChangedManyToMany] = {} class OperationDependencyObjectExists(BaseSchema): @@ -38,7 +39,7 @@ class OperationDependencyUniqueValue(BaseSchema): model_config = ConfigDict(frozen=True) model: str - field: str + field: FieldName value: Any nullable: bool @@ -60,6 +61,10 @@ class SingleOperationWithDependencies(BaseSchema): operation: DatabaseOperation dependencies: set[OperationDependency] = set() + @property + def main_operation(self) -> DatabaseOperation: + return self.operation + class MergableOperationsWithDependencies(BaseSchema): children: list[SingleOperationWithDependencies] @@ -68,6 +73,10 @@ class MergableOperationsWithDependencies(BaseSchema): def dependencies(self) -> set[OperationDependency]: return reduce(operator.or_, (c.dependencies for c in self.children), set()) + @property + def main_operation(self) -> DatabaseOperation: + return self.children[0].operation + OperationWithDependencies = Union[ SingleOperationWithDependencies, @@ -75,6 +84,14 @@ OperationWithDependencies = Union[ ] +class FoundObjectReference(BaseSchema): + model_config = ConfigDict(frozen=True) + + obj: ObjectReference + field: FieldName + on_delete: str + + class DummyValue: pass @@ -86,7 +103,7 @@ class ChangedObjectCollection(BaseSchema): Iterable as a list of ChangedObject instances. """ prev: PreviousObjectCollection = PreviousObjectCollection() - objects: dict[str, dict[int, ChangedObject]] = {} + objects: dict[ModelName, dict[ObjectID, ChangedObject]] = {} def __iter__(self): yield from chain(*(objects.values() for model, objects in self.objects.items())) @@ -137,12 +154,12 @@ class ChangedObjectCollection(BaseSchema): | operation.remove_values) def clean_and_complete_prev(self): - ids: dict[str, set[int]] = {} + ids: dict[ModelName, set[ObjectID]] = {} for model_name, changed_objects in self.objects.items(): ids.setdefault(model_name, set()).update(set(changed_objects.keys())) model = apps.get_model("mapdata", model_name) - relations: dict[str, Type[Model]] = {field.name: field.related_model - for field in model.get_fields() if field.is_relation} + relations: dict[FieldName, Type[Model]] = {field.name: field.related_model + for field in model.get_fields() if field.is_relation} for obj in changed_objects.values(): for field_name, value in obj.fields.items(): related_model = relations.get(field_name, None) @@ -225,6 +242,67 @@ class ChangedObjectCollection(BaseSchema): from pprint import pprint pprint(operations_with_dependencies) + # time to check which stuff cannot be done + objects_to_delete: dict[ModelName, set[ObjectID]] = {} # objects that will be deleted [find references!] + objects_to_exist_before: dict[ModelName, set[ObjectID]] = {} # objects that need to exist before [won't be created!] + objects_to_create: dict[ModelName, set[ObjectID]] = {} # objects that will be created [needed to create the previous var] + for operation in operations_with_dependencies: + main_operation = operation.main_operation + if isinstance(main_operation, DeleteObjectOperation): + objects_to_delete.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id) + objects_to_exist_before.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id) + + if isinstance(main_operation, UpdateObjectOperation): + objects_to_exist_before.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id) + else: + objects_to_create.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id) + + for dependency in operation.dependencies: + if isinstance(dependency, OperationDependencyObjectExists): + objects_to_exist_before.setdefault(dependency.obj.model, set()).add(dependency.obj.id) + + # objects that we create do not need to exist before + for model, ids in objects_to_create.items(): + objects_to_exist_before.get(model, set()).difference_update(ids) + + # let's find which objects that need to exist before actually exist + objects_exist_before: dict[ModelName, dict[ObjectID, bool]] = {} + for model, ids in objects_to_exist_before.items(): + model_cls = apps.get_model('mapdata', model) + ids_found = set(model_cls.objects.filter(pk__in=ids).values_list('pk', flat=True)) + objects_exist_before[model] = {id_: (id_ in ids_found) for id_ in ids} + + # let's find which protected references objects we want to delete have + potential_fields: dict[ModelName, dict[FieldName, dict[ModelName, set[ObjectID]]]] = {} + for model, ids in objects_to_exist_before.items(): + for field in apps.get_model('mapdata', model)._meta.get_fields(): + if isinstance(field, (ManyToOneRel, OneToOneRel)) or field.model._meta.app_label != "mapdata": + continue + potential_fields.setdefault(field.related_model._meta.model_name, + {}).setdefault(field.field.attname, {})[model] = ids + + # collect all references + found_obj_references: dict[ModelName, dict[ObjectID, set[FoundObjectReference]]] = {} + for model, fields in potential_fields.items(): + model_cls = apps.get_model('mapdata', model) + q = Q() + targets_reverse: dict[FieldName, dict[ObjectID, ModelName]] = {} + for field_name, targets in fields.items(): + ids = reduce(operator.or_, targets.values(), set()) + q |= Q(**{f'{field_name}__in': ids}) + targets_reverse[field_name] = dict(chain(*(((id_, target_model) for id_, in target_ids) + for target_model, target_ids in targets))) + for result in model_cls.objects.filter(q).values("id", *fields.keys()): + source_ref = ObjectReference(model=model, id=result.pop("id")) + for field, target_id in result.items(): + target_model = targets_reverse[field][target_id] + found_obj_references.setdefault(target_model, {}).setdefault(target_id, set()).add( + FoundObjectReference(obj=source_ref, field=field, + on_delete=model_cls._meta.get_field(field).on_delete.__name__) + ) + + + # todo: continue here return DatabaseOperationCollection() diff --git a/src/c3nav/editor/operations.py b/src/c3nav/editor/operations.py index 2e26fc1d..d3ead736 100644 --- a/src/c3nav/editor/operations.py +++ b/src/c3nav/editor/operations.py @@ -1,23 +1,22 @@ -import datetime import json from dataclasses import dataclass -from typing import Annotated, Literal, Union, TypeAlias, Any, Self -from uuid import UUID, uuid4 +from typing import Annotated, Literal, Union, TypeAlias, Any, Self, Iterator from django.apps import apps from django.core import serializers from django.db.models import Model -from django.utils import timezone from pydantic import ConfigDict -from pydantic.fields import Field from pydantic.types import Discriminator from c3nav.api.schema import BaseSchema from c3nav.mapdata.fields import I18nField from c3nav.mapdata.models import LocationSlug +ModelName: TypeAlias = str +ObjectID: TypeAlias = int +FieldName: TypeAlias = str -FieldValuesDict: TypeAlias = dict[str, Any] +FieldValuesDict: TypeAlias = dict[FieldName, Any] class ObjectReference(BaseSchema): @@ -25,8 +24,8 @@ class ObjectReference(BaseSchema): Reference to an object based on model name and ID. """ model_config = ConfigDict(frozen=True) - model: str - id: int + model: ModelName + id: ObjectID @classmethod def from_instance(cls, instance: Model): @@ -42,12 +41,12 @@ class PreviousObject(BaseSchema): class PreviousObjectCollection(BaseSchema): - objects: dict[str, dict[int, PreviousObject]] = {} + objects: dict[ModelName, dict[ObjectID, PreviousObject]] = {} def get(self, ref: ObjectReference) -> PreviousObject | None: return self.objects.get(ref.model, {}).get(ref.id, None) - def get_ids(self) -> dict[str, set[int]]: + def get_ids(self) -> dict[ModelName, set[ObjectID]]: """ :return: all referenced IDs sorted by model """ @@ -155,9 +154,9 @@ class DeleteObjectOperation(BaseOperation): class UpdateManyToManyOperation(BaseOperation): type: Literal["m2m_add"] = "m2m_update" - field: str - add_values: set[int] = set() - remove_values: set[int] = set() + field: FieldName + add_values: set[ObjectID] = set() + remove_values: set[ObjectID] = set() def apply(self, values: FieldValuesDict, instance: Model) -> Model: values[self.field] = sorted((set(values[self.field]) | self.add_values) - self.remove_values) @@ -169,7 +168,7 @@ class UpdateManyToManyOperation(BaseOperation): class ClearManyToManyOperation(BaseOperation): type: Literal["m2m_clear"] = "m2m_clear" - field: str + field: FieldName def apply(self, values: FieldValuesDict, instance: Model) -> Model: values[self.field] = [] @@ -198,7 +197,7 @@ class DatabaseOperationCollection(BaseSchema): prev: PreviousObjectCollection = PreviousObjectCollection() _operations: list[DatabaseOperation] = [] - def __iter__(self): + def __iter__(self) -> Iterator[DatabaseOperation]: yield from self._operations def __len__(self): diff --git a/src/c3nav/mapdata/models/locations.py b/src/c3nav/mapdata/models/locations.py index 29dbea94..69f019e2 100644 --- a/src/c3nav/mapdata/models/locations.py +++ b/src/c3nav/mapdata/models/locations.py @@ -177,7 +177,7 @@ class Location(LocationSlug, AccessRestrictionMixin, TitledMixin, models.Model): class SpecificLocation(Location, models.Model): - groups = models.ManyToManyField('mapdata.LocationGroup', verbose_name=_('Location Groups'), blank=True) + groups = models.ManyToManyField('mapdata.LocationGroup', verbose_name__=_('Location Groups'), blank=True) label_settings = models.ForeignKey('mapdata.LabelSettings', null=True, blank=True, on_delete=models.PROTECT, verbose_name=_('label settings')) label_override = I18nField(_('Label override'), plural_name='label_overrides', blank=True, fallback_any=True)