From 35a87384245ac5a4dcf159f9c6b6a47cd27a7808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Tue, 3 Dec 2024 14:18:16 +0100 Subject: [PATCH] implement new serialize for level --- src/c3nav/api/schema.py | 73 ++++++++++++++++++------- src/c3nav/mapdata/models/level.py | 15 ++--- src/c3nav/mapdata/models/locations.py | 30 ++++++---- src/c3nav/mapdata/schemas/model_base.py | 19 ++++++- src/c3nav/site/static/site/js/c3nav.js | 10 ++-- 5 files changed, 98 insertions(+), 49 deletions(-) diff --git a/src/c3nav/api/schema.py b/src/c3nav/api/schema.py index 22e407ba..a756982d 100644 --- a/src/c3nav/api/schema.py +++ b/src/c3nav/api/schema.py @@ -1,6 +1,10 @@ +from contextlib import suppress +from dataclasses import dataclass from types import NoneType -from typing import Annotated, Any, Literal, Union +from typing import Annotated, Any, Literal, Union, ClassVar +from django.core.exceptions import FieldDoesNotExist +from django.db.models import Model, ManyToManyField from django.utils.functional import Promise from ninja import Schema from pydantic import Discriminator @@ -12,31 +16,62 @@ from pydantic_core.core_schema import ValidationInfo from c3nav.api.utils import NonEmptyStr +def make_serializable(values: Any): + if isinstance(values, Schema): + return values + if isinstance(values, (str, bool, int, float, complex, NoneType)): + return values + if isinstance(values, dict): + return { + key: make_serializable(val) + for key, val in values.items() + } + if isinstance(values, (list, tuple, set, frozenset)): + return type(values)(make_serializable(val) for val in values) + if isinstance(values, Promise): + return str(values) + return values + + +@dataclass +class ModelDataForwarder: + obj: Model + overrides: dict + + def __getattr__(self, key): + # noinspection PyUnusedLocal + with suppress(KeyError): + return make_serializable(self.overrides[key]) + with suppress(FieldDoesNotExist): + field = self.obj._meta.get_field(key) + if field.is_relation: + if field.many_to_many: + return [obj.pk for obj in getattr(self.obj, key).all()] + return make_serializable(getattr(self.obj, field.attname)) + return make_serializable(getattr(self.obj, key)) + + class BaseSchema(Schema): + orig_keys: ClassVar[frozenset[str]] = frozenset() + @model_validator(mode="wrap") # noqa @classmethod def _run_root_validator(cls, values: Any, handler: ModelWrapValidatorHandler[Schema], info: ValidationInfo) -> Any: """ overwriting this, we need to call serialize to get the correct data """ - return handler(cls.convert(values)) + if hasattr(values, 'serialize') and callable(values.serialize) and not getattr(values, 'new_serialize', False): + converted = make_serializable(values.serialize()) + elif isinstance(values, Model): + converted = ModelDataForwarder( + obj=values, + overrides=cls.get_overrides(values), + ) + else: + converted = make_serializable(values) + return handler(converted) @classmethod - def convert(cls, values: Any): - if isinstance(values, Schema): - return values - if isinstance(values, (str, bool, int, float, complex, NoneType)): - return values - if isinstance(values, dict): - return { - key: cls.convert(val) - for key, val in values.items() - } - if isinstance(values, (list, tuple, set, frozenset)): - return type(values)(cls.convert(val) for val in values) - if isinstance(values, Promise): - return str(values) - if hasattr(values, 'serialize') and callable(values.serialize): - return cls.convert(values.serialize()) - return values + def get_overrides(cls, value: Model) -> dict: + return {} class APIErrorSchema(BaseSchema): diff --git a/src/c3nav/mapdata/models/level.py b/src/c3nav/mapdata/models/level.py index eb89aa54..c0c3f824 100644 --- a/src/c3nav/mapdata/models/level.py +++ b/src/c3nav/mapdata/models/level.py @@ -14,8 +14,12 @@ from c3nav.mapdata.models.locations import SpecificLocation class Level(SpecificLocation, models.Model): """ - A map level + A physical level of the map, containing building, spaces, doors… + + A level is a specific location, and can therefore be routed to and from, as well as belong to location groups. """ + new_serialize = True + base_altitude = models.DecimalField(_('base altitude'), null=False, unique=True, max_digits=6, decimal_places=2) default_height = models.DecimalField(_('default space height'), max_digits=6, decimal_places=2, default=3.0, validators=[MinValueValidator(Decimal('0'))]) @@ -68,15 +72,6 @@ class Level(SpecificLocation, models.Model): def primary_level_pk(self): return self.pk if self.on_top_of_id is None else self.on_top_of_id - def _serialize(self, level=True, **kwargs): - result = super()._serialize(**kwargs) - result['short_label'] = self.short_label - result['on_top_of'] = self.on_top_of_id - result['base_altitude'] = float(str(self.base_altitude)) - result['default_height'] = float(str(self.default_height)) - result['door_height'] = float(str(self.door_height)) - return result - def details_display(self, editor_url=True, **kwargs): result = super().details_display(**kwargs) result['display'].insert(3, (_('short label'), self.short_label)) diff --git a/src/c3nav/mapdata/models/locations.py b/src/c3nav/mapdata/models/locations.py index 6ae91932..f258a29a 100644 --- a/src/c3nav/mapdata/models/locations.py +++ b/src/c3nav/mapdata/models/locations.py @@ -113,7 +113,7 @@ class Location(LocationSlug, AccessRestrictionMixin, TitledMixin, models.Model): class Meta: abstract = True - def serialize(self, detailed=True, describe_only=False, **kwargs): + def serialize(self, detailed=True, **kwargs): result = super().serialize(detailed=detailed, **kwargs) if not detailed: fields = ('id', 'type', 'slug', 'title', 'subtitle', 'icon', 'point', 'bounds', 'grid_square', @@ -195,23 +195,19 @@ class SpecificLocation(Location, models.Model): if grid_square is not None: result['grid_square'] = grid_square or None if detailed: - groups = {} - for group in self.groups.all(): - groups.setdefault(group.category, []).append(group.pk) - groups = {category.name: (items[0] if items else None) if category.single else items - for category, items in groups.items() - if getattr(category, 'allow_'+self.__class__._meta.default_related_name)} - result['groups'] = groups + result['groups'] = self.groups_by_category - label_settings = self.get_label_settings() - if label_settings: - result['label_settings'] = label_settings.serialize(detailed=False) + result["label_settings"] = self.label_settings_id + effective_label_settings = self.effective_label_settings + if effective_label_settings: + result['effective_label_settings'] = effective_label_settings.serialize(detailed=False) if self.label_overrides: # todo: what if only one language is set? result['label_override'] = self.label_override return result - def get_label_settings(self): + @property + def effective_label_settings(self): if self.label_settings: return self.label_settings for group in self.groups.all(): @@ -219,6 +215,16 @@ class SpecificLocation(Location, models.Model): return group.label_settings return None + @property + def groups_by_category(self): + groups_by_category = {} + for group in self.groups.all(): + groups_by_category.setdefault(group.category, []).append(group.pk) + groups_by_category = {category.name: (items[0] if items else None) if category.single else items + for category, items in groups_by_category.items() + if getattr(category, 'allow_' + self.__class__._meta.default_related_name)} + return groups_by_category + def details_display(self, **kwargs): result = super().details_display(**kwargs) diff --git a/src/c3nav/mapdata/schemas/model_base.py b/src/c3nav/mapdata/schemas/model_base.py index f54cc3b6..568d4b6b 100644 --- a/src/c3nav/mapdata/schemas/model_base.py +++ b/src/c3nav/mapdata/schemas/model_base.py @@ -144,7 +144,12 @@ class SpecificLocationSchema(LocationSchema): description="grid cell(s) that this location is in, if a grid is defined and the location is within it", example="C3", ) - groups: dict[ + groups: list[PositiveInt] = APIField( + title="location groups", + description="location group(s) that this specific location belongs to.", + example=[5, 1, 3, 7], + ) + groups_by_category: dict[ Annotated[NonEmptyStr, APIField(title="location group category name")], Union[ Annotated[list[PositiveInt], APIField( @@ -163,7 +168,7 @@ class SpecificLocationSchema(LocationSchema): )], ] ] = APIField( - title="location groups", + title="location groups by category", description="location group(s) that this specific location belongs to, grouped by categories.\n\n" "keys are location group category names. see location group category endpoint for details.\n\n" "categories may be missing if no groups apply.", @@ -173,7 +178,15 @@ class SpecificLocationSchema(LocationSchema): "category_with_single_false": [1, 3, 7], } ) - label_settings: Union[ + label_settings: Optional[PositiveInt] = APIField( + default=None, + title="label settings", + description=( + schema_description(LabelSettingsSchema) + + "\n\nif not set, label settings of location groups might be used" + ) + ) + effective_label_settings: Union[ Annotated[LabelSettingsSchema, APIField( title="label settings", description="label settings to use", diff --git a/src/c3nav/site/static/site/js/c3nav.js b/src/c3nav/site/static/site/js/c3nav.js index a1b46569..713eb7f5 100644 --- a/src/c3nav/site/static/site/js/c3nav.js +++ b/src/c3nav/site/static/site/js/c3nav.js @@ -152,8 +152,8 @@ c3nav = { }); }, _sort_labels: function (a, b) { - var result = (a[0].label_settings.min_zoom || -10) - (b[0].label_settings.min_zoom || -10); - if (result === 0) result = b[0].label_settings.font_size - a[0].label_settings.font_size; + var result = (a[0].effective_label_settings.min_zoom || -10) - (b[0].effective_label_settings.min_zoom || -10); + if (result === 0) result = b[0].effective_label_settings.font_size - a[0].effective_label_settings.font_size; return result; }, _last_time_searchable_locations_loaded: null, @@ -406,12 +406,12 @@ c3nav = { for (var item of labels) { location = item[0]; label = item[1]; - if (zoom < (location.label_settings.min_zoom || -10)) { + if (zoom < (location.effective_label_settings.min_zoom || -10)) { // since the labels are sorted by min_zoom, we can just leave here break; } if (bounds.contains(label.getLatLng())) { - if ((location.label_settings.max_zoom || 10) > zoom) { + if ((location.effective_label_settings.max_zoom || 10) > zoom) { c3nav._labelLayer._maybeAddLayerToRBush(label); } else { valid_upper.unshift(label); @@ -1008,7 +1008,7 @@ c3nav = { new_text[i] = new_text[i].trim(); } var html = $('
').append($('').html(' ' + new_text.join(' 
 ') + ' ')); - html.css('font-size', location.label_settings.font_size + 'px'); + html.css('font-size', location.effective_label_settings.font_size + 'px'); return L.marker(L.GeoJSON.coordsToLatLng(location.point.slice(1)), { icon: L.divIcon({ html: html[0].outerHTML,