auto-generate code for mesh_msg.h

This commit is contained in:
Laura Klünder 2023-10-04 22:25:15 +02:00
parent e0961c16c1
commit e0c9f36c66
8 changed files with 325 additions and 45 deletions

View file

@ -4,7 +4,7 @@ from asgiref.sync import async_to_sync
from channels.generic.websocket import WebsocketConsumer
from django.utils import timezone
from c3nav.control.views.utils import get_mesh_comm_group
from c3nav.mesh.utils import get_mesh_comm_group
from c3nav.mesh import messages
from c3nav.mesh.messages import MeshMessage, BROADCAST_ADDRESS
from c3nav.mesh.models import MeshNode, NodeMessage

View file

@ -19,6 +19,23 @@ class SimpleFormat:
value = value[0]
return value, data[self.size:]
c_types = {
"B": "uint8_t",
"H": "uint16_t",
"I": "uint32_t",
"b": "int8_t",
"h": "int16_t",
"i": "int32_t",
}
def get_c_struct(self, name):
c_type = self.c_types[self.fmt[-1]]
num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
if num == 1:
return "%s %s;" % (c_type, name)
else:
return "%s %s[%d];" % (c_type, name, num)
class FixedStrFormat:
def __init__(self, num):
@ -30,6 +47,12 @@ class FixedStrFormat:
def decode(self, data: bytes):
return struct.unpack('%ss' % self.num, data[:self.num])[0].rstrip(bytes((0, ))).decode(), data[self.num:]
def get_c_struct(self, name):
return "char %(name)s[%(length)d];" % {
"name": name,
"length": self.num,
}
class BoolFormat:
def encode(self, value):
@ -38,6 +61,11 @@ class BoolFormat:
def decode(self, data: bytes):
return bool(struct.unpack('B', data[:1])[0]), data[1:]
def get_c_struct(self, name):
return "uint8_t %(name)s;" % {
"name": name,
}
class HexFormat:
def __init__(self, num, sep=''):
@ -53,14 +81,27 @@ class HexFormat:
data[self.num:]
)
def get_c_struct(self, name):
return "uint8_t %(name)s[%(length)d];" % {
"name": name,
"length": self.num,
}
class VarStrFormat:
var_num = 1
def encode(self, value: str) -> bytes:
return bytes((len(value)+1, )) + value.encode() + bytes((0, ))
def decode(self, data: bytes):
return data[1:data[0]].decode(), data[data[0]+1:]
def get_c_struct(self, name):
return "uint8_t num;\nchar %(name)s[0];" % {
"name": name,
}
class MacAddressFormat:
def encode(self, value: str) -> bytes:
@ -69,8 +110,15 @@ class MacAddressFormat:
def decode(self, data: bytes):
return (MAC_FMT % tuple(data[:6])), data[6:]
def get_c_struct(self, name):
return "uint8_t %(name)s[6];" % {
"name": name,
}
class MacAddressesListFormat:
var_num = 6
def encode(self, value: list[str]) -> bytes:
return bytes((len(value), )) + sum(
(bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value),
@ -80,6 +128,11 @@ class MacAddressesListFormat:
def decode(self, data: bytes):
return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6:]
def get_c_struct(self, name):
return "uint8_t num;\nuint8_t %(name)s[6][0];" % {
"name": name,
}
class LedType(IntEnum):
SERIAL = 1
@ -138,3 +191,20 @@ class LedConfigFormat:
else:
raise ValueError
return value, data[4:]
def get_c_struct(self, name):
return (
"uint8_t type;\n"
"union {\n"
" struct {\n"
" uint8_t gpio;\n"
" uint8_t rmt;\n"
" } serial;\n"
" struct {\n"
" uint8_t gpio_red;\n"
" uint8_t gpio_green;\n"
" uint8_t gpio_blue;\n"
" } multipin;\n"
" uint8_t bytes[3];\n"
"};"
)

View file

@ -166,4 +166,3 @@ class ConfigPositionMessageForm(MeshMessageForm):
x_pos = forms.IntegerField(min_value=0, max_value=2**16-1, label=_('X'))
y_pos = forms.IntegerField(min_value=0, max_value=2 ** 16 - 1, label=_('Y'))
z_pos = forms.IntegerField(min_value=0, max_value=2 ** 16 - 1, label=_('Z'))

View file

View file

@ -0,0 +1,78 @@
from django.core.management.base import BaseCommand
from c3nav.mesh.messages import MeshMessage
from c3nav.mesh.utils import indent_c
class Command(BaseCommand):
help = 'export mesh message structs for c code'
def handle(self, *args, **options):
done_struct_names = set()
nodata = set()
struct_lines = {
"num": "uint8_t num; /** defined here for convenience. subtypes that use it will define it themselves */"
}
for msg_type in MeshMessage.msg_types.values():
if msg_type.c_struct_name:
if msg_type.c_struct_name in done_struct_names:
continue
done_struct_names.add(msg_type.c_struct_name)
msg_type = MeshMessage.c_structs[msg_type.c_struct_name]
code = msg_type.get_c_struct()
if code:
struct_lines[msg_type.get_c_struct_name()] = (
"mesh_msg_%s_t %s;" % (
msg_type.get_c_struct_name(),
msg_type.get_c_struct_name().replace("_announce", ""),
)
)
print(code)
print()
else:
nodata.add(msg_type)
print("/** union between all message data structs */")
print("typedef union __packed {")
for line in struct_lines.values():
print(indent_c(line))
print("} mesh_msg_data_t;")
print()
max_msg_type = max(MeshMessage.msg_types.keys())
macro_data = []
for i in range(((max_msg_type//16)+1)*16):
msg_type = MeshMessage.msg_types.get(i, None)
if msg_type:
macro_data.append((
msg_type.get_c_enum_name()+',',
("nodata" if msg_type in nodata else msg_type.get_c_struct_name())+',',
msg_type.get_var_num(),
msg_type.__doc__.strip(),
))
else:
macro_data.append((
"RESERVED_%02X," % i,
"nodata,",
0,
"",
))
max0 = max(len(d[0]) for d in macro_data)
max1 = max(len(d[1]) for d in macro_data)
max2 = max(len(str(d[2])) for d in macro_data)
lines = []
for i, (macro_name, struct_name, num_len, comment) in enumerate(macro_data):
lines.append(indent_c(
"FN(%s %s %s) /** 0x%02X %s*/" % (
macro_name.ljust(max0),
struct_name.ljust(max1),
str(num_len).rjust(max2),
i,
comment+(" " if comment else ""),
)
))
print("#define FOR_ALL_MESH_MSG_TYPES(FN) \\")
print(" \\\n".join(lines))

View file

@ -1,11 +1,13 @@
import re
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from enum import IntEnum, unique
from itertools import chain
from typing import TypeVar
import channels
from asgiref.sync import async_to_sync
from c3nav.control.views.utils import get_mesh_comm_group
from c3nav.mesh.utils import get_mesh_comm_group, indent_c
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, HexFormat, LedConfig, LedConfigFormat,
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat)
@ -17,6 +19,8 @@ NO_LAYER = 0xFF
@unique
class MeshMessageType(IntEnum):
NOOP = 0x00
ECHO_REQUEST = 0x01
ECHO_RESPONSE = 0x02
@ -32,7 +36,7 @@ class MeshMessageType(IntEnum):
CONFIG_UPLINK = 0x14
M = TypeVar('M', bound='Message')
M = TypeVar('M', bound='MeshMessage')
@unique
@ -43,24 +47,31 @@ class ChipType(IntEnum):
@dataclass
class MeshMessage:
dst: str = field(metadata={'format': MacAddressFormat()})
src: str = field(metadata={'format': MacAddressFormat()})
msg_id: int = field(metadata={'format': SimpleFormat('B')}, init=False, repr=False)
dst: str = field(metadata={"format": MacAddressFormat()})
src: str = field(metadata={"format": MacAddressFormat()})
msg_id: int = field(metadata={"format": SimpleFormat('B')}, init=False, repr=False)
msg_types = {}
c_structs = {}
c_struct_name = None
# noinspection PyMethodOverriding
def __init_subclass__(cls, /, msg_id=None, **kwargs):
def __init_subclass__(cls, /, msg_id=None, c_struct_name=None, **kwargs):
super().__init_subclass__(**kwargs)
if msg_id:
if msg_id is not None:
cls.msg_id = msg_id
if msg_id in MeshMessage.msg_types:
raise TypeError('duplicate use of msg_id %d' % msg_id)
MeshMessage.msg_types[msg_id] = cls
if c_struct_name:
cls.c_struct_name = c_struct_name
if c_struct_name in MeshMessage.c_structs:
raise TypeError('duplicate use of c_struct_name %s' % c_struct_name)
MeshMessage.c_structs[c_struct_name] = cls
def encode(self):
data = bytes()
for field_ in fields(self):
data += field_.metadata['format'].encode(getattr(self, field_.name))
data += field_.metadata["format"].encode(getattr(self, field_.name))
return data
@classmethod
@ -68,7 +79,7 @@ class MeshMessage:
klass = cls.msg_types[data[12]]
values = {}
for field_ in fields(klass):
values[field_.name], data = field_.metadata['format'].decode(data)
values[field_.name], data = field_.metadata["format"].decode(data)
values.pop('msg_id')
return klass(**values)
@ -90,56 +101,160 @@ class MeshMessage:
"msg": self.tojson()
})
@classmethod
def get_ignore_c_fields(self):
return set()
@dataclass
class EchoRequestMessage(MeshMessage, msg_id=MeshMessageType.ECHO_REQUEST):
content: str = field(default='', metadata={'format': VarStrFormat()})
@classmethod
def get_additional_c_fields(self):
return ()
@classmethod
def get_c_struct(cls):
ignore_fields = cls.get_ignore_c_fields()
if cls != MeshMessage:
ignore_fields |= set(field.name for field in fields(MeshMessage))
items = tuple(
(
tuple(field.metadata["format"].get_c_struct(field.metadata.get("c_name", field.name)).split("\n")),
field.metadata.get("doc", None),
)
for field in fields(cls)
if field.name not in ignore_fields
)
if not items:
return ""
max_line_len = max(len(line) for line in chain(*(code for code, doc in items)))
msg_comment = cls.__doc__.strip()
return "%(comment)stypedef struct __packed {\n%(elements)s\n} %(name)s;" % {
"comment": ("/** %s */\n" % msg_comment) if msg_comment else "",
"elements": indent_c(
"\n".join(chain(*(
(code if not comment
else (code[:-1]+("%s /** %s */" % (code[-1].ljust(max_line_len), comment),)))
for code, comment in items
), cls.get_additional_c_fields()))
),
"name": "mesh_msg_%s_t" % cls.get_c_struct_name(),
}
@classmethod
def get_var_num(cls):
return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0)
@classmethod
def get_c_struct_name(cls):
return (
cls.c_struct_name if cls.c_struct_name else
re.sub(
r"([a-z])([A-Z])",
r"\1_\2",
cls.__name__.removeprefix('Mesh').removesuffix('Message')
).lower().replace('config', 'cfg').replace('firmware', 'fw').replace('position', 'pos')
)
@classmethod
def get_c_enum_name(cls):
return re.sub(
r"([a-z])([A-Z])",
r"\1_\2",
cls.__name__.removeprefix('Mesh').removesuffix('Message')
).upper().replace('CONFIG', 'CFG').replace('FIRMWARE', 'FW').replace('POSITION', 'POS')
@dataclass
class EchoResponseMessage(MeshMessage, msg_id=MeshMessageType.ECHO_RESPONSE):
content: str = field(default='', metadata={'format': VarStrFormat()})
class NoopMessage(MeshMessage, msg_id=MeshMessageType.NOOP):
""" noop """
pass
@dataclass
class BaseEchoMessage(MeshMessage, c_struct_name="echo"):
""" repeat back string """
content: str = field(default='', metadata={
"format": VarStrFormat(),
"doc": "string to echo",
"c_name": "str",
})
@dataclass
class EchoRequestMessage(BaseEchoMessage, msg_id=MeshMessageType.ECHO_REQUEST):
""" repeat back string """
pass
@dataclass
class EchoResponseMessage(BaseEchoMessage, msg_id=MeshMessageType.ECHO_RESPONSE):
""" repeat back string """
pass
@dataclass
class MeshSigninMessage(MeshMessage, msg_id=MeshMessageType.MESH_SIGNIN):
""" node says hello to upstream node """
pass
@dataclass
class MeshLayerAnnounceMessage(MeshMessage, msg_id=MeshMessageType.MESH_LAYER_ANNOUNCE):
layer: int = field(metadata={'format': SimpleFormat('B')})
""" upstream node announces layer number """
layer: int = field(metadata={
"format": SimpleFormat('B'),
"doc": "mesh layer that the sending node is on",
})
@dataclass
class MeshAddDestinationsMessage(MeshMessage, msg_id=MeshMessageType.MESH_ADD_DESTINATIONS):
mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()})
class BaseDestinationsMessage(MeshMessage, c_struct_name="destinations"):
""" downstream node announces served/no longer served destination """
mac_addresses: list[str] = field(default_factory=list, metadata={
"format": MacAddressesListFormat(),
"doc": "mac adresses of the destinations",
"c_name": "mac",
})
@dataclass
class MeshRemoveDestinationsMessage(MeshMessage, msg_id=MeshMessageType.MESH_REMOVE_DESTINATIONS):
mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()})
class MeshAddDestinationsMessage(BaseDestinationsMessage, msg_id=MeshMessageType.MESH_ADD_DESTINATIONS):
""" downstream node announces served destination """
pass
@dataclass
class MeshRemoveDestinationsMessage(BaseDestinationsMessage, msg_id=MeshMessageType.MESH_REMOVE_DESTINATIONS):
""" downstream node announces no longer served destination """
pass
@dataclass
class ConfigDumpMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_DUMP):
""" request for the node to dump its config """
pass
@dataclass
class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE):
chip: int = field(metadata={'format': SimpleFormat('H')})
revision: int = field(metadata={'format': SimpleFormat('2B')})
magic_word: int = field(metadata={'format': SimpleFormat('I')}, repr=False)
secure_version: int = field(metadata={'format': SimpleFormat('I')})
reserv1: list[int] = field(metadata={'format': SimpleFormat('2I')}, repr=False)
version: str = field(metadata={'format': FixedStrFormat(32)})
project_name: str = field(metadata={'format': FixedStrFormat(32)})
compile_time: str = field(metadata={'format': FixedStrFormat(16)})
compile_date: str = field(metadata={'format': FixedStrFormat(16)})
idf_version: str = field(metadata={'format': FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={'format': HexFormat(32)})
reserv2: list[int] = field(metadata={'format': SimpleFormat('20I')}, repr=False)
""" respond firmware info """
chip: int = field(metadata={
"format": SimpleFormat('H'),
"c_name": "chip_id",
})
revision_major: int = field(metadata={"format": SimpleFormat('B')})
revision_minor: int = field(metadata={"format": SimpleFormat('B')})
magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
secure_version: int = field(metadata={"format": SimpleFormat('I')})
reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
version: str = field(metadata={"format": FixedStrFormat(32)})
project_name: str = field(metadata={"format": FixedStrFormat(32)})
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={"format": HexFormat(32)})
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
def to_model_data(self):
return {
@ -153,26 +268,40 @@ class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE)
def get_chip_display(self):
return ChipType(self.chip).name.replace('_', '-')
@classmethod
def get_ignore_c_fields(self):
return {
"magic_word", "secure_version", "reserv1", "version", "project_name",
"compile_time", "compile_date", "idf_version", "app_elf_sha256", "reserv2"
}
@classmethod
def get_additional_c_fields(self):
return ("esp_app_desc_t app_desc;", )
@dataclass
class ConfigPositionMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_POSITION):
x_pos: int = field(metadata={'format': SimpleFormat('I')})
y_pos: int = field(metadata={'format': SimpleFormat('I')})
z_pos: int = field(metadata={'format': SimpleFormat('H')})
""" set/respond position config """
x_pos: int = field(metadata={"format": SimpleFormat('i')})
y_pos: int = field(metadata={"format": SimpleFormat('i')})
z_pos: int = field(metadata={"format": SimpleFormat('h')})
@dataclass
class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED):
led_config: LedConfig = field(metadata={'format': LedConfigFormat()})
""" set/respond led config """
led_config: LedConfig = field(metadata={"format": LedConfigFormat()})
@dataclass
class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK):
enabled: bool = field(metadata={'format': BoolFormat()})
ssid: str = field(metadata={'format': FixedStrFormat(32)})
password: str = field(metadata={'format': FixedStrFormat(64)})
channel: int = field(metadata={'format': SimpleFormat('B')})
udp: bool = field(metadata={'format': BoolFormat()})
ssl: bool = field(metadata={'format': BoolFormat()})
host: str = field(metadata={'format': FixedStrFormat(64)})
port: int = field(metadata={'format': SimpleFormat('H')})
""" set/respond uplink config """
enabled: bool = field(metadata={"format": BoolFormat()})
ssid: str = field(metadata={"format": FixedStrFormat(32)})
password: str = field(metadata={"format": FixedStrFormat(64)})
channel: int = field(metadata={"format": SimpleFormat('B')})
udp: bool = field(metadata={"format": BoolFormat()})
ssl: bool = field(metadata={"format": BoolFormat()})
host: str = field(metadata={"format": FixedStrFormat(64)})
port: int = field(metadata={"format": SimpleFormat('H')})

View file

@ -1,2 +1,6 @@
def get_mesh_comm_group(address):
return 'mesh_comm_%s' % address.replace(':', '-')
def indent_c(code):
return " "+code.replace("\n", "\n ")