find all references and more as_operations code

This commit is contained in:
Laura Klünder 2024-11-05 13:33:16 +01:00
parent 54920a2f54
commit ce9c87ae4c
3 changed files with 104 additions and 27 deletions

View file

@ -1,23 +1,24 @@
import operator import operator
from functools import reduce from functools import reduce
from itertools import chain from itertools import chain
from typing import Type, Any, Optional, Annotated, Union from typing import Type, Any, Union
from django.apps import apps from django.apps import apps
from django.db.models import Model, OneToOneField, ForeignKey from django.db.models import Model, Q
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel
from pydantic.config import ConfigDict from pydantic.config import ConfigDict
from c3nav.api.schema import BaseSchema from c3nav.api.schema import BaseSchema
from c3nav.editor.operations import DatabaseOperationCollection, CreateObjectOperation, UpdateObjectOperation, \ from c3nav.editor.operations import DatabaseOperationCollection, CreateObjectOperation, UpdateObjectOperation, \
DeleteObjectOperation, ClearManyToManyOperation, FieldValuesDict, ObjectReference, PreviousObjectCollection, \ DeleteObjectOperation, ClearManyToManyOperation, FieldValuesDict, ObjectReference, PreviousObjectCollection, \
DatabaseOperation DatabaseOperation, ObjectID, FieldName, ModelName
from c3nav.mapdata.fields import I18nField from c3nav.mapdata.fields import I18nField
class ChangedManyToMany(BaseSchema): class ChangedManyToMany(BaseSchema):
cleared: bool = False cleared: bool = False
added: list[int] = [] added: list[ObjectID] = []
removed: list[int] = [] removed: list[ObjectID] = []
class ChangedObject(BaseSchema): class ChangedObject(BaseSchema):
@ -26,7 +27,7 @@ class ChangedObject(BaseSchema):
created: bool = False created: bool = False
deleted: bool = False deleted: bool = False
fields: FieldValuesDict = {} fields: FieldValuesDict = {}
m2m_changes: dict[str, ChangedManyToMany] = {} m2m_changes: dict[FieldName, ChangedManyToMany] = {}
class OperationDependencyObjectExists(BaseSchema): class OperationDependencyObjectExists(BaseSchema):
@ -38,7 +39,7 @@ class OperationDependencyUniqueValue(BaseSchema):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
model: str model: str
field: str field: FieldName
value: Any value: Any
nullable: bool nullable: bool
@ -60,6 +61,10 @@ class SingleOperationWithDependencies(BaseSchema):
operation: DatabaseOperation operation: DatabaseOperation
dependencies: set[OperationDependency] = set() dependencies: set[OperationDependency] = set()
@property
def main_operation(self) -> DatabaseOperation:
return self.operation
class MergableOperationsWithDependencies(BaseSchema): class MergableOperationsWithDependencies(BaseSchema):
children: list[SingleOperationWithDependencies] children: list[SingleOperationWithDependencies]
@ -68,6 +73,10 @@ class MergableOperationsWithDependencies(BaseSchema):
def dependencies(self) -> set[OperationDependency]: def dependencies(self) -> set[OperationDependency]:
return reduce(operator.or_, (c.dependencies for c in self.children), set()) return reduce(operator.or_, (c.dependencies for c in self.children), set())
@property
def main_operation(self) -> DatabaseOperation:
return self.children[0].operation
OperationWithDependencies = Union[ OperationWithDependencies = Union[
SingleOperationWithDependencies, SingleOperationWithDependencies,
@ -75,6 +84,14 @@ OperationWithDependencies = Union[
] ]
class FoundObjectReference(BaseSchema):
model_config = ConfigDict(frozen=True)
obj: ObjectReference
field: FieldName
on_delete: str
class DummyValue: class DummyValue:
pass pass
@ -86,7 +103,7 @@ class ChangedObjectCollection(BaseSchema):
Iterable as a list of ChangedObject instances. Iterable as a list of ChangedObject instances.
""" """
prev: PreviousObjectCollection = PreviousObjectCollection() prev: PreviousObjectCollection = PreviousObjectCollection()
objects: dict[str, dict[int, ChangedObject]] = {} objects: dict[ModelName, dict[ObjectID, ChangedObject]] = {}
def __iter__(self): def __iter__(self):
yield from chain(*(objects.values() for model, objects in self.objects.items())) yield from chain(*(objects.values() for model, objects in self.objects.items()))
@ -137,12 +154,12 @@ class ChangedObjectCollection(BaseSchema):
| operation.remove_values) | operation.remove_values)
def clean_and_complete_prev(self): def clean_and_complete_prev(self):
ids: dict[str, set[int]] = {} ids: dict[ModelName, set[ObjectID]] = {}
for model_name, changed_objects in self.objects.items(): for model_name, changed_objects in self.objects.items():
ids.setdefault(model_name, set()).update(set(changed_objects.keys())) ids.setdefault(model_name, set()).update(set(changed_objects.keys()))
model = apps.get_model("mapdata", model_name) model = apps.get_model("mapdata", model_name)
relations: dict[str, Type[Model]] = {field.name: field.related_model relations: dict[FieldName, Type[Model]] = {field.name: field.related_model
for field in model.get_fields() if field.is_relation} for field in model.get_fields() if field.is_relation}
for obj in changed_objects.values(): for obj in changed_objects.values():
for field_name, value in obj.fields.items(): for field_name, value in obj.fields.items():
related_model = relations.get(field_name, None) related_model = relations.get(field_name, None)
@ -225,6 +242,67 @@ class ChangedObjectCollection(BaseSchema):
from pprint import pprint from pprint import pprint
pprint(operations_with_dependencies) pprint(operations_with_dependencies)
# time to check which stuff cannot be done
objects_to_delete: dict[ModelName, set[ObjectID]] = {} # objects that will be deleted [find references!]
objects_to_exist_before: dict[ModelName, set[ObjectID]] = {} # objects that need to exist before [won't be created!]
objects_to_create: dict[ModelName, set[ObjectID]] = {} # objects that will be created [needed to create the previous var]
for operation in operations_with_dependencies:
main_operation = operation.main_operation
if isinstance(main_operation, DeleteObjectOperation):
objects_to_delete.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
objects_to_exist_before.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
if isinstance(main_operation, UpdateObjectOperation):
objects_to_exist_before.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
else:
objects_to_create.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
for dependency in operation.dependencies:
if isinstance(dependency, OperationDependencyObjectExists):
objects_to_exist_before.setdefault(dependency.obj.model, set()).add(dependency.obj.id)
# objects that we create do not need to exist before
for model, ids in objects_to_create.items():
objects_to_exist_before.get(model, set()).difference_update(ids)
# let's find which objects that need to exist before actually exist
objects_exist_before: dict[ModelName, dict[ObjectID, bool]] = {}
for model, ids in objects_to_exist_before.items():
model_cls = apps.get_model('mapdata', model)
ids_found = set(model_cls.objects.filter(pk__in=ids).values_list('pk', flat=True))
objects_exist_before[model] = {id_: (id_ in ids_found) for id_ in ids}
# let's find which protected references objects we want to delete have
potential_fields: dict[ModelName, dict[FieldName, dict[ModelName, set[ObjectID]]]] = {}
for model, ids in objects_to_exist_before.items():
for field in apps.get_model('mapdata', model)._meta.get_fields():
if isinstance(field, (ManyToOneRel, OneToOneRel)) or field.model._meta.app_label != "mapdata":
continue
potential_fields.setdefault(field.related_model._meta.model_name,
{}).setdefault(field.field.attname, {})[model] = ids
# collect all references
found_obj_references: dict[ModelName, dict[ObjectID, set[FoundObjectReference]]] = {}
for model, fields in potential_fields.items():
model_cls = apps.get_model('mapdata', model)
q = Q()
targets_reverse: dict[FieldName, dict[ObjectID, ModelName]] = {}
for field_name, targets in fields.items():
ids = reduce(operator.or_, targets.values(), set())
q |= Q(**{f'{field_name}__in': ids})
targets_reverse[field_name] = dict(chain(*(((id_, target_model) for id_, in target_ids)
for target_model, target_ids in targets)))
for result in model_cls.objects.filter(q).values("id", *fields.keys()):
source_ref = ObjectReference(model=model, id=result.pop("id"))
for field, target_id in result.items():
target_model = targets_reverse[field][target_id]
found_obj_references.setdefault(target_model, {}).setdefault(target_id, set()).add(
FoundObjectReference(obj=source_ref, field=field,
on_delete=model_cls._meta.get_field(field).on_delete.__name__)
)
# todo: continue here # todo: continue here
return DatabaseOperationCollection() return DatabaseOperationCollection()

View file

@ -1,23 +1,22 @@
import datetime
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Literal, Union, TypeAlias, Any, Self from typing import Annotated, Literal, Union, TypeAlias, Any, Self, Iterator
from uuid import UUID, uuid4
from django.apps import apps from django.apps import apps
from django.core import serializers from django.core import serializers
from django.db.models import Model from django.db.models import Model
from django.utils import timezone
from pydantic import ConfigDict from pydantic import ConfigDict
from pydantic.fields import Field
from pydantic.types import Discriminator from pydantic.types import Discriminator
from c3nav.api.schema import BaseSchema from c3nav.api.schema import BaseSchema
from c3nav.mapdata.fields import I18nField from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug from c3nav.mapdata.models import LocationSlug
ModelName: TypeAlias = str
ObjectID: TypeAlias = int
FieldName: TypeAlias = str
FieldValuesDict: TypeAlias = dict[str, Any] FieldValuesDict: TypeAlias = dict[FieldName, Any]
class ObjectReference(BaseSchema): class ObjectReference(BaseSchema):
@ -25,8 +24,8 @@ class ObjectReference(BaseSchema):
Reference to an object based on model name and ID. Reference to an object based on model name and ID.
""" """
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
model: str model: ModelName
id: int id: ObjectID
@classmethod @classmethod
def from_instance(cls, instance: Model): def from_instance(cls, instance: Model):
@ -42,12 +41,12 @@ class PreviousObject(BaseSchema):
class PreviousObjectCollection(BaseSchema): class PreviousObjectCollection(BaseSchema):
objects: dict[str, dict[int, PreviousObject]] = {} objects: dict[ModelName, dict[ObjectID, PreviousObject]] = {}
def get(self, ref: ObjectReference) -> PreviousObject | None: def get(self, ref: ObjectReference) -> PreviousObject | None:
return self.objects.get(ref.model, {}).get(ref.id, None) return self.objects.get(ref.model, {}).get(ref.id, None)
def get_ids(self) -> dict[str, set[int]]: def get_ids(self) -> dict[ModelName, set[ObjectID]]:
""" """
:return: all referenced IDs sorted by model :return: all referenced IDs sorted by model
""" """
@ -155,9 +154,9 @@ class DeleteObjectOperation(BaseOperation):
class UpdateManyToManyOperation(BaseOperation): class UpdateManyToManyOperation(BaseOperation):
type: Literal["m2m_add"] = "m2m_update" type: Literal["m2m_add"] = "m2m_update"
field: str field: FieldName
add_values: set[int] = set() add_values: set[ObjectID] = set()
remove_values: set[int] = set() remove_values: set[ObjectID] = set()
def apply(self, values: FieldValuesDict, instance: Model) -> Model: def apply(self, values: FieldValuesDict, instance: Model) -> Model:
values[self.field] = sorted((set(values[self.field]) | self.add_values) - self.remove_values) values[self.field] = sorted((set(values[self.field]) | self.add_values) - self.remove_values)
@ -169,7 +168,7 @@ class UpdateManyToManyOperation(BaseOperation):
class ClearManyToManyOperation(BaseOperation): class ClearManyToManyOperation(BaseOperation):
type: Literal["m2m_clear"] = "m2m_clear" type: Literal["m2m_clear"] = "m2m_clear"
field: str field: FieldName
def apply(self, values: FieldValuesDict, instance: Model) -> Model: def apply(self, values: FieldValuesDict, instance: Model) -> Model:
values[self.field] = [] values[self.field] = []
@ -198,7 +197,7 @@ class DatabaseOperationCollection(BaseSchema):
prev: PreviousObjectCollection = PreviousObjectCollection() prev: PreviousObjectCollection = PreviousObjectCollection()
_operations: list[DatabaseOperation] = [] _operations: list[DatabaseOperation] = []
def __iter__(self): def __iter__(self) -> Iterator[DatabaseOperation]:
yield from self._operations yield from self._operations
def __len__(self): def __len__(self):

View file

@ -177,7 +177,7 @@ class Location(LocationSlug, AccessRestrictionMixin, TitledMixin, models.Model):
class SpecificLocation(Location, models.Model): class SpecificLocation(Location, models.Model):
groups = models.ManyToManyField('mapdata.LocationGroup', verbose_name=_('Location Groups'), blank=True) groups = models.ManyToManyField('mapdata.LocationGroup', verbose_name__=_('Location Groups'), blank=True)
label_settings = models.ForeignKey('mapdata.LabelSettings', null=True, blank=True, on_delete=models.PROTECT, label_settings = models.ForeignKey('mapdata.LabelSettings', null=True, blank=True, on_delete=models.PROTECT,
verbose_name=_('label settings')) verbose_name=_('label settings'))
label_override = I18nField(_('Label override'), plural_name='label_overrides', blank=True, fallback_any=True) label_override = I18nField(_('Label override'), plural_name='label_overrides', blank=True, fallback_any=True)