find all references and more as_operations code
This commit is contained in:
parent
54920a2f54
commit
ce9c87ae4c
3 changed files with 104 additions and 27 deletions
|
@ -1,23 +1,24 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import Type, Any, Optional, Annotated, Union
|
||||
from typing import Type, Any, Union
|
||||
|
||||
from django.apps import apps
|
||||
from django.db.models import Model, OneToOneField, ForeignKey
|
||||
from django.db.models import Model, Q
|
||||
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel
|
||||
from pydantic.config import ConfigDict
|
||||
|
||||
from c3nav.api.schema import BaseSchema
|
||||
from c3nav.editor.operations import DatabaseOperationCollection, CreateObjectOperation, UpdateObjectOperation, \
|
||||
DeleteObjectOperation, ClearManyToManyOperation, FieldValuesDict, ObjectReference, PreviousObjectCollection, \
|
||||
DatabaseOperation
|
||||
DatabaseOperation, ObjectID, FieldName, ModelName
|
||||
from c3nav.mapdata.fields import I18nField
|
||||
|
||||
|
||||
class ChangedManyToMany(BaseSchema):
|
||||
cleared: bool = False
|
||||
added: list[int] = []
|
||||
removed: list[int] = []
|
||||
added: list[ObjectID] = []
|
||||
removed: list[ObjectID] = []
|
||||
|
||||
|
||||
class ChangedObject(BaseSchema):
|
||||
|
@ -26,7 +27,7 @@ class ChangedObject(BaseSchema):
|
|||
created: bool = False
|
||||
deleted: bool = False
|
||||
fields: FieldValuesDict = {}
|
||||
m2m_changes: dict[str, ChangedManyToMany] = {}
|
||||
m2m_changes: dict[FieldName, ChangedManyToMany] = {}
|
||||
|
||||
|
||||
class OperationDependencyObjectExists(BaseSchema):
|
||||
|
@ -38,7 +39,7 @@ class OperationDependencyUniqueValue(BaseSchema):
|
|||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
model: str
|
||||
field: str
|
||||
field: FieldName
|
||||
value: Any
|
||||
nullable: bool
|
||||
|
||||
|
@ -60,6 +61,10 @@ class SingleOperationWithDependencies(BaseSchema):
|
|||
operation: DatabaseOperation
|
||||
dependencies: set[OperationDependency] = set()
|
||||
|
||||
@property
|
||||
def main_operation(self) -> DatabaseOperation:
|
||||
return self.operation
|
||||
|
||||
|
||||
class MergableOperationsWithDependencies(BaseSchema):
|
||||
children: list[SingleOperationWithDependencies]
|
||||
|
@ -68,6 +73,10 @@ class MergableOperationsWithDependencies(BaseSchema):
|
|||
def dependencies(self) -> set[OperationDependency]:
|
||||
return reduce(operator.or_, (c.dependencies for c in self.children), set())
|
||||
|
||||
@property
|
||||
def main_operation(self) -> DatabaseOperation:
|
||||
return self.children[0].operation
|
||||
|
||||
|
||||
OperationWithDependencies = Union[
|
||||
SingleOperationWithDependencies,
|
||||
|
@ -75,6 +84,14 @@ OperationWithDependencies = Union[
|
|||
]
|
||||
|
||||
|
||||
class FoundObjectReference(BaseSchema):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
obj: ObjectReference
|
||||
field: FieldName
|
||||
on_delete: str
|
||||
|
||||
|
||||
class DummyValue:
|
||||
pass
|
||||
|
||||
|
@ -86,7 +103,7 @@ class ChangedObjectCollection(BaseSchema):
|
|||
Iterable as a list of ChangedObject instances.
|
||||
"""
|
||||
prev: PreviousObjectCollection = PreviousObjectCollection()
|
||||
objects: dict[str, dict[int, ChangedObject]] = {}
|
||||
objects: dict[ModelName, dict[ObjectID, ChangedObject]] = {}
|
||||
|
||||
def __iter__(self):
|
||||
yield from chain(*(objects.values() for model, objects in self.objects.items()))
|
||||
|
@ -137,11 +154,11 @@ class ChangedObjectCollection(BaseSchema):
|
|||
| operation.remove_values)
|
||||
|
||||
def clean_and_complete_prev(self):
|
||||
ids: dict[str, set[int]] = {}
|
||||
ids: dict[ModelName, set[ObjectID]] = {}
|
||||
for model_name, changed_objects in self.objects.items():
|
||||
ids.setdefault(model_name, set()).update(set(changed_objects.keys()))
|
||||
model = apps.get_model("mapdata", model_name)
|
||||
relations: dict[str, Type[Model]] = {field.name: field.related_model
|
||||
relations: dict[FieldName, Type[Model]] = {field.name: field.related_model
|
||||
for field in model.get_fields() if field.is_relation}
|
||||
for obj in changed_objects.values():
|
||||
for field_name, value in obj.fields.items():
|
||||
|
@ -225,6 +242,67 @@ class ChangedObjectCollection(BaseSchema):
|
|||
from pprint import pprint
|
||||
pprint(operations_with_dependencies)
|
||||
|
||||
# time to check which stuff cannot be done
|
||||
objects_to_delete: dict[ModelName, set[ObjectID]] = {} # objects that will be deleted [find references!]
|
||||
objects_to_exist_before: dict[ModelName, set[ObjectID]] = {} # objects that need to exist before [won't be created!]
|
||||
objects_to_create: dict[ModelName, set[ObjectID]] = {} # objects that will be created [needed to create the previous var]
|
||||
for operation in operations_with_dependencies:
|
||||
main_operation = operation.main_operation
|
||||
if isinstance(main_operation, DeleteObjectOperation):
|
||||
objects_to_delete.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
|
||||
objects_to_exist_before.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
|
||||
|
||||
if isinstance(main_operation, UpdateObjectOperation):
|
||||
objects_to_exist_before.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
|
||||
else:
|
||||
objects_to_create.setdefault(main_operation.obj.model, set()).add(main_operation.obj.id)
|
||||
|
||||
for dependency in operation.dependencies:
|
||||
if isinstance(dependency, OperationDependencyObjectExists):
|
||||
objects_to_exist_before.setdefault(dependency.obj.model, set()).add(dependency.obj.id)
|
||||
|
||||
# objects that we create do not need to exist before
|
||||
for model, ids in objects_to_create.items():
|
||||
objects_to_exist_before.get(model, set()).difference_update(ids)
|
||||
|
||||
# let's find which objects that need to exist before actually exist
|
||||
objects_exist_before: dict[ModelName, dict[ObjectID, bool]] = {}
|
||||
for model, ids in objects_to_exist_before.items():
|
||||
model_cls = apps.get_model('mapdata', model)
|
||||
ids_found = set(model_cls.objects.filter(pk__in=ids).values_list('pk', flat=True))
|
||||
objects_exist_before[model] = {id_: (id_ in ids_found) for id_ in ids}
|
||||
|
||||
# let's find which protected references objects we want to delete have
|
||||
potential_fields: dict[ModelName, dict[FieldName, dict[ModelName, set[ObjectID]]]] = {}
|
||||
for model, ids in objects_to_exist_before.items():
|
||||
for field in apps.get_model('mapdata', model)._meta.get_fields():
|
||||
if isinstance(field, (ManyToOneRel, OneToOneRel)) or field.model._meta.app_label != "mapdata":
|
||||
continue
|
||||
potential_fields.setdefault(field.related_model._meta.model_name,
|
||||
{}).setdefault(field.field.attname, {})[model] = ids
|
||||
|
||||
# collect all references
|
||||
found_obj_references: dict[ModelName, dict[ObjectID, set[FoundObjectReference]]] = {}
|
||||
for model, fields in potential_fields.items():
|
||||
model_cls = apps.get_model('mapdata', model)
|
||||
q = Q()
|
||||
targets_reverse: dict[FieldName, dict[ObjectID, ModelName]] = {}
|
||||
for field_name, targets in fields.items():
|
||||
ids = reduce(operator.or_, targets.values(), set())
|
||||
q |= Q(**{f'{field_name}__in': ids})
|
||||
targets_reverse[field_name] = dict(chain(*(((id_, target_model) for id_, in target_ids)
|
||||
for target_model, target_ids in targets)))
|
||||
for result in model_cls.objects.filter(q).values("id", *fields.keys()):
|
||||
source_ref = ObjectReference(model=model, id=result.pop("id"))
|
||||
for field, target_id in result.items():
|
||||
target_model = targets_reverse[field][target_id]
|
||||
found_obj_references.setdefault(target_model, {}).setdefault(target_id, set()).add(
|
||||
FoundObjectReference(obj=source_ref, field=field,
|
||||
on_delete=model_cls._meta.get_field(field).on_delete.__name__)
|
||||
)
|
||||
|
||||
|
||||
|
||||
# todo: continue here
|
||||
|
||||
return DatabaseOperationCollection()
|
||||
|
|
|
@ -1,23 +1,22 @@
|
|||
import datetime
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Union, TypeAlias, Any, Self
|
||||
from uuid import UUID, uuid4
|
||||
from typing import Annotated, Literal, Union, TypeAlias, Any, Self, Iterator
|
||||
|
||||
from django.apps import apps
|
||||
from django.core import serializers
|
||||
from django.db.models import Model
|
||||
from django.utils import timezone
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.fields import Field
|
||||
from pydantic.types import Discriminator
|
||||
|
||||
from c3nav.api.schema import BaseSchema
|
||||
from c3nav.mapdata.fields import I18nField
|
||||
from c3nav.mapdata.models import LocationSlug
|
||||
|
||||
ModelName: TypeAlias = str
|
||||
ObjectID: TypeAlias = int
|
||||
FieldName: TypeAlias = str
|
||||
|
||||
FieldValuesDict: TypeAlias = dict[str, Any]
|
||||
FieldValuesDict: TypeAlias = dict[FieldName, Any]
|
||||
|
||||
|
||||
class ObjectReference(BaseSchema):
|
||||
|
@ -25,8 +24,8 @@ class ObjectReference(BaseSchema):
|
|||
Reference to an object based on model name and ID.
|
||||
"""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
model: str
|
||||
id: int
|
||||
model: ModelName
|
||||
id: ObjectID
|
||||
|
||||
@classmethod
|
||||
def from_instance(cls, instance: Model):
|
||||
|
@ -42,12 +41,12 @@ class PreviousObject(BaseSchema):
|
|||
|
||||
|
||||
class PreviousObjectCollection(BaseSchema):
|
||||
objects: dict[str, dict[int, PreviousObject]] = {}
|
||||
objects: dict[ModelName, dict[ObjectID, PreviousObject]] = {}
|
||||
|
||||
def get(self, ref: ObjectReference) -> PreviousObject | None:
|
||||
return self.objects.get(ref.model, {}).get(ref.id, None)
|
||||
|
||||
def get_ids(self) -> dict[str, set[int]]:
|
||||
def get_ids(self) -> dict[ModelName, set[ObjectID]]:
|
||||
"""
|
||||
:return: all referenced IDs sorted by model
|
||||
"""
|
||||
|
@ -155,9 +154,9 @@ class DeleteObjectOperation(BaseOperation):
|
|||
|
||||
class UpdateManyToManyOperation(BaseOperation):
|
||||
type: Literal["m2m_add"] = "m2m_update"
|
||||
field: str
|
||||
add_values: set[int] = set()
|
||||
remove_values: set[int] = set()
|
||||
field: FieldName
|
||||
add_values: set[ObjectID] = set()
|
||||
remove_values: set[ObjectID] = set()
|
||||
|
||||
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
|
||||
values[self.field] = sorted((set(values[self.field]) | self.add_values) - self.remove_values)
|
||||
|
@ -169,7 +168,7 @@ class UpdateManyToManyOperation(BaseOperation):
|
|||
|
||||
class ClearManyToManyOperation(BaseOperation):
|
||||
type: Literal["m2m_clear"] = "m2m_clear"
|
||||
field: str
|
||||
field: FieldName
|
||||
|
||||
def apply(self, values: FieldValuesDict, instance: Model) -> Model:
|
||||
values[self.field] = []
|
||||
|
@ -198,7 +197,7 @@ class DatabaseOperationCollection(BaseSchema):
|
|||
prev: PreviousObjectCollection = PreviousObjectCollection()
|
||||
_operations: list[DatabaseOperation] = []
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[DatabaseOperation]:
|
||||
yield from self._operations
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -177,7 +177,7 @@ class Location(LocationSlug, AccessRestrictionMixin, TitledMixin, models.Model):
|
|||
|
||||
|
||||
class SpecificLocation(Location, models.Model):
|
||||
groups = models.ManyToManyField('mapdata.LocationGroup', verbose_name=_('Location Groups'), blank=True)
|
||||
groups = models.ManyToManyField('mapdata.LocationGroup', verbose_name__=_('Location Groups'), blank=True)
|
||||
label_settings = models.ForeignKey('mapdata.LabelSettings', null=True, blank=True, on_delete=models.PROTECT,
|
||||
verbose_name=_('label settings'))
|
||||
label_override = I18nField(_('Label override'), plural_name='label_overrides', blank=True, fallback_any=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue