176 lines
No EOL
4.8 KiB
Python
176 lines
No EOL
4.8 KiB
Python
from contextlib import suppress
|
|
from dataclasses import dataclass
|
|
from types import NoneType
|
|
from typing import Annotated, Any, Literal, Union, ClassVar
|
|
|
|
from django.core.exceptions import FieldDoesNotExist
|
|
from django.db.models import Model, ManyToManyField
|
|
from django.utils.functional import Promise
|
|
from ninja import Schema
|
|
from pydantic import Discriminator
|
|
from pydantic import Field as APIField
|
|
from pydantic import model_validator
|
|
from pydantic.functional_validators import ModelWrapValidatorHandler
|
|
from pydantic_core.core_schema import ValidationInfo
|
|
|
|
from c3nav.api.utils import NonEmptyStr
|
|
|
|
|
|
def make_serializable(values: Any):
|
|
if isinstance(values, Schema):
|
|
return values
|
|
if isinstance(values, (str, bool, int, float, complex, NoneType)):
|
|
return values
|
|
if isinstance(values, dict):
|
|
return {
|
|
key: make_serializable(val)
|
|
for key, val in values.items()
|
|
}
|
|
if isinstance(values, (list, tuple, set, frozenset)):
|
|
if values and isinstance(next(iter(values)), Model):
|
|
return type(values)(val.pk for val in values)
|
|
return type(values)(make_serializable(val) for val in values)
|
|
if isinstance(values, Promise):
|
|
return str(values)
|
|
return values
|
|
|
|
|
|
@dataclass
|
|
class ModelDataForwarder:
|
|
obj: Model
|
|
overrides: dict
|
|
|
|
def __getattr__(self, key):
|
|
# noinspection PyUnusedLocal
|
|
with suppress(KeyError):
|
|
return make_serializable(self.overrides[key])
|
|
with suppress(FieldDoesNotExist):
|
|
field = self.obj._meta.get_field(key)
|
|
if field.is_relation:
|
|
if field.many_to_many:
|
|
return [obj.pk for obj in getattr(self.obj, key).all()]
|
|
return make_serializable(getattr(self.obj, field.attname))
|
|
return make_serializable(getattr(self.obj, key))
|
|
|
|
|
|
class BaseSchema(Schema):
|
|
orig_keys: ClassVar[frozenset[str]] = frozenset()
|
|
|
|
@model_validator(mode="wrap") # noqa
|
|
@classmethod
|
|
def _run_root_validator(cls, values: Any, handler: ModelWrapValidatorHandler[Schema], info: ValidationInfo) -> Any:
|
|
""" overwriting this, we need to call serialize to get the correct data """
|
|
if hasattr(values, 'serialize') and callable(values.serialize) and not getattr(values, 'new_serialize', False):
|
|
converted = make_serializable(values.serialize())
|
|
elif isinstance(values, Model):
|
|
converted = ModelDataForwarder(
|
|
obj=values,
|
|
overrides=cls.get_overrides(values),
|
|
)
|
|
else:
|
|
converted = make_serializable(values)
|
|
return handler(converted)
|
|
|
|
@classmethod
|
|
def get_overrides(cls, value: Model) -> dict:
|
|
return {}
|
|
|
|
|
|
class APIErrorSchema(BaseSchema):
|
|
"""
|
|
An error has occured with this request
|
|
"""
|
|
detail: NonEmptyStr = APIField(
|
|
description="A human-readable error description"
|
|
)
|
|
|
|
|
|
class PolygonSchema(BaseSchema):
|
|
"""
|
|
A GeoJSON Polygon
|
|
"""
|
|
type: Literal["Polygon"]
|
|
coordinates: list[list[tuple[float, float]]] = APIField(
|
|
example=[[[1.5, 1.5], [1.5, 2.5], [2.5, 2.5], [2.5, 2.5]]]
|
|
)
|
|
|
|
class Config(Schema.Config):
|
|
title = "GeoJSON Polygon"
|
|
|
|
|
|
class MultiPolygonSchema(BaseSchema):
|
|
"""
|
|
A GeoJSON MultiPolygon
|
|
"""
|
|
type: Literal["MultiPolygon"]
|
|
coordinates: list[list[list[tuple[float, float]]]] = APIField(
|
|
example=[[[[1.5, 1.5], [1.5, 2.5], [2.5, 2.5], [2.5, 2.5]]]]
|
|
)
|
|
|
|
class Config(Schema.Config):
|
|
title = "GeoJSON Polygon"
|
|
|
|
|
|
class LineStringSchema(BaseSchema):
|
|
"""
|
|
A GeoJSON LineString
|
|
"""
|
|
type: Literal["LineString"]
|
|
coordinates: list[tuple[float, float]] = APIField(
|
|
example=[[1.5, 1.5], [2.5, 2.5], [5, 8.7]]
|
|
)
|
|
|
|
class Config(Schema.Config):
|
|
title = "GeoJSON LineString"
|
|
|
|
|
|
class LineSchema(BaseSchema):
|
|
"""
|
|
A GeoJSON LineString with only two points
|
|
"""
|
|
type: Literal["LineString"]
|
|
coordinates: tuple[tuple[float, float], tuple[float, float]] = APIField(
|
|
example=[[1.5, 1.5], [5, 8.7]]
|
|
)
|
|
|
|
class Config(Schema.Config):
|
|
title = "GeoJSON LineString (only two points)"
|
|
|
|
|
|
class PointSchema(BaseSchema):
|
|
"""
|
|
A GeoJSON Point
|
|
"""
|
|
type: Literal["Point"]
|
|
coordinates: tuple[float, float] = APIField(
|
|
example=[1, 2.5]
|
|
)
|
|
|
|
class Config(Schema.Config):
|
|
title = "GeoJSON Point"
|
|
|
|
|
|
GeometrySchema = Annotated[
|
|
Union[
|
|
PolygonSchema,
|
|
LineStringSchema,
|
|
PointSchema,
|
|
MultiPolygonSchema,
|
|
],
|
|
Discriminator("type"),
|
|
]
|
|
|
|
|
|
class AnyGeometrySchema(BaseSchema):
|
|
"""
|
|
A GeoJSON Geometry
|
|
"""
|
|
type: NonEmptyStr
|
|
coordinates: Any
|
|
|
|
|
|
class StatsSchema(BaseSchema):
|
|
users_total: int
|
|
reports_total: int
|
|
reports_today: int
|
|
reports_open: int |