From 502e251b5481c91d1a39e480c07e35212624a931 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Fri, 20 Oct 2023 16:22:32 +0200 Subject: [PATCH] fix generated c code --- src/c3nav/mesh/baseformats.py | 22 +++++++++- src/c3nav/mesh/messages.py | 80 ++++++++++++++++++++++++++++++++--- 2 files changed, 95 insertions(+), 7 deletions(-) diff --git a/src/c3nav/mesh/baseformats.py b/src/c3nav/mesh/baseformats.py index 638e18ae..e30e7069 100644 --- a/src/c3nav/mesh/baseformats.py +++ b/src/c3nav/mesh/baseformats.py @@ -86,9 +86,10 @@ class SimpleFormat(BaseFormat): class EnumFormat(SimpleFormat): - def __init__(self, fmt="B", *, as_hex=False): + def __init__(self, fmt="B", *, as_hex=False, c_definition=True): super().__init__(fmt) self.as_hex = as_hex + self.c_definition = c_definition def set_field_type(self, field_type): super().set_field_type(field_type) @@ -99,6 +100,8 @@ class EnumFormat(SimpleFormat): return self.field_type(value), out_data def get_c_parts(self): + if not self.c_definition: + return super().get_c_parts() return self.c_struct_name, "" def fromjson(self, data): @@ -108,6 +111,8 @@ class EnumFormat(SimpleFormat): return data.name def get_c_definitions(self) -> dict[str, str]: + if not self.c_definition: + return {} prefix = normalize_name(self.field_type.__name__).upper() options = [] last_value = None @@ -224,6 +229,21 @@ class VarStrFormat(BaseVarFormat): return super().get_num_c_code() + "\n" + "char", "[0]" +class VarBytesFormat(BaseVarFormat): + def get_var_num(self): + return 1 + + def encode(self, value: bytes) -> bytes: + return struct.pack(self.num_fmt, len(value)) + value + + def decode(self, data: bytes) -> tuple[bytes, bytes]: + num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:] + + def get_c_parts(self): + return super().get_num_c_code() + "\n" + "uint8_t", "[0]" + + @dataclass class StructType: _union_options = {} diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index 23069922..3a3be1f7 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -5,8 +5,8 @@ from typing import TypeVar import channels from asgiref.sync import async_to_sync -from c3nav.mesh.baseformats import (BoolFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat, VarStrFormat, - normalize_name, EnumFormat) +from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat, + VarBytesFormat, VarStrFormat, normalize_name) from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat, RangeItemType) from c3nav.mesh.utils import get_mesh_comm_group @@ -42,8 +42,17 @@ class MeshMessageType(IntEnum): CONFIG_UPLINK = 0x14 CONFIG_POSITION = 0x15 - LOCATE_REQUEST_RANGE = 0x20 - LOCATE_RANGE_RESULTS = 0x21 + OTA_STATUS = 0x20 + OTA_REQUEST_STATUS = 0x21 + OTA_START = 0x22 + OTA_URL = 0x23 + OTA_FRAGMENT = 0x24 + OTA_REQUEST_FRAGMENT = 0x25 + OTA_APPLY = 0x26 + OTA_REBOOT = 0x27 + + LOCATE_REQUEST_RANGE = 0x30 + LOCATE_RANGE_RESULTS = 0x31 @property def pretty_name(self): @@ -71,7 +80,7 @@ class ChipType(IntEnum): class MeshMessage(StructType, union_type_field="msg_type"): dst: str = field(metadata={"format": MacAddressFormat()}) src: str = field(metadata={"format": MacAddressFormat()}) - msg_type: MeshMessageType = field(metadata={"format": EnumFormat('B')}, init=False, repr=False) + msg_type: MeshMessageType = field(metadata={"format": EnumFormat('B', c_definition=False)}, init=False, repr=False) c_structs = {} c_struct_name = None @@ -220,7 +229,7 @@ class ConfigDumpMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_DUMP): class ConfigHardwareMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_HARDWARE): """ respond hardware/chip info """ chip: ChipType = field(metadata={ - "format": EnumFormat("H"), + "format": EnumFormat("H", c_definition=False), "c_name": "chip_id", }) revision_major: int = field(metadata={"format": SimpleFormat('B')}) @@ -263,6 +272,65 @@ class ConfigUplinkMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_UPLINK): port: int = field(metadata={"format": SimpleFormat('H')}) +@dataclass +class OTAStatusMessage(MeshMessage, msg_type=MeshMessageType.OTA_STATUS): + """ report OTA status """ + source: bool = field(metadata={"format": BoolFormat()}) + update_id: int = field(metadata={"format": SimpleFormat('I')}) + total_bytes: int = field(metadata={"format": SimpleFormat('I')}) + received_bytes: int = field(metadata={"format": SimpleFormat('I')}) + auto_apply: int = field(metadata={"format": SimpleFormat('I')}) + app_desc: FirmwareAppDescription = field() + + +@dataclass +class OTARequestStatusMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_STATUS): + """ request OTA status """ + + +@dataclass +class OTAStartMessage(MeshMessage, msg_type=MeshMessageType.OTA_START): + """ instruct node to start OTA """ + update_id: int = field(metadata={"format": SimpleFormat('I')}) + total_bytes: int = field(metadata={"format": SimpleFormat('I')}) + + +@dataclass +class OTAURLMessage(MeshMessage, msg_type=MeshMessageType.OTA_URL): + """ supply download URL for OTA update and who to distribute it to """ + update_id: int = field(metadata={"format": SimpleFormat('I')}) + distribute_to: str = field(metadata={"format": MacAddressFormat()}) + url: str = field(metadata={"format": VarStrFormat()}) + + +@dataclass +class OTAFragmentMessage(MeshMessage, msg_type=MeshMessageType.OTA_FRAGMENT): + """ supply OTA fragment """ + update_id: int = field(metadata={"format": SimpleFormat('I')}) + offset_bytes: int = field(metadata={"format": SimpleFormat('I')}) + data: str = field(metadata={"format": VarBytesFormat()}) + + +@dataclass +class OTAFRequestMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_FRAGMENT): + """ request fragment after we haven't gottan one for a while """ + update_id: int = field(metadata={"format": SimpleFormat('I')}) + offset_bytes: int = field(metadata={"format": SimpleFormat('I')}) + + +@dataclass +class OTAApplyMessage(MeshMessage, msg_type=MeshMessageType.OTA_APPLY): + """ apply OTA, optionally instruct to apply it when done """ + update_id: int = field(metadata={"format": SimpleFormat('I')}) + when_done: bool = field(metadata={"format": BoolFormat()}) + + +@dataclass +class OTARebootMessage(MeshMessage, msg_type=MeshMessageType.OTA_REBOOT): + """ announcing OTA reboot """ + update_id: int = field(metadata={"format": SimpleFormat('I')}) + + @dataclass class LocateRequestRangeMessage(MeshMessage, msg_type=MeshMessageType.LOCATE_REQUEST_RANGE): """ request to report distance to all nearby nodes """