change all of the MeshMessages c-models from dataclasses to pydantic
This commit is contained in:
parent
bd1a143d31
commit
0fd789173a
14 changed files with 1691 additions and 1628 deletions
|
@ -1,52 +1,5 @@
|
|||
from typing import Annotated, Any, Type
|
||||
from typing import Annotated
|
||||
|
||||
import annotated_types
|
||||
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
def get_api_post_data(request):
|
||||
is_json = request.META.get('CONTENT_TYPE').lower() == 'application/json'
|
||||
if is_json:
|
||||
try:
|
||||
data = request.json_body
|
||||
except AttributeError:
|
||||
pass # todo fix this raise ParseError('Invalid JSON.')
|
||||
return data
|
||||
return request.POST
|
||||
|
||||
|
||||
class EnumSchemaByNameMixin:
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
json_schema = handler(core_schema)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
json_schema["enum"] = [m.name for m in cls]
|
||||
json_schema["type"] = "string"
|
||||
return json_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: Type[Any], handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls.validate,
|
||||
core_schema.any_schema(),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x.name),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v: int):
|
||||
if isinstance(v, cls):
|
||||
return v
|
||||
try:
|
||||
return cls[v]
|
||||
except KeyError:
|
||||
pass
|
||||
return cls(v)
|
||||
|
||||
|
||||
NonEmptyStr = Annotated[str, annotated_types.MinLen(1)]
|
||||
|
|
|
@ -12,8 +12,8 @@ from pydantic import PositiveInt, field_validator
|
|||
from c3nav.api.auth import APIKeyAuth, auth_permission_responses, auth_responses, validate_responses
|
||||
from c3nav.api.exceptions import API404, APIConflict, APIRequestValidationFailed
|
||||
from c3nav.api.schema import BaseSchema
|
||||
from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage
|
||||
from c3nav.mesh.messages import MeshMessageType
|
||||
from c3nav.mesh.schemas import BoardType, ChipType, FirmwareImage
|
||||
from c3nav.mesh.messages import MeshMessageType, MeshMessage
|
||||
from c3nav.mesh.models import FirmwareBuild, FirmwareVersion, NodeMessage
|
||||
|
||||
mesh_api_router = APIRouter(tags=["mesh"], auth=APIKeyAuth(permissions={"mesh_control"}))
|
||||
|
@ -93,7 +93,7 @@ def firmware_by_id(request, firmware_id: int):
|
|||
@mesh_api_router.get('/firmwares/{firmware_id}/{variant}/image_data',
|
||||
summary="firmware image header",
|
||||
description="get firmware image header for specific firmware build",
|
||||
response={200: FirmwareImage.schema, **API404.dict(), **auth_responses},
|
||||
response={200: FirmwareImage, **API404.dict(), **auth_responses},
|
||||
openapi_extra={
|
||||
"externalDocs": {
|
||||
'description': 'esp-idf docs',
|
||||
|
@ -105,7 +105,7 @@ def firmware_by_id(request, firmware_id: int):
|
|||
def firmware_build_image(request, firmware_id: int, variant: str):
|
||||
try:
|
||||
build = FirmwareBuild.objects.get(version_id=firmware_id, variant=variant)
|
||||
return FirmwareImage.tojson(build.firmware_image)
|
||||
return build.firmware_image.model_dump()
|
||||
except FirmwareVersion.DoesNotExist:
|
||||
raise API404("Firmware or firmware build not found")
|
||||
|
||||
|
@ -218,7 +218,7 @@ class NodeMessageSchema(BaseSchema):
|
|||
src_node: NodeAddress
|
||||
message_type: MeshMessageType
|
||||
datetime: datetime
|
||||
data: dict
|
||||
data: MeshMessage
|
||||
|
||||
@staticmethod
|
||||
def resolve_src_node(obj):
|
||||
|
|
|
@ -1,810 +0,0 @@
|
|||
import re
|
||||
import struct
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import Field, dataclass
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from typing import Any, Self, Sequence
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
from c3nav.mesh.utils import indent_c
|
||||
|
||||
|
||||
class BaseFormat(ABC):
|
||||
|
||||
def get_var_num(self):
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, value):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def decode(cls, data) -> tuple[Any, bytes]:
|
||||
pass
|
||||
|
||||
def fromjson(self, data):
|
||||
return data
|
||||
|
||||
def tojson(self, data):
|
||||
return data
|
||||
|
||||
@abstractmethod
|
||||
def get_min_size(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_size(self):
|
||||
pass
|
||||
|
||||
def get_size(self, calculate_max=False):
|
||||
if calculate_max:
|
||||
return self.get_max_size()
|
||||
else:
|
||||
return self.get_min_size()
|
||||
|
||||
@abstractmethod
|
||||
def get_c_parts(self) -> tuple[str, str]:
|
||||
pass
|
||||
|
||||
def get_c_code(self, name) -> str:
|
||||
pre, post = self.get_c_parts()
|
||||
return "%s %s%s;" % (pre, name, post)
|
||||
|
||||
def set_field_type(self, field_type):
|
||||
self.field_type = field_type
|
||||
|
||||
def get_c_definitions(self) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
def get_typedef_name(self):
|
||||
return '%s_t' % normalize_name(self.field_type.__name__)
|
||||
|
||||
|
||||
class SimpleFormat(BaseFormat):
|
||||
def __init__(self, fmt):
|
||||
self.fmt = fmt
|
||||
self.size = struct.calcsize(fmt)
|
||||
|
||||
self.c_type = self.c_types[self.fmt[-1]]
|
||||
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
|
||||
|
||||
def encode(self, value):
|
||||
if self.num == 1:
|
||||
return struct.pack(self.fmt, value)
|
||||
return struct.pack(self.fmt, *value)
|
||||
|
||||
def decode(self, data: bytes) -> tuple[Any, bytes]:
|
||||
value = struct.unpack(self.fmt, data[:self.size])
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
return value, data[self.size:]
|
||||
|
||||
def get_min_size(self):
|
||||
return self.size
|
||||
|
||||
def get_max_size(self):
|
||||
return self.size
|
||||
|
||||
c_types = {
|
||||
"B": "uint8_t",
|
||||
"H": "uint16_t",
|
||||
"I": "uint32_t",
|
||||
"Q": "uint64_t",
|
||||
"b": "int8_t",
|
||||
"h": "int16_t",
|
||||
"i": "int32_t",
|
||||
"q": "int64_t",
|
||||
"s": "char",
|
||||
}
|
||||
|
||||
def get_c_parts(self):
|
||||
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
|
||||
|
||||
|
||||
class SimpleConstFormat(SimpleFormat):
|
||||
def __init__(self, fmt, const_value: int):
|
||||
super().__init__(fmt)
|
||||
self.const_value = const_value
|
||||
|
||||
def decode(self, data: bytes) -> tuple[Any, bytes]:
|
||||
value, out_data = super().decode(data)
|
||||
if value != self.const_value:
|
||||
raise ValueError('const_value is wrong')
|
||||
return value, out_data
|
||||
|
||||
|
||||
class EnumFormat(SimpleFormat):
|
||||
def __init__(self, fmt="B", *, as_hex=False, c_definition=True):
|
||||
super().__init__(fmt)
|
||||
self.as_hex = as_hex
|
||||
self.c_definition = c_definition
|
||||
|
||||
def set_field_type(self, field_type):
|
||||
super().set_field_type(field_type)
|
||||
self.c_struct_name = normalize_name(field_type.__name__) + '_t'
|
||||
|
||||
def decode(self, data: bytes) -> tuple[Any, bytes]:
|
||||
value, out_data = super().decode(data)
|
||||
return self.field_type(value), out_data
|
||||
|
||||
def get_c_parts(self):
|
||||
if not self.c_definition:
|
||||
return super().get_c_parts()
|
||||
return self.c_struct_name, ""
|
||||
|
||||
def fromjson(self, data):
|
||||
return self.field_type[data]
|
||||
|
||||
def tojson(self, data):
|
||||
return data.name
|
||||
|
||||
def get_c_definitions(self) -> dict[str, str]:
|
||||
if not self.c_definition:
|
||||
return {}
|
||||
prefix = normalize_name(self.field_type.__name__).upper()
|
||||
options = []
|
||||
last_value = None
|
||||
for item in self.field_type:
|
||||
if last_value is not None and item.value != last_value + 1:
|
||||
options.append('')
|
||||
last_value = item.value
|
||||
options.append("%(prefix)s_%(name)s = %(value)s," % {
|
||||
"prefix": prefix,
|
||||
"name": normalize_name(item.name).upper(),
|
||||
"value": ("0x%02x" if self.as_hex else "%d") % item.value
|
||||
})
|
||||
|
||||
return {
|
||||
self.c_struct_name: "enum {\n%(options)s\n};\ntypedef uint8_t %(name)s;" % {
|
||||
"options": indent_c("\n".join(options)),
|
||||
"name": self.c_struct_name,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TwoNibblesEnumFormat(SimpleFormat):
|
||||
def __init__(self):
|
||||
super().__init__('B')
|
||||
|
||||
def decode(self, data: bytes) -> tuple[bool, bytes]:
|
||||
fields = dataclass_fields(self.field_type)
|
||||
value, data = super().decode(data)
|
||||
return self.field_type(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data
|
||||
|
||||
def encode(self, value):
|
||||
fields = dataclass_fields(self.field_type)
|
||||
return super().encode(
|
||||
getattr(value, fields[0].name).value * 2 ** 4 +
|
||||
getattr(value, fields[1].name).value * 2 ** 4
|
||||
)
|
||||
|
||||
def fromjson(self, data):
|
||||
fields = dataclass_fields(self.field_type)
|
||||
return self.field_type(*(field.type[data[field.name]] for field in fields))
|
||||
|
||||
def tojson(self, data):
|
||||
fields = dataclass_fields(self.field_type)
|
||||
return {
|
||||
field.name: getattr(data, field.name).name for field in fields
|
||||
}
|
||||
|
||||
|
||||
class ChipRevFormat(SimpleFormat):
|
||||
def __init__(self):
|
||||
super().__init__('H')
|
||||
|
||||
def decode(self, data: bytes) -> tuple[tuple[int, int], bytes]:
|
||||
value, data = super().decode(data)
|
||||
return (value // 100, value % 100), data
|
||||
|
||||
def encode(self, value):
|
||||
return value[0] * 100 + value[1]
|
||||
|
||||
|
||||
class BoolFormat(SimpleFormat):
|
||||
def __init__(self):
|
||||
super().__init__('B')
|
||||
|
||||
def encode(self, value):
|
||||
return super().encode(int(value))
|
||||
|
||||
def decode(self, data: bytes) -> tuple[bool, bytes]:
|
||||
value, data = super().decode(data)
|
||||
if value > 1:
|
||||
raise ValueError('Boolean value > 1')
|
||||
return bool(value), data
|
||||
|
||||
|
||||
class FixedStrFormat(SimpleFormat):
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
super().__init__('%ds' % self.num)
|
||||
|
||||
def encode(self, value: str):
|
||||
return value.encode()[:self.num].ljust(self.num, bytes((0,))),
|
||||
|
||||
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||
return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:]
|
||||
|
||||
|
||||
class FixedHexFormat(SimpleFormat):
|
||||
def __init__(self, num, sep=''):
|
||||
self.num = num
|
||||
self.sep = sep
|
||||
super().__init__('%dB' % self.num)
|
||||
|
||||
def encode(self, value: str):
|
||||
return super().encode(tuple(bytes.fromhex(value.replace(':', ''))))
|
||||
|
||||
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
|
||||
|
||||
|
||||
@abstractmethod
|
||||
class BaseVarFormat(BaseFormat, ABC):
|
||||
def __init__(self, max_num):
|
||||
self.num_fmt = 'H'
|
||||
self.num_size = struct.calcsize(self.num_fmt)
|
||||
self.max_num = max_num
|
||||
|
||||
def get_min_size(self):
|
||||
return self.num_size
|
||||
|
||||
def get_max_size(self):
|
||||
return self.num_size + self.max_num * self.get_var_num()
|
||||
|
||||
def get_num_c_code(self):
|
||||
return SimpleFormat(self.num_fmt).get_c_code("num")
|
||||
|
||||
|
||||
class VarArrayFormat(BaseVarFormat):
|
||||
def __init__(self, child_type, max_num):
|
||||
super().__init__(max_num=max_num)
|
||||
self.child_type = child_type
|
||||
self.child_size = self.child_type.get_min_size()
|
||||
|
||||
def get_var_num(self):
|
||||
return self.child_size
|
||||
pass
|
||||
|
||||
def encode(self, values: Sequence) -> bytes:
|
||||
num = len(values)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
data = struct.pack(self.num_fmt, num)
|
||||
for value in values:
|
||||
data += self.child_type.encode(value)
|
||||
return data
|
||||
|
||||
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
data = data[self.num_size:]
|
||||
result = []
|
||||
for i in range(num):
|
||||
item, data = self.child_type.decode(data)
|
||||
result.append(item)
|
||||
return result, data
|
||||
|
||||
def fromjson(self, data):
|
||||
return [
|
||||
self.child_type.fromjson(item) for item in data
|
||||
]
|
||||
|
||||
def tojson(self, data):
|
||||
return [
|
||||
self.child_type.tojson(item) for item in data
|
||||
]
|
||||
|
||||
def get_c_parts(self):
|
||||
pre, post = self.child_type.get_c_parts()
|
||||
return super().get_num_c_code() + "\n" + pre, "[0]" + post
|
||||
|
||||
|
||||
class VarStrFormat(BaseVarFormat):
|
||||
def __init__(self, max_len):
|
||||
super().__init__(max_num=max_len)
|
||||
|
||||
def get_var_num(self):
|
||||
return 1
|
||||
|
||||
def encode(self, value: str) -> bytes:
|
||||
num = len(value)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return struct.pack(self.num_fmt, num) + value.encode()
|
||||
|
||||
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:]
|
||||
|
||||
def get_c_parts(self):
|
||||
return super().get_num_c_code() + "\n" + "char", "[0]"
|
||||
|
||||
|
||||
class VarBytesFormat(BaseVarFormat):
|
||||
def __init__(self, max_size):
|
||||
super().__init__(max_num=max_size)
|
||||
|
||||
def get_var_num(self):
|
||||
return 1
|
||||
|
||||
def encode(self, value: bytes) -> bytes:
|
||||
num = len(value)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return struct.pack(self.num_fmt, num) + value
|
||||
|
||||
def decode(self, data: bytes) -> tuple[bytes, bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:]
|
||||
|
||||
def get_c_parts(self):
|
||||
return super().get_num_c_code() + "\n" + "uint8_t", "[0]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructType:
|
||||
_union_options = {}
|
||||
union_type_field = None
|
||||
existing_c_struct = None
|
||||
c_includes = set()
|
||||
|
||||
@classmethod
|
||||
def get_field_format(cls, attr_name):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
fields = [f for f in dataclass_fields(cls) if f.name == attr_name]
|
||||
if not fields:
|
||||
raise TypeError(f"{cls}.{attr_name} not a field")
|
||||
field = fields[0]
|
||||
type_ = typing.get_type_hints(cls)[attr_name]
|
||||
if "format" in field.metadata:
|
||||
field_format = field.metadata["format"]
|
||||
field_format.set_field_type(type_)
|
||||
return field_format
|
||||
|
||||
if issubclass(type_, StructType):
|
||||
return type_
|
||||
raise TypeError('field %s.%s has no format and is no StructType' %
|
||||
(cls.__class__.__name__, attr_name))
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
def __init_subclass__(cls, /, union_type_field=None, existing_c_struct=None, c_includes=None, **kwargs):
|
||||
cls.union_type_field = union_type_field
|
||||
if c_includes is not None:
|
||||
cls.c_includes |= set(c_includes)
|
||||
if cls.existing_c_struct is not None:
|
||||
# TODO: can we make it possible? does it even make sense?
|
||||
raise TypeError('subclassing an external c struct is not possible')
|
||||
cls.existing_c_struct = existing_c_struct
|
||||
if union_type_field:
|
||||
if union_type_field in cls._union_options:
|
||||
raise TypeError('Duplicate union_type_field: %s', union_type_field)
|
||||
cls._union_options[union_type_field] = {}
|
||||
f = getattr(cls, union_type_field)
|
||||
metadata = dict(f.metadata)
|
||||
metadata['union_discriminator'] = True
|
||||
f.metadata = metadata
|
||||
f.repr = False
|
||||
f.init = False
|
||||
|
||||
for attr_name in cls.__dict__.keys():
|
||||
attr = getattr(cls, attr_name)
|
||||
if isinstance(attr, Field):
|
||||
metadata = dict(attr.metadata)
|
||||
if "defining_class" not in metadata:
|
||||
metadata["defining_class"] = cls
|
||||
attr.metadata = metadata
|
||||
|
||||
for key, values in cls._union_options.items():
|
||||
value = kwargs.pop(key, None)
|
||||
if value is not None:
|
||||
if value in values:
|
||||
raise TypeError('Duplicate %s: %s', (key, value))
|
||||
values[value] = cls
|
||||
setattr(cls, key, value)
|
||||
|
||||
# pydantic model
|
||||
cls._pydantic_fields = getattr(cls, '_pydantic_fields', {}).copy()
|
||||
fields = []
|
||||
for field_ in dataclass_fields(cls):
|
||||
fields.append((field_.name, field_.type, field_.metadata))
|
||||
for attr_name in tuple(cls.__annotations__.keys()):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
metadata = attr.metadata if isinstance(attr, Field) else {}
|
||||
try:
|
||||
type_ = cls.__annotations__[attr_name]
|
||||
except KeyError:
|
||||
# print('nope', cls, attr_name)
|
||||
continue
|
||||
fields.append((attr_name, type_, metadata))
|
||||
for name, type_, metadata in fields:
|
||||
try:
|
||||
field_format = cls.get_field_format(name)
|
||||
except TypeError:
|
||||
# todo: in case of not a field, ignore it?
|
||||
continue
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
cls._pydantic_fields[name] = (type_, ...)
|
||||
else:
|
||||
if metadata.get("json_embed"):
|
||||
cls._pydantic_fields.update(type_._pydantic_fields)
|
||||
else:
|
||||
cls._pydantic_fields[name] = (type_.schema, ...)
|
||||
cls.schema = create_model(cls.__name__ + 'Schema', **cls._pydantic_fields)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_var_num(cls):
|
||||
return sum([cls.get_field_format(f.name).get_var_num() for f in dataclass_fields(cls)], start=0)
|
||||
|
||||
@classmethod
|
||||
def get_types(cls):
|
||||
if not cls.union_type_field:
|
||||
raise TypeError('Not a union class')
|
||||
return cls._union_options[cls.union_type_field]
|
||||
|
||||
@classmethod
|
||||
def get_type(cls, type_id) -> Self:
|
||||
if not cls.union_type_field:
|
||||
raise TypeError('Not a union class')
|
||||
return cls.get_types()[type_id]
|
||||
|
||||
@classmethod
|
||||
def encode(cls, instance, ignore_fields=()) -> bytes:
|
||||
data = bytes()
|
||||
if cls.union_type_field and type(instance) is not cls:
|
||||
if not isinstance(instance, cls):
|
||||
raise ValueError('expected value of type %r, got %r' % (cls, instance))
|
||||
|
||||
for field_ in dataclass_fields(cls):
|
||||
data += cls.get_field_format(field_.name).encode(getattr(instance, field_.name))
|
||||
|
||||
# todo: better
|
||||
data += instance.encode(instance, ignore_fields=set(f.name for f in dataclass_fields(cls)))
|
||||
return data
|
||||
|
||||
for field_ in dataclass_fields(cls):
|
||||
if field_.name in ignore_fields:
|
||||
continue
|
||||
value = getattr(instance, field_.name)
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
data += field_format.encode(value)
|
||||
else:
|
||||
if not isinstance(value, field_.type):
|
||||
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
||||
(field_.type, cls.__name__, field_.name, value))
|
||||
data += value.encode(value)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def decode(cls, data: bytes) -> tuple[Self, bytes]:
|
||||
orig_data = data
|
||||
kwargs = {}
|
||||
no_init_data = {}
|
||||
for field_ in dataclass_fields(cls):
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
value, data = field_format.decode(data)
|
||||
else:
|
||||
value, data = field_.type.decode(data)
|
||||
if field_.init:
|
||||
kwargs[field_.name] = value
|
||||
else:
|
||||
no_init_data[field_.name] = value
|
||||
|
||||
if cls.union_type_field:
|
||||
try:
|
||||
type_value = no_init_data[cls.union_type_field]
|
||||
except KeyError:
|
||||
raise TypeError('union_type_field %s.%s is missing' %
|
||||
(cls.__name__, cls.union_type_field))
|
||||
try:
|
||||
klass = cls.get_type(type_value)
|
||||
except KeyError:
|
||||
raise TypeError('union_type_field %s.%s value %r no known' %
|
||||
(cls.__name__, cls.union_type_field, type_value))
|
||||
return klass.decode(orig_data)
|
||||
return cls(**kwargs), data
|
||||
|
||||
@classmethod
|
||||
def tojson(cls, instance) -> dict:
|
||||
result = {}
|
||||
|
||||
if cls.union_type_field and type(instance) is not cls:
|
||||
if not isinstance(instance, cls):
|
||||
raise ValueError('expected value of type %r, got %r' % (cls, instance))
|
||||
|
||||
for field_ in dataclass_fields(instance):
|
||||
if field_.name is cls.union_type_field:
|
||||
result[field_.name] = cls.get_field_format(field_.name).tojson(getattr(instance, field_.name))
|
||||
break
|
||||
else:
|
||||
raise TypeError('couldn\'t find %s value' % cls.union_type_field)
|
||||
|
||||
result.update(instance.tojson(instance))
|
||||
return result
|
||||
|
||||
for field_ in dataclass_fields(cls):
|
||||
value = getattr(instance, field_.name)
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
result[field_.name] = field_format.tojson(value)
|
||||
else:
|
||||
if not isinstance(value, field_.type):
|
||||
raise ValueError('expected value of type %r for %s.%s, got %r' %
|
||||
(field_.type, cls.__name__, field_.name, value))
|
||||
json_val = value.tojson(value)
|
||||
if field_.metadata.get("json_embed"):
|
||||
for k, v in json_val.items():
|
||||
result[k] = v
|
||||
else:
|
||||
result[field_.name] = value.tojson(value)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def upgrade_json(cls, data):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def fromjson(cls, data: dict) -> Self:
|
||||
data = data.copy()
|
||||
|
||||
# todo: upgrade_json
|
||||
cls.upgrade_json(data)
|
||||
|
||||
kwargs = {}
|
||||
no_init_data = {}
|
||||
for field_ in dataclass_fields(cls):
|
||||
raw_value = data.get(field_.name, None)
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
value = field_format.fromjson(raw_value)
|
||||
else:
|
||||
if field_.metadata.get("json_embed"):
|
||||
value = field_.type.fromjson(data)
|
||||
else:
|
||||
value = field_.type.fromjson(raw_value)
|
||||
if field_.init:
|
||||
kwargs[field_.name] = value
|
||||
else:
|
||||
no_init_data[field_.name] = value
|
||||
|
||||
if cls.union_type_field:
|
||||
try:
|
||||
type_value = no_init_data.pop(cls.union_type_field)
|
||||
except KeyError:
|
||||
raise TypeError('union_type_field %s.%s is missing' %
|
||||
(cls.__name__, cls.union_type_field))
|
||||
try:
|
||||
klass = cls.get_type(type_value)
|
||||
except KeyError:
|
||||
raise TypeError('union_type_field %s.%s value 0x%02x no known' %
|
||||
(cls.__name__, cls.union_type_field, type_value))
|
||||
return klass.fromjson(data)
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_c_struct_items(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False):
|
||||
ignore_fields = set() if not ignore_fields else set(ignore_fields)
|
||||
|
||||
items = []
|
||||
|
||||
for field_ in dataclass_fields(cls):
|
||||
if field_.name in ignore_fields:
|
||||
continue
|
||||
if in_union and field_.metadata["defining_class"] != cls:
|
||||
continue
|
||||
|
||||
name = field_.metadata.get("c_name", field_.name)
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
if not field_.metadata.get("union_discriminator") or field_.metadata.get("defining_class") == cls:
|
||||
items.append((
|
||||
(
|
||||
("%(typedef_name)s %(name)s;" % {
|
||||
"typedef_name": field_format.get_typedef_name(),
|
||||
"name": name,
|
||||
})
|
||||
if field_.metadata.get("as_definition")
|
||||
else field_format.get_c_code(name)
|
||||
),
|
||||
field_.metadata.get("doc", None),
|
||||
)),
|
||||
else:
|
||||
if field_.metadata.get("c_embed"):
|
||||
embedded_items = field_.type.get_c_struct_items(ignore_fields, no_empty, top_level, union_only)
|
||||
items.extend(embedded_items)
|
||||
else:
|
||||
items.append((
|
||||
(
|
||||
("%(typedef_name)s %(name)s;" % {
|
||||
"typedef_name": field_.type.get_typedef_name(),
|
||||
"name": name,
|
||||
})
|
||||
if field_.metadata.get("as_definition")
|
||||
else field_.type.get_c_code(name, typedef=False)
|
||||
),
|
||||
field_.metadata.get("doc", None),
|
||||
))
|
||||
|
||||
if cls.union_type_field:
|
||||
if not union_only:
|
||||
union_code = cls.get_c_union_code(ignore_fields)
|
||||
items.append(("union __packed %s;" % union_code, ""))
|
||||
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def get_c_union_size(cls):
|
||||
return max(
|
||||
(option.get_min_size(no_inherited_fields=True) for option in
|
||||
cls._union_options[cls.union_type_field].values()),
|
||||
default=0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_c_definitions(cls) -> dict[str, str]:
|
||||
definitions = {}
|
||||
for field_ in dataclass_fields(cls):
|
||||
field_format = cls.get_field_format(field_.name)
|
||||
if not (isinstance(field_format, type) and issubclass(field_format, StructType)):
|
||||
definitions.update(field_format.get_c_definitions())
|
||||
if field_.metadata.get("as_definition"):
|
||||
typedef_name = field_format.get_typedef_name()
|
||||
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
|
||||
"code": ''.join(field_format.get_c_parts()),
|
||||
"name": typedef_name,
|
||||
}
|
||||
else:
|
||||
definitions.update(field_.type.get_c_definitions())
|
||||
if field_.metadata.get("as_definition"):
|
||||
typedef_name = field_.type.get_typedef_name()
|
||||
definitions[typedef_name] = field_.type.get_c_code(name=typedef_name, typedef=True)
|
||||
if cls.union_type_field:
|
||||
for key, option in cls._union_options[cls.union_type_field].items():
|
||||
definitions.update(option.get_c_definitions())
|
||||
return definitions
|
||||
|
||||
@classmethod
|
||||
def get_c_union_code(cls, ignore_fields=None):
|
||||
union_items = []
|
||||
for key, option in cls._union_options[cls.union_type_field].items():
|
||||
base_name = normalize_name(getattr(key, 'name', option.__name__))
|
||||
item_c_code = option.get_c_code(
|
||||
base_name, ignore_fields=ignore_fields, typedef=False, in_union=True, no_empty=True
|
||||
)
|
||||
if item_c_code:
|
||||
union_items.append(item_c_code)
|
||||
size = cls.get_c_union_size()
|
||||
union_items.append(
|
||||
"uint8_t bytes[%0d]; " % size
|
||||
)
|
||||
return "{\n" + indent_c("\n".join(union_items)) + "\n}"
|
||||
|
||||
@classmethod
|
||||
def get_c_parts(cls, ignore_fields=None, no_empty=False, top_level=False, union_only=False, in_union=False):
|
||||
if cls.existing_c_struct is not None:
|
||||
return (cls.existing_c_struct, "")
|
||||
|
||||
ignore_fields = set() if not ignore_fields else set(ignore_fields)
|
||||
|
||||
if union_only:
|
||||
if cls.union_type_field:
|
||||
union_code = cls.get_c_union_code(ignore_fields)
|
||||
return "typedef union __packed %s" % union_code, ""
|
||||
else:
|
||||
return "", ""
|
||||
|
||||
pre = ""
|
||||
|
||||
items = cls.get_c_struct_items(ignore_fields=ignore_fields,
|
||||
no_empty=no_empty,
|
||||
top_level=top_level,
|
||||
union_only=union_only,
|
||||
in_union=in_union)
|
||||
|
||||
if no_empty and not items:
|
||||
return "", ""
|
||||
|
||||
# todo: struct comment
|
||||
if top_level:
|
||||
comment = cls.__doc__.strip()
|
||||
if comment:
|
||||
pre += "/** %s */\n" % comment
|
||||
pre += "typedef struct __packed "
|
||||
else:
|
||||
pre += "struct __packed "
|
||||
|
||||
pre += "{\n%(elements)s\n}" % {
|
||||
"elements": indent_c(
|
||||
"\n".join(
|
||||
code + ("" if not comment else (" /** %s */" % comment))
|
||||
for code, comment in items
|
||||
)
|
||||
),
|
||||
}
|
||||
return pre, ""
|
||||
|
||||
@classmethod
|
||||
def get_c_code(cls, name=None, ignore_fields=None, no_empty=False, typedef=True, union_only=False,
|
||||
in_union=False) -> str:
|
||||
pre, post = cls.get_c_parts(ignore_fields=ignore_fields,
|
||||
no_empty=no_empty,
|
||||
top_level=typedef,
|
||||
union_only=union_only,
|
||||
in_union=in_union)
|
||||
if no_empty and not pre and not post:
|
||||
return ""
|
||||
return "%s %s%s;" % (pre, name, post)
|
||||
|
||||
@classmethod
|
||||
def get_variable_name(cls, base_name):
|
||||
return base_name
|
||||
|
||||
@classmethod
|
||||
def get_typedef_name(cls):
|
||||
return "%s_t" % normalize_name(cls.__name__)
|
||||
|
||||
@classmethod
|
||||
def get_min_size(cls, no_inherited_fields=False) -> int:
|
||||
if cls.union_type_field:
|
||||
own_size = sum([cls.get_field_format(f.name).get_min_size() for f in dataclass_fields(cls)])
|
||||
union_size = max(
|
||||
[0] + [option.get_min_size(True) for option in cls._union_options[cls.union_type_field].values()])
|
||||
return own_size + union_size
|
||||
if no_inherited_fields:
|
||||
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
||||
else:
|
||||
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
||||
return sum((cls.get_field_format(f.name).get_min_size() for f in relevant_fields), start=0)
|
||||
|
||||
@classmethod
|
||||
def get_size(cls, no_inherited_fields=False, calculate_max=False) -> int:
|
||||
if cls.union_type_field:
|
||||
own_size = sum(
|
||||
[cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in dataclass_fields(cls)])
|
||||
union_size = max(
|
||||
[0] + [option.get_size(no_inherited_fields=True, calculate_max=calculate_max) for option in
|
||||
cls._union_options[cls.union_type_field].values()])
|
||||
return own_size + union_size
|
||||
if no_inherited_fields:
|
||||
relevant_fields = [f for f in dataclass_fields(cls) if f.metadata["defining_class"] == cls]
|
||||
else:
|
||||
relevant_fields = [f for f in dataclass_fields(cls) if not f.metadata.get("union_discriminator")]
|
||||
return sum((cls.get_field_format(f.name).get_size(calculate_max=calculate_max) for f in relevant_fields),
|
||||
start=0)
|
||||
|
||||
|
||||
def normalize_name(name):
|
||||
if '_' in name:
|
||||
name = name.lower()
|
||||
else:
|
||||
name = re.sub(
|
||||
r"(([a-z])([A-Z]))|(([a-zA-Z])([A-Z][a-z]))",
|
||||
r"\2\5_\3\6",
|
||||
name
|
||||
).lower()
|
||||
|
||||
name = re.sub(
|
||||
r"(ota)([a-z])",
|
||||
r"\1_\2",
|
||||
name
|
||||
).lower()
|
||||
|
||||
name = name.replace('config', 'cfg')
|
||||
name = name.replace('position', 'pos')
|
||||
name = name.replace('mesh_', '')
|
||||
name = name.replace('firmware', 'fw')
|
||||
name = name.replace('hardware', 'hw')
|
||||
return name
|
961
src/c3nav/mesh/cformats.py
Normal file
961
src/c3nav/mesh/cformats.py
Normal file
|
@ -0,0 +1,961 @@
|
|||
import re
|
||||
import struct
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from enum import IntEnum, Enum
|
||||
from typing import Any, Sequence, Self, Annotated, Literal, Union, Type, TypeVar, ClassVar
|
||||
|
||||
from annotated_types import SLOTS, BaseMetadata, Ge
|
||||
from pydantic.fields import Field, FieldInfo
|
||||
from pydantic_extra_types.mac_address import MacAddress
|
||||
|
||||
from c3nav.mesh.utils import indent_c
|
||||
|
||||
|
||||
@dataclass(frozen=True, **SLOTS)
|
||||
class VarLen(BaseMetadata):
|
||||
var_len_name: str = "num"
|
||||
|
||||
|
||||
@dataclass(frozen=True, **SLOTS)
|
||||
class NoDef(BaseMetadata):
|
||||
no_def: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True, **SLOTS)
|
||||
class AsHex(BaseMetadata):
|
||||
as_hex: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True, **SLOTS)
|
||||
class LenBytes(BaseMetadata):
|
||||
len_bytes: Annotated[int, Ge(1)]
|
||||
|
||||
|
||||
@dataclass(frozen=True, **SLOTS)
|
||||
class AsDefinition(BaseMetadata):
|
||||
as_definition: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True, **SLOTS)
|
||||
class CEmbed(BaseMetadata):
|
||||
c_embed: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True, **SLOTS)
|
||||
class CName(BaseMetadata):
|
||||
c_name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, **SLOTS)
|
||||
class CDoc(BaseMetadata):
|
||||
c_doc: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExistingCStruct():
|
||||
name: str
|
||||
includes: list[str]
|
||||
|
||||
|
||||
class CEnum(str, Enum):
|
||||
def __new__(cls, value, c_value):
|
||||
obj = str.__new__(cls)
|
||||
obj._value_ = value
|
||||
obj.c_value = c_value
|
||||
return obj
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.value)
|
||||
|
||||
|
||||
def discriminator_value(**kwargs):
|
||||
return type('DiscriminatorValue', (), {
|
||||
# todo: make this so pydantic doesn't throw a warning
|
||||
**{name: value for name, value in kwargs.items()},
|
||||
'__annotations__': {
|
||||
name: Annotated[Literal[value], Field(init=False)]
|
||||
for name, value in kwargs.items()
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
class TwoNibblesEncodable:
|
||||
pass
|
||||
|
||||
|
||||
class SplitTypeHint(namedtuple("SplitTypeHint", ("base", "metadata"))):
|
||||
@classmethod
|
||||
def from_annotation(cls, type_hint) -> Self:
|
||||
if typing.get_origin(type_hint) is Annotated:
|
||||
field_infos = tuple(m for m in type_hint.__metadata__ if isinstance(m, FieldInfo))
|
||||
return cls(
|
||||
base=typing.get_args(type_hint)[0],
|
||||
metadata=(
|
||||
*(m for m in type_hint.__metadata__),
|
||||
*(tuple(field_infos[0].metadata) if field_infos else ())
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(type_hint, FieldInfo):
|
||||
return cls(
|
||||
base=type_hint.annotation,
|
||||
metadata=tuple(type_hint.metadata)
|
||||
)
|
||||
|
||||
return cls(
|
||||
base=type_hint,
|
||||
metadata=()
|
||||
)
|
||||
|
||||
def get_len_metadata(self):
|
||||
max_length = None
|
||||
var_len_name = None
|
||||
for m in self.metadata:
|
||||
ml = getattr(m, 'max_length', None)
|
||||
if ml is not None:
|
||||
max_length = ml if max_length is None else min(max_length, ml)
|
||||
|
||||
vl = getattr(m, 'var_len_name', None)
|
||||
if vl is not None:
|
||||
if var_len_name is not None:
|
||||
raise ValueError('can\'t set variable length name twice')
|
||||
var_len_name = vl
|
||||
return max_length, var_len_name
|
||||
|
||||
def get_min_max_metadata(self, default_min=-(2 ** 63), default_max=2 ** 63 - 1):
|
||||
min_ = default_min
|
||||
max_ = default_max
|
||||
for m in self.metadata:
|
||||
gt = getattr(m, 'gt', None)
|
||||
if gt is not None:
|
||||
min_ = max(min_, gt + 1)
|
||||
ge = getattr(m, 'ge', None)
|
||||
if ge is not None:
|
||||
min_ = max(min_, ge)
|
||||
lt = getattr(m, 'lt', None)
|
||||
if lt is not None:
|
||||
max_ = min(max_, lt - 1)
|
||||
le = getattr(m, 'le', None)
|
||||
if le is not None:
|
||||
max_ = min(max_, le)
|
||||
return min_, max_
|
||||
|
||||
|
||||
def normalize_name(name):
|
||||
if '_' in name:
|
||||
name = name.lower()
|
||||
else:
|
||||
name = re.sub(
|
||||
r"(([a-z])([A-Z]))|(([a-zA-Z])([A-Z][a-z]))",
|
||||
r"\2\5_\3\6",
|
||||
name
|
||||
).lower()
|
||||
|
||||
name = re.sub(
|
||||
r"(ota)([a-z])",
|
||||
r"\1_\2",
|
||||
name
|
||||
).lower()
|
||||
|
||||
name = name.replace('config', 'cfg')
|
||||
name = name.replace('position', 'pos')
|
||||
name = name.replace('mesh_', '')
|
||||
name = name.replace('firmware', 'fw')
|
||||
name = name.replace('hardware', 'hw')
|
||||
return name
|
||||
|
||||
|
||||
class CFormat(ABC):
|
||||
# todo: make this some cool generic with a TypeVar
|
||||
|
||||
def get_var_num(self):
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, value):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def decode(cls, data) -> tuple[Any, bytes]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_min_size(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_size(self):
|
||||
pass
|
||||
|
||||
def get_size(self, calculate_max=False):
|
||||
if calculate_max:
|
||||
return self.get_max_size()
|
||||
else:
|
||||
return self.get_min_size()
|
||||
|
||||
@abstractmethod
|
||||
def get_c_parts(self) -> tuple[str, str]:
|
||||
pass
|
||||
|
||||
def get_c_code(self, name) -> str:
|
||||
pre, post = self.get_c_parts()
|
||||
return "%s %s%s;" % (pre, name, post)
|
||||
|
||||
def get_c_definitions(self) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
def get_typedef_name(self):
|
||||
raise TypeError('no typedef for %r' % self)
|
||||
|
||||
def get_c_includes(self) -> set:
|
||||
return set()
|
||||
|
||||
@classmethod
|
||||
def from_annotation(cls, annotation, attr_name=None) -> Self:
|
||||
if cls is not CFormat:
|
||||
raise TypeError('call on CFormat!')
|
||||
return cls.from_split_type_hint(SplitTypeHint.from_annotation(annotation), attr_name=attr_name)
|
||||
|
||||
@classmethod
|
||||
def from_split_type_hint(cls, type_hint: SplitTypeHint, attr_name=None) -> Self:
|
||||
if cls is not CFormat:
|
||||
raise TypeError('call on CFormat!')
|
||||
outer_type_hint = None
|
||||
if typing.get_origin(type_hint.base) is list:
|
||||
outer_type_hint = SplitTypeHint(
|
||||
base=list,
|
||||
metadata=type_hint.metadata
|
||||
)
|
||||
type_hint = SplitTypeHint(
|
||||
base=typing.get_args(type_hint.base)[0],
|
||||
metadata=()
|
||||
)
|
||||
if typing.get_origin(type_hint.base) is Annotated:
|
||||
type_hint = SplitTypeHint(
|
||||
base=typing.get_args(type_hint.base)[0],
|
||||
metadata=tuple(type_hint.base.__metadata__)
|
||||
)
|
||||
|
||||
field_format = None
|
||||
|
||||
if typing.get_origin(type_hint.base) is Literal:
|
||||
literal_val = typing.get_args(type_hint.base)[0]
|
||||
if isinstance(literal_val, CEnum):
|
||||
options = [v.c_value for v in type(literal_val)]
|
||||
literal_val = literal_val.c_value
|
||||
int_type = get_int_type(
|
||||
*type_hint.get_min_max_metadata(default_min=min(options), default_max=max(options))
|
||||
)
|
||||
elif isinstance(literal_val, int):
|
||||
int_type = get_int_type(literal_val, literal_val)
|
||||
else:
|
||||
raise ValueError()
|
||||
if int_type is None:
|
||||
raise ValueError('invalid range:', attr_name)
|
||||
field_format = SimpleConstFormat(int_type, const_value=literal_val)
|
||||
elif typing.get_origin(type_hint.base) is Union:
|
||||
discriminator = None
|
||||
for m in type_hint.metadata:
|
||||
discriminator = getattr(m, 'discriminator', discriminator)
|
||||
if discriminator is None:
|
||||
raise ValueError('no discriminator')
|
||||
discriminator_as_hex = any(getattr(m, "as_hex", False) for m in type_hint.metadata)
|
||||
field_format = UnionFormat(
|
||||
model_formats=[StructFormat(type_) for type_ in typing.get_args(type_hint.base)],
|
||||
discriminator=discriminator,
|
||||
discriminator_as_hex=discriminator_as_hex,
|
||||
)
|
||||
elif type_hint.base is int:
|
||||
int_type = get_int_type(*type_hint.get_min_max_metadata())
|
||||
if int_type is None:
|
||||
raise ValueError('invalid range:', attr_name)
|
||||
field_format = SimpleFormat(int_type)
|
||||
elif type_hint.base is bool:
|
||||
field_format = BoolFormat()
|
||||
elif type_hint.base in (str, bytes):
|
||||
as_hex = any(getattr(m, 'as_hex', False) for m in type_hint.metadata)
|
||||
max_length, var_len_name = type_hint.get_len_metadata()
|
||||
if max_length is None:
|
||||
raise ValueError('missing str max_length:', attr_name)
|
||||
|
||||
if type_hint.base is str:
|
||||
if var_len_name is not None:
|
||||
field_format = VarStrFormat(max_len=max_length)
|
||||
else:
|
||||
field_format = FixedHexFormat(max_length//2) if as_hex else FixedStrFormat(max_length)
|
||||
else:
|
||||
if var_len_name is None:
|
||||
field_format = FixedBytesFormat(num=max_length)
|
||||
else:
|
||||
field_format = VarBytesFormat(max_size=max_length)
|
||||
elif type_hint.base is MacAddress:
|
||||
field_format = MacAddressFormat()
|
||||
elif isinstance(type_hint.base, type) and issubclass(type_hint.base, CEnum):
|
||||
no_def = any(getattr(m, 'no_def', False) for m in type_hint.metadata)
|
||||
as_hex = any(getattr(m, 'as_hex', False) for m in type_hint.metadata)
|
||||
len_bytes = None
|
||||
for m in type_hint.metadata:
|
||||
len_bytes = getattr(m, 'len_bytes', len_bytes)
|
||||
|
||||
if len_bytes:
|
||||
int_type = get_int_type(0, 2 ** (8 * len_bytes - 1))
|
||||
else:
|
||||
options = [v.c_value for v in type_hint.base]
|
||||
int_type = get_int_type(min(options), max(options))
|
||||
if int_type is None:
|
||||
raise ValueError('invalid range:', attr_name)
|
||||
field_format = EnumFormat(enum_cls=type_hint.base, fmt=int_type, as_hex=as_hex, c_definition=not no_def)
|
||||
elif isinstance(type_hint.base, type) and issubclass(type_hint.base, TwoNibblesEncodable):
|
||||
field_format = TwoNibblesEnumFormat(type_hint.base)
|
||||
elif isinstance(type_hint.base, type) and typing.get_type_hints(type_hint.base):
|
||||
field_format = StructFormat(model=type_hint.base)
|
||||
|
||||
if field_format is None:
|
||||
raise ValueError('Unknown type annotation for c structs', type_hint.base)
|
||||
else:
|
||||
if outer_type_hint is not None and outer_type_hint.base is list:
|
||||
max_length, var_len_name = outer_type_hint.get_len_metadata()
|
||||
if max_length is None:
|
||||
raise ValueError('missing list max_length:', attr_name)
|
||||
if var_len_name:
|
||||
field_format = VarArrayFormat(field_format, max_num=max_length)
|
||||
else:
|
||||
raise ValueError('fixed-len list not implemented:', attr_name)
|
||||
|
||||
return field_format
|
||||
|
||||
|
||||
def get_int_type(min_: int, max_: int) -> str | None:
|
||||
if min_ < 0:
|
||||
if min_ < -(2 ** 63) or max_ > 2 ** 63 - 1:
|
||||
return None
|
||||
elif min_ < -(2 ** 31) or max_ > 2 ** 31 - 1:
|
||||
return "q"
|
||||
elif min_ < -(2 ** 15) or max_ > 2 ** 15 - 1:
|
||||
return "i"
|
||||
elif min_ < -(2 ** 7) or max_ > 2 ** 7 - 1:
|
||||
return "h"
|
||||
else:
|
||||
return "b"
|
||||
|
||||
if max_ > 2 ** 64 - 1:
|
||||
return None
|
||||
elif max_ > 2 ** 32 - 1:
|
||||
return "Q"
|
||||
elif max_ > 2 ** 16 - 1:
|
||||
return "I"
|
||||
elif max_ > 2 ** 8 - 1:
|
||||
return "H"
|
||||
else:
|
||||
return "B"
|
||||
|
||||
|
||||
class SimpleFormat(CFormat):
|
||||
def __init__(self, fmt):
|
||||
self.fmt = fmt
|
||||
self.size = struct.calcsize(fmt)
|
||||
|
||||
self.c_type = self.c_types[self.fmt[-1]]
|
||||
self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1
|
||||
|
||||
def encode(self, value):
|
||||
if self.num == 1:
|
||||
return struct.pack(self.fmt, value)
|
||||
return struct.pack(self.fmt, *value)
|
||||
|
||||
def decode(self, data: bytes) -> tuple[Any, bytes]:
|
||||
value = struct.unpack(self.fmt, data[:self.size])
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
return value, data[self.size:]
|
||||
|
||||
def get_min_size(self):
|
||||
return self.size
|
||||
|
||||
def get_max_size(self):
|
||||
return self.size
|
||||
|
||||
c_types = {
|
||||
"B": "uint8_t",
|
||||
"H": "uint16_t",
|
||||
"I": "uint32_t",
|
||||
"Q": "uint64_t",
|
||||
"b": "int8_t",
|
||||
"h": "int16_t",
|
||||
"i": "int32_t",
|
||||
"q": "int64_t",
|
||||
"s": "char",
|
||||
}
|
||||
|
||||
def get_c_parts(self):
|
||||
return self.c_type, ("" if self.num == 1 else ("[%d]" % self.num))
|
||||
|
||||
|
||||
class SimpleConstFormat(SimpleFormat):
|
||||
def __init__(self, fmt, const_value: int):
|
||||
super().__init__(fmt)
|
||||
self.const_value = const_value
|
||||
|
||||
def decode(self, data: bytes) -> tuple[Any, bytes]:
|
||||
value, out_data = super().decode(data)
|
||||
if value != self.const_value:
|
||||
raise ValueError('const_value is wrong')
|
||||
return value, out_data
|
||||
|
||||
|
||||
class EnumFormat(SimpleFormat):
|
||||
def __init__(self, enum_cls: Type[CEnum], fmt="B", *, as_hex=False, c_definition=True):
|
||||
super().__init__(fmt)
|
||||
self.enum_cls = enum_cls
|
||||
self.enum_lookup = {v.c_value: v for v in enum_cls}
|
||||
if len(self.enum_cls) != len(self.enum_lookup):
|
||||
raise ValueError
|
||||
self.as_hex = as_hex
|
||||
self.c_definition = c_definition
|
||||
|
||||
self.c_struct_name = normalize_name(enum_cls.__name__) + '_t'
|
||||
|
||||
def decode(self, data: bytes) -> tuple[Any, bytes]:
|
||||
value, out_data = super().decode(data)
|
||||
return self.enum_lookup[value], out_data
|
||||
|
||||
def get_typedef_name(self):
|
||||
return '%s_t' % normalize_name(self.enum_cls.__name__)
|
||||
|
||||
def get_c_parts(self):
|
||||
if not self.c_definition:
|
||||
return super().get_c_parts()
|
||||
return self.c_struct_name, ""
|
||||
|
||||
def get_c_definitions(self) -> dict[str, str]:
|
||||
if not self.c_definition:
|
||||
return {}
|
||||
prefix = normalize_name(self.enum_cls.__name__).upper()
|
||||
options = []
|
||||
last_value = None
|
||||
for item in self.enum_cls:
|
||||
if last_value is not None and item.c_value != last_value + 1:
|
||||
options.append('')
|
||||
last_value = item.c_value
|
||||
options.append("%(prefix)s_%(name)s = %(value)s," % {
|
||||
"prefix": prefix,
|
||||
"name": normalize_name(item.name).upper(),
|
||||
"value": ("0x%02x" if self.as_hex else "%d") % item.c_value
|
||||
})
|
||||
|
||||
return {
|
||||
self.c_struct_name: "enum {\n%(options)s\n};\ntypedef uint8_t %(name)s;" % {
|
||||
"options": indent_c("\n".join(options)),
|
||||
"name": self.c_struct_name,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TwoNibblesEnumFormat(SimpleFormat):
|
||||
def __init__(self, data_cls):
|
||||
self.data_cls = data_cls
|
||||
super().__init__('B')
|
||||
|
||||
def decode(self, data: bytes) -> tuple[bool, bytes]:
|
||||
fields = dataclass_fields(self.data_cls)
|
||||
value, data = super().decode(data)
|
||||
return self.data_cls(fields[0].type(value // 2 ** 4), fields[1].type(value // 2 ** 4)), data
|
||||
|
||||
def encode(self, value):
|
||||
fields = dataclass_fields(self.data_cls)
|
||||
return super().encode(
|
||||
getattr(value, fields[0].name).value * 2 ** 4 +
|
||||
getattr(value, fields[1].name).value * 2 ** 4
|
||||
)
|
||||
|
||||
|
||||
class BoolFormat(SimpleFormat):
|
||||
def __init__(self):
|
||||
super().__init__('B')
|
||||
|
||||
def encode(self, value):
|
||||
return super().encode(int(value))
|
||||
|
||||
def decode(self, data: bytes) -> tuple[bool, bytes]:
|
||||
value, data = super().decode(data)
|
||||
if value > 1:
|
||||
raise ValueError('Boolean value > 1')
|
||||
return bool(value), data
|
||||
|
||||
|
||||
class FixedStrFormat(SimpleFormat):
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
super().__init__('%ds' % self.num)
|
||||
|
||||
def encode(self, value: str):
|
||||
return value.encode()[:self.num].ljust(self.num, bytes((0,))),
|
||||
|
||||
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||
return data[:self.num].rstrip(bytes((0,))).decode(), data[self.num:]
|
||||
|
||||
|
||||
class FixedBytesFormat(SimpleFormat):
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
super().__init__('%dB' % self.num)
|
||||
|
||||
def encode(self, value: str):
|
||||
return super().encode(tuple(value))
|
||||
|
||||
def decode(self, data: bytes) -> tuple[bytes, bytes]:
|
||||
return data[:self.num], data[self.num:]
|
||||
|
||||
|
||||
class FixedHexFormat(SimpleFormat):
|
||||
def __init__(self, num, sep=''):
|
||||
self.num = num
|
||||
self.sep = sep
|
||||
super().__init__('%dB' % self.num)
|
||||
|
||||
def encode(self, value: str):
|
||||
return super().encode(tuple(bytes.fromhex(value.replace(':', ''))))
|
||||
|
||||
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||
return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:]
|
||||
|
||||
|
||||
class MacAddressFormat(FixedHexFormat):
|
||||
def __init__(self):
|
||||
super().__init__(num=6, sep=':')
|
||||
|
||||
|
||||
class BaseVarFormat(CFormat, ABC):
|
||||
def __init__(self, max_num):
|
||||
self.num_fmt = 'H'
|
||||
self.num_size = struct.calcsize(self.num_fmt)
|
||||
self.max_num = max_num
|
||||
|
||||
def get_min_size(self):
|
||||
return self.num_size
|
||||
|
||||
def get_max_size(self):
|
||||
return self.num_size + self.max_num * self.get_var_num()
|
||||
|
||||
def get_num_c_code(self):
|
||||
return SimpleFormat(self.num_fmt).get_c_code("num")
|
||||
|
||||
|
||||
class VarArrayFormat(BaseVarFormat):
|
||||
def __init__(self, child_type, max_num):
|
||||
super().__init__(max_num=max_num)
|
||||
self.child_type = child_type
|
||||
self.child_size = self.child_type.get_min_size()
|
||||
|
||||
def get_var_num(self):
|
||||
return self.child_size
|
||||
pass
|
||||
|
||||
def encode(self, values: Sequence) -> bytes:
|
||||
num = len(values)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
data = struct.pack(self.num_fmt, num)
|
||||
for value in values:
|
||||
data += self.child_type.encode(value)
|
||||
return data
|
||||
|
||||
def decode(self, data: bytes) -> tuple[list[Any], bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
data = data[self.num_size:]
|
||||
result = []
|
||||
for i in range(num):
|
||||
item, data = self.child_type.decode(data)
|
||||
result.append(item)
|
||||
return result, data
|
||||
|
||||
def get_c_parts(self):
|
||||
pre, post = self.child_type.get_c_parts()
|
||||
return super().get_num_c_code() + "\n" + pre, "[0]" + post
|
||||
|
||||
|
||||
class VarStrFormat(BaseVarFormat):
|
||||
def __init__(self, max_len):
|
||||
super().__init__(max_num=max_len)
|
||||
|
||||
def get_var_num(self):
|
||||
return 1
|
||||
|
||||
def encode(self, value: str) -> bytes:
|
||||
num = len(value)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return struct.pack(self.num_fmt, num) + value.encode()
|
||||
|
||||
def decode(self, data: bytes) -> tuple[str, bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))).decode(), data[self.num_size + num:]
|
||||
|
||||
def get_c_parts(self):
|
||||
return super().get_num_c_code() + "\n" + "char", "[0]"
|
||||
|
||||
|
||||
class VarBytesFormat(BaseVarFormat):
|
||||
def __init__(self, max_size):
|
||||
super().__init__(max_num=max_size)
|
||||
|
||||
def get_var_num(self):
|
||||
return 1
|
||||
|
||||
def encode(self, value: bytes) -> bytes:
|
||||
num = len(value)
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return struct.pack(self.num_fmt, num) + value
|
||||
|
||||
def decode(self, data: bytes) -> tuple[bytes, bytes]:
|
||||
num = struct.unpack(self.num_fmt, data[:self.num_size])[0]
|
||||
if num > self.max_num:
|
||||
raise ValueError(f'too many elements, got {num} but maximum is {self.max_num}')
|
||||
return data[self.num_size:self.num_size + num].rstrip(bytes((0,))), data[self.num_size + num:]
|
||||
|
||||
def get_c_parts(self):
|
||||
return super().get_num_c_code() + "\n" + "uint8_t", "[0]"
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class CFormatDecodeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StructFormat(CFormat):
|
||||
_format_cache: dict[Type, dict[str, CFormat]] = {}
|
||||
|
||||
def __new__(cls, model: Type[T]):
|
||||
result = cls._format_cache.get(model, None)
|
||||
if not result:
|
||||
result = super().__new__(cls)
|
||||
cls._format_cache.get(model, result)
|
||||
return result
|
||||
|
||||
def __init__(self, model: Type[T]):
|
||||
self.model = model
|
||||
|
||||
self._field_formats = {}
|
||||
self._as_definition = set()
|
||||
self._c_embed = set()
|
||||
self._c_names = {}
|
||||
self._c_docs = {}
|
||||
self._no_init_data = set()
|
||||
for name, type_hint in typing.get_type_hints(self.model, include_extras=True).items():
|
||||
if type_hint is ClassVar:
|
||||
continue
|
||||
type_hint = SplitTypeHint.from_annotation(type_hint)
|
||||
|
||||
if any(getattr(m, "as_definition", False) for m in type_hint.metadata):
|
||||
self._as_definition.add(name)
|
||||
if any(getattr(m, "c_embed", False) for m in type_hint.metadata):
|
||||
self._c_embed.add(name)
|
||||
if not all(getattr(m, "init", True) for m in type_hint.metadata):
|
||||
self._no_init_data.add(name)
|
||||
for m in type_hint.metadata:
|
||||
with suppress(AttributeError):
|
||||
self._c_names[name] = m.c_name
|
||||
with suppress(AttributeError):
|
||||
self._c_docs[name] = m.c_doc
|
||||
|
||||
self._field_formats[name] = CFormat.from_split_type_hint(type_hint, attr_name=name)
|
||||
|
||||
def get_var_num(self):
|
||||
return sum([field_format.get_var_num() for name, field_format in self._field_formats.items()], start=0)
|
||||
|
||||
def encode(self, instance: T, ignore_fields=()) -> bytes:
|
||||
data = bytes()
|
||||
for name, field_format in self._field_formats.items():
|
||||
if name in ignore_fields:
|
||||
continue
|
||||
data += field_format.encode(getattr(instance, name))
|
||||
return data
|
||||
|
||||
def decode(self, data: bytes) -> tuple[T, bytes]:
|
||||
decoded = {}
|
||||
for name, field_format in self._field_formats.items():
|
||||
try:
|
||||
value, data = field_format.decode(data)
|
||||
except (struct.error, UnicodeDecodeError, ValueError) as e:
|
||||
raise CFormatDecodeError(f"failed to decode model={self.model}, field={name}, data={data}, e={e}")
|
||||
if isinstance(value, CEnum):
|
||||
value = value.value
|
||||
if name not in self._no_init_data:
|
||||
decoded[name] = value
|
||||
return self.model.model_validate(decoded), data
|
||||
|
||||
def get_min_size(self) -> int:
|
||||
return sum((
|
||||
field_format.get_min_size() for field_format in self._field_formats.values()
|
||||
), start=0)
|
||||
|
||||
def get_max_size(self) -> int:
|
||||
raise ValueError
|
||||
|
||||
def get_size(self, calculate_max=False):
|
||||
return sum((
|
||||
field_format.get_size(calculate_max=calculate_max) for field_format in self._field_formats.values()
|
||||
), start=0)
|
||||
|
||||
def get_c_struct_items(self, ignore_fields=None, no_empty=False, top_level=False):
|
||||
ignore_fields = set() if not ignore_fields else set(ignore_fields)
|
||||
|
||||
items = []
|
||||
|
||||
for name, field_format in self._field_formats.items():
|
||||
if name in ignore_fields:
|
||||
continue
|
||||
|
||||
c_name = self._c_names.get(name, name)
|
||||
if not isinstance(field_format, (StructFormat, UnionFormat)):
|
||||
items.append((
|
||||
(
|
||||
("%(typedef_name)s %(name)s;" % {
|
||||
"typedef_name": field_format.get_typedef_name(),
|
||||
"name": c_name,
|
||||
})
|
||||
if name in self._as_definition
|
||||
else field_format.get_c_code(c_name)
|
||||
),
|
||||
self._c_docs.get(name, None),
|
||||
)),
|
||||
else:
|
||||
if name in self._c_embed:
|
||||
embedded_items = field_format.get_c_struct_items(ignore_fields, no_empty, top_level)
|
||||
items.extend(embedded_items)
|
||||
else:
|
||||
items.append((
|
||||
(
|
||||
("%(typedef_name)s %(name)s;" % {
|
||||
"typedef_name": field_format.get_typedef_name(),
|
||||
"name": c_name,
|
||||
})
|
||||
if name in self._as_definition
|
||||
else field_format.get_c_code(c_name, typedef=False)
|
||||
),
|
||||
self._c_docs.get(name, None),
|
||||
))
|
||||
|
||||
return items
|
||||
|
||||
def get_c_parts(self, ignore_fields=None, no_empty=False, top_level=False) -> tuple[str, str]:
|
||||
with suppress(AttributeError):
|
||||
return (self.model.existing_c_struct.name, "")
|
||||
|
||||
ignore_fields = set() if not ignore_fields else set(ignore_fields)
|
||||
|
||||
pre = ""
|
||||
|
||||
items = self.get_c_struct_items(ignore_fields=ignore_fields,
|
||||
no_empty=no_empty,
|
||||
top_level=top_level)
|
||||
|
||||
if no_empty and not items:
|
||||
return "", ""
|
||||
|
||||
if top_level:
|
||||
comment = self.model.__doc__.strip()
|
||||
if comment:
|
||||
pre += "/** %s */\n" % comment
|
||||
pre += "typedef struct __packed "
|
||||
else:
|
||||
pre += "struct __packed "
|
||||
|
||||
pre += "{\n%(elements)s\n}" % {
|
||||
"elements": indent_c(
|
||||
"\n".join(
|
||||
code + ("" if not comment else (" /** %s */" % comment))
|
||||
for code, comment in items
|
||||
)
|
||||
),
|
||||
}
|
||||
return pre, ""
|
||||
|
||||
def get_c_code(self, name=None, ignore_fields=None, no_empty=False, typedef=True) -> str:
|
||||
pre, post = self.get_c_parts(ignore_fields=ignore_fields,
|
||||
no_empty=no_empty,
|
||||
top_level=typedef)
|
||||
if no_empty and not pre and not post:
|
||||
return ""
|
||||
return "%s %s%s;" % (pre, name, post)
|
||||
|
||||
def get_c_definitions(self) -> dict[str, str]:
|
||||
definitions = {}
|
||||
for name, field_format in self._field_formats.items():
|
||||
definitions.update(field_format.get_c_definitions())
|
||||
if name in self._as_definition:
|
||||
typedef_name = field_format.get_typedef_name()
|
||||
if not isinstance(field_format, StructFormat):
|
||||
definitions[typedef_name] = 'typedef %(code)s %(name)s;' % {
|
||||
"code": ''.join(field_format.get_c_parts()),
|
||||
"name": typedef_name,
|
||||
}
|
||||
else:
|
||||
definitions[typedef_name] = field_format.get_c_code(name=typedef_name, typedef=True)
|
||||
return definitions
|
||||
|
||||
def get_typedef_name(self):
|
||||
return "%s_t" % normalize_name(self.model.__name__)
|
||||
|
||||
def get_c_includes(self) -> set:
|
||||
result = set()
|
||||
with suppress(AttributeError):
|
||||
result.update(self.model.existing_c_struct.includes)
|
||||
for field_format in self._field_formats.values():
|
||||
result.update(field_format.get_c_includes())
|
||||
return result
|
||||
|
||||
|
||||
class UnionFormat(CFormat):
|
||||
def __init__(self, model_formats: Sequence[StructFormat], discriminator: str, discriminator_as_hex: bool = False):
|
||||
self.discriminator = discriminator
|
||||
models = {
|
||||
getattr(model_format.model, discriminator): model_format for model_format in model_formats
|
||||
}
|
||||
if len(models) != len(model_formats):
|
||||
raise ValueError
|
||||
types = set(type(value) for value in models.keys())
|
||||
if len(types) != 1:
|
||||
raise ValueError
|
||||
discriminator_annotation = tuple(types)[0]
|
||||
if discriminator_as_hex:
|
||||
discriminator_annotation = Annotated[discriminator_annotation, AsHex()]
|
||||
self.discriminator_format = CFormat.from_annotation(discriminator_annotation)
|
||||
self.key_to_name = {value.c_value: value.name for value in models.keys()}
|
||||
self.models = {value.c_value: model_format for value, model_format in models.items()}
|
||||
|
||||
def get_var_num(self):
|
||||
return 0 # todo: is this always correct?
|
||||
|
||||
def encode(self, instance) -> bytes:
|
||||
discriminator_value = getattr(instance, self.discriminator)
|
||||
try:
|
||||
model_format = self.models[discriminator_value.c_value]
|
||||
except KeyError:
|
||||
raise ValueError('Unknown discriminator value for Union: %r' % discriminator_value)
|
||||
if not isinstance(instance, model_format.model):
|
||||
raise ValueError('Unknown value for Union discriminator %r: %r' % (discriminator_value, instance))
|
||||
return (
|
||||
self.discriminator_format.encode(discriminator_value.c_value)
|
||||
+ model_format.encode(instance, ignore_fields=(self.discriminator, ))
|
||||
)
|
||||
|
||||
def decode(self, data: bytes) -> tuple[T, bytes]:
|
||||
discriminator_value, remaining_data = self.discriminator_format.decode(data)
|
||||
return self.models[discriminator_value.c_value].decode(data)
|
||||
|
||||
def get_min_size(self) -> int:
|
||||
return max([0] + [
|
||||
model_format.get_min_size()
|
||||
for model_format in self.models.values()
|
||||
])
|
||||
|
||||
def get_max_size(self) -> int:
|
||||
raise ValueError
|
||||
|
||||
def get_size(self=False, calculate_max=False):
|
||||
return max([0] + [
|
||||
field_format.get_size(calculate_max=calculate_max)
|
||||
for field_format in self.models.values()
|
||||
])
|
||||
|
||||
def get_c_struct_items(self, ignore_fields=None, no_empty=False, top_level=False):
|
||||
return [
|
||||
(self.discriminator_format.get_c_code(self.discriminator), None),
|
||||
("union __packed %s;" % self.get_c_union_code(), None),
|
||||
]
|
||||
|
||||
def get_c_union_size(self):
|
||||
return max(
|
||||
(model_format.get_min_size() for model_format in self.models.values()),
|
||||
default=0,
|
||||
) - self.discriminator_format.get_min_size()
|
||||
|
||||
def get_c_union_code(self):
|
||||
union_items = []
|
||||
for key, model_format in self.models.items():
|
||||
base_name = normalize_name(self.key_to_name[key])
|
||||
item_c_code = model_format.get_c_code(
|
||||
base_name, ignore_fields=(self.discriminator, ), typedef=False, no_empty=True
|
||||
)
|
||||
if item_c_code:
|
||||
union_items.append(item_c_code)
|
||||
size = self.get_c_union_size()
|
||||
union_items.append(
|
||||
"uint8_t bytes[%0d];" % size
|
||||
)
|
||||
return "{\n" + indent_c("\n".join(union_items)) + "\n}"
|
||||
|
||||
def get_c_parts(self, ignore_fields=None, no_empty=False, top_level=False) -> tuple[str, str]:
|
||||
items = self.get_c_struct_items(no_empty=no_empty,
|
||||
top_level=top_level)
|
||||
|
||||
if no_empty and not items:
|
||||
return "", ""
|
||||
|
||||
if top_level:
|
||||
pre = "typedef struct __packed "
|
||||
else:
|
||||
pre = "struct __packed "
|
||||
|
||||
pre += "{\n%(elements)s\n}" % {
|
||||
"elements": indent_c(
|
||||
"\n".join(
|
||||
code + ("" if not comment else (" /** %s */" % comment))
|
||||
for code, comment in items
|
||||
)
|
||||
),
|
||||
}
|
||||
return pre, ""
|
||||
|
||||
def get_c_code(self, name=None, ignore_fields=None, no_empty=False, typedef=True,) -> str:
|
||||
pre, post = self.get_c_parts(ignore_fields=ignore_fields,
|
||||
no_empty=no_empty,
|
||||
top_level=typedef)
|
||||
if no_empty and not pre and not post:
|
||||
return ""
|
||||
return "%s %s%s;" % (pre, name, post)
|
||||
|
||||
def get_c_definitions(self) -> dict[str, str]:
|
||||
definitions = {}
|
||||
definitions.update(self.discriminator_format.get_c_definitions())
|
||||
for model_format in self.models.values():
|
||||
definitions.update(model_format.get_c_definitions())
|
||||
return definitions
|
||||
|
||||
def get_typedef_name(self):
|
||||
names = [model_format.model.__name__ for model_format in self.models.values()]
|
||||
min_len = min(len(name) for name in names)
|
||||
longest_prefix = ''
|
||||
longest_suffix = ''
|
||||
for i in reversed(range(min_len)):
|
||||
a = set(name[:i] for name in names)
|
||||
if len(a) == 1:
|
||||
longest_prefix = tuple(a)[0]
|
||||
break
|
||||
for i in reversed(range(min_len)):
|
||||
a = set(name[-i:] for name in names)
|
||||
if len(a) == 1:
|
||||
longest_suffix = tuple(a)[0]
|
||||
break
|
||||
return "%s_t" % normalize_name(longest_prefix if len(longest_prefix) > len(longest_suffix) else longest_suffix)
|
||||
|
||||
def get_c_includes(self) -> set:
|
||||
result = set()
|
||||
result.update(self.discriminator_format.get_c_includes())
|
||||
for model_format in self.models.values():
|
||||
result.update(model_format.get_c_includes())
|
||||
return result
|
|
@ -17,6 +17,7 @@ from django.utils import timezone
|
|||
from django.utils.crypto import constant_time_compare
|
||||
|
||||
from c3nav.mesh import messages
|
||||
from c3nav.mesh.cformats import CFormat
|
||||
from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, OTA_CHUNK_SIZE,
|
||||
MeshMessage, MeshMessageType, OTAApplyMessage, OTASettingMessage)
|
||||
from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage, OTARecipientStatus, OTAUpdate, OTAUpdateRecipient
|
||||
|
@ -46,6 +47,7 @@ class NodeState:
|
|||
|
||||
|
||||
class MeshConsumer(AsyncWebsocketConsumer):
|
||||
mesh_msg_format = CFormat.from_annotation(MeshMessage)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.uplink = None
|
||||
|
@ -56,6 +58,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
self.ota_send_task = None
|
||||
self.ota_chunks: dict[int, set[int]] = {} # keys are update IDs, values are a list of chunk IDs
|
||||
self.ota_chunks_available_condition = asyncio.Condition()
|
||||
self.accepted = None
|
||||
|
||||
async def connect(self):
|
||||
self.headers = dict(self.scope["headers"])
|
||||
|
@ -68,13 +71,16 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
self.ping_task = get_event_loop().create_task(self.ping_regularly())
|
||||
self.check_node_state_task = get_event_loop().create_task(self.check_node_states())
|
||||
self.ota_send_task = get_event_loop().create_task(self.ota_send())
|
||||
self.accepted = True
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
if not self.accepted:
|
||||
return
|
||||
self.ping_task.cancel()
|
||||
self.check_node_state_task.cancel()
|
||||
self.ota_send_task.cancel()
|
||||
await self.log_text(self.uplink.node, "mesh websocket disconnected")
|
||||
if self.uplink is not None:
|
||||
await self.log_text(self.uplink.node, "mesh websocket disconnected")
|
||||
# leave broadcast group
|
||||
await self.channel_layer.group_discard("mesh_comm_broadcast", self.channel_name)
|
||||
|
||||
|
@ -91,10 +97,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
end_reason=MeshUplink.EndReason.CLOSED
|
||||
)
|
||||
|
||||
async def send_msg(self, msg, sender=None, exclude_uplink_address=None):
|
||||
async def send_msg(self, msg: MeshMessage, sender=None, exclude_uplink_address=None):
|
||||
# print("sending", msg, MeshMessage.encode(msg).hex(' ', 1))
|
||||
# self.log_text(msg.dst, "sending %s" % msg)
|
||||
await self.send(bytes_data=MeshMessage.encode(msg))
|
||||
# self.log_text(msg_envelope.dst, "sending %s" % msg)
|
||||
await self.send(bytes_data=self.mesh_msg_format.encode(msg))
|
||||
await self.channel_layer.group_send("mesh_msg_sent", {
|
||||
"type": "mesh.msg_sent",
|
||||
"timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"),
|
||||
|
@ -113,18 +119,20 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
if bytes_data is None:
|
||||
return
|
||||
try:
|
||||
msg, data = messages.MeshMessage.decode(bytes_data)
|
||||
msg, data = self.mesh_msg_format.decode(bytes_data)
|
||||
msg: MeshMessage
|
||||
except Exception:
|
||||
print("Unable to decode: ")
|
||||
print("Unable to decode: msg_type=", hex(bytes_data[12]))
|
||||
print(bytes_data)
|
||||
traceback.print_exc()
|
||||
return
|
||||
msg.content = msg.content
|
||||
|
||||
# print(msg)
|
||||
|
||||
if msg.dst != messages.MESH_ROOT_ADDRESS and msg.dst != messages.MESH_PARENT_ADDRESS:
|
||||
# message not adressed to us, forward it
|
||||
print('Received message for forwarding:', msg)
|
||||
print('Received message for forwarding:', msg.content)
|
||||
|
||||
if not self.uplink:
|
||||
await self.log_text(None, "received message not for us before sign in message, ignoring...")
|
||||
|
@ -132,12 +140,12 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
return
|
||||
|
||||
# trace messages collect node adresses before forwarding
|
||||
if isinstance(msg, messages.MeshRouteTraceMessage):
|
||||
if isinstance(msg.content, messages.MeshRouteTraceMessage):
|
||||
print('adding ourselves to trace message before forwarding')
|
||||
await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding")
|
||||
msg.trace.append(MESH_ROOT_ADDRESS)
|
||||
msg.content.trace.append(MESH_ROOT_ADDRESS)
|
||||
|
||||
result = await msg.send(exclude_uplink_address=self.uplink.node.address)
|
||||
result = await msg.content.send(exclude_uplink_address=self.uplink.node.address)
|
||||
|
||||
if not result:
|
||||
print('message had no route')
|
||||
|
@ -154,7 +162,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
|
||||
src_node, created = await MeshNode.objects.aget_or_create(address=msg.src)
|
||||
|
||||
if isinstance(msg, messages.MeshSigninMessage):
|
||||
if isinstance(msg.content, messages.MeshSigninMessage):
|
||||
if not self.check_valid_address(msg.src):
|
||||
print('reject node with invalid address address')
|
||||
await self.close()
|
||||
|
@ -172,10 +180,12 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
await self.log_received_message(src_node, msg)
|
||||
|
||||
# inform signed in uplink node about its layer
|
||||
await self.send_msg(messages.MeshLayerAnnounceMessage(
|
||||
await self.send_msg(messages.MeshMessage(
|
||||
src=messages.MESH_ROOT_ADDRESS,
|
||||
dst=msg.src,
|
||||
layer=messages.NO_LAYER
|
||||
content=messages.MeshLayerAnnounceMessage(
|
||||
layer=messages.NO_LAYER
|
||||
)
|
||||
))
|
||||
|
||||
# add signed in uplink node to broadcast group
|
||||
|
@ -186,7 +196,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
|
||||
# add this node as a destination that this uplink handles (duh)
|
||||
await self.add_dst_nodes(nodes=(src_node, ))
|
||||
self.dst_nodes[msg.src].last_msg[MeshMessageType.MESH_SIGNIN] = msg
|
||||
self.dst_nodes[msg.src].last_msg[MeshMessageType.MESH_SIGNIN] = msg.content
|
||||
|
||||
return
|
||||
|
||||
|
@ -202,38 +212,42 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
except KeyError:
|
||||
print('unexpected message from', msg.src)
|
||||
return
|
||||
node_status.last_msg[msg.msg_type] = msg
|
||||
node_status.last_msg[msg.content.msg_type] = msg.content
|
||||
|
||||
if isinstance(msg, messages.MeshAddDestinationsMessage):
|
||||
result = await self.add_dst_nodes(addresses=msg.addresses)
|
||||
if isinstance(msg.content, messages.MeshAddDestinationsMessage):
|
||||
result = await self.add_dst_nodes(addresses=msg.content.addresses)
|
||||
if not result:
|
||||
print('disconnecting node that send invalid destinations', msg)
|
||||
print('disconnecting node that send invalid destinations', msg.content)
|
||||
await self.close()
|
||||
|
||||
if isinstance(msg, messages.MeshRemoveDestinationsMessage):
|
||||
await self.remove_dst_nodes(addresses=msg.addresses)
|
||||
if isinstance(msg.content, messages.MeshRemoveDestinationsMessage):
|
||||
await self.remove_dst_nodes(addresses=msg.content.addresses)
|
||||
|
||||
if isinstance(msg, messages.MeshRouteRequestMessage):
|
||||
if msg.address == MESH_ROOT_ADDRESS:
|
||||
if isinstance(msg.content, messages.MeshRouteRequestMessage):
|
||||
if msg.content.address == MESH_ROOT_ADDRESS:
|
||||
await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace")
|
||||
await self.send_msg(messages.MeshRouteTraceMessage(
|
||||
await self.send_msg(messages.MeshMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=msg.src,
|
||||
request_id=msg.request_id,
|
||||
trace=[MESH_ROOT_ADDRESS],
|
||||
content=messages.MeshRouteTraceMessage(
|
||||
request_id=msg.content.request_id,
|
||||
trace=[MESH_ROOT_ADDRESS],
|
||||
)
|
||||
))
|
||||
else:
|
||||
await self.log_text(MESH_ROOT_ADDRESS, "route request about someone else, sending response")
|
||||
self.open_requests.add(msg.request_id)
|
||||
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.address)
|
||||
await self.send_msg(messages.MeshRouteResponseMessage(
|
||||
self.open_requests.add(msg.content.request_id)
|
||||
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.content.address)
|
||||
await self.send_msg(messages.MeshMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=msg.src,
|
||||
request_id=msg.request_id,
|
||||
route=uplink.node_id if uplink else MESH_NONE_ADDRESS,
|
||||
content=messages.MeshRouteResponseMessage(
|
||||
request_id=msg.content.request_id,
|
||||
route=uplink.node_id if uplink else MESH_NONE_ADDRESS,
|
||||
)
|
||||
))
|
||||
|
||||
if isinstance(msg, (messages.ConfigHardwareMessage,
|
||||
if isinstance(msg.content, (messages.ConfigHardwareMessage,
|
||||
messages.ConfigFirmwareMessage,
|
||||
messages.ConfigBoardMessage)):
|
||||
if (node_status.waiting_for == NodeWaitingFor.CONFIG and
|
||||
|
@ -243,16 +257,16 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
print('got all config, checking ota')
|
||||
await self.check_ota([msg.src], first_time=True)
|
||||
|
||||
if isinstance(msg, messages.OTAStatusMessage):
|
||||
print('got OTA status', msg)
|
||||
node_status.reported_ota_update = msg.update_id
|
||||
if isinstance(msg.content, messages.OTAStatusMessage):
|
||||
print('got OTA status', msg.content)
|
||||
node_status.reported_ota_update = msg.content.update_id
|
||||
if node_status.waiting_for == NodeWaitingFor.OTA_START_STOP:
|
||||
update_id = node_status.ota_recipient.update_id if node_status.ota_recipient else 0
|
||||
if update_id == msg.update_id:
|
||||
if update_id == msg.content.update_id:
|
||||
print('start/cancel confirmed!')
|
||||
node_status.waiting_for = NodeWaitingFor.NOTHING
|
||||
if update_id:
|
||||
if msg.status.is_failed:
|
||||
if msg.content.status.is_failed:
|
||||
print('ota failed')
|
||||
node_status.ota_recipient.status = OTARecipientStatus.FAILED
|
||||
await node_status.ota_recipient.send_status()
|
||||
|
@ -263,14 +277,14 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
else:
|
||||
print('queue chunk sending')
|
||||
await self.ota_set_chunks(node_status.ota_recipient.update,
|
||||
min_chunk=msg.next_expected_chunk)
|
||||
min_chunk=msg.content.next_expected_chunk)
|
||||
|
||||
if isinstance(msg, messages.OTARequestFragmentsMessage):
|
||||
print('got OTA fragment request', msg)
|
||||
if isinstance(msg.content, messages.OTARequestFragmentsMessage):
|
||||
print('got OTA fragment request', msg.content)
|
||||
desired_update_id = node_status.ota_recipient.update_id if node_status.ota_recipient else 0
|
||||
if desired_update_id and msg.update_id == desired_update_id:
|
||||
if desired_update_id and msg.content.update_id == desired_update_id:
|
||||
print('queue requested chunk sending')
|
||||
await self.ota_set_chunks(node_status.ota_recipient.update, chunks=set(msg.chunks))
|
||||
await self.ota_set_chunks(node_status.ota_recipient.update, chunks=set(msg.content.chunks))
|
||||
|
||||
@database_sync_to_async
|
||||
def create_uplink_in_database(self, address):
|
||||
|
@ -334,7 +348,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
self.uplink.node.address, "we're the route for this message but it came from here so... no"
|
||||
)
|
||||
return
|
||||
await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"])
|
||||
await self.send_msg(MeshMessage.model_validate(data["msg"]), data["sender"])
|
||||
|
||||
async def mesh_ota_recipients_changed(self, data):
|
||||
addresses = set(data["addresses"]) & set(self.dst_nodes.keys())
|
||||
|
@ -349,7 +363,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
"""
|
||||
|
||||
async def log_received_message(self, src_node: MeshNode, msg: messages.MeshMessage):
|
||||
as_json = MeshMessage.tojson(msg)
|
||||
as_json = msg.model_dump()
|
||||
await self.channel_layer.group_send("mesh_msg_received", {
|
||||
"type": "mesh.msg_received",
|
||||
"timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"),
|
||||
|
@ -360,7 +374,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
await NodeMessage.objects.acreate(
|
||||
uplink=self.uplink,
|
||||
src_node=src_node,
|
||||
message_type=msg.msg_type.name,
|
||||
message_type=msg.content.msg_type.name,
|
||||
data=as_json,
|
||||
)
|
||||
|
||||
|
@ -435,29 +449,34 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
node_state.last_sent = timezone.now()
|
||||
print('request config dump, attempt #%d' % node_state.attempt)
|
||||
node_state.attempt += 1
|
||||
await self.send_msg(messages.ConfigDumpMessage(
|
||||
await self.send_msg(messages.MeshMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=address,
|
||||
content=messages.ConfigDumpMessage()
|
||||
))
|
||||
|
||||
case NodeWaitingFor.OTA_START_STOP:
|
||||
node_state.last_sent = timezone.now()
|
||||
if node_state.ota_recipient:
|
||||
print('starting ota, attempt #%d' % node_state.attempt)
|
||||
await self.send_msg(messages.OTAStartMessage(
|
||||
await self.send_msg(messages.MeshMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=address,
|
||||
update_id=node_state.ota_recipient.update_id, # noqa
|
||||
total_bytes=node_state.ota_recipient.update.build.binary.size,
|
||||
auto_apply=False,
|
||||
auto_reboot=False,
|
||||
content=messages.OTAStartMessage(
|
||||
update_id=node_state.ota_recipient.update_id, # noqa
|
||||
total_bytes=node_state.ota_recipient.update.build.binary.size,
|
||||
auto_apply=False,
|
||||
auto_reboot=False,
|
||||
)
|
||||
))
|
||||
else:
|
||||
print('canceling ota, attempt #%d' % node_state.attempt)
|
||||
await self.send_msg(messages.OTAAbortMessage(
|
||||
await self.send_msg(messages.MeshMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=address,
|
||||
update_id=0,
|
||||
content=messages.OTAAbortMessage(
|
||||
update_id=0,
|
||||
)
|
||||
))
|
||||
|
||||
async def check_node_states(self):
|
||||
|
@ -511,12 +530,14 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
with self.dst_nodes[recipients[0]].ota_recipient.update.build.binary.open('rb') as f:
|
||||
f.seek(chunk * OTA_CHUNK_SIZE)
|
||||
data = f.read(OTA_CHUNK_SIZE)
|
||||
await self.send_msg(messages.OTAFragmentMessage(
|
||||
await self.send_msg(messages.MeshMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=recipients[0] if len(recipients) == 1 else MESH_BROADCAST_ADDRESS,
|
||||
update_id=update_id,
|
||||
chunk=chunk,
|
||||
data=data,
|
||||
content=messages.OTAFragmentMessage(
|
||||
update_id=update_id,
|
||||
chunk=chunk,
|
||||
data=data,
|
||||
)
|
||||
))
|
||||
|
||||
# wait a bit until we send more
|
||||
|
@ -681,7 +702,7 @@ class MeshUIConsumer(AsyncJsonWebsocketConsumer):
|
|||
self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]}
|
||||
|
||||
for recipient in msg_to_send["recipients"]:
|
||||
await MeshMessage.fromjson({
|
||||
await MeshMessage.model_validate({
|
||||
'dst': recipient,
|
||||
**msg_to_send["msg_data"],
|
||||
}).send(sender=self.channel_name)
|
||||
|
|
|
@ -1,285 +0,0 @@
|
|||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum, unique
|
||||
from typing import BinaryIO, Self
|
||||
|
||||
from c3nav.api.utils import EnumSchemaByNameMixin
|
||||
from c3nav.mesh.baseformats import (BoolFormat, ChipRevFormat, EnumFormat, FixedHexFormat, FixedStrFormat,
|
||||
SimpleConstFormat, SimpleFormat, StructType, TwoNibblesEnumFormat, VarArrayFormat)
|
||||
|
||||
|
||||
class MacAddressFormat(FixedHexFormat):
|
||||
def __init__(self):
|
||||
super().__init__(num=6, sep=':')
|
||||
|
||||
|
||||
class MacAddressesListFormat(VarArrayFormat):
|
||||
def __init__(self, max_num):
|
||||
super().__init__(child_type=MacAddressFormat(), max_num=max_num)
|
||||
|
||||
|
||||
@unique
|
||||
class LedType(IntEnum):
|
||||
NONE = 0
|
||||
SERIAL = 1
|
||||
MULTIPIN = 2
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.lower()
|
||||
|
||||
|
||||
@unique
|
||||
class SerialLedType(IntEnum):
|
||||
WS2812 = 1
|
||||
SK6812 = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class LedConfig(StructType, union_type_field="led_type"):
|
||||
"""
|
||||
configuration for an optional connected status LED
|
||||
"""
|
||||
led_type: LedType = field(metadata={"format": EnumFormat(), "c_name": "type"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoLedConfig(LedConfig, led_type=LedType.NONE):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SerialLedConfig(LedConfig, led_type=LedType.SERIAL):
|
||||
serial_led_type: SerialLedType = field(metadata={"format": EnumFormat(), "c_name": "type"})
|
||||
gpio: int = field(metadata={"format": SimpleFormat('B')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultipinLedConfig(LedConfig, led_type=LedType.MULTIPIN):
|
||||
gpio_red: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_green: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_blue: int = field(metadata={"format": SimpleFormat('B')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoardSPIConfig(StructType):
|
||||
"""
|
||||
configuration for spi bus used for ETH or UWB
|
||||
"""
|
||||
gpio_miso: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_mosi: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_clk: int = field(metadata={"format": SimpleFormat('B')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class UWBConfig(StructType):
|
||||
"""
|
||||
configuration for the connection to the UWB module
|
||||
"""
|
||||
enable: bool = field(metadata={"format": BoolFormat()})
|
||||
gpio_cs: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_irq: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_rst: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_wakeup: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_exton: int = field(metadata={"format": SimpleFormat('B')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class UplinkEthConfig(StructType):
|
||||
"""
|
||||
configuration for the connection to the ETH module
|
||||
"""
|
||||
enable: bool = field(metadata={"format": BoolFormat()})
|
||||
gpio_cs: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_int: int = field(metadata={"format": SimpleFormat('B')})
|
||||
gpio_rst: int = field(metadata={"format": SimpleFormat('b')})
|
||||
|
||||
|
||||
@unique
|
||||
class BoardType(EnumSchemaByNameMixin, IntEnum):
|
||||
CUSTOM = 0x00
|
||||
|
||||
# devboards
|
||||
ESP32_C3_DEVKIT_M_1 = 0x01
|
||||
ESP32_C3_32S = 2
|
||||
|
||||
# custom boards
|
||||
C3NAV_UWB_BOARD = 0x10
|
||||
C3NAV_LOCATION_PCB_REV_0_1 = 0x11
|
||||
C3NAV_LOCATION_PCB_REV_0_2 = 0x12
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
if self.name.startswith('ESP32'):
|
||||
return self.name.replace('_', '-').replace('DEVKIT-', 'DevKit')
|
||||
if self.name.startswith('C3NAV'):
|
||||
name = self.name.replace('_', ' ').lower()
|
||||
name = name.replace('uwb', 'UWB').replace('pcb', 'PCB')
|
||||
name = re.sub(r'[0-9]+( [0-9+])+', lambda s: s[0].replace(' ', '.'), name)
|
||||
name = re.sub(r'rev.*', lambda s: s[0].replace(' ', ''), name)
|
||||
return name
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoardConfig(StructType, union_type_field="board"):
|
||||
board: BoardType = field(metadata={"format": EnumFormat(as_hex=True)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomBoardConfig(BoardConfig, board=BoardType.CUSTOM):
|
||||
spi: BoardSPIConfig = field(metadata={"as_definition": True})
|
||||
uwb: UWBConfig = field(metadata={"as_definition": True})
|
||||
eth: UplinkEthConfig = field(metadata={"as_definition": True})
|
||||
led: LedConfig = field(metadata={"as_definition": True})
|
||||
|
||||
|
||||
@dataclass
|
||||
class DevkitMBoardConfig(BoardConfig, board=BoardType.ESP32_C3_DEVKIT_M_1):
|
||||
spi: BoardSPIConfig = field(metadata={"as_definition": True})
|
||||
uwb: UWBConfig = field(metadata={"as_definition": True})
|
||||
eth: UplinkEthConfig = field(metadata={"as_definition": True})
|
||||
|
||||
|
||||
@dataclass
|
||||
class Esp32SBoardConfig(BoardConfig, board=BoardType.ESP32_C3_32S):
|
||||
spi: BoardSPIConfig = field(metadata={"as_definition": True})
|
||||
uwb: UWBConfig = field(metadata={"as_definition": True})
|
||||
eth: UplinkEthConfig = field(metadata={"as_definition": True})
|
||||
|
||||
|
||||
@dataclass
|
||||
class UwbBoardConfig(BoardConfig, board=BoardType.C3NAV_UWB_BOARD):
|
||||
eth: UplinkEthConfig = field(metadata={"as_definition": True})
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocationPCBRev0Dot1BoardConfig(BoardConfig, board=BoardType.C3NAV_LOCATION_PCB_REV_0_1):
|
||||
eth: UplinkEthConfig = field(metadata={"as_definition": True})
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocationPCBRev0Dot2BoardConfig(BoardConfig, board=BoardType.C3NAV_LOCATION_PCB_REV_0_2):
|
||||
eth: UplinkEthConfig = field(metadata={"as_definition": True})
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeResultItem(StructType):
|
||||
peer: str = field(metadata={"format": MacAddressFormat()})
|
||||
rssi: int = field(metadata={"format": SimpleFormat('b')})
|
||||
distance: int = field(metadata={"format": SimpleFormat('h')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class RawFTMEntry(StructType):
|
||||
dlog_token: int = field(metadata={"format": SimpleFormat('B')})
|
||||
rssi: int = field(metadata={"format": SimpleFormat('b')})
|
||||
rtt: int = field(metadata={"format": SimpleFormat('I')})
|
||||
t1: int = field(metadata={"format": SimpleFormat('Q')})
|
||||
t2: int = field(metadata={"format": SimpleFormat('Q')})
|
||||
t3: int = field(metadata={"format": SimpleFormat('Q')})
|
||||
t4: int = field(metadata={"format": SimpleFormat('Q')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirmwareAppDescription(StructType, existing_c_struct="esp_app_desc_t", c_includes=['<esp_app_desc.h>']):
|
||||
magic_word: int = field(metadata={"format": SimpleConstFormat('I', 0xAB_CD_54_32)}, repr=False)
|
||||
secure_version: int = field(metadata={"format": SimpleFormat('I')})
|
||||
reserv1: list[int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
|
||||
version: str = field(metadata={"format": FixedStrFormat(32)})
|
||||
project_name: str = field(metadata={"format": FixedStrFormat(32)})
|
||||
compile_time: str = field(metadata={"format": FixedStrFormat(16)})
|
||||
compile_date: str = field(metadata={"format": FixedStrFormat(16)})
|
||||
idf_version: str = field(metadata={"format": FixedStrFormat(32)})
|
||||
app_elf_sha256: str = field(metadata={"format": FixedHexFormat(32)})
|
||||
reserv2: list[int] = field(metadata={"format": SimpleFormat('20I')}, repr=False)
|
||||
|
||||
|
||||
@unique
|
||||
class SPIFlashMode(EnumSchemaByNameMixin, IntEnum):
|
||||
QIO = 0
|
||||
QOUT = 1
|
||||
DIO = 2
|
||||
DOUT = 3
|
||||
|
||||
|
||||
@unique
|
||||
class FlashSize(EnumSchemaByNameMixin, IntEnum):
|
||||
SIZE_1MB = 0
|
||||
SIZE_2MB = 1
|
||||
SIZE_4MB = 2
|
||||
SIZE_8MB = 3
|
||||
SIZE_16MB = 4
|
||||
SIZE_32MB = 5
|
||||
SIZE_64MB = 6
|
||||
SIZE_128MB = 7
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.removeprefix('SIZE_')
|
||||
|
||||
|
||||
@unique
|
||||
class FlashFrequency(EnumSchemaByNameMixin, IntEnum):
|
||||
FREQ_40MHZ = 0
|
||||
FREQ_26MHZ = 1
|
||||
FREQ_20MHZ = 2
|
||||
FREQ_80MHZ = 0xf
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.removeprefix('FREQ_').replace('MHZ', 'Mhz')
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashSettings:
|
||||
size: FlashSize
|
||||
frequency: FlashFrequency
|
||||
|
||||
@property
|
||||
def display(self):
|
||||
return f"{self.size.pretty_name} ({self.frequency.pretty_name})"
|
||||
|
||||
|
||||
@unique
|
||||
class ChipType(EnumSchemaByNameMixin, IntEnum):
|
||||
ESP32_S2 = 2
|
||||
ESP32_C3 = 5
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.replace('_', '-')
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirmwareImageFileHeader(StructType):
|
||||
magic_word: int = field(metadata={"format": SimpleConstFormat('B', 0xE9)}, repr=False)
|
||||
num_segments: int = field(metadata={"format": SimpleFormat('B')})
|
||||
spi_flash_mode: SPIFlashMode = field(metadata={"format": EnumFormat()})
|
||||
flash_stuff: FlashSettings = field(metadata={"format": TwoNibblesEnumFormat()})
|
||||
entry_point: int = field(metadata={"format": SimpleFormat('I')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirmwareImageExtendedFileHeader(StructType):
|
||||
wp_pin: int = field(metadata={"format": SimpleFormat('B')})
|
||||
drive_settings: int = field(metadata={"format": SimpleFormat('3B')})
|
||||
chip: ChipType = field(metadata={"format": EnumFormat('H')})
|
||||
min_chip_rev_old: int = field(metadata={"format": SimpleFormat('B')})
|
||||
min_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()})
|
||||
max_chip_rev: tuple[int, int] = field(metadata={"format": ChipRevFormat()})
|
||||
reserv: int = field(metadata={"format": SimpleFormat('I')}, repr=False)
|
||||
hash_appended: bool = field(metadata={"format": BoolFormat()})
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirmwareImage(StructType):
|
||||
header: FirmwareImageFileHeader
|
||||
ext_header: FirmwareImageExtendedFileHeader
|
||||
first_segment_headers: tuple[int, int] = field(metadata={"format": SimpleFormat('2I')}, repr=False)
|
||||
app_desc: FirmwareAppDescription
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: BinaryIO) -> Self:
|
||||
result, data = cls.decode(file.read(FirmwareImage.get_min_size()))
|
||||
return result
|
|
@ -13,11 +13,15 @@ from django.db import transaction
|
|||
from django.forms import BooleanField, ChoiceField, Form, ModelMultipleChoiceField, MultipleChoiceField
|
||||
from django.http import Http404
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from pydantic import ValidationError as PydanticValidationError
|
||||
from pydantic.type_adapter import TypeAdapter
|
||||
|
||||
from c3nav.mesh.dataformats import BoardConfig, BoardType, LedType, SerialLedType
|
||||
from c3nav.mesh.messages import MESH_BROADCAST_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, MeshMessageType
|
||||
from c3nav.mesh.cformats import CFormat
|
||||
from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, MeshMessageContent,
|
||||
MeshMessageType)
|
||||
from c3nav.mesh.models import (FirmwareBuild, HardwareDescription, MeshNode, OTARecipientStatus, OTAUpdate,
|
||||
OTAUpdateRecipient)
|
||||
from c3nav.mesh.schemas import BoardConfig, BoardType, LedType, SerialLedType
|
||||
from c3nav.mesh.utils import MESH_ALL_OTA_GROUP, group_msg_type_choices
|
||||
|
||||
|
||||
|
@ -64,7 +68,7 @@ class MeshMessageForm(forms.Form):
|
|||
if cls.msg_type in MeshMessageForm.msg_types:
|
||||
raise TypeError('duplicate use of msg %s' % cls.msg_type)
|
||||
MeshMessageForm.msg_types[cls.msg_type] = cls
|
||||
cls.msg_type_class = MeshMessage.get_type(cls.msg_type)
|
||||
cls.msg_type_class = CFormat.from_annotation(MeshMessageContent).models.get(cls.msg_type.c_value).model
|
||||
|
||||
@classmethod
|
||||
def get_form_for_type(cls, msg_type):
|
||||
|
@ -83,9 +87,11 @@ class MeshMessageForm(forms.Form):
|
|||
raise Exception('nope')
|
||||
|
||||
return {
|
||||
'msg_type': self.msg_type.name,
|
||||
'src': MESH_ROOT_ADDRESS,
|
||||
**self.get_cleaned_msg_data(),
|
||||
"src": MESH_ROOT_ADDRESS,
|
||||
"content": {
|
||||
"msg_type": self.msg_type.name,
|
||||
**self.get_cleaned_msg_data(),
|
||||
}
|
||||
}
|
||||
|
||||
def get_recipients(self):
|
||||
|
@ -96,7 +102,7 @@ class MeshMessageForm(forms.Form):
|
|||
recipients = self.get_recipients()
|
||||
for recipient in recipients:
|
||||
print('sending to ', recipient)
|
||||
async_to_sync(MeshMessage.fromjson({
|
||||
async_to_sync(MeshMessage.model_validate({
|
||||
'dst': recipient,
|
||||
**msg_data,
|
||||
}).send)()
|
||||
|
@ -173,8 +179,8 @@ class ConfigBoardMessageForm(MeshMessageForm):
|
|||
"prefix": "led_",
|
||||
"field": "board",
|
||||
"values": tuple(
|
||||
cfg.board.name for cfg in BoardConfig._union_options["board"].values()
|
||||
if "led" in cfg.__dataclass_fields__
|
||||
board_type.name for board_type in BoardType
|
||||
if "led" in CFormat.from_annotation(BoardConfig).models[board_type.c_value]._field_formats
|
||||
),
|
||||
},
|
||||
{
|
||||
|
@ -191,8 +197,8 @@ class ConfigBoardMessageForm(MeshMessageForm):
|
|||
"prefix": "uwb_",
|
||||
"field": "board",
|
||||
"values": tuple(
|
||||
cfg.board.name for cfg in BoardConfig._union_options["board"].values()
|
||||
if "uwb" in cfg.__dataclass_fields__
|
||||
board_type.name for board_type in BoardType
|
||||
if "uwb" in CFormat.from_annotation(BoardConfig).models[board_type.c_value]._field_formats
|
||||
),
|
||||
},
|
||||
{
|
||||
|
@ -204,10 +210,7 @@ class ConfigBoardMessageForm(MeshMessageForm):
|
|||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
|
||||
board_cfg = BoardConfig._union_options["board"][BoardType[cleaned_data["board"]]]
|
||||
has_led = "led" in board_cfg.__dataclass_fields__
|
||||
has_uwb = "uwb" in board_cfg.__dataclass_fields__
|
||||
orig_cleaned_keys = set(cleaned_data.keys())
|
||||
|
||||
led_values = {
|
||||
"led_type": cleaned_data.pop("led_type"),
|
||||
|
@ -217,43 +220,29 @@ class ConfigBoardMessageForm(MeshMessageForm):
|
|||
if name.startswith('led_')
|
||||
}
|
||||
}
|
||||
if led_values:
|
||||
cleaned_data["led"] = led_values
|
||||
|
||||
uwb_values = {
|
||||
name.removeprefix('uwb_'): cleaned_data.pop(name)
|
||||
for name in tuple(cleaned_data.keys())
|
||||
if name.startswith('uwb_')
|
||||
}
|
||||
|
||||
errors = {}
|
||||
|
||||
if has_led:
|
||||
prefix = led_values["led_type"].lower()+'_'
|
||||
cleaned_data["led"] = {
|
||||
"led_type": led_values["led_type"],
|
||||
**{
|
||||
name.removeprefix(prefix): value
|
||||
for name, value in led_values.items()
|
||||
if name.startswith(prefix)
|
||||
}
|
||||
}
|
||||
for key, value in tuple(cleaned_data["led"].items()):
|
||||
if value is None:
|
||||
field_name = f'led_{prefix}{key}'
|
||||
if self.fields[field_name].min_value == -1:
|
||||
cleaned_data[key] = -1
|
||||
else:
|
||||
errors[field_name] = _('this field is required')
|
||||
|
||||
if has_uwb:
|
||||
if uwb_values:
|
||||
cleaned_data["uwb"] = uwb_values
|
||||
for key, value in tuple(cleaned_data["uwb"].items()):
|
||||
if value is None:
|
||||
field_name = f'uwb_{key}'
|
||||
if self.fields[field_name].min_value == -1 or not cleaned_data["uwb"]["enable"]:
|
||||
cleaned_data[key] = -1
|
||||
else:
|
||||
errors[field_name] = _('this field is required')
|
||||
|
||||
if errors:
|
||||
try:
|
||||
TypeAdapter(BoardConfig).validate_python(cleaned_data)
|
||||
except PydanticValidationError as e:
|
||||
from pprint import pprint
|
||||
pprint(e.errors())
|
||||
errors = {}
|
||||
for error in e.errors():
|
||||
loc = "_".join(s for s in error["loc"] if not s.isupper())
|
||||
if loc in orig_cleaned_keys:
|
||||
errors.setdefault(loc, []).append(error["msg"])
|
||||
else:
|
||||
errors.setdefault("__all__", []).append(f"{loc}: {error['msg']}")
|
||||
raise ValidationError(errors)
|
||||
|
||||
return cleaned_data
|
||||
|
|
|
@ -1,52 +1,47 @@
|
|||
from dataclasses import fields
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from c3nav.mesh.baseformats import StructType, normalize_name
|
||||
from c3nav.mesh.messages import MeshMessage
|
||||
from c3nav.mesh.cformats import UnionFormat, normalize_name, CFormat
|
||||
from c3nav.mesh.messages import MeshMessageContent
|
||||
from c3nav.mesh.utils import indent_c
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'export mesh message structs for c code'
|
||||
|
||||
@staticmethod
|
||||
def get_msg_c_enum_name(msg_type):
|
||||
return normalize_name(msg_type.__name__.removeprefix('Mesh').removesuffix('Message')).upper()
|
||||
|
||||
def handle(self, *args, **options):
|
||||
done_struct_names = set()
|
||||
nodata = set()
|
||||
struct_lines = {}
|
||||
struct_sizes = []
|
||||
struct_max_sizes = []
|
||||
done_definitions = set()
|
||||
|
||||
for include in StructType.c_includes:
|
||||
mesh_msg_content_format = CFormat.from_annotation(MeshMessageContent)
|
||||
if not isinstance(mesh_msg_content_format, UnionFormat):
|
||||
raise Exception('wuah')
|
||||
discriminator_size = mesh_msg_content_format.discriminator_format.get_size()
|
||||
for include in mesh_msg_content_format.get_c_includes():
|
||||
print(f'#include {include}')
|
||||
|
||||
ignore_names = set(field_.name for field_ in fields(MeshMessage))
|
||||
for msg_type, msg_class in MeshMessage.get_types().items():
|
||||
if msg_class.c_struct_name:
|
||||
if msg_class.c_struct_name in done_struct_names:
|
||||
continue
|
||||
done_struct_names.add(msg_class.c_struct_name)
|
||||
if MeshMessage.c_structs[msg_class.c_struct_name] != msg_class:
|
||||
# the purpose of MeshMessage.c_structs is unclear, currently this never triggers
|
||||
# todo get rid of the whole c_structs thing if it doesn't turn out to be useful for anything
|
||||
raise ValueError('what happened?')
|
||||
|
||||
base_name = (msg_class.c_struct_name or normalize_name(
|
||||
getattr(msg_type, 'name', msg_class.__name__)
|
||||
))
|
||||
for msg_type, msg_content_format in mesh_msg_content_format.models.items():
|
||||
base_name = normalize_name(mesh_msg_content_format.key_to_name[msg_type])
|
||||
name = "mesh_msg_%s_t" % base_name
|
||||
|
||||
for definition_name, definition in msg_class.get_c_definitions().items():
|
||||
for definition_name, definition in msg_content_format.get_c_definitions().items():
|
||||
if definition_name not in done_definitions:
|
||||
done_definitions.add(definition_name)
|
||||
print(definition)
|
||||
print()
|
||||
|
||||
code = msg_class.get_c_code(name, ignore_fields=ignore_names, no_empty=True)
|
||||
code = msg_content_format.get_c_code(name, ignore_fields=('msg_type', ), no_empty=True)
|
||||
if code:
|
||||
size = msg_class.get_size(no_inherited_fields=True, calculate_max=False)
|
||||
max_size = msg_class.get_size(no_inherited_fields=True, calculate_max=True)
|
||||
size = msg_content_format.get_size(calculate_max=False)
|
||||
max_size = msg_content_format.get_size(calculate_max=True)
|
||||
size -= discriminator_size
|
||||
max_size -= discriminator_size
|
||||
struct_lines[base_name] = "%s %s;" % (name, base_name.replace('_announce', ''))
|
||||
struct_sizes.append(size)
|
||||
struct_max_sizes.append(max_size)
|
||||
|
@ -55,13 +50,13 @@ class Command(BaseCommand):
|
|||
(name, size))
|
||||
print()
|
||||
else:
|
||||
nodata.add(msg_class)
|
||||
nodata.add(msg_content_format.model)
|
||||
|
||||
print("/** union between all message data structs */")
|
||||
print("typedef union __packed {")
|
||||
for line in struct_lines.values():
|
||||
print(indent_c(line))
|
||||
print("} mesh_msg_data_t; ")
|
||||
print("} mesh_msg_data_t;")
|
||||
print(
|
||||
"static_assert(sizeof(mesh_msg_data_t) == %d, \"size of generated message structs is calculated wrong\");"
|
||||
% max(struct_sizes)
|
||||
|
@ -72,20 +67,18 @@ class Command(BaseCommand):
|
|||
|
||||
print()
|
||||
|
||||
max_msg_type = max(MeshMessage.get_types().keys())
|
||||
max_msg_type = max(mesh_msg_content_format.models.keys())
|
||||
macro_data = []
|
||||
for i in range(((max_msg_type//16)+1)*16):
|
||||
msg_class = MeshMessage.get_types().get(i, None)
|
||||
if msg_class:
|
||||
name = (msg_class.c_struct_name or normalize_name(
|
||||
getattr(msg_class.msg_type, 'name', msg_class.__name__)
|
||||
))
|
||||
msg_content_format = mesh_msg_content_format.models.get(i, None)
|
||||
if msg_content_format:
|
||||
name = normalize_name(mesh_msg_content_format.key_to_name[i])
|
||||
macro_data.append((
|
||||
msg_class.get_c_enum_name(),
|
||||
("nodata" if msg_class in nodata else name),
|
||||
msg_class.get_var_num(),
|
||||
msg_class.get_size(no_inherited_fields=True, calculate_max=True),
|
||||
msg_class.__doc__.strip(),
|
||||
self.get_msg_c_enum_name(msg_content_format.model),
|
||||
("nodata" if msg_content_format.model in nodata else name),
|
||||
msg_content_format.get_var_num(), # todo: uh?
|
||||
msg_content_format.get_size(calculate_max=True) - discriminator_size,
|
||||
msg_content_format.model.__doc__.strip(),
|
||||
))
|
||||
else:
|
||||
macro_data.append((
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum, unique
|
||||
from typing import TypeVar
|
||||
from enum import unique
|
||||
from typing import Annotated, Union
|
||||
|
||||
import channels
|
||||
from annotated_types import Ge, Le, Lt, MaxLen
|
||||
from channels.db import database_sync_to_async
|
||||
from pydantic import PositiveInt
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.types import Discriminator, NonNegativeInt
|
||||
from pydantic_extra_types.mac_address import MacAddress
|
||||
|
||||
from c3nav.api.utils import EnumSchemaByNameMixin
|
||||
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat,
|
||||
VarBytesFormat, VarStrFormat, normalize_name)
|
||||
from c3nav.mesh.dataformats import (BoardConfig, ChipType, FirmwareAppDescription, MacAddressesListFormat,
|
||||
MacAddressFormat, RangeResultItem, RawFTMEntry)
|
||||
from c3nav.mesh.cformats import CDoc, CEmbed, CName, LenBytes, NoDef, VarLen, discriminator_value, CEnum
|
||||
from c3nav.mesh.schemas import BoardConfig, ChipType, FirmwareAppDescription, RangeResultItem, RawFTMEntry
|
||||
from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP
|
||||
|
||||
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
|
||||
|
@ -23,45 +24,45 @@ OTA_CHUNK_SIZE = 512
|
|||
|
||||
|
||||
@unique
|
||||
class MeshMessageType(EnumSchemaByNameMixin, IntEnum):
|
||||
NOOP = 0x00
|
||||
class MeshMessageType(CEnum):
|
||||
NOOP = "NOOP", 0x00
|
||||
|
||||
ECHO_REQUEST = 0x01
|
||||
ECHO_RESPONSE = 0x02
|
||||
ECHO_REQUEST = "ECHO_REQUEST", 0x01
|
||||
ECHO_RESPONSE = "ECHO_RESPONSE", 0x02
|
||||
|
||||
MESH_SIGNIN = 0x03
|
||||
MESH_LAYER_ANNOUNCE = 0x04
|
||||
MESH_ADD_DESTINATIONS = 0x05
|
||||
MESH_REMOVE_DESTINATIONS = 0x06
|
||||
MESH_ROUTE_REQUEST = 0x07
|
||||
MESH_ROUTE_RESPONSE = 0x08
|
||||
MESH_ROUTE_TRACE = 0x09
|
||||
MESH_ROUTING_FAILED = 0x0a
|
||||
MESH_SIGNIN = "MESH_SIGNIN", 0x03
|
||||
MESH_LAYER_ANNOUNCE = "MESH_LAYER_ANNOUNCE", 0x04
|
||||
MESH_ADD_DESTINATIONS = "MESH_ADD_DESTINATIONS", 0x05
|
||||
MESH_REMOVE_DESTINATIONS = "MESH_REMOVE_DESTINATIONS", 0x06
|
||||
MESH_ROUTE_REQUEST = "MESH_ROUTE_REQUEST", 0x07
|
||||
MESH_ROUTE_RESPONSE = "MESH_ROUTE_RESPONSE", 0x08
|
||||
MESH_ROUTE_TRACE = "MESH_ROUTE_TRACE", 0x09
|
||||
MESH_ROUTING_FAILED = "MESH_ROUTING_FAILED", 0x0a
|
||||
|
||||
CONFIG_DUMP = 0x10
|
||||
CONFIG_HARDWARE = 0x11
|
||||
CONFIG_BOARD = 0x12
|
||||
CONFIG_FIRMWARE = 0x13
|
||||
CONFIG_UPLINK = 0x14
|
||||
CONFIG_POSITION = 0x15
|
||||
CONFIG_DUMP = "CONFIG_DUMP", 0x10
|
||||
CONFIG_HARDWARE = "CONFIG_HARDWARE", 0x11
|
||||
CONFIG_BOARD = "CONFIG_BOARD", 0x12
|
||||
CONFIG_FIRMWARE = "CONFIG_FIRMWARE", 0x13
|
||||
CONFIG_UPLINK = "CONFIG_UPLINK", 0x14
|
||||
CONFIG_POSITION = "CONFIG_POSITION", 0x15
|
||||
|
||||
OTA_STATUS = 0x20
|
||||
OTA_REQUEST_STATUS = 0x21
|
||||
OTA_START = 0x22
|
||||
OTA_URL = 0x23
|
||||
OTA_FRAGMENT = 0x24
|
||||
OTA_REQUEST_FRAGMENTS = 0x25
|
||||
OTA_SETTING = 0x26
|
||||
OTA_APPLY = 0x27
|
||||
OTA_ABORT = 0x28
|
||||
OTA_STATUS = "OTA_STATUS", 0x20
|
||||
OTA_REQUEST_STATUS = "OTA_REQUEST_STATUS", 0x21
|
||||
OTA_START = "OTA_START", 0x22
|
||||
OTA_URL = "OTA_URL", 0x23
|
||||
OTA_FRAGMENT = "OTA_FRAGMENT", 0x24
|
||||
OTA_REQUEST_FRAGMENTS = "OTA_REQUEST_FRAGMENTS", 0x25
|
||||
OTA_SETTING = "OTA_SETTING", 0x26
|
||||
OTA_APPLY = "OTA_APPLY", 0x27
|
||||
OTA_ABORT = "OTA_ABORT", 0x28
|
||||
|
||||
LOCATE_REQUEST_RANGE = 0x30
|
||||
LOCATE_RANGE_RESULTS = 0x31
|
||||
LOCATE_RAW_FTM_RESULTS = 0x32
|
||||
LOCATE_REQUEST_RANGE = "LOCATE_REQUEST_RANGE", 0x30
|
||||
LOCATE_RANGE_RESULTS = "LOCATE_RANGE_RESULTS", 0x31
|
||||
LOCATE_RAW_FTM_RESULTS = "LOCATE_RAW_FTM_RESULTS", 0x32
|
||||
|
||||
REBOOT = 0x40
|
||||
REBOOT = "REBOOT", 0x40
|
||||
|
||||
REPORT_ERROR = 0x50
|
||||
REPORT_ERROR = "REPORT_ERROR", 0x50
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
|
@ -72,32 +73,266 @@ class MeshMessageType(EnumSchemaByNameMixin, IntEnum):
|
|||
return name
|
||||
|
||||
|
||||
M = TypeVar('M', bound='MeshMessage')
|
||||
class NoopMessage(discriminator_value(msg_type=MeshMessageType.NOOP), BaseModel):
|
||||
""" noop """
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshMessage(StructType, union_type_field="msg_type"):
|
||||
dst: str = field(metadata={"format": MacAddressFormat()})
|
||||
src: str = field(metadata={"format": MacAddressFormat()})
|
||||
msg_type: MeshMessageType = field(metadata={"format": EnumFormat('B', c_definition=False)}, init=False, repr=False)
|
||||
c_structs = {}
|
||||
c_struct_name = None
|
||||
class EchoRequestMessage(discriminator_value(msg_type=MeshMessageType.ECHO_REQUEST), BaseModel):
|
||||
""" repeat back string """
|
||||
content: Annotated[str, MaxLen(255), VarLen()] = ""
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
def __init_subclass__(cls, /, c_struct_name=None, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if c_struct_name:
|
||||
cls.c_struct_name = c_struct_name
|
||||
if c_struct_name in MeshMessage.c_structs:
|
||||
raise TypeError('duplicate use of c_struct_name %s' % c_struct_name)
|
||||
MeshMessage.c_structs[c_struct_name] = cls
|
||||
|
||||
class EchoResponseMessage(discriminator_value(msg_type=MeshMessageType.ECHO_RESPONSE), BaseModel):
|
||||
""" repeat back string """
|
||||
content: Annotated[str, MaxLen(255), VarLen()] = ""
|
||||
|
||||
|
||||
class MeshSigninMessage(discriminator_value(msg_type=MeshMessageType.MESH_SIGNIN), BaseModel):
|
||||
""" node says hello to upstream node """
|
||||
pass
|
||||
|
||||
|
||||
class MeshLayerAnnounceMessage(discriminator_value(msg_type=MeshMessageType.MESH_LAYER_ANNOUNCE), BaseModel):
|
||||
""" upstream node announces layer number """
|
||||
layer: Annotated[PositiveInt, Lt(2 ** 8), CDoc("mesh layer that the sending node is on")]
|
||||
|
||||
|
||||
class MeshAddDestinationsMessage(discriminator_value(msg_type=MeshMessageType.MESH_ADD_DESTINATIONS), BaseModel):
|
||||
""" downstream node announces served destination """
|
||||
addresses: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("adresses of the added destinations",)]
|
||||
|
||||
|
||||
class MeshRemoveDestinationsMessage(discriminator_value(msg_type=MeshMessageType.MESH_REMOVE_DESTINATIONS), BaseModel):
|
||||
""" downstream node announces no longer served destination """
|
||||
addresses: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("adresses of the removed destinations",)]
|
||||
|
||||
|
||||
class MeshRouteRequestMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_REQUEST), BaseModel):
|
||||
""" request routing information for node """
|
||||
request_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
address: Annotated[MacAddress, CDoc("target address for the route")]
|
||||
|
||||
|
||||
class MeshRouteResponseMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_RESPONSE), BaseModel):
|
||||
""" reporting the routing table entry to the given address """
|
||||
request_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
route: Annotated[MacAddress, CDoc("routing table entry or 00:00:00:00:00:00")]
|
||||
|
||||
|
||||
class MeshRouteTraceMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTE_TRACE), BaseModel):
|
||||
""" special message, collects all hop adresses on its way """
|
||||
request_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
trace: Annotated[list[MacAddress], MaxLen(16), VarLen(), CDoc("addresses encountered by this message")]
|
||||
|
||||
|
||||
class MeshRoutingFailedMessage(discriminator_value(msg_type=MeshMessageType.MESH_ROUTING_FAILED), BaseModel):
|
||||
""" TODO description"""
|
||||
address: MacAddress
|
||||
|
||||
|
||||
class ConfigDumpMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_DUMP), BaseModel):
|
||||
""" request for the node to dump its config """
|
||||
pass
|
||||
|
||||
|
||||
class ConfigHardwareMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_HARDWARE), BaseModel):
|
||||
""" respond hardware/chip info """
|
||||
chip: Annotated[ChipType, NoDef(), LenBytes(2), CName("chip_id")]
|
||||
revision_major: Annotated[NonNegativeInt, Lt(2**8)]
|
||||
revision_minor: Annotated[NonNegativeInt, Lt(2**8)]
|
||||
|
||||
def get_chip_display(self):
|
||||
return ChipType(self.chip).name.replace('_', '-')
|
||||
|
||||
|
||||
class ConfigBoardMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_BOARD), BaseModel):
|
||||
""" set/respond board config """
|
||||
board_config: Annotated[BoardConfig, CEmbed]
|
||||
|
||||
|
||||
class ConfigFirmwareMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_FIRMWARE), BaseModel):
|
||||
""" respond firmware info """
|
||||
app_desc: FirmwareAppDescription
|
||||
|
||||
|
||||
class ConfigPositionMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_POSITION), BaseModel):
|
||||
""" set/respond position config """
|
||||
x_pos: Annotated[int, Ge(-2**31), Lt(2**31)]
|
||||
y_pos: Annotated[int, Ge(-2**31), Lt(2**31)]
|
||||
z_pos: Annotated[int, Ge(-2**15), Lt(2**15)]
|
||||
|
||||
|
||||
class ConfigUplinkMessage(discriminator_value(msg_type=MeshMessageType.CONFIG_UPLINK), BaseModel):
|
||||
""" set/respond uplink config """
|
||||
enabled: bool
|
||||
ssid: Annotated[str, MaxLen(32)]
|
||||
password: Annotated[str, MaxLen(64)]
|
||||
channel: Annotated[PositiveInt, Le(15)]
|
||||
udp: bool
|
||||
ssl: bool
|
||||
host: Annotated[str, MaxLen(64)]
|
||||
port: Annotated[PositiveInt, Lt(2**16)]
|
||||
|
||||
|
||||
@unique
|
||||
class OTADeviceStatus(CEnum):
|
||||
""" ota status, the ones >= 0x10 denote a permanent failure """
|
||||
NONE = "NONE", 0x00
|
||||
|
||||
STARTED = "STARTED", 0x01
|
||||
APPLIED = "APPLIED", 0x02
|
||||
|
||||
START_FAILED = "START_FAILED", 0x10
|
||||
WRITE_FAILED = "WRITE_FAILED", 0x12
|
||||
APPLY_FAILED = "APPLY_FAILED", 0x13
|
||||
ROLLED_BACK = "ROLLED_BACK", 0x14
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.replace('_', ' ').lower()
|
||||
|
||||
@property
|
||||
def is_failed(self):
|
||||
return self >= self.START_FAILED
|
||||
|
||||
|
||||
class OTAStatusMessage(discriminator_value(msg_type=MeshMessageType.OTA_STATUS), BaseModel):
|
||||
""" report OTA status """
|
||||
update_id: Annotated[NonNegativeInt, Lt(2**32)]
|
||||
received_bytes: Annotated[NonNegativeInt, Lt(2**32)]
|
||||
next_expected_chunk: Annotated[NonNegativeInt, Lt(2**16)]
|
||||
auto_apply: bool
|
||||
auto_reboot: bool
|
||||
status: OTADeviceStatus
|
||||
|
||||
|
||||
class OTARequestStatusMessage(discriminator_value(msg_type=MeshMessageType.OTA_REQUEST_STATUS), BaseModel):
|
||||
""" request OTA status """
|
||||
pass
|
||||
|
||||
|
||||
class OTAStartMessage(discriminator_value(msg_type=MeshMessageType.OTA_START), BaseModel):
|
||||
""" instruct node to start OTA """
|
||||
update_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
total_bytes: Annotated[PositiveInt, Lt(2**32)]
|
||||
auto_apply: bool
|
||||
auto_reboot: bool
|
||||
|
||||
|
||||
class OTAURLMessage(discriminator_value(msg_type=MeshMessageType.OTA_URL), BaseModel):
|
||||
""" supply download URL for OTA update and who to distribute it to """
|
||||
update_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
distribute_to: MacAddress
|
||||
url: Annotated[str, MaxLen(255), VarLen()]
|
||||
|
||||
|
||||
class OTAFragmentMessage(discriminator_value(msg_type=MeshMessageType.OTA_FRAGMENT), BaseModel):
|
||||
""" supply OTA fragment """
|
||||
update_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
chunk: Annotated[PositiveInt, Lt(2**16)]
|
||||
data: Annotated[bytes, MaxLen(OTA_CHUNK_SIZE), VarLen()]
|
||||
|
||||
|
||||
class OTARequestFragmentsMessage(discriminator_value(msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS), BaseModel):
|
||||
""" request missing fragments """
|
||||
update_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
chunks: Annotated[list[Annotated[PositiveInt, Lt(2**16)]], MaxLen(128), VarLen()]
|
||||
|
||||
|
||||
class OTASettingMessage(discriminator_value(msg_type=MeshMessageType.OTA_SETTING), BaseModel):
|
||||
""" configure whether to automatically apply and reboot when update is completed """
|
||||
update_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
auto_apply: bool
|
||||
auto_reboot: bool
|
||||
|
||||
|
||||
class OTAApplyMessage(discriminator_value(msg_type=MeshMessageType.OTA_APPLY), BaseModel):
|
||||
""" apply OTA and optionally reboot """
|
||||
update_id: Annotated[PositiveInt, Lt(2**32)]
|
||||
reboot: bool
|
||||
|
||||
|
||||
class OTAAbortMessage(discriminator_value(msg_type=MeshMessageType.OTA_ABORT), BaseModel):
|
||||
""" announcing OTA abort """
|
||||
update_id: Annotated[NonNegativeInt, Lt(2**32)]
|
||||
|
||||
|
||||
class LocateRequestRangeMessage(discriminator_value(msg_type=MeshMessageType.LOCATE_REQUEST_RANGE), BaseModel):
|
||||
""" request to report distance to all nearby nodes """
|
||||
pass
|
||||
|
||||
|
||||
class LocateRangeResults(discriminator_value(msg_type=MeshMessageType.LOCATE_RANGE_RESULTS), BaseModel):
|
||||
""" reports distance to given nodes """
|
||||
ranges: Annotated[list[RangeResultItem], MaxLen(16), VarLen()]
|
||||
|
||||
|
||||
class LocateRawFTMResults(discriminator_value(msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS), BaseModel):
|
||||
""" reports distance to given nodes """
|
||||
peer: MacAddress
|
||||
results: Annotated[list[RawFTMEntry], MaxLen(16), VarLen()]
|
||||
|
||||
|
||||
class Reboot(discriminator_value(msg_type=MeshMessageType.REBOOT), BaseModel):
|
||||
""" reboot the device """
|
||||
pass
|
||||
|
||||
|
||||
class ReportError(discriminator_value(msg_type=MeshMessageType.REPORT_ERROR), BaseModel):
|
||||
""" report a critical error to upstream """
|
||||
message: Annotated[str, MaxLen(255), VarLen()]
|
||||
|
||||
|
||||
MeshMessageContent = Annotated[
|
||||
Union[
|
||||
NoopMessage,
|
||||
EchoRequestMessage,
|
||||
EchoResponseMessage,
|
||||
MeshSigninMessage,
|
||||
MeshLayerAnnounceMessage,
|
||||
MeshAddDestinationsMessage,
|
||||
MeshRemoveDestinationsMessage,
|
||||
MeshRouteRequestMessage,
|
||||
MeshRouteResponseMessage,
|
||||
MeshRouteTraceMessage,
|
||||
MeshRoutingFailedMessage,
|
||||
ConfigDumpMessage,
|
||||
ConfigHardwareMessage,
|
||||
ConfigBoardMessage,
|
||||
ConfigFirmwareMessage,
|
||||
ConfigPositionMessage,
|
||||
ConfigUplinkMessage,
|
||||
OTAStatusMessage,
|
||||
OTARequestStatusMessage,
|
||||
OTAStartMessage,
|
||||
OTAURLMessage,
|
||||
OTAFragmentMessage,
|
||||
OTARequestFragmentsMessage,
|
||||
OTASettingMessage,
|
||||
OTAApplyMessage,
|
||||
OTAAbortMessage,
|
||||
LocateRequestRangeMessage,
|
||||
LocateRangeResults,
|
||||
LocateRawFTMResults,
|
||||
Reboot,
|
||||
ReportError,
|
||||
],
|
||||
Discriminator("msg_type")
|
||||
]
|
||||
|
||||
|
||||
class MeshMessage(BaseModel):
|
||||
dst: MacAddress
|
||||
src: MacAddress
|
||||
content: MeshMessageContent
|
||||
|
||||
async def send(self, sender=None, exclude_uplink_address=None) -> bool:
|
||||
data = {
|
||||
"type": "mesh.send",
|
||||
"sender": sender,
|
||||
"exclude_uplink_address": exclude_uplink_address,
|
||||
"msg": MeshMessage.tojson(self),
|
||||
"msg": self.model_dump(),
|
||||
}
|
||||
|
||||
if self.dst in (MESH_CHILDREN_ADDRESS, MESH_BROADCAST_ADDRESS):
|
||||
|
@ -110,283 +345,4 @@ class MeshMessage(StructType, union_type_field="msg_type"):
|
|||
return False
|
||||
if uplink.node_id == exclude_uplink_address:
|
||||
return False
|
||||
await channels.layers.get_channel_layer().send(uplink.name, data)
|
||||
|
||||
@classmethod
|
||||
def get_ignore_c_fields(self):
|
||||
return set()
|
||||
|
||||
@classmethod
|
||||
def get_additional_c_fields(self):
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_variable_name(cls, base_name):
|
||||
return cls.c_struct_name or base_name
|
||||
|
||||
@classmethod
|
||||
def get_c_enum_name(cls):
|
||||
return normalize_name(cls.__name__.removeprefix('Mesh').removesuffix('Message')).upper()
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoopMessage(MeshMessage, msg_type=MeshMessageType.NOOP):
|
||||
""" noop """
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EchoRequestMessage(MeshMessage, msg_type=MeshMessageType.ECHO_REQUEST):
|
||||
""" repeat back string """
|
||||
content: str = field(default='', metadata={'format': VarStrFormat(max_len=255)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class EchoResponseMessage(MeshMessage, msg_type=MeshMessageType.ECHO_RESPONSE):
|
||||
""" repeat back string """
|
||||
content: str = field(default='', metadata={'format': VarStrFormat(max_len=255)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshSigninMessage(MeshMessage, msg_type=MeshMessageType.MESH_SIGNIN):
|
||||
""" node says hello to upstream node """
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshLayerAnnounceMessage(MeshMessage, msg_type=MeshMessageType.MESH_LAYER_ANNOUNCE):
|
||||
""" upstream node announces layer number """
|
||||
layer: int = field(metadata={
|
||||
"format": SimpleFormat('B'),
|
||||
"doc": "mesh layer that the sending node is on",
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshAddDestinationsMessage(MeshMessage, msg_type=MeshMessageType.MESH_ADD_DESTINATIONS):
|
||||
""" downstream node announces served destination """
|
||||
addresses: list[str] = field(default_factory=list, metadata={
|
||||
"format": MacAddressesListFormat(max_num=16),
|
||||
"doc": "adresses of the added destinations",
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshRemoveDestinationsMessage(MeshMessage, msg_type=MeshMessageType.MESH_REMOVE_DESTINATIONS):
|
||||
""" downstream node announces no longer served destination """
|
||||
addresses: list[str] = field(default_factory=list, metadata={
|
||||
"format": MacAddressesListFormat(max_num=16),
|
||||
"doc": "adresses of the removed destinations",
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshRouteRequestMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_REQUEST):
|
||||
""" request routing information for node """
|
||||
request_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
address: str = field(metadata={
|
||||
"format": MacAddressFormat(),
|
||||
"doc": "target address for the route"
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshRouteResponseMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_RESPONSE):
|
||||
""" reporting the routing table entry to the given address """
|
||||
request_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
route: str = field(metadata={
|
||||
"format": MacAddressFormat(),
|
||||
"doc": "routing table entry or 00:00:00:00:00:00"
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshRouteTraceMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTE_TRACE):
|
||||
""" special message, collects all hop adresses on its way """
|
||||
request_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
trace: list[str] = field(default_factory=list, metadata={
|
||||
"format": MacAddressesListFormat(max_num=16),
|
||||
"doc": "addresses encountered by this message",
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshRoutingFailedMessage(MeshMessage, msg_type=MeshMessageType.MESH_ROUTING_FAILED):
|
||||
""" TODO description"""
|
||||
address: str = field(metadata={"format": MacAddressFormat()})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigDumpMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_DUMP):
|
||||
""" request for the node to dump its config """
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigHardwareMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_HARDWARE):
|
||||
""" respond hardware/chip info """
|
||||
chip: ChipType = field(metadata={
|
||||
"format": EnumFormat("H", c_definition=False),
|
||||
"c_name": "chip_id",
|
||||
})
|
||||
revision_major: int = field(metadata={"format": SimpleFormat('B')})
|
||||
revision_minor: int = field(metadata={"format": SimpleFormat('B')})
|
||||
|
||||
def get_chip_display(self):
|
||||
return ChipType(self.chip).name.replace('_', '-')
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigBoardMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_BOARD):
|
||||
""" set/respond board config """
|
||||
board_config: BoardConfig = field(metadata={"c_embed": True, "json_embed": True})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigFirmwareMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_FIRMWARE):
|
||||
""" respond firmware info """
|
||||
app_desc: FirmwareAppDescription = field(metadata={'json_embed': True})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigPositionMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_POSITION):
|
||||
""" set/respond position config """
|
||||
x_pos: int = field(metadata={"format": SimpleFormat('i')})
|
||||
y_pos: int = field(metadata={"format": SimpleFormat('i')})
|
||||
z_pos: int = field(metadata={"format": SimpleFormat('h')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigUplinkMessage(MeshMessage, msg_type=MeshMessageType.CONFIG_UPLINK):
|
||||
""" set/respond uplink config """
|
||||
enabled: bool = field(metadata={"format": BoolFormat()})
|
||||
ssid: str = field(metadata={"format": FixedStrFormat(32)})
|
||||
password: str = field(metadata={"format": FixedStrFormat(64)})
|
||||
channel: int = field(metadata={"format": SimpleFormat('B')})
|
||||
udp: bool = field(metadata={"format": BoolFormat()})
|
||||
ssl: bool = field(metadata={"format": BoolFormat()})
|
||||
host: str = field(metadata={"format": FixedStrFormat(64)})
|
||||
port: int = field(metadata={"format": SimpleFormat('H')})
|
||||
|
||||
|
||||
@unique
|
||||
class OTADeviceStatus(EnumSchemaByNameMixin, IntEnum):
|
||||
""" ota status, the ones >= 0x10 denote a permanent failure """
|
||||
NONE = 0x00
|
||||
|
||||
STARTED = 0x01
|
||||
APPLIED = 0x02
|
||||
|
||||
START_FAILED = 0x10
|
||||
WRITE_FAILED = 0x12
|
||||
APPLY_FAILED = 0x13
|
||||
ROLLED_BACK = 0x14
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.replace('_', ' ').lower()
|
||||
|
||||
@property
|
||||
def is_failed(self):
|
||||
return self >= self.START_FAILED
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTAStatusMessage(MeshMessage, msg_type=MeshMessageType.OTA_STATUS):
|
||||
""" report OTA status """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
received_bytes: int = field(metadata={"format": SimpleFormat('I')})
|
||||
next_expected_chunk: int = field(metadata={"format": SimpleFormat('H')})
|
||||
auto_apply: bool = field(metadata={"format": BoolFormat()})
|
||||
auto_reboot: bool = field(metadata={"format": BoolFormat()})
|
||||
status: OTADeviceStatus = field(metadata={"format": EnumFormat('B')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTARequestStatusMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_STATUS):
|
||||
""" request OTA status """
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTAStartMessage(MeshMessage, msg_type=MeshMessageType.OTA_START):
|
||||
""" instruct node to start OTA """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
total_bytes: int = field(metadata={"format": SimpleFormat('I')})
|
||||
auto_apply: bool = field(metadata={"format": BoolFormat()})
|
||||
auto_reboot: bool = field(metadata={"format": BoolFormat()})
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTAURLMessage(MeshMessage, msg_type=MeshMessageType.OTA_URL):
|
||||
""" supply download URL for OTA update and who to distribute it to """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
distribute_to: str = field(metadata={"format": MacAddressFormat()})
|
||||
url: str = field(metadata={"format": VarStrFormat(max_len=255)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTAFragmentMessage(MeshMessage, msg_type=MeshMessageType.OTA_FRAGMENT):
|
||||
""" supply OTA fragment """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
chunk: int = field(metadata={"format": SimpleFormat('H')})
|
||||
data: bytes = field(metadata={"format": VarBytesFormat(max_size=OTA_CHUNK_SIZE)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTARequestFragmentsMessage(MeshMessage, msg_type=MeshMessageType.OTA_REQUEST_FRAGMENTS):
|
||||
""" request missing fragments """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
chunks: list[int] = field(metadata={"format": VarArrayFormat(SimpleFormat('H'), max_num=128)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTASettingMessage(MeshMessage, msg_type=MeshMessageType.OTA_SETTING):
|
||||
""" configure whether to automatically apply and reboot when update is completed """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
auto_apply: bool = field(metadata={"format": BoolFormat()})
|
||||
auto_reboot: bool = field(metadata={"format": BoolFormat()})
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTAApplyMessage(MeshMessage, msg_type=MeshMessageType.OTA_APPLY):
|
||||
""" apply OTA and optionally reboot """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
reboot: bool = field(metadata={"format": BoolFormat()})
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTAAbortMessage(MeshMessage, msg_type=MeshMessageType.OTA_ABORT):
|
||||
""" announcing OTA abort """
|
||||
update_id: int = field(metadata={"format": SimpleFormat('I')})
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocateRequestRangeMessage(MeshMessage, msg_type=MeshMessageType.LOCATE_REQUEST_RANGE):
|
||||
""" request to report distance to all nearby nodes """
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocateRangeResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RANGE_RESULTS):
|
||||
""" reports distance to given nodes """
|
||||
ranges: list[RangeResultItem] = field(metadata={"format": VarArrayFormat(RangeResultItem, max_num=16)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocateRawFTMResults(MeshMessage, msg_type=MeshMessageType.LOCATE_RAW_FTM_RESULTS):
|
||||
""" reports distance to given nodes """
|
||||
peer: str = field(metadata={"format": MacAddressFormat()})
|
||||
results: list[RawFTMEntry] = field(metadata={"format": VarArrayFormat(RawFTMEntry, max_num=16)})
|
||||
|
||||
|
||||
@dataclass
|
||||
class Reboot(MeshMessage, msg_type=MeshMessageType.REBOOT):
|
||||
""" reboot the device """
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReportError(MeshMessage, msg_type=MeshMessageType.REPORT_ERROR):
|
||||
""" report a critical error to upstream """
|
||||
message: str = field(metadata={"format": VarStrFormat(max_len=255)})
|
||||
await channels.layers.get_channel_layer().send(uplink.name, data)
|
|
@ -18,7 +18,7 @@ from django.utils.text import slugify
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from c3nav.mapdata.models.geometry.space import RangingBeacon
|
||||
from c3nav.mesh.dataformats import BoardType, ChipType, FirmwareImage
|
||||
from c3nav.mesh.schemas import BoardType, ChipType, FirmwareImage
|
||||
from c3nav.mesh.messages import ConfigFirmwareMessage, ConfigHardwareMessage
|
||||
from c3nav.mesh.messages import MeshMessage as MeshMessage
|
||||
from c3nav.mesh.messages import MeshMessageType
|
||||
|
@ -383,7 +383,7 @@ class NodeMessage(models.Model):
|
|||
|
||||
@cached_property
|
||||
def parsed(self) -> Self:
|
||||
return MeshMessage.fromjson(self.data)
|
||||
return MeshMessage.model_validate(self.data)
|
||||
|
||||
|
||||
class FirmwareVersion(models.Model):
|
||||
|
|
284
src/c3nav/mesh/schemas.py
Normal file
284
src/c3nav/mesh/schemas.py
Normal file
|
@ -0,0 +1,284 @@
|
|||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import unique
|
||||
from typing import Annotated, BinaryIO, ClassVar, Literal, Self, Union
|
||||
|
||||
from annotated_types import Gt, Le, Lt, MaxLen, Ge
|
||||
from pydantic import NegativeInt, PositiveInt
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.types import Discriminator, NonNegativeInt
|
||||
from pydantic_extra_types.mac_address import MacAddress
|
||||
|
||||
from c3nav.mesh.cformats import AsDefinition, AsHex, CName, ExistingCStruct, discriminator_value, \
|
||||
CEnum, TwoNibblesEncodable
|
||||
|
||||
|
||||
@unique
|
||||
class LedType(CEnum):
|
||||
NONE = "NONE", 0
|
||||
SERIAL = "SERIAL", 1
|
||||
MULTIPIN = "MULTIPIN", 2
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.lower()
|
||||
|
||||
|
||||
@unique
|
||||
class SerialLedType(CEnum):
|
||||
WS2812 = "WS2812", 1
|
||||
SK6812 = "SK6812", 2
|
||||
|
||||
|
||||
class NoLedConfig(discriminator_value(led_type=LedType.NONE), BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class SerialLedConfig(discriminator_value(led_type=LedType.SERIAL), BaseModel):
|
||||
serial_led_type: Annotated[SerialLedType, CName("type")]
|
||||
gpio: Annotated[PositiveInt, Lt(2**8)]
|
||||
|
||||
|
||||
class MultipinLedConfig(discriminator_value(led_type=LedType.MULTIPIN), BaseModel):
|
||||
gpio_red: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_green: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_blue: Annotated[PositiveInt, Lt(2**8)]
|
||||
|
||||
|
||||
LedConfig = Annotated[
|
||||
Union[
|
||||
NoLedConfig,
|
||||
SerialLedConfig,
|
||||
MultipinLedConfig,
|
||||
],
|
||||
Discriminator("led_type")
|
||||
]
|
||||
|
||||
|
||||
class BoardSPIConfig(BaseModel):
|
||||
"""
|
||||
configuration for spi bus used for ETH or UWB
|
||||
"""
|
||||
gpio_miso: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_mosi: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_clk: Annotated[PositiveInt, Lt(2**8)]
|
||||
|
||||
|
||||
class UWBConfig(BaseModel):
|
||||
"""
|
||||
configuration for the connection to the UWB module
|
||||
"""
|
||||
enable: bool
|
||||
gpio_cs: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_irq: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_rst: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_wakeup: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_exton: Annotated[PositiveInt, Lt(2**8)]
|
||||
|
||||
|
||||
class UplinkEthConfig(BaseModel):
|
||||
"""
|
||||
configuration for the connection to the ETH module
|
||||
"""
|
||||
enable: bool
|
||||
gpio_cs: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_int: Annotated[PositiveInt, Lt(2**8)]
|
||||
gpio_rst: Annotated[int, Ge(-1), Lt(2**7)]
|
||||
|
||||
|
||||
@unique
|
||||
class BoardType(CEnum):
|
||||
CUSTOM = "CUSTOM", 0x00
|
||||
|
||||
# devboards
|
||||
ESP32_C3_DEVKIT_M_1 = "ESP32_C3_DEVKIT_M_1", 0x01
|
||||
ESP32_C3_32S = "ESP32_C3_32S", 0x02
|
||||
|
||||
# custom boards
|
||||
C3NAV_UWB_BOARD = "C3NAV_UWB_BOARD", 0x10
|
||||
C3NAV_LOCATION_PCB_REV_0_1 = "C3NAV_LOCATION_PCB_REV_0_1", 0x11
|
||||
C3NAV_LOCATION_PCB_REV_0_2 = "C3NAV_LOCATION_PCB_REV_0_2", 0x12
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
if self.name.startswith('ESP32'):
|
||||
return self.name.replace('_', '-').replace('DEVKIT-', 'DevKit')
|
||||
if self.name.startswith('C3NAV'):
|
||||
name = self.name.replace('_', ' ').lower()
|
||||
name = name.replace('uwb', 'UWB').replace('pcb', 'PCB')
|
||||
name = re.sub(r'[0-9]+( [0-9+])+', lambda s: s[0].replace(' ', '.'), name)
|
||||
name = re.sub(r'rev.*', lambda s: s[0].replace(' ', ''), name)
|
||||
return name
|
||||
return self.name
|
||||
|
||||
|
||||
class CustomBoardConfig(discriminator_value(board=BoardType.CUSTOM), BaseModel):
|
||||
spi: Annotated[BoardSPIConfig, AsDefinition()]
|
||||
uwb: Annotated[UWBConfig, AsDefinition()]
|
||||
eth: Annotated[UplinkEthConfig, AsDefinition()]
|
||||
led: Annotated[LedConfig, AsDefinition()]
|
||||
|
||||
|
||||
class DevkitMBoardConfig(discriminator_value(board=BoardType.ESP32_C3_DEVKIT_M_1), BaseModel):
|
||||
spi: Annotated[BoardSPIConfig, AsDefinition()]
|
||||
uwb: Annotated[UWBConfig, AsDefinition()]
|
||||
eth: Annotated[UplinkEthConfig, AsDefinition()]
|
||||
|
||||
|
||||
class Esp32SBoardConfig(discriminator_value(board=BoardType.ESP32_C3_32S), BaseModel):
|
||||
spi: Annotated[BoardSPIConfig, AsDefinition()]
|
||||
uwb: Annotated[UWBConfig, AsDefinition()]
|
||||
eth: Annotated[UplinkEthConfig, AsDefinition()]
|
||||
|
||||
|
||||
class UwbBoardConfig(discriminator_value(board=BoardType.C3NAV_UWB_BOARD), BaseModel):
|
||||
eth: Annotated[UplinkEthConfig, AsDefinition()]
|
||||
|
||||
|
||||
class LocationPCBRev0Dot1BoardConfig(discriminator_value(board=BoardType.C3NAV_LOCATION_PCB_REV_0_1), BaseModel):
|
||||
eth: Annotated[UplinkEthConfig, AsDefinition()]
|
||||
|
||||
|
||||
class LocationPCBRev0Dot2BoardConfig(discriminator_value(board=BoardType.C3NAV_LOCATION_PCB_REV_0_2), BaseModel):
|
||||
eth: Annotated[UplinkEthConfig, AsDefinition()]
|
||||
|
||||
|
||||
BoardConfig = Annotated[
|
||||
Union[
|
||||
CustomBoardConfig,
|
||||
DevkitMBoardConfig,
|
||||
Esp32SBoardConfig,
|
||||
UwbBoardConfig,
|
||||
LocationPCBRev0Dot1BoardConfig,
|
||||
LocationPCBRev0Dot2BoardConfig,
|
||||
],
|
||||
Discriminator("board"),
|
||||
AsHex(),
|
||||
]
|
||||
|
||||
|
||||
class RangeResultItem(BaseModel):
|
||||
peer: MacAddress
|
||||
rssi: Annotated[NegativeInt, Gt(-100)]
|
||||
distance: Annotated[int, Gt(-32000), Lt(32000)]
|
||||
|
||||
|
||||
class RawFTMEntry(BaseModel):
|
||||
dlog_token: Annotated[PositiveInt, Lt(255)]
|
||||
rssi: Annotated[NegativeInt, Gt(-100)]
|
||||
rtt: Annotated[PositiveInt, Lt(2**32)]
|
||||
t1: Annotated[PositiveInt, Lt(2**64)]
|
||||
t2: Annotated[PositiveInt, Lt(2**64)]
|
||||
t3: Annotated[PositiveInt, Lt(2**64)]
|
||||
t4: Annotated[PositiveInt, Lt(2**64)]
|
||||
|
||||
|
||||
class FirmwareAppDescription(BaseModel):
|
||||
existing_c_struct: ClassVar = ExistingCStruct(name="esp_app_desc_t", includes=['<esp_app_desc.h>'])
|
||||
|
||||
magic_word: Literal[0xAB_CD_54_32] = field(repr=False)
|
||||
secure_version: Annotated[NonNegativeInt, Lt(2**32)]
|
||||
reserv1: Annotated[str, MaxLen(8*2), AsHex()] = field(repr=False)
|
||||
version: Annotated[str, MaxLen(32)]
|
||||
project_name: Annotated[str, MaxLen(32)]
|
||||
compile_time: Annotated[str, MaxLen(16)]
|
||||
compile_date: Annotated[str, MaxLen(16)]
|
||||
idf_version: Annotated[str, MaxLen(32)]
|
||||
app_elf_sha256: Annotated[str, MaxLen(64), AsHex()]
|
||||
reserv2: Annotated[str, MaxLen(20*4*2), AsHex()] = field(repr=False)
|
||||
|
||||
|
||||
@unique
|
||||
class SPIFlashMode(CEnum):
|
||||
QIO = "QID", 0
|
||||
QOUT = "QOUT", 1
|
||||
DIO = "DIO", 2
|
||||
DOUT = "DOUT", 3
|
||||
|
||||
|
||||
@unique
|
||||
class FlashSize(CEnum):
|
||||
SIZE_1MB = "SIZE_1MB", 0
|
||||
SIZE_2MB = "SIZE_2MB", 1
|
||||
SIZE_4MB = "SIZE_4MB", 2
|
||||
SIZE_8MB = "SIZE_8MB", 3
|
||||
SIZE_16MB = "SIZE_16MB", 4
|
||||
SIZE_32MB = "SIZE_32MB", 5
|
||||
SIZE_64MB = "SIZE_64MB", 6
|
||||
SIZE_128MB = "SIZE_128MB", 7
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.removeprefix('SIZE_')
|
||||
|
||||
|
||||
@unique
|
||||
class FlashFrequency(CEnum):
|
||||
FREQ_40MHZ = "FREQ_40MHZ", 0
|
||||
FREQ_26MHZ = "FREQ_26MHZ", 1
|
||||
FREQ_20MHZ = "FREQ_20MHZ", 2
|
||||
FREQ_80MHZ = "FREQ_80MHZ", 0xf
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.removeprefix('FREQ_').replace('MHZ', 'Mhz')
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashSettings(TwoNibblesEncodable):
|
||||
size: FlashSize
|
||||
frequency: FlashFrequency
|
||||
|
||||
@property
|
||||
def display(self):
|
||||
return f"{self.size.pretty_name} ({self.frequency.pretty_name})"
|
||||
|
||||
|
||||
@unique
|
||||
class ChipType(CEnum):
|
||||
ESP32_S2 = "ESP32_S2", 2
|
||||
ESP32_C3 = "ESP32_C3", 5
|
||||
|
||||
@property
|
||||
def pretty_name(self):
|
||||
return self.name.replace('_', '-')
|
||||
|
||||
|
||||
class FirmwareImageFileHeader(BaseModel):
|
||||
magic_word: Literal[0xE9] = field(repr=False)
|
||||
num_segments: Annotated[PositiveInt, Lt(2**8)]
|
||||
spi_flash_mode: SPIFlashMode
|
||||
flash_stuff: FlashSettings
|
||||
entry_point: Annotated[PositiveInt, Lt(2**32)]
|
||||
|
||||
|
||||
class FirmwareImageFileHeader(BaseModel):
|
||||
major: int
|
||||
minor: int
|
||||
num_segments: Annotated[PositiveInt, Lt(2**8)]
|
||||
spi_flash_mode: SPIFlashMode
|
||||
flash_stuff: FlashSettings
|
||||
entry_point: Annotated[PositiveInt, Lt(2**32)]
|
||||
|
||||
|
||||
class FirmwareImageExtendedFileHeader(BaseModel):
|
||||
wp_pin: Annotated[PositiveInt, Lt(2**8)]
|
||||
drive_settings: Annotated[bytes, MaxLen(3)]
|
||||
chip: Annotated[ChipType, Lt(2**16)]
|
||||
min_chip_rev_old: int
|
||||
min_chip_rev: Annotated[PositiveInt, Le(9999)]
|
||||
max_chip_rev: Annotated[PositiveInt, Le(9999)]
|
||||
reserv: Annotated[bytes, MaxLen(4)] = field(repr=False)
|
||||
hash_appended: bool
|
||||
|
||||
|
||||
class FirmwareImage(BaseModel):
|
||||
header: FirmwareImageFileHeader
|
||||
ext_header: FirmwareImageExtendedFileHeader
|
||||
first_segment_headers: Annotated[bytes, MaxLen(2)] = field(repr=False)
|
||||
app_desc: FirmwareAppDescription
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: BinaryIO) -> Self:
|
||||
result, data = cls.decode(file.read(FirmwareImage.get_min_size()))
|
||||
return result
|
|
@ -12,7 +12,7 @@ UPLINK_TIMEOUT = UPLINK_PING+5
|
|||
|
||||
|
||||
def indent_c(code):
|
||||
return " "+code.replace("\n", "\n ")
|
||||
return " "+code.replace("\n", "\n ").replace("\n \n", "\n\n")
|
||||
|
||||
|
||||
def get_node_names():
|
||||
|
|
|
@ -82,7 +82,7 @@ def locate_test(request):
|
|||
None
|
||||
)
|
||||
return {
|
||||
"ranges": msg.parsed.tojson(msg.parsed)["ranges"],
|
||||
"ranges": msg.parsed.model_dump()["ranges"],
|
||||
"datetime": msg.datetime,
|
||||
"location": location.serialize(simple_geometry=True) if location else None
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ django-bootstrap3==23.6
|
|||
django-compressor==4.4
|
||||
csscompressor==0.9.5
|
||||
django-ninja==1.1.0
|
||||
pydantic-extra-types==2.5.0
|
||||
django-filter==23.5
|
||||
django-environ==0.11.2
|
||||
shapely==2.0.3
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue