team-3/src/c3nav/mapdata/api/mapdata.py

301 lines
10 KiB
Python
Raw Normal View History

2024-12-02 23:40:41 +01:00
from dataclasses import dataclass
from typing import Optional, Sequence, Type, Callable, Any
from django.db.models import Model
from ninja import Query
from ninja import Router as APIRouter
2024-12-02 23:40:41 +01:00
from pydantic import PositiveInt
2023-12-03 21:55:08 +01:00
from c3nav.api.auth import auth_responses, validate_responses
from c3nav.api.exceptions import API404
2024-12-02 23:40:41 +01:00
from c3nav.api.schema import BaseSchema
from c3nav.mapdata.api.base import api_etag, optimize_query, can_access_geometry
2023-11-19 16:36:46 +01:00
from c3nav.mapdata.models import (Area, Building, Door, Hole, Level, LocationGroup, LocationGroupCategory, Source,
2024-11-21 11:56:31 +01:00
Space, Stair, DataOverlay, DataOverlayFeature)
from c3nav.mapdata.models.access import AccessRestriction, AccessRestrictionGroup
2023-11-19 16:36:46 +01:00
from c3nav.mapdata.models.geometry.space import (POI, Column, CrossDescription, LeaveDescription, LineObstacle,
Obstacle, Ramp)
from c3nav.mapdata.models.locations import DynamicLocation, LabelSettings, LocationRedirect
from c3nav.mapdata.schemas.filters import (ByCategoryFilter, ByGroupFilter, ByOnTopOfFilter, FilterSchema,
LevelGeometryFilter, SpaceGeometryFilter, BySpaceFilter, ByOverlayFilter)
from c3nav.mapdata.schemas.model_base import schema_description, LabelSettingsSchema
from c3nav.mapdata.schemas.models import (AccessRestrictionGroupSchema, AccessRestrictionSchema, AreaSchema,
2023-11-23 23:22:30 +01:00
BuildingSchema, ColumnSchema, CrossDescriptionSchema, DoorSchema,
DynamicLocationSchema, HoleSchema, LeaveDescriptionSchema, LevelSchema,
LineObstacleSchema, LocationGroupCategorySchema, LocationGroupSchema,
2024-11-21 11:56:31 +01:00
ObstacleSchema, POISchema, RampSchema, SourceSchema, SpaceSchema, StairSchema,
DataOverlaySchema, DataOverlayFeatureSchema, LocationRedirectSchema)
mapdata_api_router = APIRouter(tags=["mapdata"])
def mapdata_list_endpoint(request,
model: Type[Model],
filters: Optional[FilterSchema] = None,
order_by: Sequence[str] = ('pk',)):
# validate filters
if filters:
filters.validate(request)
# get the queryset and filter it
qs = optimize_query(
model.qs_for_request(request) if hasattr(model, 'qs_for_request') else model.objects.all()
)
if filters:
2024-12-03 17:02:02 +01:00
qs = filters.filter_qs(request, qs)
# order_by
qs = qs.order_by(*order_by)
2024-12-03 18:42:33 +01:00
result = list(qs)
2024-12-03 18:42:33 +01:00
for obj in result:
if not can_access_geometry(request, obj):
2024-12-03 18:42:33 +01:00
obj._hide_geometry = True
return result
def mapdata_retrieve_endpoint(request, model: Type[Model], **lookups):
try:
obj = optimize_query(
model.qs_for_request(request) if hasattr(model, 'qs_for_request') else model.objects.all()
).get(**lookups)
if not can_access_geometry(request, obj):
obj.geometry = None
return obj
except model.DoesNotExist:
raise API404("%s not found" % model.__name__.lower())
2024-12-02 23:40:41 +01:00
@dataclass
class MapdataEndpoint:
model: Type[Model]
schema: Type[BaseSchema]
filters: Type[FilterSchema] | None = None
@property
def model_name(self):
return self.model._meta.model_name
@property
def model_name_plural(self):
return self.model._meta.default_related_name
@dataclass
class MapdataAPIBuilder:
router: APIRouter
def build_all_endpoints(self, endpoints: dict[str, list[MapdataEndpoint]]):
for tag, endpoints in endpoints.items():
for endpoint in endpoints:
self.add_endpoints(endpoint, tag=tag)
def add_endpoints(self, endpoint: MapdataEndpoint, tag: str):
self.add_list_endpoint(endpoint, tag=tag)
self.add_by_id_endpoint(endpoint, tag=tag)
def common_params(self, endpoint: MapdataEndpoint) -> dict[str: Any]:
return {"request": None}
def _make_endpoint(self, view_params: dict[str, str], call_func: Callable,
add_call_params: dict[str, str] = None) -> Callable:
if add_call_params is None:
add_call_params = {}
call_param_values = set(add_call_params.values())
call_params = (
*(f"{name}={name}" for name in set(view_params.keys())-call_param_values),
*(f"{name}={value}" for name, value in add_call_params.items()),
)
method_code = "\n".join((
f"def gen_func({", ".join((f"{name}: {hint}" if hint else name) for name, hint in view_params.items())}):",
2024-12-03 11:13:13 +01:00
f" return call_func({", ".join(call_params)})",
2024-12-02 23:40:41 +01:00
))
2024-12-03 11:13:13 +01:00
g = {
**globals(),
"call_func": call_func,
}
exec(method_code, g)
return g["gen_func"] # noqa
2024-12-02 23:40:41 +01:00
def add_list_endpoint(self, endpoint: MapdataEndpoint, tag: str):
view_params = self.common_params(endpoint)
Query # noqa
if endpoint.filters:
filters_name = endpoint.filters.__name__
view_params["filters"] = f"Query[{filters_name}]"
list_func = self._make_endpoint(
view_params=view_params,
call_func=mapdata_list_endpoint,
add_call_params={"model": endpoint.model.__name__}
)
list_func.__name__ = f"{endpoint.model_name}_list"
self.router.get(f"/{endpoint.model_name_plural}/", summary=f"{endpoint.model_name} list",
tags=[f"mapdata-{tag}"], description=schema_description(endpoint.schema),
response={200: list[endpoint.schema],
**(validate_responses if endpoint.filters else {}),
**auth_responses})(
api_etag()(list_func)
)
def add_by_id_endpoint(self, endpoint: MapdataEndpoint, tag: str):
view_params = self.common_params(endpoint)
PositiveInt # noqa
id_field = f"{endpoint.model_name}_id"
view_params[id_field] = "PositiveInt"
list_func = self._make_endpoint(
view_params=view_params,
call_func=mapdata_retrieve_endpoint,
add_call_params={"model": endpoint.model.__name__, "pk": id_field}
)
list_func.__name__ = f"{endpoint.model_name}_by_id"
2024-12-02 23:40:41 +01:00
self.router.get(f'/{endpoint.model_name_plural}/{{{id_field}}}/', summary=f"{endpoint.model_name} by ID",
tags=[f"mapdata-{tag}"], description=schema_description(endpoint.schema),
response={200: endpoint.schema, **API404.dict(), **auth_responses})(
api_etag()(list_func)
)
class LevelFilters(ByGroupFilter, ByOnTopOfFilter):
pass
class SpaceFilters(ByGroupFilter, LevelGeometryFilter):
pass
class AreaFilters(ByGroupFilter, SpaceGeometryFilter):
pass
2024-12-02 23:40:41 +01:00
mapdata_endpoints: dict[str, list[MapdataEndpoint]] = {
"root": [
MapdataEndpoint(
model=Level,
schema=LevelSchema,
filters=LevelFilters
),
MapdataEndpoint(
model=LocationGroup,
schema=LocationGroupSchema,
filters=ByCategoryFilter,
),
MapdataEndpoint(
model=LocationGroupCategory,
schema=LocationGroupCategorySchema,
),
MapdataEndpoint(
model=LocationRedirect,
schema=LocationRedirectSchema,
),
2024-12-02 23:40:41 +01:00
MapdataEndpoint(
model=Source,
schema=SourceSchema,
),
MapdataEndpoint(
model=AccessRestriction,
schema=AccessRestrictionSchema,
),
MapdataEndpoint(
model=AccessRestrictionGroup,
schema=AccessRestrictionGroupSchema,
),
MapdataEndpoint(
model=DynamicLocation,
schema=DynamicLocationSchema,
),
MapdataEndpoint(
model=LabelSettings,
schema=LabelSettingsSchema,
),
MapdataEndpoint(
model=DataOverlay,
schema=DataOverlaySchema,
),
MapdataEndpoint(
model=DataOverlayFeature,
schema=DataOverlayFeatureSchema,
filters=ByOverlayFilter,
),
2024-12-02 23:40:41 +01:00
],
"level": [
MapdataEndpoint(
model=Building,
schema=BuildingSchema,
filters=LevelGeometryFilter
),
MapdataEndpoint(
model=Space,
schema=SpaceSchema,
filters=SpaceFilters,
),
MapdataEndpoint(
model=Door,
schema=DoorSchema,
filters=LevelGeometryFilter,
)
],
"space": [
MapdataEndpoint(
model=Hole,
schema=HoleSchema,
filters=SpaceGeometryFilter,
),
MapdataEndpoint(
model=Area,
schema=AreaSchema,
filters=AreaFilters,
),
MapdataEndpoint(
model=Stair,
schema=StairSchema,
filters=SpaceGeometryFilter,
),
MapdataEndpoint(
model=Ramp,
schema=RampSchema,
filters=SpaceGeometryFilter,
),
MapdataEndpoint(
model=Obstacle,
schema=ObstacleSchema,
filters=SpaceGeometryFilter,
),
MapdataEndpoint(
model=LineObstacle,
schema=LineObstacleSchema,
filters=SpaceGeometryFilter,
),
MapdataEndpoint(
model=Column,
schema=ColumnSchema,
filters=SpaceGeometryFilter,
),
MapdataEndpoint(
model=POI,
schema=POISchema,
filters=SpaceGeometryFilter,
),
MapdataEndpoint(
model=LeaveDescription,
schema=LeaveDescriptionSchema,
filters=BySpaceFilter,
),
MapdataEndpoint(
model=CrossDescription,
schema=CrossDescriptionSchema,
filters=BySpaceFilter,
),
],
}
MapdataAPIBuilder(router=mapdata_api_router).build_all_endpoints(mapdata_endpoints)