auto-generate code for mesh_msg.h

This commit is contained in:
Laura Klünder 2023-10-04 22:25:15 +02:00
parent e0961c16c1
commit e0c9f36c66
8 changed files with 325 additions and 45 deletions

View file

@ -4,7 +4,7 @@ from asgiref.sync import async_to_sync
from channels.generic.websocket import WebsocketConsumer from channels.generic.websocket import WebsocketConsumer
from django.utils import timezone from django.utils import timezone
from c3nav.control.views.utils import get_mesh_comm_group from c3nav.mesh.utils import get_mesh_comm_group
from c3nav.mesh import messages from c3nav.mesh import messages
from c3nav.mesh.messages import MeshMessage, BROADCAST_ADDRESS from c3nav.mesh.messages import MeshMessage, BROADCAST_ADDRESS
from c3nav.mesh.models import MeshNode, NodeMessage from c3nav.mesh.models import MeshNode, NodeMessage

View file

@ -19,6 +19,23 @@ class SimpleFormat:
value = value[0] value = value[0]
return value, data[self.size:] return value, data[self.size:]
c_types = {
"B": "uint8_t",
"H": "uint16_t",
"I": "uint32_t",
"b": "int8_t",
"h": "int16_t",
"i": "int32_t",
}
def get_c_struct(self, name):
c_type = self.c_types[self.fmt[-1]]
num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
if num == 1:
return "%s %s;" % (c_type, name)
else:
return "%s %s[%d];" % (c_type, name, num)
class FixedStrFormat: class FixedStrFormat:
def __init__(self, num): def __init__(self, num):
@ -30,6 +47,12 @@ class FixedStrFormat:
def decode(self, data: bytes): def decode(self, data: bytes):
return struct.unpack('%ss' % self.num, data[:self.num])[0].rstrip(bytes((0, ))).decode(), data[self.num:] return struct.unpack('%ss' % self.num, data[:self.num])[0].rstrip(bytes((0, ))).decode(), data[self.num:]
def get_c_struct(self, name):
return "char %(name)s[%(length)d];" % {
"name": name,
"length": self.num,
}
class BoolFormat: class BoolFormat:
def encode(self, value): def encode(self, value):
@ -38,6 +61,11 @@ class BoolFormat:
def decode(self, data: bytes): def decode(self, data: bytes):
return bool(struct.unpack('B', data[:1])[0]), data[1:] return bool(struct.unpack('B', data[:1])[0]), data[1:]
def get_c_struct(self, name):
return "uint8_t %(name)s;" % {
"name": name,
}
class HexFormat: class HexFormat:
def __init__(self, num, sep=''): def __init__(self, num, sep=''):
@ -53,14 +81,27 @@ class HexFormat:
data[self.num:] data[self.num:]
) )
def get_c_struct(self, name):
return "uint8_t %(name)s[%(length)d];" % {
"name": name,
"length": self.num,
}
class VarStrFormat: class VarStrFormat:
var_num = 1
def encode(self, value: str) -> bytes: def encode(self, value: str) -> bytes:
return bytes((len(value)+1, )) + value.encode() + bytes((0, )) return bytes((len(value)+1, )) + value.encode() + bytes((0, ))
def decode(self, data: bytes): def decode(self, data: bytes):
return data[1:data[0]].decode(), data[data[0]+1:] return data[1:data[0]].decode(), data[data[0]+1:]
def get_c_struct(self, name):
return "uint8_t num;\nchar %(name)s[0];" % {
"name": name,
}
class MacAddressFormat: class MacAddressFormat:
def encode(self, value: str) -> bytes: def encode(self, value: str) -> bytes:
@ -69,8 +110,15 @@ class MacAddressFormat:
def decode(self, data: bytes): def decode(self, data: bytes):
return (MAC_FMT % tuple(data[:6])), data[6:] return (MAC_FMT % tuple(data[:6])), data[6:]
def get_c_struct(self, name):
return "uint8_t %(name)s[6];" % {
"name": name,
}
class MacAddressesListFormat: class MacAddressesListFormat:
var_num = 6
def encode(self, value: list[str]) -> bytes: def encode(self, value: list[str]) -> bytes:
return bytes((len(value), )) + sum( return bytes((len(value), )) + sum(
(bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value), (bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value),
@ -80,6 +128,11 @@ class MacAddressesListFormat:
def decode(self, data: bytes): def decode(self, data: bytes):
return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6:] return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6:]
def get_c_struct(self, name):
return "uint8_t num;\nuint8_t %(name)s[6][0];" % {
"name": name,
}
class LedType(IntEnum): class LedType(IntEnum):
SERIAL = 1 SERIAL = 1
@ -138,3 +191,20 @@ class LedConfigFormat:
else: else:
raise ValueError raise ValueError
return value, data[4:] return value, data[4:]
def get_c_struct(self, name):
return (
"uint8_t type;\n"
"union {\n"
" struct {\n"
" uint8_t gpio;\n"
" uint8_t rmt;\n"
" } serial;\n"
" struct {\n"
" uint8_t gpio_red;\n"
" uint8_t gpio_green;\n"
" uint8_t gpio_blue;\n"
" } multipin;\n"
" uint8_t bytes[3];\n"
"};"
)

View file

@ -166,4 +166,3 @@ class ConfigPositionMessageForm(MeshMessageForm):
x_pos = forms.IntegerField(min_value=0, max_value=2**16-1, label=_('X')) x_pos = forms.IntegerField(min_value=0, max_value=2**16-1, label=_('X'))
y_pos = forms.IntegerField(min_value=0, max_value=2 ** 16 - 1, label=_('Y')) y_pos = forms.IntegerField(min_value=0, max_value=2 ** 16 - 1, label=_('Y'))
z_pos = forms.IntegerField(min_value=0, max_value=2 ** 16 - 1, label=_('Z')) z_pos = forms.IntegerField(min_value=0, max_value=2 ** 16 - 1, label=_('Z'))

View file

View file

@ -0,0 +1,78 @@
from django.core.management.base import BaseCommand
from c3nav.mesh.messages import MeshMessage
from c3nav.mesh.utils import indent_c
class Command(BaseCommand):
help = 'export mesh message structs for c code'
def handle(self, *args, **options):
done_struct_names = set()
nodata = set()
struct_lines = {
"num": "uint8_t num; /** defined here for convenience. subtypes that use it will define it themselves */"
}
for msg_type in MeshMessage.msg_types.values():
if msg_type.c_struct_name:
if msg_type.c_struct_name in done_struct_names:
continue
done_struct_names.add(msg_type.c_struct_name)
msg_type = MeshMessage.c_structs[msg_type.c_struct_name]
code = msg_type.get_c_struct()
if code:
struct_lines[msg_type.get_c_struct_name()] = (
"mesh_msg_%s_t %s;" % (
msg_type.get_c_struct_name(),
msg_type.get_c_struct_name().replace("_announce", ""),
)
)
print(code)
print()
else:
nodata.add(msg_type)
print("/** union between all message data structs */")
print("typedef union __packed {")
for line in struct_lines.values():
print(indent_c(line))
print("} mesh_msg_data_t;")
print()
max_msg_type = max(MeshMessage.msg_types.keys())
macro_data = []
for i in range(((max_msg_type//16)+1)*16):
msg_type = MeshMessage.msg_types.get(i, None)
if msg_type:
macro_data.append((
msg_type.get_c_enum_name()+',',
("nodata" if msg_type in nodata else msg_type.get_c_struct_name())+',',
msg_type.get_var_num(),
msg_type.__doc__.strip(),
))
else:
macro_data.append((
"RESERVED_%02X," % i,
"nodata,",
0,
"",
))
max0 = max(len(d[0]) for d in macro_data)
max1 = max(len(d[1]) for d in macro_data)
max2 = max(len(str(d[2])) for d in macro_data)
lines = []
for i, (macro_name, struct_name, num_len, comment) in enumerate(macro_data):
lines.append(indent_c(
"FN(%s %s %s) /** 0x%02X %s*/" % (
macro_name.ljust(max0),
struct_name.ljust(max1),
str(num_len).rjust(max2),
i,
comment+(" " if comment else ""),
)
))
print("#define FOR_ALL_MESH_MSG_TYPES(FN) \\")
print(" \\\n".join(lines))

View file

@ -1,11 +1,13 @@
import re
from dataclasses import asdict, dataclass, field, fields, is_dataclass from dataclasses import asdict, dataclass, field, fields, is_dataclass
from enum import IntEnum, unique from enum import IntEnum, unique
from itertools import chain
from typing import TypeVar from typing import TypeVar
import channels import channels
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from c3nav.control.views.utils import get_mesh_comm_group from c3nav.mesh.utils import get_mesh_comm_group, indent_c
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, HexFormat, LedConfig, LedConfigFormat, from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, HexFormat, LedConfig, LedConfigFormat,
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat) MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat)
@ -17,6 +19,8 @@ NO_LAYER = 0xFF
@unique @unique
class MeshMessageType(IntEnum): class MeshMessageType(IntEnum):
NOOP = 0x00
ECHO_REQUEST = 0x01 ECHO_REQUEST = 0x01
ECHO_RESPONSE = 0x02 ECHO_RESPONSE = 0x02
@ -32,7 +36,7 @@ class MeshMessageType(IntEnum):
CONFIG_UPLINK = 0x14 CONFIG_UPLINK = 0x14
M = TypeVar('M', bound='Message') M = TypeVar('M', bound='MeshMessage')
@unique @unique
@ -43,24 +47,31 @@ class ChipType(IntEnum):
@dataclass @dataclass
class MeshMessage: class MeshMessage:
dst: str = field(metadata={'format': MacAddressFormat()}) dst: str = field(metadata={"format": MacAddressFormat()})
src: str = field(metadata={'format': MacAddressFormat()}) src: str = field(metadata={"format": MacAddressFormat()})
msg_id: int = field(metadata={'format': SimpleFormat('B')}, init=False, repr=False) msg_id: int = field(metadata={"format": SimpleFormat('B')}, init=False, repr=False)
msg_types = {} msg_types = {}
c_structs = {}
c_struct_name = None
# noinspection PyMethodOverriding # noinspection PyMethodOverriding
def __init_subclass__(cls, /, msg_id=None, **kwargs): def __init_subclass__(cls, /, msg_id=None, c_struct_name=None, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
if msg_id: if msg_id is not None:
cls.msg_id = msg_id cls.msg_id = msg_id
if msg_id in MeshMessage.msg_types: if msg_id in MeshMessage.msg_types:
raise TypeError('duplicate use of msg_id %d' % msg_id) raise TypeError('duplicate use of msg_id %d' % msg_id)
MeshMessage.msg_types[msg_id] = cls MeshMessage.msg_types[msg_id] = cls
if c_struct_name:
cls.c_struct_name = c_struct_name
if c_struct_name in MeshMessage.c_structs:
raise TypeError('duplicate use of c_struct_name %s' % c_struct_name)
MeshMessage.c_structs[c_struct_name] = cls
def encode(self): def encode(self):
data = bytes() data = bytes()
for field_ in fields(self): for field_ in fields(self):
data += field_.metadata['format'].encode(getattr(self, field_.name)) data += field_.metadata["format"].encode(getattr(self, field_.name))
return data return data
@classmethod @classmethod
@ -68,7 +79,7 @@ class MeshMessage:
klass = cls.msg_types[data[12]] klass = cls.msg_types[data[12]]
values = {} values = {}
for field_ in fields(klass): for field_ in fields(klass):
values[field_.name], data = field_.metadata['format'].decode(data) values[field_.name], data = field_.metadata["format"].decode(data)
values.pop('msg_id') values.pop('msg_id')
return klass(**values) return klass(**values)
@ -90,56 +101,160 @@ class MeshMessage:
"msg": self.tojson() "msg": self.tojson()
}) })
@classmethod
def get_ignore_c_fields(self):
return set()
@dataclass @classmethod
class EchoRequestMessage(MeshMessage, msg_id=MeshMessageType.ECHO_REQUEST): def get_additional_c_fields(self):
content: str = field(default='', metadata={'format': VarStrFormat()}) return ()
@classmethod
def get_c_struct(cls):
ignore_fields = cls.get_ignore_c_fields()
if cls != MeshMessage:
ignore_fields |= set(field.name for field in fields(MeshMessage))
items = tuple(
(
tuple(field.metadata["format"].get_c_struct(field.metadata.get("c_name", field.name)).split("\n")),
field.metadata.get("doc", None),
)
for field in fields(cls)
if field.name not in ignore_fields
)
if not items:
return ""
max_line_len = max(len(line) for line in chain(*(code for code, doc in items)))
msg_comment = cls.__doc__.strip()
return "%(comment)stypedef struct __packed {\n%(elements)s\n} %(name)s;" % {
"comment": ("/** %s */\n" % msg_comment) if msg_comment else "",
"elements": indent_c(
"\n".join(chain(*(
(code if not comment
else (code[:-1]+("%s /** %s */" % (code[-1].ljust(max_line_len), comment),)))
for code, comment in items
), cls.get_additional_c_fields()))
),
"name": "mesh_msg_%s_t" % cls.get_c_struct_name(),
}
@classmethod
def get_var_num(cls):
return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0)
@classmethod
def get_c_struct_name(cls):
return (
cls.c_struct_name if cls.c_struct_name else
re.sub(
r"([a-z])([A-Z])",
r"\1_\2",
cls.__name__.removeprefix('Mesh').removesuffix('Message')
).lower().replace('config', 'cfg').replace('firmware', 'fw').replace('position', 'pos')
)
@classmethod
def get_c_enum_name(cls):
return re.sub(
r"([a-z])([A-Z])",
r"\1_\2",
cls.__name__.removeprefix('Mesh').removesuffix('Message')
).upper().replace('CONFIG', 'CFG').replace('FIRMWARE', 'FW').replace('POSITION', 'POS')
@dataclass @dataclass
class EchoResponseMessage(MeshMessage, msg_id=MeshMessageType.ECHO_RESPONSE): class NoopMessage(MeshMessage, msg_id=MeshMessageType.NOOP):
content: str = field(default='', metadata={'format': VarStrFormat()}) """ noop """
pass
@dataclass
class BaseEchoMessage(MeshMessage, c_struct_name="echo"):
""" repeat back string """
content: str = field(default='', metadata={
"format": VarStrFormat(),
"doc": "string to echo",
"c_name": "str",
})
@dataclass
class EchoRequestMessage(BaseEchoMessage, msg_id=MeshMessageType.ECHO_REQUEST):
""" repeat back string """
pass
@dataclass
class EchoResponseMessage(BaseEchoMessage, msg_id=MeshMessageType.ECHO_RESPONSE):
""" repeat back string """
pass
@dataclass @dataclass
class MeshSigninMessage(MeshMessage, msg_id=MeshMessageType.MESH_SIGNIN): class MeshSigninMessage(MeshMessage, msg_id=MeshMessageType.MESH_SIGNIN):
""" node says hello to upstream node """
pass pass
@dataclass @dataclass
class MeshLayerAnnounceMessage(MeshMessage, msg_id=MeshMessageType.MESH_LAYER_ANNOUNCE): class MeshLayerAnnounceMessage(MeshMessage, msg_id=MeshMessageType.MESH_LAYER_ANNOUNCE):
layer: int = field(metadata={'format': SimpleFormat('B')}) """ upstream node announces layer number """
layer: int = field(metadata={
"format": SimpleFormat('B'),
"doc": "mesh layer that the sending node is on",
})
@dataclass @dataclass
class MeshAddDestinationsMessage(MeshMessage, msg_id=MeshMessageType.MESH_ADD_DESTINATIONS): class BaseDestinationsMessage(MeshMessage, c_struct_name="destinations"):
mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()}) """ downstream node announces served/no longer served destination """
mac_addresses: list[str] = field(default_factory=list, metadata={
"format": MacAddressesListFormat(),
"doc": "mac adresses of the destinations",
"c_name": "mac",
})
@dataclass @dataclass
class MeshRemoveDestinationsMessage(MeshMessage, msg_id=MeshMessageType.MESH_REMOVE_DESTINATIONS): class MeshAddDestinationsMessage(BaseDestinationsMessage, msg_id=MeshMessageType.MESH_ADD_DESTINATIONS):
mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()}) """ downstream node announces served destination """
pass
@dataclass
class MeshRemoveDestinationsMessage(BaseDestinationsMessage, msg_id=MeshMessageType.MESH_REMOVE_DESTINATIONS):
""" downstream node announces no longer served destination """
pass
@dataclass @dataclass
class ConfigDumpMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_DUMP): class ConfigDumpMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_DUMP):
""" request for the node to dump its config """
pass pass
@dataclass @dataclass
class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE): class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE):
chip: int = field(metadata={'format': SimpleFormat('H')}) """ respond firmware info """
revision: int = field(metadata={'format': SimpleFormat('2B')}) chip: int = field(metadata={
magic_word: int = field(metadata={'format': SimpleFormat('I')}, repr=False) "format": SimpleFormat('H'),
secure_version: int = field(metadata={'format': SimpleFormat('I')}) "c_name": "chip_id",
reserv1: list[int] = field(metadata={'format': SimpleFormat('2I')}, repr=False) })
version: str = field(metadata={'format': FixedStrFormat(32)}) revision_major: int = field(metadata={"format": SimpleFormat('B')})
project_name: str = field(metadata={'format': FixedStrFormat(32)}) revision_minor: int = field(metadata={"format": SimpleFormat('B')})
compile_time: str = field(metadata={'format': FixedStrFormat(16)}) magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
compile_date: str = field(metadata={'format': FixedStrFormat(16)}) secure_version: int = field(metadata={"format": SimpleFormat('I')})
idf_version: str = field(metadata={'format': FixedStrFormat(32)}) reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
app_elf_sha256: str = field(metadata={'format': HexFormat(32)}) version: str = field(metadata={"format": FixedStrFormat(32)})
reserv2: list[int] = field(metadata={'format': SimpleFormat('20I')}, repr=False) project_name: str = field(metadata={"format": FixedStrFormat(32)})
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={"format": HexFormat(32)})
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
def to_model_data(self): def to_model_data(self):
return { return {
@ -153,26 +268,40 @@ class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE)
def get_chip_display(self): def get_chip_display(self):
return ChipType(self.chip).name.replace('_', '-') return ChipType(self.chip).name.replace('_', '-')
@classmethod
def get_ignore_c_fields(self):
return {
"magic_word", "secure_version", "reserv1", "version", "project_name",
"compile_time", "compile_date", "idf_version", "app_elf_sha256", "reserv2"
}
@classmethod
def get_additional_c_fields(self):
return ("esp_app_desc_t app_desc;", )
@dataclass @dataclass
class ConfigPositionMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_POSITION): class ConfigPositionMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_POSITION):
x_pos: int = field(metadata={'format': SimpleFormat('I')}) """ set/respond position config """
y_pos: int = field(metadata={'format': SimpleFormat('I')}) x_pos: int = field(metadata={"format": SimpleFormat('i')})
z_pos: int = field(metadata={'format': SimpleFormat('H')}) y_pos: int = field(metadata={"format": SimpleFormat('i')})
z_pos: int = field(metadata={"format": SimpleFormat('h')})
@dataclass @dataclass
class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED): class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED):
led_config: LedConfig = field(metadata={'format': LedConfigFormat()}) """ set/respond led config """
led_config: LedConfig = field(metadata={"format": LedConfigFormat()})
@dataclass @dataclass
class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK): class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK):
enabled: bool = field(metadata={'format': BoolFormat()}) """ set/respond uplink config """
ssid: str = field(metadata={'format': FixedStrFormat(32)}) enabled: bool = field(metadata={"format": BoolFormat()})
password: str = field(metadata={'format': FixedStrFormat(64)}) ssid: str = field(metadata={"format": FixedStrFormat(32)})
channel: int = field(metadata={'format': SimpleFormat('B')}) password: str = field(metadata={"format": FixedStrFormat(64)})
udp: bool = field(metadata={'format': BoolFormat()}) channel: int = field(metadata={"format": SimpleFormat('B')})
ssl: bool = field(metadata={'format': BoolFormat()}) udp: bool = field(metadata={"format": BoolFormat()})
host: str = field(metadata={'format': FixedStrFormat(64)}) ssl: bool = field(metadata={"format": BoolFormat()})
port: int = field(metadata={'format': SimpleFormat('H')}) host: str = field(metadata={"format": FixedStrFormat(64)})
port: int = field(metadata={"format": SimpleFormat('H')})

View file

@ -1,2 +1,6 @@
def get_mesh_comm_group(address): def get_mesh_comm_group(address):
return 'mesh_comm_%s' % address.replace(':', '-') return 'mesh_comm_%s' % address.replace(':', '-')
def indent_c(code):
return " "+code.replace("\n", "\n ")