From b2aa76ba2ddceb2e6c2a8a41cea863ad6e3bc85b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Sat, 18 Nov 2023 21:29:35 +0100 Subject: [PATCH] update django-ninja, including pydantic v2 and add provisional level api --- src/c3nav/api/newauth.py | 5 +- src/c3nav/api/schema.py | 10 + src/c3nav/api/urls.py | 15 +- src/c3nav/api/utils.py | 45 ++-- .../management/commands/importgeojson.py | 199 ++++++++++++++++++ src/c3nav/mapdata/models/base.py | 2 +- src/c3nav/mapdata/newapi/__init__.py | 0 src/c3nav/mapdata/newapi/endpoints.py | 39 ++++ src/c3nav/mapdata/schemas/__init__.py | 0 src/c3nav/mapdata/schemas/filters.py | 16 ++ src/c3nav/mapdata/schemas/model_base.py | 122 +++++++++++ src/c3nav/mapdata/schemas/models.py | 31 +++ src/c3nav/mapdata/schemas/responses.py | 6 + src/c3nav/mesh/baseformats.py | 4 +- src/c3nav/mesh/dataformats.py | 4 +- src/c3nav/mesh/messages.py | 4 +- src/c3nav/mesh/models.py | 4 +- src/c3nav/mesh/newapi.py | 61 +++--- src/c3nav/routing/api.py | 2 +- src/requirements/production.txt | 2 +- 20 files changed, 510 insertions(+), 61 deletions(-) create mode 100644 src/c3nav/mapdata/management/commands/importgeojson.py create mode 100644 src/c3nav/mapdata/newapi/__init__.py create mode 100644 src/c3nav/mapdata/newapi/endpoints.py create mode 100644 src/c3nav/mapdata/schemas/__init__.py create mode 100644 src/c3nav/mapdata/schemas/filters.py create mode 100644 src/c3nav/mapdata/schemas/model_base.py create mode 100644 src/c3nav/mapdata/schemas/models.py create mode 100644 src/c3nav/mapdata/schemas/responses.py diff --git a/src/c3nav/api/newauth.py b/src/c3nav/api/newauth.py index 42f8c1fd..6c04692b 100644 --- a/src/c3nav/api/newauth.py +++ b/src/c3nav/api/newauth.py @@ -52,5 +52,6 @@ class BearerAuth(HttpBearer): return user -auth_responses = {400: APIErrorSchema, 401: APIErrorSchema} -auth_permission_responses = {400: APIErrorSchema, 401: APIErrorSchema, 403: APIErrorSchema} +validate_responses = {422: APIErrorSchema,} +auth_responses = {401: APIErrorSchema,} +auth_permission_responses = {401: APIErrorSchema, 403: APIErrorSchema,} diff --git a/src/c3nav/api/schema.py b/src/c3nav/api/schema.py index a3ba27ab..e6e7ca2c 100644 --- a/src/c3nav/api/schema.py +++ b/src/c3nav/api/schema.py @@ -1,5 +1,15 @@ +from typing import Literal + from ninja import Schema +from pydantic import Field as APIField class APIErrorSchema(Schema): detail: str + + +class PolygonSchema(Schema): + type: Literal["Polygon"] + coordinates: list[list[tuple[float, float]]] = APIField( + example=[[1.5, 1.5], [1.5, 2.5], [2.5, 2.5], [2.5, 2.5]] + ) diff --git a/src/c3nav/api/urls.py b/src/c3nav/api/urls.py index 8f7c83ab..56d1d879 100644 --- a/src/c3nav/api/urls.py +++ b/src/c3nav/api/urls.py @@ -2,9 +2,11 @@ import inspect import re from collections import OrderedDict +from django.conf import settings from django.urls import include, path, re_path from django.utils.functional import cached_property from ninja import NinjaAPI +from ninja.schema import NinjaGenerateJsonSchema from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework.routers import SimpleRouter @@ -20,16 +22,17 @@ from c3nav.mapdata.api import (AccessRestrictionGroupViewSet, AccessRestrictionV LocationBySlugViewSet, LocationGroupCategoryViewSet, LocationGroupViewSet, LocationViewSet, MapViewSet, ObstacleViewSet, POIViewSet, RampViewSet, SourceViewSet, SpaceViewSet, StairViewSet, UpdatesViewSet) +from c3nav.mapdata.newapi.endpoints import map_api_router from c3nav.mapdata.utils.user import can_access_editor from c3nav.mesh.api import FirmwareViewSet -from c3nav.mesh.newapi import api_router as mesh_api_router +from c3nav.mesh.newapi import mesh_api_router from c3nav.routing.api import RoutingViewSet ninja_api = NinjaAPI( title="c3nav API", version="v2", docs_url="/", - auth=BearerAuth(), + auth=(None if settings.DEBUG else BearerAuth()), ) @@ -38,7 +41,15 @@ def on_invalid_token(request, exc): return ninja_api.create_response(request, {"detail": exc.detail}, status=exc.status_code) +# ugly hack: remove schema from the end of definition names +orig_normalize_name = NinjaGenerateJsonSchema.normalize_name +def wrap_normalize_name(self, name: str): # noqa + return orig_normalize_name(self, name).removesuffix('Schema') +NinjaGenerateJsonSchema.normalize_name = wrap_normalize_name # noqa + + ninja_api.add_router("/auth/", auth_api_router) +ninja_api.add_router("/map/", map_api_router) ninja_api.add_router("/mesh/", mesh_api_router) router = SimpleRouter() diff --git a/src/c3nav/api/utils.py b/src/c3nav/api/utils.py index edf90ceb..973d2866 100644 --- a/src/c3nav/api/utils.py +++ b/src/c3nav/api/utils.py @@ -1,7 +1,9 @@ -from enum import EnumMeta -from typing import Any, Callable, Iterator, Optional, cast +from typing import Annotated, Any, Type -from pydantic.fields import ModelField +import annotated_types +from pydantic import AfterValidator, GetCoreSchemaHandler, GetJsonSchemaHandler, PlainSerializer +from pydantic.json_schema import JsonSchemaValue, WithJsonSchema +from pydantic_core import CoreSchema, core_schema from rest_framework.exceptions import ParseError @@ -18,23 +20,34 @@ def get_api_post_data(request): class EnumSchemaByNameMixin: @classmethod - def __modify_schema__(cls, field_schema: dict[str, Any], field: Optional[ModelField]) -> None: - if field is None: - return - field_schema["enum"] = list(cast(EnumMeta, field.type_).__members__.keys()) - field_schema["type"] = "string" + def __get_pydantic_json_schema__( + cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + json_schema = handler(core_schema) + json_schema = handler.resolve_ref_schema(json_schema) + json_schema["enum"] = [m.name for m in cls] + json_schema["type"] = "string" + return json_schema @classmethod - def _validate(cls, v: Any, field: ModelField) -> Any: + def __get_pydantic_core_schema__( + cls, source: Type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.no_info_after_validator_function( + cls.validate, + core_schema.any_schema(), + serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x.name), + ) + + @classmethod + def validate(cls, v: int): if isinstance(v, cls): - # it's already the object, so it's going to json, return string - return v.name + return v try: - # it's a string, so it's coming from json, return object return cls[v] except KeyError: - raise ValueError(f"Invalid value for {cls}: `{v}`") + pass + return cls(v) - @classmethod - def __get_validators__(cls) -> Iterator[Callable[..., Any]]: - yield cls._validate + +NonEmptyStr = Annotated[str, annotated_types.MinLen(1)] diff --git a/src/c3nav/mapdata/management/commands/importgeojson.py b/src/c3nav/mapdata/management/commands/importgeojson.py new file mode 100644 index 00000000..b12e06da --- /dev/null +++ b/src/c3nav/mapdata/management/commands/importgeojson.py @@ -0,0 +1,199 @@ +import argparse +import logging +import re +from xml.etree import ElementTree + +from django.core.management.base import BaseCommand, CommandError +from django.utils.translation import gettext_lazy as _ +from shapely.affinity import scale, translate +from shapely.geometry import Polygon + +from c3nav.mapdata.models import Area, MapUpdate, Obstacle, Space +from c3nav.mapdata.utils.cache.changes import changed_geometries + + +class Command(BaseCommand): + help = 'render the map' + + @staticmethod + def space_value(value): + try: + space = Space.objects.get(pk=value) + except Space.DoesNotExist: + raise argparse.ArgumentTypeError( + _('unknown space') + ) + return space + + def add_arguments(self, parser): + parser.add_argument('svgfile', type=argparse.FileType('r'), help=_('svg file to import')) + parser.add_argument('name', type=str, help=_('name of the import')) + parser.add_argument('--type', type=str, required=True, choices=('buildings', 'areas', 'obstacles'), + help=_('type of objects to create')) + parser.add_argument('--space', type=self.space_value, required=True, + help=_('space to add the objects to')) + parser.add_argument('--minx', type=float, required=True, + help=_('minimum x coordinate, everthing left of it will be cropped')) + parser.add_argument('--miny', type=float, required=True, + help=_('minimum y coordinate, everthing below it will be cropped')) + parser.add_argument('--maxx', type=float, required=True, + help=_('maximum x coordinate, everthing right of it will be cropped')) + parser.add_argument('--maxy', type=float, required=True, + help=_('maximum y coordinate, everthing above it will be cropped')) + + @staticmethod + def parse_svg_data(data): + first = False + + last_point = (0, 0) + last_end_point = None + + done_subpaths = [] + current_subpath = [] + while data: + data = data.lstrip().replace(',', ' ') + command = data[0] + if first and command not in 'Mm': + raise ValueError('path data has to start with moveto command.') + data = data[1:].lstrip() + first = False + + numbers = [] + while True: + match = re.match(r'^-?[0-9]+(\.[0-9]+)?(e-?[0-9]+)?', data) + if match is None: + break + numbers.append(float(match.group(0))) + data = data[len(match.group(0)):].lstrip() + + relative = command.islower() + if command in 'Mm': + if not len(numbers) or len(numbers) % 2: + raise ValueError('Invalid number of arguments for moveto command!') + numbers = iter(numbers) + first = True + for x, y in zip(numbers, numbers): + if relative: + x, y = last_point[0] + x, last_point[1] + y + if first: + first = False + if current_subpath: + done_subpaths.append(current_subpath) + last_end_point = current_subpath[-1] + current_subpath = [] + current_subpath.append((x, y)) + last_point = (x, y) + + elif command in 'Ll': + if not len(numbers) or len(numbers) % 2: + raise ValueError('Invalid number of arguments for lineto command!') + numbers = iter(numbers) + for x, y in zip(numbers, numbers): + if relative: + x, y = last_point[0] + x, last_point[1] + y + if not current_subpath: + current_subpath.append(last_end_point) + current_subpath.append((x, y)) + last_point = (x, y) + + elif command in 'Hh': + if not len(numbers): + raise ValueError('Invalid number of arguments for horizontal lineto command!') + y = last_point[1] + for x in numbers: + if relative: + x = last_point[0] + x + if not current_subpath: + current_subpath.append(last_end_point) + current_subpath.append((x, y)) + last_point = (x, y) + + elif command in 'Vv': + if not len(numbers): + raise ValueError('Invalid number of arguments for vertical lineto command!') + x = last_point[0] + for y in numbers: + if relative: + y = last_point[1] + y + if not current_subpath: + current_subpath.append(last_end_point) + current_subpath.append((x, y)) + last_point = (x, y) + + elif command in 'Zz': + if numbers: + raise ValueError('Invalid number of arguments for closepath command!') + current_subpath.append(current_subpath[0]) + done_subpaths.append(current_subpath) + last_end_point = current_subpath[-1] + current_subpath = [] + + else: + raise ValueError('unknown svg command: ' + command) + + if current_subpath: + done_subpaths.append(current_subpath) + return done_subpaths + + def handle(self, *args, **options): + minx = options['minx'] + miny = options['miny'] + maxx = options['maxx'] + maxy = options['maxy'] + + if minx >= maxx: + raise CommandError(_('minx has to be lower than maxx')) + if miny >= maxy: + raise CommandError(_('miny has to be lower than maxy')) + + width = maxx-minx + height = maxy-miny + + model = {'areas': Area, 'obstacles': Obstacle}[options['type']] + + namespaces = {'svg': 'http://www.w3.org/2000/svg'} + + svg = ElementTree.fromstring(options['svgfile'].read()) + svg_width = float(svg.attrib['width']) + svg_height = float(svg.attrib['height']) + svg_viewbox = svg.attrib.get('viewBox') + + if svg_viewbox: + offset_x, offset_y, svg_width, svg_height = [float(i) for i in svg_viewbox.split(' ')] + else: + offset_x, offset_y = 0, 0 + + for element in svg.findall('.//svg:clipPath/..', namespaces): + for clippath in element.findall('./svg:clipPath', namespaces): + element.remove(clippath) + + for element in svg.findall('.//svg:symbol/..', namespaces): + for clippath in element.findall('./svg:symbol', namespaces): + element.remove(clippath) + + if svg.findall('.//*[@transform]'): + raise CommandError(_('svg contains transform attributes. Use inkscape apply transforms.')) + + if model.objects.filter(space=options['space'], import_tag=options['name']).exists(): + raise CommandError(_('objects with this import tag already exist in this space.')) + + with MapUpdate.lock(): + changed_geometries.reset() + for path in svg.findall('.//svg:path', namespaces): + for polygon in self.parse_svg_data(path.attrib['d']): + if len(polygon) < 3: + continue + polygon = Polygon(polygon) + polygon = translate(polygon, xoff=-offset_x, yoff=-offset_y) + polygon = scale(polygon, xfact=1, yfact=-1, origin=(0, svg_height/2)) + polygon = scale(polygon, xfact=width / svg_width, yfact=height / svg_height, origin=(0, 0)) + polygon = translate(polygon, xoff=minx, yoff=miny) + obj = model(geometry=polygon, space=options['space'], import_tag=options['name']) + obj.save() + MapUpdate.objects.create(type='importsvg') + + logger = logging.getLogger('c3nav') + logger.info('Imported, map update created.') + logger.info('Next step: go into the shell and edit them using ' + '%s.objects.filter(space_id=%r, import_tag=%r)' % + (model.__name__, options['space'].pk, options['name'])) diff --git a/src/c3nav/mapdata/models/base.py b/src/c3nav/mapdata/models/base.py index 66b5ecac..24764dc2 100644 --- a/src/c3nav/mapdata/models/base.py +++ b/src/c3nav/mapdata/models/base.py @@ -74,7 +74,7 @@ class TitledMixin(SerializableMixin, models.Model): result = super()._serialize(detailed=detailed, **kwargs) if detailed: result['titles'] = self.titles - result['title'] = self.title + result['title'] = str(self.title) return result @property diff --git a/src/c3nav/mapdata/newapi/__init__.py b/src/c3nav/mapdata/newapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/c3nav/mapdata/newapi/endpoints.py b/src/c3nav/mapdata/newapi/endpoints.py new file mode 100644 index 00000000..64a400b0 --- /dev/null +++ b/src/c3nav/mapdata/newapi/endpoints.py @@ -0,0 +1,39 @@ +from ninja import Query +from ninja import Router as APIRouter +from ninja.pagination import paginate + +from c3nav.api.exceptions import API404 +from c3nav.api.newauth import auth_responses +from c3nav.mapdata.models import Level, Source +from c3nav.mapdata.schemas.filters import LevelFilters +from c3nav.mapdata.schemas.models import LevelSchema +from c3nav.mapdata.schemas.responses import BoundsSchema + +map_api_router = APIRouter(tags=["map"]) + + +@map_api_router.get('/bounds/', summary="Get map boundaries", + response={200: BoundsSchema, **auth_responses}) +def bounds(request): + return { + "bounds": Source.max_bounds(), + } + + +@map_api_router.get('/levels/', response=list[LevelSchema], + summary="List available levels") +@paginate +def levels_list(request, filters: Query[LevelFilters]): + # todo: access, caching, filtering, etc + return Level.objects.all() + + +@map_api_router.get('/levels/{level_id}/', response=LevelSchema, + summary="List available levels") +def level_detail(request, level_id: int): + # todo: access, caching, filtering, etc + try: + level = Level.objects.get(id=level_id) + except Level.DoesNotExist: + raise API404("level not found") + return level diff --git a/src/c3nav/mapdata/schemas/__init__.py b/src/c3nav/mapdata/schemas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/c3nav/mapdata/schemas/filters.py b/src/c3nav/mapdata/schemas/filters.py new file mode 100644 index 00000000..152f3ea4 --- /dev/null +++ b/src/c3nav/mapdata/schemas/filters.py @@ -0,0 +1,16 @@ +from typing import Literal, Optional + +from ninja import Schema +from pydantic import Field as APIField + + +class LevelFilters(Schema): + on_top_of: Optional[Literal["null"] | int] = APIField( + None, + name='filter by on top of level ID (or "null")', + description='if set, only levels on top of this level (or "null" for no level) will be shown' + ) + group: Optional[int] = APIField( + None, + name="filter by location group" + ) diff --git a/src/c3nav/mapdata/schemas/model_base.py b/src/c3nav/mapdata/schemas/model_base.py new file mode 100644 index 00000000..b05301d6 --- /dev/null +++ b/src/c3nav/mapdata/schemas/model_base.py @@ -0,0 +1,122 @@ +from functools import cached_property +from typing import Annotated, Any, Optional + +import annotated_types +from ninja import Schema +from pydantic import Field as APIField +from pydantic import PositiveInt, model_validator +from pydantic.functional_validators import ModelWrapValidatorHandler +from pydantic.utils import GetterDict +from pydantic_core.core_schema import ValidationInfo + +from c3nav.api.utils import NonEmptyStr + + +class DjangoModelSchema(Schema): + id: PositiveInt = APIField( + title="ID", + ) + + +class SerializableSchema(Schema): + @model_validator(mode="wrap") # noqa + @classmethod + def _run_root_validator(cls, values: Any, handler: ModelWrapValidatorHandler[Schema], info: ValidationInfo) -> Any: + """ overwriting this, we need to call serialize to get the correct data """ + values = values.serialize() + return handler(values) + + +class LocationSlugSchema(Schema): + slug: NonEmptyStr = APIField( + title="location slug", + description="a slug is a unique way to refer to a location across all location types. " + "locations can have a human-readable slug. " + "if it doesn't, this field holds a slug generated based from the location type and ID. " + "this slug will work even if a human-readable slug is defined later. " + "even dynamic locations like coordinates have a slug.", + ) + + +class AccessRestrictionSchema(Schema): + access_restriction: Optional[PositiveInt] = APIField( + default=None, + title="access restriction ID", + ) + + +class TitledSchema(Schema): + titles: dict[NonEmptyStr, NonEmptyStr] = APIField( + title="title (all languages)", + description="property names are the ISO-language code. languages may be missing.", + example={ + "en": "Title", + "de": "Titel", + } + ) + title: NonEmptyStr = APIField( + title="title (preferred language)", + description="preferred language based on the Accept-Language header." + ) + + +class LocationSchema(AccessRestrictionSchema, TitledSchema, LocationSlugSchema, SerializableSchema): + subtitle: NonEmptyStr = APIField( + title="subtitle (preferred language)", + description="an automatically generated short description for this location. " + "preferred language based on the Accept-Language header." + ) + icon: Optional[NonEmptyStr] = APIField( + default=None, + title="icon name", + description="any material design icon name" + ) + can_search: bool = APIField( + title="can be searched", + ) + can_describe: bool = APIField( + title="can describe locations", + ) + # todo: add_search + + +class LabelSettingsSchema(TitledSchema, Schema): + min_zoom: float = APIField( + title="min zoom", + ) + max_zoom: float = APIField( + title="max zoom", + ) + font_size: PositiveInt = APIField( + title="font size", + ) + + +class SpecificLocationSchema(LocationSchema, DjangoModelSchema): + grid_square: Optional[NonEmptyStr] = APIField( + default=None, + title="grid square", + description="if a grid is defined and this location is within it", + ) + groups: dict[NonEmptyStr, list[PositiveInt] | Optional[PositiveInt]] = APIField( + title="location groups", + description="grouped by location group categories. " + "property names are the names of location groupes. " + "property values are integer, None or a list of integers, see example." + "see location group category endpoint for currently available possibilities." + "categories may be missing if no groups apply.", + example={ + "category_with_single_true": 5, + "other_category_with_single_true": None, + "categoryother_category_with_single_false": [1, 3, 7], + } + ) + label_settings: Optional[LabelSettingsSchema] = APIField( + default=None, + title="label settings", + ) + label_override: Optional[NonEmptyStr] = APIField( + default=None, + title="label override (preferred language)", + description="preferred language based on the Accept-Language header." + ) diff --git a/src/c3nav/mapdata/schemas/models.py b/src/c3nav/mapdata/schemas/models.py new file mode 100644 index 00000000..d511948e --- /dev/null +++ b/src/c3nav/mapdata/schemas/models.py @@ -0,0 +1,31 @@ +from typing import Optional + +from ninja import Schema +from pydantic import Field as APIField +from pydantic import PositiveFloat, PositiveInt + +from c3nav.api.utils import NonEmptyStr +from c3nav.mapdata.schemas.model_base import SpecificLocationSchema + + +class LevelSchema(SpecificLocationSchema): + short_label: NonEmptyStr = APIField( + title="short label (for level selector)", + description="unique among levels", + ) + on_top_of: Optional[PositiveInt] = APIField( + title="on top of level ID", + description="if set, this is not a main level, but it's on top of this other level" + ) + base_altitude: float = APIField( + title="base/default altitude", + ) + default_height: PositiveFloat = APIField( + title="default ceiling height", + ) + door_height: PositiveFloat = APIField( + title="door height", + ) + + class Config(Schema.Config): + title = "Level" diff --git a/src/c3nav/mapdata/schemas/responses.py b/src/c3nav/mapdata/schemas/responses.py new file mode 100644 index 00000000..e75c237b --- /dev/null +++ b/src/c3nav/mapdata/schemas/responses.py @@ -0,0 +1,6 @@ +from ninja import Schema +from pydantic import Field as APIField + + +class BoundsSchema(Schema): + bounds: tuple[tuple[float, float], tuple[float, float]] = APIField(..., example=((-10, -20), (20, 30))) diff --git a/src/c3nav/mesh/baseformats.py b/src/c3nav/mesh/baseformats.py index b64d6a1a..879ca138 100644 --- a/src/c3nav/mesh/baseformats.py +++ b/src/c3nav/mesh/baseformats.py @@ -1,7 +1,8 @@ import re import struct from abc import ABC, abstractmethod -from dataclasses import Field, dataclass, fields as dataclass_fields +from dataclasses import Field, dataclass +from dataclasses import fields as dataclass_fields from typing import Any, Self, Sequence from pydantic import create_model @@ -379,7 +380,6 @@ class StructType: else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__class__.__name__, name)) - print(cls.__name__, cls._pydantic_fields) cls.schema = create_model(cls.__name__+'Schema', **cls._pydantic_fields) super().__init_subclass__(**kwargs) diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py index 5194ea7d..70db55b4 100644 --- a/src/c3nav/mesh/dataformats.py +++ b/src/c3nav/mesh/dataformats.py @@ -4,8 +4,8 @@ from enum import IntEnum, unique from typing import BinaryIO, Self from c3nav.api.utils import EnumSchemaByNameMixin -from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedHexFormat, FixedStrFormat, SimpleFormat, StructType, - VarArrayFormat, TwoNibblesEnumFormat, ChipRevFormat, SimpleConstFormat) +from c3nav.mesh.baseformats import (BoolFormat, ChipRevFormat, EnumFormat, FixedHexFormat, FixedStrFormat, + SimpleConstFormat, SimpleFormat, StructType, TwoNibblesEnumFormat, VarArrayFormat) class MacAddressFormat(FixedHexFormat): diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index 8c9d4f64..fbbdbaf9 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -7,8 +7,8 @@ from channels.db import database_sync_to_async from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat, VarBytesFormat, VarStrFormat, normalize_name) -from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat, - RangeResultItem, RawFTMEntry, ChipType) +from c3nav.mesh.dataformats import (BoardConfig, ChipType, FirmwareAppDescription, MacAddressesListFormat, + MacAddressFormat, RangeResultItem, RawFTMEntry) from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP MESH_ROOT_ADDRESS = '00:00:00:00:00:00' diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py index 62a793b8..0a501589 100644 --- a/src/c3nav/mesh/models.py +++ b/src/c3nav/mesh/models.py @@ -371,10 +371,10 @@ class FirmwareBuild(models.Model): unique_together = [ ('version', 'variant'), ] - @property def boards(self): - return {BoardType[board.board] for board in self.firmwarebuildboard_set.all()} + return {BoardType[board.board] for board in self.firmwarebuildboard_set.all() + if board.board in BoardType._member_names_} @property def chip_type(self) -> ChipType: diff --git a/src/c3nav/mesh/newapi.py b/src/c3nav/mesh/newapi.py index 55427715..2405e7cf 100644 --- a/src/c3nav/mesh/newapi.py +++ b/src/c3nav/mesh/newapi.py @@ -1,44 +1,43 @@ from datetime import datetime +from pathlib import Path from django.db import IntegrityError, transaction from ninja import Field as APIField from ninja import Router as APIRouter from ninja import Schema, UploadedFile from ninja.pagination import paginate -from pydantic import validator +from pydantic import PositiveInt, field_validator -from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed, API404 -from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses -from c3nav.mesh.dataformats import BoardType, FirmwareImage, ChipType -from c3nav.mesh.models import FirmwareVersion, FirmwareBuild +from c3nav.api.exceptions import API404, APIConflict, APIRequestValidationFailed +from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses, validate_responses +from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage +from c3nav.mesh.models import FirmwareBuild, FirmwareVersion -api_router = APIRouter(tags=["mesh"]) +mesh_api_router = APIRouter(tags=["mesh"]) class FirmwareBuildSchema(Schema): - id: int + id: PositiveInt variant: str = APIField(..., example="c3uart") chip: ChipType = APIField(..., example=ChipType.ESP32_C3.name) - sha256_hash: str = APIField(..., regex=r"^[0-9a-f]{64}$") - url: str = APIField(..., alias="binary", example="/media/firmware/012345/firmware.bin") + sha256_hash: str = APIField(..., pattern=r"^[0-9a-f]{64}$") + url: str = APIField(..., alias="binary", example="/media/firmware/012345/firmware.bin") # todo: downlaod differently? # todo: should not be none, but parse errors boards: list[BoardType] = APIField(None, example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ]) - @staticmethod - def resolve_chip(obj): - # todo: do this in model? idk - return ChipType(obj.chip) + class Config(Schema.Config): + pass class FirmwareSchema(Schema): - id: int + id: PositiveInt project_name: str = APIField(..., example="c3nav_positioning") version: str = APIField(..., example="499837d-dirty") idf_version: str = APIField(..., example="v5.1-476-g3187b8b326") created: datetime builds: list[FirmwareBuildSchema] - @validator('builds') + @field_validator('builds') def builds_variants_must_be_unique(cls, builds): if len(set(build.variant for build in builds)) != len(builds): raise ValueError("builds must have unique variant identifiers") @@ -49,15 +48,15 @@ class Error(Schema): detail: str -@api_router.get('/firmwares/', summary="List available firmwares", - response={200: list[FirmwareSchema], **auth_responses}) +@mesh_api_router.get('/firmwares/', summary="List available firmwares", + response={200: list[FirmwareSchema], **validate_responses, **auth_responses}) @paginate def firmware_list(request): return FirmwareVersion.objects.all() -@api_router.get('/firmwares/{firmware_id}/', summary="Get specific firmware", - response={200: FirmwareSchema, **API404.dict(), **auth_responses}) +@mesh_api_router.get('/firmwares/{firmware_id}/', summary="Get specific firmware", + response={200: FirmwareSchema, **API404.dict(), **auth_responses}) def firmware_detail(request, firmware_id: int): try: return FirmwareVersion.objects.get(id=firmware_id) @@ -65,9 +64,9 @@ def firmware_detail(request, firmware_id: int): raise API404("Firmware not found") -@api_router.get('/firmwares/{firmware_id}/{variant}/image_data', - summary="Get header data of firmware build image", - response={200: FirmwareImage.schema, **API404.dict(), **auth_responses}) +@mesh_api_router.get('/firmwares/{firmware_id}/{variant}/image_data', + summary="Get header data of firmware build image", + response={200: FirmwareImage.schema, **API404.dict(), **auth_responses}) def firmware_build_image(request, firmware_id: int, variant: str): try: build = FirmwareBuild.objects.get(version_id=firmware_id, variant=variant) @@ -76,9 +75,9 @@ def firmware_build_image(request, firmware_id: int, variant: str): raise API404("Firmware or firmware build not found") -@api_router.get('/firmwares/{firmware_id}/{variant}/project_description', - summary="Get project description of firmware build", - response={200: dict, **API404.dict(), **auth_responses}) +@mesh_api_router.get('/firmwares/{firmware_id}/{variant}/project_description', + summary="Get project description of firmware build", + response={200: dict, **API404.dict(), **auth_responses}) def firmware_project_description(request, firmware_id: int, variant: str): try: return FirmwareBuild.objects.get(version_id=firmware_id, variant=variant).firmware_description @@ -97,18 +96,20 @@ class UploadFirmwareSchema(Schema): project_name: str = APIField(..., example="c3nav_positioning") version: str = APIField(..., example="499837d-dirty") idf_version: str = APIField(..., example="v5.1-476-g3187b8b326") - builds: list[UploadFirmwareBuildSchema] = APIField(..., min_items=1, unique_items=True) + builds: list[UploadFirmwareBuildSchema] = APIField(..., min_items=1) - @validator('builds') + @field_validator('builds') def builds_variants_must_be_unique(cls, builds): if len(set(build.variant for build in builds)) != len(builds): raise ValueError("builds must have unique variant identifiers") return builds -@api_router.post('/firmwares/upload', summary="Upload firmware", auth=BearerAuth(superuser=True), - description="your OpenAPI viewer might not show it: firmware_data is UploadFirmwareSchema as json", - response={200: FirmwareSchema, **auth_permission_responses, **APIConflict.dict()}) +@mesh_api_router.post( + '/firmwares/upload', summary="Upload firmware", auth=BearerAuth(superuser=True), + description="your OpenAPI viewer might not show it: firmware_data is UploadFirmwareSchema as json", + response={200: FirmwareSchema, **validate_responses, **auth_permission_responses, **APIConflict.dict()} +) def firmware_upload(request, firmware_data: UploadFirmwareSchema, binary_files: list[UploadedFile]): binary_files_by_name = {binary_file.name: binary_file for binary_file in binary_files} if len([binary_file.name for binary_file in binary_files]) != len(binary_files_by_name): diff --git a/src/c3nav/routing/api.py b/src/c3nav/routing/api.py index 17f7d300..fc4fc431 100644 --- a/src/c3nav/routing/api.py +++ b/src/c3nav/routing/api.py @@ -172,8 +172,8 @@ class RoutingViewSet(ViewSet): @action(detail=False) def locate_test(self, request): - from c3nav.mesh.models import MeshNode from c3nav.mesh.messages import MeshMessageType + from c3nav.mesh.models import MeshNode try: node = MeshNode.objects.prefetch_last_messages(MeshMessageType.LOCATE_RANGE_RESULTS).get( address="d4:f9:8d:2d:0d:f1" diff --git a/src/requirements/production.txt b/src/requirements/production.txt index 5a0a87ac..8a8f9670 100644 --- a/src/requirements/production.txt +++ b/src/requirements/production.txt @@ -3,7 +3,7 @@ django-bootstrap3==23.1 django-compressor==4.3.1 csscompressor==0.9.5 djangorestframework==3.14.0 -django-ninja==0.22.2 +django-ninja==1.0.1 django-filter==23.2 shapely==2.0.1 pybind11==2.10.4