From e0c9f36c668e8f37df6695112451a26413b25a1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Wed, 4 Oct 2023 22:25:15 +0200 Subject: [PATCH] auto-generate code for mesh_msg.h --- src/c3nav/mesh/consumers.py | 2 +- src/c3nav/mesh/dataformats.py | 70 ++++++ src/c3nav/mesh/forms.py | 1 - src/c3nav/mesh/management/__init__.py | 0 .../mesh/management/commands/__init__.py | 0 .../mesh/management/commands/mesh_msg_c.py | 78 +++++++ src/c3nav/mesh/messages.py | 215 ++++++++++++++---- src/c3nav/{control/views => mesh}/utils.py | 4 + 8 files changed, 325 insertions(+), 45 deletions(-) create mode 100644 src/c3nav/mesh/management/__init__.py create mode 100644 src/c3nav/mesh/management/commands/__init__.py create mode 100644 src/c3nav/mesh/management/commands/mesh_msg_c.py rename src/c3nav/{control/views => mesh}/utils.py (56%) diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 2989c65e..1d3b9fe0 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -4,7 +4,7 @@ from asgiref.sync import async_to_sync from channels.generic.websocket import WebsocketConsumer from django.utils import timezone -from c3nav.control.views.utils import get_mesh_comm_group +from c3nav.mesh.utils import get_mesh_comm_group from c3nav.mesh import messages from c3nav.mesh.messages import MeshMessage, BROADCAST_ADDRESS from c3nav.mesh.models import MeshNode, NodeMessage diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py index 40516950..0a3e8183 100644 --- a/src/c3nav/mesh/dataformats.py +++ b/src/c3nav/mesh/dataformats.py @@ -19,6 +19,23 @@ class SimpleFormat: value = value[0] return value, data[self.size:] + c_types = { + "B": "uint8_t", + "H": "uint16_t", + "I": "uint32_t", + "b": "int8_t", + "h": "int16_t", + "i": "int32_t", + } + + def get_c_struct(self, name): + c_type = self.c_types[self.fmt[-1]] + num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1 + if num == 1: + return "%s %s;" % (c_type, name) + else: + return "%s %s[%d];" % (c_type, name, num) + class FixedStrFormat: def __init__(self, num): @@ -30,6 +47,12 @@ class FixedStrFormat: def decode(self, data: bytes): return struct.unpack('%ss' % self.num, data[:self.num])[0].rstrip(bytes((0, ))).decode(), data[self.num:] + def get_c_struct(self, name): + return "char %(name)s[%(length)d];" % { + "name": name, + "length": self.num, + } + class BoolFormat: def encode(self, value): @@ -38,6 +61,11 @@ class BoolFormat: def decode(self, data: bytes): return bool(struct.unpack('B', data[:1])[0]), data[1:] + def get_c_struct(self, name): + return "uint8_t %(name)s;" % { + "name": name, + } + class HexFormat: def __init__(self, num, sep=''): @@ -53,14 +81,27 @@ class HexFormat: data[self.num:] ) + def get_c_struct(self, name): + return "uint8_t %(name)s[%(length)d];" % { + "name": name, + "length": self.num, + } + class VarStrFormat: + var_num = 1 + def encode(self, value: str) -> bytes: return bytes((len(value)+1, )) + value.encode() + bytes((0, )) def decode(self, data: bytes): return data[1:data[0]].decode(), data[data[0]+1:] + def get_c_struct(self, name): + return "uint8_t num;\nchar %(name)s[0];" % { + "name": name, + } + class MacAddressFormat: def encode(self, value: str) -> bytes: @@ -69,8 +110,15 @@ class MacAddressFormat: def decode(self, data: bytes): return (MAC_FMT % tuple(data[:6])), data[6:] + def get_c_struct(self, name): + return "uint8_t %(name)s[6];" % { + "name": name, + } + class MacAddressesListFormat: + var_num = 6 + def encode(self, value: list[str]) -> bytes: return bytes((len(value), )) + sum( (bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value), @@ -80,6 +128,11 @@ class MacAddressesListFormat: def decode(self, data: bytes): return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6:] + def get_c_struct(self, name): + return "uint8_t num;\nuint8_t %(name)s[6][0];" % { + "name": name, + } + class LedType(IntEnum): SERIAL = 1 @@ -138,3 +191,20 @@ class LedConfigFormat: else: raise ValueError return value, data[4:] + + def get_c_struct(self, name): + return ( + "uint8_t type;\n" + "union {\n" + " struct {\n" + " uint8_t gpio;\n" + " uint8_t rmt;\n" + " } serial;\n" + " struct {\n" + " uint8_t gpio_red;\n" + " uint8_t gpio_green;\n" + " uint8_t gpio_blue;\n" + " } multipin;\n" + " uint8_t bytes[3];\n" + "};" + ) diff --git a/src/c3nav/mesh/forms.py b/src/c3nav/mesh/forms.py index 6a386565..dc62d826 100644 --- a/src/c3nav/mesh/forms.py +++ b/src/c3nav/mesh/forms.py @@ -166,4 +166,3 @@ class ConfigPositionMessageForm(MeshMessageForm): x_pos = forms.IntegerField(min_value=0, max_value=2**16-1, label=_('X')) y_pos = forms.IntegerField(min_value=0, max_value=2 ** 16 - 1, label=_('Y')) z_pos = forms.IntegerField(min_value=0, max_value=2 ** 16 - 1, label=_('Z')) - diff --git a/src/c3nav/mesh/management/__init__.py b/src/c3nav/mesh/management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/c3nav/mesh/management/commands/__init__.py b/src/c3nav/mesh/management/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/c3nav/mesh/management/commands/mesh_msg_c.py b/src/c3nav/mesh/management/commands/mesh_msg_c.py new file mode 100644 index 00000000..36e8f6d6 --- /dev/null +++ b/src/c3nav/mesh/management/commands/mesh_msg_c.py @@ -0,0 +1,78 @@ +from django.core.management.base import BaseCommand + +from c3nav.mesh.messages import MeshMessage +from c3nav.mesh.utils import indent_c + + +class Command(BaseCommand): + help = 'export mesh message structs for c code' + + def handle(self, *args, **options): + done_struct_names = set() + nodata = set() + struct_lines = { + "num": "uint8_t num; /** defined here for convenience. subtypes that use it will define it themselves */" + } + + for msg_type in MeshMessage.msg_types.values(): + if msg_type.c_struct_name: + if msg_type.c_struct_name in done_struct_names: + continue + done_struct_names.add(msg_type.c_struct_name) + msg_type = MeshMessage.c_structs[msg_type.c_struct_name] + + code = msg_type.get_c_struct() + if code: + struct_lines[msg_type.get_c_struct_name()] = ( + "mesh_msg_%s_t %s;" % ( + msg_type.get_c_struct_name(), + msg_type.get_c_struct_name().replace("_announce", ""), + ) + ) + print(code) + print() + else: + nodata.add(msg_type) + + print("/** union between all message data structs */") + print("typedef union __packed {") + for line in struct_lines.values(): + print(indent_c(line)) + print("} mesh_msg_data_t;") + print() + + max_msg_type = max(MeshMessage.msg_types.keys()) + macro_data = [] + for i in range(((max_msg_type//16)+1)*16): + msg_type = MeshMessage.msg_types.get(i, None) + if msg_type: + macro_data.append(( + msg_type.get_c_enum_name()+',', + ("nodata" if msg_type in nodata else msg_type.get_c_struct_name())+',', + msg_type.get_var_num(), + msg_type.__doc__.strip(), + )) + else: + macro_data.append(( + "RESERVED_%02X," % i, + "nodata,", + 0, + "", + )) + + max0 = max(len(d[0]) for d in macro_data) + max1 = max(len(d[1]) for d in macro_data) + max2 = max(len(str(d[2])) for d in macro_data) + lines = [] + for i, (macro_name, struct_name, num_len, comment) in enumerate(macro_data): + lines.append(indent_c( + "FN(%s %s %s) /** 0x%02X %s*/" % ( + macro_name.ljust(max0), + struct_name.ljust(max1), + str(num_len).rjust(max2), + i, + comment+(" " if comment else ""), + ) + )) + print("#define FOR_ALL_MESH_MSG_TYPES(FN) \\") + print(" \\\n".join(lines)) diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index d28136e0..24b7f597 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -1,11 +1,13 @@ +import re from dataclasses import asdict, dataclass, field, fields, is_dataclass from enum import IntEnum, unique +from itertools import chain from typing import TypeVar import channels from asgiref.sync import async_to_sync -from c3nav.control.views.utils import get_mesh_comm_group +from c3nav.mesh.utils import get_mesh_comm_group, indent_c from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, HexFormat, LedConfig, LedConfigFormat, MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat) @@ -17,6 +19,8 @@ NO_LAYER = 0xFF @unique class MeshMessageType(IntEnum): + NOOP = 0x00 + ECHO_REQUEST = 0x01 ECHO_RESPONSE = 0x02 @@ -32,7 +36,7 @@ class MeshMessageType(IntEnum): CONFIG_UPLINK = 0x14 -M = TypeVar('M', bound='Message') +M = TypeVar('M', bound='MeshMessage') @unique @@ -43,24 +47,31 @@ class ChipType(IntEnum): @dataclass class MeshMessage: - dst: str = field(metadata={'format': MacAddressFormat()}) - src: str = field(metadata={'format': MacAddressFormat()}) - msg_id: int = field(metadata={'format': SimpleFormat('B')}, init=False, repr=False) + dst: str = field(metadata={"format": MacAddressFormat()}) + src: str = field(metadata={"format": MacAddressFormat()}) + msg_id: int = field(metadata={"format": SimpleFormat('B')}, init=False, repr=False) msg_types = {} + c_structs = {} + c_struct_name = None # noinspection PyMethodOverriding - def __init_subclass__(cls, /, msg_id=None, **kwargs): + def __init_subclass__(cls, /, msg_id=None, c_struct_name=None, **kwargs): super().__init_subclass__(**kwargs) - if msg_id: + if msg_id is not None: cls.msg_id = msg_id if msg_id in MeshMessage.msg_types: raise TypeError('duplicate use of msg_id %d' % msg_id) MeshMessage.msg_types[msg_id] = cls + if c_struct_name: + cls.c_struct_name = c_struct_name + if c_struct_name in MeshMessage.c_structs: + raise TypeError('duplicate use of c_struct_name %s' % c_struct_name) + MeshMessage.c_structs[c_struct_name] = cls def encode(self): data = bytes() for field_ in fields(self): - data += field_.metadata['format'].encode(getattr(self, field_.name)) + data += field_.metadata["format"].encode(getattr(self, field_.name)) return data @classmethod @@ -68,7 +79,7 @@ class MeshMessage: klass = cls.msg_types[data[12]] values = {} for field_ in fields(klass): - values[field_.name], data = field_.metadata['format'].decode(data) + values[field_.name], data = field_.metadata["format"].decode(data) values.pop('msg_id') return klass(**values) @@ -90,56 +101,160 @@ class MeshMessage: "msg": self.tojson() }) + @classmethod + def get_ignore_c_fields(self): + return set() -@dataclass -class EchoRequestMessage(MeshMessage, msg_id=MeshMessageType.ECHO_REQUEST): - content: str = field(default='', metadata={'format': VarStrFormat()}) + @classmethod + def get_additional_c_fields(self): + return () + + @classmethod + def get_c_struct(cls): + ignore_fields = cls.get_ignore_c_fields() + if cls != MeshMessage: + ignore_fields |= set(field.name for field in fields(MeshMessage)) + + items = tuple( + ( + tuple(field.metadata["format"].get_c_struct(field.metadata.get("c_name", field.name)).split("\n")), + field.metadata.get("doc", None), + ) + for field in fields(cls) + if field.name not in ignore_fields + ) + if not items: + return "" + max_line_len = max(len(line) for line in chain(*(code for code, doc in items))) + + msg_comment = cls.__doc__.strip() + + return "%(comment)stypedef struct __packed {\n%(elements)s\n} %(name)s;" % { + "comment": ("/** %s */\n" % msg_comment) if msg_comment else "", + "elements": indent_c( + "\n".join(chain(*( + (code if not comment + else (code[:-1]+("%s /** %s */" % (code[-1].ljust(max_line_len), comment),))) + for code, comment in items + ), cls.get_additional_c_fields())) + ), + "name": "mesh_msg_%s_t" % cls.get_c_struct_name(), + } + + @classmethod + def get_var_num(cls): + return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0) + + @classmethod + def get_c_struct_name(cls): + return ( + cls.c_struct_name if cls.c_struct_name else + re.sub( + r"([a-z])([A-Z])", + r"\1_\2", + cls.__name__.removeprefix('Mesh').removesuffix('Message') + ).lower().replace('config', 'cfg').replace('firmware', 'fw').replace('position', 'pos') + ) + + @classmethod + def get_c_enum_name(cls): + return re.sub( + r"([a-z])([A-Z])", + r"\1_\2", + cls.__name__.removeprefix('Mesh').removesuffix('Message') + ).upper().replace('CONFIG', 'CFG').replace('FIRMWARE', 'FW').replace('POSITION', 'POS') @dataclass -class EchoResponseMessage(MeshMessage, msg_id=MeshMessageType.ECHO_RESPONSE): - content: str = field(default='', metadata={'format': VarStrFormat()}) +class NoopMessage(MeshMessage, msg_id=MeshMessageType.NOOP): + """ noop """ + pass + + +@dataclass +class BaseEchoMessage(MeshMessage, c_struct_name="echo"): + """ repeat back string """ + content: str = field(default='', metadata={ + "format": VarStrFormat(), + "doc": "string to echo", + "c_name": "str", + }) + + +@dataclass +class EchoRequestMessage(BaseEchoMessage, msg_id=MeshMessageType.ECHO_REQUEST): + """ repeat back string """ + pass + + +@dataclass +class EchoResponseMessage(BaseEchoMessage, msg_id=MeshMessageType.ECHO_RESPONSE): + """ repeat back string """ + pass @dataclass class MeshSigninMessage(MeshMessage, msg_id=MeshMessageType.MESH_SIGNIN): + """ node says hello to upstream node """ pass @dataclass class MeshLayerAnnounceMessage(MeshMessage, msg_id=MeshMessageType.MESH_LAYER_ANNOUNCE): - layer: int = field(metadata={'format': SimpleFormat('B')}) + """ upstream node announces layer number """ + layer: int = field(metadata={ + "format": SimpleFormat('B'), + "doc": "mesh layer that the sending node is on", + }) @dataclass -class MeshAddDestinationsMessage(MeshMessage, msg_id=MeshMessageType.MESH_ADD_DESTINATIONS): - mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()}) +class BaseDestinationsMessage(MeshMessage, c_struct_name="destinations"): + """ downstream node announces served/no longer served destination """ + mac_addresses: list[str] = field(default_factory=list, metadata={ + "format": MacAddressesListFormat(), + "doc": "mac adresses of the destinations", + "c_name": "mac", + }) @dataclass -class MeshRemoveDestinationsMessage(MeshMessage, msg_id=MeshMessageType.MESH_REMOVE_DESTINATIONS): - mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()}) +class MeshAddDestinationsMessage(BaseDestinationsMessage, msg_id=MeshMessageType.MESH_ADD_DESTINATIONS): + """ downstream node announces served destination """ + pass + + +@dataclass +class MeshRemoveDestinationsMessage(BaseDestinationsMessage, msg_id=MeshMessageType.MESH_REMOVE_DESTINATIONS): + """ downstream node announces no longer served destination """ + pass @dataclass class ConfigDumpMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_DUMP): + """ request for the node to dump its config """ pass @dataclass class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE): - chip: int = field(metadata={'format': SimpleFormat('H')}) - revision: int = field(metadata={'format': SimpleFormat('2B')}) - magic_word: int = field(metadata={'format': SimpleFormat('I')}, repr=False) - secure_version: int = field(metadata={'format': SimpleFormat('I')}) - reserv1: list[int] = field(metadata={'format': SimpleFormat('2I')}, repr=False) - version: str = field(metadata={'format': FixedStrFormat(32)}) - project_name: str = field(metadata={'format': FixedStrFormat(32)}) - compile_time: str = field(metadata={'format': FixedStrFormat(16)}) - compile_date: str = field(metadata={'format': FixedStrFormat(16)}) - idf_version: str = field(metadata={'format': FixedStrFormat(32)}) - app_elf_sha256: str = field(metadata={'format': HexFormat(32)}) - reserv2: list[int] = field(metadata={'format': SimpleFormat('20I')}, repr=False) + """ respond firmware info """ + chip: int = field(metadata={ + "format": SimpleFormat('H'), + "c_name": "chip_id", + }) + revision_major: int = field(metadata={"format": SimpleFormat('B')}) + revision_minor: int = field(metadata={"format": SimpleFormat('B')}) + magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False) + secure_version: int = field(metadata={"format": SimpleFormat('I')}) + reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False) + version: str = field(metadata={"format": FixedStrFormat(32)}) + project_name: str = field(metadata={"format": FixedStrFormat(32)}) + compile_time: str = field(metadata={"format": FixedStrFormat(16)}) + compile_date: str = field(metadata={"format": FixedStrFormat(16)}) + idf_version: str = field(metadata={"format": FixedStrFormat(32)}) + app_elf_sha256: str = field(metadata={"format": HexFormat(32)}) + reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False) def to_model_data(self): return { @@ -153,26 +268,40 @@ class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE) def get_chip_display(self): return ChipType(self.chip).name.replace('_', '-') + @classmethod + def get_ignore_c_fields(self): + return { + "magic_word", "secure_version", "reserv1", "version", "project_name", + "compile_time", "compile_date", "idf_version", "app_elf_sha256", "reserv2" + } + + @classmethod + def get_additional_c_fields(self): + return ("esp_app_desc_t app_desc;", ) + @dataclass class ConfigPositionMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_POSITION): - x_pos: int = field(metadata={'format': SimpleFormat('I')}) - y_pos: int = field(metadata={'format': SimpleFormat('I')}) - z_pos: int = field(metadata={'format': SimpleFormat('H')}) + """ set/respond position config """ + x_pos: int = field(metadata={"format": SimpleFormat('i')}) + y_pos: int = field(metadata={"format": SimpleFormat('i')}) + z_pos: int = field(metadata={"format": SimpleFormat('h')}) @dataclass class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED): - led_config: LedConfig = field(metadata={'format': LedConfigFormat()}) + """ set/respond led config """ + led_config: LedConfig = field(metadata={"format": LedConfigFormat()}) @dataclass class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK): - enabled: bool = field(metadata={'format': BoolFormat()}) - ssid: str = field(metadata={'format': FixedStrFormat(32)}) - password: str = field(metadata={'format': FixedStrFormat(64)}) - channel: int = field(metadata={'format': SimpleFormat('B')}) - udp: bool = field(metadata={'format': BoolFormat()}) - ssl: bool = field(metadata={'format': BoolFormat()}) - host: str = field(metadata={'format': FixedStrFormat(64)}) - port: int = field(metadata={'format': SimpleFormat('H')}) + """ set/respond uplink config """ + enabled: bool = field(metadata={"format": BoolFormat()}) + ssid: str = field(metadata={"format": FixedStrFormat(32)}) + password: str = field(metadata={"format": FixedStrFormat(64)}) + channel: int = field(metadata={"format": SimpleFormat('B')}) + udp: bool = field(metadata={"format": BoolFormat()}) + ssl: bool = field(metadata={"format": BoolFormat()}) + host: str = field(metadata={"format": FixedStrFormat(64)}) + port: int = field(metadata={"format": SimpleFormat('H')}) diff --git a/src/c3nav/control/views/utils.py b/src/c3nav/mesh/utils.py similarity index 56% rename from src/c3nav/control/views/utils.py rename to src/c3nav/mesh/utils.py index 061759b3..1408f6a5 100644 --- a/src/c3nav/control/views/utils.py +++ b/src/c3nav/mesh/utils.py @@ -1,2 +1,6 @@ def get_mesh_comm_group(address): return 'mesh_comm_%s' % address.replace(':', '-') + + +def indent_c(code): + return " "+code.replace("\n", "\n ")