diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py index 3e5a3146..f497f15c 100644 --- a/src/c3nav/mesh/dataformats.py +++ b/src/c3nav/mesh/dataformats.py @@ -3,7 +3,6 @@ import struct from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from enum import IntEnum, unique -from itertools import chain from typing import Self, Sequence, Any from c3nav.mesh.utils import indent_c @@ -315,21 +314,25 @@ class StructType: return cls(**kwargs) @classmethod - def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False): + def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False, union_only=False, + union_member_as_types=False): ignore_fields = set() if not ignore_fields else set(ignore_fields) + pre = "" + items = [] for field_ in fields(cls): if field_.name in ignore_fields: continue + name = field_.metadata.get("c_name", field_.name) if "format" in field_.metadata: items.append(( - field_.metadata["format"].get_c_code(field_.name), + field_.metadata["format"].get_c_code(name), field_.metadata.get("doc", None), )), elif issubclass(field_.type, StructType): items.append(( - field_.type.get_c_code(field_.name, typedef=False), + field_.type.get_c_code(name, typedef=False), field_.metadata.get("doc", None), )) else: @@ -340,19 +343,39 @@ class StructType: parent_fields = set(field_.name for field_ in fields(cls)) union_items = [] for key, option in cls._union_options[cls.union_type_field].items(): - name = normalize_name(getattr(key, 'name', option.__name__)) - union_items.append( - option.get_c_code(name, ignore_fields=(ignore_fields | parent_fields)) + base_name = normalize_name(getattr(key, 'name', option.__name__)) + if union_member_as_types: + struct_name = cls.get_struct_name(base_name) + pre += option.get_c_code( + struct_name, + ignore_fields=(ignore_fields | parent_fields), + typedef=True + )+"\n\n" + union_items.append( + "%s %s;" % (struct_name, cls.get_variable_name(base_name)), + ) + else: + union_items.append( + option.get_c_code(base_name, ignore_fields=(ignore_fields | parent_fields)) + ) + union_items.append( + "uint8_t bytes[%s];" % max( + (option.get_min_size() for option in cls._union_options[cls.union_type_field].values()), + default=0, ) - items.append(( - "union {\n"+indent_c("\n".join(union_items))+"\n}", - "")) + ) + union_code = "{\n"+indent_c("\n".join(union_items))+"\n}", + if union_only: + return "typedef union __packed %s" % union_code, ""; + else: + items.append(("union %s;" % union_code, "")) + elif union_only: + return "", "" if no_empty and not items: return "", "" # todo: struct comment - pre = "" if typedef: comment = cls.__doc__.strip() if comment: @@ -372,23 +395,31 @@ class StructType: return pre, "" @classmethod - def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str: - pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef) + def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False, + union_member_as_types=False) -> str: + pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef, + union_only=union_only, union_member_as_types=union_member_as_types, + ) if no_empty and not pre and not post: return "" return "%s %s%s;" % (pre, name, post) @classmethod - def get_base_name(cls): - return cls.__name__ + def get_variable_name(cls, base_name): + return base_name @classmethod - def get_variable_name(cls): - return cls.get_base_name() + def get_struct_name(cls, base_name): + return "%s_t" % base_name @classmethod - def get_struct_name(cls): - return "%s_t" % cls.get_base_name() + def get_min_size(cls) -> int: + if cls.union_type_field: + return ( + {f.name: field for f in fields()}[cls.union_type_field].metadata["format"].get_min_size() + + sum((option.get_min_size() for option in cls._union_options[cls.union_type_field].values()), start=0) + ) + return sum((f.metadata.get("format", f.type).get_min_size() for f in fields(cls)), start=0) class MacAddressFormat(FixedHexFormat): @@ -426,6 +457,7 @@ class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN): gpio_blue: int = field(metadata={"format": SimpleFormat('B')}) +@dataclass class RangeItemType(StructType): address: str = field(metadata={"format": MacAddressFormat()}) distance: int = field(metadata={"format": SimpleFormat('H')}) diff --git a/src/c3nav/mesh/management/commands/mesh_msg_c.py b/src/c3nav/mesh/management/commands/mesh_msg_c.py index 034a0bcc..ce06f89f 100644 --- a/src/c3nav/mesh/management/commands/mesh_msg_c.py +++ b/src/c3nav/mesh/management/commands/mesh_msg_c.py @@ -2,22 +2,27 @@ from dataclasses import fields from django.core.management.base import BaseCommand -from c3nav.mesh.dataformats import normalize_name -from c3nav.mesh.messages import MeshMessage +from c3nav.mesh.dataformats import normalize_name, LedConfig +from c3nav.mesh.messages import MeshMessage, MeshMessageType from c3nav.mesh.utils import indent_c class Command(BaseCommand): help = 'export mesh message structs for c code' + def shorten_name(self, name): + name = name.replace('config', 'cfg') + name = name.replace('position', 'pos') + name = name.replace('mesh_', '') + name = name.replace('firmware', 'fw') + return name + def handle(self, *args, **options): done_struct_names = set() nodata = set() struct_lines = {} ignore_names = set(field_.name for field_ in fields(MeshMessage)) - from pprint import pprint - pprint(MeshMessage.get_msg_types()) for msg_id, msg_type in MeshMessage.get_msg_types().items(): if msg_type.c_struct_name: if msg_type.c_struct_name in done_struct_names: @@ -25,16 +30,22 @@ class Command(BaseCommand): done_struct_names.add(msg_type.c_struct_name) msg_type = MeshMessage.c_structs[msg_type.c_struct_name] - name = "mesh_msg_%s_t" % ( - msg_type.c_struct_name or normalize_name(getattr(msg_id, 'name', msg_type.__name__)) - ) + base_name = (msg_type.c_struct_name or self.shorten_name(normalize_name( + getattr(msg_id, 'name', msg_type.__name__) + ))) + name = "mesh_msg_%s_t" % base_name + + if msg_id == MeshMessageType.CONFIG_LED: + msg_type = LedConfig + code = msg_type.get_c_code(name, ignore_fields=ignore_names, no_empty=True) if code: + struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', '')) print(code) print() else: nodata.add(msg_type) - return + print("/** union between all message data structs */") print("typedef union __packed {") for line in struct_lines.values(): @@ -42,14 +53,17 @@ class Command(BaseCommand): print("} mesh_msg_data_t;") print() - max_msg_type = max(MeshMessage.msg_types.keys()) + max_msg_type = max(MeshMessage.get_msg_types().keys()) macro_data = [] for i in range(((max_msg_type//16)+1)*16): - msg_type = MeshMessage.msg_types.get(i, None) + msg_type = MeshMessage.get_msg_types().get(i, None) if msg_type: + name = (msg_type.c_struct_name or self.shorten_name(normalize_name( + getattr(msg_type.msg_id, 'name', msg_type.__name__) + ))) macro_data.append(( msg_type.get_c_enum_name()+',', - ("nodata" if msg_type in nodata else msg_type.get_c_struct_name())+',', + ("nodata" if msg_type in nodata else name)+',', msg_type.get_var_num(), msg_type.__doc__.strip(), )) diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index 1a2ba074..a01b8a1d 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -9,7 +9,8 @@ from asgiref.sync import async_to_sync from c3nav.mesh.utils import get_mesh_comm_group, indent_c from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, FixedHexFormat, LedConfig, LedConfig, - MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType) + MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType, + VarArrayFormat, RangeItemType) MESH_ROOT_ADDRESS = '00:00:00:00:00:00' MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff' @@ -120,15 +121,12 @@ class MeshMessage(StructType, union_type_field="msg_id"): return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0) @classmethod - def get_c_struct_name(cls): - return ( - cls.c_struct_name if cls.c_struct_name else - re.sub( - r"([a-z])([A-Z])", - r"\1_\2", - cls.__name__.removeprefix('Mesh').removesuffix('Message') - ).lower().replace('config', 'cfg').replace('firmware', 'fw').replace('position', 'pos') - ) + def get_variable_name(cls, base_name): + return cls.c_struct_name or base_name + + @classmethod + def get_struct_name(cls, base_name): + return "mesh_msg_%s_t" % base_name @classmethod def get_c_enum_name(cls): @@ -316,5 +314,4 @@ class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK): @dataclass class LocateReportRangeMessage(MeshMessage, msg_id=MeshMessageType.LOCATE_REPORT_RANGE): """ report distance to given nodes """ - #ranges: dict[str, int] = - pass \ No newline at end of file + ranges: dict[str, int] = field(metadata={"format": VarArrayFormat(RangeItemType)}) \ No newline at end of file