diff --git a/src/c3nav/control/forms.py b/src/c3nav/control/forms.py
index b1aa5e8e..4273aafd 100644
--- a/src/c3nav/control/forms.py
+++ b/src/c3nav/control/forms.py
@@ -306,7 +306,7 @@ class MeshMessageFilerForm(Form):
required=False,
label=_('message types'),
)
- nodes = ModelMultipleChoiceField(
+ src_nodes = ModelMultipleChoiceField(
queryset=MeshNode.objects.all(),
required=False,
label=_('nodes'),
diff --git a/src/c3nav/control/templates/control/mesh_messages.html b/src/c3nav/control/templates/control/mesh_messages.html
index 948a1ac7..0f978313 100644
--- a/src/c3nav/control/templates/control/mesh_messages.html
+++ b/src/c3nav/control/templates/control/mesh_messages.html
@@ -10,7 +10,7 @@
{{ form.message_types }}
- {{ form.nodes }}
+ {{ form.src_nodes }}
@@ -26,11 +26,12 @@
{% trans 'Node' %} |
{% trans 'Type' %} |
{% trans 'Data' %} |
+
{% trans 'Uplink' %} |
{% for msg in mesh_messages %}
{{ msg.datetime }} |
- {{ msg.node }} |
+ {{ msg.src_node }} |
{{ msg.get_message_type_display }} |
@@ -42,6 +43,7 @@
{% endif %}
{% endfor %}
|
+ {{ msg.uplink_node }} |
{% endfor %}
diff --git a/src/c3nav/control/templates/control/mesh_nodes.html b/src/c3nav/control/templates/control/mesh_nodes.html
index 0796916b..2f7eda76 100644
--- a/src/c3nav/control/templates/control/mesh_nodes.html
+++ b/src/c3nav/control/templates/control/mesh_nodes.html
@@ -6,8 +6,8 @@
{% block subcontent %}
- {% trans 'Address' %} |
- {% trans 'Name' %} |
+ {% trans 'Node' %} |
+ {% trans 'Status' %} |
{% trans 'Chip' %} |
{% trans 'Firmware' %} |
{% trans 'Last msg' %} |
@@ -16,17 +16,30 @@
{% for node in nodes %}
- {{ node.address }} |
- {{ node.name }} |
- {{ node.firmware.get_chip_display }}
+ {% if node.route %}
+ {% trans "online" %}
+ {% else %}
+ {% trans "offline" %}
+ {% endif %}
+ |
+ {{ node }} |
+
+ {{ node.last_messages.CONFIG_FIRMWARE.parsed.get_chip_display }}
+ rev{{ node.last_messages.CONFIG_FIRMWARE.parsed.revision|join:"." }}
|
- {{ node.firmware.version }} (IDF {{ node.firmware.idf_version }})
+ {{ node.last_messages.CONFIG_FIRMWARE.parsed.version }}
+ (IDF {{ node.last_messages.CONFIG_FIRMWARE.parsed.idf_version }})
+ |
+
+ {% blocktrans trimmed with timesince=node.last_msg|timesince %}
+ {{ timesince }} ago
+ {% endblocktrans %}
|
- {{ node.last_msg }} |
{{ node.parent }} |
{{ node.route }} |
+ {{ node.last_messages.CONFIG_FIRMWARE.data }} |
{% endfor %}
diff --git a/src/c3nav/control/views/mesh.py b/src/c3nav/control/views/mesh.py
index 7640b19a..04b8891d 100644
--- a/src/c3nav/control/views/mesh.py
+++ b/src/c3nav/control/views/mesh.py
@@ -1,6 +1,5 @@
from django.db.models import Max
from django.views.generic import ListView
-from django.views.generic.edit import FormMixin
from c3nav.control.forms import MeshMessageFilerForm
from c3nav.control.views.base import ControlPanelMixin
@@ -14,7 +13,7 @@ class MeshNodeListView(ControlPanelMixin, ListView):
context_object_name = "nodes"
def get_queryset(self):
- return super().get_queryset().annotate(last_msg=Max('received_messages__datetime'))
+ return super().get_queryset().annotate(last_msg=Max('received_messages__datetime')).prefetch_last_messages()
class MeshMessageListView(ControlPanelMixin, ListView):
@@ -31,8 +30,8 @@ class MeshMessageListView(ControlPanelMixin, ListView):
if self.form.is_valid():
if self.form.cleaned_data['message_types']:
qs = qs.filter(message_type__in=self.form.cleaned_data['message_types'])
- if self.form.cleaned_data['nodes']:
- qs = qs.filter(node__in=self.form.cleaned_data['nodes'])
+ if self.form.cleaned_data['src_nodes']:
+ qs = qs.filter(src_node__in=self.form.cleaned_data['src_nodes'])
return qs
@@ -47,4 +46,3 @@ class MeshMessageListView(ControlPanelMixin, ListView):
'form_data': form_data.urlencode(),
})
return ctx
-
diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py
index e5dda401..987f021c 100644
--- a/src/c3nav/mesh/consumers.py
+++ b/src/c3nav/mesh/consumers.py
@@ -1,31 +1,33 @@
import traceback
-from channels.db import database_sync_to_async
-from channels.generic.websocket import AsyncWebsocketConsumer
+from asgiref.sync import async_to_sync
+from channels.generic.websocket import WebsocketConsumer
from c3nav.mesh import messages
-from c3nav.mesh.models import MeshNode, NodeMessage, Firmware
+from c3nav.mesh.models import MeshNode, NodeMessage
-class MeshConsumer(AsyncWebsocketConsumer):
- async def connect(self):
+# noinspection PyAttributeOutsideInit
+class MeshConsumer(WebsocketConsumer):
+ def connect(self):
print('connected!')
# todo: auth
- self.node = None
- await self.accept()
+ self.uplink_node = None
+ self.dst_nodes = set()
+ self.accept()
- async def disconnect(self, close_code):
+ def disconnect(self, close_code):
print('disconnected!')
- if self.node is not None:
- await self.remove_route(self.node.address)
- await self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name)
- await self.channel_layer.group_discard('route_broadcast', self.channel_name)
+ if self.uplink_node is not None:
+ self.remove_route(self.uplink_node)
+ self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name)
+ self.channel_layer.group_discard('route_broadcast', self.channel_name)
- async def send_msg(self, msg):
+ def send_msg(self, msg):
print('Sending message:', msg)
- await self.send(bytes_data=msg.encode())
+ self.send(bytes_data=msg.encode())
- async def receive(self, text_data=None, bytes_data=None):
+ def receive(self, text_data=None, bytes_data=None):
if bytes_data is None:
return
try:
@@ -41,77 +43,81 @@ class MeshConsumer(AsyncWebsocketConsumer):
print('Received message:', msg)
+ src_node, created = MeshNode.objects.get_or_create(address=msg.src)
+
if isinstance(msg, messages.MeshSigninMessage):
- self.node, created = await self.get_node(msg.src)
- if created:
- print('New node signing in!')
- print(self.node)
- await self.log_received_message(msg)
- await self.send_msg(messages.MeshLayerAnnounceMessage(
+ self.uplink_node = src_node
+ # log message, since we will not log it further down
+ self.log_received_message(src_node, msg)
+
+ # inform signed in uplink node about its layer
+ self.send_msg(messages.MeshLayerAnnounceMessage(
src=messages.ROOT_ADDRESS,
dst=msg.src,
layer=messages.NO_LAYER
))
- await self.send_msg(messages.ConfigDumpMessage(
- src=messages.ROOT_ADDRESS,
- dst=msg.src,
- ))
- await self.channel_layer.group_add('route_%s' % self.node.address.replace(':', ''), self.channel_name)
- await self.channel_layer.group_add('route_broadcast', self.channel_name)
- await self.set_parent_of_nodes(None, (self.node.address, ))
- await self.add_route_to_nodes(self.node.address, (self.node.address,))
+
+ # add signed in uplink node to broadcast route
+ async_to_sync(self.channel_layer.group_add)('mesh_broadcast', self.channel_name)
+
+ # add this node as a destination that this uplink handles (duh)
+ self.add_dst_nodes((src_node.address, ))
+
return
- if self.node is None:
+ if self.uplink_node is None:
print('Expected sign-in message, but got a different one!')
- await self.close()
+ self.close()
return
- await self.log_received_message(msg)
+ self.log_received_message(src_node, msg)
- if isinstance(msg, messages.ConfigFirmwareMessage):
- await self._handle_config_firmware_msg(msg)
- return
+ def uplink_change(self, data):
+ # message handler: if we are not the given uplink, leave this group
+ if data["uplink"] != self.uplink_node.address:
+ group = self.group_name_for_node(data["address"])
+ print('leaving uplink group...')
+ async_to_sync(self.channel_layer.group_discard)(group, self.channel_name)
- @database_sync_to_async
- def _handle_config_firmware_msg(self, msg):
- self.firmware, created = Firmware.objects.get_or_create(**msg.to_model_data())
- self.node.firmware = self.firmware
- self.node.save()
-
- @database_sync_to_async
- def get_node(self, address):
- return MeshNode.objects.get_or_create(address=address)
-
- @database_sync_to_async
- def log_received_message(self, msg: messages.Message):
+ def log_received_message(self, src_node: MeshNode, msg: messages.Message):
NodeMessage.objects.create(
- node=self.node,
+ uplink_node=self.uplink_node,
+ src_node=src_node,
message_type=msg.msg_id,
data=msg.tojson()
)
- @database_sync_to_async
- def create_nodes(self, addresses):
- MeshNode.objects.bulk_create(MeshNode(address=address) for address in addresses)
+ def add_dst_nodes(self, addresses):
+ # add ourselves to this one
+ for address in addresses:
+ # create group name for this address
+ group = self.group_name_for_node(address)
- @database_sync_to_async
- def set_parent_of_nodes(self, parent_address, node_addresses):
- MeshNode.objects.filter(address__in=node_addresses).update(parent_node_id=parent_address)
+ # if we aren't handling this address yet, join the group
+ if address not in self.dst_nodes:
+ async_to_sync(self.channel_layer.group_add)(group, self.channel_name)
+ self.dst_nodes.add(address)
- @database_sync_to_async
- def add_route_to_nodes(self, route_address, node_addresses):
- MeshNode.objects.filter(address__in=node_addresses).update(route_id=route_address)
+ # tell other consumers to leave the group
+ async_to_sync(self.channel_layer.group_send)(group, {
+ "type": "uplink_change",
+ "node": address,
+ "uplink": self.uplink_node.address
+ })
+
+ # tell the node to dump its current information
+ self.send_msg(
+ messages.ConfigDumpMessage(
+ src=messages.ROOT_ADDRESS,
+ dst=address,
+ )
+ )
+
+ # add the stuff to the db as well
+ MeshNode.objects.filter(address__in=addresses).update(route_id=self.uplink_node.address)
+
+ def group_name_for_node(self, address):
+ return 'mesh_%s' % address.replace(':', '-')
- @database_sync_to_async
def remove_route(self, route_address):
MeshNode.objects.filter(route_id=route_address).update(route_id=None)
-
- @database_sync_to_async
- def remove_route_to_nodes(self, route_address, node_addresses):
- MeshNode.objects.filter(address__in=node_addresses, route_id=route_address).update(route_id=None)
-
- @database_sync_to_async
- def set_node_firmware(self, firmware):
- self.node.firmware = firmware
- self.node.save()
\ No newline at end of file
diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py
index 0953d5e2..11126fdb 100644
--- a/src/c3nav/mesh/messages.py
+++ b/src/c3nav/mesh/messages.py
@@ -30,6 +30,12 @@ class MessageType(IntEnum):
M = TypeVar('M', bound='Message')
+@unique
+class ChipType(IntEnum):
+ ESP32_S2 = 2
+ ESP32_C3 = 5
+
+
@dataclass
class Message:
dst: str = field(metadata={'format': MacAddressFormat()})
@@ -134,6 +140,9 @@ class ConfigFirmwareMessage(Message, msg_id=MessageType.CONFIG_FIRMWARE):
'sha256_hash': self.app_elf_sha256,
}
+ def get_chip_display(self):
+ return ChipType(self.chip).name.replace('_', '-')
+
@dataclass
class ConfigPositionMessage(Message, msg_id=MessageType.CONFIG_POSITION):
diff --git a/src/c3nav/mesh/migrations/0004_relay_vs_src_node_and_remove_firmware.py b/src/c3nav/mesh/migrations/0004_relay_vs_src_node_and_remove_firmware.py
new file mode 100644
index 00000000..61c1dafc
--- /dev/null
+++ b/src/c3nav/mesh/migrations/0004_relay_vs_src_node_and_remove_firmware.py
@@ -0,0 +1,70 @@
+# Generated by Django 4.2.1 on 2023-10-03 13:42
+
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+def forwards_func(apps, schema_editor):
+ NodeMessage = apps.get_model("mesh", "NodeMessage")
+ MeshNode = apps.get_model("mesh", "MeshNode")
+
+ NodeMessage.objects.filter(uplink_node=None).delete()
+
+ nodes = {node.address: node for node in MeshNode.objects.all()}
+
+ for msg in NodeMessage.objects.all():
+ if msg.data["src"] not in nodes:
+ nodes[msg.data["src"]] = MeshNode.objects.create(address=msg.data["src"])
+ msg.src_node = nodes[msg.data["src"]]
+ msg.save()
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('mesh', '0003_meshnode_name'),
+ ]
+
+ operations = [
+ migrations.RenameField(
+ model_name='nodemessage',
+ old_name='node',
+ new_name='uplink_node',
+ ),
+ migrations.RemoveField(
+ model_name='meshnode',
+ name='firmware',
+ ),
+ migrations.RemoveField(
+ model_name='meshnode',
+ name='parent_node',
+ ),
+
+ migrations.AddField(
+ model_name='nodemessage',
+ name='src_node',
+ field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT,
+ related_name='received_messages', to='mesh.meshnode', verbose_name='node'),
+ ),
+ migrations.AlterField(
+ model_name='nodemessage',
+ name='uplink_node',
+ field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT,
+ related_name='relayed_messages', to='mesh.meshnode', verbose_name='uplink node'),
+ ),
+
+ migrations.RunPython(forwards_func, migrations.RunPython.noop),
+
+ migrations.AlterField(
+ model_name='nodemessage',
+ name='src_node',
+ field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='received_messages',
+ to='mesh.meshnode', verbose_name='node'),
+ ),
+ migrations.AlterField(
+ model_name='nodemessage',
+ name='uplink_node',
+ field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='relayed_messages',
+ to='mesh.meshnode', verbose_name='uplink node'),
+ ),
+ ]
diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py
index 5f57ffa1..5365edcf 100644
--- a/src/c3nav/mesh/models.py
+++ b/src/c3nav/mesh/models.py
@@ -1,35 +1,98 @@
-from django.db import models
+from collections import UserDict
+from functools import cached_property
+
+from django.db import models, NotSupportedError
from django.utils.translation import gettext_lazy as _
-from c3nav.mesh.messages import MessageType
+from c3nav.mesh.messages import MessageType, ChipType, Message as MeshMessage
-class ChipID(models.IntegerChoices):
- ESP32S2 = 2, 'ESP32-S2'
- ESP32C3 = 5, 'ESP32-C3'
+class MeshNodeQuerySet(models.QuerySet):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._prefetch_last_messages = set()
+ self._prefetch_last_messages_done = False
+
+ def _clone(self):
+ clone = super()._clone()
+ clone._prefetch_last_messages = self._prefetch_last_messages
+ return clone
+
+ def prefetch_last_messages(self, *types: MessageType):
+ clone = self._chain()
+ clone._prefetch_last_messages |= (
+ set(types) if types else set(msgtype.value for msgtype in MessageType)
+ )
+ return clone
+
+ def _fetch_all(self):
+ super()._fetch_all()
+ if self._prefetch_last_messages and not self._prefetch_last_messages_done:
+ nodes = {node.pk: node for node in self._result_cache}
+ try:
+ for message in NodeMessage.objects.order_by('-datetime', '-pk').filter(
+ message_type__in=self._prefetch_last_messages,
+ node__in=nodes.keys(),
+ ).distinct('message_type', 'node'):
+ nodes[message.node].last_messages[message.message_type] = message
+ except NotSupportedError:
+ pass
+ print(tuple(nodes.values())[0].last_messages[MessageType.MESH_SIGNIN])
+
+
+class LastMessagesByTypeLookup(UserDict):
+ def __init__(self, node):
+ super().__init__()
+ self.node = node
+
+ def _get_key(self, item):
+ if isinstance(item, MessageType):
+ return item
+ if isinstance(item, str):
+ try:
+ return getattr(MessageType, item)
+ except AttributeError:
+ pass
+ return MessageType(item)
+
+ def __getitem__(self, key):
+ key = self._get_key(key)
+ try:
+ return self.data[key]
+ except KeyError:
+ pass
+ msg = self.node.received_messages.filter(message_type=key).order_by('-datetime', '-pk').first()
+ self.data[key] = msg
+ return msg
+
+ def __setitem__(self, key, item):
+ self.data[self._get_key(key)] = item
class MeshNode(models.Model):
address = models.CharField(_('mac address'), max_length=17, primary_key=True)
name = models.CharField(_('name'), max_length=32, null=True, blank=True)
first_seen = models.DateTimeField(_('first seen'), auto_now_add=True)
- parent_node = models.ForeignKey('MeshNode', models.PROTECT, null=True,
- related_name='child_nodes', verbose_name=_('parent node'))
route = models.ForeignKey('MeshNode', models.PROTECT, null=True,
related_name='routed_nodes', verbose_name=_('route'))
- firmware = models.ForeignKey('Firmware', models.PROTECT, null=True,
- related_name='installed_on', verbose_name=_('firmware'))
+ objects = models.Manager.from_queryset(MeshNodeQuerySet)()
def __str__(self):
if self.name:
return '%s (%s)' % (self.address, self.name)
return self.address
+ @cached_property
+ def last_messages(self):
+ return LastMessagesByTypeLookup(self)
+
class NodeMessage(models.Model):
MESSAGE_TYPES = [(msgtype.value, msgtype.name) for msgtype in MessageType]
- node = models.ForeignKey('MeshNode', models.PROTECT, null=True,
- related_name='received_messages', verbose_name=_('node'))
+ src_node = models.ForeignKey('MeshNode', models.PROTECT,
+ related_name='received_messages', verbose_name=_('node'))
+ uplink_node = models.ForeignKey('MeshNode', models.PROTECT,
+ related_name='relayed_messages', verbose_name=_('uplink node'))
datetime = models.DateTimeField(_('datetime'), db_index=True, auto_now_add=True)
message_type = models.SmallIntegerField(_('message type'), db_index=True, choices=MESSAGE_TYPES)
data = models.JSONField(_('message data'))
@@ -37,9 +100,14 @@ class NodeMessage(models.Model):
def __str__(self):
return '(#%d) %s at %s' % (self.pk, self.get_message_type_display(), self.datetime)
+ @cached_property
+ def parsed(self):
+ return MeshMessage.fromjson(self.data)
+
class Firmware(models.Model):
- chip = models.SmallIntegerField(_('chip'), db_index=True, choices=ChipID.choices)
+ CHIPS = [(msgtype.value, msgtype.name.replace('_', '-')) for msgtype in ChipType]
+ chip = models.SmallIntegerField(_('chip'), db_index=True, choices=CHIPS)
project_name = models.CharField(_('project name'), max_length=32)
version = models.CharField(_('firmware version'), max_length=32)
idf_version = models.CharField(_('IDF version'), max_length=32)