import re import struct from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from enum import IntEnum, unique from itertools import chain from typing import Self, Sequence, Any from c3nav.mesh.utils import indent_c MAC_FMT = '%02x:%02x:%02x:%02x:%02x:%02x' class BaseFormat(ABC): @abstractmethod def encode(self, value): pass @classmethod @abstractmethod def decode(cls, data) -> tuple[Any, bytes]: pass def fromjson(self, data): return data def tojson(self, data): return data @abstractmethod def get_min_size(self): pass @abstractmethod def get_c_parts(self) -> tuple[str, str]: pass def get_c_code(self, name) -> str: pre, post = self.get_c_parts() return "%s %s%s;" % (pre, name, post) class SimpleFormat(BaseFormat): def __init__(self, fmt): self.fmt = fmt self.size = struct.calcsize(fmt) self.c_type = self.c_types[self.fmt[-1]] self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1 def encode(self, value): return struct.pack(self.fmt, (value, ) if self.num == 1 else tuple(value)) def decode(self, data: bytes) -> tuple[Any, bytes]: value = struct.unpack(self.fmt, data[:self.size]) if len(value) == 1: value = value[0] return value, data[self.size:] def get_min_size(self): return self.size c_types = { "B": "uint8_t", "H": "uint16_t", "I": "uint32_t", "b": "int8_t", "h": "int16_t", "i": "int32_t", "s": "char", } def get_c_parts(self): return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num)) class BoolFormat(SimpleFormat): def __init__(self): super().__init__('B') def encode(self, value): return super().encode(int(value)) def decode(self, data: bytes) -> tuple[bool, bytes]: value, data = super().decode(data) return bool(value), data class FixedStrFormat(SimpleFormat): def __init__(self, num): self.num = num super().__init__('%ds' % self.num) def encode(self, value: str): return value.encode()[:self.num].ljust(self.num, bytes((0, ))), def decode(self, data: bytes) -> tuple[str, bytes]: return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:] class FixedHexFormat(SimpleFormat): def __init__(self, num, sep=''): self.num = num self.sep = sep super().__init__('%dB' % self.num) def encode(self, value: str): return super().encode(tuple(bytes.fromhex(value))) def decode(self, data: bytes) -> tuple[str, bytes]: return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:] @abstractmethod class BaseVarFormat(BaseFormat, ABC): def __init__(self, num_fmt='B'): self.num_fmt = num_fmt self.num_size = struct.calcsize(self.num_fmt) def get_min_size(self): return self.num_size def get_num_c_code(self): return SimpleFormat(self.num_fmt).get_c_code("num") class VarArrayFormat(BaseVarFormat): def __init__(self, child_type, num_fmt='B'): super().__init__(num_fmt=num_fmt) self.child_type = child_type self.child_size = self.child_type.get_min_size() def encode(self, values: Sequence) -> bytes: data = struct.pack(self.num_fmt, (len(values),)) for value in values: data += self.child_type.encode(value) return data def decode(self, data: bytes) -> tuple[list[Any], bytes]: num = struct.unpack(self.num_fmt, data[:self.num_size])[0] return [ self.child_type.decode(data[i:i+self.child_size]) for i in range(self.num_size, self.num_size+num*self.child_size, self.child_size) ], data[self.num_size+num*self.child_size:] def get_c_parts(self): pre, post = self.child_type.get_c_parts() return super().get_num_c_code()+"\n"+pre, "[0]"+post class VarStrFormat(BaseVarFormat): def encode(self, value: str) -> bytes: return struct.pack(self.num_fmt, (len(str),))+value.encode() def decode(self, data: bytes) -> tuple[str, bytes]: num = struct.unpack(self.num_fmt, data[:self.num_size])[0] return data[self.num_size:self.num_size+num].rstrip(bytes((0,))).decode(), data[self.num_size+num:] def get_c_parts(self): return super().get_num_c_code()+"\n"+"char", "[0]" """ TPYES """ def normalize_name(name): if '_' in name: return name.lower() return re.sub( r"([a-z])([A-Z])", r"\1_\2", name ).lower() @dataclass class StructType: _union_options = {} union_type_field = None # noinspection PyMethodOverriding def __init_subclass__(cls, /, union_type_field=None, **kwargs): cls.union_type_field = union_type_field if union_type_field: if union_type_field in cls._union_options: raise TypeError('Duplicate union_type_field: %s', union_type_field) cls._union_options[union_type_field] = {} for key, values in cls._union_options.items(): value = kwargs.pop(key, None) if value is not None: if value in values: raise TypeError('Duplicate %s: %s', (key, value)) values[value] = cls setattr(cls, key, value) super().__init_subclass__(**kwargs) @classmethod def encode(cls, instance) -> bytes: data = bytes() if cls.union_type_field and type(instance) is not cls: if not isinstance(instance, cls): raise ValueError('expected value of type %r, got %r' % (cls, instance)) for field_ in fields(instance): if field_.name is cls.union_type_field: data += field_.metadata["format"].encode(getattr(instance, field_.name)) break else: raise TypeError('couldn\'t find %s value' % cls.union_type_field) data += instance.encode(instance) return data for field_ in fields(cls): value = getattr(instance, field_.name) if "format" in field_.metadata: data += field_.metadata["format"].encode(value) elif issubclass(field_.type, StructType): if not isinstance(value, field_.type): raise ValueError('expected value of type %r for %s.%s, got %r' % (field_.type, cls.__name__, field_.name, value)) data += value.encode(value) else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__class__.__name__, field_.name)) return data @classmethod def decode(cls, data: bytes) -> Self: values = {} for field_ in fields(cls): if "format" in field_.metadata: data = field_.metadata["format"].decode(data) elif issubclass(field_.type, StructType): data = field_.type.decode(data) else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__name__, field_.name)) values[field_.name] = field_.metadata["format"].decode(data) if cls.union_type_field: try: type_value = values[cls.union_type_field] except KeyError: raise TypeError('union_type_field %s.%s is missing' % (cls.__name__, cls.union_type_field)) try: klass = cls._union_options[type_value] except KeyError: raise TypeError('union_type_field %s.%s value %r no known' % (cls.__name__, cls.union_type_field, type_value)) return klass.decode(data) return cls(**values) @classmethod def tojson(cls, instance) -> dict: result = {} if cls.union_type_field and type(instance) is not cls: if not isinstance(instance, cls): raise ValueError('expected value of type %r, got %r' % (cls, instance)) for field_ in fields(instance): if field_.name is cls.union_type_field: result[field_.name] = field_.metadata["format"].encode(getattr(instance, field_.name)) break else: raise TypeError('couldn\'t find %s value' % cls.union_type_field) result.update(instance.tojson(instance)) return result for field_ in fields(cls): value = getattr(instance, field_.name) if "format" in field_.metadata: result[field_.name] = field_.metadata["format"].tojson(value) elif issubclass(field_.type, StructType): if not isinstance(value, field_.type): raise ValueError('expected value of type %r for %s.%s, got %r' % (field_.type, cls.__name__, field_.name, value)) result[field_.name] = value.tojson(value) else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__class__.__name__, field_.name)) return result @classmethod def fromjson(cls, data): data = data.copy() # todo: upgrade_json kwargs = {} for field_ in fields(cls): if "format" in field_.metadata: data = field_.metadata["format"].decode(data) elif issubclass(field_.type, StructType): data = field_.type.decode(data) else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__name__, field_.name)) kwargs[field_.name], data = field_.metadata["format"].decode(data) if cls.union_type_field: try: type_value = kwargs[cls.union_type_field] except KeyError: raise TypeError('union_type_field %s.%s is missing' % (cls.__name__, cls.union_type_field)) try: klass = cls._union_options[type_value] except KeyError: raise TypeError('union_type_field %s.%s value %r no known' % (cls.__name__, cls.union_type_field, type_value)) return klass.fromjson(data) return cls(**kwargs) @classmethod def get_c_parts(cls, ignore_fields=None, no_empty=False, typedef=False): ignore_fields = set() if not ignore_fields else set(ignore_fields) items = [] for field_ in fields(cls): if field_.name in ignore_fields: continue if "format" in field_.metadata: items.append(( field_.metadata["format"].get_c_code(field_.name), field_.metadata.get("doc", None), )), elif issubclass(field_.type, StructType): items.append(( field_.type.get_c_code(field_.name, typedef=False), field_.metadata.get("doc", None), )) else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__name__, field_.name)) if cls.union_type_field: parent_fields = set(field_.name for field_ in fields(cls)) union_items = [] for key, option in cls._union_options[cls.union_type_field].items(): name = normalize_name(getattr(key, 'name', option.__name__)) union_items.append( option.get_c_code(name, ignore_fields=(ignore_fields | parent_fields)) ) items.append(( "union {\n"+indent_c("\n".join(union_items))+"\n}", "")) if no_empty and not items: return "", "" # todo: struct comment pre = "" if typedef: comment = cls.__doc__.strip() if comment: pre += "/** %s */\n" % comment pre += "typedef struct __packed " else: pre += "struct " pre += "{\n%(elements)s\n}" % { "elements": indent_c( "\n".join( code + ("" if not comment else (" /** %s */" % comment)) for code, comment in items ) ), } return pre, "" @classmethod def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str: pre, post = cls.get_c_parts(ignore_fields=ignore_fields, no_empty=no_empty, typedef=typedef) if no_empty and not pre and not post: return "" return "%s %s%s;" % (pre, name, post) @classmethod def get_base_name(cls): return cls.__name__ @classmethod def get_variable_name(cls): return cls.get_base_name() @classmethod def get_struct_name(cls): return "%s_t" % cls.get_base_name() class MacAddressFormat(FixedHexFormat): def __init__(self): super().__init__(num=6, sep=':') class MacAddressesListFormat(VarArrayFormat): def __init__(self): super().__init__(child_type=MacAddressFormat()) """ stuff """ @unique class LedType(IntEnum): SERIAL = 1 MULTIPIN = 2 @dataclass class LedConfig(StructType, union_type_field="led_type"): led_type: LedType = field(init=False, repr=False, metadata={"format": SimpleFormat('B')}) @dataclass class SerialLedConfig(LedConfig, StructType, led_type=LedType.SERIAL): gpio: int = field(metadata={"format": SimpleFormat('B')}) rmt: int = field(metadata={"format": SimpleFormat('B')}) @dataclass class MultipinLedConfig(LedConfig, StructType, led_type=LedType.MULTIPIN): gpio_red: int = field(metadata={"format": SimpleFormat('B')}) gpio_green: int = field(metadata={"format": SimpleFormat('B')}) gpio_blue: int = field(metadata={"format": SimpleFormat('B')}) class RangeItemType(StructType): address: str = field(metadata={"format": MacAddressFormat()}) distance: int = field(metadata={"format": SimpleFormat('H')})