add firmware upload API
This commit is contained in:
parent
44b6cc61e6
commit
2d97f9bb87
3 changed files with 85 additions and 52 deletions
|
@ -1,3 +1,6 @@
|
||||||
|
from c3nav.api.schema import APIErrorSchema
|
||||||
|
|
||||||
|
|
||||||
class CustomAPIException(Exception):
|
class CustomAPIException(Exception):
|
||||||
status_code = 400
|
status_code = 400
|
||||||
detail = ""
|
detail = ""
|
||||||
|
@ -9,10 +12,9 @@ class CustomAPIException(Exception):
|
||||||
def get_response(self, api, request):
|
def get_response(self, api, request):
|
||||||
return api.create_response(request, {"detail": self.detail}, status=self.status_code)
|
return api.create_response(request, {"detail": self.detail}, status=self.status_code)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
class API404(CustomAPIException):
|
def dict(cls):
|
||||||
status_code = 404
|
return {cls.status_code: APIErrorSchema}
|
||||||
detail = "Object not found."
|
|
||||||
|
|
||||||
|
|
||||||
class APIUnauthorized(CustomAPIException):
|
class APIUnauthorized(CustomAPIException):
|
||||||
|
@ -29,3 +31,17 @@ class APIPermissionDenied(CustomAPIException):
|
||||||
status_code = 403
|
status_code = 403
|
||||||
detail = "Permission denied."
|
detail = "Permission denied."
|
||||||
|
|
||||||
|
|
||||||
|
class API404(CustomAPIException):
|
||||||
|
status_code = 404
|
||||||
|
detail = "Object not found."
|
||||||
|
|
||||||
|
|
||||||
|
class APIConflict(CustomAPIException):
|
||||||
|
status_code = 409
|
||||||
|
detail = "Conflict"
|
||||||
|
|
||||||
|
|
||||||
|
class APIRequestValidationFailed(CustomAPIException):
|
||||||
|
status_code = 422
|
||||||
|
detail = "Bad request body."
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
|
from collections import namedtuple
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
|
from django.contrib.auth import get_user as auth_get_user
|
||||||
from django.contrib.auth.models import AnonymousUser
|
from django.contrib.auth.models import AnonymousUser
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from ninja.security import HttpBearer
|
from ninja.security import HttpBearer
|
||||||
|
@ -9,9 +11,7 @@ from c3nav.api.exceptions import APIPermissionDenied, APITokenInvalid
|
||||||
from c3nav.api.schema import APIErrorSchema
|
from c3nav.api.schema import APIErrorSchema
|
||||||
from c3nav.control.models import UserPermissions
|
from c3nav.control.models import UserPermissions
|
||||||
|
|
||||||
|
FakeRequest = namedtuple('FakeRequest', ('session', ))
|
||||||
class InvalidToken(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BearerAuth(HttpBearer):
|
class BearerAuth(HttpBearer):
|
||||||
|
@ -28,7 +28,8 @@ class BearerAuth(HttpBearer):
|
||||||
elif token.startswith("session:"):
|
elif token.startswith("session:"):
|
||||||
session = self.SessionStore(token.removeprefix("session:"))
|
session = self.SessionStore(token.removeprefix("session:"))
|
||||||
# todo: ApiTokenInvalid?
|
# todo: ApiTokenInvalid?
|
||||||
return session.user
|
user = auth_get_user(FakeRequest(session=session))
|
||||||
|
return user
|
||||||
elif token.startswith("secret:"):
|
elif token.startswith("secret:"):
|
||||||
try:
|
try:
|
||||||
user_perms = UserPermissions.objects.filter(
|
user_perms = UserPermissions.objects.filter(
|
||||||
|
@ -51,6 +52,5 @@ class BearerAuth(HttpBearer):
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
auth_responses = {401: APIErrorSchema}
|
auth_responses = {400: APIErrorSchema, 401: APIErrorSchema}
|
||||||
auth_permission_responses = {401: APIErrorSchema, 403: APIErrorSchema}
|
auth_permission_responses = {400: APIErrorSchema, 401: APIErrorSchema, 403: APIErrorSchema}
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
import base64
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from django.db import transaction
|
from django.db import IntegrityError, transaction
|
||||||
from ninja import Field as APIField
|
from ninja import Field as APIField
|
||||||
from ninja import ModelSchema
|
|
||||||
from ninja import Router as APIRouter
|
from ninja import Router as APIRouter
|
||||||
from ninja import Schema, UploadedFile
|
from ninja import Schema, UploadedFile
|
||||||
from ninja.pagination import paginate
|
from ninja.pagination import paginate
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
|
|
||||||
|
from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed
|
||||||
from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses
|
from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses
|
||||||
from c3nav.mesh.dataformats import BoardType
|
from c3nav.mesh.dataformats import BoardType
|
||||||
from c3nav.mesh.messages import ChipType
|
from c3nav.mesh.messages import ChipType
|
||||||
|
@ -30,6 +29,11 @@ class FirmwareBuildSchema(Schema):
|
||||||
# todo: do this in model? idk
|
# todo: do this in model? idk
|
||||||
return ChipType(obj.chip)
|
return ChipType(obj.chip)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve_boards(obj):
|
||||||
|
print(obj.boards)
|
||||||
|
return obj.boards
|
||||||
|
|
||||||
|
|
||||||
class FirmwareSchema(Schema):
|
class FirmwareSchema(Schema):
|
||||||
id: int
|
id: int
|
||||||
|
@ -66,52 +70,20 @@ def firmware_detail(request, firmware_id: int):
|
||||||
return 404, {"detail": "firmware not found"}
|
return 404, {"detail": "firmware not found"}
|
||||||
|
|
||||||
|
|
||||||
class Base64Bytes(bytes):
|
|
||||||
@classmethod
|
|
||||||
def __get_validators__(cls):
|
|
||||||
# one or more validators may be yielded which will be called in the
|
|
||||||
# order to validate the input, each validator will receive as an input
|
|
||||||
# the value returned from the previous validator
|
|
||||||
yield cls.validate
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __modify_schema__(cls, field_schema):
|
|
||||||
# __modify_schema__ should mutate the dict it receives in place,
|
|
||||||
# the returned value will be ignored
|
|
||||||
field_schema.update(
|
|
||||||
# simplified regex here for brevity, see the wikipedia link above
|
|
||||||
pattern='^[A-Z]{1,2}[0-9][A-Z0-9]? ?[0-9][A-Z]{2}$',
|
|
||||||
# some example postcodes
|
|
||||||
examples=['SP11 9DG', 'w1j7bu'],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate(cls, v):
|
|
||||||
if not isinstance(v, str):
|
|
||||||
raise TypeError('string required')
|
|
||||||
return cls(base64.b64decode(v.encode("ascii")))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'PostCode({super().__repr__()})'
|
|
||||||
|
|
||||||
|
|
||||||
class UploadFirmwareBuildSchema(Schema):
|
class UploadFirmwareBuildSchema(Schema):
|
||||||
variant: str = APIField(..., example="c3uart")
|
variant: str = APIField(..., example="c3uart")
|
||||||
chip: ChipType = APIField(..., example=ChipType.ESP32_C3.name)
|
chip: ChipType = APIField(..., example=ChipType.ESP32_C3.name)
|
||||||
sha256_hash: str = APIField(..., regex=r"^[0-9a-f]{64}$")
|
sha256_hash: str = APIField(..., regex=r"^[0-9a-f]{64}$")
|
||||||
boards: list[BoardType] = APIField(..., example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ])
|
boards: list[BoardType] = APIField(..., example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ])
|
||||||
binary: bytes = APIField(..., example="base64", contentEncoding="base64")
|
project_description: dict = APIField(..., title='project_description.json contents')
|
||||||
|
uploaded_filename: str = APIField(..., example="firmware.bin")
|
||||||
@validator('binary')
|
|
||||||
def get_binary_base64(cls, binary):
|
|
||||||
return base64.b64decode(binary.encode())
|
|
||||||
|
|
||||||
|
|
||||||
class UploadFirmwareSchema(Schema):
|
class UploadFirmwareSchema(Schema):
|
||||||
project_name: str = APIField(..., example="c3nav_positioning")
|
project_name: str = APIField(..., example="c3nav_positioning")
|
||||||
version: str = APIField(..., example="499837d-dirty")
|
version: str = APIField(..., example="499837d-dirty")
|
||||||
idf_version: str = APIField(..., example="v5.1-476-g3187b8b326")
|
idf_version: str = APIField(..., example="v5.1-476-g3187b8b326")
|
||||||
builds: list[UploadFirmwareBuildSchema] = APIField(..., min_items=1)
|
builds: list[UploadFirmwareBuildSchema] = APIField(..., min_items=1, unique_items=True)
|
||||||
|
|
||||||
@validator('builds')
|
@validator('builds')
|
||||||
def builds_variants_must_be_unique(cls, builds):
|
def builds_variants_must_be_unique(cls, builds):
|
||||||
|
@ -121,6 +93,51 @@ class UploadFirmwareSchema(Schema):
|
||||||
|
|
||||||
|
|
||||||
@api_router.post('/firmwares/upload', summary="Upload firmware", auth=BearerAuth(superuser=True),
|
@api_router.post('/firmwares/upload', summary="Upload firmware", auth=BearerAuth(superuser=True),
|
||||||
response={200: FirmwareSchema, **auth_permission_responses})
|
description="your OpenAPI viewer might not show it: firmware_data is UploadFirmwareSchema as json",
|
||||||
def firmware_upload(request, firmware_data: UploadFirmwareSchema):
|
response={200: FirmwareSchema, **auth_permission_responses, **APIConflict.dict()})
|
||||||
raise NotImplementedError
|
def firmware_upload(request, firmware_data: UploadFirmwareSchema, binary_files: list[UploadedFile]):
|
||||||
|
binary_files_by_name = {binary_file.name: binary_file for binary_file in binary_files}
|
||||||
|
if len([binary_file.name for binary_file in binary_files]) != len(binary_files_by_name):
|
||||||
|
raise APIRequestValidationFailed("Filenames of uploaded binary files must be unique.")
|
||||||
|
|
||||||
|
build_filenames = [build_data.uploaded_filename for build_data in firmware_data.builds]
|
||||||
|
if len(build_filenames) != len(set(build_filenames)):
|
||||||
|
raise APIRequestValidationFailed("Builds need to refer to different unique binary file names.")
|
||||||
|
|
||||||
|
if set(binary_files_by_name) != set(build_filenames):
|
||||||
|
raise APIRequestValidationFailed("All uploaded binary files need to be refered to by one build.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with transaction.atomic():
|
||||||
|
version = FirmwareVersion.objects.create(
|
||||||
|
project_name=firmware_data.project_name,
|
||||||
|
version=firmware_data.version,
|
||||||
|
idf_version=firmware_data.idf_version,
|
||||||
|
uploader=request.auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
for build_data in firmware_data.builds:
|
||||||
|
# if bin_file.size > 4 * 1024 * 1024:
|
||||||
|
# raise ValueError # todo: better error
|
||||||
|
|
||||||
|
# h = hashlib.sha256()
|
||||||
|
# h.update(build_data.binary)
|
||||||
|
# sha256_bin_file = h.hexdigest() # todo: verify sha256 correctly
|
||||||
|
#
|
||||||
|
# if sha256_bin_file != build_data.sha256_hash:
|
||||||
|
# raise ValueError
|
||||||
|
|
||||||
|
build = version.builds.create(
|
||||||
|
variant=build_data.variant,
|
||||||
|
chip=build_data.chip,
|
||||||
|
sha256_hash=build_data.sha256_hash,
|
||||||
|
project_description=build_data.project_description,
|
||||||
|
binary=binary_files_by_name[build_data.uploaded_filename],
|
||||||
|
)
|
||||||
|
|
||||||
|
for board in build_data.boards:
|
||||||
|
build.firmwarebuildboard_set.create(board=board)
|
||||||
|
except IntegrityError:
|
||||||
|
raise APIConflict('Firmware version already exists.')
|
||||||
|
|
||||||
|
return version
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue