From 67969951ed452ad8645035e7dea5618b45582342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Fri, 15 Apr 2022 20:02:42 +0200 Subject: [PATCH] handle incoming messages from mesh --- src/c3nav/mesh/consumers.py | 30 +++- src/c3nav/mesh/dataformats.py | 151 ++++++++++++++++++ src/c3nav/mesh/messages.py | 179 ++++++---------------- src/c3nav/mesh/migrations/0001_initial.py | 51 ++++++ src/c3nav/mesh/models.py | 48 ++++++ 5 files changed, 327 insertions(+), 132 deletions(-) create mode 100644 src/c3nav/mesh/dataformats.py create mode 100644 src/c3nav/mesh/migrations/0001_initial.py diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index e4a1af62..26c43b04 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -1,6 +1,10 @@ +import traceback + +from channels.db import database_sync_to_async from channels.generic.websocket import AsyncWebsocketConsumer from c3nav.mesh import messages +from c3nav.mesh.models import MeshNode, NodeMessage class MeshConsumer(AsyncWebsocketConsumer): @@ -19,15 +23,35 @@ class MeshConsumer(AsyncWebsocketConsumer): async def receive(self, text_data=None, bytes_data=None): if bytes_data is None: return - msg = messages.Message.decode(bytes_data) + try: + msg = messages.Message.decode(bytes_data) + except Exception: + traceback.print_exc() + return + + if msg.dst != messages.ROOT_ADDRESS and msg.dst != messages.PARENT_ADDRESS: + print('Received message for forwarding:', msg) + # todo: this message isn't for us, forward it + return + print('Received message:', msg) + node = await self.log_received_message(msg) # noqa if isinstance(msg, messages.MeshSigninMessage): await self.send_msg(messages.MeshLayerAnnounceMessage( - src='00:00:00:00:00:00', + src=messages.ROOT_ADDRESS, dst=msg.src, layer=messages.NO_LAYER )) await self.send_msg(messages.ConfigDumpMessage( - src='00:00:00:00:00:00', + src=messages.ROOT_ADDRESS, dst=msg.src, )) + + @database_sync_to_async + def log_received_message(self, msg: messages.Message) -> MeshNode: + node, created = MeshNode.objects.get_or_create(address=msg.src) + return NodeMessage.objects.create( + node=node, + message_type=msg.msg_id, + data=msg.tojson() + ) diff --git a/src/c3nav/mesh/dataformats.py b/src/c3nav/mesh/dataformats.py new file mode 100644 index 00000000..cab3fa9d --- /dev/null +++ b/src/c3nav/mesh/dataformats.py @@ -0,0 +1,151 @@ +import struct +from dataclasses import dataclass, field +from enum import IntEnum + +MAC_FMT = '%02x:%02x:%02x:%02x:%02x:%02x' + + +class SimpleFormat: + def __init__(self, fmt): + self.fmt = fmt + self.size = struct.calcsize(fmt) + + def encode(self, value): + return struct.pack(self.fmt, value) + + def decode(self, data: bytes): + value = struct.unpack(self.fmt, data[:self.size]) + if len(value) == 1: + value = value[0] + return value, data[self.size:] + + +class FixedStrFormat: + def __init__(self, num): + self.num = num + + def encode(self, value): + return struct.pack('%ss' % self.num, value) + + def decode(self, data: bytes): + return struct.unpack('%ss' % self.num, data[:self.num])[0].rstrip(bytes((0, ))).decode(), data[self.num:] + + +class BoolFormat: + def encode(self, value): + return struct.pack('B', (int(value), )) + + def decode(self, data: bytes): + return bool(struct.unpack('B', data[:1])[0]), data[1:] + + +class HexFormat: + def __init__(self, num, sep=''): + self.num = num + self.sep = sep + + def encode(self, value): + return struct.pack('%ss' % self.num, bytes.fromhex(value)) + + def decode(self, data: bytes): + return ( + struct.unpack('%ss' % self.num, data[:self.num])[0].hex(*([self.sep] if self.sep else [])), + data[self.num:] + ) + + +class VarStrFormat: + def encode(self, value: str) -> bytes: + return bytes((len(value)+1, )) + value.encode() + bytes((0, )) + + def decode(self, data: bytes): + return data[1:data[0]].decode(), data[data[0]+1:] + + +class MacAddressFormat: + def encode(self, value: str) -> bytes: + return bytes(int(value[i*3:i*3+2], 16) for i in range(6)) + + def decode(self, data: bytes): + return (MAC_FMT % tuple(data[:6])), data[6:] + + +class MacAddressesListFormat: + def encode(self, value: list[str]) -> bytes: + return bytes((len(value), )) + sum( + (bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value), + b'' + ) + + def decode(self, data: bytes): + return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6:] + + +class LedType(IntEnum): + SERIAL = 1 + MULTIPIN = 2 + + +@dataclass +class LedConfig: + led_type: LedType = field(init=False, repr=False) + ledconfig_types = {} + + # noinspection PyMethodOverriding + def __init_subclass__(cls, /, led_type: LedType, **kwargs): + super().__init_subclass__(**kwargs) + cls.led_type = led_type + LedConfig.ledconfig_types[led_type] = cls + + @classmethod + def fromjson(cls, data): + if data is None: + return None + return LedConfig.ledconfig_types[data.pop('led_type')](**data) + + +@dataclass +class SerialLedConfig(LedConfig, led_type=LedType.SERIAL): + gpio: int + rmt: int + + +@dataclass +class MultipinLedConfig(LedConfig, led_type=LedType.MULTIPIN): + gpio_red: int + gpio_green: int + gpio_blue: int + + +class LedConfigFormat: + def encode(self, value) -> bytes: + if value is None: + return struct.pack('BBBB', (0, 0, 0, 0)) + if isinstance(value, SerialLedConfig): + return struct.pack('BBBB', (value.type_id, value.gpio, value.rmt, 0)) + if isinstance(value, MultipinLedConfig): + return struct.pack('BBBB', (value.type_id, value.gpio_red, value.gpio_green, value.gpio_blue)) + raise ValueError + + def decode(self, data: bytes): + type_, *bytes_ = struct.unpack('BBBB', data) + if type_ == 0: + value = None + elif type_ == 1: + value = SerialLedConfig(gpio=bytes_[0], rmt=bytes_[1]) + elif type_ == 2: + value = MultipinLedConfig(gpio_red=bytes_[0], gpio_green=bytes_[1], gpio_blue=bytes_[2]) + else: + raise ValueError + return value, data[4:] + + def from_json(self, value): + if value is None: + return None + + type_ = value.pop('type') + if type_ == 'serial': + return SerialLedConfig(**value) + elif type_ == 'multipin': + return MultipinLedConfig(**value) + raise ValueError diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index 66aee846..235f10f8 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -1,133 +1,40 @@ -import struct -from dataclasses import dataclass, field, fields +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from enum import IntEnum, unique +from typing import TypeVar +from c3nav.mesh.dataformats import (BoolFormat, FixedStrFormat, HexFormat, LedConfig, LedConfigFormat, + MacAddressesListFormat, MacAddressFormat, SimpleFormat, VarStrFormat) + +ROOT_ADDRESS = '00:00:00:00:00:00' +PARENT_ADDRESS = '00:00:00:ff:ff:ff' NO_LAYER = 0xFF -MAC_FMT = '%02x:%02x:%02x:%02x:%02x:%02x' -class SimpleFormat: - def __init__(self, fmt): - self.fmt = fmt - self.size = struct.calcsize(fmt) +@unique +class MessageType(IntEnum): + ECHO_REQUEST = 0x01 + ECHO_RESPONSE = 0x02 - def encode(self, value): - return struct.pack(self.fmt, value) + MESH_SIGNIN = 0x03 + MESH_LAYER_ANNOUNCE = 0x04 + MESH_ADD_DESTINATIONS = 0x05 + MESH_REMOVE_DESTINATIONS = 0x06 - def decode(self, data: bytes): - value = struct.unpack(self.fmt, data[:self.size]) - if len(value) == 1: - value = value[0] - return value, data[self.size:] + CONFIG_DUMP = 0x10 + CONFIG_FIRMWARE = 0x11 + CONFIG_POSITION = 0x12 + CONFIG_LED = 0x13 + CONFIG_UPLINK = 0x14 -class FixedStrFormat: - def __init__(self, num): - self.num = num - - def encode(self, value): - return struct.pack('%ss' % self.num, value) - - def decode(self, data: bytes): - return struct.unpack('%ss' % self.num, data[:self.num])[0].rstrip(bytes((0, ))).decode(), data[self.num:] - - -class BoolFormat: - def encode(self, value): - return struct.pack('B', (int(value), )) - - def decode(self, data: bytes): - return bool(struct.unpack('B', data[:1])[0]), data[1:] - - -class HexFormat: - def __init__(self, num, sep=''): - self.num = num - self.sep = sep - - def encode(self, value): - return struct.pack('%ss' % self.num, bytes.fromhex(value)) - - def decode(self, data: bytes): - return ( - struct.unpack('%ss' % self.num, data[:self.num])[0].hex(*([self.sep] if self.sep else [])), - data[self.num:] - ) - - -class VarStrFormat: - def encode(self, value: str) -> bytes: - return bytes((len(value)+1, )) + value.encode() + bytes((0, )) - - def decode(self, data: bytes): - return data[1:data[0]].decode(), data[data[0]+1:] - - -class MacAddressFormat: - def encode(self, value: str) -> bytes: - return bytes(int(value[i*3:i*3+2], 16) for i in range(6)) - - def decode(self, data: bytes): - return (MAC_FMT % tuple(data[:6])), data[6:] - - -class MacAddressesListFormat: - def encode(self, value: list[str]) -> bytes: - return bytes((len(value), )) + sum( - (bytes((int(mac[i*3:i*3+2], 16) for i in range(6))) for mac in value), - b'' - ) - - def decode(self, data: bytes): - return [MAC_FMT % tuple(data[1+6*i:1+6+6*i]) for i in range(data[0])], data[1+data[0]*6:] - - -class LedConfig: - pass - - -@dataclass -class SerialLedConfig(LedConfig): - type = 1 - gpio: int - rmt: int - - -@dataclass -class MultipinLedConfig(LedConfig): - type = 2 - gpio_red: int - gpio_green: int - gpio_blue: int - - -class LedConfigFormat: - def encode(self, value) -> bytes: - if value is None: - return struct.pack('BBBB', (0, 0, 0, 0)) - if isinstance(value, SerialLedConfig): - return struct.pack('BBBB', (value.type, value.gpio, value.rmt, 0)) - if isinstance(value, MultipinLedConfig): - return struct.pack('BBBB', (value.type, value.gpio_red, value.gpio_green, value.gpio_blue)) - raise ValueError - - def decode(self, data: bytes): - type_, *bytes_ = struct.unpack('BBBB', data) - if type_ == 0: - value = None - elif type_ == 1: - value = SerialLedConfig(gpio=bytes_[0], rmt=bytes_[1]) - elif type_ == 2: - value = MultipinLedConfig(gpio_red=bytes_[0], gpio_green=bytes_[1], gpio_blue=bytes_[2]) - else: - raise ValueError - return value, data[4:] +M = TypeVar('M', bound='Message') @dataclass class Message: dst: str = field(metadata={'format': MacAddressFormat()}) src: str = field(metadata={'format': MacAddressFormat()}) - msg_id: int = field(metadata={'format': SimpleFormat('B')}, init=False, repr=True) + msg_id: int = field(metadata={'format': SimpleFormat('B')}, init=False, repr=False) msg_types = {} # noinspection PyMethodOverriding @@ -135,6 +42,8 @@ class Message: super().__init_subclass__(**kwargs) if msg_id: cls.msg_id = msg_id + if msg_id in Message.msg_types: + raise TypeError('duplicate use of msg_id %d' % msg_id) Message.msg_types[msg_id] = cls def encode(self): @@ -144,8 +53,8 @@ class Message: return data @classmethod - def decode(cls, data: bytes): - print('decode', data.hex(' ')) + def decode(cls, data: bytes) -> M: + # print('decode', data.hex(' ')) klass = cls.msg_types[data[12]] values = {} for field_ in fields(klass): @@ -153,44 +62,56 @@ class Message: values.pop('msg_id') return klass(**values) + def tojson(self): + return asdict(self) + + @classmethod + def fromjson(cls, data) -> M: + kwargs = data.copy() + klass = cls.msg_types[kwargs.pop('msg_id')] + for field_ in fields(klass): + if is_dataclass(field_.type): + kwargs[field_.name] = field_.type.fromjson(kwargs[field_.name]) + return klass(**kwargs) + @dataclass -class EchoRequestMessage(Message, msg_id=0x01): +class EchoRequestMessage(Message, msg_id=MessageType.ECHO_REQUEST): content: str = field(default='', metadata={'format': VarStrFormat()}) @dataclass -class EchoResponseMessage(Message, msg_id=0x02): +class EchoResponseMessage(Message, msg_id=MessageType.ECHO_RESPONSE): content: str = field(default='', metadata={'format': VarStrFormat()}) @dataclass -class MeshSigninMessage(Message, msg_id=0x03): +class MeshSigninMessage(Message, msg_id=MessageType.MESH_SIGNIN): pass @dataclass -class MeshLayerAnnounceMessage(Message, msg_id=0x04): +class MeshLayerAnnounceMessage(Message, msg_id=MessageType.MESH_LAYER_ANNOUNCE): layer: int = field(metadata={'format': SimpleFormat('B')}) @dataclass -class MeshAddDestinationsMessage(Message, msg_id=0x05): +class MeshAddDestinationsMessage(Message, msg_id=MessageType.MESH_ADD_DESTINATIONS): mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()}) @dataclass -class MeshRemoveDestinationsMessage(Message, msg_id=0x06): +class MeshRemoveDestinationsMessage(Message, msg_id=MessageType.MESH_REMOVE_DESTINATIONS): mac_addresses: list[str] = field(default_factory=list, metadata={'format': MacAddressesListFormat()}) @dataclass -class ConfigDumpMessage(Message, msg_id=0x10): +class ConfigDumpMessage(Message, msg_id=MessageType.CONFIG_DUMP): pass @dataclass -class ConfigFirmwareMessage(Message, msg_id=0x11): +class ConfigFirmwareMessage(Message, msg_id=MessageType.CONFIG_FIRMWARE): magic_word: int = field(metadata={'format': SimpleFormat('I')}, repr=False) secure_version: int = field(metadata={'format': SimpleFormat('I')}) reserv1: list[int] = field(metadata={'format': SimpleFormat('2I')}, repr=False) @@ -204,19 +125,19 @@ class ConfigFirmwareMessage(Message, msg_id=0x11): @dataclass -class ConfigPositionMessage(Message, msg_id=0x12): +class ConfigPositionMessage(Message, msg_id=MessageType.CONFIG_POSITION): x_pos: int = field(metadata={'format': SimpleFormat('I')}) y_pos: int = field(metadata={'format': SimpleFormat('I')}) z_pos: int = field(metadata={'format': SimpleFormat('H')}) @dataclass -class ConfigLedMessage(Message, msg_id=0x13): +class ConfigLedMessage(Message, msg_id=MessageType.CONFIG_LED): led_config: LedConfig = field(metadata={'format': LedConfigFormat()}) @dataclass -class ConfigUplinkMessage(Message, msg_id=0x14): +class ConfigUplinkMessage(Message, msg_id=MessageType.CONFIG_UPLINK): enabled: bool = field(metadata={'format': BoolFormat()}) ssid: str = field(metadata={'format': FixedStrFormat(32)}) password: str = field(metadata={'format': FixedStrFormat(64)}) diff --git a/src/c3nav/mesh/migrations/0001_initial.py b/src/c3nav/mesh/migrations/0001_initial.py new file mode 100644 index 00000000..e601e6a1 --- /dev/null +++ b/src/c3nav/mesh/migrations/0001_initial.py @@ -0,0 +1,51 @@ +# Generated by Django 4.0.3 on 2022-04-15 18:00 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='MeshNode', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('address', models.CharField(max_length=17, unique=True, verbose_name='mac address')), + ('first_seen', models.DateTimeField(auto_now_add=True, verbose_name='first seen')), + ('parent_node', models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='child_nodes', to='mesh.meshnode', verbose_name='parent node')), + ('route', models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='routed_nodes', to='mesh.meshnode', verbose_name='route')), + ], + ), + migrations.CreateModel( + name='NodeMessage', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('datetime', models.DateTimeField(auto_now_add=True, db_index=True, verbose_name='datetime')), + ('message_type', models.SmallIntegerField(choices=[(1, 'ECHO_REQUEST'), (2, 'ECHO_RESPONSE'), (3, 'MESH_SIGNIN'), (4, 'MESH_LAYER_ANNOUNCE'), (5, 'MESH_ADD_DESTINATIONS'), (6, 'MESH_REMOVE_DESTINATIONS'), (16, 'CONFIG_DUMP'), (17, 'CONFIG_FIRMWARE'), (18, 'CONFIG_POSITION'), (19, 'CONFIG_LED'), (20, 'CONFIG_UPLINK')], db_index=True, verbose_name='message type')), + ('data', models.JSONField(verbose_name='message data')), + ('node', models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='received_messages', to='mesh.meshnode', verbose_name='node')), + ], + ), + migrations.CreateModel( + name='Firmware', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('chip', models.SmallIntegerField(choices=[(2, 'ESP32-S2'), (5, 'ESP32-C3')], db_index=True, verbose_name='chip')), + ('project_name', models.CharField(max_length=32, verbose_name='project name')), + ('version', models.CharField(max_length=32, verbose_name='firmware version')), + ('idf_version', models.CharField(max_length=32, verbose_name='IDF version')), + ('compile_time', models.DateTimeField(verbose_name='compile time')), + ('sha256_hash', models.CharField(max_length=64, unique=True, verbose_name='SHA256 hash')), + ('binary', models.FileField(null=True, upload_to='', verbose_name='firmware file')), + ], + options={ + 'unique_together': {('chip', 'project_name', 'version', 'idf_version', 'compile_time', 'sha256_hash')}, + }, + ), + ] diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py index e69de29b..bb413654 100644 --- a/src/c3nav/mesh/models.py +++ b/src/c3nav/mesh/models.py @@ -0,0 +1,48 @@ +from django.db import models +from django.utils.translation import gettext_lazy as _ + +from c3nav.mesh.messages import MessageType + + +class ChipID(models.IntegerChoices): + ESP32S2 = 2, 'ESP32-S2' + ESP32C3 = 5, 'ESP32-C3' + + +class MeshNode(models.Model): + address = models.CharField(_('mac address'), max_length=17, unique=True) + first_seen = models.DateTimeField(_('first seen'), auto_now_add=True) + parent_node = models.ForeignKey('MeshNode', models.PROTECT, null=True, + related_name='child_nodes', verbose_name=_('parent node')) + route = models.ForeignKey('MeshNode', models.PROTECT, null=True, + related_name='routed_nodes', verbose_name=_('route')) + + def __str__(self): + return self.address + + +class NodeMessage(models.Model): + MESSAGE_TYPES = [(msgtype.value, msgtype.name) for msgtype in MessageType] + node = models.ForeignKey('MeshNode', models.PROTECT, null=True, + related_name='received_messages', verbose_name=_('node')) + datetime = models.DateTimeField(_('datetime'), db_index=True, auto_now_add=True) + message_type = models.SmallIntegerField(_('message type'), db_index=True, choices=MESSAGE_TYPES) + data = models.JSONField(_('message data')) + + def __str__(self): + return '(#%d) %s at %s' % (self.pk, self.get_message_type_display(), self.datetime) + + +class Firmware(models.Model): + chip = models.SmallIntegerField(_('chip'), db_index=True, choices=ChipID.choices) + project_name = models.CharField(_('project name'), max_length=32) + version = models.CharField(_('firmware version'), max_length=32) + idf_version = models.CharField(_('IDF version'), max_length=32) + compile_time = models.DateTimeField(_('compile time')) + sha256_hash = models.CharField(_('SHA256 hash'), unique=True, max_length=64) + binary = models.FileField(_('firmware file'), null=True) + + class Meta: + unique_together = [ + ('chip', 'project_name', 'version', 'idf_version', 'compile_time', 'sha256_hash'), + ]