big refactor of mesh message parsing etc
This commit is contained in:
parent
1b8d409839
commit
7a13193acd
3 changed files with 401 additions and 200 deletions
|
@ -1,24 +1,65 @@
|
||||||
|
import re
|
||||||
import struct
|
import struct
|
||||||
from dataclasses import dataclass, field
|
from abc import ABC, abstractmethod
|
||||||
from enum import IntEnum
|
from dataclasses import dataclass, field, fields
|
||||||
|
from enum import IntEnum, unique
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Self, Sequence, Any
|
||||||
|
|
||||||
|
from c3nav.mesh.utils import indent_c
|
||||||
|
|
||||||
MAC_FMT = '%02x:%02x:%02x:%02x:%02x:%02x'
|
MAC_FMT = '%02x:%02x:%02x:%02x:%02x:%02x'
|
||||||
|
|
||||||
|
|
||||||
class SimpleFormat:
|
class BaseFormat(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def encode(self, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def decode(cls, data) -> tuple[Any, bytes]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fromjson(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
def tojson(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_min_size(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_c_parts(self) -> tuple[str, str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_c_code(self, name) -> str:
|
||||||
|
pre, post = self.get_c_parts()
|
||||||
|
return "%s %s%s;" % (pre, name, post)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleFormat(BaseFormat):
|
||||||
def __init__(self, fmt):
|
def __init__(self, fmt):
|
||||||
self.fmt = fmt
|
self.fmt = fmt
|
||||||
self.size = struct.calcsize(fmt)
|
self.size = struct.calcsize(fmt)
|
||||||
|
|
||||||
def encode(self, value):
|
self.c_type = self.c_types[self.fmt[-1]]
|
||||||
return struct.pack(self.fmt, value)
|
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
|
||||||
|
|
||||||
def decode(self, data: bytes):
|
def encode(self, value):
|
||||||
|
return struct.pack(self.fmt, (value, ) if self.num == 1 else tuple(value))
|
||||||
|
|
||||||
|
def decode(self, data: bytes) -> tuple[Any, bytes]:
|
||||||
value = struct.unpack(self.fmt, data[:self.size])
|
value = struct.unpack(self.fmt, data[:self.size])
|
||||||
if len(value) == 1:
|
if len(value) == 1:
|
||||||
value = value[0]
|
value = value[0]
|
||||||
return value, data[self.size:]
|
return value, data[self.size:]
|
||||||
|
|
||||||
|
def get_min_size(self):
|
||||||
|
return self.size
|
||||||
|
|
||||||
c_types = {
|
c_types = {
|
||||||
"B": "uint8_t",
|
"B": "uint8_t",
|
||||||
"H": "uint16_t",
|
"H": "uint16_t",
|
||||||
|
@ -26,185 +67,365 @@ class SimpleFormat:
|
||||||
"b": "int8_t",
|
"b": "int8_t",
|
||||||
"h": "int16_t",
|
"h": "int16_t",
|
||||||
"i": "int32_t",
|
"i": "int32_t",
|
||||||
|
"s": "char",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_c_struct(self, name):
|
def get_c_parts(self):
|
||||||
c_type = self.c_types[self.fmt[-1]]
|
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
|
||||||
num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
|
|
||||||
if num == 1:
|
|
||||||
return "%s %s;" % (c_type, name)
|
|
||||||
else:
|
|
||||||
return "%s %s[%d];" % (c_type, name, num)
|
|
||||||
|
|
||||||
|
|
||||||
class FixedStrFormat:
|
class BoolFormat(SimpleFormat):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('B')
|
||||||
|
|
||||||
|
def encode(self, value):
|
||||||
|
return super().encode(int(value))
|
||||||
|
|
||||||
|
def decode(self, data: bytes) -> tuple[bool, bytes]:
|
||||||
|
value, data = super().decode(data)
|
||||||
|
return bool(value), data
|
||||||
|
|
||||||
|
|
||||||
|
class FixedStrFormat(SimpleFormat):
|
||||||
def __init__(self, num):
|
def __init__(self, num):
|
||||||
self.num = num
|
self.num = num
|
||||||
|
super().__init__('%ds' % self.num)
|
||||||
|
|
||||||
def encode(self, value):
|
def encode(self, value: str):
|
||||||
return struct.pack('%ss' % self.num, value.encode())
|
return value.encode()[:self.num].ljust(self.num, bytes((0, ))),
|
||||||
|
|
||||||
def decode(self, data: bytes):
|
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||||
return struct.unpack('%ss' % self.num, data[:self.num])[0].rstrip(bytes((0, ))).decode(), data[self.num:]
|
return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:]
|
||||||
|
|
||||||
def get_c_struct(self, name):
|
|
||||||
return "char %(name)s[%(length)d];" % {
|
|
||||||
"name": name,
|
|
||||||
"length": self.num,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BoolFormat:
|
class FixedHexFormat(SimpleFormat):
|
||||||
def encode(self, value):
|
|
||||||
return struct.pack('B', int(value))
|
|
||||||
|
|
||||||
def decode(self, data: bytes):
|
|
||||||
return bool(struct.unpack('B', data[:1])[0]), data[1:]
|
|
||||||
|
|
||||||
def get_c_struct(self, name):
|
|
||||||
return "uint8_t %(name)s;" % {
|
|
||||||
"name": name,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class HexFormat:
|
|
||||||
def __init__(self, num, sep=''):
|
def __init__(self, num, sep=''):
|
||||||
self.num = num
|
self.num = num
|
||||||
self.sep = sep
|
self.sep = sep
|
||||||
|
super().__init__('%dB' % self.num)
|
||||||
|
|
||||||
def encode(self, value):
|
def encode(self, value: str):
|
||||||
return struct.pack('%ss' % self.num, bytes.fromhex(value))
|
return super().encode(tuple(bytes.fromhex(value)))
|
||||||
|
|
||||||
def decode(self, data: bytes):
|
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||||
return (
|
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
|
||||||
struct.unpack('%ss' % self.num, data[:self.num])[0].hex(*([self.sep] if self.sep else [])),
|
|
||||||
data[self.num:]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_c_struct(self, name):
|
|
||||||
return "uint8_t %(name)s[%(length)d];" % {
|
|
||||||
"name": name,
|
|
||||||
"length": self.num,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class VarStrFormat:
|
@abstractmethod
|
||||||
var_num = 1
|
class BaseVarFormat(BaseFormat, ABC):
|
||||||
|
def __init__(self, num_fmt='B'):
|
||||||
|
self.num_fmt = num_fmt
|
||||||
|
self.num_size = struct.calcsize(self.num_fmt)
|
||||||
|
|
||||||
|
def get_min_size(self):
|
||||||
|
return self.num_size
|
||||||
|
|
||||||
|
def get_num_c_code(self):
|
||||||
|
return SimpleFormat(self.num_fmt).get_c_code("num")
|
||||||
|
|
||||||
|
|
||||||
|
class VarArrayFormat(BaseVarFormat):
|
||||||
|
def __init__(self, child_type, num_fmt='B'):
|
||||||
|
super().__init__(num_fmt=num_fmt)
|
||||||
|
self.child_type = child_type
|
||||||
|
self.child_size = self.child_type.get_min_size()
|
||||||
|
|
||||||
|
def encode(self, values: Sequence) -> bytes:
|
||||||
|
data = struct.pack(self.num_fmt, (len(values),))
|
||||||
|
for value in values:
|
||||||
|
data += self.child_type.encode(value)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
|
||||||
|
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||||
|
return [
|
||||||
|
self.child_type.decode(data[i:i+self.child_size])
|
||||||
|
for i in range(self.num_size, self.num_size+num*self.child_size, self.child_size)
|
||||||
|
], data[self.num_size+num*self.child_size:]
|
||||||
|
|
||||||
|
def get_c_parts(self):
|
||||||
|
pre, post = self.child_type.get_c_parts()
|
||||||
|
return super().get_num_c_code()+"\n"+pre, "[0]"+post
|
||||||
|
|
||||||
|
|
||||||
|
class VarStrFormat(BaseVarFormat):
|
||||||
def encode(self, value: str) -> bytes:
|
def encode(self, value: str) -> bytes:
|
||||||
return bytes((len(value)+1, )) + value.encode() + bytes((0, ))
|
return struct.pack(self.num_fmt, (len(str),))+value.encode()
|
||||||
|
|
||||||
def decode(self, data: bytes):
|
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||||
return data[1:data[0]].decode(), data[data[0]+1:]
|
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||||
|
return data[self.num_size:self.num_size+num].rstrip(bytes((0,))).decode(), data[self.num_size+num:]
|
||||||
|
|
||||||
def get_c_struct(self, name):
|
def get_c_parts(self):
|
||||||
return "uint8_t num;\nchar %(name)s[0];" % {
|
return super().get_num_c_code()+"\n"+"char", "[0]"
|
||||||
"name": name,
|
|
||||||
|
|
||||||
|
""" TPYES """
|
||||||
|
|
||||||
|
def normalize_name(name):
|
||||||
|
if '_' in name:
|
||||||
|
return name.lower()
|
||||||
|
return re.sub(
|
||||||
|
r"([a-z])([A-Z])",
|
||||||
|
r"\1_\2",
|
||||||
|
name
|
||||||
|
).lower()
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StructType:
|
||||||
|
_union_options = {}
|
||||||
|
union_type_field = None
|
||||||
|
|
||||||
|
# noinspection PyMethodOverriding
|
||||||
|
def __init_subclass__(cls, /, union_type_field=None, **kwargs):
|
||||||
|
cls.union_type_field = union_type_field
|
||||||
|
if union_type_field:
|
||||||
|
if union_type_field in cls._union_options:
|
||||||
|
raise TypeError('Duplicate union_type_field: %s', union_type_field)
|
||||||
|
cls._union_options[union_type_field] = {}
|
||||||
|
for key, values in cls._union_options.items():
|
||||||
|
value = kwargs.pop(key, None)
|
||||||
|
if value is not None:
|
||||||
|
if value in values:
|
||||||
|
raise TypeError('Duplicate %s: %s', (key, value))
|
||||||
|
values[value] = cls
|
||||||
|
setattr(cls, key, value)
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encode(cls, instance) -> bytes:
|
||||||
|
data = bytes()
|
||||||
|
if cls.union_type_field and type(instance) is not cls:
|
||||||
|
if not isinstance(instance, cls):
|
||||||
|
raise ValueError('expected value of type %r, got %r' % (cls, instance))
|
||||||
|
|
||||||
|
for field_ in fields(instance):
|
||||||
|
if field_.name is cls.union_type_field:
|
||||||
|
data += field_.metadata["format"].encode(getattr(instance, field_.name))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
|
||||||
|
|
||||||
|
data += instance.encode(instance)
|
||||||
|
return data
|
||||||
|
|
||||||
|
for field_ in fields(cls):
|
||||||
|
value = getattr(instance, field_.name)
|
||||||
|
if "format" in field_.metadata:
|
||||||
|
data += field_.metadata["format"].encode(value)
|
||||||
|
elif issubclass(field_.type, StructType):
|
||||||
|
if not isinstance(value, field_.type):
|
||||||
|
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
||||||
|
(field_.type, cls.__name__, field_.name, value))
|
||||||
|
data += value.encode(value)
|
||||||
|
else:
|
||||||
|
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||||
|
(cls.__class__.__name__, field_.name))
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode(cls, data: bytes) -> Self:
|
||||||
|
values = {}
|
||||||
|
for field_ in fields(cls):
|
||||||
|
if "format" in field_.metadata:
|
||||||
|
data = field_.metadata["format"].decode(data)
|
||||||
|
elif issubclass(field_.type, StructType):
|
||||||
|
data = field_.type.decode(data)
|
||||||
|
else:
|
||||||
|
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||||
|
(cls.__name__, field_.name))
|
||||||
|
values[field_.name] = field_.metadata["format"].decode(data)
|
||||||
|
|
||||||
|
if cls.union_type_field:
|
||||||
|
try:
|
||||||
|
type_value = values[cls.union_type_field]
|
||||||
|
except KeyError:
|
||||||
|
raise TypeError('union_type_field %s.%s is missing' %
|
||||||
|
(cls.__name__, cls.union_type_field))
|
||||||
|
try:
|
||||||
|
klass = cls._union_options[type_value]
|
||||||
|
except KeyError:
|
||||||
|
raise TypeError('union_type_field %s.%s value %r no known' %
|
||||||
|
(cls.__name__, cls.union_type_field, type_value))
|
||||||
|
return klass.decode(data)
|
||||||
|
return cls(**values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tojson(cls, instance) -> dict:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if cls.union_type_field and type(instance) is not cls:
|
||||||
|
if not isinstance(instance, cls):
|
||||||
|
raise ValueError('expected value of type %r, got %r' % (cls, instance))
|
||||||
|
|
||||||
|
for field_ in fields(instance):
|
||||||
|
if field_.name is cls.union_type_field:
|
||||||
|
result[field_.name] = field_.metadata["format"].encode(getattr(instance, field_.name))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
|
||||||
|
|
||||||
|
result.update(instance.tojson(instance))
|
||||||
|
return result
|
||||||
|
|
||||||
|
for field_ in fields(cls):
|
||||||
|
value = getattr(instance, field_.name)
|
||||||
|
if "format" in field_.metadata:
|
||||||
|
result[field_.name] = field_.metadata["format"].tojson(value)
|
||||||
|
elif issubclass(field_.type, StructType):
|
||||||
|
if not isinstance(value, field_.type):
|
||||||
|
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
||||||
|
(field_.type, cls.__name__, field_.name, value))
|
||||||
|
result[field_.name] = value.tojson(value)
|
||||||
|
else:
|
||||||
|
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||||
|
(cls.__class__.__name__, field_.name))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fromjson(cls, data):
|
||||||
|
data = data.copy()
|
||||||
|
|
||||||
|
# todo: upgrade_json
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
for field_ in fields(cls):
|
||||||
|
if "format" in field_.metadata:
|
||||||
|
data = field_.metadata["format"].decode(data)
|
||||||
|
elif issubclass(field_.type, StructType):
|
||||||
|
data = field_.type.decode(data)
|
||||||
|
else:
|
||||||
|
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||||
|
(cls.__name__, field_.name))
|
||||||
|
kwargs[field_.name], data = field_.metadata["format"].decode(data)
|
||||||
|
|
||||||
|
if cls.union_type_field:
|
||||||
|
try:
|
||||||
|
type_value = kwargs[cls.union_type_field]
|
||||||
|
except KeyError:
|
||||||
|
raise TypeError('union_type_field %s.%s is missing' %
|
||||||
|
(cls.__name__, cls.union_type_field))
|
||||||
|
try:
|
||||||
|
klass = cls._union_options[type_value]
|
||||||
|
except KeyError:
|
||||||
|
raise TypeError('union_type_field %s.%s value %r no known' %
|
||||||
|
(cls.__name__, cls.union_type_field, type_value))
|
||||||
|
return klass.fromjson(data)
|
||||||
|
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False):
|
||||||
|
ignore_fields = set() if not ignore_fields else set(ignore_fields)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for field_ in fields(cls):
|
||||||
|
if field_.name in ignore_fields:
|
||||||
|
continue
|
||||||
|
if "format" in field_.metadata:
|
||||||
|
items.append((
|
||||||
|
field_.metadata["format"].get_c_code(field_.name),
|
||||||
|
field_.metadata.get("doc", None),
|
||||||
|
)),
|
||||||
|
elif issubclass(field_.type, StructType):
|
||||||
|
items.append((
|
||||||
|
field_.type.get_c_code(field_.name, typedef=False),
|
||||||
|
field_.metadata.get("doc", None),
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||||
|
(cls.__name__, field_.name))
|
||||||
|
|
||||||
|
if cls.union_type_field:
|
||||||
|
parent_fields = set(field_.name for field_ in fields(cls))
|
||||||
|
union_items = []
|
||||||
|
for key, option in cls._union_options[cls.union_type_field].items():
|
||||||
|
name = normalize_name(getattr(key, 'name', option.__name__))
|
||||||
|
union_items.append(
|
||||||
|
option.get_c_code(name, ignore_fields=(ignore_fields | parent_fields))
|
||||||
|
)
|
||||||
|
items.append((
|
||||||
|
"union {\n"+indent_c("\n".join(union_items))+"\n}",
|
||||||
|
""))
|
||||||
|
|
||||||
|
if no_empty and not items:
|
||||||
|
return "", ""
|
||||||
|
|
||||||
|
# todo: struct comment
|
||||||
|
pre = ""
|
||||||
|
if typedef:
|
||||||
|
comment = cls.__doc__.strip()
|
||||||
|
if comment:
|
||||||
|
pre += "/** %s */\n" % comment
|
||||||
|
pre += "typedef struct __packed "
|
||||||
|
else:
|
||||||
|
pre += "struct "
|
||||||
|
|
||||||
|
pre += "{\n%(elements)s\n}" % {
|
||||||
|
"elements": indent_c(
|
||||||
|
"\n".join(
|
||||||
|
code + ("" if not comment else (" /** %s */" % comment))
|
||||||
|
for code, comment in items
|
||||||
|
)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
return pre, ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str:
|
||||||
|
pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef)
|
||||||
|
if no_empty and not pre and not post:
|
||||||
|
return ""
|
||||||
|
return "%s %s%s;" % (pre, name, post)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_base_name(cls):
|
||||||
|
return cls.__name__
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_variable_name(cls):
|
||||||
|
return cls.get_base_name()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_struct_name(cls):
|
||||||
|
return "%s_t" % cls.get_base_name()
|
||||||
|
|
||||||
|
|
||||||
class MacAddressFormat:
|
class MacAddressFormat(FixedHexFormat):
|
||||||
def encode(self, value: str) -> bytes:
|
def __init__(self):
|
||||||
return bytes(int(value[i*3:i*3+2], 16) for i in range(6))
|
super().__init__(num=6, sep=':')
|
||||||
|
|
||||||
def decode(self, data: bytes):
|
|
||||||
return (MAC_FMT % tuple(data[:6])), data[6:]
|
|
||||||
|
|
||||||
def get_c_struct(self, name):
|
|
||||||
return "uint8_t %(name)s[6];" % {
|
|
||||||
"name": name,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MacAddressesListFormat:
|
class MacAddressesListFormat(VarArrayFormat):
|
||||||
var_num = 6
|
def __init__(self):
|
||||||
|
super().__init__(child_type=MacAddressFormat())
|
||||||
def encode(self, value: list[str]) -> bytes:
|
|
||||||
return bytes((len(value), )) + sum(
|
|
||||||
(bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value),
|
|
||||||
b''
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, data: bytes):
|
|
||||||
return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6:]
|
|
||||||
|
|
||||||
def get_c_struct(self, name):
|
|
||||||
return "uint8_t num;\nuint8_t %(name)s[6][0];" % {
|
|
||||||
"name": name,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
""" stuff """
|
||||||
|
|
||||||
|
@unique
|
||||||
class LedType(IntEnum):
|
class LedType(IntEnum):
|
||||||
SERIAL = 1
|
SERIAL = 1
|
||||||
MULTIPIN = 2
|
MULTIPIN = 2
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LedConfig:
|
class LedConfig(StructType, union_type_field="led_type"):
|
||||||
led_type: LedType = field(init=False, repr=False)
|
led_type: LedType = field(init=False, repr=False, metadata={"format": SimpleFormat('B')})
|
||||||
ledconfig_types = {}
|
|
||||||
|
|
||||||
# noinspection PyMethodOverriding
|
|
||||||
def __init_subclass__(cls, /, led_type: LedType, **kwargs):
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
cls.led_type = led_type
|
|
||||||
LedConfig.ledconfig_types[led_type] = cls
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fromjson(cls, data):
|
|
||||||
if data is None:
|
|
||||||
return None
|
|
||||||
return LedConfig.ledconfig_types[data.pop('led_type')](**data)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SerialLedConfig(LedConfig, led_type=LedType.SERIAL):
|
class SerialLedConfig(LedConfig, StructType, led_type=LedType.SERIAL):
|
||||||
gpio: int
|
gpio: int = field(metadata={"format": SimpleFormat('B')})
|
||||||
rmt: int
|
rmt: int = field(metadata={"format": SimpleFormat('B')})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultipinLedConfig(LedConfig, led_type=LedType.MULTIPIN):
|
class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN):
|
||||||
gpio_red: int
|
gpio_red: int = field(metadata={"format": SimpleFormat('B')})
|
||||||
gpio_green: int
|
gpio_green: int = field(metadata={"format": SimpleFormat('B')})
|
||||||
gpio_blue: int
|
gpio_blue: int = field(metadata={"format": SimpleFormat('B')})
|
||||||
|
|
||||||
|
|
||||||
class LedConfigFormat:
|
class RangeItemType(StructType):
|
||||||
def encode(self, value) -> bytes:
|
address: str = field(metadata={"format": MacAddressFormat()})
|
||||||
if value is None:
|
distance: int = field(metadata={"format": SimpleFormat('H')})
|
||||||
return struct.pack('BBBB', (0, 0, 0, 0))
|
|
||||||
if isinstance(value, SerialLedConfig):
|
|
||||||
return struct.pack('BBBB', (value.type_id, value.gpio, value.rmt, 0))
|
|
||||||
if isinstance(value, MultipinLedConfig):
|
|
||||||
return struct.pack('BBBB', (value.type_id, value.gpio_red, value.gpio_green, value.gpio_blue))
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
def decode(self, data: bytes):
|
|
||||||
type_, *bytes_ = struct.unpack('BBBB', data)
|
|
||||||
if type_ == 0:
|
|
||||||
value = None
|
|
||||||
elif type_ == 1:
|
|
||||||
value = SerialLedConfig(gpio=bytes_[0], rmt=bytes_[1])
|
|
||||||
elif type_ == 2:
|
|
||||||
value = MultipinLedConfig(gpio_red=bytes_[0], gpio_green=bytes_[1], gpio_blue=bytes_[2])
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
return value, data[4:]
|
|
||||||
|
|
||||||
def get_c_struct(self, name):
|
|
||||||
return (
|
|
||||||
"uint8_t type;\n"
|
|
||||||
"union {\n"
|
|
||||||
" struct {\n"
|
|
||||||
" uint8_t gpio;\n"
|
|
||||||
" uint8_t rmt;\n"
|
|
||||||
" } serial;\n"
|
|
||||||
" struct {\n"
|
|
||||||
" uint8_t gpio_red;\n"
|
|
||||||
" uint8_t gpio_green;\n"
|
|
||||||
" uint8_t gpio_blue;\n"
|
|
||||||
" } multipin;\n"
|
|
||||||
" uint8_t bytes[3];\n"
|
|
||||||
"};"
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
from dataclasses import fields
|
||||||
|
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
|
from c3nav.mesh.dataformats import normalize_name
|
||||||
from c3nav.mesh.messages import MeshMessage
|
from c3nav.mesh.messages import MeshMessage
|
||||||
from c3nav.mesh.utils import indent_c
|
from c3nav.mesh.utils import indent_c
|
||||||
|
|
||||||
|
@ -12,26 +15,26 @@ class Command(BaseCommand):
|
||||||
nodata = set()
|
nodata = set()
|
||||||
struct_lines = {}
|
struct_lines = {}
|
||||||
|
|
||||||
for msg_type in MeshMessage.msg_types.values():
|
ignore_names = set(field_.name for field_ in fields(MeshMessage))
|
||||||
|
from pprint import pprint
|
||||||
|
pprint(MeshMessage.get_msg_types())
|
||||||
|
for msg_id, msg_type in MeshMessage.get_msg_types().items():
|
||||||
if msg_type.c_struct_name:
|
if msg_type.c_struct_name:
|
||||||
if msg_type.c_struct_name in done_struct_names:
|
if msg_type.c_struct_name in done_struct_names:
|
||||||
continue
|
continue
|
||||||
done_struct_names.add(msg_type.c_struct_name)
|
done_struct_names.add(msg_type.c_struct_name)
|
||||||
msg_type = MeshMessage.c_structs[msg_type.c_struct_name]
|
msg_type = MeshMessage.c_structs[msg_type.c_struct_name]
|
||||||
|
|
||||||
code = msg_type.get_c_struct()
|
name = "mesh_msg_%s_t" % (
|
||||||
|
msg_type.c_struct_name or normalize_name(getattr(msg_id, 'name', msg_type.__name__))
|
||||||
|
)
|
||||||
|
code = msg_type.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
|
||||||
if code:
|
if code:
|
||||||
struct_lines[msg_type.get_c_struct_name()] = (
|
|
||||||
"mesh_msg_%s_t %s;" % (
|
|
||||||
msg_type.get_c_struct_name(),
|
|
||||||
msg_type.get_c_struct_name().replace("_announce", ""),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(code)
|
print(code)
|
||||||
print()
|
print()
|
||||||
else:
|
else:
|
||||||
nodata.add(msg_type)
|
nodata.add(msg_type)
|
||||||
|
return
|
||||||
print("/** union between all message data structs */")
|
print("/** union between all message data structs */")
|
||||||
print("typedef union __packed {")
|
print("typedef union __packed {")
|
||||||
for line in struct_lines.values():
|
for line in struct_lines.values():
|
||||||
|
|
|
@ -8,8 +8,8 @@ import channels
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
|
|
||||||
from c3nav.mesh.utils import get_mesh_comm_group, indent_c
|
from c3nav.mesh.utils import get_mesh_comm_group, indent_c
|
||||||
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, HexFormat, LedConfig, LedConfigFormat,
|
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, FixedHexFormat, LedConfig, LedConfig,
|
||||||
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat)
|
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType)
|
||||||
|
|
||||||
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
|
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
|
||||||
MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff'
|
MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff'
|
||||||
|
@ -37,6 +37,8 @@ class MeshMessageType(IntEnum):
|
||||||
CONFIG_LED = 0x13
|
CONFIG_LED = 0x13
|
||||||
CONFIG_UPLINK = 0x14
|
CONFIG_UPLINK = 0x14
|
||||||
|
|
||||||
|
LOCATE_REPORT_RANGE = 0x20
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar('M', bound='MeshMessage')
|
M = TypeVar('M', bound='MeshMessage')
|
||||||
|
|
||||||
|
@ -48,22 +50,16 @@ class ChipType(IntEnum):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MeshMessage:
|
class MeshMessage(StructType, union_type_field="msg_id"):
|
||||||
dst: str = field(metadata={"format": MacAddressFormat()})
|
dst: str = field(metadata={"format": MacAddressFormat()})
|
||||||
src: str = field(metadata={"format": MacAddressFormat()})
|
src: str = field(metadata={"format": MacAddressFormat()})
|
||||||
msg_id: int = field(metadata={"format": SimpleFormat('B')}, init=False, repr=False)
|
msg_id: int = field(metadata={"format": SimpleFormat('B')}, init=False, repr=False)
|
||||||
msg_types = {}
|
|
||||||
c_structs = {}
|
c_structs = {}
|
||||||
c_struct_name = None
|
c_struct_name = None
|
||||||
|
|
||||||
# noinspection PyMethodOverriding
|
# noinspection PyMethodOverriding
|
||||||
def __init_subclass__(cls, /, msg_id=None, c_struct_name=None, **kwargs):
|
def __init_subclass__(cls, /, c_struct_name=None, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
if msg_id is not None:
|
|
||||||
cls.msg_id = msg_id
|
|
||||||
if msg_id in MeshMessage.msg_types:
|
|
||||||
raise TypeError('duplicate use of msg_id %d' % msg_id)
|
|
||||||
MeshMessage.msg_types[msg_id] = cls
|
|
||||||
if c_struct_name:
|
if c_struct_name:
|
||||||
cls.c_struct_name = c_struct_name
|
cls.c_struct_name = c_struct_name
|
||||||
if c_struct_name in MeshMessage.c_structs:
|
if c_struct_name in MeshMessage.c_structs:
|
||||||
|
@ -117,40 +113,10 @@ class MeshMessage:
|
||||||
def get_additional_c_fields(self):
|
def get_additional_c_fields(self):
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_c_struct(cls):
|
|
||||||
ignore_fields = cls.get_ignore_c_fields()
|
|
||||||
if cls != MeshMessage:
|
|
||||||
ignore_fields |= set(field.name for field in fields(MeshMessage))
|
|
||||||
|
|
||||||
items = tuple(
|
|
||||||
(
|
|
||||||
tuple(field.metadata["format"].get_c_struct(field.metadata.get("c_name", field.name)).split("\n")),
|
|
||||||
field.metadata.get("doc", None),
|
|
||||||
)
|
|
||||||
for field in fields(cls)
|
|
||||||
if field.name not in ignore_fields
|
|
||||||
)
|
|
||||||
if not items:
|
|
||||||
return ""
|
|
||||||
max_line_len = max(len(line) for line in chain(*(code for code, doc in items)))
|
|
||||||
|
|
||||||
msg_comment = cls.__doc__.strip()
|
|
||||||
|
|
||||||
return "%(comment)stypedef struct __packed {\n%(elements)s\n} %(name)s;" % {
|
|
||||||
"comment": ("/** %s */\n" % msg_comment) if msg_comment else "",
|
|
||||||
"elements": indent_c(
|
|
||||||
"\n".join(chain(*(
|
|
||||||
(code if not comment
|
|
||||||
else (code[:-1]+("%s /** %s */" % (code[-1].ljust(max_line_len), comment),)))
|
|
||||||
for code, comment in items
|
|
||||||
), cls.get_additional_c_fields()))
|
|
||||||
),
|
|
||||||
"name": "mesh_msg_%s_t" % cls.get_c_struct_name(),
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_var_num(cls):
|
def get_var_num(cls):
|
||||||
|
return 0
|
||||||
|
# todo: fix
|
||||||
return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0)
|
return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -172,6 +138,10 @@ class MeshMessage:
|
||||||
cls.__name__.removeprefix('Mesh').removesuffix('Message')
|
cls.__name__.removeprefix('Mesh').removesuffix('Message')
|
||||||
).upper().replace('CONFIG', 'CFG').replace('FIRMWARE', 'FW').replace('POSITION', 'POS')
|
).upper().replace('CONFIG', 'CFG').replace('FIRMWARE', 'FW').replace('POSITION', 'POS')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_msg_types(cls):
|
||||||
|
return cls._union_options["msg_id"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NoopMessage(MeshMessage, msg_id=MeshMessageType.NOOP):
|
class NoopMessage(MeshMessage, msg_id=MeshMessageType.NOOP):
|
||||||
|
@ -291,7 +261,7 @@ class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE)
|
||||||
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
|
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
|
||||||
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
|
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
|
||||||
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
|
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
|
||||||
app_elf_sha256: str = field(metadata={"format": HexFormat(32)})
|
app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)})
|
||||||
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
|
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -327,7 +297,7 @@ class ConfigPositionMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_POSITION)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED):
|
class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED):
|
||||||
""" set/respond led config """
|
""" set/respond led config """
|
||||||
led_config: LedConfig = field(metadata={"format": LedConfigFormat()})
|
led_config: LedConfig = field(metadata={})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -341,3 +311,10 @@ class ConfigUplinkMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_UPLINK):
|
||||||
ssl: bool = field(metadata={"format": BoolFormat()})
|
ssl: bool = field(metadata={"format": BoolFormat()})
|
||||||
host: str = field(metadata={"format": FixedStrFormat(64)})
|
host: str = field(metadata={"format": FixedStrFormat(64)})
|
||||||
port: int = field(metadata={"format": SimpleFormat('H')})
|
port: int = field(metadata={"format": SimpleFormat('H')})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocateReportRangeMessage(MeshMessage, msg_id=MeshMessageType.LOCATE_REPORT_RANGE):
|
||||||
|
""" report distance to given nodes """
|
||||||
|
#ranges: dict[str, int] =
|
||||||
|
pass
|
Loading…
Add table
Add a link
Reference in a new issue