more mesh_msg export stuffs

This commit is contained in:
Laura Klünder 2023-10-05 20:55:36 +02:00
parent 7a13193acd
commit 16f47168a2
3 changed files with 85 additions and 42 deletions

View file

@ -3,7 +3,6 @@ import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from enum import IntEnum, unique from enum import IntEnum, unique
from itertools import chain
from typing import Self, Sequence, Any from typing import Self, Sequence, Any
from c3nav.mesh.utils import indent_c from c3nav.mesh.utils import indent_c
@ -315,21 +314,25 @@ class StructType:
return cls(**kwargs) return cls(**kwargs)
@classmethod @classmethod
def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False): def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False, union_only=False,
union_member_as_types=False):
ignore_fields = set() if not ignore_fields else set(ignore_fields) ignore_fields = set() if not ignore_fields else set(ignore_fields)
pre = ""
items = [] items = []
for field_ in fields(cls): for field_ in fields(cls):
if field_.name in ignore_fields: if field_.name in ignore_fields:
continue continue
name = field_.metadata.get("c_name", field_.name)
if "format" in field_.metadata: if "format" in field_.metadata:
items.append(( items.append((
field_.metadata["format"].get_c_code(field_.name), field_.metadata["format"].get_c_code(name),
field_.metadata.get("doc", None), field_.metadata.get("doc", None),
)), )),
elif issubclass(field_.type, StructType): elif issubclass(field_.type, StructType):
items.append(( items.append((
field_.type.get_c_code(field_.name, typedef=False), field_.type.get_c_code(name, typedef=False),
field_.metadata.get("doc", None), field_.metadata.get("doc", None),
)) ))
else: else:
@ -340,19 +343,39 @@ class StructType:
parent_fields = set(field_.name for field_ in fields(cls)) parent_fields = set(field_.name for field_ in fields(cls))
union_items = [] union_items = []
for key, option in cls._union_options[cls.union_type_field].items(): for key, option in cls._union_options[cls.union_type_field].items():
name = normalize_name(getattr(key, 'name', option.__name__)) base_name = normalize_name(getattr(key, 'name', option.__name__))
union_items.append( if union_member_as_types:
option.get_c_code(name, ignore_fields=(ignore_fields | parent_fields)) struct_name = cls.get_struct_name(base_name)
pre += option.get_c_code(
struct_name,
ignore_fields=(ignore_fields | parent_fields),
typedef=True
)+"\n\n"
union_items.append(
"%s %s;" % (struct_name, cls.get_variable_name(base_name)),
)
else:
union_items.append(
option.get_c_code(base_name, ignore_fields=(ignore_fields | parent_fields))
)
union_items.append(
"uint8_t bytes[%s];" % max(
(option.get_min_size() for option in cls._union_options[cls.union_type_field].values()),
default=0,
) )
items.append(( )
"union {\n"+indent_c("\n".join(union_items))+"\n}", union_code = "{\n"+indent_c("\n".join(union_items))+"\n}",
"")) if union_only:
return "typedef union __packed %s" % union_code, "";
else:
items.append(("union %s;" % union_code, ""))
elif union_only:
return "", ""
if no_empty and not items: if no_empty and not items:
return "", "" return "", ""
# todo: struct comment # todo: struct comment
pre = ""
if typedef: if typedef:
comment = cls.__doc__.strip() comment = cls.__doc__.strip()
if comment: if comment:
@ -372,23 +395,31 @@ class StructType:
return pre, "" return pre, ""
@classmethod @classmethod
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str: def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False,
pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef) union_member_as_types=False) -> str:
pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef,
union_only=union_only, union_member_as_types=union_member_as_types,
)
if no_empty and not pre and not post: if no_empty and not pre and not post:
return "" return ""
return "%s %s%s;" % (pre, name, post) return "%s %s%s;" % (pre, name, post)
@classmethod @classmethod
def get_base_name(cls): def get_variable_name(cls, base_name):
return cls.__name__ return base_name
@classmethod @classmethod
def get_variable_name(cls): def get_struct_name(cls, base_name):
return cls.get_base_name() return "%s_t" % base_name
@classmethod @classmethod
def get_struct_name(cls): def get_min_size(cls) -> int:
return "%s_t" % cls.get_base_name() if cls.union_type_field:
return (
{f.name: field for f in fields()}[cls.union_type_field].metadata["format"].get_min_size() +
sum((option.get_min_size() for option in cls._union_options[cls.union_type_field].values()), start=0)
)
return sum((f.metadata.get("format", f.type).get_min_size() for f in fields(cls)), start=0)
class MacAddressFormat(FixedHexFormat): class MacAddressFormat(FixedHexFormat):
@ -426,6 +457,7 @@ class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN):
gpio_blue: int = field(metadata={"format": SimpleFormat('B')}) gpio_blue: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class RangeItemType(StructType): class RangeItemType(StructType):
address: str = field(metadata={"format": MacAddressFormat()}) address: str = field(metadata={"format": MacAddressFormat()})
distance: int = field(metadata={"format": SimpleFormat('H')}) distance: int = field(metadata={"format": SimpleFormat('H')})

View file

@ -2,22 +2,27 @@ from dataclasses import fields
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from c3nav.mesh.dataformats import normalize_name from c3nav.mesh.dataformats import normalize_name, LedConfig
from c3nav.mesh.messages import MeshMessage from c3nav.mesh.messages import MeshMessage, MeshMessageType
from c3nav.mesh.utils import indent_c from c3nav.mesh.utils import indent_c
class Command(BaseCommand): class Command(BaseCommand):
help = 'export mesh message structs for c code' help = 'export mesh message structs for c code'
def shorten_name(self, name):
name = name.replace('config', 'cfg')
name = name.replace('position', 'pos')
name = name.replace('mesh_', '')
name = name.replace('firmware', 'fw')
return name
def handle(self, *args, **options): def handle(self, *args, **options):
done_struct_names = set() done_struct_names = set()
nodata = set() nodata = set()
struct_lines = {} struct_lines = {}
ignore_names = set(field_.name for field_ in fields(MeshMessage)) ignore_names = set(field_.name for field_ in fields(MeshMessage))
from pprint import pprint
pprint(MeshMessage.get_msg_types())
for msg_id, msg_type in MeshMessage.get_msg_types().items(): for msg_id, msg_type in MeshMessage.get_msg_types().items():
if msg_type.c_struct_name: if msg_type.c_struct_name:
if msg_type.c_struct_name in done_struct_names: if msg_type.c_struct_name in done_struct_names:
@ -25,16 +30,22 @@ class Command(BaseCommand):
done_struct_names.add(msg_type.c_struct_name) done_struct_names.add(msg_type.c_struct_name)
msg_type = MeshMessage.c_structs[msg_type.c_struct_name] msg_type = MeshMessage.c_structs[msg_type.c_struct_name]
name = "mesh_msg_%s_t" % ( base_name = (msg_type.c_struct_name or self.shorten_name(normalize_name(
msg_type.c_struct_name or normalize_name(getattr(msg_id, 'name', msg_type.__name__)) getattr(msg_id, 'name', msg_type.__name__)
) )))
name = "mesh_msg_%s_t" % base_name
if msg_id == MeshMessageType.CONFIG_LED:
msg_type = LedConfig
code = msg_type.get_c_code(name, ignore_fields=ignore_names, no_empty=True) code = msg_type.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
if code: if code:
struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', ''))
print(code) print(code)
print() print()
else: else:
nodata.add(msg_type) nodata.add(msg_type)
return
print("/** union between all message data structs */") print("/** union between all message data structs */")
print("typedef union __packed {") print("typedef union __packed {")
for line in struct_lines.values(): for line in struct_lines.values():
@ -42,14 +53,17 @@ class Command(BaseCommand):
print("} mesh_msg_data_t;") print("} mesh_msg_data_t;")
print() print()
max_msg_type = max(MeshMessage.msg_types.keys()) max_msg_type = max(MeshMessage.get_msg_types().keys())
macro_data = [] macro_data = []
for i in range(((max_msg_type//16)+1)*16): for i in range(((max_msg_type//16)+1)*16):
msg_type = MeshMessage.msg_types.get(i, None) msg_type = MeshMessage.get_msg_types().get(i, None)
if msg_type: if msg_type:
name = (msg_type.c_struct_name or self.shorten_name(normalize_name(
getattr(msg_type.msg_id, 'name', msg_type.__name__)
)))
macro_data.append(( macro_data.append((
msg_type.get_c_enum_name()+',', msg_type.get_c_enum_name()+',',
("nodata" if msg_type in nodata else msg_type.get_c_struct_name())+',', ("nodata" if msg_type in nodata else name)+',',
msg_type.get_var_num(), msg_type.get_var_num(),
msg_type.__doc__.strip(), msg_type.__doc__.strip(),
)) ))

View file

@ -9,7 +9,8 @@ from asgiref.sync import async_to_sync
from c3nav.mesh.utils import get_mesh_comm_group, indent_c from c3nav.mesh.utils import get_mesh_comm_group, indent_c
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, FixedHexFormat, LedConfig, LedConfig, from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, FixedHexFormat, LedConfig, LedConfig,
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType) MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType,
VarArrayFormat, RangeItemType)
MESH_ROOT_ADDRESS = '00:00:00:00:00:00' MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff' MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff'
@ -120,15 +121,12 @@ class MeshMessage(StructType, union_type_field="msg_id"):
return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0) return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0)
@classmethod @classmethod
def get_c_struct_name(cls): def get_variable_name(cls, base_name):
return ( return cls.c_struct_name or base_name
cls.c_struct_name if cls.c_struct_name else
re.sub( @classmethod
r"([a-z])([A-Z])", def get_struct_name(cls, base_name):
r"\1_\2", return "mesh_msg_%s_t" % base_name
cls.__name__.removeprefix('Mesh').removesuffix('Message')
).lower().replace('config', 'cfg').replace('firmware', 'fw').replace('position', 'pos')
)
@classmethod @classmethod
def get_c_enum_name(cls): def get_c_enum_name(cls):
@ -316,5 +314,4 @@ class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK):
@dataclass @dataclass
class LocateReportRangeMessage(MeshMessage, msg_id=MeshMessageType.LOCATE_REPORT_RANGE): class LocateReportRangeMessage(MeshMessage, msg_id=MeshMessageType.LOCATE_REPORT_RANGE):
""" report distance to given nodes """ """ report distance to given nodes """
#ranges: dict[str, int] = ranges: dict[str, int] = field(metadata={"format": VarArrayFormat(RangeItemType)})
pass