update firmware api with more endpoints and parse firmware image headers

This commit is contained in:
Laura Klünder 2023-11-17 18:56:47 +01:00
parent 2d97f9bb87
commit 14e39b2377
6 changed files with 209 additions and 43 deletions

View file

@ -10,7 +10,7 @@ from rest_framework.status import HTTP_201_CREATED
from rest_framework.viewsets import ReadOnlyModelViewSet
from c3nav.control.models import UserPermissions
from c3nav.mesh.messages import ChipType
from c3nav.mesh.dataformats import ChipType
from c3nav.mesh.models import FirmwareVersion

View file

@ -1,9 +1,11 @@
import re
import struct
from abc import ABC, abstractmethod
from dataclasses import Field, dataclass, fields
from dataclasses import Field, dataclass, fields as dataclass_fields
from typing import Any, Self, Sequence
from pydantic import create_model
from c3nav.mesh.utils import indent_c
@ -87,6 +89,18 @@ class SimpleFormat(BaseFormat):
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
class SimpleConstFormat(SimpleFormat):
def __init__(self, fmt, const_value: int):
super().__init__(fmt)
self.const_value = const_value
def decode(self, data: bytes) -> tuple[Any, bytes]:
value, out_data = super().decode(data)
if value != self.const_value:
raise ValueError('const_value is wrong')
return value, out_data
class EnumFormat(SimpleFormat):
def __init__(self, fmt="B", *, as_hex=False, c_definition=True):
super().__init__(fmt)
@ -136,6 +150,45 @@ class EnumFormat(SimpleFormat):
}
class TwoNibblesEnumFormat(SimpleFormat):
def __init__(self):
super().__init__('B')
def decode(self, data: bytes) -> tuple[bool, bytes]:
fields = dataclass_fields(self.field_type)
value, data = super().decode(data)
return self.field_type(fields[0].type(value//2**4), fields[1].type(value//2**4)), data
def encode(self, value):
fields = dataclass_fields(self.field_type)
return super().encode(
getattr(value, fields[0].name).value * 2**4 +
getattr(value, fields[1].name).value * 2**4
)
def fromjson(self, data):
fields = dataclass_fields(self.field_type)
return self.field_type(*(field.type[data[field.name]] for field in fields))
def tojson(self, data):
fields = dataclass_fields(self.field_type)
return {
field.name: getattr(data, field.name).name for field in fields
}
class ChipRevFormat(SimpleFormat):
def __init__(self):
super().__init__('H')
def decode(self, data: bytes) -> tuple[bool, bytes]:
value, data = super().decode(data)
return (value // 100, value % 100), data
def encode(self, value):
return value[0]*100 + value[1]
class BoolFormat(SimpleFormat):
def __init__(self):
super().__init__('B')
@ -145,6 +198,8 @@ class BoolFormat(SimpleFormat):
def decode(self, data: bytes) -> tuple[bool, bytes]:
value, data = super().decode(data)
if value > 1:
raise ValueError('Boolean value > 1')
return bool(value), data
@ -298,11 +353,39 @@ class StructType:
raise TypeError('Duplicate %s: %s', (key, value))
values[value] = cls
setattr(cls, key, value)
# pydantic model
cls._pydantic_fields = getattr(cls, '_pydantic_fields', {}).copy()
fields = []
for field_ in dataclass_fields(cls):
fields.append((field_.name, field_.type, field_.metadata))
for attr_name in tuple(cls.__annotations__.keys()):
attr = getattr(cls, attr_name, None)
metadata = attr.metadata if isinstance(attr, Field) else {}
try:
type_ = cls.__annotations__[attr_name]
except KeyError:
# print('nope', cls, attr_name)
continue
fields.append((attr_name, type_, metadata))
for name, type_, metadata in fields:
if metadata.get("format", None):
cls._pydantic_fields[name] = (type_, ...)
elif issubclass(type_, StructType):
if metadata.get("json_embed"):
cls._pydantic_fields.update(type_._pydantic_fields)
else:
cls._pydantic_fields[name] = (type_.schema, ...)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, name))
print(cls.__name__, cls._pydantic_fields)
cls.schema = create_model(cls.__name__+'Schema', **cls._pydantic_fields)
super().__init_subclass__(**kwargs)
@classmethod
def get_var_num(cls):
return sum([f.metadata.get("format", f.type).get_var_num() for f in fields(cls)], start=0)
return sum([f.metadata.get("format", f.type).get_var_num() for f in dataclass_fields(cls)], start=0)
@classmethod
def get_types(cls):
@ -323,14 +406,14 @@ class StructType:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(cls):
for field_ in dataclass_fields(cls):
data += field_.metadata["format"].encode(getattr(instance, field_.name))
# todo: better
data += instance.encode(instance, ignore_fields=set(f.name for f in fields(cls)))
data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
return data
for field_ in fields(cls):
for field_ in dataclass_fields(cls):
if field_.name in ignore_fields:
continue
value = getattr(instance, field_.name)
@ -351,7 +434,7 @@ class StructType:
orig_data = data
kwargs = {}
no_init_data = {}
for field_ in fields(cls):
for field_ in dataclass_fields(cls):
if "format" in field_.metadata:
value, data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType):
@ -386,7 +469,7 @@ class StructType:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(instance):
for field_ in dataclass_fields(instance):
if field_.name is cls.union_type_field:
result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name))
break
@ -396,7 +479,7 @@ class StructType:
result.update(instance.tojson(instance))
return result
for field_ in fields(cls):
for field_ in dataclass_fields(cls):
value = getattr(instance, field_.name)
if "format" in field_.metadata:
result[field_.name] = field_.metadata["format"].tojson(value)
@ -428,7 +511,7 @@ class StructType:
kwargs = {}
no_init_data = {}
for field_ in fields(cls):
for field_ in dataclass_fields(cls):
raw_value = data.get(field_.name, None)
if "format" in field_.metadata:
value = field_.metadata["format"].fromjson(raw_value)
@ -466,7 +549,7 @@ class StructType:
items = []
for field_ in fields(cls):
for field_ in dataclass_fields(cls):
if field_.name in ignore_fields:
continue
if in_union and field_.metadata["defining_class"] != cls:
@ -524,7 +607,7 @@ class StructType:
@classmethod
def get_c_definitions(cls) -> dict[str, str]:
definitions = {}
for field_ in fields(cls):
for field_ in dataclass_fields(cls):
if "format" in field_.metadata:
definitions.update(field_.metadata["format"].get_c_definitions())
if field_.metadata.get("as_definition"):
@ -629,14 +712,14 @@ class StructType:
@classmethod
def get_min_size(cls, no_inherited_fields=False) -> int:
if cls.union_type_field:
own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in fields(cls)])
own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in dataclass_fields(cls)])
union_size = max(
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
return own_size + union_size
if no_inherited_fields:
relevant_fields = [f for f in fields(cls) if f.metadata["defining_class"] == cls]
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else:
relevant_fields = [f for f in fields(cls) if not f.metadata.get("union_discriminator")]
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0)

View file

@ -4,7 +4,7 @@ from enum import IntEnum, unique
from c3nav.api.utils import EnumSchemaByNameMixin
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedHexFormat, FixedStrFormat, SimpleFormat, StructType,
VarArrayFormat)
VarArrayFormat, TwoNibblesEnumFormat, ChipRevFormat, SimpleConstFormat)
class MacAddressFormat(FixedHexFormat):
@ -158,7 +158,7 @@ class RawFTMEntry(StructType):
@dataclass
class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t"):
magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
magic_word: int = field(metadata={"format": SimpleConstFormat('I', 0xAB_CD_54_32)}, repr=False)
secure_version: int = field(metadata={"format": SimpleFormat('I')})
reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
version: str = field(metadata={"format": FixedStrFormat(32)})
@ -168,3 +168,76 @@ class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t"):
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)})
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
@unique
class SPIFlashMode(EnumSchemaByNameMixin, IntEnum):
QIO = 0
QOUT = 1
DIO = 2
DOUT = 3
@unique
class FlashSize(EnumSchemaByNameMixin, IntEnum):
SIZE_1MB = 0
SIZE_2MB = 1
SIZE_4MB = 2
SIZE_8MB = 3
SIZE_16MB = 4
SIZE_32MB = 5
SIZE_64MB = 6
SIZE_128MB = 7
@unique
class FlashFrequency(EnumSchemaByNameMixin, IntEnum):
FREQ_40MHZ = 0
FREQ_26MHZ = 1
FREQ_20MHZ = 2
FREQ_80MHZ = 0xf
@dataclass
class FlashSettings:
size: FlashSize
frequency: FlashFrequency
@unique
class ChipType(EnumSchemaByNameMixin, IntEnum):
ESP32_S2 = 2
ESP32_C3 = 5
@property
def pretty_name(self):
return self.name.replace('_', '-')
@dataclass
class FirmwareImageFileHeader(StructType):
magic_word: int = field(metadata={"format": SimpleConstFormat('B', 0xE9)}, repr=False)
num_segments: int = field(metadata={"format": SimpleFormat('B')})
spi_flash_mode: SPIFlashMode = field(metadata={"format": EnumFormat()})
flash_stuff: FlashSettings = field(metadata={"format": TwoNibblesEnumFormat()})
entry_point: int = field(metadata={"format": SimpleFormat('I')})
@dataclass
class FirmwareImageExtendedFileHeader(StructType):
wp_pin: int = field(metadata={"format": SimpleFormat('B')})
drive_settings: int = field(metadata={"format": SimpleFormat('3B')})
chip_id: ChipType = field(metadata={"format": EnumFormat('H')})
min_chip_rev_old: int = field(metadata={"format": SimpleFormat('B')})
min_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()})
max_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()})
reserv: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
hash_appended: bool = field(metadata={"format": BoolFormat()})
@dataclass
class FirmwareImage(StructType):
header: FirmwareImageFileHeader
ext_header: FirmwareImageExtendedFileHeader
first_segment_headers: tuple[int, int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
app_desc: FirmwareAppDescription

View file

@ -5,11 +5,10 @@ from typing import TypeVar
import channels
from channels.db import database_sync_to_async
from c3nav.api.utils import EnumSchemaByNameMixin
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat,
VarBytesFormat, VarStrFormat, normalize_name)
from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat,
RangeResultItem, RawFTMEntry)
RangeResultItem, RawFTMEntry, ChipType)
from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
@ -68,16 +67,6 @@ class MeshMessageType(IntEnum):
M = TypeVar('M', bound='MeshMessage')
@unique
class ChipType(EnumSchemaByNameMixin, IntEnum):
ESP32_S2 = 2
ESP32_C3 = 5
@property
def pretty_name(self):
return self.name.replace('_', '-')
@dataclass
class MeshMessage(StructType, union_type_field="msg_type"):
dst: str = field(metadata={"format": MacAddressFormat()})

View file

@ -14,8 +14,8 @@ from django.utils import timezone
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from c3nav.mesh.dataformats import BoardType
from c3nav.mesh.messages import ChipType, ConfigFirmwareMessage, ConfigHardwareMessage
from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage
from c3nav.mesh.messages import ConfigFirmwareMessage, ConfigHardwareMessage
from c3nav.mesh.messages import MeshMessage as MeshMessage
from c3nav.mesh.messages import MeshMessageType
from c3nav.mesh.utils import UPLINK_TIMEOUT
@ -410,6 +410,11 @@ class FirmwareBuild(models.Model):
for board in self.boards
]
@cached_property
def firmware_image(self) -> FirmwareImage:
firmware_image, remaining = FirmwareImage.decode(self.binary.open('rb').read()[:FirmwareImage.get_min_size()])
return firmware_image
class FirmwareBuildBoard(models.Model):
BOARDS = [(boardtype.name, boardtype.pretty_name) for boardtype in BoardType]

View file

@ -7,11 +7,10 @@ from ninja import Schema, UploadedFile
from ninja.pagination import paginate
from pydantic import validator
from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed
from c3nav.api.exceptions import APIConflict, APIRequestValidationFailed, API404
from c3nav.api.newauth import BearerAuth, auth_permission_responses, auth_responses
from c3nav.mesh.dataformats import BoardType
from c3nav.mesh.messages import ChipType
from c3nav.mesh.models import FirmwareVersion
from c3nav.mesh.dataformats import BoardType, FirmwareImage, ChipType
from c3nav.mesh.models import FirmwareVersion, FirmwareBuild
api_router = APIRouter(tags=["mesh"])
@ -22,18 +21,14 @@ class FirmwareBuildSchema(Schema):
chip: ChipType = APIField(..., example=ChipType.ESP32_C3.name)
sha256_hash: str = APIField(..., regex=r"^[0-9a-f]{64}$")
url: str = APIField(..., alias="binary", example="/media/firmware/012345/firmware.bin")
boards: list[BoardType] = APIField(..., example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ])
# todo: should not be none, but parse errors
boards: list[BoardType] = APIField(None, example=[BoardType.C3NAV_LOCATION_PCB_REV_0_2.name, ])
@staticmethod
def resolve_chip(obj):
# todo: do this in model? idk
return ChipType(obj.chip)
@staticmethod
def resolve_boards(obj):
print(obj.boards)
return obj.boards
class FirmwareSchema(Schema):
id: int
@ -62,12 +57,33 @@ def firmware_list(request):
@api_router.get('/firmwares/{firmware_id}/', summary="Get specific firmware",
response={200: FirmwareSchema, **auth_responses})
response={200: FirmwareSchema, **API404.dict(), **auth_responses})
def firmware_detail(request, firmware_id: int):
try:
return FirmwareVersion.objects.get(id=firmware_id)
except FirmwareVersion.DoesNotExist:
return 404, {"detail": "firmware not found"}
raise API404("Firmware not found")
@api_router.get('/firmwares/{firmware_id}/{variant}/image_data',
summary="Get header data of firmware build image",
response={200: FirmwareImage.schema, **API404.dict(), **auth_responses})
def firmware_build_image(request, firmware_id: int, variant: str):
try:
build = FirmwareBuild.objects.get(version_id=firmware_id, variant=variant)
return FirmwareImage.tojson(build.firmware_image)
except FirmwareVersion.DoesNotExist:
raise API404("Firmware or firmware build not found")
@api_router.get('/firmwares/{firmware_id}/{variant}/project_description',
summary="Get project description of firmware build",
response={200: dict, **API404.dict(), **auth_responses})
def firmware_project_description(request, firmware_id: int, variant: str):
try:
return FirmwareBuild.objects.get(version_id=firmware_id, variant=variant).firmware_description
except FirmwareVersion.DoesNotExist:
raise API404("Firmware or firmware build not found")
class UploadFirmwareBuildSchema(Schema):