From 4d3f54bbe8ec21d1b5593917d6b726ceeff4c157 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Tue, 3 Oct 2023 17:23:29 +0200 Subject: [PATCH] more mesh communication implementation, lots of fixes and stuff --- src/c3nav/control/forms.py | 2 +- .../templates/control/mesh_messages.html | 6 +- .../control/templates/control/mesh_nodes.html | 27 +++- src/c3nav/control/views/mesh.py | 8 +- src/c3nav/mesh/consumers.py | 140 +++++++++--------- src/c3nav/mesh/messages.py | 9 ++ ...4_relay_vs_src_node_and_remove_firmware.py | 70 +++++++++ src/c3nav/mesh/models.py | 92 ++++++++++-- 8 files changed, 260 insertions(+), 94 deletions(-) create mode 100644 src/c3nav/mesh/migrations/0004_relay_vs_src_node_and_remove_firmware.py diff --git a/src/c3nav/control/forms.py b/src/c3nav/control/forms.py index b1aa5e8e..4273aafd 100644 --- a/src/c3nav/control/forms.py +++ b/src/c3nav/control/forms.py @@ -306,7 +306,7 @@ class MeshMessageFilerForm(Form): required=False, label=_('message types'), ) - nodes = ModelMultipleChoiceField( + src_nodes = ModelMultipleChoiceField( queryset=MeshNode.objects.all(), required=False, label=_('nodes'), diff --git a/src/c3nav/control/templates/control/mesh_messages.html b/src/c3nav/control/templates/control/mesh_messages.html index 948a1ac7..0f978313 100644 --- a/src/c3nav/control/templates/control/mesh_messages.html +++ b/src/c3nav/control/templates/control/mesh_messages.html @@ -10,7 +10,7 @@ {{ form.message_types }}
- {{ form.nodes }} + {{ form.src_nodes }}
@@ -26,11 +26,12 @@ {% trans 'Node' %} {% trans 'Type' %} {% trans 'Data' %} + {% trans 'Uplink' %} {% for msg in mesh_messages %} {{ msg.datetime }} - {{ msg.node }} + {{ msg.src_node }} {{ msg.get_message_type_display }} @@ -42,6 +43,7 @@ {% endif %} {% endfor %} + {{ msg.uplink_node }} {% endfor %} diff --git a/src/c3nav/control/templates/control/mesh_nodes.html b/src/c3nav/control/templates/control/mesh_nodes.html index 0796916b..2f7eda76 100644 --- a/src/c3nav/control/templates/control/mesh_nodes.html +++ b/src/c3nav/control/templates/control/mesh_nodes.html @@ -6,8 +6,8 @@ {% block subcontent %} - - + + @@ -16,17 +16,30 @@ {% for node in nodes %} - - + + + - + {% endfor %}
{% trans 'Address' %}{% trans 'Name' %}{% trans 'Node' %}{% trans 'Status' %} {% trans 'Chip' %} {% trans 'Firmware' %} {% trans 'Last msg' %}
{{ node.address }}{{ node.name }} - {{ node.firmware.get_chip_display }} + {% if node.route %} + {% trans "online" %} + {% else %} + {% trans "offline" %} + {% endif %} + {{ node }} + {{ node.last_messages.CONFIG_FIRMWARE.parsed.get_chip_display }} + rev{{ node.last_messages.CONFIG_FIRMWARE.parsed.revision|join:"." }} - {{ node.firmware.version }} (IDF {{ node.firmware.idf_version }}) + {{ node.last_messages.CONFIG_FIRMWARE.parsed.version }} + (IDF {{ node.last_messages.CONFIG_FIRMWARE.parsed.idf_version }}) + + {% blocktrans trimmed with timesince=node.last_msg|timesince %} + {{ timesince }} ago + {% endblocktrans %} {{ node.last_msg }} {{ node.parent }} {{ node.route }}{{ node.last_messages.CONFIG_FIRMWARE.data }}
diff --git a/src/c3nav/control/views/mesh.py b/src/c3nav/control/views/mesh.py index 7640b19a..04b8891d 100644 --- a/src/c3nav/control/views/mesh.py +++ b/src/c3nav/control/views/mesh.py @@ -1,6 +1,5 @@ from django.db.models import Max from django.views.generic import ListView -from django.views.generic.edit import FormMixin from c3nav.control.forms import MeshMessageFilerForm from c3nav.control.views.base import ControlPanelMixin @@ -14,7 +13,7 @@ class MeshNodeListView(ControlPanelMixin, ListView): context_object_name = "nodes" def get_queryset(self): - return super().get_queryset().annotate(last_msg=Max('received_messages__datetime')) + return super().get_queryset().annotate(last_msg=Max('received_messages__datetime')).prefetch_last_messages() class MeshMessageListView(ControlPanelMixin, ListView): @@ -31,8 +30,8 @@ class MeshMessageListView(ControlPanelMixin, ListView): if self.form.is_valid(): if self.form.cleaned_data['message_types']: qs = qs.filter(message_type__in=self.form.cleaned_data['message_types']) - if self.form.cleaned_data['nodes']: - qs = qs.filter(node__in=self.form.cleaned_data['nodes']) + if self.form.cleaned_data['src_nodes']: + qs = qs.filter(src_node__in=self.form.cleaned_data['src_nodes']) return qs @@ -47,4 +46,3 @@ class MeshMessageListView(ControlPanelMixin, ListView): 'form_data': form_data.urlencode(), }) return ctx - diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index e5dda401..987f021c 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -1,31 +1,33 @@ import traceback -from channels.db import database_sync_to_async -from channels.generic.websocket import AsyncWebsocketConsumer +from asgiref.sync import async_to_sync +from channels.generic.websocket import WebsocketConsumer from c3nav.mesh import messages -from c3nav.mesh.models import MeshNode, NodeMessage, Firmware +from c3nav.mesh.models import MeshNode, NodeMessage -class MeshConsumer(AsyncWebsocketConsumer): - async def connect(self): +# noinspection PyAttributeOutsideInit +class MeshConsumer(WebsocketConsumer): + def connect(self): print('connected!') # todo: auth - self.node = None - await self.accept() + self.uplink_node = None + self.dst_nodes = set() + self.accept() - async def disconnect(self, close_code): + def disconnect(self, close_code): print('disconnected!') - if self.node is not None: - await self.remove_route(self.node.address) - await self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name) - await self.channel_layer.group_discard('route_broadcast', self.channel_name) + if self.uplink_node is not None: + self.remove_route(self.uplink_node) + self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name) + self.channel_layer.group_discard('route_broadcast', self.channel_name) - async def send_msg(self, msg): + def send_msg(self, msg): print('Sending message:', msg) - await self.send(bytes_data=msg.encode()) + self.send(bytes_data=msg.encode()) - async def receive(self, text_data=None, bytes_data=None): + def receive(self, text_data=None, bytes_data=None): if bytes_data is None: return try: @@ -41,77 +43,81 @@ class MeshConsumer(AsyncWebsocketConsumer): print('Received message:', msg) + src_node, created = MeshNode.objects.get_or_create(address=msg.src) + if isinstance(msg, messages.MeshSigninMessage): - self.node, created = await self.get_node(msg.src) - if created: - print('New node signing in!') - print(self.node) - await self.log_received_message(msg) - await self.send_msg(messages.MeshLayerAnnounceMessage( + self.uplink_node = src_node + # log message, since we will not log it further down + self.log_received_message(src_node, msg) + + # inform signed in uplink node about its layer + self.send_msg(messages.MeshLayerAnnounceMessage( src=messages.ROOT_ADDRESS, dst=msg.src, layer=messages.NO_LAYER )) - await self.send_msg(messages.ConfigDumpMessage( - src=messages.ROOT_ADDRESS, - dst=msg.src, - )) - await self.channel_layer.group_add('route_%s' % self.node.address.replace(':', ''), self.channel_name) - await self.channel_layer.group_add('route_broadcast', self.channel_name) - await self.set_parent_of_nodes(None, (self.node.address, )) - await self.add_route_to_nodes(self.node.address, (self.node.address,)) + + # add signed in uplink node to broadcast route + async_to_sync(self.channel_layer.group_add)('mesh_broadcast', self.channel_name) + + # add this node as a destination that this uplink handles (duh) + self.add_dst_nodes((src_node.address, )) + return - if self.node is None: + if self.uplink_node is None: print('Expected sign-in message, but got a different one!') - await self.close() + self.close() return - await self.log_received_message(msg) + self.log_received_message(src_node, msg) - if isinstance(msg, messages.ConfigFirmwareMessage): - await self._handle_config_firmware_msg(msg) - return + def uplink_change(self, data): + # message handler: if we are not the given uplink, leave this group + if data["uplink"] != self.uplink_node.address: + group = self.group_name_for_node(data["address"]) + print('leaving uplink group...') + async_to_sync(self.channel_layer.group_discard)(group, self.channel_name) - @database_sync_to_async - def _handle_config_firmware_msg(self, msg): - self.firmware, created = Firmware.objects.get_or_create(**msg.to_model_data()) - self.node.firmware = self.firmware - self.node.save() - - @database_sync_to_async - def get_node(self, address): - return MeshNode.objects.get_or_create(address=address) - - @database_sync_to_async - def log_received_message(self, msg: messages.Message): + def log_received_message(self, src_node: MeshNode, msg: messages.Message): NodeMessage.objects.create( - node=self.node, + uplink_node=self.uplink_node, + src_node=src_node, message_type=msg.msg_id, data=msg.tojson() ) - @database_sync_to_async - def create_nodes(self, addresses): - MeshNode.objects.bulk_create(MeshNode(address=address) for address in addresses) + def add_dst_nodes(self, addresses): + # add ourselves to this one + for address in addresses: + # create group name for this address + group = self.group_name_for_node(address) - @database_sync_to_async - def set_parent_of_nodes(self, parent_address, node_addresses): - MeshNode.objects.filter(address__in=node_addresses).update(parent_node_id=parent_address) + # if we aren't handling this address yet, join the group + if address not in self.dst_nodes: + async_to_sync(self.channel_layer.group_add)(group, self.channel_name) + self.dst_nodes.add(address) - @database_sync_to_async - def add_route_to_nodes(self, route_address, node_addresses): - MeshNode.objects.filter(address__in=node_addresses).update(route_id=route_address) + # tell other consumers to leave the group + async_to_sync(self.channel_layer.group_send)(group, { + "type": "uplink_change", + "node": address, + "uplink": self.uplink_node.address + }) + + # tell the node to dump its current information + self.send_msg( + messages.ConfigDumpMessage( + src=messages.ROOT_ADDRESS, + dst=address, + ) + ) + + # add the stuff to the db as well + MeshNode.objects.filter(address__in=addresses).update(route_id=self.uplink_node.address) + + def group_name_for_node(self, address): + return 'mesh_%s' % address.replace(':', '-') - @database_sync_to_async def remove_route(self, route_address): MeshNode.objects.filter(route_id=route_address).update(route_id=None) - - @database_sync_to_async - def remove_route_to_nodes(self, route_address, node_addresses): - MeshNode.objects.filter(address__in=node_addresses, route_id=route_address).update(route_id=None) - - @database_sync_to_async - def set_node_firmware(self, firmware): - self.node.firmware = firmware - self.node.save() \ No newline at end of file diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index 0953d5e2..11126fdb 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -30,6 +30,12 @@ class MessageType(IntEnum): M = TypeVar('M', bound='Message') +@unique +class ChipType(IntEnum): + ESP32_S2 = 2 + ESP32_C3 = 5 + + @dataclass class Message: dst: str = field(metadata={'format': MacAddressFormat()}) @@ -134,6 +140,9 @@ class ConfigFirmwareMessage(Message, msg_id=MessageType.CONFIG_FIRMWARE): 'sha256_hash': self.app_elf_sha256, } + def get_chip_display(self): + return ChipType(self.chip).name.replace('_', '-') + @dataclass class ConfigPositionMessage(Message, msg_id=MessageType.CONFIG_POSITION): diff --git a/src/c3nav/mesh/migrations/0004_relay_vs_src_node_and_remove_firmware.py b/src/c3nav/mesh/migrations/0004_relay_vs_src_node_and_remove_firmware.py new file mode 100644 index 00000000..61c1dafc --- /dev/null +++ b/src/c3nav/mesh/migrations/0004_relay_vs_src_node_and_remove_firmware.py @@ -0,0 +1,70 @@ +# Generated by Django 4.2.1 on 2023-10-03 13:42 + +from django.db import migrations, models +import django.db.models.deletion + + +def forwards_func(apps, schema_editor): + NodeMessage = apps.get_model("mesh", "NodeMessage") + MeshNode = apps.get_model("mesh", "MeshNode") + + NodeMessage.objects.filter(uplink_node=None).delete() + + nodes = {node.address: node for node in MeshNode.objects.all()} + + for msg in NodeMessage.objects.all(): + if msg.data["src"] not in nodes: + nodes[msg.data["src"]] = MeshNode.objects.create(address=msg.data["src"]) + msg.src_node = nodes[msg.data["src"]] + msg.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ('mesh', '0003_meshnode_name'), + ] + + operations = [ + migrations.RenameField( + model_name='nodemessage', + old_name='node', + new_name='uplink_node', + ), + migrations.RemoveField( + model_name='meshnode', + name='firmware', + ), + migrations.RemoveField( + model_name='meshnode', + name='parent_node', + ), + + migrations.AddField( + model_name='nodemessage', + name='src_node', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, + related_name='received_messages', to='mesh.meshnode', verbose_name='node'), + ), + migrations.AlterField( + model_name='nodemessage', + name='uplink_node', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, + related_name='relayed_messages', to='mesh.meshnode', verbose_name='uplink node'), + ), + + migrations.RunPython(forwards_func, migrations.RunPython.noop), + + migrations.AlterField( + model_name='nodemessage', + name='src_node', + field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='received_messages', + to='mesh.meshnode', verbose_name='node'), + ), + migrations.AlterField( + model_name='nodemessage', + name='uplink_node', + field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='relayed_messages', + to='mesh.meshnode', verbose_name='uplink node'), + ), + ] diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py index 5f57ffa1..5365edcf 100644 --- a/src/c3nav/mesh/models.py +++ b/src/c3nav/mesh/models.py @@ -1,35 +1,98 @@ -from django.db import models +from collections import UserDict +from functools import cached_property + +from django.db import models, NotSupportedError from django.utils.translation import gettext_lazy as _ -from c3nav.mesh.messages import MessageType +from c3nav.mesh.messages import MessageType, ChipType, Message as MeshMessage -class ChipID(models.IntegerChoices): - ESP32S2 = 2, 'ESP32-S2' - ESP32C3 = 5, 'ESP32-C3' +class MeshNodeQuerySet(models.QuerySet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._prefetch_last_messages = set() + self._prefetch_last_messages_done = False + + def _clone(self): + clone = super()._clone() + clone._prefetch_last_messages = self._prefetch_last_messages + return clone + + def prefetch_last_messages(self, *types: MessageType): + clone = self._chain() + clone._prefetch_last_messages |= ( + set(types) if types else set(msgtype.value for msgtype in MessageType) + ) + return clone + + def _fetch_all(self): + super()._fetch_all() + if self._prefetch_last_messages and not self._prefetch_last_messages_done: + nodes = {node.pk: node for node in self._result_cache} + try: + for message in NodeMessage.objects.order_by('-datetime', '-pk').filter( + message_type__in=self._prefetch_last_messages, + node__in=nodes.keys(), + ).distinct('message_type', 'node'): + nodes[message.node].last_messages[message.message_type] = message + except NotSupportedError: + pass + print(tuple(nodes.values())[0].last_messages[MessageType.MESH_SIGNIN]) + + +class LastMessagesByTypeLookup(UserDict): + def __init__(self, node): + super().__init__() + self.node = node + + def _get_key(self, item): + if isinstance(item, MessageType): + return item + if isinstance(item, str): + try: + return getattr(MessageType, item) + except AttributeError: + pass + return MessageType(item) + + def __getitem__(self, key): + key = self._get_key(key) + try: + return self.data[key] + except KeyError: + pass + msg = self.node.received_messages.filter(message_type=key).order_by('-datetime', '-pk').first() + self.data[key] = msg + return msg + + def __setitem__(self, key, item): + self.data[self._get_key(key)] = item class MeshNode(models.Model): address = models.CharField(_('mac address'), max_length=17, primary_key=True) name = models.CharField(_('name'), max_length=32, null=True, blank=True) first_seen = models.DateTimeField(_('first seen'), auto_now_add=True) - parent_node = models.ForeignKey('MeshNode', models.PROTECT, null=True, - related_name='child_nodes', verbose_name=_('parent node')) route = models.ForeignKey('MeshNode', models.PROTECT, null=True, related_name='routed_nodes', verbose_name=_('route')) - firmware = models.ForeignKey('Firmware', models.PROTECT, null=True, - related_name='installed_on', verbose_name=_('firmware')) + objects = models.Manager.from_queryset(MeshNodeQuerySet)() def __str__(self): if self.name: return '%s (%s)' % (self.address, self.name) return self.address + @cached_property + def last_messages(self): + return LastMessagesByTypeLookup(self) + class NodeMessage(models.Model): MESSAGE_TYPES = [(msgtype.value, msgtype.name) for msgtype in MessageType] - node = models.ForeignKey('MeshNode', models.PROTECT, null=True, - related_name='received_messages', verbose_name=_('node')) + src_node = models.ForeignKey('MeshNode', models.PROTECT, + related_name='received_messages', verbose_name=_('node')) + uplink_node = models.ForeignKey('MeshNode', models.PROTECT, + related_name='relayed_messages', verbose_name=_('uplink node')) datetime = models.DateTimeField(_('datetime'), db_index=True, auto_now_add=True) message_type = models.SmallIntegerField(_('message type'), db_index=True, choices=MESSAGE_TYPES) data = models.JSONField(_('message data')) @@ -37,9 +100,14 @@ class NodeMessage(models.Model): def __str__(self): return '(#%d) %s at %s' % (self.pk, self.get_message_type_display(), self.datetime) + @cached_property + def parsed(self): + return MeshMessage.fromjson(self.data) + class Firmware(models.Model): - chip = models.SmallIntegerField(_('chip'), db_index=True, choices=ChipID.choices) + CHIPS = [(msgtype.value, msgtype.name.replace('_', '-')) for msgtype in ChipType] + chip = models.SmallIntegerField(_('chip'), db_index=True, choices=CHIPS) project_name = models.CharField(_('project name'), max_length=32) version = models.CharField(_('firmware version'), max_length=32) idf_version = models.CharField(_('IDF version'), max_length=32)