From 907cc43214a25cad0044424a83a038defc8ab3c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Wed, 4 Dec 2024 13:01:02 +0100 Subject: [PATCH] serialize CustomLocation --- src/c3nav/api/schema.py | 22 ++++++++++++++++++++-- src/c3nav/mapdata/schemas/models.py | 15 ++++++++++----- src/c3nav/mapdata/utils/locations.py | 9 +++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/c3nav/api/schema.py b/src/c3nav/api/schema.py index e15ec589..1d1cd78e 100644 --- a/src/c3nav/api/schema.py +++ b/src/c3nav/api/schema.py @@ -1,5 +1,5 @@ from contextlib import suppress -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from types import NoneType from typing import Annotated, Any, Literal, Union, ClassVar @@ -27,7 +27,8 @@ def make_serializable(values: Any): for key, val in values.items() } if isinstance(values, (list, tuple, set, frozenset)): - if values and isinstance(next(iter(values)), Model): + from c3nav.routing.router import BaseRouterProxy + if values and isinstance(next(iter(values)), (Model, BaseRouterProxy)): return type(values)(val.pk for val in values) return type(values)(make_serializable(val) for val in values) if isinstance(values, Promise): @@ -35,6 +36,18 @@ def make_serializable(values: Any): return values +@dataclass +class DataclassForwarder: + obj: Any + overrides: dict + + def __getattr__(self, key): + # noinspection PyUnusedLocal + with suppress(KeyError): + return make_serializable(self.overrides[key]) + return make_serializable(getattr(self.obj, key)) + + @dataclass class ModelDataForwarder: obj: Model @@ -67,6 +80,11 @@ class BaseSchema(Schema): obj=values, overrides=cls.get_overrides(values), ) + elif is_dataclass(values): + converted = DataclassForwarder( + obj=values, + overrides=cls.get_overrides(values), + ) else: converted = make_serializable(values) return handler(converted) diff --git a/src/c3nav/mapdata/schemas/models.py b/src/c3nav/mapdata/schemas/models.py index 126509b3..81547a8a 100644 --- a/src/c3nav/mapdata/schemas/models.py +++ b/src/c3nav/mapdata/schemas/models.py @@ -504,14 +504,19 @@ class CustomLocationSchema(BaseSchema): title="ground altitude", description="ground altitude (in the map-wide coordinate system)" ) - geometry: Union[ - PointSchema, - Annotated[None, APIField(title="null", description="geometry excluded from endpoint")] - ] = APIField( - None, + geometry: PointSchema = APIField( description="point geometry for this custom location", ) + @classmethod + def get_overrides(cls, value): + return { + "id": value.pk, + "space": value.space.pk if value.space else None, + "level": value.level.pk, + "geometry": value.serialized_geometry, + } + class TrackablePositionSchema(BaseSchema): """ diff --git a/src/c3nav/mapdata/utils/locations.py b/src/c3nav/mapdata/utils/locations.py index 08717560..acc13ab5 100644 --- a/src/c3nav/mapdata/utils/locations.py +++ b/src/c3nav/mapdata/utils/locations.py @@ -337,6 +337,15 @@ class CustomLocation: return result + @property + def point(self): + return (self.level.pk, self.x, self.y) + + @property + def bounds(self): + return ((int(math.floor(self.x)), int(math.floor(self.y))), + (int(math.ceil(self.x)), int(math.ceil(self.y)))) + def details_display(self, **kwargs): result = { 'id': self.pk,