From 14e39b23778cb94f6429d159e62fb4a41db9f7e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Fri, 17 Nov 2023 18:56:47 +0100 Subject: [PATCH] update firmware api with more endpoints and parse firmware image headers --- src/c3nav/mesh/api.py | 2 +- src/c3nav/mesh/baseformats.py | 111 +++++++++++++++++++++++++++++----- src/c3nav/mesh/dataformats.py | 77 ++++++++++++++++++++++- src/c3nav/mesh/messages.py | 13 +--- src/c3nav/mesh/models.py | 9 ++- src/c3nav/mesh/newapi.py | 40 ++++++++---- 6 files changed, 209 insertions(+), 43 deletions(-) diff --git a/src/c3nav/mesh/api.py b/src/c3nav/mesh/api.py index a9404631..ed658d1b 100644 --- a/src/c3nav/mesh/api.py +++ b/src/c3nav/mesh/api.py @@ -10,7 +10,7 @@ from rest_framework.status import HTTP_201_CREATED from rest_framework.viewsets import ReadOnlyModelViewSet from c3nav.control.models import UserPermissions -from c3nav.mesh.messages import ChipType +from c3nav.mesh.dataformats import ChipType from c3nav.mesh.models import FirmwareVersion diff --git a/src/c3nav/mesh/baseformats.py b/src/c3nav/mesh/baseformats.py index a001624c..b64d6a1a 100644 --- a/src/c3nav/mesh/baseformats.py +++ b/src/c3nav/mesh/baseformats.py @@ -1,9 +1,11 @@ import re import struct from abc import ABC, abstractmethod -from dataclasses import Field, dataclass, fields +from dataclasses import Field, dataclass, fields as dataclass_fields from typing import Any, Self, Sequence +from pydantic import create_model + from c3nav.mesh.utils import indent_c @@ -87,6 +89,18 @@ class SimpleFormat(BaseFormat): return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num)) +class SimpleConstFormat(SimpleFormat): + def __init__(self, fmt, const_value: int): + super().__init__(fmt) + self.const_value = const_value + + def decode(self, data: bytes) -> tuple[Any, bytes]: + value, out_data = super().decode(data) + if value != self.const_value: + raise ValueError('const_value is wrong') + return value, out_data + + class EnumFormat(SimpleFormat): def __init__(self, fmt="B", *, as_hex=False, c_definition=True): super().__init__(fmt) @@ -136,6 +150,45 @@ class EnumFormat(SimpleFormat): } +class TwoNibblesEnumFormat(SimpleFormat): + def __init__(self): + super().__init__('B') + + def decode(self, data: bytes) -> tuple[bool, bytes]: + fields = dataclass_fields(self.field_type) + value, data = super().decode(data) + return self.field_type(fields[0].type(value//2**4), fields[1].type(value//2**4)), data + + def encode(self, value): + fields = dataclass_fields(self.field_type) + return super().encode( + getattr(value, fields[0].name).value * 2**4 + + getattr(value, fields[1].name).value * 2**4 + ) + + def fromjson(self, data): + fields = dataclass_fields(self.field_type) + return self.field_type(*(field.type[data[field.name]] for field in fields)) + + def tojson(self, data): + fields = dataclass_fields(self.field_type) + return { + field.name: getattr(data, field.name).name for field in fields + } + + +class ChipRevFormat(SimpleFormat): + def __init__(self): + super().__init__('H') + + def decode(self, data: bytes) -> tuple[bool, bytes]: + value, data = super().decode(data) + return (value // 100, value % 100), data + + def encode(self, value): + return value[0]*100 + value[1] + + class BoolFormat(SimpleFormat): def __init__(self): super().__init__('B') @@ -145,6 +198,8 @@ class BoolFormat(SimpleFormat): def decode(self, data: bytes) -> tuple[bool, bytes]: value, data = super().decode(data) + if value > 1: + raise ValueError('Boolean value > 1') return bool(value), data @@ -298,11 +353,39 @@ class StructType: raise TypeError('Duplicate %s: %s', (key, value)) values[value] = cls setattr(cls, key, value) + + # pydantic model + cls._pydantic_fields = getattr(cls, '_pydantic_fields', {}).copy() + fields = [] + for field_ in dataclass_fields(cls): + fields.append((field_.name, field_.type, field_.metadata)) + for attr_name in tuple(cls.__annotations__.keys()): + attr = getattr(cls, attr_name, None) + metadata = attr.metadata if isinstance(attr, Field) else {} + try: + type_ = cls.__annotations__[attr_name] + except KeyError: + # print('nope', cls, attr_name) + continue + fields.append((attr_name, type_, metadata)) + for name, type_, metadata in fields: + if metadata.get("format", None): + cls._pydantic_fields[name] = (type_, ...) + elif issubclass(type_, StructType): + if metadata.get("json_embed"): + cls._pydantic_fields.update(type_._pydantic_fields) + else: + cls._pydantic_fields[name] = (type_.schema, ...) + else: + raise TypeError('field %s.%s has no format and is no StructType' % + (cls.__class__.__name__, name)) + print(cls.__name__, cls._pydantic_fields) + cls.schema = create_model(cls.__name__+'Schema', **cls._pydantic_fields) super().__init_subclass__(**kwargs) @classmethod def get_var_num(cls): - return sum([f.metadata.get("format", f.type).get_var_num() for f in fields(cls)], start=0) + return sum([f.metadata.get("format", f.type).get_var_num() for f in dataclass_fields(cls)], start=0) @classmethod def get_types(cls): @@ -323,14 +406,14 @@ class StructType: if not isinstance(instance, cls): raise ValueError('expected value of type %r, got %r' % (cls, instance)) - for field_ in fields(cls): + for field_ in dataclass_fields(cls): data += field_.metadata["format"].encode(getattr(instance, field_.name)) # todo: better - data += instance.encode(instance, ignore_fields=set(f.name for f in fields(cls))) + data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls))) return data - for field_ in fields(cls): + for field_ in dataclass_fields(cls): if field_.name in ignore_fields: continue value = getattr(instance, field_.name) @@ -351,7 +434,7 @@ class StructType: orig_data = data kwargs = {} no_init_data = {} - for field_ in fields(cls): + for field_ in dataclass_fields(cls): if "format" in field_.metadata: value, data = field_.metadata["format"].decode(data) elif issubclass(field_.type, StructType): @@ -386,7 +469,7 @@ class StructType: if not isinstance(instance, cls): raise ValueError('expected value of type %r, got %r' % (cls, instance)) - for field_ in fields(instance): + for field_ in dataclass_fields(instance): if field_.name is cls.union_type_field: result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name)) break @@ -396,7 +479,7 @@ class StructType: result.update(instance.tojson(instance)) return result - for field_ in fields(cls): + for field_ in dataclass_fields(cls): value = getattr(instance, field_.name) if "format" in field_.metadata: result[field_.name] = field_.metadata["format"].tojson(value) @@ -428,7 +511,7 @@ class StructType: kwargs = {} no_init_data = {} - for field_ in fields(cls): + for field_ in dataclass_fields(cls): raw_value = data.get(field_.name, None) if "format" in field_.metadata: value = field_.metadata["format"].fromjson(raw_value) @@ -466,7 +549,7 @@ class StructType: items = [] - for field_ in fields(cls): + for field_ in dataclass_fields(cls): if field_.name in ignore_fields: continue if in_union and field_.metadata["defining_class"] != cls: @@ -524,7 +607,7 @@ class StructType: @classmethod def get_c_definitions(cls) -> dict[str, str]: definitions = {} - for field_ in fields(cls): + for field_ in dataclass_fields(cls): if "format" in field_.metadata: definitions.update(field_.metadata["format"].get_c_definitions()) if field_.metadata.get("as_definition"): @@ -629,14 +712,14 @@ class StructType: @classmethod def get_min_size(cls, no_inherited_fields=False) -> int: if cls.union_type_field: - own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in fields(cls)]) + own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in dataclass_fields(cls)]) union_size = max( [0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()]) return own_size + union_size if no_inherited_fields: - relevant_fields = [f for f in fields(cls) if f.metadata["defining_class"] == cls] + relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls] else: - relevant_fields = [f for f in fields(cls) if not f.metadata.get("union_discriminator")] + relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0) diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py index 8ed2ec4c..bf4f05a8 100644 --- a/src/c3nav/mesh/dataformats.py +++ b/src/c3nav/mesh/dataformats.py @@ -4,7 +4,7 @@ from enum import IntEnum, unique from c3nav.api.utils import EnumSchemaByNameMixin from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedHexFormat, FixedStrFormat, SimpleFormat, StructType, - VarArrayFormat) + VarArrayFormat, TwoNibblesEnumFormat, ChipRevFormat, SimpleConstFormat) class MacAddressFormat(FixedHexFormat): @@ -158,7 +158,7 @@ class RawFTMEntry(StructType): @dataclass class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t"): - magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False) + magic_word: int = field(metadata={"format": SimpleConstFormat('I', 0xAB_CD_54_32)}, repr=False) secure_version: int = field(metadata={"format": SimpleFormat('I')}) reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False) version: str = field(metadata={"format": FixedStrFormat(32)}) @@ -168,3 +168,76 @@ class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t"): idf_version: str = field(metadata={"format": FixedStrFormat(32)}) app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)}) reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False) + + +@unique +class SPIFlashMode(EnumSchemaByNameMixin, IntEnum): + QIO = 0 + QOUT = 1 + DIO = 2 + DOUT = 3 + + +@unique +class FlashSize(EnumSchemaByNameMixin, IntEnum): + SIZE_1MB = 0 + SIZE_2MB = 1 + SIZE_4MB = 2 + SIZE_8MB = 3 + SIZE_16MB = 4 + SIZE_32MB = 5 + SIZE_64MB = 6 + SIZE_128MB = 7 + + +@unique +class FlashFrequency(EnumSchemaByNameMixin, IntEnum): + FREQ_40MHZ = 0 + FREQ_26MHZ = 1 + FREQ_20MHZ = 2 + FREQ_80MHZ = 0xf + + +@dataclass +class FlashSettings: + size: FlashSize + frequency: FlashFrequency + + +@unique +class ChipType(EnumSchemaByNameMixin, IntEnum): + ESP32_S2 = 2 + ESP32_C3 = 5 + + @property + def pretty_name(self): + return self.name.replace('_', '-') + + +@dataclass +class FirmwareImageFileHeader(StructType): + magic_word: int = field(metadata={"format": SimpleConstFormat('B', 0xE9)}, repr=False) + num_segments: int = field(metadata={"format": SimpleFormat('B')}) + spi_flash_mode: SPIFlashMode = field(metadata={"format": EnumFormat()}) + flash_stuff: FlashSettings = field(metadata={"format": TwoNibblesEnumFormat()}) + entry_point: int = field(metadata={"format": SimpleFormat('I')}) + + +@dataclass +class FirmwareImageExtendedFileHeader(StructType): + wp_pin: int = field(metadata={"format": SimpleFormat('B')}) + drive_settings: int = field(metadata={"format": SimpleFormat('3B')}) + chip_id: ChipType = field(metadata={"format": EnumFormat('H')}) + min_chip_rev_old: int = field(metadata={"format": SimpleFormat('B')}) + min_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()}) + max_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()}) + reserv: int = field(metadata={"format": SimpleFormat('I')}, repr=False) + hash_appended: bool = field(metadata={"format": BoolFormat()}) + + +@dataclass +class FirmwareImage(StructType): + header: FirmwareImageFileHeader + ext_header: FirmwareImageExtendedFileHeader + first_segment_headers: tuple[int, int] = field(metadata={"format": SimpleFormat('2I')}, repr=False) + app_desc: FirmwareAppDescription diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index 92f74ceb..8c9d4f64 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -5,11 +5,10 @@ from typing import TypeVar import channels from channels.db import database_sync_to_async -from c3nav.api.utils import EnumSchemaByNameMixin from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat, VarBytesFormat, VarStrFormat, normalize_name) from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat, - RangeResultItem, RawFTMEntry) + RangeResultItem, RawFTMEntry, ChipType) from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP MESH_ROOT_ADDRESS = '00:00:00:00:00:00' @@ -68,16 +67,6 @@ class MeshMessageType(IntEnum): M = TypeVar('M', bound='MeshMessage') -@unique -class ChipType(EnumSchemaByNameMixin, IntEnum): - ESP32_S2 = 2 - ESP32_C3 = 5 - - @property - def pretty_name(self): - return self.name.replace('_', '-') - - @dataclass class MeshMessage(StructType, union_type_field="msg_type"): dst: str = field(metadata={"format": MacAddressFormat()}) diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py index 2cad8b99..ca62b90b 100644 --- a/src/c3nav/mesh/models.py +++ b/src/c3nav/mesh/models.py @@ -14,8 +14,8 @@ from django.utils import timezone from django.utils.text import slugify from django.utils.translation import gettext_lazy as _ -from c3nav.mesh.dataformats import BoardType -from c3nav.mesh.messages import ChipType, ConfigFirmwareMessage, ConfigHardwareMessage +from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage +from c3nav.mesh.messages import ConfigFirmwareMessage, ConfigHardwareMessage from c3nav.mesh.messages import MeshMessage as MeshMessage from c3nav.mesh.messages import MeshMessageType from c3nav.mesh.utils import UPLINK_TIMEOUT @@ -410,6 +410,11 @@ class FirmwareBuild(models.Model): for board in self.boards ] + @cached_property + def firmware_image(self) -> FirmwareImage: + firmware_image, remaining = FirmwareImage.decode(self.binary.open('rb').read()[:FirmwareImage.get_min_size()]) + return firmware_image + class FirmwareBuildBoard(models.Model): BOARDS = [(boardtype.name, boardtype.pretty_name) for boardtype in BoardType] diff --git a/src/c3nav/mesh/newapi.py b/src/c3nav/mesh/newapi.py index f12b7448..e63e323b 100644 --- a/src/c3nav/mesh/newapi.py +++ b/src/c3nav/mesh/newapi.py @@ -7,11 +7,10 @@ from ninja import Schema, UploadedFile from ninja.pagination import paginate from pydantic import validator -from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed +from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed, API404 from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses -from c3nav.mesh.dataformats import BoardType -from c3nav.mesh.messages import ChipType -from c3nav.mesh.models import FirmwareVersion +from c3nav.mesh.dataformats import BoardType, FirmwareImage, ChipType +from c3nav.mesh.models import FirmwareVersion, FirmwareBuild api_router = APIRouter(tags=["mesh"]) @@ -22,18 +21,14 @@ class FirmwareBuildSchema(Schema): chip: ChipType = APIField(..., example=ChipType.ESP32_C3.name) sha256_hash: str = APIField(..., regex=r"^[0-9a-f]{64}$") url: str = APIField(..., alias="binary", example="/media/firmware/012345/firmware.bin") - boards: list[BoardType] = APIField(..., example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ]) + # todo: should not be none, but parse errors + boards: list[BoardType] = APIField(None, example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ]) @staticmethod def resolve_chip(obj): # todo: do this in model? idk return ChipType(obj.chip) - @staticmethod - def resolve_boards(obj): - print(obj.boards) - return obj.boards - class FirmwareSchema(Schema): id: int @@ -62,12 +57,33 @@ def firmware_list(request): @api_router.get('/firmwares/{firmware_id}/', summary="Get specific firmware", - response={200: FirmwareSchema, **auth_responses}) + response={200: FirmwareSchema, **API404.dict(), **auth_responses}) def firmware_detail(request, firmware_id: int): try: return FirmwareVersion.objects.get(id=firmware_id) except FirmwareVersion.DoesNotExist: - return 404, {"detail": "firmware not found"} + raise API404("Firmware not found") + + +@api_router.get('/firmwares/{firmware_id}/{variant}/image_data', + summary="Get header data of firmware build image", + response={200: FirmwareImage.schema, **API404.dict(), **auth_responses}) +def firmware_build_image(request, firmware_id: int, variant: str): + try: + build = FirmwareBuild.objects.get(version_id=firmware_id, variant=variant) + return FirmwareImage.tojson(build.firmware_image) + except FirmwareVersion.DoesNotExist: + raise API404("Firmware or firmware build not found") + + +@api_router.get('/firmwares/{firmware_id}/{variant}/project_description', + summary="Get project description of firmware build", + response={200: dict, **API404.dict(), **auth_responses}) +def firmware_project_description(request, firmware_id: int, variant: str): + try: + return FirmwareBuild.objects.get(version_id=firmware_id, variant=variant).firmware_description + except FirmwareVersion.DoesNotExist: + raise API404("Firmware or firmware build not found") class UploadFirmwareBuildSchema(Schema):