big refactor of mesh message parsing etc

This commit is contained in:
Laura Klünder 2023-10-05 19:42:36 +02:00
parent 1b8d409839
commit 7a13193acd
3 changed files with 401 additions and 200 deletions

View file

@ -1,24 +1,65 @@
import re
import struct
from dataclasses import dataclass, field
from enum import IntEnum
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from enum import IntEnum, unique
from itertools import chain
from typing import Self, Sequence, Any
from c3nav.mesh.utils import indent_c
MAC_FMT = '%02x:%02x:%02x:%02x:%02x:%02x'
class SimpleFormat:
class BaseFormat(ABC):
@abstractmethod
def encode(self, value):
pass
@classmethod
@abstractmethod
def decode(cls, data) -> tuple[Any, bytes]:
pass
def fromjson(self, data):
return data
def tojson(self, data):
return data
@abstractmethod
def get_min_size(self):
pass
@abstractmethod
def get_c_parts(self) -> tuple[str, str]:
pass
def get_c_code(self, name) -> str:
pre, post = self.get_c_parts()
return "%s %s%s;" % (pre, name, post)
class SimpleFormat(BaseFormat):
def __init__(self, fmt):
self.fmt = fmt
self.size = struct.calcsize(fmt)
def encode(self, value):
return struct.pack(self.fmt, value)
self.c_type = self.c_types[self.fmt[-1]]
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
def decode(self, data: bytes):
def encode(self, value):
return struct.pack(self.fmt, (value, ) if self.num == 1 else tuple(value))
def decode(self, data: bytes) -> tuple[Any, bytes]:
value = struct.unpack(self.fmt, data[:self.size])
if len(value) == 1:
value = value[0]
return value, data[self.size:]
def get_min_size(self):
return self.size
c_types = {
"B": "uint8_t",
"H": "uint16_t",
@ -26,185 +67,365 @@ class SimpleFormat:
"b": "int8_t",
"h": "int16_t",
"i": "int32_t",
"s": "char",
}
def get_c_struct(self, name):
c_type = self.c_types[self.fmt[-1]]
num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
if num == 1:
return "%s %s;" % (c_type, name)
else:
return "%s %s[%d];" % (c_type, name, num)
def get_c_parts(self):
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
class FixedStrFormat:
class BoolFormat(SimpleFormat):
def __init__(self):
super().__init__('B')
def encode(self, value):
return super().encode(int(value))
def decode(self, data: bytes) -> tuple[bool, bytes]:
value, data = super().decode(data)
return bool(value), data
class FixedStrFormat(SimpleFormat):
def __init__(self, num):
self.num = num
super().__init__('%ds' % self.num)
def encode(self, value):
return struct.pack('%ss' % self.num, value.encode())
def encode(self, value: str):
return value.encode()[:self.num].ljust(self.num, bytes((0, ))),
def decode(self, data: bytes):
return struct.unpack('%ss' % self.num, data[:self.num])[0].rstrip(bytes((0, ))).decode(), data[self.num:]
def get_c_struct(self, name):
return "char %(name)s[%(length)d];" % {
"name": name,
"length": self.num,
}
def decode(self, data: bytes) -> tuple[str, bytes]:
return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:]
class BoolFormat:
def encode(self, value):
return struct.pack('B', int(value))
def decode(self, data: bytes):
return bool(struct.unpack('B', data[:1])[0]), data[1:]
def get_c_struct(self, name):
return "uint8_t %(name)s;" % {
"name": name,
}
class HexFormat:
class FixedHexFormat(SimpleFormat):
def __init__(self, num, sep=''):
self.num = num
self.sep = sep
super().__init__('%dB' % self.num)
def encode(self, value):
return struct.pack('%ss' % self.num, bytes.fromhex(value))
def encode(self, value: str):
return super().encode(tuple(bytes.fromhex(value)))
def decode(self, data: bytes):
return (
struct.unpack('%ss' % self.num, data[:self.num])[0].hex(*([self.sep] if self.sep else [])),
data[self.num:]
)
def get_c_struct(self, name):
return "uint8_t %(name)s[%(length)d];" % {
"name": name,
"length": self.num,
}
def decode(self, data: bytes) -> tuple[str, bytes]:
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
class VarStrFormat:
var_num = 1
@abstractmethod
class BaseVarFormat(BaseFormat, ABC):
def __init__(self, num_fmt='B'):
self.num_fmt = num_fmt
self.num_size = struct.calcsize(self.num_fmt)
def get_min_size(self):
return self.num_size
def get_num_c_code(self):
return SimpleFormat(self.num_fmt).get_c_code("num")
class VarArrayFormat(BaseVarFormat):
def __init__(self, child_type, num_fmt='B'):
super().__init__(num_fmt=num_fmt)
self.child_type = child_type
self.child_size = self.child_type.get_min_size()
def encode(self, values: Sequence) -> bytes:
data = struct.pack(self.num_fmt, (len(values),))
for value in values:
data += self.child_type.encode(value)
return data
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
return [
self.child_type.decode(data[i:i+self.child_size])
for i in range(self.num_size, self.num_size+num*self.child_size, self.child_size)
], data[self.num_size+num*self.child_size:]
def get_c_parts(self):
pre, post = self.child_type.get_c_parts()
return super().get_num_c_code()+"\n"+pre, "[0]"+post
class VarStrFormat(BaseVarFormat):
def encode(self, value: str) -> bytes:
return bytes((len(value)+1, )) + value.encode() + bytes((0, ))
return struct.pack(self.num_fmt, (len(str),))+value.encode()
def decode(self, data: bytes):
return data[1:data[0]].decode(), data[data[0]+1:]
def decode(self, data: bytes) -> tuple[str, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
return data[self.num_size:self.num_size+num].rstrip(bytes((0,))).decode(), data[self.num_size+num:]
def get_c_struct(self, name):
return "uint8_t num;\nchar %(name)s[0];" % {
"name": name,
def get_c_parts(self):
return super().get_num_c_code()+"\n"+"char", "[0]"
""" TPYES """
def normalize_name(name):
if '_' in name:
return name.lower()
return re.sub(
r"([a-z])([A-Z])",
r"\1_\2",
name
).lower()
@dataclass
class StructType:
_union_options = {}
union_type_field = None
# noinspection PyMethodOverriding
def __init_subclass__(cls, /, union_type_field=None, **kwargs):
cls.union_type_field = union_type_field
if union_type_field:
if union_type_field in cls._union_options:
raise TypeError('Duplicate union_type_field: %s', union_type_field)
cls._union_options[union_type_field] = {}
for key, values in cls._union_options.items():
value = kwargs.pop(key, None)
if value is not None:
if value in values:
raise TypeError('Duplicate %s: %s', (key, value))
values[value] = cls
setattr(cls, key, value)
super().__init_subclass__(**kwargs)
@classmethod
def encode(cls, instance) -> bytes:
data = bytes()
if cls.union_type_field and type(instance) is not cls:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(instance):
if field_.name is cls.union_type_field:
data += field_.metadata["format"].encode(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
data += instance.encode(instance)
return data
for field_ in fields(cls):
value = getattr(instance, field_.name)
if "format" in field_.metadata:
data += field_.metadata["format"].encode(value)
elif issubclass(field_.type, StructType):
if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value))
data += value.encode(value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, field_.name))
return data
@classmethod
def decode(cls, data: bytes) -> Self:
values = {}
for field_ in fields(cls):
if "format" in field_.metadata:
data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType):
data = field_.type.decode(data)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
values[field_.name] = field_.metadata["format"].decode(data)
if cls.union_type_field:
try:
type_value = values[cls.union_type_field]
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls._union_options[type_value]
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.decode(data)
return cls(**values)
@classmethod
def tojson(cls, instance) -> dict:
result = {}
if cls.union_type_field and type(instance) is not cls:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(instance):
if field_.name is cls.union_type_field:
result[field_.name] = field_.metadata["format"].encode(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
result.update(instance.tojson(instance))
return result
for field_ in fields(cls):
value = getattr(instance, field_.name)
if "format" in field_.metadata:
result[field_.name] = field_.metadata["format"].tojson(value)
elif issubclass(field_.type, StructType):
if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value))
result[field_.name] = value.tojson(value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, field_.name))
return result
@classmethod
def fromjson(cls, data):
data = data.copy()
# todo: upgrade_json
kwargs = {}
for field_ in fields(cls):
if "format" in field_.metadata:
data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType):
data = field_.type.decode(data)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
kwargs[field_.name], data = field_.metadata["format"].decode(data)
if cls.union_type_field:
try:
type_value = kwargs[cls.union_type_field]
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls._union_options[type_value]
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.fromjson(data)
return cls(**kwargs)
@classmethod
def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False):
ignore_fields = set() if not ignore_fields else set(ignore_fields)
items = []
for field_ in fields(cls):
if field_.name in ignore_fields:
continue
if "format" in field_.metadata:
items.append((
field_.metadata["format"].get_c_code(field_.name),
field_.metadata.get("doc", None),
)),
elif issubclass(field_.type, StructType):
items.append((
field_.type.get_c_code(field_.name, typedef=False),
field_.metadata.get("doc", None),
))
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
if cls.union_type_field:
parent_fields = set(field_.name for field_ in fields(cls))
union_items = []
for key, option in cls._union_options[cls.union_type_field].items():
name = normalize_name(getattr(key, 'name', option.__name__))
union_items.append(
option.get_c_code(name, ignore_fields=(ignore_fields | parent_fields))
)
items.append((
"union {\n"+indent_c("\n".join(union_items))+"\n}",
""))
if no_empty and not items:
return "", ""
# todo: struct comment
pre = ""
if typedef:
comment = cls.__doc__.strip()
if comment:
pre += "/** %s */\n" % comment
pre += "typedef struct __packed "
else:
pre += "struct "
pre += "{\n%(elements)s\n}" % {
"elements": indent_c(
"\n".join(
code + ("" if not comment else (" /** %s */" % comment))
for code, comment in items
)
),
}
return pre, ""
@classmethod
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str:
pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef)
if no_empty and not pre and not post:
return ""
return "%s %s%s;" % (pre, name, post)
@classmethod
def get_base_name(cls):
return cls.__name__
@classmethod
def get_variable_name(cls):
return cls.get_base_name()
@classmethod
def get_struct_name(cls):
return "%s_t" % cls.get_base_name()
class MacAddressFormat:
def encode(self, value: str) -> bytes:
return bytes(int(value[i*3:i*3+2], 16) for i in range(6))
def decode(self, data: bytes):
return (MAC_FMT % tuple(data[:6])), data[6:]
def get_c_struct(self, name):
return "uint8_t %(name)s[6];" % {
"name": name,
}
class MacAddressFormat(FixedHexFormat):
def __init__(self):
super().__init__(num=6, sep=':')
class MacAddressesListFormat:
var_num = 6
def encode(self, value: list[str]) -> bytes:
return bytes((len(value), )) + sum(
(bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value),
b''
)
def decode(self, data: bytes):
return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6:]
def get_c_struct(self, name):
return "uint8_t num;\nuint8_t %(name)s[6][0];" % {
"name": name,
}
class MacAddressesListFormat(VarArrayFormat):
def __init__(self):
super().__init__(child_type=MacAddressFormat())
""" stuff """
@unique
class LedType(IntEnum):
SERIAL = 1
MULTIPIN = 2
@dataclass
class LedConfig:
led_type: LedType = field(init=False, repr=False)
ledconfig_types = {}
# noinspection PyMethodOverriding
def __init_subclass__(cls, /, led_type: LedType, **kwargs):
super().__init_subclass__(**kwargs)
cls.led_type = led_type
LedConfig.ledconfig_types[led_type] = cls
@classmethod
def fromjson(cls, data):
if data is None:
return None
return LedConfig.ledconfig_types[data.pop('led_type')](**data)
class LedConfig(StructType, union_type_field="led_type"):
led_type: LedType = field(init=False, repr=False, metadata={"format": SimpleFormat('B')})
@dataclass
class SerialLedConfig(LedConfig, led_type=LedType.SERIAL):
gpio: int
rmt: int
class SerialLedConfig(LedConfig, StructType, led_type=LedType.SERIAL):
gpio: int = field(metadata={"format": SimpleFormat('B')})
rmt: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class MultipinLedConfig(LedConfig, led_type=LedType.MULTIPIN):
gpio_red: int
gpio_green: int
gpio_blue: int
class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN):
gpio_red: int = field(metadata={"format": SimpleFormat('B')})
gpio_green: int = field(metadata={"format": SimpleFormat('B')})
gpio_blue: int = field(metadata={"format": SimpleFormat('B')})
class LedConfigFormat:
def encode(self, value) -> bytes:
if value is None:
return struct.pack('BBBB', (0, 0, 0, 0))
if isinstance(value, SerialLedConfig):
return struct.pack('BBBB', (value.type_id, value.gpio, value.rmt, 0))
if isinstance(value, MultipinLedConfig):
return struct.pack('BBBB', (value.type_id, value.gpio_red, value.gpio_green, value.gpio_blue))
raise ValueError
def decode(self, data: bytes):
type_, *bytes_ = struct.unpack('BBBB', data)
if type_ == 0:
value = None
elif type_ == 1:
value = SerialLedConfig(gpio=bytes_[0], rmt=bytes_[1])
elif type_ == 2:
value = MultipinLedConfig(gpio_red=bytes_[0], gpio_green=bytes_[1], gpio_blue=bytes_[2])
else:
raise ValueError
return value, data[4:]
def get_c_struct(self, name):
return (
"uint8_t type;\n"
"union {\n"
" struct {\n"
" uint8_t gpio;\n"
" uint8_t rmt;\n"
" } serial;\n"
" struct {\n"
" uint8_t gpio_red;\n"
" uint8_t gpio_green;\n"
" uint8_t gpio_blue;\n"
" } multipin;\n"
" uint8_t bytes[3];\n"
"};"
)
class RangeItemType(StructType):
address: str = field(metadata={"format": MacAddressFormat()})
distance: int = field(metadata={"format": SimpleFormat('H')})

View file

@ -1,5 +1,8 @@
from dataclasses import fields
from django.core.management.base import BaseCommand
from c3nav.mesh.dataformats import normalize_name
from c3nav.mesh.messages import MeshMessage
from c3nav.mesh.utils import indent_c
@ -12,26 +15,26 @@ class Command(BaseCommand):
nodata = set()
struct_lines = {}
for msg_type in MeshMessage.msg_types.values():
ignore_names = set(field_.name for field_ in fields(MeshMessage))
from pprint import pprint
pprint(MeshMessage.get_msg_types())
for msg_id, msg_type in MeshMessage.get_msg_types().items():
if msg_type.c_struct_name:
if msg_type.c_struct_name in done_struct_names:
continue
done_struct_names.add(msg_type.c_struct_name)
msg_type = MeshMessage.c_structs[msg_type.c_struct_name]
code = msg_type.get_c_struct()
name = "mesh_msg_%s_t" % (
msg_type.c_struct_name or normalize_name(getattr(msg_id, 'name', msg_type.__name__))
)
code = msg_type.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
if code:
struct_lines[msg_type.get_c_struct_name()] = (
"mesh_msg_%s_t %s;" % (
msg_type.get_c_struct_name(),
msg_type.get_c_struct_name().replace("_announce", ""),
)
)
print(code)
print()
else:
nodata.add(msg_type)
return
print("/** union between all message data structs */")
print("typedef union __packed {")
for line in struct_lines.values():

View file

@ -8,8 +8,8 @@ import channels
from asgiref.sync import async_to_sync
from c3nav.mesh.utils import get_mesh_comm_group, indent_c
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, HexFormat, LedConfig, LedConfigFormat,
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat)
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, FixedHexFormat, LedConfig, LedConfig,
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType)
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff'
@ -37,6 +37,8 @@ class MeshMessageType(IntEnum):
CONFIG_LED = 0x13
CONFIG_UPLINK = 0x14
LOCATE_REPORT_RANGE = 0x20
M = TypeVar('M', bound='MeshMessage')
@ -48,22 +50,16 @@ class ChipType(IntEnum):
@dataclass
class MeshMessage:
class MeshMessage(StructType, union_type_field="msg_id"):
dst: str = field(metadata={"format": MacAddressFormat()})
src: str = field(metadata={"format": MacAddressFormat()})
msg_id: int = field(metadata={"format": SimpleFormat('B')}, init=False, repr=False)
msg_types = {}
c_structs = {}
c_struct_name = None
# noinspection PyMethodOverriding
def __init_subclass__(cls, /, msg_id=None, c_struct_name=None, **kwargs):
def __init_subclass__(cls, /, c_struct_name=None, **kwargs):
super().__init_subclass__(**kwargs)
if msg_id is not None:
cls.msg_id = msg_id
if msg_id in MeshMessage.msg_types:
raise TypeError('duplicate use of msg_id %d' % msg_id)
MeshMessage.msg_types[msg_id] = cls
if c_struct_name:
cls.c_struct_name = c_struct_name
if c_struct_name in MeshMessage.c_structs:
@ -117,40 +113,10 @@ class MeshMessage:
def get_additional_c_fields(self):
return ()
@classmethod
def get_c_struct(cls):
ignore_fields = cls.get_ignore_c_fields()
if cls != MeshMessage:
ignore_fields |= set(field.name for field in fields(MeshMessage))
items = tuple(
(
tuple(field.metadata["format"].get_c_struct(field.metadata.get("c_name", field.name)).split("\n")),
field.metadata.get("doc", None),
)
for field in fields(cls)
if field.name not in ignore_fields
)
if not items:
return ""
max_line_len = max(len(line) for line in chain(*(code for code, doc in items)))
msg_comment = cls.__doc__.strip()
return "%(comment)stypedef struct __packed {\n%(elements)s\n} %(name)s;" % {
"comment": ("/** %s */\n" % msg_comment) if msg_comment else "",
"elements": indent_c(
"\n".join(chain(*(
(code if not comment
else (code[:-1]+("%s /** %s */" % (code[-1].ljust(max_line_len), comment),)))
for code, comment in items
), cls.get_additional_c_fields()))
),
"name": "mesh_msg_%s_t" % cls.get_c_struct_name(),
}
@classmethod
def get_var_num(cls):
return 0
# todo: fix
return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0)
@classmethod
@ -172,6 +138,10 @@ class MeshMessage:
cls.__name__.removeprefix('Mesh').removesuffix('Message')
).upper().replace('CONFIG', 'CFG').replace('FIRMWARE', 'FW').replace('POSITION', 'POS')
@classmethod
def get_msg_types(cls):
return cls._union_options["msg_id"]
@dataclass
class NoopMessage(MeshMessage, msg_id=MeshMessageType.NOOP):
@ -291,7 +261,7 @@ class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE)
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={"format": HexFormat(32)})
app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)})
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
@classmethod
@ -327,7 +297,7 @@ class ConfigPositionMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_POSITION)
@dataclass
class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED):
""" set/respond led config """
led_config: LedConfig = field(metadata={"format": LedConfigFormat()})
led_config: LedConfig = field(metadata={})
@dataclass
@ -341,3 +311,10 @@ class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK):
ssl: bool = field(metadata={"format": BoolFormat()})
host: str = field(metadata={"format": FixedStrFormat(64)})
port: int = field(metadata={"format": SimpleFormat('H')})
@dataclass
class LocateReportRangeMessage(MeshMessage, msg_id=MeshMessageType.LOCATE_REPORT_RANGE):
""" report distance to given nodes """
#ranges: dict[str, int] =
pass