diff --git a/src/c3nav/api/exceptions.py b/src/c3nav/api/exceptions.py index f620bb8b..2c81b8e5 100644 --- a/src/c3nav/api/exceptions.py +++ b/src/c3nav/api/exceptions.py @@ -1,3 +1,6 @@ +from c3nav.api.schema import APIErrorSchema + + class CustomAPIException(Exception): status_code = 400 detail = "" @@ -9,10 +12,9 @@ class CustomAPIException(Exception): def get_response(self, api, request): return api.create_response(request, {"detail": self.detail}, status=self.status_code) - -class API404(CustomAPIException): - status_code = 404 - detail = "Object not found." + @classmethod + def dict(cls): + return {cls.status_code: APIErrorSchema} class APIUnauthorized(CustomAPIException): @@ -29,3 +31,17 @@ class APIPermissionDenied(CustomAPIException): status_code = 403 detail = "Permission denied." + +class API404(CustomAPIException): + status_code = 404 + detail = "Object not found." + + +class APIConflict(CustomAPIException): + status_code = 409 + detail = "Conflict" + + +class APIRequestValidationFailed(CustomAPIException): + status_code = 422 + detail = "Bad request body." diff --git a/src/c3nav/api/newauth.py b/src/c3nav/api/newauth.py index 774acccc..42f8c1fd 100644 --- a/src/c3nav/api/newauth.py +++ b/src/c3nav/api/newauth.py @@ -1,5 +1,7 @@ +from collections import namedtuple from importlib import import_module +from django.contrib.auth import get_user as auth_get_user from django.contrib.auth.models import AnonymousUser from django.db.models import Q from ninja.security import HttpBearer @@ -9,9 +11,7 @@ from c3nav.api.exceptions import APIPermissionDenied, APITokenInvalid from c3nav.api.schema import APIErrorSchema from c3nav.control.models import UserPermissions - -class InvalidToken(Exception): - pass +FakeRequest = namedtuple('FakeRequest', ('session', )) class BearerAuth(HttpBearer): @@ -28,7 +28,8 @@ class BearerAuth(HttpBearer): elif token.startswith("session:"): session = self.SessionStore(token.removeprefix("session:")) # todo: ApiTokenInvalid? - return session.user + user = auth_get_user(FakeRequest(session=session)) + return user elif token.startswith("secret:"): try: user_perms = UserPermissions.objects.filter( @@ -51,6 +52,5 @@ class BearerAuth(HttpBearer): return user -auth_responses = {401: APIErrorSchema} -auth_permission_responses = {401: APIErrorSchema, 403: APIErrorSchema} - +auth_responses = {400: APIErrorSchema, 401: APIErrorSchema} +auth_permission_responses = {400: APIErrorSchema, 401: APIErrorSchema, 403: APIErrorSchema} diff --git a/src/c3nav/mesh/newapi.py b/src/c3nav/mesh/newapi.py index 04aef977..f12b7448 100644 --- a/src/c3nav/mesh/newapi.py +++ b/src/c3nav/mesh/newapi.py @@ -1,14 +1,13 @@ -import base64 from datetime import datetime -from django.db import transaction +from django.db import IntegrityError, transaction from ninja import Field as APIField -from ninja import ModelSchema from ninja import Router as APIRouter from ninja import Schema, UploadedFile from ninja.pagination import paginate from pydantic import validator +from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses from c3nav.mesh.dataformats import BoardType from c3nav.mesh.messages import ChipType @@ -30,6 +29,11 @@ class FirmwareBuildSchema(Schema): # todo: do this in model? idk return ChipType(obj.chip) + @staticmethod + def resolve_boards(obj): + print(obj.boards) + return obj.boards + class FirmwareSchema(Schema): id: int @@ -66,52 +70,20 @@ def firmware_detail(request, firmware_id: int): return 404, {"detail": "firmware not found"} -class Base64Bytes(bytes): - @classmethod - def __get_validators__(cls): - # one or more validators may be yielded which will be called in the - # order to validate the input, each validator will receive as an input - # the value returned from the previous validator - yield cls.validate - - @classmethod - def __modify_schema__(cls, field_schema): - # __modify_schema__ should mutate the dict it receives in place, - # the returned value will be ignored - field_schema.update( - # simplified regex here for brevity, see the wikipedia link above - pattern='^[A-Z]{1,2}[0-9][A-Z0-9]? ?[0-9][A-Z]{2}$', - # some example postcodes - examples=['SP11 9DG', 'w1j7bu'], - ) - - @classmethod - def validate(cls, v): - if not isinstance(v, str): - raise TypeError('string required') - return cls(base64.b64decode(v.encode("ascii"))) - - def __repr__(self): - return f'PostCode({super().__repr__()})' - - class UploadFirmwareBuildSchema(Schema): variant: str = APIField(..., example="c3uart") chip: ChipType = APIField(..., example=ChipType.ESP32_C3.name) sha256_hash: str = APIField(..., regex=r"^[0-9a-f]{64}$") boards: list[BoardType] = APIField(..., example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ]) - binary: bytes = APIField(..., example="base64", contentEncoding="base64") - - @validator('binary') - def get_binary_base64(cls, binary): - return base64.b64decode(binary.encode()) + project_description: dict = APIField(..., title='project_description.json contents') + uploaded_filename: str = APIField(..., example="firmware.bin") class UploadFirmwareSchema(Schema): project_name: str = APIField(..., example="c3nav_positioning") version: str = APIField(..., example="499837d-dirty") idf_version: str = APIField(..., example="v5.1-476-g3187b8b326") - builds: list[UploadFirmwareBuildSchema] = APIField(..., min_items=1) + builds: list[UploadFirmwareBuildSchema] = APIField(..., min_items=1, unique_items=True) @validator('builds') def builds_variants_must_be_unique(cls, builds): @@ -121,6 +93,51 @@ class UploadFirmwareSchema(Schema): @api_router.post('/firmwares/upload', summary="Upload firmware", auth=BearerAuth(superuser=True), - response={200: FirmwareSchema, **auth_permission_responses}) -def firmware_upload(request, firmware_data: UploadFirmwareSchema): - raise NotImplementedError + description="your OpenAPI viewer might not show it: firmware_data is UploadFirmwareSchema as json", + response={200: FirmwareSchema, **auth_permission_responses, **APIConflict.dict()}) +def firmware_upload(request, firmware_data: UploadFirmwareSchema, binary_files: list[UploadedFile]): + binary_files_by_name = {binary_file.name: binary_file for binary_file in binary_files} + if len([binary_file.name for binary_file in binary_files]) != len(binary_files_by_name): + raise APIRequestValidationFailed("Filenames of uploaded binary files must be unique.") + + build_filenames = [build_data.uploaded_filename for build_data in firmware_data.builds] + if len(build_filenames) != len(set(build_filenames)): + raise APIRequestValidationFailed("Builds need to refer to different unique binary file names.") + + if set(binary_files_by_name) != set(build_filenames): + raise APIRequestValidationFailed("All uploaded binary files need to be refered to by one build.") + + try: + with transaction.atomic(): + version = FirmwareVersion.objects.create( + project_name=firmware_data.project_name, + version=firmware_data.version, + idf_version=firmware_data.idf_version, + uploader=request.auth, + ) + + for build_data in firmware_data.builds: + # if bin_file.size > 4 * 1024 * 1024: + # raise ValueError # todo: better error + + # h = hashlib.sha256() + # h.update(build_data.binary) + # sha256_bin_file = h.hexdigest() # todo: verify sha256 correctly + # + # if sha256_bin_file != build_data.sha256_hash: + # raise ValueError + + build = version.builds.create( + variant=build_data.variant, + chip=build_data.chip, + sha256_hash=build_data.sha256_hash, + project_description=build_data.project_description, + binary=binary_files_by_name[build_data.uploaded_filename], + ) + + for board in build_data.boards: + build.firmwarebuildboard_set.create(board=board) + except IntegrityError: + raise APIConflict('Firmware version already exists.') + + return version