From 683f80ef751f02dc8fd14313af6b4a37d9deac9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Mon, 4 Apr 2022 14:48:43 +0200 Subject: [PATCH] set up base websocket server for mesh communication --- src/c3nav/mesh/consumers.py | 21 +++++--- src/c3nav/mesh/messages.py | 101 ++++++++++++++++++++++++++++++++++++ src/c3nav/mesh/urls.py | 4 +- 3 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 src/c3nav/mesh/messages.py diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 9dd9b2eb..4ffe6e86 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -1,12 +1,19 @@ -from channels.generic.websocket import WebsocketConsumer +from channels.generic.websocket import AsyncWebsocketConsumer + +from c3nav.mesh import messages -class EchoConsumer(WebsocketConsumer): - def connect(self): - self.accept() +class MeshConsumer(AsyncWebsocketConsumer): + async def connect(self): + await self.accept() - def disconnect(self, close_code): + async def disconnect(self, close_code): pass - def receive(self, text_data): - self.send(text_data=text_data) + async def receive(self, text_data=None, bytes_data=None): + if bytes_data is None: + return + msg = messages.Message.decode(bytes_data) + print('Received message:', msg) + if isinstance(msg, messages.MeshSigninMessage): + await self.send(messages.MeshLayerAnnounceMessage(messages.NO_LAYER).encode()) diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py new file mode 100644 index 00000000..7c46bbd0 --- /dev/null +++ b/src/c3nav/mesh/messages.py @@ -0,0 +1,101 @@ +import struct +from dataclasses import dataclass, field, fields + +NO_LAYER = 0xFF +MAC_FMT = '%02x:%02x:%02x:%02x:%02x:%02x' + + +class SimpleFormat: + def __init__(self, fmt): + self.fmt = fmt + self.size = struct.calcsize(fmt) + + def encode(self, value): + return struct.pack(self.fmt, value) + + def decode(self, data: bytes): + return struct.unpack(self.fmt, data[:self.size]), data[self.size:] + + +class VarStrFormat: + def encode(self, value: str) -> bytes: + return bytes((len(value)+1, )) + value.encode() + bytes((0, )) + + def decode(self, data: bytes): + return data[1:data[0]].decode(), data[data[0]+1:] + + +class MacAddressFormat: + def encode(self, value: str) -> bytes: + return bytes(int(value[i*3:i*3+2], 16) for i in range(6)) + + def decode(self, data: bytes): + return (MAC_FMT % tuple(data[:6])), data[6:] + + +class MacAddressesListFormat: + def encode(self, value: list[str]) -> bytes: + return bytes((len(value), )) + sum( + (bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value), + b'' + ) + + def decode(self, data: bytes): + return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6] + + +@dataclass +class Message: + msg_types = {} + + # noinspection PyMethodOverriding + def __init_subclass__(cls, /, msg_id=None, **kwargs): + super().__init_subclass__(**kwargs) + if msg_id: + cls.msg_id = msg_id + Message.msg_types[msg_id] = cls + + def encode(self): + data = bytes((self.msg_id, )) + for field_ in fields(self): + data += field_.metadata['format'].encode(getattr(self, field_.name)) + return data + + @classmethod + def decode(cls, data: bytes): + klass = cls.msg_types[data[0]] + data = data[1:] + values = {} + for field_ in fields(klass): + values[field_.name], data = field_.metadata['format'].decode(data) + return klass(**values) + + +@dataclass +class EchoMessage(Message, msg_id=0x01): + content: str = field(default='', metadata={'format': VarStrFormat()}) + + +@dataclass +class MeshSigninMessage(Message, msg_id=0x02): + mac_address: str = field(metadata={'format': MacAddressFormat()}) + + +@dataclass +class MeshLayerAnnounceMessage(Message, msg_id=0x03): + layer: int = field(metadata={'format': SimpleFormat('B')}) + + +@dataclass +class BaseMeshDestinationsMessage(Message): + mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()}) + + +@dataclass +class MeshAddDestinationsMessage(BaseMeshDestinationsMessage, msg_id=0x04): + pass + + +@dataclass +class MeshRemoveDestinationsMessage(BaseMeshDestinationsMessage, msg_id=0x05): + pass diff --git a/src/c3nav/mesh/urls.py b/src/c3nav/mesh/urls.py index dd2bc410..4d075a86 100644 --- a/src/c3nav/mesh/urls.py +++ b/src/c3nav/mesh/urls.py @@ -1,7 +1,7 @@ from django.urls import path -from c3nav.mesh.consumers import EchoConsumer +from c3nav.mesh.consumers import MeshConsumer websocket_urlpatterns = [ - path('ws', EchoConsumer.as_asgi()), + path('ws', MeshConsumer.as_asgi()), ]