update firmware api with more endpoints and parse firmware image headers

This commit is contained in:
Laura Klünder 2023-11-17 18:56:47 +01:00
parent 2d97f9bb87
commit 14e39b2377
6 changed files with 209 additions and 43 deletions

View file

@ -10,7 +10,7 @@ from rest_framework.status import HTTP_201_CREATED
from rest_framework.viewsets import ReadOnlyModelViewSet from rest_framework.viewsets import ReadOnlyModelViewSet
from c3nav.control.models import UserPermissions from c3nav.control.models import UserPermissions
from c3nav.mesh.messages import ChipType from c3nav.mesh.dataformats import ChipType
from c3nav.mesh.models import FirmwareVersion from c3nav.mesh.models import FirmwareVersion

View file

@ -1,9 +1,11 @@
import re import re
import struct import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import Field, dataclass, fields from dataclasses import Field, dataclass, fields as dataclass_fields
from typing import Any, Self, Sequence from typing import Any, Self, Sequence
from pydantic import create_model
from c3nav.mesh.utils import indent_c from c3nav.mesh.utils import indent_c
@ -87,6 +89,18 @@ class SimpleFormat(BaseFormat):
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num)) return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
class SimpleConstFormat(SimpleFormat):
def __init__(self, fmt, const_value: int):
super().__init__(fmt)
self.const_value = const_value
def decode(self, data: bytes) -> tuple[Any, bytes]:
value, out_data = super().decode(data)
if value != self.const_value:
raise ValueError('const_value is wrong')
return value, out_data
class EnumFormat(SimpleFormat): class EnumFormat(SimpleFormat):
def __init__(self, fmt="B", *, as_hex=False, c_definition=True): def __init__(self, fmt="B", *, as_hex=False, c_definition=True):
super().__init__(fmt) super().__init__(fmt)
@ -136,6 +150,45 @@ class EnumFormat(SimpleFormat):
} }
class TwoNibblesEnumFormat(SimpleFormat):
def __init__(self):
super().__init__('B')
def decode(self, data: bytes) -> tuple[bool, bytes]:
fields = dataclass_fields(self.field_type)
value, data = super().decode(data)
return self.field_type(fields[0].type(value//2**4), fields[1].type(value//2**4)), data
def encode(self, value):
fields = dataclass_fields(self.field_type)
return super().encode(
getattr(value, fields[0].name).value * 2**4 +
getattr(value, fields[1].name).value * 2**4
)
def fromjson(self, data):
fields = dataclass_fields(self.field_type)
return self.field_type(*(field.type[data[field.name]] for field in fields))
def tojson(self, data):
fields = dataclass_fields(self.field_type)
return {
field.name: getattr(data, field.name).name for field in fields
}
class ChipRevFormat(SimpleFormat):
def __init__(self):
super().__init__('H')
def decode(self, data: bytes) -> tuple[bool, bytes]:
value, data = super().decode(data)
return (value // 100, value % 100), data
def encode(self, value):
return value[0]*100 + value[1]
class BoolFormat(SimpleFormat): class BoolFormat(SimpleFormat):
def __init__(self): def __init__(self):
super().__init__('B') super().__init__('B')
@ -145,6 +198,8 @@ class BoolFormat(SimpleFormat):
def decode(self, data: bytes) -> tuple[bool, bytes]: def decode(self, data: bytes) -> tuple[bool, bytes]:
value, data = super().decode(data) value, data = super().decode(data)
if value > 1:
raise ValueError('Boolean value > 1')
return bool(value), data return bool(value), data
@ -298,11 +353,39 @@ class StructType:
raise TypeError('Duplicate %s: %s', (key, value)) raise TypeError('Duplicate %s: %s', (key, value))
values[value] = cls values[value] = cls
setattr(cls, key, value) setattr(cls, key, value)
# pydantic model
cls._pydantic_fields = getattr(cls, '_pydantic_fields', {}).copy()
fields = []
for field_ in dataclass_fields(cls):
fields.append((field_.name, field_.type, field_.metadata))
for attr_name in tuple(cls.__annotations__.keys()):
attr = getattr(cls, attr_name, None)
metadata = attr.metadata if isinstance(attr, Field) else {}
try:
type_ = cls.__annotations__[attr_name]
except KeyError:
# print('nope', cls, attr_name)
continue
fields.append((attr_name, type_, metadata))
for name, type_, metadata in fields:
if metadata.get("format", None):
cls._pydantic_fields[name] = (type_, ...)
elif issubclass(type_, StructType):
if metadata.get("json_embed"):
cls._pydantic_fields.update(type_._pydantic_fields)
else:
cls._pydantic_fields[name] = (type_.schema, ...)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, name))
print(cls.__name__, cls._pydantic_fields)
cls.schema = create_model(cls.__name__+'Schema', **cls._pydantic_fields)
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
@classmethod @classmethod
def get_var_num(cls): def get_var_num(cls):
return sum([f.metadata.get("format", f.type).get_var_num() for f in fields(cls)], start=0) return sum([f.metadata.get("format", f.type).get_var_num() for f in dataclass_fields(cls)], start=0)
@classmethod @classmethod
def get_types(cls): def get_types(cls):
@ -323,14 +406,14 @@ class StructType:
if not isinstance(instance, cls): if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance)) raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(cls): for field_ in dataclass_fields(cls):
data += field_.metadata["format"].encode(getattr(instance, field_.name)) data += field_.metadata["format"].encode(getattr(instance, field_.name))
# todo: better # todo: better
data += instance.encode(instance, ignore_fields=set(f.name for f in fields(cls))) data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
return data return data
for field_ in fields(cls): for field_ in dataclass_fields(cls):
if field_.name in ignore_fields: if field_.name in ignore_fields:
continue continue
value = getattr(instance, field_.name) value = getattr(instance, field_.name)
@ -351,7 +434,7 @@ class StructType:
orig_data = data orig_data = data
kwargs = {} kwargs = {}
no_init_data = {} no_init_data = {}
for field_ in fields(cls): for field_ in dataclass_fields(cls):
if "format" in field_.metadata: if "format" in field_.metadata:
value, data = field_.metadata["format"].decode(data) value, data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType): elif issubclass(field_.type, StructType):
@ -386,7 +469,7 @@ class StructType:
if not isinstance(instance, cls): if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance)) raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(instance): for field_ in dataclass_fields(instance):
if field_.name is cls.union_type_field: if field_.name is cls.union_type_field:
result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name)) result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name))
break break
@ -396,7 +479,7 @@ class StructType:
result.update(instance.tojson(instance)) result.update(instance.tojson(instance))
return result return result
for field_ in fields(cls): for field_ in dataclass_fields(cls):
value = getattr(instance, field_.name) value = getattr(instance, field_.name)
if "format" in field_.metadata: if "format" in field_.metadata:
result[field_.name] = field_.metadata["format"].tojson(value) result[field_.name] = field_.metadata["format"].tojson(value)
@ -428,7 +511,7 @@ class StructType:
kwargs = {} kwargs = {}
no_init_data = {} no_init_data = {}
for field_ in fields(cls): for field_ in dataclass_fields(cls):
raw_value = data.get(field_.name, None) raw_value = data.get(field_.name, None)
if "format" in field_.metadata: if "format" in field_.metadata:
value = field_.metadata["format"].fromjson(raw_value) value = field_.metadata["format"].fromjson(raw_value)
@ -466,7 +549,7 @@ class StructType:
items = [] items = []
for field_ in fields(cls): for field_ in dataclass_fields(cls):
if field_.name in ignore_fields: if field_.name in ignore_fields:
continue continue
if in_union and field_.metadata["defining_class"] != cls: if in_union and field_.metadata["defining_class"] != cls:
@ -524,7 +607,7 @@ class StructType:
@classmethod @classmethod
def get_c_definitions(cls) -> dict[str, str]: def get_c_definitions(cls) -> dict[str, str]:
definitions = {} definitions = {}
for field_ in fields(cls): for field_ in dataclass_fields(cls):
if "format" in field_.metadata: if "format" in field_.metadata:
definitions.update(field_.metadata["format"].get_c_definitions()) definitions.update(field_.metadata["format"].get_c_definitions())
if field_.metadata.get("as_definition"): if field_.metadata.get("as_definition"):
@ -629,14 +712,14 @@ class StructType:
@classmethod @classmethod
def get_min_size(cls, no_inherited_fields=False) -> int: def get_min_size(cls, no_inherited_fields=False) -> int:
if cls.union_type_field: if cls.union_type_field:
own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in fields(cls)]) own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in dataclass_fields(cls)])
union_size = max( union_size = max(
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()]) [0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
return own_size + union_size return own_size + union_size
if no_inherited_fields: if no_inherited_fields:
relevant_fields = [f for f in fields(cls) if f.metadata["defining_class"] == cls] relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else: else:
relevant_fields = [f for f in fields(cls) if not f.metadata.get("union_discriminator")] relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0) return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0)

View file

@ -4,7 +4,7 @@ from enum import IntEnum, unique
from c3nav.api.utils import EnumSchemaByNameMixin from c3nav.api.utils import EnumSchemaByNameMixin
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedHexFormat, FixedStrFormat, SimpleFormat, StructType, from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedHexFormat, FixedStrFormat, SimpleFormat, StructType,
VarArrayFormat) VarArrayFormat, TwoNibblesEnumFormat, ChipRevFormat, SimpleConstFormat)
class MacAddressFormat(FixedHexFormat): class MacAddressFormat(FixedHexFormat):
@ -158,7 +158,7 @@ class RawFTMEntry(StructType):
@dataclass @dataclass
class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t"): class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t"):
magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False) magic_word: int = field(metadata={"format": SimpleConstFormat('I', 0xAB_CD_54_32)}, repr=False)
secure_version: int = field(metadata={"format": SimpleFormat('I')}) secure_version: int = field(metadata={"format": SimpleFormat('I')})
reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False) reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
version: str = field(metadata={"format": FixedStrFormat(32)}) version: str = field(metadata={"format": FixedStrFormat(32)})
@ -168,3 +168,76 @@ class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t"):
idf_version: str = field(metadata={"format": FixedStrFormat(32)}) idf_version: str = field(metadata={"format": FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)}) app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)})
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False) reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
@unique
class SPIFlashMode(EnumSchemaByNameMixin, IntEnum):
QIO = 0
QOUT = 1
DIO = 2
DOUT = 3
@unique
class FlashSize(EnumSchemaByNameMixin, IntEnum):
SIZE_1MB = 0
SIZE_2MB = 1
SIZE_4MB = 2
SIZE_8MB = 3
SIZE_16MB = 4
SIZE_32MB = 5
SIZE_64MB = 6
SIZE_128MB = 7
@unique
class FlashFrequency(EnumSchemaByNameMixin, IntEnum):
FREQ_40MHZ = 0
FREQ_26MHZ = 1
FREQ_20MHZ = 2
FREQ_80MHZ = 0xf
@dataclass
class FlashSettings:
size: FlashSize
frequency: FlashFrequency
@unique
class ChipType(EnumSchemaByNameMixin, IntEnum):
ESP32_S2 = 2
ESP32_C3 = 5
@property
def pretty_name(self):
return self.name.replace('_', '-')
@dataclass
class FirmwareImageFileHeader(StructType):
magic_word: int = field(metadata={"format": SimpleConstFormat('B', 0xE9)}, repr=False)
num_segments: int = field(metadata={"format": SimpleFormat('B')})
spi_flash_mode: SPIFlashMode = field(metadata={"format": EnumFormat()})
flash_stuff: FlashSettings = field(metadata={"format": TwoNibblesEnumFormat()})
entry_point: int = field(metadata={"format": SimpleFormat('I')})
@dataclass
class FirmwareImageExtendedFileHeader(StructType):
wp_pin: int = field(metadata={"format": SimpleFormat('B')})
drive_settings: int = field(metadata={"format": SimpleFormat('3B')})
chip_id: ChipType = field(metadata={"format": EnumFormat('H')})
min_chip_rev_old: int = field(metadata={"format": SimpleFormat('B')})
min_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()})
max_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()})
reserv: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
hash_appended: bool = field(metadata={"format": BoolFormat()})
@dataclass
class FirmwareImage(StructType):
header: FirmwareImageFileHeader
ext_header: FirmwareImageExtendedFileHeader
first_segment_headers: tuple[int, int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
app_desc: FirmwareAppDescription

View file

@ -5,11 +5,10 @@ from typing import TypeVar
import channels import channels
from channels.db import database_sync_to_async from channels.db import database_sync_to_async
from c3nav.api.utils import EnumSchemaByNameMixin
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat, from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat,
VarBytesFormat, VarStrFormat, normalize_name) VarBytesFormat, VarStrFormat, normalize_name)
from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat, from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat,
RangeResultItem, RawFTMEntry) RangeResultItem, RawFTMEntry, ChipType)
from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP
MESH_ROOT_ADDRESS = '00:00:00:00:00:00' MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
@ -68,16 +67,6 @@ class MeshMessageType(IntEnum):
M = TypeVar('M', bound='MeshMessage') M = TypeVar('M', bound='MeshMessage')
@unique
class ChipType(EnumSchemaByNameMixin, IntEnum):
ESP32_S2 = 2
ESP32_C3 = 5
@property
def pretty_name(self):
return self.name.replace('_', '-')
@dataclass @dataclass
class MeshMessage(StructType, union_type_field="msg_type"): class MeshMessage(StructType, union_type_field="msg_type"):
dst: str = field(metadata={"format": MacAddressFormat()}) dst: str = field(metadata={"format": MacAddressFormat()})

View file

@ -14,8 +14,8 @@ from django.utils import timezone
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from c3nav.mesh.dataformats import BoardType from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage
from c3nav.mesh.messages import ChipType, ConfigFirmwareMessage, ConfigHardwareMessage from c3nav.mesh.messages import ConfigFirmwareMessage, ConfigHardwareMessage
from c3nav.mesh.messages import MeshMessage as MeshMessage from c3nav.mesh.messages import MeshMessage as MeshMessage
from c3nav.mesh.messages import MeshMessageType from c3nav.mesh.messages import MeshMessageType
from c3nav.mesh.utils import UPLINK_TIMEOUT from c3nav.mesh.utils import UPLINK_TIMEOUT
@ -410,6 +410,11 @@ class FirmwareBuild(models.Model):
for board in self.boards for board in self.boards
] ]
@cached_property
def firmware_image(self) -> FirmwareImage:
firmware_image, remaining = FirmwareImage.decode(self.binary.open('rb').read()[:FirmwareImage.get_min_size()])
return firmware_image
class FirmwareBuildBoard(models.Model): class FirmwareBuildBoard(models.Model):
BOARDS = [(boardtype.name, boardtype.pretty_name) for boardtype in BoardType] BOARDS = [(boardtype.name, boardtype.pretty_name) for boardtype in BoardType]

View file

@ -7,11 +7,10 @@ from ninja import Schema, UploadedFile
from ninja.pagination import paginate from ninja.pagination import paginate
from pydantic import validator from pydantic import validator
from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed, API404
from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses
from c3nav.mesh.dataformats import BoardType from c3nav.mesh.dataformats import BoardType, FirmwareImage, ChipType
from c3nav.mesh.messages import ChipType from c3nav.mesh.models import FirmwareVersion, FirmwareBuild
from c3nav.mesh.models import FirmwareVersion
api_router = APIRouter(tags=["mesh"]) api_router = APIRouter(tags=["mesh"])
@ -22,18 +21,14 @@ class FirmwareBuildSchema(Schema):
chip: ChipType = APIField(..., example=ChipType.ESP32_C3.name) chip: ChipType = APIField(..., example=ChipType.ESP32_C3.name)
sha256_hash: str = APIField(..., regex=r"^[0-9a-f]{64}$") sha256_hash: str = APIField(..., regex=r"^[0-9a-f]{64}$")
url: str = APIField(..., alias="binary", example="/media/firmware/012345/firmware.bin") url: str = APIField(..., alias="binary", example="/media/firmware/012345/firmware.bin")
boards: list[BoardType] = APIField(..., example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ]) # todo: should not be none, but parse errors
boards: list[BoardType] = APIField(None, example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ])
@staticmethod @staticmethod
def resolve_chip(obj): def resolve_chip(obj):
# todo: do this in model? idk # todo: do this in model? idk
return ChipType(obj.chip) return ChipType(obj.chip)
@staticmethod
def resolve_boards(obj):
print(obj.boards)
return obj.boards
class FirmwareSchema(Schema): class FirmwareSchema(Schema):
id: int id: int
@ -62,12 +57,33 @@ def firmware_list(request):
@api_router.get('/firmwares/{firmware_id}/', summary="Get specific firmware", @api_router.get('/firmwares/{firmware_id}/', summary="Get specific firmware",
response={200: FirmwareSchema, **auth_responses}) response={200: FirmwareSchema, **API404.dict(), **auth_responses})
def firmware_detail(request, firmware_id: int): def firmware_detail(request, firmware_id: int):
try: try:
return FirmwareVersion.objects.get(id=firmware_id) return FirmwareVersion.objects.get(id=firmware_id)
except FirmwareVersion.DoesNotExist: except FirmwareVersion.DoesNotExist:
return 404, {"detail": "firmware not found"} raise API404("Firmware not found")
@api_router.get('/firmwares/{firmware_id}/{variant}/image_data',
summary="Get header data of firmware build image",
response={200: FirmwareImage.schema, **API404.dict(), **auth_responses})
def firmware_build_image(request, firmware_id: int, variant: str):
try:
build = FirmwareBuild.objects.get(version_id=firmware_id, variant=variant)
return FirmwareImage.tojson(build.firmware_image)
except FirmwareVersion.DoesNotExist:
raise API404("Firmware or firmware build not found")
@api_router.get('/firmwares/{firmware_id}/{variant}/project_description',
summary="Get project description of firmware build",
response={200: dict, **API404.dict(), **auth_responses})
def firmware_project_description(request, firmware_id: int, variant: str):
try:
return FirmwareBuild.objects.get(version_id=firmware_id, variant=variant).firmware_description
except FirmwareVersion.DoesNotExist:
raise API404("Firmware or firmware build not found")
class UploadFirmwareBuildSchema(Schema): class UploadFirmwareBuildSchema(Schema):