more mesh_msg export stuffs

This commit is contained in:
Laura Klünder 2023-10-05 20:55:36 +02:00
parent 7a13193acd
commit 16f47168a2
3 changed files with 85 additions and 42 deletions

View file

@ -3,7 +3,6 @@ import struct
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from enum import IntEnum, unique
from itertools import chain
from typing import Self, Sequence, Any
from c3nav.mesh.utils import indent_c
@ -315,21 +314,25 @@ class StructType:
return cls(**kwargs)
@classmethod
def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False):
def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False, union_only=False,
union_member_as_types=False):
ignore_fields = set() if not ignore_fields else set(ignore_fields)
pre = ""
items = []
for field_ in fields(cls):
if field_.name in ignore_fields:
continue
name = field_.metadata.get("c_name", field_.name)
if "format" in field_.metadata:
items.append((
field_.metadata["format"].get_c_code(field_.name),
field_.metadata["format"].get_c_code(name),
field_.metadata.get("doc", None),
)),
elif issubclass(field_.type, StructType):
items.append((
field_.type.get_c_code(field_.name, typedef=False),
field_.type.get_c_code(name, typedef=False),
field_.metadata.get("doc", None),
))
else:
@ -340,19 +343,39 @@ class StructType:
parent_fields = set(field_.name for field_ in fields(cls))
union_items = []
for key, option in cls._union_options[cls.union_type_field].items():
name = normalize_name(getattr(key, 'name', option.__name__))
union_items.append(
option.get_c_code(name, ignore_fields=(ignore_fields | parent_fields))
base_name = normalize_name(getattr(key, 'name', option.__name__))
if union_member_as_types:
struct_name = cls.get_struct_name(base_name)
pre += option.get_c_code(
struct_name,
ignore_fields=(ignore_fields | parent_fields),
typedef=True
)+"\n\n"
union_items.append(
"%s %s;" % (struct_name, cls.get_variable_name(base_name)),
)
else:
union_items.append(
option.get_c_code(base_name, ignore_fields=(ignore_fields | parent_fields))
)
union_items.append(
"uint8_t bytes[%s];" % max(
(option.get_min_size() for option in cls._union_options[cls.union_type_field].values()),
default=0,
)
items.append((
"union {\n"+indent_c("\n".join(union_items))+"\n}",
""))
)
union_code = "{\n"+indent_c("\n".join(union_items))+"\n}",
if union_only:
return "typedef union __packed %s" % union_code, "";
else:
items.append(("union %s;" % union_code, ""))
elif union_only:
return "", ""
if no_empty and not items:
return "", ""
# todo: struct comment
pre = ""
if typedef:
comment = cls.__doc__.strip()
if comment:
@ -372,23 +395,31 @@ class StructType:
return pre, ""
@classmethod
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str:
pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef)
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False,
union_member_as_types=False) -> str:
pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef,
union_only=union_only, union_member_as_types=union_member_as_types,
)
if no_empty and not pre and not post:
return ""
return "%s %s%s;" % (pre, name, post)
@classmethod
def get_base_name(cls):
return cls.__name__
def get_variable_name(cls, base_name):
return base_name
@classmethod
def get_variable_name(cls):
return cls.get_base_name()
def get_struct_name(cls, base_name):
return "%s_t" % base_name
@classmethod
def get_struct_name(cls):
return "%s_t" % cls.get_base_name()
def get_min_size(cls) -> int:
if cls.union_type_field:
return (
{f.name: field for f in fields()}[cls.union_type_field].metadata["format"].get_min_size() +
sum((option.get_min_size() for option in cls._union_options[cls.union_type_field].values()), start=0)
)
return sum((f.metadata.get("format", f.type).get_min_size() for f in fields(cls)), start=0)
class MacAddressFormat(FixedHexFormat):
@ -426,6 +457,7 @@ class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN):
gpio_blue: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class RangeItemType(StructType):
address: str = field(metadata={"format": MacAddressFormat()})
distance: int = field(metadata={"format": SimpleFormat('H')})

View file

@ -2,22 +2,27 @@ from dataclasses import fields
from django.core.management.base import BaseCommand
from c3nav.mesh.dataformats import normalize_name
from c3nav.mesh.messages import MeshMessage
from c3nav.mesh.dataformats import normalize_name, LedConfig
from c3nav.mesh.messages import MeshMessage, MeshMessageType
from c3nav.mesh.utils import indent_c
class Command(BaseCommand):
help = 'export mesh message structs for c code'
def shorten_name(self, name):
name = name.replace('config', 'cfg')
name = name.replace('position', 'pos')
name = name.replace('mesh_', '')
name = name.replace('firmware', 'fw')
return name
def handle(self, *args, **options):
done_struct_names = set()
nodata = set()
struct_lines = {}
ignore_names = set(field_.name for field_ in fields(MeshMessage))
from pprint import pprint
pprint(MeshMessage.get_msg_types())
for msg_id, msg_type in MeshMessage.get_msg_types().items():
if msg_type.c_struct_name:
if msg_type.c_struct_name in done_struct_names:
@ -25,16 +30,22 @@ class Command(BaseCommand):
done_struct_names.add(msg_type.c_struct_name)
msg_type = MeshMessage.c_structs[msg_type.c_struct_name]
name = "mesh_msg_%s_t" % (
msg_type.c_struct_name or normalize_name(getattr(msg_id, 'name', msg_type.__name__))
)
base_name = (msg_type.c_struct_name or self.shorten_name(normalize_name(
getattr(msg_id, 'name', msg_type.__name__)
)))
name = "mesh_msg_%s_t" % base_name
if msg_id == MeshMessageType.CONFIG_LED:
msg_type = LedConfig
code = msg_type.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
if code:
struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', ''))
print(code)
print()
else:
nodata.add(msg_type)
return
print("/** union between all message data structs */")
print("typedef union __packed {")
for line in struct_lines.values():
@ -42,14 +53,17 @@ class Command(BaseCommand):
print("} mesh_msg_data_t;")
print()
max_msg_type = max(MeshMessage.msg_types.keys())
max_msg_type = max(MeshMessage.get_msg_types().keys())
macro_data = []
for i in range(((max_msg_type//16)+1)*16):
msg_type = MeshMessage.msg_types.get(i, None)
msg_type = MeshMessage.get_msg_types().get(i, None)
if msg_type:
name = (msg_type.c_struct_name or self.shorten_name(normalize_name(
getattr(msg_type.msg_id, 'name', msg_type.__name__)
)))
macro_data.append((
msg_type.get_c_enum_name()+',',
("nodata" if msg_type in nodata else msg_type.get_c_struct_name())+',',
("nodata" if msg_type in nodata else name)+',',
msg_type.get_var_num(),
msg_type.__doc__.strip(),
))

View file

@ -9,7 +9,8 @@ from asgiref.sync import async_to_sync
from c3nav.mesh.utils import get_mesh_comm_group, indent_c
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, FixedHexFormat, LedConfig, LedConfig,
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType)
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType,
VarArrayFormat, RangeItemType)
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff'
@ -120,15 +121,12 @@ class MeshMessage(StructType, union_type_field="msg_id"):
return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0)
@classmethod
def get_c_struct_name(cls):
return (
cls.c_struct_name if cls.c_struct_name else
re.sub(
r"([a-z])([A-Z])",
r"\1_\2",
cls.__name__.removeprefix('Mesh').removesuffix('Message')
).lower().replace('config', 'cfg').replace('firmware', 'fw').replace('position', 'pos')
)
def get_variable_name(cls, base_name):
return cls.c_struct_name or base_name
@classmethod
def get_struct_name(cls, base_name):
return "mesh_msg_%s_t" % base_name
@classmethod
def get_c_enum_name(cls):
@ -316,5 +314,4 @@ class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK):
@dataclass
class LocateReportRangeMessage(MeshMessage, msg_id=MeshMessageType.LOCATE_REPORT_RANGE):
""" report distance to given nodes """
#ranges: dict[str, int] =
pass
ranges: dict[str, int] = field(metadata={"format": VarArrayFormat(RangeItemType)})