find all references and more as_operations code

This commit is contained in:
Laura Klünder 2024-11-05 13:33:16 +01:00
parent 54920a2f54
commit ce9c87ae4c
3 changed files with 104 additions and 27 deletions

View file

@ -1,23 +1,24 @@
import operator
from functools import reduce
from itertools import chain
from typing import Type, Any, Optional, Annotated, Union
from typing import Type, Any, Union
from django.apps import apps
from django.db.models import Model, OneToOneField, ForeignKey
from django.db.models import Model, Q
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel
from pydantic.config import ConfigDict
from c3nav.api.schema import BaseSchema
from c3nav.editor.operations import DatabaseOperationCollection, CreateObjectOperation, UpdateObjectOperation, \
DeleteObjectOperation, ClearManyToManyOperation, FieldValuesDict, ObjectReference, PreviousObjectCollection, \
DatabaseOperation
DatabaseOperation, ObjectID, FieldName, ModelName
from c3nav.mapdata.fields import I18nField
class ChangedManyToMany(BaseSchema):
cleared: bool = False
added: list[int] = []
removed: list[int] = []
added: list[ObjectID] = []
removed: list[ObjectID] = []
class ChangedObject(BaseSchema):
@ -26,7 +27,7 @@ class ChangedObject(BaseSchema):
created: bool = False
deleted: bool = False
fields: FieldValuesDict = {}
m2m_changes: dict[str, ChangedManyToMany] = {}
m2m_changes: dict[FieldName, ChangedManyToMany] = {}
class OperationDependencyObjectExists(BaseSchema):
@ -38,7 +39,7 @@ class OperationDependencyUniqueValue(BaseSchema):
model_config = ConfigDict(frozen=True)
model: str
field: str
field: FieldName
value: Any
nullable: bool
@ -60,6 +61,10 @@ class SingleOperationWithDependencies(BaseSchema):
operation: DatabaseOperation
dependencies: set[OperationDependency] = set()
@property
def main_operation(self) -> DatabaseOperation:
return self.operation
class MergableOperationsWithDependencies(BaseSchema):
children: list[SingleOperationWithDependencies]
@ -68,6 +73,10 @@ class MergableOperationsWithDependencies(BaseSchema):
def dependencies(self) -> set[OperationDependency]:
return reduce(operator.or_, (c.dependencies for c in self.children), set())
@property
def main_operation(self) -> DatabaseOperation:
return self.children[0].operation
OperationWithDependencies = Union[
SingleOperationWithDependencies,
@ -75,6 +84,14 @@ OperationWithDependencies = Union[
]
class FoundObjectReference(BaseSchema):
model_config = ConfigDict(frozen=True)
obj: ObjectReference
field: FieldName
on_delete: str
class DummyValue:
pass
@ -86,7 +103,7 @@ class ChangedObjectCollection(BaseSchema):
Iterable as a list of ChangedObject instances.
"""
prev: PreviousObjectCollection = PreviousObjectCollection()
objects: dict[str, dict[int, ChangedObject]] = {}
objects: dict[ModelName, dict[ObjectID, ChangedObject]] = {}
def __iter__(self):
yield from chain(*(objects.values() for model, objects in self.objects.items()))
@ -137,11 +154,11 @@ class ChangedObjectCollection(BaseSchema):
| operation.remove_values)
def clean_and_complete_prev(self):
ids: dict[str, set[int]] = {}
ids: dict[ModelName, set[ObjectID]] = {}
for model_name, changed_objects in self.objects.items():
ids.setdefault(model_name, set()).update(set(changed_objects.keys()))
model = apps.get_model("mapdata", model_name)
relations: dict[str, Type[Model]] = {field.name: field.related_model
relations: dict[FieldName, Type[Model]] = {field.name: field.related_model
for field in model.get_fields() if field.is_relation}
for obj in changed_objects.values():
for field_name, value in obj.fields.items():
@ -225,6 +242,67 @@ class ChangedObjectCollection(BaseSchema):
from pprint import pprint
pprint(operations_with_dependencies)
# time to check which stuff cannot be done
objects_to_delete: dict[ModelName, set[ObjectID]] = {} # objects that will be deleted [find references!]
objects_to_exist_before: dict[ModelName, set[ObjectID]] = {} # objects that need to exist before [won't be created!]
objects_to_create: dict[ModelName, set[ObjectID]] = {} # objects that will be created [needed to create the previous var]
for operation in operations_with_dependencies:
main_operation = operation.main_operation
if isinstance(main_operation, DeleteObjectOperation):
objects_to_delete.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
objects_to_exist_before.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
if isinstance(main_operation, UpdateObjectOperation):
objects_to_exist_before.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
else:
objects_to_create.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
for dependency in operation.dependencies:
if isinstance(dependency, OperationDependencyObjectExists):
objects_to_exist_before.setdefault(dependency.obj.model, set()).add(dependency.obj.id)
# objects that we create do not need to exist before
for model, ids in objects_to_create.items():
objects_to_exist_before.get(model, set()).difference_update(ids)
# let's find which objects that need to exist before actually exist
objects_exist_before: dict[ModelName, dict[ObjectID, bool]] = {}
for model, ids in objects_to_exist_before.items():
model_cls = apps.get_model('mapdata', model)
ids_found = set(model_cls.objects.filter(pk__in=ids).values_list('pk', flat=True))
objects_exist_before[model] = {id_: (id_ in ids_found) for id_ in ids}
# let's find which protected references objects we want to delete have
potential_fields: dict[ModelName, dict[FieldName, dict[ModelName, set[ObjectID]]]] = {}
for model, ids in objects_to_exist_before.items():
for field in apps.get_model('mapdata', model)._meta.get_fields():
if isinstance(field, (ManyToOneRel, OneToOneRel)) or field.model._meta.app_label != "mapdata":
continue
potential_fields.setdefault(field.related_model._meta.model_name,
{}).setdefault(field.field.attname, {})[model] = ids
# collect all references
found_obj_references: dict[ModelName, dict[ObjectID, set[FoundObjectReference]]] = {}
for model, fields in potential_fields.items():
model_cls = apps.get_model('mapdata', model)
q = Q()
targets_reverse: dict[FieldName, dict[ObjectID, ModelName]] = {}
for field_name, targets in fields.items():
ids = reduce(operator.or_, targets.values(), set())
q |= Q(**{f'{field_name}__in': ids})
targets_reverse[field_name] = dict(chain(*(((id_, target_model) for id_, in target_ids)
for target_model, target_ids in targets)))
for result in model_cls.objects.filter(q).values("id", *fields.keys()):
source_ref = ObjectReference(model=model, id=result.pop("id"))
for field, target_id in result.items():
target_model = targets_reverse[field][target_id]
found_obj_references.setdefault(target_model, {}).setdefault(target_id, set()).add(
FoundObjectReference(obj=source_ref, field=field,
on_delete=model_cls._meta.get_field(field).on_delete.__name__)
)
# todo: continue here
return DatabaseOperationCollection()

View file

@ -1,23 +1,22 @@
import datetime
import json
from dataclasses import dataclass
from typing import Annotated, Literal, Union, TypeAlias, Any, Self
from uuid import UUID, uuid4
from typing import Annotated, Literal, Union, TypeAlias, Any, Self, Iterator
from django.apps import apps
from django.core import serializers
from django.db.models import Model
from django.utils import timezone
from pydantic import ConfigDict
from pydantic.fields import Field
from pydantic.types import Discriminator
from c3nav.api.schema import BaseSchema
from c3nav.mapdata.fields import I18nField
from c3nav.mapdata.models import LocationSlug
ModelName: TypeAlias = str
ObjectID: TypeAlias = int
FieldName: TypeAlias = str
FieldValuesDict: TypeAlias = dict[str, Any]
FieldValuesDict: TypeAlias = dict[FieldName, Any]
class ObjectReference(BaseSchema):
@ -25,8 +24,8 @@ class ObjectReference(BaseSchema):
Reference to an object based on model name and ID.
"""
model_config = ConfigDict(frozen=True)
model: str
id: int
model: ModelName
id: ObjectID
@classmethod
def from_instance(cls, instance: Model):
@ -42,12 +41,12 @@ class PreviousObject(BaseSchema):
class PreviousObjectCollection(BaseSchema):
objects: dict[str, dict[int, PreviousObject]] = {}
objects: dict[ModelName, dict[ObjectID, PreviousObject]] = {}
def get(self, ref: ObjectReference) -> PreviousObject | None:
return self.objects.get(ref.model, {}).get(ref.id, None)
def get_ids(self) -> dict[str, set[int]]:
def get_ids(self) -> dict[ModelName, set[ObjectID]]:
"""
:return: all referenced IDs sorted by model
"""
@ -155,9 +154,9 @@ class DeleteObjectOperation(BaseOperation):
class UpdateManyToManyOperation(BaseOperation):
type: Literal["m2m_add"] = "m2m_update"
field: str
add_values: set[int] = set()
remove_values: set[int] = set()
field: FieldName
add_values: set[ObjectID] = set()
remove_values: set[ObjectID] = set()
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
values[self.field] = sorted((set(values[self.field]) | self.add_values) - self.remove_values)
@ -169,7 +168,7 @@ class UpdateManyToManyOperation(BaseOperation):
class ClearManyToManyOperation(BaseOperation):
type: Literal["m2m_clear"] = "m2m_clear"
field: str
field: FieldName
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
values[self.field] = []
@ -198,7 +197,7 @@ class DatabaseOperationCollection(BaseSchema):
prev: PreviousObjectCollection = PreviousObjectCollection()
_operations: list[DatabaseOperation] = []
def __iter__(self):
def __iter__(self) -> Iterator[DatabaseOperation]:
yield from self._operations
def __len__(self):

View file

@ -177,7 +177,7 @@ class Location(LocationSlug, AccessRestrictionMixin, TitledMixin, models.Model):
class SpecificLocation(Location, models.Model):
groups = models.ManyToManyField('mapdata.LocationGroup', verbose_name=_('Location Groups'), blank=True)
groups = models.ManyToManyField('mapdata.LocationGroup', verbose_name__=_('Location Groups'), blank=True)
label_settings = models.ForeignKey('mapdata.LabelSettings', null=True, blank=True, on_delete=models.PROTECT,
verbose_name=_('label settings'))
label_override = I18nField(_('Label override'), plural_name='label_overrides', blank=True, fallback_any=True)