diff --git a/src/c3nav/mesh/baseformats.py b/src/c3nav/mesh/baseformats.py new file mode 100644 index 00000000..312dcfe0 --- /dev/null +++ b/src/c3nav/mesh/baseformats.py @@ -0,0 +1,477 @@ +import re +import struct +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields, field, Field +from typing import Any, Sequence, Self + +from c3nav.mesh.utils import indent_c + + +class BaseFormat(ABC): + + + def get_var_num(self): + return 0 + + @abstractmethod + def encode(self, value): + pass + + @classmethod + @abstractmethod + def decode(cls, data) -> tuple[Any, bytes]: + pass + + def fromjson(self, data): + return data + + def tojson(self, data): + return data + + @abstractmethod + def get_min_size(self): + pass + + @abstractmethod + def get_c_parts(self) -> tuple[str, str]: + pass + + def get_c_code(self, name) -> str: + pre, post = self.get_c_parts() + return "%s %s%s;" % (pre, name, post) + + +class SimpleFormat(BaseFormat): + def __init__(self, fmt): + self.fmt = fmt + self.size = struct.calcsize(fmt) + + self.c_type = self.c_types[self.fmt[-1]] + self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1 + + def encode(self, value): + return struct.pack(self.fmt, (value,) if self.num == 1 else tuple(value)) + + def decode(self, data: bytes) -> tuple[Any, bytes]: + value = struct.unpack(self.fmt, data[:self.size]) + if len(value) == 1: + value = value[0] + return value, data[self.size:] + + def get_min_size(self): + return self.size + + c_types = { + "B": "uint8_t", + "H": "uint16_t", + "I": "uint32_t", + "b": "int8_t", + "h": "int16_t", + "i": "int32_t", + "s": "char", + } + + def get_c_parts(self): + return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num)) + + +class BoolFormat(SimpleFormat): + def __init__(self): + super().__init__('B') + + def encode(self, value): + return super().encode(int(value)) + + def decode(self, data: bytes) -> tuple[bool, bytes]: + value, data = super().decode(data) + return bool(value), data + + +class FixedStrFormat(SimpleFormat): + def __init__(self, num): + self.num = num + super().__init__('%ds' % self.num) + + def encode(self, value: str): + return value.encode()[:self.num].ljust(self.num, bytes((0,))), + + def decode(self, data: bytes) -> tuple[str, bytes]: + return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:] + + +class FixedHexFormat(SimpleFormat): + def __init__(self, num, sep=''): + self.num = num + self.sep = sep + super().__init__('%dB' % self.num) + + def encode(self, value: str): + return super().encode(tuple(bytes.fromhex(value))) + + def decode(self, data: bytes) -> tuple[str, bytes]: + return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:] + + +@abstractmethod +class BaseVarFormat(BaseFormat, ABC): + def __init__(self, num_fmt='B'): + self.num_fmt = num_fmt + self.num_size = struct.calcsize(self.num_fmt) + + def get_min_size(self): + return self.num_size + + def get_num_c_code(self): + return SimpleFormat(self.num_fmt).get_c_code("num") + + +class VarArrayFormat(BaseVarFormat): + def __init__(self, child_type, num_fmt='B'): + super().__init__(num_fmt=num_fmt) + self.child_type = child_type + self.child_size = self.child_type.get_min_size() + + def get_var_num(self): + return self.child_size + pass + + def encode(self, values: Sequence) -> bytes: + data = struct.pack(self.num_fmt, (len(values),)) + for value in values: + data += self.child_type.encode(value) + return data + + def decode(self, data: bytes) -> tuple[list[Any], bytes]: + num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + return [ + self.child_type.decode(data[i:i + self.child_size]) + for i in range(self.num_size, self.num_size + num * self.child_size, self.child_size) + ], data[self.num_size + num * self.child_size:] + + def get_c_parts(self): + pre, post = self.child_type.get_c_parts() + return super().get_num_c_code() + "\n" + pre, "[0]" + post + + +class VarStrFormat(BaseVarFormat): + + def get_var_num(self): + return 1 + + def encode(self, value: str) -> bytes: + return struct.pack(self.num_fmt, (len(str),)) + value.encode() + + def decode(self, data: bytes) -> tuple[str, bytes]: + num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:] + + def get_c_parts(self): + return super().get_num_c_code() + "\n" + "char", "[0]" + + +@dataclass +class StructType: + _union_options = {} + union_type_field = None + + # noinspection PyMethodOverriding + def __init_subclass__(cls, /, union_type_field=None, no_c_type=False, **kwargs): + cls.union_type_field = union_type_field + if union_type_field: + if union_type_field in cls._union_options: + raise TypeError('Duplicate union_type_field: %s', union_type_field) + cls._union_options[union_type_field] = {} + f = getattr(cls, union_type_field) + metadata = dict(f.metadata) + metadata['union_discriminator'] = True + f.metadata = metadata + f.repr = False + f.init = False + + for attr_name in cls.__dict__.keys(): + attr = getattr(cls, attr_name) + if isinstance(attr, Field): + metadata = dict(attr.metadata) + if "defining_class" not in metadata: + metadata["defining_class"] = cls + attr.metadata = metadata + + for key, values in cls._union_options.items(): + value = kwargs.pop(key, None) + if value is not None: + if value in values: + raise TypeError('Duplicate %s: %s', (key, value)) + values[value] = cls + setattr(cls, key, value) + super().__init_subclass__(**kwargs) + + @classmethod + def get_var_num(cls): + return sum([f.metadata.get("format", f.type).get_var_num() for f in fields(cls)], start=0) + + @classmethod + def encode(cls, instance) -> bytes: + data = bytes() + if cls.union_type_field and type(instance) is not cls: + if not isinstance(instance, cls): + raise ValueError('expected value of type %r, got %r' % (cls, instance)) + + for field_ in fields(instance): + if field_.name is cls.union_type_field: + data += field_.metadata["format"].encode(getattr(instance, field_.name)) + break + else: + raise TypeError('couldn\'t find %s value' % cls.union_type_field) + + data += instance.encode(instance) + return data + + for field_ in fields(cls): + value = getattr(instance, field_.name) + if "format" in field_.metadata: + data += field_.metadata["format"].encode(value) + elif issubclass(field_.type, StructType): + if not isinstance(value, field_.type): + raise ValueError('expected value of type %r for %s.%s, got %r' % + (field_.type, cls.__name__, field_.name, value)) + data += value.encode(value) + else: + raise TypeError('field %s.%s has no format and is no StructType' % + (cls.__class__.__name__, field_.name)) + return data + + @classmethod + def decode(cls, data: bytes) -> Self: + values = {} + for field_ in fields(cls): + if "format" in field_.metadata: + data = field_.metadata["format"].decode(data) + elif issubclass(field_.type, StructType): + data = field_.type.decode(data) + else: + raise TypeError('field %s.%s has no format and is no StructType' % + (cls.__name__, field_.name)) + values[field_.name] = field_.metadata["format"].decode(data) + + if cls.union_type_field: + try: + type_value = values[cls.union_type_field] + except KeyError: + raise TypeError('union_type_field %s.%s is missing' % + (cls.__name__, cls.union_type_field)) + try: + klass = cls._union_options[type_value] + except KeyError: + raise TypeError('union_type_field %s.%s value %r no known' % + (cls.__name__, cls.union_type_field, type_value)) + return klass.decode(data) + return cls(**values) + + @classmethod + def tojson(cls, instance) -> dict: + result = {} + + if cls.union_type_field and type(instance) is not cls: + if not isinstance(instance, cls): + raise ValueError('expected value of type %r, got %r' % (cls, instance)) + + for field_ in fields(instance): + if field_.name is cls.union_type_field: + result[field_.name] = field_.metadata["format"].encode(getattr(instance, field_.name)) + break + else: + raise TypeError('couldn\'t find %s value' % cls.union_type_field) + + result.update(instance.tojson(instance)) + return result + + for field_ in fields(cls): + value = getattr(instance, field_.name) + if "format" in field_.metadata: + result[field_.name] = field_.metadata["format"].tojson(value) + elif issubclass(field_.type, StructType): + if not isinstance(value, field_.type): + raise ValueError('expected value of type %r for %s.%s, got %r' % + (field_.type, cls.__name__, field_.name, value)) + result[field_.name] = value.tojson(value) + else: + raise TypeError('field %s.%s has no format and is no StructType' % + (cls.__class__.__name__, field_.name)) + return result + + @classmethod + def fromjson(cls, data): + data = data.copy() + + # todo: upgrade_json + + kwargs = {} + for field_ in fields(cls): + if "format" in field_.metadata: + data = field_.metadata["format"].decode(data) + elif issubclass(field_.type, StructType): + data = field_.type.decode(data) + else: + raise TypeError('field %s.%s has no format and is no StructType' % + (cls.__name__, field_.name)) + kwargs[field_.name], data = field_.metadata["format"].decode(data) + + if cls.union_type_field: + try: + type_value = kwargs[cls.union_type_field] + except KeyError: + raise TypeError('union_type_field %s.%s is missing' % + (cls.__name__, cls.union_type_field)) + try: + klass = cls._union_options[type_value] + except KeyError: + raise TypeError('union_type_field %s.%s value %r no known' % + (cls.__name__, cls.union_type_field, type_value)) + return klass.fromjson(data) + + return cls(**kwargs) + + @classmethod + def get_c_struct_items(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False): + ignore_fields = set() if not ignore_fields else set(ignore_fields) + + items = [] + + for field_ in fields(cls): + if field_.name in ignore_fields: + continue + if in_union and field_.metadata["defining_class"] != cls: + continue + + name = field_.metadata.get("c_name", field_.name) + if "format" in field_.metadata: + if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls: + items.append(( + field_.metadata["format"].get_c_code(name), + field_.metadata.get("doc", None), + )), + elif issubclass(field_.type, StructType): + if field_.metadata.get("c_embed"): + embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only) + items.extend(embedded_items) + else: + items.append(( + field_.type.get_c_code(name, typedef=False), + field_.metadata.get("doc", None), + )) + else: + raise TypeError('field %s.%s has no format and is no StructType' % + (cls.__name__, field_.name)) + + if cls.union_type_field: + if not union_only: + union_code = cls.get_c_union_code(ignore_fields) + items.append(("union __packed %s;" % union_code, "")) + + return items + + @classmethod + def get_c_union_size(cls): + return max( + (option.get_min_size(no_inherited_fields=True) for option in + cls._union_options[cls.union_type_field].values()), + default=0, + ) + + @classmethod + def get_c_union_code(cls, ignore_fields=None): + union_items = [] + for key, option in cls._union_options[cls.union_type_field].items(): + base_name = normalize_name(getattr(key, 'name', option.__name__)) + union_items.append( + option.get_c_code(base_name, ignore_fields=ignore_fields, typedef=False, in_union=True) + ) + size = cls.get_c_union_size() + union_items.append( + "uint8_t bytes[%0d]; " % size + ) + return "{\n" + indent_c("\n".join(union_items)) + "\n}" + + @classmethod + def get_c_parts(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False): + ignore_fields = set() if not ignore_fields else set(ignore_fields) + + if union_only: + if cls.union_type_field: + union_code = cls.get_c_union_code(ignore_fields) + return "typedef union __packed %s" % union_code, "" + else: + return "", "" + + pre = "" + + items = cls.get_c_struct_items(ignore_fields=ignore_fields, + no_empty=no_empty, + top_level=top_level, + union_only=union_only, + in_union=in_union) + + if no_empty and not items: + return "", "" + + # todo: struct comment + if top_level: + comment = cls.__doc__.strip() + if comment: + pre += "/** %s */\n" % comment + pre += "typedef struct __packed " + else: + pre += "struct __packed " + + pre += "{\n%(elements)s\n}" % { + "elements": indent_c( + "\n".join( + code + ("" if not comment else (" /** %s */" % comment)) + for code, comment in items + ) + ), + } + return pre, "" + + @classmethod + def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False, + in_union=False) -> str: + pre, post = cls.get_c_parts(ignore_fields=ignore_fields, + no_empty=no_empty, + top_level=typedef, + union_only=union_only, + in_union=in_union) + if no_empty and not pre and not post: + return "" + return "%s %s%s;" % (pre, name, post) + + @classmethod + def get_variable_name(cls, base_name): + return base_name + + @classmethod + def get_struct_name(cls, base_name): + return "%s_t" % base_name + + @classmethod + def get_min_size(cls, no_inherited_fields=False) -> int: + if cls.union_type_field: + own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in fields(cls)]) + union_size = max( + [0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()]) + return own_size + union_size + if no_inherited_fields: + relevant_fields = [f for f in fields(cls) if f.metadata["defining_class"] == cls] + else: + relevant_fields = [f for f in fields(cls) if not f.metadata.get("union_discriminator")] + return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0) + + +def normalize_name(name): + if '_' in name: + return name.lower() + return re.sub( + r"([a-z])([A-Z])", + r"\1_\2", + name + ).lower() diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py index f4c47ee0..90de1d37 100644 --- a/src/c3nav/mesh/dataformats.py +++ b/src/c3nav/mesh/dataformats.py @@ -462,26 +462,27 @@ class MacAddressesListFormat(VarArrayFormat): super().__init__(child_type=MacAddressFormat()) -""" stuff """ @unique class LedType(IntEnum): SERIAL = 1 MULTIPIN = 2 + @dataclass class LedConfig(StructType, union_type_field="led_type"): - led_type: LedType = field(init=False, repr=False, metadata={"format": SimpleFormat('B')}) + led_type: LedType = field(metadata={"format": SimpleFormat('B')}) + leds_are_cool: int = field(metadata={"format": SimpleFormat('B')}) @dataclass -class SerialLedConfig(LedConfig, StructType, led_type=LedType.SERIAL): +class SerialLedConfig(LedConfig, led_type=LedType.SERIAL): gpio: int = field(metadata={"format": SimpleFormat('B')}) rmt: int = field(metadata={"format": SimpleFormat('B')}) @dataclass -class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN): +class MultipinLedConfig(LedConfig, led_type=LedType.MULTIPIN): gpio_red: int = field(metadata={"format": SimpleFormat('B')}) gpio_green: int = field(metadata={"format": SimpleFormat('B')}) gpio_blue: int = field(metadata={"format": SimpleFormat('B')}) @@ -491,3 +492,17 @@ class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN): class RangeItemType(StructType): address: str = field(metadata={"format": MacAddressFormat()}) distance: int = field(metadata={"format": SimpleFormat('H')}) + + +@dataclass +class FirmwareAppDescription(StructType, no_c_type=True): + magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False) + secure_version: int = field(metadata={"format": SimpleFormat('I')}) + reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False) + version: str = field(metadata={"format": FixedStrFormat(32)}) + project_name: str = field(metadata={"format": FixedStrFormat(32)}) + compile_time: str = field(metadata={"format": FixedStrFormat(16)}) + compile_date: str = field(metadata={"format": FixedStrFormat(16)}) + idf_version: str = field(metadata={"format": FixedStrFormat(32)}) + app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)}) + reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False) \ No newline at end of file diff --git a/src/c3nav/mesh/management/commands/mesh_msg_c.py b/src/c3nav/mesh/management/commands/mesh_msg_c.py index c8175edd..28ab7d68 100644 --- a/src/c3nav/mesh/management/commands/mesh_msg_c.py +++ b/src/c3nav/mesh/management/commands/mesh_msg_c.py @@ -2,7 +2,8 @@ from dataclasses import fields from django.core.management.base import BaseCommand -from c3nav.mesh.dataformats import normalize_name, LedConfig +from c3nav.mesh.dataformats import LedConfig +from c3nav.mesh.baseformats import normalize_name from c3nav.mesh.messages import MeshMessage, MeshMessageType from c3nav.mesh.utils import indent_c @@ -17,10 +18,11 @@ class Command(BaseCommand): name = name.replace('firmware', 'fw') return name - def handle(self, *args, **options): + def handle(self, *args, **options): done_struct_names = set() nodata = set() struct_lines = {} + struct_sizes = [] ignore_names = set(field_.name for field_ in fields(MeshMessage)) for msg_id, msg_type in MeshMessage.get_types().items(): @@ -35,13 +37,13 @@ class Command(BaseCommand): ))) name = "mesh_msg_%s_t" % base_name - if msg_id == MeshMessageType.CONFIG_LED: - msg_type = LedConfig - code = msg_type.get_c_code(name, ignore_fields=ignore_names, no_empty=True) if code: + size = msg_type.get_min_size(no_inherited_fields=True) struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', '')) + struct_sizes.append(size) print(code) + print("static_assert(sizeof(%s) == %d, \"size of generated message structs is calculated wrong\");" % (name, size)) print() else: nodata.add(msg_type) @@ -50,8 +52,8 @@ class Command(BaseCommand): print("typedef union __packed {") for line in struct_lines.values(): print(indent_c(line)) - print("} mesh_msg_data_t;") - print() + print("} mesh_msg_data_t; ") + print("static_assert(sizeof(mesh_msg_data_t) == %d, \"size of generated message structs is calculated wrong\");" % max(struct_sizes)) max_msg_type = max(MeshMessage.get_types().keys()) macro_data = [] diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index f4987dff..fa394067 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -8,9 +8,10 @@ import channels from asgiref.sync import async_to_sync from c3nav.mesh.utils import get_mesh_comm_group, indent_c -from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, FixedHexFormat, LedConfig, LedConfig, - MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType, - VarArrayFormat, RangeItemType) +from c3nav.mesh.dataformats import (LedConfig, LedConfig, + MacAddressesListFormat, MacAddressFormat, RangeItemType, FirmwareAppDescription) +from c3nav.mesh.baseformats import SimpleFormat, BoolFormat, FixedStrFormat, FixedHexFormat, VarArrayFormat, \ + VarStrFormat, StructType MESH_ROOT_ADDRESS = '00:00:00:00:00:00' MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff' @@ -31,6 +32,7 @@ class MeshMessageType(IntEnum): MESH_ROUTE_REQUEST = 0x07 MESH_ROUTE_RESPONSE = 0x08 MESH_ROUTE_TRACE = 0x09 + MESH_ROUTING_FAILED = 0x0a CONFIG_DUMP = 0x10 CONFIG_FIRMWARE = 0x11 @@ -83,12 +85,6 @@ class MeshMessage(StructType, union_type_field="msg_id"): def get_additional_c_fields(self): return () - @classmethod - def get_var_num(cls): - return 0 - # todo: fix - return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0) - @classmethod def get_variable_name(cls, base_name): return cls.c_struct_name or base_name @@ -201,6 +197,12 @@ class MeshRouteTraceMessage(MeshMessage, msg_id=MeshMessageType.MESH_ROUTE_TRACE }) +@dataclass +class MeshRoutingFailedMessage(MeshMessage, msg_id=MeshMessageType.MESH_ROUTING_FAILED): + """ TODO description""" + address: str = field(metadata={"format": MacAddressFormat()}) + + @dataclass class ConfigDumpMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_DUMP): """ request for the node to dump its config """ @@ -216,16 +218,7 @@ class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE) }) revision_major: int = field(metadata={"format": SimpleFormat('B')}) revision_minor: int = field(metadata={"format": SimpleFormat('B')}) - magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False) - secure_version: int = field(metadata={"format": SimpleFormat('I')}) - reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False) - version: str = field(metadata={"format": FixedStrFormat(32)}) - project_name: str = field(metadata={"format": FixedStrFormat(32)}) - compile_time: str = field(metadata={"format": FixedStrFormat(16)}) - compile_date: str = field(metadata={"format": FixedStrFormat(16)}) - idf_version: str = field(metadata={"format": FixedStrFormat(32)}) - app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)}) - reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False) + app_desc: FirmwareAppDescription = field(metadata={'json_embed': True}) @classmethod def upgrade_json(cls, data): @@ -259,7 +252,7 @@ class ConfigPositionMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_POSITION) @dataclass class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED): """ set/respond led config """ - led_config: LedConfig = field(metadata={}) + led_config: LedConfig = field(metadata={"c_embed": True}) @dataclass