improve c type generation

This commit is contained in:
Gwendolyn 2023-10-06 02:10:17 +02:00
parent 59ebdd74bb
commit 6ed57d99d2
4 changed files with 518 additions and 31 deletions

View file

@ -0,0 +1,477 @@
import re
import struct
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, field, Field
from typing import Any, Sequence, Self
from c3nav.mesh.utils import indent_c
class BaseFormat(ABC):
def get_var_num(self):
return 0
@abstractmethod
def encode(self, value):
pass
@classmethod
@abstractmethod
def decode(cls, data) -> tuple[Any, bytes]:
pass
def fromjson(self, data):
return data
def tojson(self, data):
return data
@abstractmethod
def get_min_size(self):
pass
@abstractmethod
def get_c_parts(self) -> tuple[str, str]:
pass
def get_c_code(self, name) -> str:
pre, post = self.get_c_parts()
return "%s %s%s;" % (pre, name, post)
class SimpleFormat(BaseFormat):
def __init__(self, fmt):
self.fmt = fmt
self.size = struct.calcsize(fmt)
self.c_type = self.c_types[self.fmt[-1]]
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
def encode(self, value):
return struct.pack(self.fmt, (value,) if self.num == 1 else tuple(value))
def decode(self, data: bytes) -> tuple[Any, bytes]:
value = struct.unpack(self.fmt, data[:self.size])
if len(value) == 1:
value = value[0]
return value, data[self.size:]
def get_min_size(self):
return self.size
c_types = {
"B": "uint8_t",
"H": "uint16_t",
"I": "uint32_t",
"b": "int8_t",
"h": "int16_t",
"i": "int32_t",
"s": "char",
}
def get_c_parts(self):
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
class BoolFormat(SimpleFormat):
def __init__(self):
super().__init__('B')
def encode(self, value):
return super().encode(int(value))
def decode(self, data: bytes) -> tuple[bool, bytes]:
value, data = super().decode(data)
return bool(value), data
class FixedStrFormat(SimpleFormat):
def __init__(self, num):
self.num = num
super().__init__('%ds' % self.num)
def encode(self, value: str):
return value.encode()[:self.num].ljust(self.num, bytes((0,))),
def decode(self, data: bytes) -> tuple[str, bytes]:
return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:]
class FixedHexFormat(SimpleFormat):
def __init__(self, num, sep=''):
self.num = num
self.sep = sep
super().__init__('%dB' % self.num)
def encode(self, value: str):
return super().encode(tuple(bytes.fromhex(value)))
def decode(self, data: bytes) -> tuple[str, bytes]:
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
@abstractmethod
class BaseVarFormat(BaseFormat, ABC):
def __init__(self, num_fmt='B'):
self.num_fmt = num_fmt
self.num_size = struct.calcsize(self.num_fmt)
def get_min_size(self):
return self.num_size
def get_num_c_code(self):
return SimpleFormat(self.num_fmt).get_c_code("num")
class VarArrayFormat(BaseVarFormat):
def __init__(self, child_type, num_fmt='B'):
super().__init__(num_fmt=num_fmt)
self.child_type = child_type
self.child_size = self.child_type.get_min_size()
def get_var_num(self):
return self.child_size
pass
def encode(self, values: Sequence) -> bytes:
data = struct.pack(self.num_fmt, (len(values),))
for value in values:
data += self.child_type.encode(value)
return data
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
return [
self.child_type.decode(data[i:i + self.child_size])
for i in range(self.num_size, self.num_size + num * self.child_size, self.child_size)
], data[self.num_size + num * self.child_size:]
def get_c_parts(self):
pre, post = self.child_type.get_c_parts()
return super().get_num_c_code() + "\n" + pre, "[0]" + post
class VarStrFormat(BaseVarFormat):
def get_var_num(self):
return 1
def encode(self, value: str) -> bytes:
return struct.pack(self.num_fmt, (len(str),)) + value.encode()
def decode(self, data: bytes) -> tuple[str, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:]
def get_c_parts(self):
return super().get_num_c_code() + "\n" + "char", "[0]"
@dataclass
class StructType:
_union_options = {}
union_type_field = None
# noinspection PyMethodOverriding
def __init_subclass__(cls, /, union_type_field=None, no_c_type=False, **kwargs):
cls.union_type_field = union_type_field
if union_type_field:
if union_type_field in cls._union_options:
raise TypeError('Duplicate union_type_field: %s', union_type_field)
cls._union_options[union_type_field] = {}
f = getattr(cls, union_type_field)
metadata = dict(f.metadata)
metadata['union_discriminator'] = True
f.metadata = metadata
f.repr = False
f.init = False
for attr_name in cls.__dict__.keys():
attr = getattr(cls, attr_name)
if isinstance(attr, Field):
metadata = dict(attr.metadata)
if "defining_class" not in metadata:
metadata["defining_class"] = cls
attr.metadata = metadata
for key, values in cls._union_options.items():
value = kwargs.pop(key, None)
if value is not None:
if value in values:
raise TypeError('Duplicate %s: %s', (key, value))
values[value] = cls
setattr(cls, key, value)
super().__init_subclass__(**kwargs)
@classmethod
def get_var_num(cls):
return sum([f.metadata.get("format", f.type).get_var_num() for f in fields(cls)], start=0)
@classmethod
def encode(cls, instance) -> bytes:
data = bytes()
if cls.union_type_field and type(instance) is not cls:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(instance):
if field_.name is cls.union_type_field:
data += field_.metadata["format"].encode(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
data += instance.encode(instance)
return data
for field_ in fields(cls):
value = getattr(instance, field_.name)
if "format" in field_.metadata:
data += field_.metadata["format"].encode(value)
elif issubclass(field_.type, StructType):
if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value))
data += value.encode(value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, field_.name))
return data
@classmethod
def decode(cls, data: bytes) -> Self:
values = {}
for field_ in fields(cls):
if "format" in field_.metadata:
data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType):
data = field_.type.decode(data)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
values[field_.name] = field_.metadata["format"].decode(data)
if cls.union_type_field:
try:
type_value = values[cls.union_type_field]
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls._union_options[type_value]
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.decode(data)
return cls(**values)
@classmethod
def tojson(cls, instance) -> dict:
result = {}
if cls.union_type_field and type(instance) is not cls:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in fields(instance):
if field_.name is cls.union_type_field:
result[field_.name] = field_.metadata["format"].encode(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
result.update(instance.tojson(instance))
return result
for field_ in fields(cls):
value = getattr(instance, field_.name)
if "format" in field_.metadata:
result[field_.name] = field_.metadata["format"].tojson(value)
elif issubclass(field_.type, StructType):
if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value))
result[field_.name] = value.tojson(value)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, field_.name))
return result
@classmethod
def fromjson(cls, data):
data = data.copy()
# todo: upgrade_json
kwargs = {}
for field_ in fields(cls):
if "format" in field_.metadata:
data = field_.metadata["format"].decode(data)
elif issubclass(field_.type, StructType):
data = field_.type.decode(data)
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
kwargs[field_.name], data = field_.metadata["format"].decode(data)
if cls.union_type_field:
try:
type_value = kwargs[cls.union_type_field]
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls._union_options[type_value]
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.fromjson(data)
return cls(**kwargs)
@classmethod
def get_c_struct_items(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False):
ignore_fields = set() if not ignore_fields else set(ignore_fields)
items = []
for field_ in fields(cls):
if field_.name in ignore_fields:
continue
if in_union and field_.metadata["defining_class"] != cls:
continue
name = field_.metadata.get("c_name", field_.name)
if "format" in field_.metadata:
if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls:
items.append((
field_.metadata["format"].get_c_code(name),
field_.metadata.get("doc", None),
)),
elif issubclass(field_.type, StructType):
if field_.metadata.get("c_embed"):
embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only)
items.extend(embedded_items)
else:
items.append((
field_.type.get_c_code(name, typedef=False),
field_.metadata.get("doc", None),
))
else:
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__name__, field_.name))
if cls.union_type_field:
if not union_only:
union_code = cls.get_c_union_code(ignore_fields)
items.append(("union __packed %s;" % union_code, ""))
return items
@classmethod
def get_c_union_size(cls):
return max(
(option.get_min_size(no_inherited_fields=True) for option in
cls._union_options[cls.union_type_field].values()),
default=0,
)
@classmethod
def get_c_union_code(cls, ignore_fields=None):
union_items = []
for key, option in cls._union_options[cls.union_type_field].items():
base_name = normalize_name(getattr(key, 'name', option.__name__))
union_items.append(
option.get_c_code(base_name, ignore_fields=ignore_fields, typedef=False, in_union=True)
)
size = cls.get_c_union_size()
union_items.append(
"uint8_t bytes[%0d]; " % size
)
return "{\n" + indent_c("\n".join(union_items)) + "\n}"
@classmethod
def get_c_parts(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False):
ignore_fields = set() if not ignore_fields else set(ignore_fields)
if union_only:
if cls.union_type_field:
union_code = cls.get_c_union_code(ignore_fields)
return "typedef union __packed %s" % union_code, ""
else:
return "", ""
pre = ""
items = cls.get_c_struct_items(ignore_fields=ignore_fields,
no_empty=no_empty,
top_level=top_level,
union_only=union_only,
in_union=in_union)
if no_empty and not items:
return "", ""
# todo: struct comment
if top_level:
comment = cls.__doc__.strip()
if comment:
pre += "/** %s */\n" % comment
pre += "typedef struct __packed "
else:
pre += "struct __packed "
pre += "{\n%(elements)s\n}" % {
"elements": indent_c(
"\n".join(
code + ("" if not comment else (" /** %s */" % comment))
for code, comment in items
)
),
}
return pre, ""
@classmethod
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False,
in_union=False) -> str:
pre, post = cls.get_c_parts(ignore_fields=ignore_fields,
no_empty=no_empty,
top_level=typedef,
union_only=union_only,
in_union=in_union)
if no_empty and not pre and not post:
return ""
return "%s %s%s;" % (pre, name, post)
@classmethod
def get_variable_name(cls, base_name):
return base_name
@classmethod
def get_struct_name(cls, base_name):
return "%s_t" % base_name
@classmethod
def get_min_size(cls, no_inherited_fields=False) -> int:
if cls.union_type_field:
own_size = sum([f.metadata.get("format", f.type).get_min_size() for f in fields(cls)])
union_size = max(
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
return own_size + union_size
if no_inherited_fields:
relevant_fields = [f for f in fields(cls) if f.metadata["defining_class"] == cls]
else:
relevant_fields = [f for f in fields(cls) if not f.metadata.get("union_discriminator")]
return sum((f.metadata.get("format", f.type).get_min_size() for f in relevant_fields), start=0)
def normalize_name(name):
if '_' in name:
return name.lower()
return re.sub(
r"([a-z])([A-Z])",
r"\1_\2",
name
).lower()

View file

@ -462,26 +462,27 @@ class MacAddressesListFormat(VarArrayFormat):
super().__init__(child_type=MacAddressFormat())
""" stuff """
@unique
class LedType(IntEnum):
SERIAL = 1
MULTIPIN = 2
@dataclass
class LedConfig(StructType, union_type_field="led_type"):
led_type: LedType = field(init=False, repr=False, metadata={"format": SimpleFormat('B')})
led_type: LedType = field(metadata={"format": SimpleFormat('B')})
leds_are_cool: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class SerialLedConfig(LedConfig, StructType, led_type=LedType.SERIAL):
class SerialLedConfig(LedConfig, led_type=LedType.SERIAL):
gpio: int = field(metadata={"format": SimpleFormat('B')})
rmt: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN):
class MultipinLedConfig(LedConfig, led_type=LedType.MULTIPIN):
gpio_red: int = field(metadata={"format": SimpleFormat('B')})
gpio_green: int = field(metadata={"format": SimpleFormat('B')})
gpio_blue: int = field(metadata={"format": SimpleFormat('B')})
@ -491,3 +492,17 @@ class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN):
class RangeItemType(StructType):
address: str = field(metadata={"format": MacAddressFormat()})
distance: int = field(metadata={"format": SimpleFormat('H')})
@dataclass
class FirmwareAppDescription(StructType, no_c_type=True):
magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
secure_version: int = field(metadata={"format": SimpleFormat('I')})
reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
version: str = field(metadata={"format": FixedStrFormat(32)})
project_name: str = field(metadata={"format": FixedStrFormat(32)})
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)})
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)

View file

@ -2,7 +2,8 @@ from dataclasses import fields
from django.core.management.base import BaseCommand
from c3nav.mesh.dataformats import normalize_name, LedConfig
from c3nav.mesh.dataformats import LedConfig
from c3nav.mesh.baseformats import normalize_name
from c3nav.mesh.messages import MeshMessage, MeshMessageType
from c3nav.mesh.utils import indent_c
@ -21,6 +22,7 @@ class Command(BaseCommand):
done_struct_names = set()
nodata = set()
struct_lines = {}
struct_sizes = []
ignore_names = set(field_.name for field_ in fields(MeshMessage))
for msg_id, msg_type in MeshMessage.get_types().items():
@ -35,13 +37,13 @@ class Command(BaseCommand):
)))
name = "mesh_msg_%s_t" % base_name
if msg_id == MeshMessageType.CONFIG_LED:
msg_type = LedConfig
code = msg_type.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
if code:
size = msg_type.get_min_size(no_inherited_fields=True)
struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', ''))
struct_sizes.append(size)
print(code)
print("static_assert(sizeof(%s) == %d, \"size of generated message structs is calculated wrong\");" % (name, size))
print()
else:
nodata.add(msg_type)
@ -51,7 +53,7 @@ class Command(BaseCommand):
for line in struct_lines.values():
print(indent_c(line))
print("} mesh_msg_data_t; ")
print()
print("static_assert(sizeof(mesh_msg_data_t) == %d, \"size of generated message structs is calculated wrong\");" % max(struct_sizes))
max_msg_type = max(MeshMessage.get_types().keys())
macro_data = []

View file

@ -8,9 +8,10 @@ import channels
from asgiref.sync import async_to_sync
from c3nav.mesh.utils import get_mesh_comm_group, indent_c
from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, FixedHexFormat, LedConfig, LedConfig,
MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat, StructType,
VarArrayFormat, RangeItemType)
from c3nav.mesh.dataformats import (LedConfig, LedConfig,
MacAddressesListFormat, MacAddressFormat, RangeItemType, FirmwareAppDescription)
from c3nav.mesh.baseformats import SimpleFormat, BoolFormat, FixedStrFormat, FixedHexFormat, VarArrayFormat, \
VarStrFormat, StructType
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
MESH_PARENT_ADDRESS = '00:00:00:ff:ff:ff'
@ -31,6 +32,7 @@ class MeshMessageType(IntEnum):
MESH_ROUTE_REQUEST = 0x07
MESH_ROUTE_RESPONSE = 0x08
MESH_ROUTE_TRACE = 0x09
MESH_ROUTING_FAILED = 0x0a
CONFIG_DUMP = 0x10
CONFIG_FIRMWARE = 0x11
@ -83,12 +85,6 @@ class MeshMessage(StructType, union_type_field="msg_id"):
def get_additional_c_fields(self):
return ()
@classmethod
def get_var_num(cls):
return 0
# todo: fix
return sum((getattr(field.metadata["format"], "var_num", 0) for field in fields(cls)), start=0)
@classmethod
def get_variable_name(cls, base_name):
return cls.c_struct_name or base_name
@ -201,6 +197,12 @@ class MeshRouteTraceMessage(MeshMessage, msg_id=MeshMessageType.MESH_ROUTE_TRACE
})
@dataclass
class MeshRoutingFailedMessage(MeshMessage, msg_id=MeshMessageType.MESH_ROUTING_FAILED):
""" TODO description"""
address: str = field(metadata={"format": MacAddressFormat()})
@dataclass
class ConfigDumpMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_DUMP):
""" request for the node to dump its config """
@ -216,16 +218,7 @@ class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE)
})
revision_major: int = field(metadata={"format": SimpleFormat('B')})
revision_minor: int = field(metadata={"format": SimpleFormat('B')})
magic_word: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
secure_version: int = field(metadata={"format": SimpleFormat('I')})
reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
version: str = field(metadata={"format": FixedStrFormat(32)})
project_name: str = field(metadata={"format": FixedStrFormat(32)})
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)})
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
app_desc: FirmwareAppDescription = field(metadata={'json_embed': True})
@classmethod
def upgrade_json(cls, data):
@ -259,7 +252,7 @@ class ConfigPositionMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_POSITION)
@dataclass
class ConfigLedMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_LED):
""" set/respond led config """
led_config: LedConfig = field(metadata={})
led_config: LedConfig = field(metadata={"c_embed": True})
@dataclass