From da5ff59c96adb64e239158c2f28d61586edca207 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Fri, 6 Oct 2023 01:06:30 +0200 Subject: [PATCH] new parsing works --- .../templates/control/mesh_messages.html | 4 +- src/c3nav/control/views/mesh.py | 17 ++- src/c3nav/mesh/consumers.py | 6 +- src/c3nav/mesh/dataformats.py | 100 ++++++++++++------ src/c3nav/mesh/forms.py | 3 +- .../mesh/management/commands/mesh_msg_c.py | 6 +- src/c3nav/mesh/messages.py | 39 +------ src/c3nav/mesh/models.py | 5 +- 8 files changed, 91 insertions(+), 89 deletions(-) diff --git a/src/c3nav/control/templates/control/mesh_messages.html b/src/c3nav/control/templates/control/mesh_messages.html index d7abf8eb..4a70ad2c 100644 --- a/src/c3nav/control/templates/control/mesh_messages.html +++ b/src/c3nav/control/templates/control/mesh_messages.html @@ -59,9 +59,9 @@ X={{ msg.parsed.x_pos }}, Y={{ msg.parsed.y_pos }}, Z={{ msg.parsed.z_pos }} {% elif msg.get_message_type_display == "MESH_ADD_DESTINATIONS" or msg.get_message_type_display == "MESH_REMOVE_DESTINATIONS" %} - mac adresses:
+ adresses:
diff --git a/src/c3nav/control/views/mesh.py b/src/c3nav/control/views/mesh.py index 7ad13b32..ed469207 100644 --- a/src/c3nav/control/views/mesh.py +++ b/src/c3nav/control/views/mesh.py @@ -1,3 +1,4 @@ +from functools import cached_property from uuid import uuid4 from django.contrib import messages @@ -12,7 +13,7 @@ from django.views.generic import ListView, DetailView, FormView, UpdateView, Tem from c3nav.control.forms import MeshMessageFilterForm from c3nav.control.views.base import ControlPanelMixin from c3nav.mesh.forms import MeshMessageForm, MeshNodeForm -from c3nav.mesh.messages import MeshMessageType +from c3nav.mesh.messages import MeshMessageType, MeshMessage from c3nav.mesh.models import MeshNode, NodeMessage from c3nav.mesh.utils import get_node_names @@ -99,9 +100,13 @@ class MeshMessageListView(ControlPanelMixin, ListView): class MeshMessageSendView(ControlPanelMixin, FormView): template_name = "control/mesh_message_send.html" + @cached_property + def msg_type(self): + return MeshMessageType[self.kwargs['msg_type']] + def get_form_class(self): try: - return MeshMessageForm.get_form_for_type(MeshMessageType[self.kwargs['msg_type']]) + return MeshMessageForm.get_form_for_type(self.msg_type) except KeyError: raise Http404('unknown message type') @@ -112,13 +117,15 @@ class MeshMessageSendView(ControlPanelMixin, FormView): } def get_initial(self): - if 'recipient' in self.kwargs and self.kwargs['msg_type'].startswith('CONFIG_'): + if 'recipient' in self.kwargs and self.msg_type.name.startswith('CONFIG_'): try: node = MeshNode.objects.get(address=self.kwargs['recipient']) except MeshNode.DoesNotExist: pass else: - return node.last_messages[self.kwargs['msg_type']].parsed.tojson() + return MeshMessage.get_type(self.msg_type).tojson( + node.last_messages[self.msg_type].parsed + ) return {} def get_success_url(self): @@ -155,7 +162,7 @@ class MeshMessageSendingView(ControlPanelMixin, TemplateView): "node_names": node_names, "send_uuid": uuid, **data, - "node_name": node_names[data["msg_data"]["address"]], + "node_name": node_names.get(data["msg_data"].get("address"), ""), "recipients": [(address, node_names[address]) for address in data["recipients"]], "msg_type": MeshMessageType(data["msg_data"]["msg_id"]).name, } diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index ce6e0a90..f14e8fd7 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -33,7 +33,7 @@ class MeshConsumer(WebsocketConsumer): def send_msg(self, msg, sender=None): # print("sending", msg) # self.log_text(msg.dst, "sending %s" % msg) - self.send(bytes_data=msg.encode()) + self.send(bytes_data=MeshMessage.encode(msg)) async_to_sync(self.channel_layer.group_send)("mesh_msg_sent", { "type": "mesh.msg_sent", "timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"), @@ -48,7 +48,7 @@ class MeshConsumer(WebsocketConsumer): if bytes_data is None: return try: - msg = messages.MeshMessage.decode(bytes_data) + msg, data = messages.MeshMessage.decode(bytes_data) except Exception: traceback.print_exc() return @@ -119,7 +119,7 @@ class MeshConsumer(WebsocketConsumer): self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"]) def log_received_message(self, src_node: MeshNode, msg: messages.MeshMessage): - as_json = msg.tojson() + as_json = MeshMessage.tojson(msg) async_to_sync(self.channel_layer.group_send)("mesh_msg_received", { "type": "mesh.msg_received", "timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"), diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py index f497f15c..4cd07806 100644 --- a/src/c3nav/mesh/dataformats.py +++ b/src/c3nav/mesh/dataformats.py @@ -2,7 +2,7 @@ import re import struct from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields -from enum import IntEnum, unique +from enum import IntEnum, unique, Enum from typing import Self, Sequence, Any from c3nav.mesh.utils import indent_c @@ -17,7 +17,7 @@ class BaseFormat(ABC): @classmethod @abstractmethod - def decode(cls, data) -> tuple[Any, bytes]: + def decode(cls, data: bytes) -> tuple[Any, bytes]: pass def fromjson(self, data): @@ -48,7 +48,9 @@ class SimpleFormat(BaseFormat): self.num = int(self.fmt[:-1]) if len(self.fmt) > 1 else 1 def encode(self, value): - return struct.pack(self.fmt, (value, ) if self.num == 1 else tuple(value)) + if self.num == 1: + return struct.pack(self.fmt, value) + return struct.pack(self.fmt, *value) def decode(self, data: bytes) -> tuple[Any, bytes]: value = struct.unpack(self.fmt, data[:self.size]) @@ -104,7 +106,7 @@ class FixedHexFormat(SimpleFormat): super().__init__('%dB' % self.num) def encode(self, value: str): - return super().encode(tuple(bytes.fromhex(value))) + return super().encode(tuple(bytes.fromhex(value.replace(':', '')))) def decode(self, data: bytes) -> tuple[str, bytes]: return self.sep.join(('%02x' % i) for i in data[:self.num]), data[self.num:] @@ -137,10 +139,12 @@ class VarArrayFormat(BaseVarFormat): def decode(self, data: bytes) -> tuple[list[Any], bytes]: num = struct.unpack(self.num_fmt, data[:self.num_size])[0] - return [ - self.child_type.decode(data[i:i+self.child_size]) - for i in range(self.num_size, self.num_size+num*self.child_size, self.child_size) - ], data[self.num_size+num*self.child_size:] + data = data[self.num_size:] + result = [] + for i in range(num): + item, data = self.child_type.decode(data) + result.append(item) + return result, data def get_c_parts(self): pre, post = self.child_type.get_c_parts() @@ -192,23 +196,34 @@ class StructType: super().__init_subclass__(**kwargs) @classmethod - def encode(cls, instance) -> bytes: + def get_types(cls): + if not cls.union_type_field: + raise TypeError('Not a union class') + return cls._union_options[cls.union_type_field] + + @classmethod + def get_type(cls, type_id) -> Self: + if not cls.union_type_field: + raise TypeError('Not a union class') + return cls.get_types()[type_id] + + @classmethod + def encode(cls, instance, ignore_fields=()) -> bytes: data = bytes() if cls.union_type_field and type(instance) is not cls: if not isinstance(instance, cls): raise ValueError('expected value of type %r, got %r' % (cls, instance)) - for field_ in fields(instance): - if field_.name is cls.union_type_field: - data += field_.metadata["format"].encode(getattr(instance, field_.name)) - break - else: - raise TypeError('couldn\'t find %s value' % cls.union_type_field) + for field_ in fields(cls): + data += field_.metadata["format"].encode(getattr(instance, field_.name)) - data += instance.encode(instance) + # todo: better + data += instance.encode(instance, ignore_fields=set(f.name for f in fields(cls))) return data for field_ in fields(cls): + if field_.name in ignore_fields: + continue value = getattr(instance, field_.name) if "format" in field_.metadata: data += field_.metadata["format"].encode(value) @@ -224,30 +239,35 @@ class StructType: @classmethod def decode(cls, data: bytes) -> Self: - values = {} + orig_data = data + kwargs = {} + no_init_data = {} for field_ in fields(cls): if "format" in field_.metadata: - data = field_.metadata["format"].decode(data) + value, data = field_.metadata["format"].decode(data) elif issubclass(field_.type, StructType): - data = field_.type.decode(data) + value, data = field_.type.decode(data) else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__name__, field_.name)) - values[field_.name] = field_.metadata["format"].decode(data) + if field_.init: + kwargs[field_.name] = value + else: + no_init_data[field_.name] = value if cls.union_type_field: try: - type_value = values[cls.union_type_field] + type_value = no_init_data[cls.union_type_field] except KeyError: raise TypeError('union_type_field %s.%s is missing' % (cls.__name__, cls.union_type_field)) try: - klass = cls._union_options[type_value] + klass = cls.get_type(type_value) except KeyError: raise TypeError('union_type_field %s.%s value %r no known' % (cls.__name__, cls.union_type_field, type_value)) - return klass.decode(data) - return cls(**values) + return klass.decode(orig_data) + return cls(**kwargs), data @classmethod def tojson(cls, instance) -> dict: @@ -259,7 +279,7 @@ class StructType: for field_ in fields(instance): if field_.name is cls.union_type_field: - result[field_.name] = field_.metadata["format"].encode(getattr(instance, field_.name)) + result[field_.name] = field_.metadata["format"].tojson(getattr(instance, field_.name)) break else: raise TypeError('couldn\'t find %s value' % cls.union_type_field) @@ -282,32 +302,42 @@ class StructType: return result @classmethod - def fromjson(cls, data): + def upgrade_json(cls, data): + return data + + @classmethod + def fromjson(cls, data: dict): data = data.copy() # todo: upgrade_json + cls.upgrade_json(data) kwargs = {} + no_init_data = {} for field_ in fields(cls): + raw_value = data.get(field_.name, None) if "format" in field_.metadata: - data = field_.metadata["format"].decode(data) + value = field_.metadata["format"].fromjson(raw_value) elif issubclass(field_.type, StructType): - data = field_.type.decode(data) + value = field_.type.fromjson(raw_value) else: raise TypeError('field %s.%s has no format and is no StructType' % (cls.__name__, field_.name)) - kwargs[field_.name], data = field_.metadata["format"].decode(data) + if field_.init: + kwargs[field_.name] = value + else: + no_init_data[field_.name] = value if cls.union_type_field: try: - type_value = kwargs[cls.union_type_field] + type_value = no_init_data.pop(cls.union_type_field) except KeyError: raise TypeError('union_type_field %s.%s is missing' % (cls.__name__, cls.union_type_field)) try: - klass = cls._union_options[type_value] + klass = cls.get_type(type_value) except KeyError: - raise TypeError('union_type_field %s.%s value %r no known' % + raise TypeError('union_type_field %s.%s value 0x%02x no known' % (cls.__name__, cls.union_type_field, type_value)) return klass.fromjson(data) @@ -342,7 +372,7 @@ class StructType: if cls.union_type_field: parent_fields = set(field_.name for field_ in fields(cls)) union_items = [] - for key, option in cls._union_options[cls.union_type_field].items(): + for key, option in cls.get_types().items(): base_name = normalize_name(getattr(key, 'name', option.__name__)) if union_member_as_types: struct_name = cls.get_struct_name(base_name) @@ -360,7 +390,7 @@ class StructType: ) union_items.append( "uint8_t bytes[%s];" % max( - (option.get_min_size() for option in cls._union_options[cls.union_type_field].values()), + (option.get_min_size() for option in cls.get_types().values()), default=0, ) ) @@ -417,7 +447,7 @@ class StructType: if cls.union_type_field: return ( {f.name: field for f in fields()}[cls.union_type_field].metadata["format"].get_min_size() + - sum((option.get_min_size() for option in cls._union_options[cls.union_type_field].values()), start=0) + sum((option.get_min_size() for option in cls.get_types().values()), start=0) ) return sum((f.metadata.get("format", f.type).get_min_size() for f in fields(cls)), start=0) diff --git a/src/c3nav/mesh/forms.py b/src/c3nav/mesh/forms.py index 2a94ebf1..0f5e0281 100644 --- a/src/c3nav/mesh/forms.py +++ b/src/c3nav/mesh/forms.py @@ -54,6 +54,7 @@ class MeshMessageForm(forms.Form): if cls.msg_type in MeshMessageForm.msg_types: raise TypeError('duplicate use of msg %s' % cls.msg_type) MeshMessageForm.msg_types[cls.msg_type] = cls + cls.msg_type_class = MeshMessage.get_type(cls.msg_type) @classmethod def get_form_for_type(cls, msg_type): @@ -94,7 +95,7 @@ class MeshMessageForm(forms.Form): class MeshRouteRequestForm(MeshMessageForm): msg_type = MeshMessageType.MESH_ROUTE_REQUEST - address = forms.ChoiceField(choices=()) + address = forms.ChoiceField(choices=(), required=True) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/c3nav/mesh/management/commands/mesh_msg_c.py b/src/c3nav/mesh/management/commands/mesh_msg_c.py index ce06f89f..c8175edd 100644 --- a/src/c3nav/mesh/management/commands/mesh_msg_c.py +++ b/src/c3nav/mesh/management/commands/mesh_msg_c.py @@ -23,7 +23,7 @@ class Command(BaseCommand): struct_lines = {} ignore_names = set(field_.name for field_ in fields(MeshMessage)) - for msg_id, msg_type in MeshMessage.get_msg_types().items(): + for msg_id, msg_type in MeshMessage.get_types().items(): if msg_type.c_struct_name: if msg_type.c_struct_name in done_struct_names: continue @@ -53,10 +53,10 @@ class Command(BaseCommand): print("} mesh_msg_data_t;") print() - max_msg_type = max(MeshMessage.get_msg_types().keys()) + max_msg_type = max(MeshMessage.get_types().keys()) macro_data = [] for i in range(((max_msg_type//16)+1)*16): - msg_type = MeshMessage.get_msg_types().get(i, None) + msg_type = MeshMessage.get_types().get(i, None) if msg_type: name = (msg_type.c_struct_name or self.shorten_name(normalize_name( getattr(msg_type.msg_id, 'name', msg_type.__name__) diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index a01b8a1d..42e8cc7e 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -67,43 +67,11 @@ class MeshMessage(StructType, union_type_field="msg_id"): raise TypeError('duplicate use of c_struct_name %s' % c_struct_name) MeshMessage.c_structs[c_struct_name] = cls - def encode(self): - data = bytes() - for field_ in fields(self): - data += field_.metadata["format"].encode(getattr(self, field_.name)) - return data - - @classmethod - def decode(cls, data: bytes) -> M: - klass = cls.msg_types[data[12]] - values = {} - for field_ in fields(klass): - values[field_.name], data = field_.metadata["format"].decode(data) - values.pop('msg_id') - return klass(**values) - - def tojson(self): - return asdict(self) - - @classmethod - def fromjson(cls, data) -> M: - kwargs = data.copy() - klass = cls.msg_types[kwargs.pop('msg_id')] - kwargs = klass.upgrade_json(kwargs) - for field_ in fields(klass): - if is_dataclass(field_.type): - kwargs[field_.name] = field_.type.fromjson(kwargs[field_.name]) - return klass(**kwargs) - - @classmethod - def upgrade_json(cls, data): - return data - def send(self, sender=None): async_to_sync(channels.layers.get_channel_layer().group_send)(get_mesh_comm_group(self.dst), { "type": "mesh.send", "sender": sender, - "msg": self.tojson() + "msg": MeshMessage.tojson(self), }) @classmethod @@ -136,10 +104,6 @@ class MeshMessage(StructType, union_type_field="msg_id"): cls.__name__.removeprefix('Mesh').removesuffix('Message') ).upper().replace('CONFIG', 'CFG').replace('FIRMWARE', 'FW').replace('POSITION', 'POS') - @classmethod - def get_msg_types(cls): - return cls._union_options["msg_id"] - @dataclass class NoopMessage(MeshMessage, msg_id=MeshMessageType.NOOP): @@ -264,7 +228,6 @@ class ConfigFirmwareMessage(MeshMessage, msg_id=MeshMessageType.CONFIG_FIRMWARE) @classmethod def upgrade_json(cls, data): - data = data.copy() # todo: deepcopy? if 'revision' in data: data['revision_major'], data['revision_minor'] = data.pop('revision') return data diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py index a27ebcae..9702a99a 100644 --- a/src/c3nav/mesh/models.py +++ b/src/c3nav/mesh/models.py @@ -1,6 +1,7 @@ from collections import UserDict from functools import cached_property from operator import attrgetter +from typing import Mapping, Self, Any from django.db import models, NotSupportedError from django.utils.translation import gettext_lazy as _ @@ -89,7 +90,7 @@ class MeshNode(models.Model): return self.address @cached_property - def last_messages(self): + def last_messages(self) -> Mapping[Any, Self]: return LastMessagesByTypeLookup(self) @@ -107,7 +108,7 @@ class NodeMessage(models.Model): return '(#%d) %s at %s' % (self.pk, self.get_message_type_display(), self.datetime) @cached_property - def parsed(self): + def parsed(self) -> dict: return MeshMessage.fromjson(self.data)