From 0fd789173a01371151cb57a728624c1b32f5a786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Tue, 27 Feb 2024 18:25:18 +0100 Subject: [PATCH] change all of the MeshMessages c-models from dataclasses to pydantic --- src/c3nav/api/utils.py | 49 +- src/c3nav/mesh/api.py | 10 +- src/c3nav/mesh/baseformats.py | 810 --------------- src/c3nav/mesh/cformats.py | 961 ++++++++++++++++++ src/c3nav/mesh/consumers.py | 135 +-- src/c3nav/mesh/dataformats.py | 285 ------ src/c3nav/mesh/forms.py | 79 +- .../management/commands/generate_c_types.py | 67 +- src/c3nav/mesh/messages.py | 630 ++++++------ src/c3nav/mesh/models.py | 4 +- src/c3nav/mesh/schemas.py | 284 ++++++ src/c3nav/mesh/utils.py | 2 +- src/c3nav/routing/api/positioning.py | 2 +- src/requirements/production.txt | 1 + 14 files changed, 1691 insertions(+), 1628 deletions(-) delete mode 100644 src/c3nav/mesh/baseformats.py create mode 100644 src/c3nav/mesh/cformats.py delete mode 100644 src/c3nav/mesh/dataformats.py create mode 100644 src/c3nav/mesh/schemas.py diff --git a/src/c3nav/api/utils.py b/src/c3nav/api/utils.py index bbe5d7a3..780666a7 100644 --- a/src/c3nav/api/utils.py +++ b/src/c3nav/api/utils.py @@ -1,52 +1,5 @@ -from typing import Annotated, Any, Type +from typing import Annotated import annotated_types -from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import CoreSchema, core_schema - - -def get_api_post_data(request): - is_json = request.META.get('CONTENT_TYPE').lower() == 'application/json' - if is_json: - try: - data = request.json_body - except AttributeError: - pass # todo fix this raise ParseError('Invalid JSON.') - return data - return request.POST - - -class EnumSchemaByNameMixin: - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - json_schema = handler(core_schema) - json_schema = handler.resolve_ref_schema(json_schema) - json_schema["enum"] = [m.name for m in cls] - json_schema["type"] = "string" - return json_schema - - @classmethod - def __get_pydantic_core_schema__( - cls, source: Type[Any], handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - return core_schema.no_info_after_validator_function( - cls.validate, - core_schema.any_schema(), - serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x.name), - ) - - @classmethod - def validate(cls, v: int): - if isinstance(v, cls): - return v - try: - return cls[v] - except KeyError: - pass - return cls(v) - NonEmptyStr = Annotated[str, annotated_types.MinLen(1)] diff --git a/src/c3nav/mesh/api.py b/src/c3nav/mesh/api.py index b3aa7846..0082ac17 100644 --- a/src/c3nav/mesh/api.py +++ b/src/c3nav/mesh/api.py @@ -12,8 +12,8 @@ from pydantic import PositiveInt, field_validator from c3nav.api.auth import APIKeyAuth, auth_permission_responses, auth_responses, validate_responses from c3nav.api.exceptions import API404, APIConflict, APIRequestValidationFailed from c3nav.api.schema import BaseSchema -from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage -from c3nav.mesh.messages import MeshMessageType +from c3nav.mesh.schemas import BoardType, ChipType, FirmwareImage +from c3nav.mesh.messages import MeshMessageType, MeshMessage from c3nav.mesh.models import FirmwareBuild, FirmwareVersion, NodeMessage mesh_api_router = APIRouter(tags=["mesh"], auth=APIKeyAuth(permissions={"mesh_control"})) @@ -93,7 +93,7 @@ def firmware_by_id(request, firmware_id: int): @mesh_api_router.get('/firmwares/{firmware_id}/{variant}/image_data', summary="firmware image header", description="get firmware image header for specific firmware build", - response={200: FirmwareImage.schema, **API404.dict(), **auth_responses}, + response={200: FirmwareImage, **API404.dict(), **auth_responses}, openapi_extra={ "externalDocs": { 'description': 'esp-idf docs', @@ -105,7 +105,7 @@ def firmware_by_id(request, firmware_id: int): def firmware_build_image(request, firmware_id: int, variant: str): try: build = FirmwareBuild.objects.get(version_id=firmware_id, variant=variant) - return FirmwareImage.tojson(build.firmware_image) + return build.firmware_image.model_dump() except FirmwareVersion.DoesNotExist: raise API404("Firmware or firmware build not found") @@ -218,7 +218,7 @@ class NodeMessageSchema(BaseSchema): src_node: NodeAddress message_type: MeshMessageType datetime: datetime - data: dict + data: MeshMessage @staticmethod def resolve_src_node(obj): diff --git a/src/c3nav/mesh/baseformats.py b/src/c3nav/mesh/baseformats.py deleted file mode 100644 index 0b3f5de3..00000000 --- a/src/c3nav/mesh/baseformats.py +++ /dev/null @@ -1,810 +0,0 @@ -import re -import struct -import typing -from abc import ABC, abstractmethod -from dataclasses import Field, dataclass -from dataclasses import fields as dataclass_fields -from typing import Any, Self, Sequence - -from pydantic import create_model - -from c3nav.mesh.utils import indent_c - - -class BaseFormat(ABC): - - def get_var_num(self): - return 0 - - @abstractmethod - def encode(self, value): - pass - - @classmethod - @abstractmethod - def decode(cls, data) -> tuple[Any, bytes]: - pass - - def fromjson(self, data): - return data - - def tojson(self, data): - return data - - @abstractmethod - def get_min_size(self): - pass - - @abstractmethod - def get_max_size(self): - pass - - def get_size(self, calculate_max=False): - if calculate_max: - return self.get_max_size() - else: - return self.get_min_size() - - @abstractmethod - def get_c_parts(self) -> tuple[str, str]: - pass - - def get_c_code(self, name) -> str: - pre, post = self.get_c_parts() - return "%s %s%s;" % (pre, name, post) - - def set_field_type(self, field_type): - self.field_type = field_type - - def get_c_definitions(self) -> dict[str, str]: - return {} - - def get_typedef_name(self): - return '%s_t' % normalize_name(self.field_type.__name__) - - -class SimpleFormat(BaseFormat): - def __init__(self, fmt): - self.fmt = fmt - self.size = struct.calcsize(fmt) - - self.c_type = self.c_types[self.fmt[-1]] - self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1 - - def encode(self, value): - if self.num == 1: - return struct.pack(self.fmt, value) - return struct.pack(self.fmt, *value) - - def decode(self, data: bytes) -> tuple[Any, bytes]: - value = struct.unpack(self.fmt, data[:self.size]) - if len(value) == 1: - value = value[0] - return value, data[self.size:] - - def get_min_size(self): - return self.size - - def get_max_size(self): - return self.size - - c_types = { - "B": "uint8_t", - "H": "uint16_t", - "I": "uint32_t", - "Q": "uint64_t", - "b": "int8_t", - "h": "int16_t", - "i": "int32_t", - "q": "int64_t", - "s": "char", - } - - def get_c_parts(self): - return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num)) - - -class SimpleConstFormat(SimpleFormat): - def __init__(self, fmt, const_value: int): - super().__init__(fmt) - self.const_value = const_value - - def decode(self, data: bytes) -> tuple[Any, bytes]: - value, out_data = super().decode(data) - if value != self.const_value: - raise ValueError('const_value is wrong') - return value, out_data - - -class EnumFormat(SimpleFormat): - def __init__(self, fmt="B", *, as_hex=False, c_definition=True): - super().__init__(fmt) - self.as_hex = as_hex - self.c_definition = c_definition - - def set_field_type(self, field_type): - super().set_field_type(field_type) - self.c_struct_name = normalize_name(field_type.__name__) + '_t' - - def decode(self, data: bytes) -> tuple[Any, bytes]: - value, out_data = super().decode(data) - return self.field_type(value), out_data - - def get_c_parts(self): - if not self.c_definition: - return super().get_c_parts() - return self.c_struct_name, "" - - def fromjson(self, data): - return self.field_type[data] - - def tojson(self, data): - return data.name - - def get_c_definitions(self) -> dict[str, str]: - if not self.c_definition: - return {} - prefix = normalize_name(self.field_type.__name__).upper() - options = [] - last_value = None - for item in self.field_type: - if last_value is not None and item.value != last_value + 1: - options.append('') - last_value = item.value - options.append("%(prefix)s_%(name)s = %(value)s," % { - "prefix": prefix, - "name": normalize_name(item.name).upper(), - "value": ("0x%02x" if self.as_hex else "%d") % item.value - }) - - return { - self.c_struct_name: "enum {\n%(options)s\n};\ntypedef uint8_t %(name)s;" % { - "options": indent_c("\n".join(options)), - "name": self.c_struct_name, - } - } - - -class TwoNibblesEnumFormat(SimpleFormat): - def __init__(self): - super().__init__('B') - - def decode(self, data: bytes) -> tuple[bool, bytes]: - fields = dataclass_fields(self.field_type) - value, data = super().decode(data) - return self.field_type(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data - - def encode(self, value): - fields = dataclass_fields(self.field_type) - return super().encode( - getattr(value, fields[0].name).value * 2 ** 4 + - getattr(value, fields[1].name).value * 2 ** 4 - ) - - def fromjson(self, data): - fields = dataclass_fields(self.field_type) - return self.field_type(*(field.type[data[field.name]] for field in fields)) - - def tojson(self, data): - fields = dataclass_fields(self.field_type) - return { - field.name: getattr(data, field.name).name for field in fields - } - - -class ChipRevFormat(SimpleFormat): - def __init__(self): - super().__init__('H') - - def decode(self, data: bytes) -> tuple[tuple[int, int], bytes]: - value, data = super().decode(data) - return (value // 100, value % 100), data - - def encode(self, value): - return value[0] * 100 + value[1] - - -class BoolFormat(SimpleFormat): - def __init__(self): - super().__init__('B') - - def encode(self, value): - return super().encode(int(value)) - - def decode(self, data: bytes) -> tuple[bool, bytes]: - value, data = super().decode(data) - if value > 1: - raise ValueError('Boolean value > 1') - return bool(value), data - - -class FixedStrFormat(SimpleFormat): - def __init__(self, num): - self.num = num - super().__init__('%ds' % self.num) - - def encode(self, value: str): - return value.encode()[:self.num].ljust(self.num, bytes((0,))), - - def decode(self, data: bytes) -> tuple[str, bytes]: - return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:] - - -class FixedHexFormat(SimpleFormat): - def __init__(self, num, sep=''): - self.num = num - self.sep = sep - super().__init__('%dB' % self.num) - - def encode(self, value: str): - return super().encode(tuple(bytes.fromhex(value.replace(':', '')))) - - def decode(self, data: bytes) -> tuple[str, bytes]: - return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:] - - -@abstractmethod -class BaseVarFormat(BaseFormat, ABC): - def __init__(self, max_num): - self.num_fmt = 'H' - self.num_size = struct.calcsize(self.num_fmt) - self.max_num = max_num - - def get_min_size(self): - return self.num_size - - def get_max_size(self): - return self.num_size + self.max_num * self.get_var_num() - - def get_num_c_code(self): - return SimpleFormat(self.num_fmt).get_c_code("num") - - -class VarArrayFormat(BaseVarFormat): - def __init__(self, child_type, max_num): - super().__init__(max_num=max_num) - self.child_type = child_type - self.child_size = self.child_type.get_min_size() - - def get_var_num(self): - return self.child_size - pass - - def encode(self, values: Sequence) -> bytes: - num = len(values) - if num > self.max_num: - raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') - data = struct.pack(self.num_fmt, num) - for value in values: - data += self.child_type.encode(value) - return data - - def decode(self, data: bytes) -> tuple[list[Any], bytes]: - num = struct.unpack(self.num_fmt, data[:self.num_size])[0] - if num > self.max_num: - raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') - data = data[self.num_size:] - result = [] - for i in range(num): - item, data = self.child_type.decode(data) - result.append(item) - return result, data - - def fromjson(self, data): - return [ - self.child_type.fromjson(item) for item in data - ] - - def tojson(self, data): - return [ - self.child_type.tojson(item) for item in data - ] - - def get_c_parts(self): - pre, post = self.child_type.get_c_parts() - return super().get_num_c_code() + "\n" + pre, "[0]" + post - - -class VarStrFormat(BaseVarFormat): - def __init__(self, max_len): - super().__init__(max_num=max_len) - - def get_var_num(self): - return 1 - - def encode(self, value: str) -> bytes: - num = len(value) - if num > self.max_num: - raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') - return struct.pack(self.num_fmt, num) + value.encode() - - def decode(self, data: bytes) -> tuple[str, bytes]: - num = struct.unpack(self.num_fmt, data[:self.num_size])[0] - if num > self.max_num: - raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') - return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:] - - def get_c_parts(self): - return super().get_num_c_code() + "\n" + "char", "[0]" - - -class VarBytesFormat(BaseVarFormat): - def __init__(self, max_size): - super().__init__(max_num=max_size) - - def get_var_num(self): - return 1 - - def encode(self, value: bytes) -> bytes: - num = len(value) - if num > self.max_num: - raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') - return struct.pack(self.num_fmt, num) + value - - def decode(self, data: bytes) -> tuple[bytes, bytes]: - num = struct.unpack(self.num_fmt, data[:self.num_size])[0] - if num > self.max_num: - raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') - return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:] - - def get_c_parts(self): - return super().get_num_c_code() + "\n" + "uint8_t", "[0]" - - -@dataclass -class StructType: - _union_options = {} - union_type_field = None - existing_c_struct = None - c_includes = set() - - @classmethod - def get_field_format(cls, attr_name): - attr = getattr(cls, attr_name, None) - fields = [f for f in dataclass_fields(cls) if f.name == attr_name] - if not fields: - raise TypeError(f"{cls}.{attr_name} not a field") - field = fields[0] - type_ = typing.get_type_hints(cls)[attr_name] - if "format" in field.metadata: - field_format = field.metadata["format"] - field_format.set_field_type(type_) - return field_format - - if issubclass(type_, StructType): - return type_ - raise TypeError('field %s.%s has no format and is no StructType' % - (cls.__class__.__name__, attr_name)) - - # noinspection PyMethodOverriding - def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs): - cls.union_type_field = union_type_field - if c_includes is not None: - cls.c_includes |= set(c_includes) - if cls.existing_c_struct is not None: - # TODO: can we make it possible? does it even make sense? - raise TypeError('subclassing an external c struct is not possible') - cls.existing_c_struct = existing_c_struct - if union_type_field: - if union_type_field in cls._union_options: - raise TypeError('Duplicate union_type_field: %s', union_type_field) - cls._union_options[union_type_field] = {} - f = getattr(cls, union_type_field) - metadata = dict(f.metadata) - metadata['union_discriminator'] = True - f.metadata = metadata - f.repr = False - f.init = False - - for attr_name in cls.__dict__.keys(): - attr = getattr(cls, attr_name) - if isinstance(attr, Field): - metadata = dict(attr.metadata) - if "defining_class" not in metadata: - metadata["defining_class"] = cls - attr.metadata = metadata - - for key, values in cls._union_options.items(): - value = kwargs.pop(key, None) - if value is not None: - if value in values: - raise TypeError('Duplicate %s: %s', (key, value)) - values[value] = cls - setattr(cls, key, value) - - # pydantic model - cls._pydantic_fields = getattr(cls, '_pydantic_fields', {}).copy() - fields = [] - for field_ in dataclass_fields(cls): - fields.append((field_.name, field_.type, field_.metadata)) - for attr_name in tuple(cls.__annotations__.keys()): - attr = getattr(cls, attr_name, None) - metadata = attr.metadata if isinstance(attr, Field) else {} - try: - type_ = cls.__annotations__[attr_name] - except KeyError: - # print('nope', cls, attr_name) - continue - fields.append((attr_name, type_, metadata)) - for name, type_, metadata in fields: - try: - field_format = cls.get_field_format(name) - except TypeError: - # todo: in case of not a field, ignore it? - continue - if not (isinstance(field_format, type) and issubclass(field_format, StructType)): - cls._pydantic_fields[name] = (type_, ...) - else: - if metadata.get("json_embed"): - cls._pydantic_fields.update(type_._pydantic_fields) - else: - cls._pydantic_fields[name] = (type_.schema, ...) - cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields) - super().__init_subclass__(**kwargs) - - @classmethod - def get_var_num(cls): - return sum([cls.get_field_format(f.name).get_var_num() for f in dataclass_fields(cls)], start=0) - - @classmethod - def get_types(cls): - if not cls.union_type_field: - raise TypeError('Not a union class') - return cls._union_options[cls.union_type_field] - - @classmethod - def get_type(cls, type_id) -> Self: - if not cls.union_type_field: - raise TypeError('Not a union class') - return cls.get_types()[type_id] - - @classmethod - def encode(cls, instance, ignore_fields=()) -> bytes: - data = bytes() - if cls.union_type_field and type(instance) is not cls: - if not isinstance(instance, cls): - raise ValueError('expected value of type %r, got %r' % (cls, instance)) - - for field_ in dataclass_fields(cls): - data += cls.get_field_format(field_.name).encode(getattr(instance, field_.name)) - - # todo: better - data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls))) - return data - - for field_ in dataclass_fields(cls): - if field_.name in ignore_fields: - continue - value = getattr(instance, field_.name) - field_format = cls.get_field_format(field_.name) - if not (isinstance(field_format, type) and issubclass(field_format, StructType)): - data += field_format.encode(value) - else: - if not isinstance(value, field_.type): - raise ValueError('expected value of type %r for %s.%s, got %r' % - (field_.type, cls.__name__, field_.name, value)) - data += value.encode(value) - return data - - @classmethod - def decode(cls, data: bytes) -> tuple[Self, bytes]: - orig_data = data - kwargs = {} - no_init_data = {} - for field_ in dataclass_fields(cls): - field_format = cls.get_field_format(field_.name) - if not (isinstance(field_format, type) and issubclass(field_format, StructType)): - value, data = field_format.decode(data) - else: - value, data = field_.type.decode(data) - if field_.init: - kwargs[field_.name] = value - else: - no_init_data[field_.name] = value - - if cls.union_type_field: - try: - type_value = no_init_data[cls.union_type_field] - except KeyError: - raise TypeError('union_type_field %s.%s is missing' % - (cls.__name__, cls.union_type_field)) - try: - klass = cls.get_type(type_value) - except KeyError: - raise TypeError('union_type_field %s.%s value %r no known' % - (cls.__name__, cls.union_type_field, type_value)) - return klass.decode(orig_data) - return cls(**kwargs), data - - @classmethod - def tojson(cls, instance) -> dict: - result = {} - - if cls.union_type_field and type(instance) is not cls: - if not isinstance(instance, cls): - raise ValueError('expected value of type %r, got %r' % (cls, instance)) - - for field_ in dataclass_fields(instance): - if field_.name is cls.union_type_field: - result[field_.name] = cls.get_field_format(field_.name).tojson(getattr(instance, field_.name)) - break - else: - raise TypeError('couldn\'t find %s value' % cls.union_type_field) - - result.update(instance.tojson(instance)) - return result - - for field_ in dataclass_fields(cls): - value = getattr(instance, field_.name) - field_format = cls.get_field_format(field_.name) - if not (isinstance(field_format, type) and issubclass(field_format, StructType)): - result[field_.name] = field_format.tojson(value) - else: - if not isinstance(value, field_.type): - raise ValueError('expected value of type %r for %s.%s, got %r' % - (field_.type, cls.__name__, field_.name, value)) - json_val = value.tojson(value) - if field_.metadata.get("json_embed"): - for k, v in json_val.items(): - result[k] = v - else: - result[field_.name] = value.tojson(value) - return result - - @classmethod - def upgrade_json(cls, data): - return data - - @classmethod - def fromjson(cls, data: dict) -> Self: - data = data.copy() - - # todo: upgrade_json - cls.upgrade_json(data) - - kwargs = {} - no_init_data = {} - for field_ in dataclass_fields(cls): - raw_value = data.get(field_.name, None) - field_format = cls.get_field_format(field_.name) - if not (isinstance(field_format, type) and issubclass(field_format, StructType)): - value = field_format.fromjson(raw_value) - else: - if field_.metadata.get("json_embed"): - value = field_.type.fromjson(data) - else: - value = field_.type.fromjson(raw_value) - if field_.init: - kwargs[field_.name] = value - else: - no_init_data[field_.name] = value - - if cls.union_type_field: - try: - type_value = no_init_data.pop(cls.union_type_field) - except KeyError: - raise TypeError('union_type_field %s.%s is missing' % - (cls.__name__, cls.union_type_field)) - try: - klass = cls.get_type(type_value) - except KeyError: - raise TypeError('union_type_field %s.%s value 0x%02x no known' % - (cls.__name__, cls.union_type_field, type_value)) - return klass.fromjson(data) - - return cls(**kwargs) - - @classmethod - def get_c_struct_items(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False): - ignore_fields = set() if not ignore_fields else set(ignore_fields) - - items = [] - - for field_ in dataclass_fields(cls): - if field_.name in ignore_fields: - continue - if in_union and field_.metadata["defining_class"] != cls: - continue - - name = field_.metadata.get("c_name", field_.name) - field_format = cls.get_field_format(field_.name) - if not (isinstance(field_format, type) and issubclass(field_format, StructType)): - if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls: - items.append(( - ( - ("%(typedef_name)s %(name)s;" % { - "typedef_name": field_format.get_typedef_name(), - "name": name, - }) - if field_.metadata.get("as_definition") - else field_format.get_c_code(name) - ), - field_.metadata.get("doc", None), - )), - else: - if field_.metadata.get("c_embed"): - embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only) - items.extend(embedded_items) - else: - items.append(( - ( - ("%(typedef_name)s %(name)s;" % { - "typedef_name": field_.type.get_typedef_name(), - "name": name, - }) - if field_.metadata.get("as_definition") - else field_.type.get_c_code(name, typedef=False) - ), - field_.metadata.get("doc", None), - )) - - if cls.union_type_field: - if not union_only: - union_code = cls.get_c_union_code(ignore_fields) - items.append(("union __packed %s;" % union_code, "")) - - return items - - @classmethod - def get_c_union_size(cls): - return max( - (option.get_min_size(no_inherited_fields=True) for option in - cls._union_options[cls.union_type_field].values()), - default=0, - ) - - @classmethod - def get_c_definitions(cls) -> dict[str, str]: - definitions = {} - for field_ in dataclass_fields(cls): - field_format = cls.get_field_format(field_.name) - if not (isinstance(field_format, type) and issubclass(field_format, StructType)): - definitions.update(field_format.get_c_definitions()) - if field_.metadata.get("as_definition"): - typedef_name = field_format.get_typedef_name() - definitions[typedef_name] = 'typedef %(code)s %(name)s;' % { - "code": ''.join(field_format.get_c_parts()), - "name": typedef_name, - } - else: - definitions.update(field_.type.get_c_definitions()) - if field_.metadata.get("as_definition"): - typedef_name = field_.type.get_typedef_name() - definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True) - if cls.union_type_field: - for key, option in cls._union_options[cls.union_type_field].items(): - definitions.update(option.get_c_definitions()) - return definitions - - @classmethod - def get_c_union_code(cls, ignore_fields=None): - union_items = [] - for key, option in cls._union_options[cls.union_type_field].items(): - base_name = normalize_name(getattr(key, 'name', option.__name__)) - item_c_code = option.get_c_code( - base_name, ignore_fields=ignore_fields, typedef=False, in_union=True, no_empty=True - ) - if item_c_code: - union_items.append(item_c_code) - size = cls.get_c_union_size() - union_items.append( - "uint8_t bytes[%0d]; " % size - ) - return "{\n" + indent_c("\n".join(union_items)) + "\n}" - - @classmethod - def get_c_parts(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False): - if cls.existing_c_struct is not None: - return (cls.existing_c_struct, "") - - ignore_fields = set() if not ignore_fields else set(ignore_fields) - - if union_only: - if cls.union_type_field: - union_code = cls.get_c_union_code(ignore_fields) - return "typedef union __packed %s" % union_code, "" - else: - return "", "" - - pre = "" - - items = cls.get_c_struct_items(ignore_fields=ignore_fields, - no_empty=no_empty, - top_level=top_level, - union_only=union_only, - in_union=in_union) - - if no_empty and not items: - return "", "" - - # todo: struct comment - if top_level: - comment = cls.__doc__.strip() - if comment: - pre += "/** %s */\n" % comment - pre += "typedef struct __packed " - else: - pre += "struct __packed " - - pre += "{\n%(elements)s\n}" % { - "elements": indent_c( - "\n".join( - code + ("" if not comment else (" /** %s */" % comment)) - for code, comment in items - ) - ), - } - return pre, "" - - @classmethod - def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False, - in_union=False) -> str: - pre, post = cls.get_c_parts(ignore_fields=ignore_fields, - no_empty=no_empty, - top_level=typedef, - union_only=union_only, - in_union=in_union) - if no_empty and not pre and not post: - return "" - return "%s %s%s;" % (pre, name, post) - - @classmethod - def get_variable_name(cls, base_name): - return base_name - - @classmethod - def get_typedef_name(cls): - return "%s_t" % normalize_name(cls.__name__) - - @classmethod - def get_min_size(cls, no_inherited_fields=False) -> int: - if cls.union_type_field: - own_size = sum([cls.get_field_format(f.name).get_min_size() for f in dataclass_fields(cls)]) - union_size = max( - [0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()]) - return own_size + union_size - if no_inherited_fields: - relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls] - else: - relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] - return sum((cls.get_field_format(f.name).get_min_size() for f in relevant_fields), start=0) - - @classmethod - def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int: - if cls.union_type_field: - own_size = sum( - [cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)]) - union_size = max( - [0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in - cls._union_options[cls.union_type_field].values()]) - return own_size + union_size - if no_inherited_fields: - relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls] - else: - relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")] - return sum((cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in relevant_fields), - start=0) - - -def normalize_name(name): - if '_' in name: - name = name.lower() - else: - name = re.sub( - r"(([a-z])([A-Z]))|(([a-zA-Z])([A-Z][a-z]))", - r"\2\5_\3\6", - name - ).lower() - - name = re.sub( - r"(ota)([a-z])", - r"\1_\2", - name - ).lower() - - name = name.replace('config', 'cfg') - name = name.replace('position', 'pos') - name = name.replace('mesh_', '') - name = name.replace('firmware', 'fw') - name = name.replace('hardware', 'hw') - return name diff --git a/src/c3nav/mesh/cformats.py b/src/c3nav/mesh/cformats.py new file mode 100644 index 00000000..1cc2b2f0 --- /dev/null +++ b/src/c3nav/mesh/cformats.py @@ -0,0 +1,961 @@ +import re +import struct +import typing +from abc import ABC, abstractmethod +from collections import namedtuple +from contextlib import suppress +from dataclasses import dataclass +from dataclasses import fields as dataclass_fields +from enum import IntEnum, Enum +from typing import Any, Sequence, Self, Annotated, Literal, Union, Type, TypeVar, ClassVar + +from annotated_types import SLOTS, BaseMetadata, Ge +from pydantic.fields import Field, FieldInfo +from pydantic_extra_types.mac_address import MacAddress + +from c3nav.mesh.utils import indent_c + + +@dataclass(frozen=True, **SLOTS) +class VarLen(BaseMetadata): + var_len_name: str = "num" + + +@dataclass(frozen=True, **SLOTS) +class NoDef(BaseMetadata): + no_def: bool = True + + +@dataclass(frozen=True, **SLOTS) +class AsHex(BaseMetadata): + as_hex: bool = True + + +@dataclass(frozen=True, **SLOTS) +class LenBytes(BaseMetadata): + len_bytes: Annotated[int, Ge(1)] + + +@dataclass(frozen=True, **SLOTS) +class AsDefinition(BaseMetadata): + as_definition: bool = True + + +@dataclass(frozen=True, **SLOTS) +class CEmbed(BaseMetadata): + c_embed: bool = True + + +@dataclass(frozen=True, **SLOTS) +class CName(BaseMetadata): + c_name: str + + +@dataclass(frozen=True, **SLOTS) +class CDoc(BaseMetadata): + c_doc: str + + +@dataclass +class ExistingCStruct(): + name: str + includes: list[str] + + +class CEnum(str, Enum): + def __new__(cls, value, c_value): + obj = str.__new__(cls) + obj._value_ = value + obj.c_value = c_value + return obj + + def __hash__(self): + return hash(self.value) + + +def discriminator_value(**kwargs): + return type('DiscriminatorValue', (), { + # todo: make this so pydantic doesn't throw a warning + **{name: value for name, value in kwargs.items()}, + '__annotations__': { + name: Annotated[Literal[value], Field(init=False)] + for name, value in kwargs.items() + } + }) + + +class TwoNibblesEncodable: + pass + + +class SplitTypeHint(namedtuple("SplitTypeHint", ("base", "metadata"))): + @classmethod + def from_annotation(cls, type_hint) -> Self: + if typing.get_origin(type_hint) is Annotated: + field_infos = tuple(m for m in type_hint.__metadata__ if isinstance(m, FieldInfo)) + return cls( + base=typing.get_args(type_hint)[0], + metadata=( + *(m for m in type_hint.__metadata__), + *(tuple(field_infos[0].metadata) if field_infos else ()) + ) + ) + + if isinstance(type_hint, FieldInfo): + return cls( + base=type_hint.annotation, + metadata=tuple(type_hint.metadata) + ) + + return cls( + base=type_hint, + metadata=() + ) + + def get_len_metadata(self): + max_length = None + var_len_name = None + for m in self.metadata: + ml = getattr(m, 'max_length', None) + if ml is not None: + max_length = ml if max_length is None else min(max_length, ml) + + vl = getattr(m, 'var_len_name', None) + if vl is not None: + if var_len_name is not None: + raise ValueError('can\'t set variable length name twice') + var_len_name = vl + return max_length, var_len_name + + def get_min_max_metadata(self, default_min=-(2 ** 63), default_max=2 ** 63 - 1): + min_ = default_min + max_ = default_max + for m in self.metadata: + gt = getattr(m, 'gt', None) + if gt is not None: + min_ = max(min_, gt + 1) + ge = getattr(m, 'ge', None) + if ge is not None: + min_ = max(min_, ge) + lt = getattr(m, 'lt', None) + if lt is not None: + max_ = min(max_, lt - 1) + le = getattr(m, 'le', None) + if le is not None: + max_ = min(max_, le) + return min_, max_ + + +def normalize_name(name): + if '_' in name: + name = name.lower() + else: + name = re.sub( + r"(([a-z])([A-Z]))|(([a-zA-Z])([A-Z][a-z]))", + r"\2\5_\3\6", + name + ).lower() + + name = re.sub( + r"(ota)([a-z])", + r"\1_\2", + name + ).lower() + + name = name.replace('config', 'cfg') + name = name.replace('position', 'pos') + name = name.replace('mesh_', '') + name = name.replace('firmware', 'fw') + name = name.replace('hardware', 'hw') + return name + + +class CFormat(ABC): + # todo: make this some cool generic with a TypeVar + + def get_var_num(self): + return 0 + + @abstractmethod + def encode(self, value): + pass + + @classmethod + @abstractmethod + def decode(cls, data) -> tuple[Any, bytes]: + pass + + @abstractmethod + def get_min_size(self): + pass + + @abstractmethod + def get_max_size(self): + pass + + def get_size(self, calculate_max=False): + if calculate_max: + return self.get_max_size() + else: + return self.get_min_size() + + @abstractmethod + def get_c_parts(self) -> tuple[str, str]: + pass + + def get_c_code(self, name) -> str: + pre, post = self.get_c_parts() + return "%s %s%s;" % (pre, name, post) + + def get_c_definitions(self) -> dict[str, str]: + return {} + + def get_typedef_name(self): + raise TypeError('no typedef for %r' % self) + + def get_c_includes(self) -> set: + return set() + + @classmethod + def from_annotation(cls, annotation, attr_name=None) -> Self: + if cls is not CFormat: + raise TypeError('call on CFormat!') + return cls.from_split_type_hint(SplitTypeHint.from_annotation(annotation), attr_name=attr_name) + + @classmethod + def from_split_type_hint(cls, type_hint: SplitTypeHint, attr_name=None) -> Self: + if cls is not CFormat: + raise TypeError('call on CFormat!') + outer_type_hint = None + if typing.get_origin(type_hint.base) is list: + outer_type_hint = SplitTypeHint( + base=list, + metadata=type_hint.metadata + ) + type_hint = SplitTypeHint( + base=typing.get_args(type_hint.base)[0], + metadata=() + ) + if typing.get_origin(type_hint.base) is Annotated: + type_hint = SplitTypeHint( + base=typing.get_args(type_hint.base)[0], + metadata=tuple(type_hint.base.__metadata__) + ) + + field_format = None + + if typing.get_origin(type_hint.base) is Literal: + literal_val = typing.get_args(type_hint.base)[0] + if isinstance(literal_val, CEnum): + options = [v.c_value for v in type(literal_val)] + literal_val = literal_val.c_value + int_type = get_int_type( + *type_hint.get_min_max_metadata(default_min=min(options), default_max=max(options)) + ) + elif isinstance(literal_val, int): + int_type = get_int_type(literal_val, literal_val) + else: + raise ValueError() + if int_type is None: + raise ValueError('invalid range:', attr_name) + field_format = SimpleConstFormat(int_type, const_value=literal_val) + elif typing.get_origin(type_hint.base) is Union: + discriminator = None + for m in type_hint.metadata: + discriminator = getattr(m, 'discriminator', discriminator) + if discriminator is None: + raise ValueError('no discriminator') + discriminator_as_hex = any(getattr(m, "as_hex", False) for m in type_hint.metadata) + field_format = UnionFormat( + model_formats=[StructFormat(type_) for type_ in typing.get_args(type_hint.base)], + discriminator=discriminator, + discriminator_as_hex=discriminator_as_hex, + ) + elif type_hint.base is int: + int_type = get_int_type(*type_hint.get_min_max_metadata()) + if int_type is None: + raise ValueError('invalid range:', attr_name) + field_format = SimpleFormat(int_type) + elif type_hint.base is bool: + field_format = BoolFormat() + elif type_hint.base in (str, bytes): + as_hex = any(getattr(m, 'as_hex', False) for m in type_hint.metadata) + max_length, var_len_name = type_hint.get_len_metadata() + if max_length is None: + raise ValueError('missing str max_length:', attr_name) + + if type_hint.base is str: + if var_len_name is not None: + field_format = VarStrFormat(max_len=max_length) + else: + field_format = FixedHexFormat(max_length//2) if as_hex else FixedStrFormat(max_length) + else: + if var_len_name is None: + field_format = FixedBytesFormat(num=max_length) + else: + field_format = VarBytesFormat(max_size=max_length) + elif type_hint.base is MacAddress: + field_format = MacAddressFormat() + elif isinstance(type_hint.base, type) and issubclass(type_hint.base, CEnum): + no_def = any(getattr(m, 'no_def', False) for m in type_hint.metadata) + as_hex = any(getattr(m, 'as_hex', False) for m in type_hint.metadata) + len_bytes = None + for m in type_hint.metadata: + len_bytes = getattr(m, 'len_bytes', len_bytes) + + if len_bytes: + int_type = get_int_type(0, 2 ** (8 * len_bytes - 1)) + else: + options = [v.c_value for v in type_hint.base] + int_type = get_int_type(min(options), max(options)) + if int_type is None: + raise ValueError('invalid range:', attr_name) + field_format = EnumFormat(enum_cls=type_hint.base, fmt=int_type, as_hex=as_hex, c_definition=not no_def) + elif isinstance(type_hint.base, type) and issubclass(type_hint.base, TwoNibblesEncodable): + field_format = TwoNibblesEnumFormat(type_hint.base) + elif isinstance(type_hint.base, type) and typing.get_type_hints(type_hint.base): + field_format = StructFormat(model=type_hint.base) + + if field_format is None: + raise ValueError('Unknown type annotation for c structs', type_hint.base) + else: + if outer_type_hint is not None and outer_type_hint.base is list: + max_length, var_len_name = outer_type_hint.get_len_metadata() + if max_length is None: + raise ValueError('missing list max_length:', attr_name) + if var_len_name: + field_format = VarArrayFormat(field_format, max_num=max_length) + else: + raise ValueError('fixed-len list not implemented:', attr_name) + + return field_format + + +def get_int_type(min_: int, max_: int) -> str | None: + if min_ < 0: + if min_ < -(2 ** 63) or max_ > 2 ** 63 - 1: + return None + elif min_ < -(2 ** 31) or max_ > 2 ** 31 - 1: + return "q" + elif min_ < -(2 ** 15) or max_ > 2 ** 15 - 1: + return "i" + elif min_ < -(2 ** 7) or max_ > 2 ** 7 - 1: + return "h" + else: + return "b" + + if max_ > 2 ** 64 - 1: + return None + elif max_ > 2 ** 32 - 1: + return "Q" + elif max_ > 2 ** 16 - 1: + return "I" + elif max_ > 2 ** 8 - 1: + return "H" + else: + return "B" + + +class SimpleFormat(CFormat): + def __init__(self, fmt): + self.fmt = fmt + self.size = struct.calcsize(fmt) + + self.c_type = self.c_types[self.fmt[-1]] + self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1 + + def encode(self, value): + if self.num == 1: + return struct.pack(self.fmt, value) + return struct.pack(self.fmt, *value) + + def decode(self, data: bytes) -> tuple[Any, bytes]: + value = struct.unpack(self.fmt, data[:self.size]) + if len(value) == 1: + value = value[0] + return value, data[self.size:] + + def get_min_size(self): + return self.size + + def get_max_size(self): + return self.size + + c_types = { + "B": "uint8_t", + "H": "uint16_t", + "I": "uint32_t", + "Q": "uint64_t", + "b": "int8_t", + "h": "int16_t", + "i": "int32_t", + "q": "int64_t", + "s": "char", + } + + def get_c_parts(self): + return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num)) + + +class SimpleConstFormat(SimpleFormat): + def __init__(self, fmt, const_value: int): + super().__init__(fmt) + self.const_value = const_value + + def decode(self, data: bytes) -> tuple[Any, bytes]: + value, out_data = super().decode(data) + if value != self.const_value: + raise ValueError('const_value is wrong') + return value, out_data + + +class EnumFormat(SimpleFormat): + def __init__(self, enum_cls: Type[CEnum], fmt="B", *, as_hex=False, c_definition=True): + super().__init__(fmt) + self.enum_cls = enum_cls + self.enum_lookup = {v.c_value: v for v in enum_cls} + if len(self.enum_cls) != len(self.enum_lookup): + raise ValueError + self.as_hex = as_hex + self.c_definition = c_definition + + self.c_struct_name = normalize_name(enum_cls.__name__) + '_t' + + def decode(self, data: bytes) -> tuple[Any, bytes]: + value, out_data = super().decode(data) + return self.enum_lookup[value], out_data + + def get_typedef_name(self): + return '%s_t' % normalize_name(self.enum_cls.__name__) + + def get_c_parts(self): + if not self.c_definition: + return super().get_c_parts() + return self.c_struct_name, "" + + def get_c_definitions(self) -> dict[str, str]: + if not self.c_definition: + return {} + prefix = normalize_name(self.enum_cls.__name__).upper() + options = [] + last_value = None + for item in self.enum_cls: + if last_value is not None and item.c_value != last_value + 1: + options.append('') + last_value = item.c_value + options.append("%(prefix)s_%(name)s = %(value)s," % { + "prefix": prefix, + "name": normalize_name(item.name).upper(), + "value": ("0x%02x" if self.as_hex else "%d") % item.c_value + }) + + return { + self.c_struct_name: "enum {\n%(options)s\n};\ntypedef uint8_t %(name)s;" % { + "options": indent_c("\n".join(options)), + "name": self.c_struct_name, + } + } + + +class TwoNibblesEnumFormat(SimpleFormat): + def __init__(self, data_cls): + self.data_cls = data_cls + super().__init__('B') + + def decode(self, data: bytes) -> tuple[bool, bytes]: + fields = dataclass_fields(self.data_cls) + value, data = super().decode(data) + return self.data_cls(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data + + def encode(self, value): + fields = dataclass_fields(self.data_cls) + return super().encode( + getattr(value, fields[0].name).value * 2 ** 4 + + getattr(value, fields[1].name).value * 2 ** 4 + ) + + +class BoolFormat(SimpleFormat): + def __init__(self): + super().__init__('B') + + def encode(self, value): + return super().encode(int(value)) + + def decode(self, data: bytes) -> tuple[bool, bytes]: + value, data = super().decode(data) + if value > 1: + raise ValueError('Boolean value > 1') + return bool(value), data + + +class FixedStrFormat(SimpleFormat): + def __init__(self, num): + self.num = num + super().__init__('%ds' % self.num) + + def encode(self, value: str): + return value.encode()[:self.num].ljust(self.num, bytes((0,))), + + def decode(self, data: bytes) -> tuple[str, bytes]: + return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:] + + +class FixedBytesFormat(SimpleFormat): + def __init__(self, num): + self.num = num + super().__init__('%dB' % self.num) + + def encode(self, value: str): + return super().encode(tuple(value)) + + def decode(self, data: bytes) -> tuple[bytes, bytes]: + return data[:self.num], data[self.num:] + + +class FixedHexFormat(SimpleFormat): + def __init__(self, num, sep=''): + self.num = num + self.sep = sep + super().__init__('%dB' % self.num) + + def encode(self, value: str): + return super().encode(tuple(bytes.fromhex(value.replace(':', '')))) + + def decode(self, data: bytes) -> tuple[str, bytes]: + return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:] + + +class MacAddressFormat(FixedHexFormat): + def __init__(self): + super().__init__(num=6, sep=':') + + +class BaseVarFormat(CFormat, ABC): + def __init__(self, max_num): + self.num_fmt = 'H' + self.num_size = struct.calcsize(self.num_fmt) + self.max_num = max_num + + def get_min_size(self): + return self.num_size + + def get_max_size(self): + return self.num_size + self.max_num * self.get_var_num() + + def get_num_c_code(self): + return SimpleFormat(self.num_fmt).get_c_code("num") + + +class VarArrayFormat(BaseVarFormat): + def __init__(self, child_type, max_num): + super().__init__(max_num=max_num) + self.child_type = child_type + self.child_size = self.child_type.get_min_size() + + def get_var_num(self): + return self.child_size + pass + + def encode(self, values: Sequence) -> bytes: + num = len(values) + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + data = struct.pack(self.num_fmt, num) + for value in values: + data += self.child_type.encode(value) + return data + + def decode(self, data: bytes) -> tuple[list[Any], bytes]: + num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + data = data[self.num_size:] + result = [] + for i in range(num): + item, data = self.child_type.decode(data) + result.append(item) + return result, data + + def get_c_parts(self): + pre, post = self.child_type.get_c_parts() + return super().get_num_c_code() + "\n" + pre, "[0]" + post + + +class VarStrFormat(BaseVarFormat): + def __init__(self, max_len): + super().__init__(max_num=max_len) + + def get_var_num(self): + return 1 + + def encode(self, value: str) -> bytes: + num = len(value) + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + return struct.pack(self.num_fmt, num) + value.encode() + + def decode(self, data: bytes) -> tuple[str, bytes]: + num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:] + + def get_c_parts(self): + return super().get_num_c_code() + "\n" + "char", "[0]" + + +class VarBytesFormat(BaseVarFormat): + def __init__(self, max_size): + super().__init__(max_num=max_size) + + def get_var_num(self): + return 1 + + def encode(self, value: bytes) -> bytes: + num = len(value) + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + return struct.pack(self.num_fmt, num) + value + + def decode(self, data: bytes) -> tuple[bytes, bytes]: + num = struct.unpack(self.num_fmt, data[:self.num_size])[0] + if num > self.max_num: + raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}') + return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:] + + def get_c_parts(self): + return super().get_num_c_code() + "\n" + "uint8_t", "[0]" + + +T = TypeVar('T') + + +class CFormatDecodeError(Exception): + pass + + +class StructFormat(CFormat): + _format_cache: dict[Type, dict[str, CFormat]] = {} + + def __new__(cls, model: Type[T]): + result = cls._format_cache.get(model, None) + if not result: + result = super().__new__(cls) + cls._format_cache.get(model, result) + return result + + def __init__(self, model: Type[T]): + self.model = model + + self._field_formats = {} + self._as_definition = set() + self._c_embed = set() + self._c_names = {} + self._c_docs = {} + self._no_init_data = set() + for name, type_hint in typing.get_type_hints(self.model, include_extras=True).items(): + if type_hint is ClassVar: + continue + type_hint = SplitTypeHint.from_annotation(type_hint) + + if any(getattr(m, "as_definition", False) for m in type_hint.metadata): + self._as_definition.add(name) + if any(getattr(m, "c_embed", False) for m in type_hint.metadata): + self._c_embed.add(name) + if not all(getattr(m, "init", True) for m in type_hint.metadata): + self._no_init_data.add(name) + for m in type_hint.metadata: + with suppress(AttributeError): + self._c_names[name] = m.c_name + with suppress(AttributeError): + self._c_docs[name] = m.c_doc + + self._field_formats[name] = CFormat.from_split_type_hint(type_hint, attr_name=name) + + def get_var_num(self): + return sum([field_format.get_var_num() for name, field_format in self._field_formats.items()], start=0) + + def encode(self, instance: T, ignore_fields=()) -> bytes: + data = bytes() + for name, field_format in self._field_formats.items(): + if name in ignore_fields: + continue + data += field_format.encode(getattr(instance, name)) + return data + + def decode(self, data: bytes) -> tuple[T, bytes]: + decoded = {} + for name, field_format in self._field_formats.items(): + try: + value, data = field_format.decode(data) + except (struct.error, UnicodeDecodeError, ValueError) as e: + raise CFormatDecodeError(f"failed to decode model={self.model}, field={name}, data={data}, e={e}") + if isinstance(value, CEnum): + value = value.value + if name not in self._no_init_data: + decoded[name] = value + return self.model.model_validate(decoded), data + + def get_min_size(self) -> int: + return sum(( + field_format.get_min_size() for field_format in self._field_formats.values() + ), start=0) + + def get_max_size(self) -> int: + raise ValueError + + def get_size(self, calculate_max=False): + return sum(( + field_format.get_size(calculate_max=calculate_max) for field_format in self._field_formats.values() + ), start=0) + + def get_c_struct_items(self, ignore_fields=None, no_empty=False, top_level=False): + ignore_fields = set() if not ignore_fields else set(ignore_fields) + + items = [] + + for name, field_format in self._field_formats.items(): + if name in ignore_fields: + continue + + c_name = self._c_names.get(name, name) + if not isinstance(field_format, (StructFormat, UnionFormat)): + items.append(( + ( + ("%(typedef_name)s %(name)s;" % { + "typedef_name": field_format.get_typedef_name(), + "name": c_name, + }) + if name in self._as_definition + else field_format.get_c_code(c_name) + ), + self._c_docs.get(name, None), + )), + else: + if name in self._c_embed: + embedded_items = field_format.get_c_struct_items(ignore_fields, no_empty, top_level) + items.extend(embedded_items) + else: + items.append(( + ( + ("%(typedef_name)s %(name)s;" % { + "typedef_name": field_format.get_typedef_name(), + "name": c_name, + }) + if name in self._as_definition + else field_format.get_c_code(c_name, typedef=False) + ), + self._c_docs.get(name, None), + )) + + return items + + def get_c_parts(self, ignore_fields=None, no_empty=False, top_level=False) -> tuple[str, str]: + with suppress(AttributeError): + return (self.model.existing_c_struct.name, "") + + ignore_fields = set() if not ignore_fields else set(ignore_fields) + + pre = "" + + items = self.get_c_struct_items(ignore_fields=ignore_fields, + no_empty=no_empty, + top_level=top_level) + + if no_empty and not items: + return "", "" + + if top_level: + comment = self.model.__doc__.strip() + if comment: + pre += "/** %s */\n" % comment + pre += "typedef struct __packed " + else: + pre += "struct __packed " + + pre += "{\n%(elements)s\n}" % { + "elements": indent_c( + "\n".join( + code + ("" if not comment else (" /** %s */" % comment)) + for code, comment in items + ) + ), + } + return pre, "" + + def get_c_code(self, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str: + pre, post = self.get_c_parts(ignore_fields=ignore_fields, + no_empty=no_empty, + top_level=typedef) + if no_empty and not pre and not post: + return "" + return "%s %s%s;" % (pre, name, post) + + def get_c_definitions(self) -> dict[str, str]: + definitions = {} + for name, field_format in self._field_formats.items(): + definitions.update(field_format.get_c_definitions()) + if name in self._as_definition: + typedef_name = field_format.get_typedef_name() + if not isinstance(field_format, StructFormat): + definitions[typedef_name] = 'typedef %(code)s %(name)s;' % { + "code": ''.join(field_format.get_c_parts()), + "name": typedef_name, + } + else: + definitions[typedef_name] = field_format.get_c_code(name=typedef_name, typedef=True) + return definitions + + def get_typedef_name(self): + return "%s_t" % normalize_name(self.model.__name__) + + def get_c_includes(self) -> set: + result = set() + with suppress(AttributeError): + result.update(self.model.existing_c_struct.includes) + for field_format in self._field_formats.values(): + result.update(field_format.get_c_includes()) + return result + + +class UnionFormat(CFormat): + def __init__(self, model_formats: Sequence[StructFormat], discriminator: str, discriminator_as_hex: bool = False): + self.discriminator = discriminator + models = { + getattr(model_format.model, discriminator): model_format for model_format in model_formats + } + if len(models) != len(model_formats): + raise ValueError + types = set(type(value) for value in models.keys()) + if len(types) != 1: + raise ValueError + discriminator_annotation = tuple(types)[0] + if discriminator_as_hex: + discriminator_annotation = Annotated[discriminator_annotation, AsHex()] + self.discriminator_format = CFormat.from_annotation(discriminator_annotation) + self.key_to_name = {value.c_value: value.name for value in models.keys()} + self.models = {value.c_value: model_format for value, model_format in models.items()} + + def get_var_num(self): + return 0 # todo: is this always correct? + + def encode(self, instance) -> bytes: + discriminator_value = getattr(instance, self.discriminator) + try: + model_format = self.models[discriminator_value.c_value] + except KeyError: + raise ValueError('Unknown discriminator value for Union: %r' % discriminator_value) + if not isinstance(instance, model_format.model): + raise ValueError('Unknown value for Union discriminator %r: %r' % (discriminator_value, instance)) + return ( + self.discriminator_format.encode(discriminator_value.c_value) + + model_format.encode(instance, ignore_fields=(self.discriminator, )) + ) + + def decode(self, data: bytes) -> tuple[T, bytes]: + discriminator_value, remaining_data = self.discriminator_format.decode(data) + return self.models[discriminator_value.c_value].decode(data) + + def get_min_size(self) -> int: + return max([0] + [ + model_format.get_min_size() + for model_format in self.models.values() + ]) + + def get_max_size(self) -> int: + raise ValueError + + def get_size(self=False, calculate_max=False): + return max([0] + [ + field_format.get_size(calculate_max=calculate_max) + for field_format in self.models.values() + ]) + + def get_c_struct_items(self, ignore_fields=None, no_empty=False, top_level=False): + return [ + (self.discriminator_format.get_c_code(self.discriminator), None), + ("union __packed %s;" % self.get_c_union_code(), None), + ] + + def get_c_union_size(self): + return max( + (model_format.get_min_size() for model_format in self.models.values()), + default=0, + ) - self.discriminator_format.get_min_size() + + def get_c_union_code(self): + union_items = [] + for key, model_format in self.models.items(): + base_name = normalize_name(self.key_to_name[key]) + item_c_code = model_format.get_c_code( + base_name, ignore_fields=(self.discriminator, ), typedef=False, no_empty=True + ) + if item_c_code: + union_items.append(item_c_code) + size = self.get_c_union_size() + union_items.append( + "uint8_t bytes[%0d];" % size + ) + return "{\n" + indent_c("\n".join(union_items)) + "\n}" + + def get_c_parts(self, ignore_fields=None, no_empty=False, top_level=False) -> tuple[str, str]: + items = self.get_c_struct_items(no_empty=no_empty, + top_level=top_level) + + if no_empty and not items: + return "", "" + + if top_level: + pre = "typedef struct __packed " + else: + pre = "struct __packed " + + pre += "{\n%(elements)s\n}" % { + "elements": indent_c( + "\n".join( + code + ("" if not comment else (" /** %s */" % comment)) + for code, comment in items + ) + ), + } + return pre, "" + + def get_c_code(self, name=None, ignore_fields=None, no_empty=False, typedef=True,) -> str: + pre, post = self.get_c_parts(ignore_fields=ignore_fields, + no_empty=no_empty, + top_level=typedef) + if no_empty and not pre and not post: + return "" + return "%s %s%s;" % (pre, name, post) + + def get_c_definitions(self) -> dict[str, str]: + definitions = {} + definitions.update(self.discriminator_format.get_c_definitions()) + for model_format in self.models.values(): + definitions.update(model_format.get_c_definitions()) + return definitions + + def get_typedef_name(self): + names = [model_format.model.__name__ for model_format in self.models.values()] + min_len = min(len(name) for name in names) + longest_prefix = '' + longest_suffix = '' + for i in reversed(range(min_len)): + a = set(name[:i] for name in names) + if len(a) == 1: + longest_prefix = tuple(a)[0] + break + for i in reversed(range(min_len)): + a = set(name[-i:] for name in names) + if len(a) == 1: + longest_suffix = tuple(a)[0] + break + return "%s_t" % normalize_name(longest_prefix if len(longest_prefix) > len(longest_suffix) else longest_suffix) + + def get_c_includes(self) -> set: + result = set() + result.update(self.discriminator_format.get_c_includes()) + for model_format in self.models.values(): + result.update(model_format.get_c_includes()) + return result diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 3b56a591..08d4e704 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -17,6 +17,7 @@ from django.utils import timezone from django.utils.crypto import constant_time_compare from c3nav.mesh import messages +from c3nav.mesh.cformats import CFormat from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, OTA_CHUNK_SIZE, MeshMessage, MeshMessageType, OTAApplyMessage, OTASettingMessage) from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage, OTARecipientStatus, OTAUpdate, OTAUpdateRecipient @@ -46,6 +47,7 @@ class NodeState: class MeshConsumer(AsyncWebsocketConsumer): + mesh_msg_format = CFormat.from_annotation(MeshMessage) def __init__(self): super().__init__() self.uplink = None @@ -56,6 +58,7 @@ class MeshConsumer(AsyncWebsocketConsumer): self.ota_send_task = None self.ota_chunks: dict[int, set[int]] = {} # keys are update IDs, values are a list of chunk IDs self.ota_chunks_available_condition = asyncio.Condition() + self.accepted = None async def connect(self): self.headers = dict(self.scope["headers"]) @@ -68,13 +71,16 @@ class MeshConsumer(AsyncWebsocketConsumer): self.ping_task = get_event_loop().create_task(self.ping_regularly()) self.check_node_state_task = get_event_loop().create_task(self.check_node_states()) self.ota_send_task = get_event_loop().create_task(self.ota_send()) + self.accepted = True async def disconnect(self, close_code): + if not self.accepted: + return self.ping_task.cancel() self.check_node_state_task.cancel() self.ota_send_task.cancel() - await self.log_text(self.uplink.node, "mesh websocket disconnected") if self.uplink is not None: + await self.log_text(self.uplink.node, "mesh websocket disconnected") # leave broadcast group await self.channel_layer.group_discard("mesh_comm_broadcast", self.channel_name) @@ -91,10 +97,10 @@ class MeshConsumer(AsyncWebsocketConsumer): end_reason=MeshUplink.EndReason.CLOSED ) - async def send_msg(self, msg, sender=None, exclude_uplink_address=None): + async def send_msg(self, msg: MeshMessage, sender=None, exclude_uplink_address=None): # print("sending", msg, MeshMessage.encode(msg).hex(' ', 1)) - # self.log_text(msg.dst, "sending %s" % msg) - await self.send(bytes_data=MeshMessage.encode(msg)) + # self.log_text(msg_envelope.dst, "sending %s" % msg) + await self.send(bytes_data=self.mesh_msg_format.encode(msg)) await self.channel_layer.group_send("mesh_msg_sent", { "type": "mesh.msg_sent", "timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"), @@ -113,18 +119,20 @@ class MeshConsumer(AsyncWebsocketConsumer): if bytes_data is None: return try: - msg, data = messages.MeshMessage.decode(bytes_data) + msg, data = self.mesh_msg_format.decode(bytes_data) + msg: MeshMessage except Exception: - print("Unable to decode: ") + print("Unable to decode: msg_type=", hex(bytes_data[12])) print(bytes_data) traceback.print_exc() return + msg.content = msg.content # print(msg) if msg.dst != messages.MESH_ROOT_ADDRESS and msg.dst != messages.MESH_PARENT_ADDRESS: # message not adressed to us, forward it - print('Received message for forwarding:', msg) + print('Received message for forwarding:', msg.content) if not self.uplink: await self.log_text(None, "received message not for us before sign in message, ignoring...") @@ -132,12 +140,12 @@ class MeshConsumer(AsyncWebsocketConsumer): return # trace messages collect node adresses before forwarding - if isinstance(msg, messages.MeshRouteTraceMessage): + if isinstance(msg.content, messages.MeshRouteTraceMessage): print('adding ourselves to trace message before forwarding') await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding") - msg.trace.append(MESH_ROOT_ADDRESS) + msg.content.trace.append(MESH_ROOT_ADDRESS) - result = await msg.send(exclude_uplink_address=self.uplink.node.address) + result = await msg.content.send(exclude_uplink_address=self.uplink.node.address) if not result: print('message had no route') @@ -154,7 +162,7 @@ class MeshConsumer(AsyncWebsocketConsumer): src_node, created = await MeshNode.objects.aget_or_create(address=msg.src) - if isinstance(msg, messages.MeshSigninMessage): + if isinstance(msg.content, messages.MeshSigninMessage): if not self.check_valid_address(msg.src): print('reject node with invalid address address') await self.close() @@ -172,10 +180,12 @@ class MeshConsumer(AsyncWebsocketConsumer): await self.log_received_message(src_node, msg) # inform signed in uplink node about its layer - await self.send_msg(messages.MeshLayerAnnounceMessage( + await self.send_msg(messages.MeshMessage( src=messages.MESH_ROOT_ADDRESS, dst=msg.src, - layer=messages.NO_LAYER + content=messages.MeshLayerAnnounceMessage( + layer=messages.NO_LAYER + ) )) # add signed in uplink node to broadcast group @@ -186,7 +196,7 @@ class MeshConsumer(AsyncWebsocketConsumer): # add this node as a destination that this uplink handles (duh) await self.add_dst_nodes(nodes=(src_node, )) - self.dst_nodes[msg.src].last_msg[MeshMessageType.MESH_SIGNIN] = msg + self.dst_nodes[msg.src].last_msg[MeshMessageType.MESH_SIGNIN] = msg.content return @@ -202,38 +212,42 @@ class MeshConsumer(AsyncWebsocketConsumer): except KeyError: print('unexpected message from', msg.src) return - node_status.last_msg[msg.msg_type] = msg + node_status.last_msg[msg.content.msg_type] = msg.content - if isinstance(msg, messages.MeshAddDestinationsMessage): - result = await self.add_dst_nodes(addresses=msg.addresses) + if isinstance(msg.content, messages.MeshAddDestinationsMessage): + result = await self.add_dst_nodes(addresses=msg.content.addresses) if not result: - print('disconnecting node that send invalid destinations', msg) + print('disconnecting node that send invalid destinations', msg.content) await self.close() - if isinstance(msg, messages.MeshRemoveDestinationsMessage): - await self.remove_dst_nodes(addresses=msg.addresses) + if isinstance(msg.content, messages.MeshRemoveDestinationsMessage): + await self.remove_dst_nodes(addresses=msg.content.addresses) - if isinstance(msg, messages.MeshRouteRequestMessage): - if msg.address == MESH_ROOT_ADDRESS: + if isinstance(msg.content, messages.MeshRouteRequestMessage): + if msg.content.address == MESH_ROOT_ADDRESS: await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace") - await self.send_msg(messages.MeshRouteTraceMessage( + await self.send_msg(messages.MeshMessage( src=MESH_ROOT_ADDRESS, dst=msg.src, - request_id=msg.request_id, - trace=[MESH_ROOT_ADDRESS], + content=messages.MeshRouteTraceMessage( + request_id=msg.content.request_id, + trace=[MESH_ROOT_ADDRESS], + ) )) else: await self.log_text(MESH_ROOT_ADDRESS, "route request about someone else, sending response") - self.open_requests.add(msg.request_id) - uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.address) - await self.send_msg(messages.MeshRouteResponseMessage( + self.open_requests.add(msg.content.request_id) + uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.content.address) + await self.send_msg(messages.MeshMessage( src=MESH_ROOT_ADDRESS, dst=msg.src, - request_id=msg.request_id, - route=uplink.node_id if uplink else MESH_NONE_ADDRESS, + content=messages.MeshRouteResponseMessage( + request_id=msg.content.request_id, + route=uplink.node_id if uplink else MESH_NONE_ADDRESS, + ) )) - if isinstance(msg, (messages.ConfigHardwareMessage, + if isinstance(msg.content, (messages.ConfigHardwareMessage, messages.ConfigFirmwareMessage, messages.ConfigBoardMessage)): if (node_status.waiting_for == NodeWaitingFor.CONFIG and @@ -243,16 +257,16 @@ class MeshConsumer(AsyncWebsocketConsumer): print('got all config, checking ota') await self.check_ota([msg.src], first_time=True) - if isinstance(msg, messages.OTAStatusMessage): - print('got OTA status', msg) - node_status.reported_ota_update = msg.update_id + if isinstance(msg.content, messages.OTAStatusMessage): + print('got OTA status', msg.content) + node_status.reported_ota_update = msg.content.update_id if node_status.waiting_for == NodeWaitingFor.OTA_START_STOP: update_id = node_status.ota_recipient.update_id if node_status.ota_recipient else 0 - if update_id == msg.update_id: + if update_id == msg.content.update_id: print('start/cancel confirmed!') node_status.waiting_for = NodeWaitingFor.NOTHING if update_id: - if msg.status.is_failed: + if msg.content.status.is_failed: print('ota failed') node_status.ota_recipient.status = OTARecipientStatus.FAILED await node_status.ota_recipient.send_status() @@ -263,14 +277,14 @@ class MeshConsumer(AsyncWebsocketConsumer): else: print('queue chunk sending') await self.ota_set_chunks(node_status.ota_recipient.update, - min_chunk=msg.next_expected_chunk) + min_chunk=msg.content.next_expected_chunk) - if isinstance(msg, messages.OTARequestFragmentsMessage): - print('got OTA fragment request', msg) + if isinstance(msg.content, messages.OTARequestFragmentsMessage): + print('got OTA fragment request', msg.content) desired_update_id = node_status.ota_recipient.update_id if node_status.ota_recipient else 0 - if desired_update_id and msg.update_id == desired_update_id: + if desired_update_id and msg.content.update_id == desired_update_id: print('queue requested chunk sending') - await self.ota_set_chunks(node_status.ota_recipient.update, chunks=set(msg.chunks)) + await self.ota_set_chunks(node_status.ota_recipient.update, chunks=set(msg.content.chunks)) @database_sync_to_async def create_uplink_in_database(self, address): @@ -334,7 +348,7 @@ class MeshConsumer(AsyncWebsocketConsumer): self.uplink.node.address, "we're the route for this message but it came from here so... no" ) return - await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"]) + await self.send_msg(MeshMessage.model_validate(data["msg"]), data["sender"]) async def mesh_ota_recipients_changed(self, data): addresses = set(data["addresses"]) & set(self.dst_nodes.keys()) @@ -349,7 +363,7 @@ class MeshConsumer(AsyncWebsocketConsumer): """ async def log_received_message(self, src_node: MeshNode, msg: messages.MeshMessage): - as_json = MeshMessage.tojson(msg) + as_json = msg.model_dump() await self.channel_layer.group_send("mesh_msg_received", { "type": "mesh.msg_received", "timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"), @@ -360,7 +374,7 @@ class MeshConsumer(AsyncWebsocketConsumer): await NodeMessage.objects.acreate( uplink=self.uplink, src_node=src_node, - message_type=msg.msg_type.name, + message_type=msg.content.msg_type.name, data=as_json, ) @@ -435,29 +449,34 @@ class MeshConsumer(AsyncWebsocketConsumer): node_state.last_sent = timezone.now() print('request config dump, attempt #%d' % node_state.attempt) node_state.attempt += 1 - await self.send_msg(messages.ConfigDumpMessage( + await self.send_msg(messages.MeshMessage( src=MESH_ROOT_ADDRESS, dst=address, + content=messages.ConfigDumpMessage() )) case NodeWaitingFor.OTA_START_STOP: node_state.last_sent = timezone.now() if node_state.ota_recipient: print('starting ota, attempt #%d' % node_state.attempt) - await self.send_msg(messages.OTAStartMessage( + await self.send_msg(messages.MeshMessage( src=MESH_ROOT_ADDRESS, dst=address, - update_id=node_state.ota_recipient.update_id, # noqa - total_bytes=node_state.ota_recipient.update.build.binary.size, - auto_apply=False, - auto_reboot=False, + content=messages.OTAStartMessage( + update_id=node_state.ota_recipient.update_id, # noqa + total_bytes=node_state.ota_recipient.update.build.binary.size, + auto_apply=False, + auto_reboot=False, + ) )) else: print('canceling ota, attempt #%d' % node_state.attempt) - await self.send_msg(messages.OTAAbortMessage( + await self.send_msg(messages.MeshMessage( src=MESH_ROOT_ADDRESS, dst=address, - update_id=0, + content=messages.OTAAbortMessage( + update_id=0, + ) )) async def check_node_states(self): @@ -511,12 +530,14 @@ class MeshConsumer(AsyncWebsocketConsumer): with self.dst_nodes[recipients[0]].ota_recipient.update.build.binary.open('rb') as f: f.seek(chunk * OTA_CHUNK_SIZE) data = f.read(OTA_CHUNK_SIZE) - await self.send_msg(messages.OTAFragmentMessage( + await self.send_msg(messages.MeshMessage( src=MESH_ROOT_ADDRESS, dst=recipients[0] if len(recipients) == 1 else MESH_BROADCAST_ADDRESS, - update_id=update_id, - chunk=chunk, - data=data, + content=messages.OTAFragmentMessage( + update_id=update_id, + chunk=chunk, + data=data, + ) )) # wait a bit until we send more @@ -681,7 +702,7 @@ class MeshUIConsumer(AsyncJsonWebsocketConsumer): self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]} for recipient in msg_to_send["recipients"]: - await MeshMessage.fromjson({ + await MeshMessage.model_validate({ 'dst': recipient, **msg_to_send["msg_data"], }).send(sender=self.channel_name) diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py deleted file mode 100644 index a2611254..00000000 --- a/src/c3nav/mesh/dataformats.py +++ /dev/null @@ -1,285 +0,0 @@ -import re -from dataclasses import dataclass, field -from enum import IntEnum, unique -from typing import BinaryIO, Self - -from c3nav.api.utils import EnumSchemaByNameMixin -from c3nav.mesh.baseformats import (BoolFormat, ChipRevFormat, EnumFormat, FixedHexFormat, FixedStrFormat, - SimpleConstFormat, SimpleFormat, StructType, TwoNibblesEnumFormat, VarArrayFormat) - - -class MacAddressFormat(FixedHexFormat): - def __init__(self): - super().__init__(num=6, sep=':') - - -class MacAddressesListFormat(VarArrayFormat): - def __init__(self, max_num): - super().__init__(child_type=MacAddressFormat(), max_num=max_num) - - -@unique -class LedType(IntEnum): - NONE = 0 - SERIAL = 1 - MULTIPIN = 2 - - @property - def pretty_name(self): - return self.name.lower() - - -@unique -class SerialLedType(IntEnum): - WS2812 = 1 - SK6812 = 2 - - -@dataclass -class LedConfig(StructType, union_type_field="led_type"): - """ - configuration for an optional connected status LED - """ - led_type: LedType = field(metadata={"format": EnumFormat(), "c_name": "type"}) - - -@dataclass -class NoLedConfig(LedConfig, led_type=LedType.NONE): - pass - - -@dataclass -class SerialLedConfig(LedConfig, led_type=LedType.SERIAL): - serial_led_type: SerialLedType = field(metadata={"format": EnumFormat(), "c_name": "type"}) - gpio: int = field(metadata={"format": SimpleFormat('B')}) - - -@dataclass -class MultipinLedConfig(LedConfig, led_type=LedType.MULTIPIN): - gpio_red: int = field(metadata={"format": SimpleFormat('B')}) - gpio_green: int = field(metadata={"format": SimpleFormat('B')}) - gpio_blue: int = field(metadata={"format": SimpleFormat('B')}) - - -@dataclass -class BoardSPIConfig(StructType): - """ - configuration for spi bus used for ETH or UWB - """ - gpio_miso: int = field(metadata={"format": SimpleFormat('B')}) - gpio_mosi: int = field(metadata={"format": SimpleFormat('B')}) - gpio_clk: int = field(metadata={"format": SimpleFormat('B')}) - - -@dataclass -class UWBConfig(StructType): - """ - configuration for the connection to the UWB module - """ - enable: bool = field(metadata={"format": BoolFormat()}) - gpio_cs: int = field(metadata={"format": SimpleFormat('B')}) - gpio_irq: int = field(metadata={"format": SimpleFormat('B')}) - gpio_rst: int = field(metadata={"format": SimpleFormat('B')}) - gpio_wakeup: int = field(metadata={"format": SimpleFormat('B')}) - gpio_exton: int = field(metadata={"format": SimpleFormat('B')}) - - -@dataclass -class UplinkEthConfig(StructType): - """ - configuration for the connection to the ETH module - """ - enable: bool = field(metadata={"format": BoolFormat()}) - gpio_cs: int = field(metadata={"format": SimpleFormat('B')}) - gpio_int: int = field(metadata={"format": SimpleFormat('B')}) - gpio_rst: int = field(metadata={"format": SimpleFormat('b')}) - - -@unique -class BoardType(EnumSchemaByNameMixin, IntEnum): - CUSTOM = 0x00 - - # devboards - ESP32_C3_DEVKIT_M_1 = 0x01 - ESP32_C3_32S = 2 - - # custom boards - C3NAV_UWB_BOARD = 0x10 - C3NAV_LOCATION_PCB_REV_0_1 = 0x11 - C3NAV_LOCATION_PCB_REV_0_2 = 0x12 - - @property - def pretty_name(self): - if self.name.startswith('ESP32'): - return self.name.replace('_', '-').replace('DEVKIT-', 'DevKit') - if self.name.startswith('C3NAV'): - name = self.name.replace('_', ' ').lower() - name = name.replace('uwb', 'UWB').replace('pcb', 'PCB') - name = re.sub(r'[0-9]+( [0-9+])+', lambda s: s[0].replace(' ', '.'), name) - name = re.sub(r'rev.*', lambda s: s[0].replace(' ', ''), name) - return name - return self.name - - -@dataclass -class BoardConfig(StructType, union_type_field="board"): - board: BoardType = field(metadata={"format": EnumFormat(as_hex=True)}) - - -@dataclass -class CustomBoardConfig(BoardConfig, board=BoardType.CUSTOM): - spi: BoardSPIConfig = field(metadata={"as_definition": True}) - uwb: UWBConfig = field(metadata={"as_definition": True}) - eth: UplinkEthConfig = field(metadata={"as_definition": True}) - led: LedConfig = field(metadata={"as_definition": True}) - - -@dataclass -class DevkitMBoardConfig(BoardConfig, board=BoardType.ESP32_C3_DEVKIT_M_1): - spi: BoardSPIConfig = field(metadata={"as_definition": True}) - uwb: UWBConfig = field(metadata={"as_definition": True}) - eth: UplinkEthConfig = field(metadata={"as_definition": True}) - - -@dataclass -class Esp32SBoardConfig(BoardConfig, board=BoardType.ESP32_C3_32S): - spi: BoardSPIConfig = field(metadata={"as_definition": True}) - uwb: UWBConfig = field(metadata={"as_definition": True}) - eth: UplinkEthConfig = field(metadata={"as_definition": True}) - - -@dataclass -class UwbBoardConfig(BoardConfig, board=BoardType.C3NAV_UWB_BOARD): - eth: UplinkEthConfig = field(metadata={"as_definition": True}) - - -@dataclass -class LocationPCBRev0Dot1BoardConfig(BoardConfig, board=BoardType.C3NAV_LOCATION_PCB_REV_0_1): - eth: UplinkEthConfig = field(metadata={"as_definition": True}) - - -@dataclass -class LocationPCBRev0Dot2BoardConfig(BoardConfig, board=BoardType.C3NAV_LOCATION_PCB_REV_0_2): - eth: UplinkEthConfig = field(metadata={"as_definition": True}) - - -@dataclass -class RangeResultItem(StructType): - peer: str = field(metadata={"format": MacAddressFormat()}) - rssi: int = field(metadata={"format": SimpleFormat('b')}) - distance: int = field(metadata={"format": SimpleFormat('h')}) - - -@dataclass -class RawFTMEntry(StructType): - dlog_token: int = field(metadata={"format": SimpleFormat('B')}) - rssi: int = field(metadata={"format": SimpleFormat('b')}) - rtt: int = field(metadata={"format": SimpleFormat('I')}) - t1: int = field(metadata={"format": SimpleFormat('Q')}) - t2: int = field(metadata={"format": SimpleFormat('Q')}) - t3: int = field(metadata={"format": SimpleFormat('Q')}) - t4: int = field(metadata={"format": SimpleFormat('Q')}) - - -@dataclass -class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t", c_includes=['']): - magic_word: int = field(metadata={"format": SimpleConstFormat('I', 0xAB_CD_54_32)}, repr=False) - secure_version: int = field(metadata={"format": SimpleFormat('I')}) - reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False) - version: str = field(metadata={"format": FixedStrFormat(32)}) - project_name: str = field(metadata={"format": FixedStrFormat(32)}) - compile_time: str = field(metadata={"format": FixedStrFormat(16)}) - compile_date: str = field(metadata={"format": FixedStrFormat(16)}) - idf_version: str = field(metadata={"format": FixedStrFormat(32)}) - app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)}) - reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False) - - -@unique -class SPIFlashMode(EnumSchemaByNameMixin, IntEnum): - QIO = 0 - QOUT = 1 - DIO = 2 - DOUT = 3 - - -@unique -class FlashSize(EnumSchemaByNameMixin, IntEnum): - SIZE_1MB = 0 - SIZE_2MB = 1 - SIZE_4MB = 2 - SIZE_8MB = 3 - SIZE_16MB = 4 - SIZE_32MB = 5 - SIZE_64MB = 6 - SIZE_128MB = 7 - - @property - def pretty_name(self): - return self.name.removeprefix('SIZE_') - - -@unique -class FlashFrequency(EnumSchemaByNameMixin, IntEnum): - FREQ_40MHZ = 0 - FREQ_26MHZ = 1 - FREQ_20MHZ = 2 - FREQ_80MHZ = 0xf - - @property - def pretty_name(self): - return self.name.removeprefix('FREQ_').replace('MHZ', 'Mhz') - - -@dataclass -class FlashSettings: - size: FlashSize - frequency: FlashFrequency - - @property - def display(self): - return f"{self.size.pretty_name} ({self.frequency.pretty_name})" - - -@unique -class ChipType(EnumSchemaByNameMixin, IntEnum): - ESP32_S2 = 2 - ESP32_C3 = 5 - - @property - def pretty_name(self): - return self.name.replace('_', '-') - - -@dataclass -class FirmwareImageFileHeader(StructType): - magic_word: int = field(metadata={"format": SimpleConstFormat('B', 0xE9)}, repr=False) - num_segments: int = field(metadata={"format": SimpleFormat('B')}) - spi_flash_mode: SPIFlashMode = field(metadata={"format": EnumFormat()}) - flash_stuff: FlashSettings = field(metadata={"format": TwoNibblesEnumFormat()}) - entry_point: int = field(metadata={"format": SimpleFormat('I')}) - - -@dataclass -class FirmwareImageExtendedFileHeader(StructType): - wp_pin: int = field(metadata={"format": SimpleFormat('B')}) - drive_settings: int = field(metadata={"format": SimpleFormat('3B')}) - chip: ChipType = field(metadata={"format": EnumFormat('H')}) - min_chip_rev_old: int = field(metadata={"format": SimpleFormat('B')}) - min_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()}) - max_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()}) - reserv: int = field(metadata={"format": SimpleFormat('I')}, repr=False) - hash_appended: bool = field(metadata={"format": BoolFormat()}) - - -@dataclass -class FirmwareImage(StructType): - header: FirmwareImageFileHeader - ext_header: FirmwareImageExtendedFileHeader - first_segment_headers: tuple[int, int] = field(metadata={"format": SimpleFormat('2I')}, repr=False) - app_desc: FirmwareAppDescription - - @classmethod - def from_file(cls, file: BinaryIO) -> Self: - result, data = cls.decode(file.read(FirmwareImage.get_min_size())) - return result diff --git a/src/c3nav/mesh/forms.py b/src/c3nav/mesh/forms.py index f72fe744..86c0793b 100644 --- a/src/c3nav/mesh/forms.py +++ b/src/c3nav/mesh/forms.py @@ -13,11 +13,15 @@ from django.db import transaction from django.forms import BooleanField, ChoiceField, Form, ModelMultipleChoiceField, MultipleChoiceField from django.http import Http404 from django.utils.translation import gettext_lazy as _ +from pydantic import ValidationError as PydanticValidationError +from pydantic.type_adapter import TypeAdapter -from c3nav.mesh.dataformats import BoardConfig, BoardType, LedType, SerialLedType -from c3nav.mesh.messages import MESH_BROADCAST_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, MeshMessageType +from c3nav.mesh.cformats import CFormat +from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, MeshMessageContent, + MeshMessageType) from c3nav.mesh.models import (FirmwareBuild, HardwareDescription, MeshNode, OTARecipientStatus, OTAUpdate, OTAUpdateRecipient) +from c3nav.mesh.schemas import BoardConfig, BoardType, LedType, SerialLedType from c3nav.mesh.utils import MESH_ALL_OTA_GROUP, group_msg_type_choices @@ -64,7 +68,7 @@ class MeshMessageForm(forms.Form): if cls.msg_type in MeshMessageForm.msg_types: raise TypeError('duplicate use of msg %s' % cls.msg_type) MeshMessageForm.msg_types[cls.msg_type] = cls - cls.msg_type_class = MeshMessage.get_type(cls.msg_type) + cls.msg_type_class = CFormat.from_annotation(MeshMessageContent).models.get(cls.msg_type.c_value).model @classmethod def get_form_for_type(cls, msg_type): @@ -83,9 +87,11 @@ class MeshMessageForm(forms.Form): raise Exception('nope') return { - 'msg_type': self.msg_type.name, - 'src': MESH_ROOT_ADDRESS, - **self.get_cleaned_msg_data(), + "src": MESH_ROOT_ADDRESS, + "content": { + "msg_type": self.msg_type.name, + **self.get_cleaned_msg_data(), + } } def get_recipients(self): @@ -96,7 +102,7 @@ class MeshMessageForm(forms.Form): recipients = self.get_recipients() for recipient in recipients: print('sending to ', recipient) - async_to_sync(MeshMessage.fromjson({ + async_to_sync(MeshMessage.model_validate({ 'dst': recipient, **msg_data, }).send)() @@ -173,8 +179,8 @@ class ConfigBoardMessageForm(MeshMessageForm): "prefix": "led_", "field": "board", "values": tuple( - cfg.board.name for cfg in BoardConfig._union_options["board"].values() - if "led" in cfg.__dataclass_fields__ + board_type.name for board_type in BoardType + if "led" in CFormat.from_annotation(BoardConfig).models[board_type.c_value]._field_formats ), }, { @@ -191,8 +197,8 @@ class ConfigBoardMessageForm(MeshMessageForm): "prefix": "uwb_", "field": "board", "values": tuple( - cfg.board.name for cfg in BoardConfig._union_options["board"].values() - if "uwb" in cfg.__dataclass_fields__ + board_type.name for board_type in BoardType + if "uwb" in CFormat.from_annotation(BoardConfig).models[board_type.c_value]._field_formats ), }, { @@ -204,10 +210,7 @@ class ConfigBoardMessageForm(MeshMessageForm): def clean(self): cleaned_data = super().clean() - - board_cfg = BoardConfig._union_options["board"][BoardType[cleaned_data["board"]]] - has_led = "led" in board_cfg.__dataclass_fields__ - has_uwb = "uwb" in board_cfg.__dataclass_fields__ + orig_cleaned_keys = set(cleaned_data.keys()) led_values = { "led_type": cleaned_data.pop("led_type"), @@ -217,43 +220,29 @@ class ConfigBoardMessageForm(MeshMessageForm): if name.startswith('led_') } } + if led_values: + cleaned_data["led"] = led_values + uwb_values = { name.removeprefix('uwb_'): cleaned_data.pop(name) for name in tuple(cleaned_data.keys()) if name.startswith('uwb_') } - - errors = {} - - if has_led: - prefix = led_values["led_type"].lower()+'_' - cleaned_data["led"] = { - "led_type": led_values["led_type"], - **{ - name.removeprefix(prefix): value - for name, value in led_values.items() - if name.startswith(prefix) - } - } - for key, value in tuple(cleaned_data["led"].items()): - if value is None: - field_name = f'led_{prefix}{key}' - if self.fields[field_name].min_value == -1: - cleaned_data[key] = -1 - else: - errors[field_name] = _('this field is required') - - if has_uwb: + if uwb_values: cleaned_data["uwb"] = uwb_values - for key, value in tuple(cleaned_data["uwb"].items()): - if value is None: - field_name = f'uwb_{key}' - if self.fields[field_name].min_value == -1 or not cleaned_data["uwb"]["enable"]: - cleaned_data[key] = -1 - else: - errors[field_name] = _('this field is required') - if errors: + try: + TypeAdapter(BoardConfig).validate_python(cleaned_data) + except PydanticValidationError as e: + from pprint import pprint + pprint(e.errors()) + errors = {} + for error in e.errors(): + loc = "_".join(s for s in error["loc"] if not s.isupper()) + if loc in orig_cleaned_keys: + errors.setdefault(loc, []).append(error["msg"]) + else: + errors.setdefault("__all__", []).append(f"{loc}: {error['msg']}") raise ValidationError(errors) return cleaned_data diff --git a/src/c3nav/mesh/management/commands/generate_c_types.py b/src/c3nav/mesh/management/commands/generate_c_types.py index 41f768b4..e4e1587b 100644 --- a/src/c3nav/mesh/management/commands/generate_c_types.py +++ b/src/c3nav/mesh/management/commands/generate_c_types.py @@ -1,52 +1,47 @@ -from dataclasses import fields - from django.core.management.base import BaseCommand -from c3nav.mesh.baseformats import StructType, normalize_name -from c3nav.mesh.messages import MeshMessage +from c3nav.mesh.cformats import UnionFormat, normalize_name, CFormat +from c3nav.mesh.messages import MeshMessageContent from c3nav.mesh.utils import indent_c class Command(BaseCommand): help = 'export mesh message structs for c code' + @staticmethod + def get_msg_c_enum_name(msg_type): + return normalize_name(msg_type.__name__.removeprefix('Mesh').removesuffix('Message')).upper() + def handle(self, *args, **options): - done_struct_names = set() nodata = set() struct_lines = {} struct_sizes = [] struct_max_sizes = [] done_definitions = set() - for include in StructType.c_includes: + mesh_msg_content_format = CFormat.from_annotation(MeshMessageContent) + if not isinstance(mesh_msg_content_format, UnionFormat): + raise Exception('wuah') + discriminator_size = mesh_msg_content_format.discriminator_format.get_size() + for include in mesh_msg_content_format.get_c_includes(): print(f'#include {include}') - ignore_names = set(field_.name for field_ in fields(MeshMessage)) - for msg_type, msg_class in MeshMessage.get_types().items(): - if msg_class.c_struct_name: - if msg_class.c_struct_name in done_struct_names: - continue - done_struct_names.add(msg_class.c_struct_name) - if MeshMessage.c_structs[msg_class.c_struct_name] != msg_class: - # the purpose of MeshMessage.c_structs is unclear, currently this never triggers - # todo get rid of the whole c_structs thing if it doesn't turn out to be useful for anything - raise ValueError('what happened?') - - base_name = (msg_class.c_struct_name or normalize_name( - getattr(msg_type, 'name', msg_class.__name__) - )) + for msg_type, msg_content_format in mesh_msg_content_format.models.items(): + base_name = normalize_name(mesh_msg_content_format.key_to_name[msg_type]) name = "mesh_msg_%s_t" % base_name - for definition_name, definition in msg_class.get_c_definitions().items(): + for definition_name, definition in msg_content_format.get_c_definitions().items(): if definition_name not in done_definitions: done_definitions.add(definition_name) print(definition) print() - code = msg_class.get_c_code(name, ignore_fields=ignore_names, no_empty=True) + code = msg_content_format.get_c_code(name, ignore_fields=('msg_type', ), no_empty=True) if code: - size = msg_class.get_size(no_inherited_fields=True, calculate_max=False) - max_size = msg_class.get_size(no_inherited_fields=True, calculate_max=True) + size = msg_content_format.get_size(calculate_max=False) + max_size = msg_content_format.get_size(calculate_max=True) + size -= discriminator_size + max_size -= discriminator_size struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', '')) struct_sizes.append(size) struct_max_sizes.append(max_size) @@ -55,13 +50,13 @@ class Command(BaseCommand): (name, size)) print() else: - nodata.add(msg_class) + nodata.add(msg_content_format.model) print("/** union between all message data structs */") print("typedef union __packed {") for line in struct_lines.values(): print(indent_c(line)) - print("} mesh_msg_data_t; ") + print("} mesh_msg_data_t;") print( "static_assert(sizeof(mesh_msg_data_t) == %d, \"size of generated message structs is calculated wrong\");" % max(struct_sizes) @@ -72,20 +67,18 @@ class Command(BaseCommand): print() - max_msg_type = max(MeshMessage.get_types().keys()) + max_msg_type = max(mesh_msg_content_format.models.keys()) macro_data = [] for i in range(((max_msg_type//16)+1)*16): - msg_class = MeshMessage.get_types().get(i, None) - if msg_class: - name = (msg_class.c_struct_name or normalize_name( - getattr(msg_class.msg_type, 'name', msg_class.__name__) - )) + msg_content_format = mesh_msg_content_format.models.get(i, None) + if msg_content_format: + name = normalize_name(mesh_msg_content_format.key_to_name[i]) macro_data.append(( - msg_class.get_c_enum_name(), - ("nodata" if msg_class in nodata else name), - msg_class.get_var_num(), - msg_class.get_size(no_inherited_fields=True, calculate_max=True), - msg_class.__doc__.strip(), + self.get_msg_c_enum_name(msg_content_format.model), + ("nodata" if msg_content_format.model in nodata else name), + msg_content_format.get_var_num(), # todo: uh? + msg_content_format.get_size(calculate_max=True) - discriminator_size, + msg_content_format.model.__doc__.strip(), )) else: macro_data.append(( diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index aee537a7..ce12abeb 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -1,15 +1,16 @@ -from dataclasses import dataclass, field -from enum import IntEnum, unique -from typing import TypeVar +from enum import unique +from typing import Annotated, Union import channels +from annotated_types import Ge, Le, Lt, MaxLen from channels.db import database_sync_to_async +from pydantic import PositiveInt +from pydantic.main import BaseModel +from pydantic.types import Discriminator, NonNegativeInt +from pydantic_extra_types.mac_address import MacAddress -from c3nav.api.utils import EnumSchemaByNameMixin -from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat, - VarBytesFormat, VarStrFormat, normalize_name) -from c3nav.mesh.dataformats import (BoardConfig, ChipType, FirmwareAppDescription, MacAddressesListFormat, - MacAddressFormat, RangeResultItem, RawFTMEntry) +from c3nav.mesh.cformats import CDoc, CEmbed, CName, LenBytes, NoDef, VarLen, discriminator_value, CEnum +from c3nav.mesh.schemas import BoardConfig, ChipType, FirmwareAppDescription, RangeResultItem, RawFTMEntry from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP MESH_ROOT_ADDRESS = '00:00:00:00:00:00' @@ -23,45 +24,45 @@ OTA_CHUNK_SIZE = 512 @unique -class MeshMessageType(EnumSchemaByNameMixin, IntEnum): - NOOP = 0x00 +class MeshMessageType(CEnum): + NOOP = "NOOP", 0x00 - ECHO_REQUEST = 0x01 - ECHO_RESPONSE = 0x02 + ECHO_REQUEST = "ECHO_REQUEST", 0x01 + ECHO_RESPONSE = "ECHO_RESPONSE", 0x02 - MESH_SIGNIN = 0x03 - MESH_LAYER_ANNOUNCE = 0x04 - MESH_ADD_DESTINATIONS = 0x05 - MESH_REMOVE_DESTINATIONS = 0x06 - MESH_ROUTE_REQUEST = 0x07 - MESH_ROUTE_RESPONSE = 0x08 - MESH_ROUTE_TRACE = 0x09 - MESH_ROUTING_FAILED = 0x0a + MESH_SIGNIN = "MESH_SIGNIN", 0x03 + MESH_LAYER_ANNOUNCE = "MESH_LAYER_ANNOUNCE", 0x04 + MESH_ADD_DESTINATIONS = "MESH_ADD_DESTINATIONS", 0x05 + MESH_REMOVE_DESTINATIONS = "MESH_REMOVE_DESTINATIONS", 0x06 + MESH_ROUTE_REQUEST = "MESH_ROUTE_REQUEST", 0x07 + MESH_ROUTE_RESPONSE = "MESH_ROUTE_RESPONSE", 0x08 + MESH_ROUTE_TRACE = "MESH_ROUTE_TRACE", 0x09 + MESH_ROUTING_FAILED = "MESH_ROUTING_FAILED", 0x0a - CONFIG_DUMP = 0x10 - CONFIG_HARDWARE = 0x11 - CONFIG_BOARD = 0x12 - CONFIG_FIRMWARE = 0x13 - CONFIG_UPLINK = 0x14 - CONFIG_POSITION = 0x15 + CONFIG_DUMP = "CONFIG_DUMP", 0x10 + CONFIG_HARDWARE = "CONFIG_HARDWARE", 0x11 + CONFIG_BOARD = "CONFIG_BOARD", 0x12 + CONFIG_FIRMWARE = "CONFIG_FIRMWARE", 0x13 + CONFIG_UPLINK = "CONFIG_UPLINK", 0x14 + CONFIG_POSITION = "CONFIG_POSITION", 0x15 - OTA_STATUS = 0x20 - OTA_REQUEST_STATUS = 0x21 - OTA_START = 0x22 - OTA_URL = 0x23 - OTA_FRAGMENT = 0x24 - OTA_REQUEST_FRAGMENTS = 0x25 - OTA_SETTING = 0x26 - OTA_APPLY = 0x27 - OTA_ABORT = 0x28 + OTA_STATUS = "OTA_STATUS", 0x20 + OTA_REQUEST_STATUS = "OTA_REQUEST_STATUS", 0x21 + OTA_START = "OTA_START", 0x22 + OTA_URL = "OTA_URL", 0x23 + OTA_FRAGMENT = "OTA_FRAGMENT", 0x24 + OTA_REQUEST_FRAGMENTS = "OTA_REQUEST_FRAGMENTS", 0x25 + OTA_SETTING = "OTA_SETTING", 0x26 + OTA_APPLY = "OTA_APPLY", 0x27 + OTA_ABORT = "OTA_ABORT", 0x28 - LOCATE_REQUEST_RANGE = 0x30 - LOCATE_RANGE_RESULTS = 0x31 - LOCATE_RAW_FTM_RESULTS = 0x32 + LOCATE_REQUEST_RANGE = "LOCATE_REQUEST_RANGE", 0x30 + LOCATE_RANGE_RESULTS = "LOCATE_RANGE_RESULTS", 0x31 + LOCATE_RAW_FTM_RESULTS = "LOCATE_RAW_FTM_RESULTS", 0x32 - REBOOT = 0x40 + REBOOT = "REBOOT", 0x40 - REPORT_ERROR = 0x50 + REPORT_ERROR = "REPORT_ERROR", 0x50 @property def pretty_name(self): @@ -72,32 +73,266 @@ class MeshMessageType(EnumSchemaByNameMixin, IntEnum): return name -M = TypeVar('M', bound='MeshMessage') +class NoopMessage(discriminator_value(msg_type=MeshMessageType.NOOP), BaseModel): + """ noop """ + pass -@dataclass -class MeshMessage(StructType, union_type_field="msg_type"): - dst: str = field(metadata={"format": MacAddressFormat()}) - src: str = field(metadata={"format": MacAddressFormat()}) - msg_type: MeshMessageType = field(metadata={"format": EnumFormat('B', c_definition=False)}, init=False, repr=False) - c_structs = {} - c_struct_name = None +class EchoRequestMessage(discriminator_value(msg_type=MeshMessageType.ECHO_REQUEST), BaseModel): + """ repeat back string """ + content: Annotated[str, MaxLen(255), VarLen()] = "" - # noinspection PyMethodOverriding - def __init_subclass__(cls, /, c_struct_name=None, **kwargs): - super().__init_subclass__(**kwargs) - if c_struct_name: - cls.c_struct_name = c_struct_name - if c_struct_name in MeshMessage.c_structs: - raise TypeError('duplicate use of c_struct_name %s' % c_struct_name) - MeshMessage.c_structs[c_struct_name] = cls + +class EchoResponseMessage(discriminator_value(msg_type=MeshMessageType.ECHO_RESPONSE), BaseModel): + """ repeat back string """ + content: Annotated[str, MaxLen(255), VarLen()] = "" + + +class MeshSigninMessage(discriminator_value(msg_type=MeshMessageType.MESH_SIGNIN), BaseModel): + """ node says hello to upstream node """ + pass + + +class MeshLayerAnnounceMessage(discriminator_value(msg_type=MeshMessageType.MESH_LAYER_ANNOUNCE), BaseModel): + """ upstream node announces layer number """ + layer: Annotated[PositiveInt, Lt(2 ** 8), CDoc("mesh layer that the sending node is on")] + + +class MeshAddDestinationsMessage(discriminator_value(msg_type=MeshMessageType.MESH_ADD_DESTINATIONS), BaseModel): + """ downstream node announces served destination """ + addresses: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("adresses of the added destinations",)] + + +class MeshRemoveDestinationsMessage(discriminator_value(msg_type=MeshMessageType.MESH_REMOVE_DESTINATIONS), BaseModel): + """ downstream node announces no longer served destination """ + addresses: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("adresses of the removed destinations",)] + + +class MeshRouteRequestMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_REQUEST), BaseModel): + """ request routing information for node """ + request_id: Annotated[PositiveInt, Lt(2**32)] + address: Annotated[MacAddress, CDoc("target address for the route")] + + +class MeshRouteResponseMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_RESPONSE), BaseModel): + """ reporting the routing table entry to the given address """ + request_id: Annotated[PositiveInt, Lt(2**32)] + route: Annotated[MacAddress, CDoc("routing table entry or 00:00:00:00:00:00")] + + +class MeshRouteTraceMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_TRACE), BaseModel): + """ special message, collects all hop adresses on its way """ + request_id: Annotated[PositiveInt, Lt(2**32)] + trace: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("addresses encountered by this message")] + + +class MeshRoutingFailedMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTING_FAILED), BaseModel): + """ TODO description""" + address: MacAddress + + +class ConfigDumpMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_DUMP), BaseModel): + """ request for the node to dump its config """ + pass + + +class ConfigHardwareMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_HARDWARE), BaseModel): + """ respond hardware/chip info """ + chip: Annotated[ChipType, NoDef(), LenBytes(2), CName("chip_id")] + revision_major: Annotated[NonNegativeInt, Lt(2**8)] + revision_minor: Annotated[NonNegativeInt, Lt(2**8)] + + def get_chip_display(self): + return ChipType(self.chip).name.replace('_', '-') + + +class ConfigBoardMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_BOARD), BaseModel): + """ set/respond board config """ + board_config: Annotated[BoardConfig, CEmbed] + + +class ConfigFirmwareMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_FIRMWARE), BaseModel): + """ respond firmware info """ + app_desc: FirmwareAppDescription + + +class ConfigPositionMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_POSITION), BaseModel): + """ set/respond position config """ + x_pos: Annotated[int, Ge(-2**31), Lt(2**31)] + y_pos: Annotated[int, Ge(-2**31), Lt(2**31)] + z_pos: Annotated[int, Ge(-2**15), Lt(2**15)] + + +class ConfigUplinkMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_UPLINK), BaseModel): + """ set/respond uplink config """ + enabled: bool + ssid: Annotated[str, MaxLen(32)] + password: Annotated[str, MaxLen(64)] + channel: Annotated[PositiveInt, Le(15)] + udp: bool + ssl: bool + host: Annotated[str, MaxLen(64)] + port: Annotated[PositiveInt, Lt(2**16)] + + +@unique +class OTADeviceStatus(CEnum): + """ ota status, the ones >= 0x10 denote a permanent failure """ + NONE = "NONE", 0x00 + + STARTED = "STARTED", 0x01 + APPLIED = "APPLIED", 0x02 + + START_FAILED = "START_FAILED", 0x10 + WRITE_FAILED = "WRITE_FAILED", 0x12 + APPLY_FAILED = "APPLY_FAILED", 0x13 + ROLLED_BACK = "ROLLED_BACK", 0x14 + + @property + def pretty_name(self): + return self.name.replace('_', ' ').lower() + + @property + def is_failed(self): + return self >= self.START_FAILED + + +class OTAStatusMessage(discriminator_value(msg_type=MeshMessageType.OTA_STATUS), BaseModel): + """ report OTA status """ + update_id: Annotated[NonNegativeInt, Lt(2**32)] + received_bytes: Annotated[NonNegativeInt, Lt(2**32)] + next_expected_chunk: Annotated[NonNegativeInt, Lt(2**16)] + auto_apply: bool + auto_reboot: bool + status: OTADeviceStatus + + +class OTARequestStatusMessage(discriminator_value(msg_type=MeshMessageType.OTA_REQUEST_STATUS), BaseModel): + """ request OTA status """ + pass + + +class OTAStartMessage(discriminator_value(msg_type=MeshMessageType.OTA_START), BaseModel): + """ instruct node to start OTA """ + update_id: Annotated[PositiveInt, Lt(2**32)] + total_bytes: Annotated[PositiveInt, Lt(2**32)] + auto_apply: bool + auto_reboot: bool + + +class OTAURLMessage(discriminator_value(msg_type=MeshMessageType.OTA_URL), BaseModel): + """ supply download URL for OTA update and who to distribute it to """ + update_id: Annotated[PositiveInt, Lt(2**32)] + distribute_to: MacAddress + url: Annotated[str, MaxLen(255), VarLen()] + + +class OTAFragmentMessage(discriminator_value(msg_type=MeshMessageType.OTA_FRAGMENT), BaseModel): + """ supply OTA fragment """ + update_id: Annotated[PositiveInt, Lt(2**32)] + chunk: Annotated[PositiveInt, Lt(2**16)] + data: Annotated[bytes, MaxLen(OTA_CHUNK_SIZE), VarLen()] + + +class OTARequestFragmentsMessage(discriminator_value(msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS), BaseModel): + """ request missing fragments """ + update_id: Annotated[PositiveInt, Lt(2**32)] + chunks: Annotated[list[Annotated[PositiveInt, Lt(2**16)]], MaxLen(128), VarLen()] + + +class OTASettingMessage(discriminator_value(msg_type=MeshMessageType.OTA_SETTING), BaseModel): + """ configure whether to automatically apply and reboot when update is completed """ + update_id: Annotated[PositiveInt, Lt(2**32)] + auto_apply: bool + auto_reboot: bool + + +class OTAApplyMessage(discriminator_value(msg_type=MeshMessageType.OTA_APPLY), BaseModel): + """ apply OTA and optionally reboot """ + update_id: Annotated[PositiveInt, Lt(2**32)] + reboot: bool + + +class OTAAbortMessage(discriminator_value(msg_type=MeshMessageType.OTA_ABORT), BaseModel): + """ announcing OTA abort """ + update_id: Annotated[NonNegativeInt, Lt(2**32)] + + +class LocateRequestRangeMessage(discriminator_value(msg_type=MeshMessageType.LOCATE_REQUEST_RANGE), BaseModel): + """ request to report distance to all nearby nodes """ + pass + + +class LocateRangeResults(discriminator_value(msg_type=MeshMessageType.LOCATE_RANGE_RESULTS), BaseModel): + """ reports distance to given nodes """ + ranges: Annotated[list[RangeResultItem], MaxLen(16), VarLen()] + + +class LocateRawFTMResults(discriminator_value(msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS), BaseModel): + """ reports distance to given nodes """ + peer: MacAddress + results: Annotated[list[RawFTMEntry], MaxLen(16), VarLen()] + + +class Reboot(discriminator_value(msg_type=MeshMessageType.REBOOT), BaseModel): + """ reboot the device """ + pass + + +class ReportError(discriminator_value(msg_type=MeshMessageType.REPORT_ERROR), BaseModel): + """ report a critical error to upstream """ + message: Annotated[str, MaxLen(255), VarLen()] + + +MeshMessageContent = Annotated[ + Union[ + NoopMessage, + EchoRequestMessage, + EchoResponseMessage, + MeshSigninMessage, + MeshLayerAnnounceMessage, + MeshAddDestinationsMessage, + MeshRemoveDestinationsMessage, + MeshRouteRequestMessage, + MeshRouteResponseMessage, + MeshRouteTraceMessage, + MeshRoutingFailedMessage, + ConfigDumpMessage, + ConfigHardwareMessage, + ConfigBoardMessage, + ConfigFirmwareMessage, + ConfigPositionMessage, + ConfigUplinkMessage, + OTAStatusMessage, + OTARequestStatusMessage, + OTAStartMessage, + OTAURLMessage, + OTAFragmentMessage, + OTARequestFragmentsMessage, + OTASettingMessage, + OTAApplyMessage, + OTAAbortMessage, + LocateRequestRangeMessage, + LocateRangeResults, + LocateRawFTMResults, + Reboot, + ReportError, + ], + Discriminator("msg_type") +] + + +class MeshMessage(BaseModel): + dst: MacAddress + src: MacAddress + content: MeshMessageContent async def send(self, sender=None, exclude_uplink_address=None) -> bool: data = { "type": "mesh.send", "sender": sender, "exclude_uplink_address": exclude_uplink_address, - "msg": MeshMessage.tojson(self), + "msg": self.model_dump(), } if self.dst in (MESH_CHILDREN_ADDRESS, MESH_BROADCAST_ADDRESS): @@ -110,283 +345,4 @@ class MeshMessage(StructType, union_type_field="msg_type"): return False if uplink.node_id == exclude_uplink_address: return False - await channels.layers.get_channel_layer().send(uplink.name, data) - - @classmethod - def get_ignore_c_fields(self): - return set() - - @classmethod - def get_additional_c_fields(self): - return () - - @classmethod - def get_variable_name(cls, base_name): - return cls.c_struct_name or base_name - - @classmethod - def get_c_enum_name(cls): - return normalize_name(cls.__name__.removeprefix('Mesh').removesuffix('Message')).upper() - - -@dataclass -class NoopMessage(MeshMessage, msg_type=MeshMessageType.NOOP): - """ noop """ - pass - - -@dataclass -class EchoRequestMessage(MeshMessage, msg_type=MeshMessageType.ECHO_REQUEST): - """ repeat back string """ - content: str = field(default='', metadata={'format': VarStrFormat(max_len=255)}) - - -@dataclass -class EchoResponseMessage(MeshMessage, msg_type=MeshMessageType.ECHO_RESPONSE): - """ repeat back string """ - content: str = field(default='', metadata={'format': VarStrFormat(max_len=255)}) - - -@dataclass -class MeshSigninMessage(MeshMessage, msg_type=MeshMessageType.MESH_SIGNIN): - """ node says hello to upstream node """ - pass - - -@dataclass -class MeshLayerAnnounceMessage(MeshMessage, msg_type=MeshMessageType.MESH_LAYER_ANNOUNCE): - """ upstream node announces layer number """ - layer: int = field(metadata={ - "format": SimpleFormat('B'), - "doc": "mesh layer that the sending node is on", - }) - - -@dataclass -class MeshAddDestinationsMessage(MeshMessage, msg_type=MeshMessageType.MESH_ADD_DESTINATIONS): - """ downstream node announces served destination """ - addresses: list[str] = field(default_factory=list, metadata={ - "format": MacAddressesListFormat(max_num=16), - "doc": "adresses of the added destinations", - }) - - -@dataclass -class MeshRemoveDestinationsMessage(MeshMessage, msg_type=MeshMessageType.MESH_REMOVE_DESTINATIONS): - """ downstream node announces no longer served destination """ - addresses: list[str] = field(default_factory=list, metadata={ - "format": MacAddressesListFormat(max_num=16), - "doc": "adresses of the removed destinations", - }) - - -@dataclass -class MeshRouteRequestMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_REQUEST): - """ request routing information for node """ - request_id: int = field(metadata={"format": SimpleFormat('I')}) - address: str = field(metadata={ - "format": MacAddressFormat(), - "doc": "target address for the route" - }) - - -@dataclass -class MeshRouteResponseMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_RESPONSE): - """ reporting the routing table entry to the given address """ - request_id: int = field(metadata={"format": SimpleFormat('I')}) - route: str = field(metadata={ - "format": MacAddressFormat(), - "doc": "routing table entry or 00:00:00:00:00:00" - }) - - -@dataclass -class MeshRouteTraceMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_TRACE): - """ special message, collects all hop adresses on its way """ - request_id: int = field(metadata={"format": SimpleFormat('I')}) - trace: list[str] = field(default_factory=list, metadata={ - "format": MacAddressesListFormat(max_num=16), - "doc": "addresses encountered by this message", - }) - - -@dataclass -class MeshRoutingFailedMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTING_FAILED): - """ TODO description""" - address: str = field(metadata={"format": MacAddressFormat()}) - - -@dataclass -class ConfigDumpMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_DUMP): - """ request for the node to dump its config """ - pass - - -@dataclass -class ConfigHardwareMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_HARDWARE): - """ respond hardware/chip info """ - chip: ChipType = field(metadata={ - "format": EnumFormat("H", c_definition=False), - "c_name": "chip_id", - }) - revision_major: int = field(metadata={"format": SimpleFormat('B')}) - revision_minor: int = field(metadata={"format": SimpleFormat('B')}) - - def get_chip_display(self): - return ChipType(self.chip).name.replace('_', '-') - - -@dataclass -class ConfigBoardMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_BOARD): - """ set/respond board config """ - board_config: BoardConfig = field(metadata={"c_embed": True, "json_embed": True}) - - -@dataclass -class ConfigFirmwareMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_FIRMWARE): - """ respond firmware info """ - app_desc: FirmwareAppDescription = field(metadata={'json_embed': True}) - - -@dataclass -class ConfigPositionMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_POSITION): - """ set/respond position config """ - x_pos: int = field(metadata={"format": SimpleFormat('i')}) - y_pos: int = field(metadata={"format": SimpleFormat('i')}) - z_pos: int = field(metadata={"format": SimpleFormat('h')}) - - -@dataclass -class ConfigUplinkMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_UPLINK): - """ set/respond uplink config """ - enabled: bool = field(metadata={"format": BoolFormat()}) - ssid: str = field(metadata={"format": FixedStrFormat(32)}) - password: str = field(metadata={"format": FixedStrFormat(64)}) - channel: int = field(metadata={"format": SimpleFormat('B')}) - udp: bool = field(metadata={"format": BoolFormat()}) - ssl: bool = field(metadata={"format": BoolFormat()}) - host: str = field(metadata={"format": FixedStrFormat(64)}) - port: int = field(metadata={"format": SimpleFormat('H')}) - - -@unique -class OTADeviceStatus(EnumSchemaByNameMixin, IntEnum): - """ ota status, the ones >= 0x10 denote a permanent failure """ - NONE = 0x00 - - STARTED = 0x01 - APPLIED = 0x02 - - START_FAILED = 0x10 - WRITE_FAILED = 0x12 - APPLY_FAILED = 0x13 - ROLLED_BACK = 0x14 - - @property - def pretty_name(self): - return self.name.replace('_', ' ').lower() - - @property - def is_failed(self): - return self >= self.START_FAILED - - -@dataclass -class OTAStatusMessage(MeshMessage, msg_type=MeshMessageType.OTA_STATUS): - """ report OTA status """ - update_id: int = field(metadata={"format": SimpleFormat('I')}) - received_bytes: int = field(metadata={"format": SimpleFormat('I')}) - next_expected_chunk: int = field(metadata={"format": SimpleFormat('H')}) - auto_apply: bool = field(metadata={"format": BoolFormat()}) - auto_reboot: bool = field(metadata={"format": BoolFormat()}) - status: OTADeviceStatus = field(metadata={"format": EnumFormat('B')}) - - -@dataclass -class OTARequestStatusMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_STATUS): - """ request OTA status """ - pass - - -@dataclass -class OTAStartMessage(MeshMessage, msg_type=MeshMessageType.OTA_START): - """ instruct node to start OTA """ - update_id: int = field(metadata={"format": SimpleFormat('I')}) - total_bytes: int = field(metadata={"format": SimpleFormat('I')}) - auto_apply: bool = field(metadata={"format": BoolFormat()}) - auto_reboot: bool = field(metadata={"format": BoolFormat()}) - - -@dataclass -class OTAURLMessage(MeshMessage, msg_type=MeshMessageType.OTA_URL): - """ supply download URL for OTA update and who to distribute it to """ - update_id: int = field(metadata={"format": SimpleFormat('I')}) - distribute_to: str = field(metadata={"format": MacAddressFormat()}) - url: str = field(metadata={"format": VarStrFormat(max_len=255)}) - - -@dataclass -class OTAFragmentMessage(MeshMessage, msg_type=MeshMessageType.OTA_FRAGMENT): - """ supply OTA fragment """ - update_id: int = field(metadata={"format": SimpleFormat('I')}) - chunk: int = field(metadata={"format": SimpleFormat('H')}) - data: bytes = field(metadata={"format": VarBytesFormat(max_size=OTA_CHUNK_SIZE)}) - - -@dataclass -class OTARequestFragmentsMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS): - """ request missing fragments """ - update_id: int = field(metadata={"format": SimpleFormat('I')}) - chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'), max_num=128)}) - - -@dataclass -class OTASettingMessage(MeshMessage, msg_type=MeshMessageType.OTA_SETTING): - """ configure whether to automatically apply and reboot when update is completed """ - update_id: int = field(metadata={"format": SimpleFormat('I')}) - auto_apply: bool = field(metadata={"format": BoolFormat()}) - auto_reboot: bool = field(metadata={"format": BoolFormat()}) - - -@dataclass -class OTAApplyMessage(MeshMessage, msg_type=MeshMessageType.OTA_APPLY): - """ apply OTA and optionally reboot """ - update_id: int = field(metadata={"format": SimpleFormat('I')}) - reboot: bool = field(metadata={"format": BoolFormat()}) - - -@dataclass -class OTAAbortMessage(MeshMessage, msg_type=MeshMessageType.OTA_ABORT): - """ announcing OTA abort """ - update_id: int = field(metadata={"format": SimpleFormat('I')}) - - -@dataclass -class LocateRequestRangeMessage(MeshMessage, msg_type=MeshMessageType.LOCATE_REQUEST_RANGE): - """ request to report distance to all nearby nodes """ - pass - - -@dataclass -class LocateRangeResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RANGE_RESULTS): - """ reports distance to given nodes """ - ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem, max_num=16)}) - - -@dataclass -class LocateRawFTMResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS): - """ reports distance to given nodes """ - peer: str = field(metadata={"format": MacAddressFormat()}) - results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry, max_num=16)}) - - -@dataclass -class Reboot(MeshMessage, msg_type=MeshMessageType.REBOOT): - """ reboot the device """ - pass - - -@dataclass -class ReportError(MeshMessage, msg_type=MeshMessageType.REPORT_ERROR): - """ report a critical error to upstream """ - message: str = field(metadata={"format": VarStrFormat(max_len=255)}) + await channels.layers.get_channel_layer().send(uplink.name, data) \ No newline at end of file diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py index 6d420e98..0d67dc49 100644 --- a/src/c3nav/mesh/models.py +++ b/src/c3nav/mesh/models.py @@ -18,7 +18,7 @@ from django.utils.text import slugify from django.utils.translation import gettext_lazy as _ from c3nav.mapdata.models.geometry.space import RangingBeacon -from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage +from c3nav.mesh.schemas import BoardType, ChipType, FirmwareImage from c3nav.mesh.messages import ConfigFirmwareMessage, ConfigHardwareMessage from c3nav.mesh.messages import MeshMessage as MeshMessage from c3nav.mesh.messages import MeshMessageType @@ -383,7 +383,7 @@ class NodeMessage(models.Model): @cached_property def parsed(self) -> Self: - return MeshMessage.fromjson(self.data) + return MeshMessage.model_validate(self.data) class FirmwareVersion(models.Model): diff --git a/src/c3nav/mesh/schemas.py b/src/c3nav/mesh/schemas.py new file mode 100644 index 00000000..7eeb08fc --- /dev/null +++ b/src/c3nav/mesh/schemas.py @@ -0,0 +1,284 @@ +import re +from dataclasses import dataclass, field +from enum import unique +from typing import Annotated, BinaryIO, ClassVar, Literal, Self, Union + +from annotated_types import Gt, Le, Lt, MaxLen, Ge +from pydantic import NegativeInt, PositiveInt +from pydantic.main import BaseModel +from pydantic.types import Discriminator, NonNegativeInt +from pydantic_extra_types.mac_address import MacAddress + +from c3nav.mesh.cformats import AsDefinition, AsHex, CName, ExistingCStruct, discriminator_value, \ + CEnum, TwoNibblesEncodable + + +@unique +class LedType(CEnum): + NONE = "NONE", 0 + SERIAL = "SERIAL", 1 + MULTIPIN = "MULTIPIN", 2 + + @property + def pretty_name(self): + return self.name.lower() + + +@unique +class SerialLedType(CEnum): + WS2812 = "WS2812", 1 + SK6812 = "SK6812", 2 + + +class NoLedConfig(discriminator_value(led_type=LedType.NONE), BaseModel): + pass + + +class SerialLedConfig(discriminator_value(led_type=LedType.SERIAL), BaseModel): + serial_led_type: Annotated[SerialLedType, CName("type")] + gpio: Annotated[PositiveInt, Lt(2**8)] + + +class MultipinLedConfig(discriminator_value(led_type=LedType.MULTIPIN), BaseModel): + gpio_red: Annotated[PositiveInt, Lt(2**8)] + gpio_green: Annotated[PositiveInt, Lt(2**8)] + gpio_blue: Annotated[PositiveInt, Lt(2**8)] + + +LedConfig = Annotated[ + Union[ + NoLedConfig, + SerialLedConfig, + MultipinLedConfig, + ], + Discriminator("led_type") +] + + +class BoardSPIConfig(BaseModel): + """ + configuration for spi bus used for ETH or UWB + """ + gpio_miso: Annotated[PositiveInt, Lt(2**8)] + gpio_mosi: Annotated[PositiveInt, Lt(2**8)] + gpio_clk: Annotated[PositiveInt, Lt(2**8)] + + +class UWBConfig(BaseModel): + """ + configuration for the connection to the UWB module + """ + enable: bool + gpio_cs: Annotated[PositiveInt, Lt(2**8)] + gpio_irq: Annotated[PositiveInt, Lt(2**8)] + gpio_rst: Annotated[PositiveInt, Lt(2**8)] + gpio_wakeup: Annotated[PositiveInt, Lt(2**8)] + gpio_exton: Annotated[PositiveInt, Lt(2**8)] + + +class UplinkEthConfig(BaseModel): + """ + configuration for the connection to the ETH module + """ + enable: bool + gpio_cs: Annotated[PositiveInt, Lt(2**8)] + gpio_int: Annotated[PositiveInt, Lt(2**8)] + gpio_rst: Annotated[int, Ge(-1), Lt(2**7)] + + +@unique +class BoardType(CEnum): + CUSTOM = "CUSTOM", 0x00 + + # devboards + ESP32_C3_DEVKIT_M_1 = "ESP32_C3_DEVKIT_M_1", 0x01 + ESP32_C3_32S = "ESP32_C3_32S", 0x02 + + # custom boards + C3NAV_UWB_BOARD = "C3NAV_UWB_BOARD", 0x10 + C3NAV_LOCATION_PCB_REV_0_1 = "C3NAV_LOCATION_PCB_REV_0_1", 0x11 + C3NAV_LOCATION_PCB_REV_0_2 = "C3NAV_LOCATION_PCB_REV_0_2", 0x12 + + @property + def pretty_name(self): + if self.name.startswith('ESP32'): + return self.name.replace('_', '-').replace('DEVKIT-', 'DevKit') + if self.name.startswith('C3NAV'): + name = self.name.replace('_', ' ').lower() + name = name.replace('uwb', 'UWB').replace('pcb', 'PCB') + name = re.sub(r'[0-9]+( [0-9+])+', lambda s: s[0].replace(' ', '.'), name) + name = re.sub(r'rev.*', lambda s: s[0].replace(' ', ''), name) + return name + return self.name + + +class CustomBoardConfig(discriminator_value(board=BoardType.CUSTOM), BaseModel): + spi: Annotated[BoardSPIConfig, AsDefinition()] + uwb: Annotated[UWBConfig, AsDefinition()] + eth: Annotated[UplinkEthConfig, AsDefinition()] + led: Annotated[LedConfig, AsDefinition()] + + +class DevkitMBoardConfig(discriminator_value(board=BoardType.ESP32_C3_DEVKIT_M_1), BaseModel): + spi: Annotated[BoardSPIConfig, AsDefinition()] + uwb: Annotated[UWBConfig, AsDefinition()] + eth: Annotated[UplinkEthConfig, AsDefinition()] + + +class Esp32SBoardConfig(discriminator_value(board=BoardType.ESP32_C3_32S), BaseModel): + spi: Annotated[BoardSPIConfig, AsDefinition()] + uwb: Annotated[UWBConfig, AsDefinition()] + eth: Annotated[UplinkEthConfig, AsDefinition()] + + +class UwbBoardConfig(discriminator_value(board=BoardType.C3NAV_UWB_BOARD), BaseModel): + eth: Annotated[UplinkEthConfig, AsDefinition()] + + +class LocationPCBRev0Dot1BoardConfig(discriminator_value(board=BoardType.C3NAV_LOCATION_PCB_REV_0_1), BaseModel): + eth: Annotated[UplinkEthConfig, AsDefinition()] + + +class LocationPCBRev0Dot2BoardConfig(discriminator_value(board=BoardType.C3NAV_LOCATION_PCB_REV_0_2), BaseModel): + eth: Annotated[UplinkEthConfig, AsDefinition()] + + +BoardConfig = Annotated[ + Union[ + CustomBoardConfig, + DevkitMBoardConfig, + Esp32SBoardConfig, + UwbBoardConfig, + LocationPCBRev0Dot1BoardConfig, + LocationPCBRev0Dot2BoardConfig, + ], + Discriminator("board"), + AsHex(), +] + + +class RangeResultItem(BaseModel): + peer: MacAddress + rssi: Annotated[NegativeInt, Gt(-100)] + distance: Annotated[int, Gt(-32000), Lt(32000)] + + +class RawFTMEntry(BaseModel): + dlog_token: Annotated[PositiveInt, Lt(255)] + rssi: Annotated[NegativeInt, Gt(-100)] + rtt: Annotated[PositiveInt, Lt(2**32)] + t1: Annotated[PositiveInt, Lt(2**64)] + t2: Annotated[PositiveInt, Lt(2**64)] + t3: Annotated[PositiveInt, Lt(2**64)] + t4: Annotated[PositiveInt, Lt(2**64)] + + +class FirmwareAppDescription(BaseModel): + existing_c_struct: ClassVar = ExistingCStruct(name="esp_app_desc_t", includes=['']) + + magic_word: Literal[0xAB_CD_54_32] = field(repr=False) + secure_version: Annotated[NonNegativeInt, Lt(2**32)] + reserv1: Annotated[str, MaxLen(8*2), AsHex()] = field(repr=False) + version: Annotated[str, MaxLen(32)] + project_name: Annotated[str, MaxLen(32)] + compile_time: Annotated[str, MaxLen(16)] + compile_date: Annotated[str, MaxLen(16)] + idf_version: Annotated[str, MaxLen(32)] + app_elf_sha256: Annotated[str, MaxLen(64), AsHex()] + reserv2: Annotated[str, MaxLen(20*4*2), AsHex()] = field(repr=False) + + +@unique +class SPIFlashMode(CEnum): + QIO = "QID", 0 + QOUT = "QOUT", 1 + DIO = "DIO", 2 + DOUT = "DOUT", 3 + + +@unique +class FlashSize(CEnum): + SIZE_1MB = "SIZE_1MB", 0 + SIZE_2MB = "SIZE_2MB", 1 + SIZE_4MB = "SIZE_4MB", 2 + SIZE_8MB = "SIZE_8MB", 3 + SIZE_16MB = "SIZE_16MB", 4 + SIZE_32MB = "SIZE_32MB", 5 + SIZE_64MB = "SIZE_64MB", 6 + SIZE_128MB = "SIZE_128MB", 7 + + @property + def pretty_name(self): + return self.name.removeprefix('SIZE_') + + +@unique +class FlashFrequency(CEnum): + FREQ_40MHZ = "FREQ_40MHZ", 0 + FREQ_26MHZ = "FREQ_26MHZ", 1 + FREQ_20MHZ = "FREQ_20MHZ", 2 + FREQ_80MHZ = "FREQ_80MHZ", 0xf + + @property + def pretty_name(self): + return self.name.removeprefix('FREQ_').replace('MHZ', 'Mhz') + + +@dataclass +class FlashSettings(TwoNibblesEncodable): + size: FlashSize + frequency: FlashFrequency + + @property + def display(self): + return f"{self.size.pretty_name} ({self.frequency.pretty_name})" + + +@unique +class ChipType(CEnum): + ESP32_S2 = "ESP32_S2", 2 + ESP32_C3 = "ESP32_C3", 5 + + @property + def pretty_name(self): + return self.name.replace('_', '-') + + +class FirmwareImageFileHeader(BaseModel): + magic_word: Literal[0xE9] = field(repr=False) + num_segments: Annotated[PositiveInt, Lt(2**8)] + spi_flash_mode: SPIFlashMode + flash_stuff: FlashSettings + entry_point: Annotated[PositiveInt, Lt(2**32)] + + +class FirmwareImageFileHeader(BaseModel): + major: int + minor: int + num_segments: Annotated[PositiveInt, Lt(2**8)] + spi_flash_mode: SPIFlashMode + flash_stuff: FlashSettings + entry_point: Annotated[PositiveInt, Lt(2**32)] + + +class FirmwareImageExtendedFileHeader(BaseModel): + wp_pin: Annotated[PositiveInt, Lt(2**8)] + drive_settings: Annotated[bytes, MaxLen(3)] + chip: Annotated[ChipType, Lt(2**16)] + min_chip_rev_old: int + min_chip_rev: Annotated[PositiveInt, Le(9999)] + max_chip_rev: Annotated[PositiveInt, Le(9999)] + reserv: Annotated[bytes, MaxLen(4)] = field(repr=False) + hash_appended: bool + + +class FirmwareImage(BaseModel): + header: FirmwareImageFileHeader + ext_header: FirmwareImageExtendedFileHeader + first_segment_headers: Annotated[bytes, MaxLen(2)] = field(repr=False) + app_desc: FirmwareAppDescription + + @classmethod + def from_file(cls, file: BinaryIO) -> Self: + result, data = cls.decode(file.read(FirmwareImage.get_min_size())) + return result diff --git a/src/c3nav/mesh/utils.py b/src/c3nav/mesh/utils.py index 8a21fdbe..b2e18b28 100644 --- a/src/c3nav/mesh/utils.py +++ b/src/c3nav/mesh/utils.py @@ -12,7 +12,7 @@ UPLINK_TIMEOUT = UPLINK_PING+5 def indent_c(code): - return " "+code.replace("\n", "\n ") + return " "+code.replace("\n", "\n ").replace("\n \n", "\n\n") def get_node_names(): diff --git a/src/c3nav/routing/api/positioning.py b/src/c3nav/routing/api/positioning.py index 8ddb8beb..bb0ad6d8 100644 --- a/src/c3nav/routing/api/positioning.py +++ b/src/c3nav/routing/api/positioning.py @@ -82,7 +82,7 @@ def locate_test(request): None ) return { - "ranges": msg.parsed.tojson(msg.parsed)["ranges"], + "ranges": msg.parsed.model_dump()["ranges"], "datetime": msg.datetime, "location": location.serialize(simple_geometry=True) if location else None } diff --git a/src/requirements/production.txt b/src/requirements/production.txt index 90d60f97..67e46de3 100644 --- a/src/requirements/production.txt +++ b/src/requirements/production.txt @@ -3,6 +3,7 @@ django-bootstrap3==23.6 django-compressor==4.4 csscompressor==0.9.5 django-ninja==1.1.0 +pydantic-extra-types==2.5.0 django-filter==23.5 django-environ==0.11.2 shapely==2.0.3