more mesh_msg export stuffs
This commit is contained in:
parent
7a13193acd
commit
16f47168a2
3 changed files with 85 additions and 42 deletions
|
@ -3,7 +3,6 @@ import struct
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import IntEnum, unique
|
||||
from itertools import chain
|
||||
from typing import Self, Sequence, Any
|
||||
|
||||
from c3nav.mesh.utils import indent_c
|
||||
|
@ -315,21 +314,25 @@ class StructType:
|
|||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False):
|
||||
def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False, union_only=False,
|
||||
union_member_as_types=False):
|
||||
ignore_fields = set() if not ignore_fields else set(ignore_fields)
|
||||
|
||||
pre = ""
|
||||
|
||||
items = []
|
||||
for field_ in fields(cls):
|
||||
if field_.name in ignore_fields:
|
||||
continue
|
||||
name = field_.metadata.get("c_name", field_.name)
|
||||
if "format" in field_.metadata:
|
||||
items.append((
|
||||
field_.metadata["format"].get_c_code(field_.name),
|
||||
field_.metadata["format"].get_c_code(name),
|
||||
field_.metadata.get("doc", None),
|
||||
)),
|
||||
elif issubclass(field_.type, StructType):
|
||||
items.append((
|
||||
field_.type.get_c_code(field_.name, typedef=False),
|
||||
field_.type.get_c_code(name, typedef=False),
|
||||
field_.metadata.get("doc", None),
|
||||
))
|
||||
else:
|
||||
|
@ -340,19 +343,39 @@ class StructType:
|
|||
parent_fields = set(field_.name for field_ in fields(cls))
|
||||
union_items = []
|
||||
for key, option in cls._union_options[cls.union_type_field].items():
|
||||
name = normalize_name(getattr(key, 'name', option.__name__))
|
||||
union_items.append(
|
||||
option.get_c_code(name, ignore_fields=(ignore_fields | parent_fields))
|
||||
base_name = normalize_name(getattr(key, 'name', option.__name__))
|
||||
if union_member_as_types:
|
||||
struct_name = cls.get_struct_name(base_name)
|
||||
pre += option.get_c_code(
|
||||
struct_name,
|
||||
ignore_fields=(ignore_fields | parent_fields),
|
||||
typedef=True
|
||||
)+"\n\n"
|
||||
union_items.append(
|
||||
"%s %s;" % (struct_name, cls.get_variable_name(base_name)),
|
||||
)
|
||||
else:
|
||||
union_items.append(
|
||||
option.get_c_code(base_name, ignore_fields=(ignore_fields | parent_fields))
|
||||
)
|
||||
union_items.append(
|
||||
"uint8_t bytes[%s];" % max(
|
||||
(option.get_min_size() for option in cls._union_options[cls.union_type_field].values()),
|
||||
default=0,
|
||||
)
|
||||
items.append((
|
||||
"union {\n"+indent_c("\n".join(union_items))+"\n}",
|
||||
""))
|
||||
)
|
||||
union_code = "{\n"+indent_c("\n".join(union_items))+"\n}",
|
||||
if union_only:
|
||||
return "typedef union __packed %s" % union_code, "";
|
||||
else:
|
||||
items.append(("union %s;" % union_code, ""))
|
||||
elif union_only:
|
||||
return "", ""
|
||||
|
||||
if no_empty and not items:
|
||||
return "", ""
|
||||
|
||||
# todo: struct comment
|
||||
pre = ""
|
||||
if typedef:
|
||||
comment = cls.__doc__.strip()
|
||||
if comment:
|
||||
|
@ -372,23 +395,31 @@ class StructType:
|
|||
return pre, ""
|
||||
|
||||
@classmethod
|
||||
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str:
|
||||
pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef)
|
||||
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False,
|
||||
union_member_as_types=False) -> str:
|
||||
pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef,
|
||||
union_only=union_only, union_member_as_types=union_member_as_types,
|
||||
)
|
||||
if no_empty and not pre and not post:
|
||||
return ""
|
||||
return "%s %s%s;" % (pre, name, post)
|
||||
|
||||
@classmethod
|
||||
def get_base_name(cls):
|
||||
return cls.__name__
|
||||
def get_variable_name(cls, base_name):
|
||||
return base_name
|
||||
|
||||
@classmethod
|
||||
def get_variable_name(cls):
|
||||
return cls.get_base_name()
|
||||
def get_struct_name(cls, base_name):
|
||||
return "%s_t" % base_name
|
||||
|
||||
@classmethod
|
||||
def get_struct_name(cls):
|
||||
return "%s_t" % cls.get_base_name()
|
||||
def get_min_size(cls) -> int:
|
||||
if cls.union_type_field:
|
||||
return (
|
||||
{f.name: field for f in fields()}[cls.union_type_field].metadata["format"].get_min_size() +
|
||||
sum((option.get_min_size() for option in cls._union_options[cls.union_type_field].values()), start=0)
|
||||
)
|
||||
return sum((f.metadata.get("format", f.type).get_min_size() for f in fields(cls)), start=0)
|
||||
|
||||
|
||||
class MacAddressFormat(FixedHexFormat):
|
||||
|
@ -426,6 +457,7 @@ class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN):
|
|||
gpio_blue: int = field(metadata={"format": SimpleFormat('B')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeItemType(StructType):
|
||||
address: str = field(metadata={"format": MacAddressFormat()})
|
||||
distance: int = field(metadata={"format": SimpleFormat('H')})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue