change all of the MeshMessages c-models from dataclasses to pydantic

This commit is contained in:
Laura Klünder 2024-02-27 18:25:18 +01:00
parent bd1a143d31
commit 0fd789173a
14 changed files with 1691 additions and 1628 deletions

View file

@ -1,52 +1,5 @@
from typing import Annotated, Any, Type
from typing import Annotated
import annotated_types
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema, core_schema
def get_api_post_data(request):
is_json = request.META.get('CONTENT_TYPE').lower() == 'application/json'
if is_json:
try:
data = request.json_body
except AttributeError:
pass # todo fix this raise ParseError('Invalid JSON.')
return data
return request.POST
class EnumSchemaByNameMixin:
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
json_schema = handler(core_schema)
json_schema = handler.resolve_ref_schema(json_schema)
json_schema["enum"] = [m.name for m in cls]
json_schema["type"] = "string"
return json_schema
@classmethod
def __get_pydantic_core_schema__(
cls, source: Type[Any], handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls.validate,
core_schema.any_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x.name),
)
@classmethod
def validate(cls, v: int):
if isinstance(v, cls):
return v
try:
return cls[v]
except KeyError:
pass
return cls(v)
NonEmptyStr = Annotated[str, annotated_types.MinLen(1)]

View file

@ -12,8 +12,8 @@ from pydantic import PositiveInt, field_validator
from c3nav.api.auth import APIKeyAuth, auth_permission_responses, auth_responses, validate_responses
from c3nav.api.exceptions import API404, APIConflict, APIRequestValidationFailed
from c3nav.api.schema import BaseSchema
from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage
from c3nav.mesh.messages import MeshMessageType
from c3nav.mesh.schemas import BoardType, ChipType, FirmwareImage
from c3nav.mesh.messages import MeshMessageType, MeshMessage
from c3nav.mesh.models import FirmwareBuild, FirmwareVersion, NodeMessage
mesh_api_router = APIRouter(tags=["mesh"], auth=APIKeyAuth(permissions={"mesh_control"}))
@ -93,7 +93,7 @@ def firmware_by_id(request, firmware_id: int):
@mesh_api_router.get('/firmwares/{firmware_id}/{variant}/image_data',
summary="firmware image header",
description="get firmware image header for specific firmware build",
response={200: FirmwareImage.schema, **API404.dict(), **auth_responses},
response={200: FirmwareImage, **API404.dict(), **auth_responses},
openapi_extra={
"externalDocs": {
'description': 'esp-idf docs',
@ -105,7 +105,7 @@ def firmware_by_id(request, firmware_id: int):
def firmware_build_image(request, firmware_id: int, variant: str):
try:
build = FirmwareBuild.objects.get(version_id=firmware_id, variant=variant)
return FirmwareImage.tojson(build.firmware_image)
return build.firmware_image.model_dump()
except FirmwareVersion.DoesNotExist:
raise API404("Firmware or firmware build not found")
@ -218,7 +218,7 @@ class NodeMessageSchema(BaseSchema):
src_node: NodeAddress
message_type: MeshMessageType
datetime: datetime
data: dict
data: MeshMessage
@staticmethod
def resolve_src_node(obj):

View file

@ -1,810 +0,0 @@
import re
import struct
import typing
from abc import ABC, abstractmethod
from dataclasses import Field, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Self, Sequence
from pydantic import create_model
from c3nav.mesh.utils import indent_c
class BaseFormat(ABC):
def get_var_num(self):
return 0
@abstractmethod
def encode(self, value):
pass
@classmethod
@abstractmethod
def decode(cls, data) -> tuple[Any, bytes]:
pass
def fromjson(self, data):
return data
def tojson(self, data):
return data
@abstractmethod
def get_min_size(self):
pass
@abstractmethod
def get_max_size(self):
pass
def get_size(self, calculate_max=False):
if calculate_max:
return self.get_max_size()
else:
return self.get_min_size()
@abstractmethod
def get_c_parts(self) -> tuple[str, str]:
pass
def get_c_code(self, name) -> str:
pre, post = self.get_c_parts()
return "%s %s%s;" % (pre, name, post)
def set_field_type(self, field_type):
self.field_type = field_type
def get_c_definitions(self) -> dict[str, str]:
return {}
def get_typedef_name(self):
return '%s_t' % normalize_name(self.field_type.__name__)
class SimpleFormat(BaseFormat):
def __init__(self, fmt):
self.fmt = fmt
self.size = struct.calcsize(fmt)
self.c_type = self.c_types[self.fmt[-1]]
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
def encode(self, value):
if self.num == 1:
return struct.pack(self.fmt, value)
return struct.pack(self.fmt, *value)
def decode(self, data: bytes) -> tuple[Any, bytes]:
value = struct.unpack(self.fmt, data[:self.size])
if len(value) == 1:
value = value[0]
return value, data[self.size:]
def get_min_size(self):
return self.size
def get_max_size(self):
return self.size
c_types = {
"B": "uint8_t",
"H": "uint16_t",
"I": "uint32_t",
"Q": "uint64_t",
"b": "int8_t",
"h": "int16_t",
"i": "int32_t",
"q": "int64_t",
"s": "char",
}
def get_c_parts(self):
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
class SimpleConstFormat(SimpleFormat):
def __init__(self, fmt, const_value: int):
super().__init__(fmt)
self.const_value = const_value
def decode(self, data: bytes) -> tuple[Any, bytes]:
value, out_data = super().decode(data)
if value != self.const_value:
raise ValueError('const_value is wrong')
return value, out_data
class EnumFormat(SimpleFormat):
def __init__(self, fmt="B", *, as_hex=False, c_definition=True):
super().__init__(fmt)
self.as_hex = as_hex
self.c_definition = c_definition
def set_field_type(self, field_type):
super().set_field_type(field_type)
self.c_struct_name = normalize_name(field_type.__name__) + '_t'
def decode(self, data: bytes) -> tuple[Any, bytes]:
value, out_data = super().decode(data)
return self.field_type(value), out_data
def get_c_parts(self):
if not self.c_definition:
return super().get_c_parts()
return self.c_struct_name, ""
def fromjson(self, data):
return self.field_type[data]
def tojson(self, data):
return data.name
def get_c_definitions(self) -> dict[str, str]:
if not self.c_definition:
return {}
prefix = normalize_name(self.field_type.__name__).upper()
options = []
last_value = None
for item in self.field_type:
if last_value is not None and item.value != last_value + 1:
options.append('')
last_value = item.value
options.append("%(prefix)s_%(name)s = %(value)s," % {
"prefix": prefix,
"name": normalize_name(item.name).upper(),
"value": ("0x%02x" if self.as_hex else "%d") % item.value
})
return {
self.c_struct_name: "enum {\n%(options)s\n};\ntypedef uint8_t %(name)s;" % {
"options": indent_c("\n".join(options)),
"name": self.c_struct_name,
}
}
class TwoNibblesEnumFormat(SimpleFormat):
def __init__(self):
super().__init__('B')
def decode(self, data: bytes) -> tuple[bool, bytes]:
fields = dataclass_fields(self.field_type)
value, data = super().decode(data)
return self.field_type(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data
def encode(self, value):
fields = dataclass_fields(self.field_type)
return super().encode(
getattr(value, fields[0].name).value * 2 ** 4 +
getattr(value, fields[1].name).value * 2 ** 4
)
def fromjson(self, data):
fields = dataclass_fields(self.field_type)
return self.field_type(*(field.type[data[field.name]] for field in fields))
def tojson(self, data):
fields = dataclass_fields(self.field_type)
return {
field.name: getattr(data, field.name).name for field in fields
}
class ChipRevFormat(SimpleFormat):
def __init__(self):
super().__init__('H')
def decode(self, data: bytes) -> tuple[tuple[int, int], bytes]:
value, data = super().decode(data)
return (value // 100, value % 100), data
def encode(self, value):
return value[0] * 100 + value[1]
class BoolFormat(SimpleFormat):
def __init__(self):
super().__init__('B')
def encode(self, value):
return super().encode(int(value))
def decode(self, data: bytes) -> tuple[bool, bytes]:
value, data = super().decode(data)
if value > 1:
raise ValueError('Boolean value > 1')
return bool(value), data
class FixedStrFormat(SimpleFormat):
def __init__(self, num):
self.num = num
super().__init__('%ds' % self.num)
def encode(self, value: str):
return value.encode()[:self.num].ljust(self.num, bytes((0,))),
def decode(self, data: bytes) -> tuple[str, bytes]:
return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:]
class FixedHexFormat(SimpleFormat):
def __init__(self, num, sep=''):
self.num = num
self.sep = sep
super().__init__('%dB' % self.num)
def encode(self, value: str):
return super().encode(tuple(bytes.fromhex(value.replace(':', ''))))
def decode(self, data: bytes) -> tuple[str, bytes]:
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
@abstractmethod
class BaseVarFormat(BaseFormat, ABC):
def __init__(self, max_num):
self.num_fmt = 'H'
self.num_size = struct.calcsize(self.num_fmt)
self.max_num = max_num
def get_min_size(self):
return self.num_size
def get_max_size(self):
return self.num_size + self.max_num * self.get_var_num()
def get_num_c_code(self):
return SimpleFormat(self.num_fmt).get_c_code("num")
class VarArrayFormat(BaseVarFormat):
def __init__(self, child_type, max_num):
super().__init__(max_num=max_num)
self.child_type = child_type
self.child_size = self.child_type.get_min_size()
def get_var_num(self):
return self.child_size
pass
def encode(self, values: Sequence) -> bytes:
num = len(values)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
data = struct.pack(self.num_fmt, num)
for value in values:
data += self.child_type.encode(value)
return data
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
data = data[self.num_size:]
result = []
for i in range(num):
item, data = self.child_type.decode(data)
result.append(item)
return result, data
def fromjson(self, data):
return [
self.child_type.fromjson(item) for item in data
]
def tojson(self, data):
return [
self.child_type.tojson(item) for item in data
]
def get_c_parts(self):
pre, post = self.child_type.get_c_parts()
return super().get_num_c_code() + "\n" + pre, "[0]" + post
class VarStrFormat(BaseVarFormat):
def __init__(self, max_len):
super().__init__(max_num=max_len)
def get_var_num(self):
return 1
def encode(self, value: str) -> bytes:
num = len(value)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return struct.pack(self.num_fmt, num) + value.encode()
def decode(self, data: bytes) -> tuple[str, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:]
def get_c_parts(self):
return super().get_num_c_code() + "\n" + "char", "[0]"
class VarBytesFormat(BaseVarFormat):
def __init__(self, max_size):
super().__init__(max_num=max_size)
def get_var_num(self):
return 1
def encode(self, value: bytes) -> bytes:
num = len(value)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return struct.pack(self.num_fmt, num) + value
def decode(self, data: bytes) -> tuple[bytes, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:]
def get_c_parts(self):
return super().get_num_c_code() + "\n" + "uint8_t", "[0]"
@dataclass
class StructType:
_union_options = {}
union_type_field = None
existing_c_struct = None
c_includes = set()
@classmethod
def get_field_format(cls, attr_name):
attr = getattr(cls, attr_name, None)
fields = [f for f in dataclass_fields(cls) if f.name == attr_name]
if not fields:
raise TypeError(f"{cls}.{attr_name} not a field")
field = fields[0]
type_ = typing.get_type_hints(cls)[attr_name]
if "format" in field.metadata:
field_format = field.metadata["format"]
field_format.set_field_type(type_)
return field_format
if issubclass(type_, StructType):
return type_
raise TypeError('field %s.%s has no format and is no StructType' %
(cls.__class__.__name__, attr_name))
# noinspection PyMethodOverriding
def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs):
cls.union_type_field = union_type_field
if c_includes is not None:
cls.c_includes |= set(c_includes)
if cls.existing_c_struct is not None:
# TODO: can we make it possible? does it even make sense?
raise TypeError('subclassing an external c struct is not possible')
cls.existing_c_struct = existing_c_struct
if union_type_field:
if union_type_field in cls._union_options:
raise TypeError('Duplicate union_type_field: %s', union_type_field)
cls._union_options[union_type_field] = {}
f = getattr(cls, union_type_field)
metadata = dict(f.metadata)
metadata['union_discriminator'] = True
f.metadata = metadata
f.repr = False
f.init = False
for attr_name in cls.__dict__.keys():
attr = getattr(cls, attr_name)
if isinstance(attr, Field):
metadata = dict(attr.metadata)
if "defining_class" not in metadata:
metadata["defining_class"] = cls
attr.metadata = metadata
for key, values in cls._union_options.items():
value = kwargs.pop(key, None)
if value is not None:
if value in values:
raise TypeError('Duplicate %s: %s', (key, value))
values[value] = cls
setattr(cls, key, value)
# pydantic model
cls._pydantic_fields = getattr(cls, '_pydantic_fields', {}).copy()
fields = []
for field_ in dataclass_fields(cls):
fields.append((field_.name, field_.type, field_.metadata))
for attr_name in tuple(cls.__annotations__.keys()):
attr = getattr(cls, attr_name, None)
metadata = attr.metadata if isinstance(attr, Field) else {}
try:
type_ = cls.__annotations__[attr_name]
except KeyError:
# print('nope', cls, attr_name)
continue
fields.append((attr_name, type_, metadata))
for name, type_, metadata in fields:
try:
field_format = cls.get_field_format(name)
except TypeError:
# todo: in case of not a field, ignore it?
continue
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
cls._pydantic_fields[name] = (type_, ...)
else:
if metadata.get("json_embed"):
cls._pydantic_fields.update(type_._pydantic_fields)
else:
cls._pydantic_fields[name] = (type_.schema, ...)
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
super().__init_subclass__(**kwargs)
@classmethod
def get_var_num(cls):
return sum([cls.get_field_format(f.name).get_var_num() for f in dataclass_fields(cls)], start=0)
@classmethod
def get_types(cls):
if not cls.union_type_field:
raise TypeError('Not a union class')
return cls._union_options[cls.union_type_field]
@classmethod
def get_type(cls, type_id) -> Self:
if not cls.union_type_field:
raise TypeError('Not a union class')
return cls.get_types()[type_id]
@classmethod
def encode(cls, instance, ignore_fields=()) -> bytes:
data = bytes()
if cls.union_type_field and type(instance) is not cls:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in dataclass_fields(cls):
data += cls.get_field_format(field_.name).encode(getattr(instance, field_.name))
# todo: better
data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
return data
for field_ in dataclass_fields(cls):
if field_.name in ignore_fields:
continue
value = getattr(instance, field_.name)
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
data += field_format.encode(value)
else:
if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value))
data += value.encode(value)
return data
@classmethod
def decode(cls, data: bytes) -> tuple[Self, bytes]:
orig_data = data
kwargs = {}
no_init_data = {}
for field_ in dataclass_fields(cls):
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
value, data = field_format.decode(data)
else:
value, data = field_.type.decode(data)
if field_.init:
kwargs[field_.name] = value
else:
no_init_data[field_.name] = value
if cls.union_type_field:
try:
type_value = no_init_data[cls.union_type_field]
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls.get_type(type_value)
except KeyError:
raise TypeError('union_type_field %s.%s value %r no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.decode(orig_data)
return cls(**kwargs), data
@classmethod
def tojson(cls, instance) -> dict:
result = {}
if cls.union_type_field and type(instance) is not cls:
if not isinstance(instance, cls):
raise ValueError('expected value of type %r, got %r' % (cls, instance))
for field_ in dataclass_fields(instance):
if field_.name is cls.union_type_field:
result[field_.name] = cls.get_field_format(field_.name).tojson(getattr(instance, field_.name))
break
else:
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
result.update(instance.tojson(instance))
return result
for field_ in dataclass_fields(cls):
value = getattr(instance, field_.name)
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
result[field_.name] = field_format.tojson(value)
else:
if not isinstance(value, field_.type):
raise ValueError('expected value of type %r for %s.%s, got %r' %
(field_.type, cls.__name__, field_.name, value))
json_val = value.tojson(value)
if field_.metadata.get("json_embed"):
for k, v in json_val.items():
result[k] = v
else:
result[field_.name] = value.tojson(value)
return result
@classmethod
def upgrade_json(cls, data):
return data
@classmethod
def fromjson(cls, data: dict) -> Self:
data = data.copy()
# todo: upgrade_json
cls.upgrade_json(data)
kwargs = {}
no_init_data = {}
for field_ in dataclass_fields(cls):
raw_value = data.get(field_.name, None)
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
value = field_format.fromjson(raw_value)
else:
if field_.metadata.get("json_embed"):
value = field_.type.fromjson(data)
else:
value = field_.type.fromjson(raw_value)
if field_.init:
kwargs[field_.name] = value
else:
no_init_data[field_.name] = value
if cls.union_type_field:
try:
type_value = no_init_data.pop(cls.union_type_field)
except KeyError:
raise TypeError('union_type_field %s.%s is missing' %
(cls.__name__, cls.union_type_field))
try:
klass = cls.get_type(type_value)
except KeyError:
raise TypeError('union_type_field %s.%s value 0x%02x no known' %
(cls.__name__, cls.union_type_field, type_value))
return klass.fromjson(data)
return cls(**kwargs)
@classmethod
def get_c_struct_items(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False):
ignore_fields = set() if not ignore_fields else set(ignore_fields)
items = []
for field_ in dataclass_fields(cls):
if field_.name in ignore_fields:
continue
if in_union and field_.metadata["defining_class"] != cls:
continue
name = field_.metadata.get("c_name", field_.name)
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls:
items.append((
(
("%(typedef_name)s %(name)s;" % {
"typedef_name": field_format.get_typedef_name(),
"name": name,
})
if field_.metadata.get("as_definition")
else field_format.get_c_code(name)
),
field_.metadata.get("doc", None),
)),
else:
if field_.metadata.get("c_embed"):
embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only)
items.extend(embedded_items)
else:
items.append((
(
("%(typedef_name)s %(name)s;" % {
"typedef_name": field_.type.get_typedef_name(),
"name": name,
})
if field_.metadata.get("as_definition")
else field_.type.get_c_code(name, typedef=False)
),
field_.metadata.get("doc", None),
))
if cls.union_type_field:
if not union_only:
union_code = cls.get_c_union_code(ignore_fields)
items.append(("union __packed %s;" % union_code, ""))
return items
@classmethod
def get_c_union_size(cls):
return max(
(option.get_min_size(no_inherited_fields=True) for option in
cls._union_options[cls.union_type_field].values()),
default=0,
)
@classmethod
def get_c_definitions(cls) -> dict[str, str]:
definitions = {}
for field_ in dataclass_fields(cls):
field_format = cls.get_field_format(field_.name)
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
definitions.update(field_format.get_c_definitions())
if field_.metadata.get("as_definition"):
typedef_name = field_format.get_typedef_name()
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
"code": ''.join(field_format.get_c_parts()),
"name": typedef_name,
}
else:
definitions.update(field_.type.get_c_definitions())
if field_.metadata.get("as_definition"):
typedef_name = field_.type.get_typedef_name()
definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True)
if cls.union_type_field:
for key, option in cls._union_options[cls.union_type_field].items():
definitions.update(option.get_c_definitions())
return definitions
@classmethod
def get_c_union_code(cls, ignore_fields=None):
union_items = []
for key, option in cls._union_options[cls.union_type_field].items():
base_name = normalize_name(getattr(key, 'name', option.__name__))
item_c_code = option.get_c_code(
base_name, ignore_fields=ignore_fields, typedef=False, in_union=True, no_empty=True
)
if item_c_code:
union_items.append(item_c_code)
size = cls.get_c_union_size()
union_items.append(
"uint8_t bytes[%0d]; " % size
)
return "{\n" + indent_c("\n".join(union_items)) + "\n}"
@classmethod
def get_c_parts(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False):
if cls.existing_c_struct is not None:
return (cls.existing_c_struct, "")
ignore_fields = set() if not ignore_fields else set(ignore_fields)
if union_only:
if cls.union_type_field:
union_code = cls.get_c_union_code(ignore_fields)
return "typedef union __packed %s" % union_code, ""
else:
return "", ""
pre = ""
items = cls.get_c_struct_items(ignore_fields=ignore_fields,
no_empty=no_empty,
top_level=top_level,
union_only=union_only,
in_union=in_union)
if no_empty and not items:
return "", ""
# todo: struct comment
if top_level:
comment = cls.__doc__.strip()
if comment:
pre += "/** %s */\n" % comment
pre += "typedef struct __packed "
else:
pre += "struct __packed "
pre += "{\n%(elements)s\n}" % {
"elements": indent_c(
"\n".join(
code + ("" if not comment else (" /** %s */" % comment))
for code, comment in items
)
),
}
return pre, ""
@classmethod
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False,
in_union=False) -> str:
pre, post = cls.get_c_parts(ignore_fields=ignore_fields,
no_empty=no_empty,
top_level=typedef,
union_only=union_only,
in_union=in_union)
if no_empty and not pre and not post:
return ""
return "%s %s%s;" % (pre, name, post)
@classmethod
def get_variable_name(cls, base_name):
return base_name
@classmethod
def get_typedef_name(cls):
return "%s_t" % normalize_name(cls.__name__)
@classmethod
def get_min_size(cls, no_inherited_fields=False) -> int:
if cls.union_type_field:
own_size = sum([cls.get_field_format(f.name).get_min_size() for f in dataclass_fields(cls)])
union_size = max(
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
return own_size + union_size
if no_inherited_fields:
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else:
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((cls.get_field_format(f.name).get_min_size() for f in relevant_fields), start=0)
@classmethod
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
if cls.union_type_field:
own_size = sum(
[cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
union_size = max(
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
cls._union_options[cls.union_type_field].values()])
return own_size + union_size
if no_inherited_fields:
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
else:
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
return sum((cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in relevant_fields),
start=0)
def normalize_name(name):
if '_' in name:
name = name.lower()
else:
name = re.sub(
r"(([a-z])([A-Z]))|(([a-zA-Z])([A-Z][a-z]))",
r"\2\5_\3\6",
name
).lower()
name = re.sub(
r"(ota)([a-z])",
r"\1_\2",
name
).lower()
name = name.replace('config', 'cfg')
name = name.replace('position', 'pos')
name = name.replace('mesh_', '')
name = name.replace('firmware', 'fw')
name = name.replace('hardware', 'hw')
return name

961
src/c3nav/mesh/cformats.py Normal file
View file

@ -0,0 +1,961 @@
import re
import struct
import typing
from abc import ABC, abstractmethod
from collections import namedtuple
from contextlib import suppress
from dataclasses import dataclass
from dataclasses import fields as dataclass_fields
from enum import IntEnum, Enum
from typing import Any, Sequence, Self, Annotated, Literal, Union, Type, TypeVar, ClassVar
from annotated_types import SLOTS, BaseMetadata, Ge
from pydantic.fields import Field, FieldInfo
from pydantic_extra_types.mac_address import MacAddress
from c3nav.mesh.utils import indent_c
@dataclass(frozen=True, **SLOTS)
class VarLen(BaseMetadata):
var_len_name: str = "num"
@dataclass(frozen=True, **SLOTS)
class NoDef(BaseMetadata):
no_def: bool = True
@dataclass(frozen=True, **SLOTS)
class AsHex(BaseMetadata):
as_hex: bool = True
@dataclass(frozen=True, **SLOTS)
class LenBytes(BaseMetadata):
len_bytes: Annotated[int, Ge(1)]
@dataclass(frozen=True, **SLOTS)
class AsDefinition(BaseMetadata):
as_definition: bool = True
@dataclass(frozen=True, **SLOTS)
class CEmbed(BaseMetadata):
c_embed: bool = True
@dataclass(frozen=True, **SLOTS)
class CName(BaseMetadata):
c_name: str
@dataclass(frozen=True, **SLOTS)
class CDoc(BaseMetadata):
c_doc: str
@dataclass
class ExistingCStruct():
name: str
includes: list[str]
class CEnum(str, Enum):
def __new__(cls, value, c_value):
obj = str.__new__(cls)
obj._value_ = value
obj.c_value = c_value
return obj
def __hash__(self):
return hash(self.value)
def discriminator_value(**kwargs):
return type('DiscriminatorValue', (), {
# todo: make this so pydantic doesn't throw a warning
**{name: value for name, value in kwargs.items()},
'__annotations__': {
name: Annotated[Literal[value], Field(init=False)]
for name, value in kwargs.items()
}
})
class TwoNibblesEncodable:
pass
class SplitTypeHint(namedtuple("SplitTypeHint", ("base", "metadata"))):
@classmethod
def from_annotation(cls, type_hint) -> Self:
if typing.get_origin(type_hint) is Annotated:
field_infos = tuple(m for m in type_hint.__metadata__ if isinstance(m, FieldInfo))
return cls(
base=typing.get_args(type_hint)[0],
metadata=(
*(m for m in type_hint.__metadata__),
*(tuple(field_infos[0].metadata) if field_infos else ())
)
)
if isinstance(type_hint, FieldInfo):
return cls(
base=type_hint.annotation,
metadata=tuple(type_hint.metadata)
)
return cls(
base=type_hint,
metadata=()
)
def get_len_metadata(self):
max_length = None
var_len_name = None
for m in self.metadata:
ml = getattr(m, 'max_length', None)
if ml is not None:
max_length = ml if max_length is None else min(max_length, ml)
vl = getattr(m, 'var_len_name', None)
if vl is not None:
if var_len_name is not None:
raise ValueError('can\'t set variable length name twice')
var_len_name = vl
return max_length, var_len_name
def get_min_max_metadata(self, default_min=-(2 ** 63), default_max=2 ** 63 - 1):
min_ = default_min
max_ = default_max
for m in self.metadata:
gt = getattr(m, 'gt', None)
if gt is not None:
min_ = max(min_, gt + 1)
ge = getattr(m, 'ge', None)
if ge is not None:
min_ = max(min_, ge)
lt = getattr(m, 'lt', None)
if lt is not None:
max_ = min(max_, lt - 1)
le = getattr(m, 'le', None)
if le is not None:
max_ = min(max_, le)
return min_, max_
def normalize_name(name):
if '_' in name:
name = name.lower()
else:
name = re.sub(
r"(([a-z])([A-Z]))|(([a-zA-Z])([A-Z][a-z]))",
r"\2\5_\3\6",
name
).lower()
name = re.sub(
r"(ota)([a-z])",
r"\1_\2",
name
).lower()
name = name.replace('config', 'cfg')
name = name.replace('position', 'pos')
name = name.replace('mesh_', '')
name = name.replace('firmware', 'fw')
name = name.replace('hardware', 'hw')
return name
class CFormat(ABC):
# todo: make this some cool generic with a TypeVar
def get_var_num(self):
return 0
@abstractmethod
def encode(self, value):
pass
@classmethod
@abstractmethod
def decode(cls, data) -> tuple[Any, bytes]:
pass
@abstractmethod
def get_min_size(self):
pass
@abstractmethod
def get_max_size(self):
pass
def get_size(self, calculate_max=False):
if calculate_max:
return self.get_max_size()
else:
return self.get_min_size()
@abstractmethod
def get_c_parts(self) -> tuple[str, str]:
pass
def get_c_code(self, name) -> str:
pre, post = self.get_c_parts()
return "%s %s%s;" % (pre, name, post)
def get_c_definitions(self) -> dict[str, str]:
return {}
def get_typedef_name(self):
raise TypeError('no typedef for %r' % self)
def get_c_includes(self) -> set:
return set()
@classmethod
def from_annotation(cls, annotation, attr_name=None) -> Self:
if cls is not CFormat:
raise TypeError('call on CFormat!')
return cls.from_split_type_hint(SplitTypeHint.from_annotation(annotation), attr_name=attr_name)
@classmethod
def from_split_type_hint(cls, type_hint: SplitTypeHint, attr_name=None) -> Self:
if cls is not CFormat:
raise TypeError('call on CFormat!')
outer_type_hint = None
if typing.get_origin(type_hint.base) is list:
outer_type_hint = SplitTypeHint(
base=list,
metadata=type_hint.metadata
)
type_hint = SplitTypeHint(
base=typing.get_args(type_hint.base)[0],
metadata=()
)
if typing.get_origin(type_hint.base) is Annotated:
type_hint = SplitTypeHint(
base=typing.get_args(type_hint.base)[0],
metadata=tuple(type_hint.base.__metadata__)
)
field_format = None
if typing.get_origin(type_hint.base) is Literal:
literal_val = typing.get_args(type_hint.base)[0]
if isinstance(literal_val, CEnum):
options = [v.c_value for v in type(literal_val)]
literal_val = literal_val.c_value
int_type = get_int_type(
*type_hint.get_min_max_metadata(default_min=min(options), default_max=max(options))
)
elif isinstance(literal_val, int):
int_type = get_int_type(literal_val, literal_val)
else:
raise ValueError()
if int_type is None:
raise ValueError('invalid range:', attr_name)
field_format = SimpleConstFormat(int_type, const_value=literal_val)
elif typing.get_origin(type_hint.base) is Union:
discriminator = None
for m in type_hint.metadata:
discriminator = getattr(m, 'discriminator', discriminator)
if discriminator is None:
raise ValueError('no discriminator')
discriminator_as_hex = any(getattr(m, "as_hex", False) for m in type_hint.metadata)
field_format = UnionFormat(
model_formats=[StructFormat(type_) for type_ in typing.get_args(type_hint.base)],
discriminator=discriminator,
discriminator_as_hex=discriminator_as_hex,
)
elif type_hint.base is int:
int_type = get_int_type(*type_hint.get_min_max_metadata())
if int_type is None:
raise ValueError('invalid range:', attr_name)
field_format = SimpleFormat(int_type)
elif type_hint.base is bool:
field_format = BoolFormat()
elif type_hint.base in (str, bytes):
as_hex = any(getattr(m, 'as_hex', False) for m in type_hint.metadata)
max_length, var_len_name = type_hint.get_len_metadata()
if max_length is None:
raise ValueError('missing str max_length:', attr_name)
if type_hint.base is str:
if var_len_name is not None:
field_format = VarStrFormat(max_len=max_length)
else:
field_format = FixedHexFormat(max_length//2) if as_hex else FixedStrFormat(max_length)
else:
if var_len_name is None:
field_format = FixedBytesFormat(num=max_length)
else:
field_format = VarBytesFormat(max_size=max_length)
elif type_hint.base is MacAddress:
field_format = MacAddressFormat()
elif isinstance(type_hint.base, type) and issubclass(type_hint.base, CEnum):
no_def = any(getattr(m, 'no_def', False) for m in type_hint.metadata)
as_hex = any(getattr(m, 'as_hex', False) for m in type_hint.metadata)
len_bytes = None
for m in type_hint.metadata:
len_bytes = getattr(m, 'len_bytes', len_bytes)
if len_bytes:
int_type = get_int_type(0, 2 ** (8 * len_bytes - 1))
else:
options = [v.c_value for v in type_hint.base]
int_type = get_int_type(min(options), max(options))
if int_type is None:
raise ValueError('invalid range:', attr_name)
field_format = EnumFormat(enum_cls=type_hint.base, fmt=int_type, as_hex=as_hex, c_definition=not no_def)
elif isinstance(type_hint.base, type) and issubclass(type_hint.base, TwoNibblesEncodable):
field_format = TwoNibblesEnumFormat(type_hint.base)
elif isinstance(type_hint.base, type) and typing.get_type_hints(type_hint.base):
field_format = StructFormat(model=type_hint.base)
if field_format is None:
raise ValueError('Unknown type annotation for c structs', type_hint.base)
else:
if outer_type_hint is not None and outer_type_hint.base is list:
max_length, var_len_name = outer_type_hint.get_len_metadata()
if max_length is None:
raise ValueError('missing list max_length:', attr_name)
if var_len_name:
field_format = VarArrayFormat(field_format, max_num=max_length)
else:
raise ValueError('fixed-len list not implemented:', attr_name)
return field_format
def get_int_type(min_: int, max_: int) -> str | None:
if min_ < 0:
if min_ < -(2 ** 63) or max_ > 2 ** 63 - 1:
return None
elif min_ < -(2 ** 31) or max_ > 2 ** 31 - 1:
return "q"
elif min_ < -(2 ** 15) or max_ > 2 ** 15 - 1:
return "i"
elif min_ < -(2 ** 7) or max_ > 2 ** 7 - 1:
return "h"
else:
return "b"
if max_ > 2 ** 64 - 1:
return None
elif max_ > 2 ** 32 - 1:
return "Q"
elif max_ > 2 ** 16 - 1:
return "I"
elif max_ > 2 ** 8 - 1:
return "H"
else:
return "B"
class SimpleFormat(CFormat):
def __init__(self, fmt):
self.fmt = fmt
self.size = struct.calcsize(fmt)
self.c_type = self.c_types[self.fmt[-1]]
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
def encode(self, value):
if self.num == 1:
return struct.pack(self.fmt, value)
return struct.pack(self.fmt, *value)
def decode(self, data: bytes) -> tuple[Any, bytes]:
value = struct.unpack(self.fmt, data[:self.size])
if len(value) == 1:
value = value[0]
return value, data[self.size:]
def get_min_size(self):
return self.size
def get_max_size(self):
return self.size
c_types = {
"B": "uint8_t",
"H": "uint16_t",
"I": "uint32_t",
"Q": "uint64_t",
"b": "int8_t",
"h": "int16_t",
"i": "int32_t",
"q": "int64_t",
"s": "char",
}
def get_c_parts(self):
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
class SimpleConstFormat(SimpleFormat):
def __init__(self, fmt, const_value: int):
super().__init__(fmt)
self.const_value = const_value
def decode(self, data: bytes) -> tuple[Any, bytes]:
value, out_data = super().decode(data)
if value != self.const_value:
raise ValueError('const_value is wrong')
return value, out_data
class EnumFormat(SimpleFormat):
def __init__(self, enum_cls: Type[CEnum], fmt="B", *, as_hex=False, c_definition=True):
super().__init__(fmt)
self.enum_cls = enum_cls
self.enum_lookup = {v.c_value: v for v in enum_cls}
if len(self.enum_cls) != len(self.enum_lookup):
raise ValueError
self.as_hex = as_hex
self.c_definition = c_definition
self.c_struct_name = normalize_name(enum_cls.__name__) + '_t'
def decode(self, data: bytes) -> tuple[Any, bytes]:
value, out_data = super().decode(data)
return self.enum_lookup[value], out_data
def get_typedef_name(self):
return '%s_t' % normalize_name(self.enum_cls.__name__)
def get_c_parts(self):
if not self.c_definition:
return super().get_c_parts()
return self.c_struct_name, ""
def get_c_definitions(self) -> dict[str, str]:
if not self.c_definition:
return {}
prefix = normalize_name(self.enum_cls.__name__).upper()
options = []
last_value = None
for item in self.enum_cls:
if last_value is not None and item.c_value != last_value + 1:
options.append('')
last_value = item.c_value
options.append("%(prefix)s_%(name)s = %(value)s," % {
"prefix": prefix,
"name": normalize_name(item.name).upper(),
"value": ("0x%02x" if self.as_hex else "%d") % item.c_value
})
return {
self.c_struct_name: "enum {\n%(options)s\n};\ntypedef uint8_t %(name)s;" % {
"options": indent_c("\n".join(options)),
"name": self.c_struct_name,
}
}
class TwoNibblesEnumFormat(SimpleFormat):
def __init__(self, data_cls):
self.data_cls = data_cls
super().__init__('B')
def decode(self, data: bytes) -> tuple[bool, bytes]:
fields = dataclass_fields(self.data_cls)
value, data = super().decode(data)
return self.data_cls(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data
def encode(self, value):
fields = dataclass_fields(self.data_cls)
return super().encode(
getattr(value, fields[0].name).value * 2 ** 4 +
getattr(value, fields[1].name).value * 2 ** 4
)
class BoolFormat(SimpleFormat):
def __init__(self):
super().__init__('B')
def encode(self, value):
return super().encode(int(value))
def decode(self, data: bytes) -> tuple[bool, bytes]:
value, data = super().decode(data)
if value > 1:
raise ValueError('Boolean value > 1')
return bool(value), data
class FixedStrFormat(SimpleFormat):
def __init__(self, num):
self.num = num
super().__init__('%ds' % self.num)
def encode(self, value: str):
return value.encode()[:self.num].ljust(self.num, bytes((0,))),
def decode(self, data: bytes) -> tuple[str, bytes]:
return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:]
class FixedBytesFormat(SimpleFormat):
def __init__(self, num):
self.num = num
super().__init__('%dB' % self.num)
def encode(self, value: str):
return super().encode(tuple(value))
def decode(self, data: bytes) -> tuple[bytes, bytes]:
return data[:self.num], data[self.num:]
class FixedHexFormat(SimpleFormat):
def __init__(self, num, sep=''):
self.num = num
self.sep = sep
super().__init__('%dB' % self.num)
def encode(self, value: str):
return super().encode(tuple(bytes.fromhex(value.replace(':', ''))))
def decode(self, data: bytes) -> tuple[str, bytes]:
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
class MacAddressFormat(FixedHexFormat):
def __init__(self):
super().__init__(num=6, sep=':')
class BaseVarFormat(CFormat, ABC):
def __init__(self, max_num):
self.num_fmt = 'H'
self.num_size = struct.calcsize(self.num_fmt)
self.max_num = max_num
def get_min_size(self):
return self.num_size
def get_max_size(self):
return self.num_size + self.max_num * self.get_var_num()
def get_num_c_code(self):
return SimpleFormat(self.num_fmt).get_c_code("num")
class VarArrayFormat(BaseVarFormat):
def __init__(self, child_type, max_num):
super().__init__(max_num=max_num)
self.child_type = child_type
self.child_size = self.child_type.get_min_size()
def get_var_num(self):
return self.child_size
pass
def encode(self, values: Sequence) -> bytes:
num = len(values)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
data = struct.pack(self.num_fmt, num)
for value in values:
data += self.child_type.encode(value)
return data
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
data = data[self.num_size:]
result = []
for i in range(num):
item, data = self.child_type.decode(data)
result.append(item)
return result, data
def get_c_parts(self):
pre, post = self.child_type.get_c_parts()
return super().get_num_c_code() + "\n" + pre, "[0]" + post
class VarStrFormat(BaseVarFormat):
def __init__(self, max_len):
super().__init__(max_num=max_len)
def get_var_num(self):
return 1
def encode(self, value: str) -> bytes:
num = len(value)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return struct.pack(self.num_fmt, num) + value.encode()
def decode(self, data: bytes) -> tuple[str, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:]
def get_c_parts(self):
return super().get_num_c_code() + "\n" + "char", "[0]"
class VarBytesFormat(BaseVarFormat):
def __init__(self, max_size):
super().__init__(max_num=max_size)
def get_var_num(self):
return 1
def encode(self, value: bytes) -> bytes:
num = len(value)
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return struct.pack(self.num_fmt, num) + value
def decode(self, data: bytes) -> tuple[bytes, bytes]:
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
if num > self.max_num:
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:]
def get_c_parts(self):
return super().get_num_c_code() + "\n" + "uint8_t", "[0]"
T = TypeVar('T')
class CFormatDecodeError(Exception):
pass
class StructFormat(CFormat):
_format_cache: dict[Type, dict[str, CFormat]] = {}
def __new__(cls, model: Type[T]):
result = cls._format_cache.get(model, None)
if not result:
result = super().__new__(cls)
cls._format_cache.get(model, result)
return result
def __init__(self, model: Type[T]):
self.model = model
self._field_formats = {}
self._as_definition = set()
self._c_embed = set()
self._c_names = {}
self._c_docs = {}
self._no_init_data = set()
for name, type_hint in typing.get_type_hints(self.model, include_extras=True).items():
if type_hint is ClassVar:
continue
type_hint = SplitTypeHint.from_annotation(type_hint)
if any(getattr(m, "as_definition", False) for m in type_hint.metadata):
self._as_definition.add(name)
if any(getattr(m, "c_embed", False) for m in type_hint.metadata):
self._c_embed.add(name)
if not all(getattr(m, "init", True) for m in type_hint.metadata):
self._no_init_data.add(name)
for m in type_hint.metadata:
with suppress(AttributeError):
self._c_names[name] = m.c_name
with suppress(AttributeError):
self._c_docs[name] = m.c_doc
self._field_formats[name] = CFormat.from_split_type_hint(type_hint, attr_name=name)
def get_var_num(self):
return sum([field_format.get_var_num() for name, field_format in self._field_formats.items()], start=0)
def encode(self, instance: T, ignore_fields=()) -> bytes:
data = bytes()
for name, field_format in self._field_formats.items():
if name in ignore_fields:
continue
data += field_format.encode(getattr(instance, name))
return data
def decode(self, data: bytes) -> tuple[T, bytes]:
decoded = {}
for name, field_format in self._field_formats.items():
try:
value, data = field_format.decode(data)
except (struct.error, UnicodeDecodeError, ValueError) as e:
raise CFormatDecodeError(f"failed to decode model={self.model}, field={name}, data={data}, e={e}")
if isinstance(value, CEnum):
value = value.value
if name not in self._no_init_data:
decoded[name] = value
return self.model.model_validate(decoded), data
def get_min_size(self) -> int:
return sum((
field_format.get_min_size() for field_format in self._field_formats.values()
), start=0)
def get_max_size(self) -> int:
raise ValueError
def get_size(self, calculate_max=False):
return sum((
field_format.get_size(calculate_max=calculate_max) for field_format in self._field_formats.values()
), start=0)
def get_c_struct_items(self, ignore_fields=None, no_empty=False, top_level=False):
ignore_fields = set() if not ignore_fields else set(ignore_fields)
items = []
for name, field_format in self._field_formats.items():
if name in ignore_fields:
continue
c_name = self._c_names.get(name, name)
if not isinstance(field_format, (StructFormat, UnionFormat)):
items.append((
(
("%(typedef_name)s %(name)s;" % {
"typedef_name": field_format.get_typedef_name(),
"name": c_name,
})
if name in self._as_definition
else field_format.get_c_code(c_name)
),
self._c_docs.get(name, None),
)),
else:
if name in self._c_embed:
embedded_items = field_format.get_c_struct_items(ignore_fields, no_empty, top_level)
items.extend(embedded_items)
else:
items.append((
(
("%(typedef_name)s %(name)s;" % {
"typedef_name": field_format.get_typedef_name(),
"name": c_name,
})
if name in self._as_definition
else field_format.get_c_code(c_name, typedef=False)
),
self._c_docs.get(name, None),
))
return items
def get_c_parts(self, ignore_fields=None, no_empty=False, top_level=False) -> tuple[str, str]:
with suppress(AttributeError):
return (self.model.existing_c_struct.name, "")
ignore_fields = set() if not ignore_fields else set(ignore_fields)
pre = ""
items = self.get_c_struct_items(ignore_fields=ignore_fields,
no_empty=no_empty,
top_level=top_level)
if no_empty and not items:
return "", ""
if top_level:
comment = self.model.__doc__.strip()
if comment:
pre += "/** %s */\n" % comment
pre += "typedef struct __packed "
else:
pre += "struct __packed "
pre += "{\n%(elements)s\n}" % {
"elements": indent_c(
"\n".join(
code + ("" if not comment else (" /** %s */" % comment))
for code, comment in items
)
),
}
return pre, ""
def get_c_code(self, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str:
pre, post = self.get_c_parts(ignore_fields=ignore_fields,
no_empty=no_empty,
top_level=typedef)
if no_empty and not pre and not post:
return ""
return "%s %s%s;" % (pre, name, post)
def get_c_definitions(self) -> dict[str, str]:
definitions = {}
for name, field_format in self._field_formats.items():
definitions.update(field_format.get_c_definitions())
if name in self._as_definition:
typedef_name = field_format.get_typedef_name()
if not isinstance(field_format, StructFormat):
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
"code": ''.join(field_format.get_c_parts()),
"name": typedef_name,
}
else:
definitions[typedef_name] = field_format.get_c_code(name=typedef_name, typedef=True)
return definitions
def get_typedef_name(self):
return "%s_t" % normalize_name(self.model.__name__)
def get_c_includes(self) -> set:
result = set()
with suppress(AttributeError):
result.update(self.model.existing_c_struct.includes)
for field_format in self._field_formats.values():
result.update(field_format.get_c_includes())
return result
class UnionFormat(CFormat):
def __init__(self, model_formats: Sequence[StructFormat], discriminator: str, discriminator_as_hex: bool = False):
self.discriminator = discriminator
models = {
getattr(model_format.model, discriminator): model_format for model_format in model_formats
}
if len(models) != len(model_formats):
raise ValueError
types = set(type(value) for value in models.keys())
if len(types) != 1:
raise ValueError
discriminator_annotation = tuple(types)[0]
if discriminator_as_hex:
discriminator_annotation = Annotated[discriminator_annotation, AsHex()]
self.discriminator_format = CFormat.from_annotation(discriminator_annotation)
self.key_to_name = {value.c_value: value.name for value in models.keys()}
self.models = {value.c_value: model_format for value, model_format in models.items()}
def get_var_num(self):
return 0 # todo: is this always correct?
def encode(self, instance) -> bytes:
discriminator_value = getattr(instance, self.discriminator)
try:
model_format = self.models[discriminator_value.c_value]
except KeyError:
raise ValueError('Unknown discriminator value for Union: %r' % discriminator_value)
if not isinstance(instance, model_format.model):
raise ValueError('Unknown value for Union discriminator %r: %r' % (discriminator_value, instance))
return (
self.discriminator_format.encode(discriminator_value.c_value)
+ model_format.encode(instance, ignore_fields=(self.discriminator, ))
)
def decode(self, data: bytes) -> tuple[T, bytes]:
discriminator_value, remaining_data = self.discriminator_format.decode(data)
return self.models[discriminator_value.c_value].decode(data)
def get_min_size(self) -> int:
return max([0] + [
model_format.get_min_size()
for model_format in self.models.values()
])
def get_max_size(self) -> int:
raise ValueError
def get_size(self=False, calculate_max=False):
return max([0] + [
field_format.get_size(calculate_max=calculate_max)
for field_format in self.models.values()
])
def get_c_struct_items(self, ignore_fields=None, no_empty=False, top_level=False):
return [
(self.discriminator_format.get_c_code(self.discriminator), None),
("union __packed %s;" % self.get_c_union_code(), None),
]
def get_c_union_size(self):
return max(
(model_format.get_min_size() for model_format in self.models.values()),
default=0,
) - self.discriminator_format.get_min_size()
def get_c_union_code(self):
union_items = []
for key, model_format in self.models.items():
base_name = normalize_name(self.key_to_name[key])
item_c_code = model_format.get_c_code(
base_name, ignore_fields=(self.discriminator, ), typedef=False, no_empty=True
)
if item_c_code:
union_items.append(item_c_code)
size = self.get_c_union_size()
union_items.append(
"uint8_t bytes[%0d];" % size
)
return "{\n" + indent_c("\n".join(union_items)) + "\n}"
def get_c_parts(self, ignore_fields=None, no_empty=False, top_level=False) -> tuple[str, str]:
items = self.get_c_struct_items(no_empty=no_empty,
top_level=top_level)
if no_empty and not items:
return "", ""
if top_level:
pre = "typedef struct __packed "
else:
pre = "struct __packed "
pre += "{\n%(elements)s\n}" % {
"elements": indent_c(
"\n".join(
code + ("" if not comment else (" /** %s */" % comment))
for code, comment in items
)
),
}
return pre, ""
def get_c_code(self, name=None, ignore_fields=None, no_empty=False, typedef=True,) -> str:
pre, post = self.get_c_parts(ignore_fields=ignore_fields,
no_empty=no_empty,
top_level=typedef)
if no_empty and not pre and not post:
return ""
return "%s %s%s;" % (pre, name, post)
def get_c_definitions(self) -> dict[str, str]:
definitions = {}
definitions.update(self.discriminator_format.get_c_definitions())
for model_format in self.models.values():
definitions.update(model_format.get_c_definitions())
return definitions
def get_typedef_name(self):
names = [model_format.model.__name__ for model_format in self.models.values()]
min_len = min(len(name) for name in names)
longest_prefix = ''
longest_suffix = ''
for i in reversed(range(min_len)):
a = set(name[:i] for name in names)
if len(a) == 1:
longest_prefix = tuple(a)[0]
break
for i in reversed(range(min_len)):
a = set(name[-i:] for name in names)
if len(a) == 1:
longest_suffix = tuple(a)[0]
break
return "%s_t" % normalize_name(longest_prefix if len(longest_prefix) > len(longest_suffix) else longest_suffix)
def get_c_includes(self) -> set:
result = set()
result.update(self.discriminator_format.get_c_includes())
for model_format in self.models.values():
result.update(model_format.get_c_includes())
return result

View file

@ -17,6 +17,7 @@ from django.utils import timezone
from django.utils.crypto import constant_time_compare
from c3nav.mesh import messages
from c3nav.mesh.cformats import CFormat
from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, OTA_CHUNK_SIZE,
MeshMessage, MeshMessageType, OTAApplyMessage, OTASettingMessage)
from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage, OTARecipientStatus, OTAUpdate, OTAUpdateRecipient
@ -46,6 +47,7 @@ class NodeState:
class MeshConsumer(AsyncWebsocketConsumer):
mesh_msg_format = CFormat.from_annotation(MeshMessage)
def __init__(self):
super().__init__()
self.uplink = None
@ -56,6 +58,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
self.ota_send_task = None
self.ota_chunks: dict[int, set[int]] = {} # keys are update IDs, values are a list of chunk IDs
self.ota_chunks_available_condition = asyncio.Condition()
self.accepted = None
async def connect(self):
self.headers = dict(self.scope["headers"])
@ -68,13 +71,16 @@ class MeshConsumer(AsyncWebsocketConsumer):
self.ping_task = get_event_loop().create_task(self.ping_regularly())
self.check_node_state_task = get_event_loop().create_task(self.check_node_states())
self.ota_send_task = get_event_loop().create_task(self.ota_send())
self.accepted = True
async def disconnect(self, close_code):
if not self.accepted:
return
self.ping_task.cancel()
self.check_node_state_task.cancel()
self.ota_send_task.cancel()
await self.log_text(self.uplink.node, "mesh websocket disconnected")
if self.uplink is not None:
await self.log_text(self.uplink.node, "mesh websocket disconnected")
# leave broadcast group
await self.channel_layer.group_discard("mesh_comm_broadcast", self.channel_name)
@ -91,10 +97,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
end_reason=MeshUplink.EndReason.CLOSED
)
async def send_msg(self, msg, sender=None, exclude_uplink_address=None):
async def send_msg(self, msg: MeshMessage, sender=None, exclude_uplink_address=None):
# print("sending", msg, MeshMessage.encode(msg).hex(' ', 1))
# self.log_text(msg.dst, "sending %s" % msg)
await self.send(bytes_data=MeshMessage.encode(msg))
# self.log_text(msg_envelope.dst, "sending %s" % msg)
await self.send(bytes_data=self.mesh_msg_format.encode(msg))
await self.channel_layer.group_send("mesh_msg_sent", {
"type": "mesh.msg_sent",
"timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"),
@ -113,18 +119,20 @@ class MeshConsumer(AsyncWebsocketConsumer):
if bytes_data is None:
return
try:
msg, data = messages.MeshMessage.decode(bytes_data)
msg, data = self.mesh_msg_format.decode(bytes_data)
msg: MeshMessage
except Exception:
print("Unable to decode: ")
print("Unable to decode: msg_type=", hex(bytes_data[12]))
print(bytes_data)
traceback.print_exc()
return
msg.content = msg.content
# print(msg)
if msg.dst != messages.MESH_ROOT_ADDRESS and msg.dst != messages.MESH_PARENT_ADDRESS:
# message not adressed to us, forward it
print('Received message for forwarding:', msg)
print('Received message for forwarding:', msg.content)
if not self.uplink:
await self.log_text(None, "received message not for us before sign in message, ignoring...")
@ -132,12 +140,12 @@ class MeshConsumer(AsyncWebsocketConsumer):
return
# trace messages collect node adresses before forwarding
if isinstance(msg, messages.MeshRouteTraceMessage):
if isinstance(msg.content, messages.MeshRouteTraceMessage):
print('adding ourselves to trace message before forwarding')
await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding")
msg.trace.append(MESH_ROOT_ADDRESS)
msg.content.trace.append(MESH_ROOT_ADDRESS)
result = await msg.send(exclude_uplink_address=self.uplink.node.address)
result = await msg.content.send(exclude_uplink_address=self.uplink.node.address)
if not result:
print('message had no route')
@ -154,7 +162,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
src_node, created = await MeshNode.objects.aget_or_create(address=msg.src)
if isinstance(msg, messages.MeshSigninMessage):
if isinstance(msg.content, messages.MeshSigninMessage):
if not self.check_valid_address(msg.src):
print('reject node with invalid address address')
await self.close()
@ -172,10 +180,12 @@ class MeshConsumer(AsyncWebsocketConsumer):
await self.log_received_message(src_node, msg)
# inform signed in uplink node about its layer
await self.send_msg(messages.MeshLayerAnnounceMessage(
await self.send_msg(messages.MeshMessage(
src=messages.MESH_ROOT_ADDRESS,
dst=msg.src,
content=messages.MeshLayerAnnounceMessage(
layer=messages.NO_LAYER
)
))
# add signed in uplink node to broadcast group
@ -186,7 +196,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
# add this node as a destination that this uplink handles (duh)
await self.add_dst_nodes(nodes=(src_node, ))
self.dst_nodes[msg.src].last_msg[MeshMessageType.MESH_SIGNIN] = msg
self.dst_nodes[msg.src].last_msg[MeshMessageType.MESH_SIGNIN] = msg.content
return
@ -202,38 +212,42 @@ class MeshConsumer(AsyncWebsocketConsumer):
except KeyError:
print('unexpected message from', msg.src)
return
node_status.last_msg[msg.msg_type] = msg
node_status.last_msg[msg.content.msg_type] = msg.content
if isinstance(msg, messages.MeshAddDestinationsMessage):
result = await self.add_dst_nodes(addresses=msg.addresses)
if isinstance(msg.content, messages.MeshAddDestinationsMessage):
result = await self.add_dst_nodes(addresses=msg.content.addresses)
if not result:
print('disconnecting node that send invalid destinations', msg)
print('disconnecting node that send invalid destinations', msg.content)
await self.close()
if isinstance(msg, messages.MeshRemoveDestinationsMessage):
await self.remove_dst_nodes(addresses=msg.addresses)
if isinstance(msg.content, messages.MeshRemoveDestinationsMessage):
await self.remove_dst_nodes(addresses=msg.content.addresses)
if isinstance(msg, messages.MeshRouteRequestMessage):
if msg.address == MESH_ROOT_ADDRESS:
if isinstance(msg.content, messages.MeshRouteRequestMessage):
if msg.content.address == MESH_ROOT_ADDRESS:
await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace")
await self.send_msg(messages.MeshRouteTraceMessage(
await self.send_msg(messages.MeshMessage(
src=MESH_ROOT_ADDRESS,
dst=msg.src,
request_id=msg.request_id,
content=messages.MeshRouteTraceMessage(
request_id=msg.content.request_id,
trace=[MESH_ROOT_ADDRESS],
)
))
else:
await self.log_text(MESH_ROOT_ADDRESS, "route request about someone else, sending response")
self.open_requests.add(msg.request_id)
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.address)
await self.send_msg(messages.MeshRouteResponseMessage(
self.open_requests.add(msg.content.request_id)
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.content.address)
await self.send_msg(messages.MeshMessage(
src=MESH_ROOT_ADDRESS,
dst=msg.src,
request_id=msg.request_id,
content=messages.MeshRouteResponseMessage(
request_id=msg.content.request_id,
route=uplink.node_id if uplink else MESH_NONE_ADDRESS,
)
))
if isinstance(msg, (messages.ConfigHardwareMessage,
if isinstance(msg.content, (messages.ConfigHardwareMessage,
messages.ConfigFirmwareMessage,
messages.ConfigBoardMessage)):
if (node_status.waiting_for == NodeWaitingFor.CONFIG and
@ -243,16 +257,16 @@ class MeshConsumer(AsyncWebsocketConsumer):
print('got all config, checking ota')
await self.check_ota([msg.src], first_time=True)
if isinstance(msg, messages.OTAStatusMessage):
print('got OTA status', msg)
node_status.reported_ota_update = msg.update_id
if isinstance(msg.content, messages.OTAStatusMessage):
print('got OTA status', msg.content)
node_status.reported_ota_update = msg.content.update_id
if node_status.waiting_for == NodeWaitingFor.OTA_START_STOP:
update_id = node_status.ota_recipient.update_id if node_status.ota_recipient else 0
if update_id == msg.update_id:
if update_id == msg.content.update_id:
print('start/cancel confirmed!')
node_status.waiting_for = NodeWaitingFor.NOTHING
if update_id:
if msg.status.is_failed:
if msg.content.status.is_failed:
print('ota failed')
node_status.ota_recipient.status = OTARecipientStatus.FAILED
await node_status.ota_recipient.send_status()
@ -263,14 +277,14 @@ class MeshConsumer(AsyncWebsocketConsumer):
else:
print('queue chunk sending')
await self.ota_set_chunks(node_status.ota_recipient.update,
min_chunk=msg.next_expected_chunk)
min_chunk=msg.content.next_expected_chunk)
if isinstance(msg, messages.OTARequestFragmentsMessage):
print('got OTA fragment request', msg)
if isinstance(msg.content, messages.OTARequestFragmentsMessage):
print('got OTA fragment request', msg.content)
desired_update_id = node_status.ota_recipient.update_id if node_status.ota_recipient else 0
if desired_update_id and msg.update_id == desired_update_id:
if desired_update_id and msg.content.update_id == desired_update_id:
print('queue requested chunk sending')
await self.ota_set_chunks(node_status.ota_recipient.update, chunks=set(msg.chunks))
await self.ota_set_chunks(node_status.ota_recipient.update, chunks=set(msg.content.chunks))
@database_sync_to_async
def create_uplink_in_database(self, address):
@ -334,7 +348,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
self.uplink.node.address, "we're the route for this message but it came from here so... no"
)
return
await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"])
await self.send_msg(MeshMessage.model_validate(data["msg"]), data["sender"])
async def mesh_ota_recipients_changed(self, data):
addresses = set(data["addresses"]) & set(self.dst_nodes.keys())
@ -349,7 +363,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
"""
async def log_received_message(self, src_node: MeshNode, msg: messages.MeshMessage):
as_json = MeshMessage.tojson(msg)
as_json = msg.model_dump()
await self.channel_layer.group_send("mesh_msg_received", {
"type": "mesh.msg_received",
"timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"),
@ -360,7 +374,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
await NodeMessage.objects.acreate(
uplink=self.uplink,
src_node=src_node,
message_type=msg.msg_type.name,
message_type=msg.content.msg_type.name,
data=as_json,
)
@ -435,29 +449,34 @@ class MeshConsumer(AsyncWebsocketConsumer):
node_state.last_sent = timezone.now()
print('request config dump, attempt #%d' % node_state.attempt)
node_state.attempt += 1
await self.send_msg(messages.ConfigDumpMessage(
await self.send_msg(messages.MeshMessage(
src=MESH_ROOT_ADDRESS,
dst=address,
content=messages.ConfigDumpMessage()
))
case NodeWaitingFor.OTA_START_STOP:
node_state.last_sent = timezone.now()
if node_state.ota_recipient:
print('starting ota, attempt #%d' % node_state.attempt)
await self.send_msg(messages.OTAStartMessage(
await self.send_msg(messages.MeshMessage(
src=MESH_ROOT_ADDRESS,
dst=address,
content=messages.OTAStartMessage(
update_id=node_state.ota_recipient.update_id, # noqa
total_bytes=node_state.ota_recipient.update.build.binary.size,
auto_apply=False,
auto_reboot=False,
)
))
else:
print('canceling ota, attempt #%d' % node_state.attempt)
await self.send_msg(messages.OTAAbortMessage(
await self.send_msg(messages.MeshMessage(
src=MESH_ROOT_ADDRESS,
dst=address,
content=messages.OTAAbortMessage(
update_id=0,
)
))
async def check_node_states(self):
@ -511,12 +530,14 @@ class MeshConsumer(AsyncWebsocketConsumer):
with self.dst_nodes[recipients[0]].ota_recipient.update.build.binary.open('rb') as f:
f.seek(chunk * OTA_CHUNK_SIZE)
data = f.read(OTA_CHUNK_SIZE)
await self.send_msg(messages.OTAFragmentMessage(
await self.send_msg(messages.MeshMessage(
src=MESH_ROOT_ADDRESS,
dst=recipients[0] if len(recipients) == 1 else MESH_BROADCAST_ADDRESS,
content=messages.OTAFragmentMessage(
update_id=update_id,
chunk=chunk,
data=data,
)
))
# wait a bit until we send more
@ -681,7 +702,7 @@ class MeshUIConsumer(AsyncJsonWebsocketConsumer):
self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]}
for recipient in msg_to_send["recipients"]:
await MeshMessage.fromjson({
await MeshMessage.model_validate({
'dst': recipient,
**msg_to_send["msg_data"],
}).send(sender=self.channel_name)

View file

@ -1,285 +0,0 @@
import re
from dataclasses import dataclass, field
from enum import IntEnum, unique
from typing import BinaryIO, Self
from c3nav.api.utils import EnumSchemaByNameMixin
from c3nav.mesh.baseformats import (BoolFormat, ChipRevFormat, EnumFormat, FixedHexFormat, FixedStrFormat,
SimpleConstFormat, SimpleFormat, StructType, TwoNibblesEnumFormat, VarArrayFormat)
class MacAddressFormat(FixedHexFormat):
def __init__(self):
super().__init__(num=6, sep=':')
class MacAddressesListFormat(VarArrayFormat):
def __init__(self, max_num):
super().__init__(child_type=MacAddressFormat(), max_num=max_num)
@unique
class LedType(IntEnum):
NONE = 0
SERIAL = 1
MULTIPIN = 2
@property
def pretty_name(self):
return self.name.lower()
@unique
class SerialLedType(IntEnum):
WS2812 = 1
SK6812 = 2
@dataclass
class LedConfig(StructType, union_type_field="led_type"):
"""
configuration for an optional connected status LED
"""
led_type: LedType = field(metadata={"format": EnumFormat(), "c_name": "type"})
@dataclass
class NoLedConfig(LedConfig, led_type=LedType.NONE):
pass
@dataclass
class SerialLedConfig(LedConfig, led_type=LedType.SERIAL):
serial_led_type: SerialLedType = field(metadata={"format": EnumFormat(), "c_name": "type"})
gpio: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class MultipinLedConfig(LedConfig, led_type=LedType.MULTIPIN):
gpio_red: int = field(metadata={"format": SimpleFormat('B')})
gpio_green: int = field(metadata={"format": SimpleFormat('B')})
gpio_blue: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class BoardSPIConfig(StructType):
"""
configuration for spi bus used for ETH or UWB
"""
gpio_miso: int = field(metadata={"format": SimpleFormat('B')})
gpio_mosi: int = field(metadata={"format": SimpleFormat('B')})
gpio_clk: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class UWBConfig(StructType):
"""
configuration for the connection to the UWB module
"""
enable: bool = field(metadata={"format": BoolFormat()})
gpio_cs: int = field(metadata={"format": SimpleFormat('B')})
gpio_irq: int = field(metadata={"format": SimpleFormat('B')})
gpio_rst: int = field(metadata={"format": SimpleFormat('B')})
gpio_wakeup: int = field(metadata={"format": SimpleFormat('B')})
gpio_exton: int = field(metadata={"format": SimpleFormat('B')})
@dataclass
class UplinkEthConfig(StructType):
"""
configuration for the connection to the ETH module
"""
enable: bool = field(metadata={"format": BoolFormat()})
gpio_cs: int = field(metadata={"format": SimpleFormat('B')})
gpio_int: int = field(metadata={"format": SimpleFormat('B')})
gpio_rst: int = field(metadata={"format": SimpleFormat('b')})
@unique
class BoardType(EnumSchemaByNameMixin, IntEnum):
CUSTOM = 0x00
# devboards
ESP32_C3_DEVKIT_M_1 = 0x01
ESP32_C3_32S = 2
# custom boards
C3NAV_UWB_BOARD = 0x10
C3NAV_LOCATION_PCB_REV_0_1 = 0x11
C3NAV_LOCATION_PCB_REV_0_2 = 0x12
@property
def pretty_name(self):
if self.name.startswith('ESP32'):
return self.name.replace('_', '-').replace('DEVKIT-', 'DevKit')
if self.name.startswith('C3NAV'):
name = self.name.replace('_', ' ').lower()
name = name.replace('uwb', 'UWB').replace('pcb', 'PCB')
name = re.sub(r'[0-9]+( [0-9+])+', lambda s: s[0].replace(' ', '.'), name)
name = re.sub(r'rev.*', lambda s: s[0].replace(' ', ''), name)
return name
return self.name
@dataclass
class BoardConfig(StructType, union_type_field="board"):
board: BoardType = field(metadata={"format": EnumFormat(as_hex=True)})
@dataclass
class CustomBoardConfig(BoardConfig, board=BoardType.CUSTOM):
spi: BoardSPIConfig = field(metadata={"as_definition": True})
uwb: UWBConfig = field(metadata={"as_definition": True})
eth: UplinkEthConfig = field(metadata={"as_definition": True})
led: LedConfig = field(metadata={"as_definition": True})
@dataclass
class DevkitMBoardConfig(BoardConfig, board=BoardType.ESP32_C3_DEVKIT_M_1):
spi: BoardSPIConfig = field(metadata={"as_definition": True})
uwb: UWBConfig = field(metadata={"as_definition": True})
eth: UplinkEthConfig = field(metadata={"as_definition": True})
@dataclass
class Esp32SBoardConfig(BoardConfig, board=BoardType.ESP32_C3_32S):
spi: BoardSPIConfig = field(metadata={"as_definition": True})
uwb: UWBConfig = field(metadata={"as_definition": True})
eth: UplinkEthConfig = field(metadata={"as_definition": True})
@dataclass
class UwbBoardConfig(BoardConfig, board=BoardType.C3NAV_UWB_BOARD):
eth: UplinkEthConfig = field(metadata={"as_definition": True})
@dataclass
class LocationPCBRev0Dot1BoardConfig(BoardConfig, board=BoardType.C3NAV_LOCATION_PCB_REV_0_1):
eth: UplinkEthConfig = field(metadata={"as_definition": True})
@dataclass
class LocationPCBRev0Dot2BoardConfig(BoardConfig, board=BoardType.C3NAV_LOCATION_PCB_REV_0_2):
eth: UplinkEthConfig = field(metadata={"as_definition": True})
@dataclass
class RangeResultItem(StructType):
peer: str = field(metadata={"format": MacAddressFormat()})
rssi: int = field(metadata={"format": SimpleFormat('b')})
distance: int = field(metadata={"format": SimpleFormat('h')})
@dataclass
class RawFTMEntry(StructType):
dlog_token: int = field(metadata={"format": SimpleFormat('B')})
rssi: int = field(metadata={"format": SimpleFormat('b')})
rtt: int = field(metadata={"format": SimpleFormat('I')})
t1: int = field(metadata={"format": SimpleFormat('Q')})
t2: int = field(metadata={"format": SimpleFormat('Q')})
t3: int = field(metadata={"format": SimpleFormat('Q')})
t4: int = field(metadata={"format": SimpleFormat('Q')})
@dataclass
class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t", c_includes=['<esp_app_desc.h>']):
magic_word: int = field(metadata={"format": SimpleConstFormat('I', 0xAB_CD_54_32)}, repr=False)
secure_version: int = field(metadata={"format": SimpleFormat('I')})
reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
version: str = field(metadata={"format": FixedStrFormat(32)})
project_name: str = field(metadata={"format": FixedStrFormat(32)})
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)})
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
@unique
class SPIFlashMode(EnumSchemaByNameMixin, IntEnum):
QIO = 0
QOUT = 1
DIO = 2
DOUT = 3
@unique
class FlashSize(EnumSchemaByNameMixin, IntEnum):
SIZE_1MB = 0
SIZE_2MB = 1
SIZE_4MB = 2
SIZE_8MB = 3
SIZE_16MB = 4
SIZE_32MB = 5
SIZE_64MB = 6
SIZE_128MB = 7
@property
def pretty_name(self):
return self.name.removeprefix('SIZE_')
@unique
class FlashFrequency(EnumSchemaByNameMixin, IntEnum):
FREQ_40MHZ = 0
FREQ_26MHZ = 1
FREQ_20MHZ = 2
FREQ_80MHZ = 0xf
@property
def pretty_name(self):
return self.name.removeprefix('FREQ_').replace('MHZ', 'Mhz')
@dataclass
class FlashSettings:
size: FlashSize
frequency: FlashFrequency
@property
def display(self):
return f"{self.size.pretty_name} ({self.frequency.pretty_name})"
@unique
class ChipType(EnumSchemaByNameMixin, IntEnum):
ESP32_S2 = 2
ESP32_C3 = 5
@property
def pretty_name(self):
return self.name.replace('_', '-')
@dataclass
class FirmwareImageFileHeader(StructType):
magic_word: int = field(metadata={"format": SimpleConstFormat('B', 0xE9)}, repr=False)
num_segments: int = field(metadata={"format": SimpleFormat('B')})
spi_flash_mode: SPIFlashMode = field(metadata={"format": EnumFormat()})
flash_stuff: FlashSettings = field(metadata={"format": TwoNibblesEnumFormat()})
entry_point: int = field(metadata={"format": SimpleFormat('I')})
@dataclass
class FirmwareImageExtendedFileHeader(StructType):
wp_pin: int = field(metadata={"format": SimpleFormat('B')})
drive_settings: int = field(metadata={"format": SimpleFormat('3B')})
chip: ChipType = field(metadata={"format": EnumFormat('H')})
min_chip_rev_old: int = field(metadata={"format": SimpleFormat('B')})
min_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()})
max_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()})
reserv: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
hash_appended: bool = field(metadata={"format": BoolFormat()})
@dataclass
class FirmwareImage(StructType):
header: FirmwareImageFileHeader
ext_header: FirmwareImageExtendedFileHeader
first_segment_headers: tuple[int, int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
app_desc: FirmwareAppDescription
@classmethod
def from_file(cls, file: BinaryIO) -> Self:
result, data = cls.decode(file.read(FirmwareImage.get_min_size()))
return result

View file

@ -13,11 +13,15 @@ from django.db import transaction
from django.forms import BooleanField, ChoiceField, Form, ModelMultipleChoiceField, MultipleChoiceField
from django.http import Http404
from django.utils.translation import gettext_lazy as _
from pydantic import ValidationError as PydanticValidationError
from pydantic.type_adapter import TypeAdapter
from c3nav.mesh.dataformats import BoardConfig, BoardType, LedType, SerialLedType
from c3nav.mesh.messages import MESH_BROADCAST_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, MeshMessageType
from c3nav.mesh.cformats import CFormat
from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, MeshMessageContent,
MeshMessageType)
from c3nav.mesh.models import (FirmwareBuild, HardwareDescription, MeshNode, OTARecipientStatus, OTAUpdate,
OTAUpdateRecipient)
from c3nav.mesh.schemas import BoardConfig, BoardType, LedType, SerialLedType
from c3nav.mesh.utils import MESH_ALL_OTA_GROUP, group_msg_type_choices
@ -64,7 +68,7 @@ class MeshMessageForm(forms.Form):
if cls.msg_type in MeshMessageForm.msg_types:
raise TypeError('duplicate use of msg %s' % cls.msg_type)
MeshMessageForm.msg_types[cls.msg_type] = cls
cls.msg_type_class = MeshMessage.get_type(cls.msg_type)
cls.msg_type_class = CFormat.from_annotation(MeshMessageContent).models.get(cls.msg_type.c_value).model
@classmethod
def get_form_for_type(cls, msg_type):
@ -83,10 +87,12 @@ class MeshMessageForm(forms.Form):
raise Exception('nope')
return {
'msg_type': self.msg_type.name,
'src': MESH_ROOT_ADDRESS,
"src": MESH_ROOT_ADDRESS,
"content": {
"msg_type": self.msg_type.name,
**self.get_cleaned_msg_data(),
}
}
def get_recipients(self):
return [self.recipient] if self.recipient else self.cleaned_data['recipients']
@ -96,7 +102,7 @@ class MeshMessageForm(forms.Form):
recipients = self.get_recipients()
for recipient in recipients:
print('sending to ', recipient)
async_to_sync(MeshMessage.fromjson({
async_to_sync(MeshMessage.model_validate({
'dst': recipient,
**msg_data,
}).send)()
@ -173,8 +179,8 @@ class ConfigBoardMessageForm(MeshMessageForm):
"prefix": "led_",
"field": "board",
"values": tuple(
cfg.board.name for cfg in BoardConfig._union_options["board"].values()
if "led" in cfg.__dataclass_fields__
board_type.name for board_type in BoardType
if "led" in CFormat.from_annotation(BoardConfig).models[board_type.c_value]._field_formats
),
},
{
@ -191,8 +197,8 @@ class ConfigBoardMessageForm(MeshMessageForm):
"prefix": "uwb_",
"field": "board",
"values": tuple(
cfg.board.name for cfg in BoardConfig._union_options["board"].values()
if "uwb" in cfg.__dataclass_fields__
board_type.name for board_type in BoardType
if "uwb" in CFormat.from_annotation(BoardConfig).models[board_type.c_value]._field_formats
),
},
{
@ -204,10 +210,7 @@ class ConfigBoardMessageForm(MeshMessageForm):
def clean(self):
cleaned_data = super().clean()
board_cfg = BoardConfig._union_options["board"][BoardType[cleaned_data["board"]]]
has_led = "led" in board_cfg.__dataclass_fields__
has_uwb = "uwb" in board_cfg.__dataclass_fields__
orig_cleaned_keys = set(cleaned_data.keys())
led_values = {
"led_type": cleaned_data.pop("led_type"),
@ -217,43 +220,29 @@ class ConfigBoardMessageForm(MeshMessageForm):
if name.startswith('led_')
}
}
if led_values:
cleaned_data["led"] = led_values
uwb_values = {
name.removeprefix('uwb_'): cleaned_data.pop(name)
for name in tuple(cleaned_data.keys())
if name.startswith('uwb_')
}
errors = {}
if has_led:
prefix = led_values["led_type"].lower()+'_'
cleaned_data["led"] = {
"led_type": led_values["led_type"],
**{
name.removeprefix(prefix): value
for name, value in led_values.items()
if name.startswith(prefix)
}
}
for key, value in tuple(cleaned_data["led"].items()):
if value is None:
field_name = f'led_{prefix}{key}'
if self.fields[field_name].min_value == -1:
cleaned_data[key] = -1
else:
errors[field_name] = _('this field is required')
if has_uwb:
if uwb_values:
cleaned_data["uwb"] = uwb_values
for key, value in tuple(cleaned_data["uwb"].items()):
if value is None:
field_name = f'uwb_{key}'
if self.fields[field_name].min_value == -1 or not cleaned_data["uwb"]["enable"]:
cleaned_data[key] = -1
else:
errors[field_name] = _('this field is required')
if errors:
try:
TypeAdapter(BoardConfig).validate_python(cleaned_data)
except PydanticValidationError as e:
from pprint import pprint
pprint(e.errors())
errors = {}
for error in e.errors():
loc = "_".join(s for s in error["loc"] if not s.isupper())
if loc in orig_cleaned_keys:
errors.setdefault(loc, []).append(error["msg"])
else:
errors.setdefault("__all__", []).append(f"{loc}: {error['msg']}")
raise ValidationError(errors)
return cleaned_data

View file

@ -1,52 +1,47 @@
from dataclasses import fields
from django.core.management.base import BaseCommand
from c3nav.mesh.baseformats import StructType, normalize_name
from c3nav.mesh.messages import MeshMessage
from c3nav.mesh.cformats import UnionFormat, normalize_name, CFormat
from c3nav.mesh.messages import MeshMessageContent
from c3nav.mesh.utils import indent_c
class Command(BaseCommand):
help = 'export mesh message structs for c code'
@staticmethod
def get_msg_c_enum_name(msg_type):
return normalize_name(msg_type.__name__.removeprefix('Mesh').removesuffix('Message')).upper()
def handle(self, *args, **options):
done_struct_names = set()
nodata = set()
struct_lines = {}
struct_sizes = []
struct_max_sizes = []
done_definitions = set()
for include in StructType.c_includes:
mesh_msg_content_format = CFormat.from_annotation(MeshMessageContent)
if not isinstance(mesh_msg_content_format, UnionFormat):
raise Exception('wuah')
discriminator_size = mesh_msg_content_format.discriminator_format.get_size()
for include in mesh_msg_content_format.get_c_includes():
print(f'#include {include}')
ignore_names = set(field_.name for field_ in fields(MeshMessage))
for msg_type, msg_class in MeshMessage.get_types().items():
if msg_class.c_struct_name:
if msg_class.c_struct_name in done_struct_names:
continue
done_struct_names.add(msg_class.c_struct_name)
if MeshMessage.c_structs[msg_class.c_struct_name] != msg_class:
# the purpose of MeshMessage.c_structs is unclear, currently this never triggers
# todo get rid of the whole c_structs thing if it doesn't turn out to be useful for anything
raise ValueError('what happened?')
base_name = (msg_class.c_struct_name or normalize_name(
getattr(msg_type, 'name', msg_class.__name__)
))
for msg_type, msg_content_format in mesh_msg_content_format.models.items():
base_name = normalize_name(mesh_msg_content_format.key_to_name[msg_type])
name = "mesh_msg_%s_t" % base_name
for definition_name, definition in msg_class.get_c_definitions().items():
for definition_name, definition in msg_content_format.get_c_definitions().items():
if definition_name not in done_definitions:
done_definitions.add(definition_name)
print(definition)
print()
code = msg_class.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
code = msg_content_format.get_c_code(name, ignore_fields=('msg_type', ), no_empty=True)
if code:
size = msg_class.get_size(no_inherited_fields=True, calculate_max=False)
max_size = msg_class.get_size(no_inherited_fields=True, calculate_max=True)
size = msg_content_format.get_size(calculate_max=False)
max_size = msg_content_format.get_size(calculate_max=True)
size -= discriminator_size
max_size -= discriminator_size
struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', ''))
struct_sizes.append(size)
struct_max_sizes.append(max_size)
@ -55,7 +50,7 @@ class Command(BaseCommand):
(name, size))
print()
else:
nodata.add(msg_class)
nodata.add(msg_content_format.model)
print("/** union between all message data structs */")
print("typedef union __packed {")
@ -72,20 +67,18 @@ class Command(BaseCommand):
print()
max_msg_type = max(MeshMessage.get_types().keys())
max_msg_type = max(mesh_msg_content_format.models.keys())
macro_data = []
for i in range(((max_msg_type//16)+1)*16):
msg_class = MeshMessage.get_types().get(i, None)
if msg_class:
name = (msg_class.c_struct_name or normalize_name(
getattr(msg_class.msg_type, 'name', msg_class.__name__)
))
msg_content_format = mesh_msg_content_format.models.get(i, None)
if msg_content_format:
name = normalize_name(mesh_msg_content_format.key_to_name[i])
macro_data.append((
msg_class.get_c_enum_name(),
("nodata" if msg_class in nodata else name),
msg_class.get_var_num(),
msg_class.get_size(no_inherited_fields=True, calculate_max=True),
msg_class.__doc__.strip(),
self.get_msg_c_enum_name(msg_content_format.model),
("nodata" if msg_content_format.model in nodata else name),
msg_content_format.get_var_num(), # todo: uh?
msg_content_format.get_size(calculate_max=True) - discriminator_size,
msg_content_format.model.__doc__.strip(),
))
else:
macro_data.append((

View file

@ -1,15 +1,16 @@
from dataclasses import dataclass, field
from enum import IntEnum, unique
from typing import TypeVar
from enum import unique
from typing import Annotated, Union
import channels
from annotated_types import Ge, Le, Lt, MaxLen
from channels.db import database_sync_to_async
from pydantic import PositiveInt
from pydantic.main import BaseModel
from pydantic.types import Discriminator, NonNegativeInt
from pydantic_extra_types.mac_address import MacAddress
from c3nav.api.utils import EnumSchemaByNameMixin
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat,
VarBytesFormat, VarStrFormat, normalize_name)
from c3nav.mesh.dataformats import (BoardConfig, ChipType, FirmwareAppDescription, MacAddressesListFormat,
MacAddressFormat, RangeResultItem, RawFTMEntry)
from c3nav.mesh.cformats import CDoc, CEmbed, CName, LenBytes, NoDef, VarLen, discriminator_value, CEnum
from c3nav.mesh.schemas import BoardConfig, ChipType, FirmwareAppDescription, RangeResultItem, RawFTMEntry
from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
@ -23,45 +24,45 @@ OTA_CHUNK_SIZE = 512
@unique
class MeshMessageType(EnumSchemaByNameMixin, IntEnum):
NOOP = 0x00
class MeshMessageType(CEnum):
NOOP = "NOOP", 0x00
ECHO_REQUEST = 0x01
ECHO_RESPONSE = 0x02
ECHO_REQUEST = "ECHO_REQUEST", 0x01
ECHO_RESPONSE = "ECHO_RESPONSE", 0x02
MESH_SIGNIN = 0x03
MESH_LAYER_ANNOUNCE = 0x04
MESH_ADD_DESTINATIONS = 0x05
MESH_REMOVE_DESTINATIONS = 0x06
MESH_ROUTE_REQUEST = 0x07
MESH_ROUTE_RESPONSE = 0x08
MESH_ROUTE_TRACE = 0x09
MESH_ROUTING_FAILED = 0x0a
MESH_SIGNIN = "MESH_SIGNIN", 0x03
MESH_LAYER_ANNOUNCE = "MESH_LAYER_ANNOUNCE", 0x04
MESH_ADD_DESTINATIONS = "MESH_ADD_DESTINATIONS", 0x05
MESH_REMOVE_DESTINATIONS = "MESH_REMOVE_DESTINATIONS", 0x06
MESH_ROUTE_REQUEST = "MESH_ROUTE_REQUEST", 0x07
MESH_ROUTE_RESPONSE = "MESH_ROUTE_RESPONSE", 0x08
MESH_ROUTE_TRACE = "MESH_ROUTE_TRACE", 0x09
MESH_ROUTING_FAILED = "MESH_ROUTING_FAILED", 0x0a
CONFIG_DUMP = 0x10
CONFIG_HARDWARE = 0x11
CONFIG_BOARD = 0x12
CONFIG_FIRMWARE = 0x13
CONFIG_UPLINK = 0x14
CONFIG_POSITION = 0x15
CONFIG_DUMP = "CONFIG_DUMP", 0x10
CONFIG_HARDWARE = "CONFIG_HARDWARE", 0x11
CONFIG_BOARD = "CONFIG_BOARD", 0x12
CONFIG_FIRMWARE = "CONFIG_FIRMWARE", 0x13
CONFIG_UPLINK = "CONFIG_UPLINK", 0x14
CONFIG_POSITION = "CONFIG_POSITION", 0x15
OTA_STATUS = 0x20
OTA_REQUEST_STATUS = 0x21
OTA_START = 0x22
OTA_URL = 0x23
OTA_FRAGMENT = 0x24
OTA_REQUEST_FRAGMENTS = 0x25
OTA_SETTING = 0x26
OTA_APPLY = 0x27
OTA_ABORT = 0x28
OTA_STATUS = "OTA_STATUS", 0x20
OTA_REQUEST_STATUS = "OTA_REQUEST_STATUS", 0x21
OTA_START = "OTA_START", 0x22
OTA_URL = "OTA_URL", 0x23
OTA_FRAGMENT = "OTA_FRAGMENT", 0x24
OTA_REQUEST_FRAGMENTS = "OTA_REQUEST_FRAGMENTS", 0x25
OTA_SETTING = "OTA_SETTING", 0x26
OTA_APPLY = "OTA_APPLY", 0x27
OTA_ABORT = "OTA_ABORT", 0x28
LOCATE_REQUEST_RANGE = 0x30
LOCATE_RANGE_RESULTS = 0x31
LOCATE_RAW_FTM_RESULTS = 0x32
LOCATE_REQUEST_RANGE = "LOCATE_REQUEST_RANGE", 0x30
LOCATE_RANGE_RESULTS = "LOCATE_RANGE_RESULTS", 0x31
LOCATE_RAW_FTM_RESULTS = "LOCATE_RAW_FTM_RESULTS", 0x32
REBOOT = 0x40
REBOOT = "REBOOT", 0x40
REPORT_ERROR = 0x50
REPORT_ERROR = "REPORT_ERROR", 0x50
@property
def pretty_name(self):
@ -72,32 +73,266 @@ class MeshMessageType(EnumSchemaByNameMixin, IntEnum):
return name
M = TypeVar('M', bound='MeshMessage')
class NoopMessage(discriminator_value(msg_type=MeshMessageType.NOOP), BaseModel):
""" noop """
pass
@dataclass
class MeshMessage(StructType, union_type_field="msg_type"):
dst: str = field(metadata={"format": MacAddressFormat()})
src: str = field(metadata={"format": MacAddressFormat()})
msg_type: MeshMessageType = field(metadata={"format": EnumFormat('B', c_definition=False)}, init=False, repr=False)
c_structs = {}
c_struct_name = None
class EchoRequestMessage(discriminator_value(msg_type=MeshMessageType.ECHO_REQUEST), BaseModel):
""" repeat back string """
content: Annotated[str, MaxLen(255), VarLen()] = ""
# noinspection PyMethodOverriding
def __init_subclass__(cls, /, c_struct_name=None, **kwargs):
super().__init_subclass__(**kwargs)
if c_struct_name:
cls.c_struct_name = c_struct_name
if c_struct_name in MeshMessage.c_structs:
raise TypeError('duplicate use of c_struct_name %s' % c_struct_name)
MeshMessage.c_structs[c_struct_name] = cls
class EchoResponseMessage(discriminator_value(msg_type=MeshMessageType.ECHO_RESPONSE), BaseModel):
""" repeat back string """
content: Annotated[str, MaxLen(255), VarLen()] = ""
class MeshSigninMessage(discriminator_value(msg_type=MeshMessageType.MESH_SIGNIN), BaseModel):
""" node says hello to upstream node """
pass
class MeshLayerAnnounceMessage(discriminator_value(msg_type=MeshMessageType.MESH_LAYER_ANNOUNCE), BaseModel):
""" upstream node announces layer number """
layer: Annotated[PositiveInt, Lt(2 ** 8), CDoc("mesh layer that the sending node is on")]
class MeshAddDestinationsMessage(discriminator_value(msg_type=MeshMessageType.MESH_ADD_DESTINATIONS), BaseModel):
""" downstream node announces served destination """
addresses: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("adresses of the added destinations",)]
class MeshRemoveDestinationsMessage(discriminator_value(msg_type=MeshMessageType.MESH_REMOVE_DESTINATIONS), BaseModel):
""" downstream node announces no longer served destination """
addresses: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("adresses of the removed destinations",)]
class MeshRouteRequestMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_REQUEST), BaseModel):
""" request routing information for node """
request_id: Annotated[PositiveInt, Lt(2**32)]
address: Annotated[MacAddress, CDoc("target address for the route")]
class MeshRouteResponseMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_RESPONSE), BaseModel):
""" reporting the routing table entry to the given address """
request_id: Annotated[PositiveInt, Lt(2**32)]
route: Annotated[MacAddress, CDoc("routing table entry or 00:00:00:00:00:00")]
class MeshRouteTraceMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_TRACE), BaseModel):
""" special message, collects all hop adresses on its way """
request_id: Annotated[PositiveInt, Lt(2**32)]
trace: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("addresses encountered by this message")]
class MeshRoutingFailedMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTING_FAILED), BaseModel):
""" TODO description"""
address: MacAddress
class ConfigDumpMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_DUMP), BaseModel):
""" request for the node to dump its config """
pass
class ConfigHardwareMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_HARDWARE), BaseModel):
""" respond hardware/chip info """
chip: Annotated[ChipType, NoDef(), LenBytes(2), CName("chip_id")]
revision_major: Annotated[NonNegativeInt, Lt(2**8)]
revision_minor: Annotated[NonNegativeInt, Lt(2**8)]
def get_chip_display(self):
return ChipType(self.chip).name.replace('_', '-')
class ConfigBoardMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_BOARD), BaseModel):
""" set/respond board config """
board_config: Annotated[BoardConfig, CEmbed]
class ConfigFirmwareMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_FIRMWARE), BaseModel):
""" respond firmware info """
app_desc: FirmwareAppDescription
class ConfigPositionMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_POSITION), BaseModel):
""" set/respond position config """
x_pos: Annotated[int, Ge(-2**31), Lt(2**31)]
y_pos: Annotated[int, Ge(-2**31), Lt(2**31)]
z_pos: Annotated[int, Ge(-2**15), Lt(2**15)]
class ConfigUplinkMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_UPLINK), BaseModel):
""" set/respond uplink config """
enabled: bool
ssid: Annotated[str, MaxLen(32)]
password: Annotated[str, MaxLen(64)]
channel: Annotated[PositiveInt, Le(15)]
udp: bool
ssl: bool
host: Annotated[str, MaxLen(64)]
port: Annotated[PositiveInt, Lt(2**16)]
@unique
class OTADeviceStatus(CEnum):
""" ota status, the ones >= 0x10 denote a permanent failure """
NONE = "NONE", 0x00
STARTED = "STARTED", 0x01
APPLIED = "APPLIED", 0x02
START_FAILED = "START_FAILED", 0x10
WRITE_FAILED = "WRITE_FAILED", 0x12
APPLY_FAILED = "APPLY_FAILED", 0x13
ROLLED_BACK = "ROLLED_BACK", 0x14
@property
def pretty_name(self):
return self.name.replace('_', ' ').lower()
@property
def is_failed(self):
return self >= self.START_FAILED
class OTAStatusMessage(discriminator_value(msg_type=MeshMessageType.OTA_STATUS), BaseModel):
""" report OTA status """
update_id: Annotated[NonNegativeInt, Lt(2**32)]
received_bytes: Annotated[NonNegativeInt, Lt(2**32)]
next_expected_chunk: Annotated[NonNegativeInt, Lt(2**16)]
auto_apply: bool
auto_reboot: bool
status: OTADeviceStatus
class OTARequestStatusMessage(discriminator_value(msg_type=MeshMessageType.OTA_REQUEST_STATUS), BaseModel):
""" request OTA status """
pass
class OTAStartMessage(discriminator_value(msg_type=MeshMessageType.OTA_START), BaseModel):
""" instruct node to start OTA """
update_id: Annotated[PositiveInt, Lt(2**32)]
total_bytes: Annotated[PositiveInt, Lt(2**32)]
auto_apply: bool
auto_reboot: bool
class OTAURLMessage(discriminator_value(msg_type=MeshMessageType.OTA_URL), BaseModel):
""" supply download URL for OTA update and who to distribute it to """
update_id: Annotated[PositiveInt, Lt(2**32)]
distribute_to: MacAddress
url: Annotated[str, MaxLen(255), VarLen()]
class OTAFragmentMessage(discriminator_value(msg_type=MeshMessageType.OTA_FRAGMENT), BaseModel):
""" supply OTA fragment """
update_id: Annotated[PositiveInt, Lt(2**32)]
chunk: Annotated[PositiveInt, Lt(2**16)]
data: Annotated[bytes, MaxLen(OTA_CHUNK_SIZE), VarLen()]
class OTARequestFragmentsMessage(discriminator_value(msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS), BaseModel):
""" request missing fragments """
update_id: Annotated[PositiveInt, Lt(2**32)]
chunks: Annotated[list[Annotated[PositiveInt, Lt(2**16)]], MaxLen(128), VarLen()]
class OTASettingMessage(discriminator_value(msg_type=MeshMessageType.OTA_SETTING), BaseModel):
""" configure whether to automatically apply and reboot when update is completed """
update_id: Annotated[PositiveInt, Lt(2**32)]
auto_apply: bool
auto_reboot: bool
class OTAApplyMessage(discriminator_value(msg_type=MeshMessageType.OTA_APPLY), BaseModel):
""" apply OTA and optionally reboot """
update_id: Annotated[PositiveInt, Lt(2**32)]
reboot: bool
class OTAAbortMessage(discriminator_value(msg_type=MeshMessageType.OTA_ABORT), BaseModel):
""" announcing OTA abort """
update_id: Annotated[NonNegativeInt, Lt(2**32)]
class LocateRequestRangeMessage(discriminator_value(msg_type=MeshMessageType.LOCATE_REQUEST_RANGE), BaseModel):
""" request to report distance to all nearby nodes """
pass
class LocateRangeResults(discriminator_value(msg_type=MeshMessageType.LOCATE_RANGE_RESULTS), BaseModel):
""" reports distance to given nodes """
ranges: Annotated[list[RangeResultItem], MaxLen(16), VarLen()]
class LocateRawFTMResults(discriminator_value(msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS), BaseModel):
""" reports distance to given nodes """
peer: MacAddress
results: Annotated[list[RawFTMEntry], MaxLen(16), VarLen()]
class Reboot(discriminator_value(msg_type=MeshMessageType.REBOOT), BaseModel):
""" reboot the device """
pass
class ReportError(discriminator_value(msg_type=MeshMessageType.REPORT_ERROR), BaseModel):
""" report a critical error to upstream """
message: Annotated[str, MaxLen(255), VarLen()]
MeshMessageContent = Annotated[
Union[
NoopMessage,
EchoRequestMessage,
EchoResponseMessage,
MeshSigninMessage,
MeshLayerAnnounceMessage,
MeshAddDestinationsMessage,
MeshRemoveDestinationsMessage,
MeshRouteRequestMessage,
MeshRouteResponseMessage,
MeshRouteTraceMessage,
MeshRoutingFailedMessage,
ConfigDumpMessage,
ConfigHardwareMessage,
ConfigBoardMessage,
ConfigFirmwareMessage,
ConfigPositionMessage,
ConfigUplinkMessage,
OTAStatusMessage,
OTARequestStatusMessage,
OTAStartMessage,
OTAURLMessage,
OTAFragmentMessage,
OTARequestFragmentsMessage,
OTASettingMessage,
OTAApplyMessage,
OTAAbortMessage,
LocateRequestRangeMessage,
LocateRangeResults,
LocateRawFTMResults,
Reboot,
ReportError,
],
Discriminator("msg_type")
]
class MeshMessage(BaseModel):
dst: MacAddress
src: MacAddress
content: MeshMessageContent
async def send(self, sender=None, exclude_uplink_address=None) -> bool:
data = {
"type": "mesh.send",
"sender": sender,
"exclude_uplink_address": exclude_uplink_address,
"msg": MeshMessage.tojson(self),
"msg": self.model_dump(),
}
if self.dst in (MESH_CHILDREN_ADDRESS, MESH_BROADCAST_ADDRESS):
@ -111,282 +346,3 @@ class MeshMessage(StructType, union_type_field="msg_type"):
if uplink.node_id == exclude_uplink_address:
return False
await channels.layers.get_channel_layer().send(uplink.name, data)
@classmethod
def get_ignore_c_fields(self):
return set()
@classmethod
def get_additional_c_fields(self):
return ()
@classmethod
def get_variable_name(cls, base_name):
return cls.c_struct_name or base_name
@classmethod
def get_c_enum_name(cls):
return normalize_name(cls.__name__.removeprefix('Mesh').removesuffix('Message')).upper()
@dataclass
class NoopMessage(MeshMessage, msg_type=MeshMessageType.NOOP):
""" noop """
pass
@dataclass
class EchoRequestMessage(MeshMessage, msg_type=MeshMessageType.ECHO_REQUEST):
""" repeat back string """
content: str = field(default='', metadata={'format': VarStrFormat(max_len=255)})
@dataclass
class EchoResponseMessage(MeshMessage, msg_type=MeshMessageType.ECHO_RESPONSE):
""" repeat back string """
content: str = field(default='', metadata={'format': VarStrFormat(max_len=255)})
@dataclass
class MeshSigninMessage(MeshMessage, msg_type=MeshMessageType.MESH_SIGNIN):
""" node says hello to upstream node """
pass
@dataclass
class MeshLayerAnnounceMessage(MeshMessage, msg_type=MeshMessageType.MESH_LAYER_ANNOUNCE):
""" upstream node announces layer number """
layer: int = field(metadata={
"format": SimpleFormat('B'),
"doc": "mesh layer that the sending node is on",
})
@dataclass
class MeshAddDestinationsMessage(MeshMessage, msg_type=MeshMessageType.MESH_ADD_DESTINATIONS):
""" downstream node announces served destination """
addresses: list[str] = field(default_factory=list, metadata={
"format": MacAddressesListFormat(max_num=16),
"doc": "adresses of the added destinations",
})
@dataclass
class MeshRemoveDestinationsMessage(MeshMessage, msg_type=MeshMessageType.MESH_REMOVE_DESTINATIONS):
""" downstream node announces no longer served destination """
addresses: list[str] = field(default_factory=list, metadata={
"format": MacAddressesListFormat(max_num=16),
"doc": "adresses of the removed destinations",
})
@dataclass
class MeshRouteRequestMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_REQUEST):
""" request routing information for node """
request_id: int = field(metadata={"format": SimpleFormat('I')})
address: str = field(metadata={
"format": MacAddressFormat(),
"doc": "target address for the route"
})
@dataclass
class MeshRouteResponseMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_RESPONSE):
""" reporting the routing table entry to the given address """
request_id: int = field(metadata={"format": SimpleFormat('I')})
route: str = field(metadata={
"format": MacAddressFormat(),
"doc": "routing table entry or 00:00:00:00:00:00"
})
@dataclass
class MeshRouteTraceMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_TRACE):
""" special message, collects all hop adresses on its way """
request_id: int = field(metadata={"format": SimpleFormat('I')})
trace: list[str] = field(default_factory=list, metadata={
"format": MacAddressesListFormat(max_num=16),
"doc": "addresses encountered by this message",
})
@dataclass
class MeshRoutingFailedMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTING_FAILED):
""" TODO description"""
address: str = field(metadata={"format": MacAddressFormat()})
@dataclass
class ConfigDumpMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_DUMP):
""" request for the node to dump its config """
pass
@dataclass
class ConfigHardwareMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_HARDWARE):
""" respond hardware/chip info """
chip: ChipType = field(metadata={
"format": EnumFormat("H", c_definition=False),
"c_name": "chip_id",
})
revision_major: int = field(metadata={"format": SimpleFormat('B')})
revision_minor: int = field(metadata={"format": SimpleFormat('B')})
def get_chip_display(self):
return ChipType(self.chip).name.replace('_', '-')
@dataclass
class ConfigBoardMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_BOARD):
""" set/respond board config """
board_config: BoardConfig = field(metadata={"c_embed": True, "json_embed": True})
@dataclass
class ConfigFirmwareMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_FIRMWARE):
""" respond firmware info """
app_desc: FirmwareAppDescription = field(metadata={'json_embed': True})
@dataclass
class ConfigPositionMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_POSITION):
""" set/respond position config """
x_pos: int = field(metadata={"format": SimpleFormat('i')})
y_pos: int = field(metadata={"format": SimpleFormat('i')})
z_pos: int = field(metadata={"format": SimpleFormat('h')})
@dataclass
class ConfigUplinkMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_UPLINK):
""" set/respond uplink config """
enabled: bool = field(metadata={"format": BoolFormat()})
ssid: str = field(metadata={"format": FixedStrFormat(32)})
password: str = field(metadata={"format": FixedStrFormat(64)})
channel: int = field(metadata={"format": SimpleFormat('B')})
udp: bool = field(metadata={"format": BoolFormat()})
ssl: bool = field(metadata={"format": BoolFormat()})
host: str = field(metadata={"format": FixedStrFormat(64)})
port: int = field(metadata={"format": SimpleFormat('H')})
@unique
class OTADeviceStatus(EnumSchemaByNameMixin, IntEnum):
""" ota status, the ones >= 0x10 denote a permanent failure """
NONE = 0x00
STARTED = 0x01
APPLIED = 0x02
START_FAILED = 0x10
WRITE_FAILED = 0x12
APPLY_FAILED = 0x13
ROLLED_BACK = 0x14
@property
def pretty_name(self):
return self.name.replace('_', ' ').lower()
@property
def is_failed(self):
return self >= self.START_FAILED
@dataclass
class OTAStatusMessage(MeshMessage, msg_type=MeshMessageType.OTA_STATUS):
""" report OTA status """
update_id: int = field(metadata={"format": SimpleFormat('I')})
received_bytes: int = field(metadata={"format": SimpleFormat('I')})
next_expected_chunk: int = field(metadata={"format": SimpleFormat('H')})
auto_apply: bool = field(metadata={"format": BoolFormat()})
auto_reboot: bool = field(metadata={"format": BoolFormat()})
status: OTADeviceStatus = field(metadata={"format": EnumFormat('B')})
@dataclass
class OTARequestStatusMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_STATUS):
""" request OTA status """
pass
@dataclass
class OTAStartMessage(MeshMessage, msg_type=MeshMessageType.OTA_START):
""" instruct node to start OTA """
update_id: int = field(metadata={"format": SimpleFormat('I')})
total_bytes: int = field(metadata={"format": SimpleFormat('I')})
auto_apply: bool = field(metadata={"format": BoolFormat()})
auto_reboot: bool = field(metadata={"format": BoolFormat()})
@dataclass
class OTAURLMessage(MeshMessage, msg_type=MeshMessageType.OTA_URL):
""" supply download URL for OTA update and who to distribute it to """
update_id: int = field(metadata={"format": SimpleFormat('I')})
distribute_to: str = field(metadata={"format": MacAddressFormat()})
url: str = field(metadata={"format": VarStrFormat(max_len=255)})
@dataclass
class OTAFragmentMessage(MeshMessage, msg_type=MeshMessageType.OTA_FRAGMENT):
""" supply OTA fragment """
update_id: int = field(metadata={"format": SimpleFormat('I')})
chunk: int = field(metadata={"format": SimpleFormat('H')})
data: bytes = field(metadata={"format": VarBytesFormat(max_size=OTA_CHUNK_SIZE)})
@dataclass
class OTARequestFragmentsMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS):
""" request missing fragments """
update_id: int = field(metadata={"format": SimpleFormat('I')})
chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'), max_num=128)})
@dataclass
class OTASettingMessage(MeshMessage, msg_type=MeshMessageType.OTA_SETTING):
""" configure whether to automatically apply and reboot when update is completed """
update_id: int = field(metadata={"format": SimpleFormat('I')})
auto_apply: bool = field(metadata={"format": BoolFormat()})
auto_reboot: bool = field(metadata={"format": BoolFormat()})
@dataclass
class OTAApplyMessage(MeshMessage, msg_type=MeshMessageType.OTA_APPLY):
""" apply OTA and optionally reboot """
update_id: int = field(metadata={"format": SimpleFormat('I')})
reboot: bool = field(metadata={"format": BoolFormat()})
@dataclass
class OTAAbortMessage(MeshMessage, msg_type=MeshMessageType.OTA_ABORT):
""" announcing OTA abort """
update_id: int = field(metadata={"format": SimpleFormat('I')})
@dataclass
class LocateRequestRangeMessage(MeshMessage, msg_type=MeshMessageType.LOCATE_REQUEST_RANGE):
""" request to report distance to all nearby nodes """
pass
@dataclass
class LocateRangeResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RANGE_RESULTS):
""" reports distance to given nodes """
ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem, max_num=16)})
@dataclass
class LocateRawFTMResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS):
""" reports distance to given nodes """
peer: str = field(metadata={"format": MacAddressFormat()})
results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry, max_num=16)})
@dataclass
class Reboot(MeshMessage, msg_type=MeshMessageType.REBOOT):
""" reboot the device """
pass
@dataclass
class ReportError(MeshMessage, msg_type=MeshMessageType.REPORT_ERROR):
""" report a critical error to upstream """
message: str = field(metadata={"format": VarStrFormat(max_len=255)})

View file

@ -18,7 +18,7 @@ from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from c3nav.mapdata.models.geometry.space import RangingBeacon
from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage
from c3nav.mesh.schemas import BoardType, ChipType, FirmwareImage
from c3nav.mesh.messages import ConfigFirmwareMessage, ConfigHardwareMessage
from c3nav.mesh.messages import MeshMessage as MeshMessage
from c3nav.mesh.messages import MeshMessageType
@ -383,7 +383,7 @@ class NodeMessage(models.Model):
@cached_property
def parsed(self) -> Self:
return MeshMessage.fromjson(self.data)
return MeshMessage.model_validate(self.data)
class FirmwareVersion(models.Model):

284
src/c3nav/mesh/schemas.py Normal file
View file

@ -0,0 +1,284 @@
import re
from dataclasses import dataclass, field
from enum import unique
from typing import Annotated, BinaryIO, ClassVar, Literal, Self, Union
from annotated_types import Gt, Le, Lt, MaxLen, Ge
from pydantic import NegativeInt, PositiveInt
from pydantic.main import BaseModel
from pydantic.types import Discriminator, NonNegativeInt
from pydantic_extra_types.mac_address import MacAddress
from c3nav.mesh.cformats import AsDefinition, AsHex, CName, ExistingCStruct, discriminator_value, \
CEnum, TwoNibblesEncodable
@unique
class LedType(CEnum):
NONE = "NONE", 0
SERIAL = "SERIAL", 1
MULTIPIN = "MULTIPIN", 2
@property
def pretty_name(self):
return self.name.lower()
@unique
class SerialLedType(CEnum):
WS2812 = "WS2812", 1
SK6812 = "SK6812", 2
class NoLedConfig(discriminator_value(led_type=LedType.NONE), BaseModel):
pass
class SerialLedConfig(discriminator_value(led_type=LedType.SERIAL), BaseModel):
serial_led_type: Annotated[SerialLedType, CName("type")]
gpio: Annotated[PositiveInt, Lt(2**8)]
class MultipinLedConfig(discriminator_value(led_type=LedType.MULTIPIN), BaseModel):
gpio_red: Annotated[PositiveInt, Lt(2**8)]
gpio_green: Annotated[PositiveInt, Lt(2**8)]
gpio_blue: Annotated[PositiveInt, Lt(2**8)]
LedConfig = Annotated[
Union[
NoLedConfig,
SerialLedConfig,
MultipinLedConfig,
],
Discriminator("led_type")
]
class BoardSPIConfig(BaseModel):
"""
configuration for spi bus used for ETH or UWB
"""
gpio_miso: Annotated[PositiveInt, Lt(2**8)]
gpio_mosi: Annotated[PositiveInt, Lt(2**8)]
gpio_clk: Annotated[PositiveInt, Lt(2**8)]
class UWBConfig(BaseModel):
"""
configuration for the connection to the UWB module
"""
enable: bool
gpio_cs: Annotated[PositiveInt, Lt(2**8)]
gpio_irq: Annotated[PositiveInt, Lt(2**8)]
gpio_rst: Annotated[PositiveInt, Lt(2**8)]
gpio_wakeup: Annotated[PositiveInt, Lt(2**8)]
gpio_exton: Annotated[PositiveInt, Lt(2**8)]
class UplinkEthConfig(BaseModel):
"""
configuration for the connection to the ETH module
"""
enable: bool
gpio_cs: Annotated[PositiveInt, Lt(2**8)]
gpio_int: Annotated[PositiveInt, Lt(2**8)]
gpio_rst: Annotated[int, Ge(-1), Lt(2**7)]
@unique
class BoardType(CEnum):
CUSTOM = "CUSTOM", 0x00
# devboards
ESP32_C3_DEVKIT_M_1 = "ESP32_C3_DEVKIT_M_1", 0x01
ESP32_C3_32S = "ESP32_C3_32S", 0x02
# custom boards
C3NAV_UWB_BOARD = "C3NAV_UWB_BOARD", 0x10
C3NAV_LOCATION_PCB_REV_0_1 = "C3NAV_LOCATION_PCB_REV_0_1", 0x11
C3NAV_LOCATION_PCB_REV_0_2 = "C3NAV_LOCATION_PCB_REV_0_2", 0x12
@property
def pretty_name(self):
if self.name.startswith('ESP32'):
return self.name.replace('_', '-').replace('DEVKIT-', 'DevKit')
if self.name.startswith('C3NAV'):
name = self.name.replace('_', ' ').lower()
name = name.replace('uwb', 'UWB').replace('pcb', 'PCB')
name = re.sub(r'[0-9]+( [0-9+])+', lambda s: s[0].replace(' ', '.'), name)
name = re.sub(r'rev.*', lambda s: s[0].replace(' ', ''), name)
return name
return self.name
class CustomBoardConfig(discriminator_value(board=BoardType.CUSTOM), BaseModel):
spi: Annotated[BoardSPIConfig, AsDefinition()]
uwb: Annotated[UWBConfig, AsDefinition()]
eth: Annotated[UplinkEthConfig, AsDefinition()]
led: Annotated[LedConfig, AsDefinition()]
class DevkitMBoardConfig(discriminator_value(board=BoardType.ESP32_C3_DEVKIT_M_1), BaseModel):
spi: Annotated[BoardSPIConfig, AsDefinition()]
uwb: Annotated[UWBConfig, AsDefinition()]
eth: Annotated[UplinkEthConfig, AsDefinition()]
class Esp32SBoardConfig(discriminator_value(board=BoardType.ESP32_C3_32S), BaseModel):
spi: Annotated[BoardSPIConfig, AsDefinition()]
uwb: Annotated[UWBConfig, AsDefinition()]
eth: Annotated[UplinkEthConfig, AsDefinition()]
class UwbBoardConfig(discriminator_value(board=BoardType.C3NAV_UWB_BOARD), BaseModel):
eth: Annotated[UplinkEthConfig, AsDefinition()]
class LocationPCBRev0Dot1BoardConfig(discriminator_value(board=BoardType.C3NAV_LOCATION_PCB_REV_0_1), BaseModel):
eth: Annotated[UplinkEthConfig, AsDefinition()]
class LocationPCBRev0Dot2BoardConfig(discriminator_value(board=BoardType.C3NAV_LOCATION_PCB_REV_0_2), BaseModel):
eth: Annotated[UplinkEthConfig, AsDefinition()]
BoardConfig = Annotated[
Union[
CustomBoardConfig,
DevkitMBoardConfig,
Esp32SBoardConfig,
UwbBoardConfig,
LocationPCBRev0Dot1BoardConfig,
LocationPCBRev0Dot2BoardConfig,
],
Discriminator("board"),
AsHex(),
]
class RangeResultItem(BaseModel):
peer: MacAddress
rssi: Annotated[NegativeInt, Gt(-100)]
distance: Annotated[int, Gt(-32000), Lt(32000)]
class RawFTMEntry(BaseModel):
dlog_token: Annotated[PositiveInt, Lt(255)]
rssi: Annotated[NegativeInt, Gt(-100)]
rtt: Annotated[PositiveInt, Lt(2**32)]
t1: Annotated[PositiveInt, Lt(2**64)]
t2: Annotated[PositiveInt, Lt(2**64)]
t3: Annotated[PositiveInt, Lt(2**64)]
t4: Annotated[PositiveInt, Lt(2**64)]
class FirmwareAppDescription(BaseModel):
existing_c_struct: ClassVar = ExistingCStruct(name="esp_app_desc_t", includes=['<esp_app_desc.h>'])
magic_word: Literal[0xAB_CD_54_32] = field(repr=False)
secure_version: Annotated[NonNegativeInt, Lt(2**32)]
reserv1: Annotated[str, MaxLen(8*2), AsHex()] = field(repr=False)
version: Annotated[str, MaxLen(32)]
project_name: Annotated[str, MaxLen(32)]
compile_time: Annotated[str, MaxLen(16)]
compile_date: Annotated[str, MaxLen(16)]
idf_version: Annotated[str, MaxLen(32)]
app_elf_sha256: Annotated[str, MaxLen(64), AsHex()]
reserv2: Annotated[str, MaxLen(20*4*2), AsHex()] = field(repr=False)
@unique
class SPIFlashMode(CEnum):
QIO = "QID", 0
QOUT = "QOUT", 1
DIO = "DIO", 2
DOUT = "DOUT", 3
@unique
class FlashSize(CEnum):
SIZE_1MB = "SIZE_1MB", 0
SIZE_2MB = "SIZE_2MB", 1
SIZE_4MB = "SIZE_4MB", 2
SIZE_8MB = "SIZE_8MB", 3
SIZE_16MB = "SIZE_16MB", 4
SIZE_32MB = "SIZE_32MB", 5
SIZE_64MB = "SIZE_64MB", 6
SIZE_128MB = "SIZE_128MB", 7
@property
def pretty_name(self):
return self.name.removeprefix('SIZE_')
@unique
class FlashFrequency(CEnum):
FREQ_40MHZ = "FREQ_40MHZ", 0
FREQ_26MHZ = "FREQ_26MHZ", 1
FREQ_20MHZ = "FREQ_20MHZ", 2
FREQ_80MHZ = "FREQ_80MHZ", 0xf
@property
def pretty_name(self):
return self.name.removeprefix('FREQ_').replace('MHZ', 'Mhz')
@dataclass
class FlashSettings(TwoNibblesEncodable):
size: FlashSize
frequency: FlashFrequency
@property
def display(self):
return f"{self.size.pretty_name} ({self.frequency.pretty_name})"
@unique
class ChipType(CEnum):
ESP32_S2 = "ESP32_S2", 2
ESP32_C3 = "ESP32_C3", 5
@property
def pretty_name(self):
return self.name.replace('_', '-')
class FirmwareImageFileHeader(BaseModel):
magic_word: Literal[0xE9] = field(repr=False)
num_segments: Annotated[PositiveInt, Lt(2**8)]
spi_flash_mode: SPIFlashMode
flash_stuff: FlashSettings
entry_point: Annotated[PositiveInt, Lt(2**32)]
class FirmwareImageFileHeader(BaseModel):
major: int
minor: int
num_segments: Annotated[PositiveInt, Lt(2**8)]
spi_flash_mode: SPIFlashMode
flash_stuff: FlashSettings
entry_point: Annotated[PositiveInt, Lt(2**32)]
class FirmwareImageExtendedFileHeader(BaseModel):
wp_pin: Annotated[PositiveInt, Lt(2**8)]
drive_settings: Annotated[bytes, MaxLen(3)]
chip: Annotated[ChipType, Lt(2**16)]
min_chip_rev_old: int
min_chip_rev: Annotated[PositiveInt, Le(9999)]
max_chip_rev: Annotated[PositiveInt, Le(9999)]
reserv: Annotated[bytes, MaxLen(4)] = field(repr=False)
hash_appended: bool
class FirmwareImage(BaseModel):
header: FirmwareImageFileHeader
ext_header: FirmwareImageExtendedFileHeader
first_segment_headers: Annotated[bytes, MaxLen(2)] = field(repr=False)
app_desc: FirmwareAppDescription
@classmethod
def from_file(cls, file: BinaryIO) -> Self:
result, data = cls.decode(file.read(FirmwareImage.get_min_size()))
return result

View file

@ -12,7 +12,7 @@ UPLINK_TIMEOUT = UPLINK_PING+5
def indent_c(code):
return " "+code.replace("\n", "\n ")
return " "+code.replace("\n", "\n ").replace("\n \n", "\n\n")
def get_node_names():

View file

@ -82,7 +82,7 @@ def locate_test(request):
None
)
return {
"ranges": msg.parsed.tojson(msg.parsed)["ranges"],
"ranges": msg.parsed.model_dump()["ranges"],
"datetime": msg.datetime,
"location": location.serialize(simple_geometry=True) if location else None
}

View file

@ -3,6 +3,7 @@ django-bootstrap3==23.6
django-compressor==4.4
csscompressor==0.9.5
django-ninja==1.1.0
pydantic-extra-types==2.5.0
django-filter==23.5
django-environ==0.11.2
shapely==2.0.3