serialize CustomLocation

This commit is contained in:
Laura Klünder 2024-12-04 13:01:02 +01:00
parent 00ec22c334
commit 907cc43214
3 changed files with 39 additions and 7 deletions

View file

@ -1,5 +1,5 @@
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass, is_dataclass
from types import NoneType from types import NoneType
from typing import Annotated, Any, Literal, Union, ClassVar from typing import Annotated, Any, Literal, Union, ClassVar
@ -27,7 +27,8 @@ def make_serializable(values: Any):
for key, val in values.items() for key, val in values.items()
} }
if isinstance(values, (list, tuple, set, frozenset)): if isinstance(values, (list, tuple, set, frozenset)):
if values and isinstance(next(iter(values)), Model): from c3nav.routing.router import BaseRouterProxy
if values and isinstance(next(iter(values)), (Model, BaseRouterProxy)):
return type(values)(val.pk for val in values) return type(values)(val.pk for val in values)
return type(values)(make_serializable(val) for val in values) return type(values)(make_serializable(val) for val in values)
if isinstance(values, Promise): if isinstance(values, Promise):
@ -35,6 +36,18 @@ def make_serializable(values: Any):
return values return values
@dataclass
class DataclassForwarder:
obj: Any
overrides: dict
def __getattr__(self, key):
# noinspection PyUnusedLocal
with suppress(KeyError):
return make_serializable(self.overrides[key])
return make_serializable(getattr(self.obj, key))
@dataclass @dataclass
class ModelDataForwarder: class ModelDataForwarder:
obj: Model obj: Model
@ -67,6 +80,11 @@ class BaseSchema(Schema):
obj=values, obj=values,
overrides=cls.get_overrides(values), overrides=cls.get_overrides(values),
) )
elif is_dataclass(values):
converted = DataclassForwarder(
obj=values,
overrides=cls.get_overrides(values),
)
else: else:
converted = make_serializable(values) converted = make_serializable(values)
return handler(converted) return handler(converted)

View file

@ -504,14 +504,19 @@ class CustomLocationSchema(BaseSchema):
title="ground altitude", title="ground altitude",
description="ground altitude (in the map-wide coordinate system)" description="ground altitude (in the map-wide coordinate system)"
) )
geometry: Union[ geometry: PointSchema = APIField(
PointSchema,
Annotated[None, APIField(title="null", description="geometry excluded from endpoint")]
] = APIField(
None,
description="point geometry for this custom location", description="point geometry for this custom location",
) )
@classmethod
def get_overrides(cls, value):
return {
"id": value.pk,
"space": value.space.pk if value.space else None,
"level": value.level.pk,
"geometry": value.serialized_geometry,
}
class TrackablePositionSchema(BaseSchema): class TrackablePositionSchema(BaseSchema):
""" """

View file

@ -337,6 +337,15 @@ class CustomLocation:
return result return result
@property
def point(self):
return (self.level.pk, self.x, self.y)
@property
def bounds(self):
return ((int(math.floor(self.x)), int(math.floor(self.y))),
(int(math.ceil(self.x)), int(math.ceil(self.y))))
def details_display(self, **kwargs): def details_display(self, **kwargs):
result = { result = {
'id': self.pk, 'id': self.pk,