more mesh communication implementation, lots of fixes and stuff

This commit is contained in:
Laura Klünder 2023-10-03 17:23:29 +02:00
parent df6efbc8d5
commit 4d3f54bbe8
8 changed files with 260 additions and 94 deletions

View file

@ -306,7 +306,7 @@ class MeshMessageFilerForm(Form):
required=False,
label=_('message types'),
)
nodes = ModelMultipleChoiceField(
src_nodes = ModelMultipleChoiceField(
queryset=MeshNode.objects.all(),
required=False,
label=_('nodes'),

View file

@ -10,7 +10,7 @@
{{ form.message_types }}
</div>
<div class="field">
{{ form.nodes }}
{{ form.src_nodes }}
</div>
<div class="field">
<button type="submit">Filer</button>
@ -26,11 +26,12 @@
<th>{% trans 'Node' %}</th>
<th>{% trans 'Type' %}</th>
<th>{% trans 'Data' %}</th>
<th>{% trans 'Uplink' %}</th>
</tr>
{% for msg in mesh_messages %}
<tr>
<td>{{ msg.datetime }}</td>
<td>{{ msg.node }}</td>
<td>{{ msg.src_node }}</td>
<td>{{ msg.get_message_type_display }}</td>
<td>
@ -42,6 +43,7 @@
{% endif %}
{% endfor %}
</td>
<td>{{ msg.uplink_node }}</td>
</tr>
{% endfor %}
</table>

View file

@ -6,8 +6,8 @@
{% block subcontent %}
<table>
<tr>
<th>{% trans 'Address' %}</th>
<th>{% trans 'Name' %}</th>
<th>{% trans 'Node' %}</th>
<th>{% trans 'Status' %}</th>
<th>{% trans 'Chip' %}</th>
<th>{% trans 'Firmware' %}</th>
<th>{% trans 'Last msg' %}</th>
@ -16,17 +16,30 @@
</tr>
{% for node in nodes %}
<tr>
<td>{{ node.address }}</td>
<td>{{ node.name }}</td>
<td>
{{ node.firmware.get_chip_display }}
{% if node.route %}
<span style="color: green;">{% trans "online" %}</span>
{% else %}
<span style="color: red;">{% trans "offline" %}</span>
{% endif %}
</td>
<td>{{ node }}</td>
<td>
{{ node.last_messages.CONFIG_FIRMWARE.parsed.get_chip_display }}
rev{{ node.last_messages.CONFIG_FIRMWARE.parsed.revision|join:"." }}
</td>
<td>
{{ node.firmware.version }} (IDF {{ node.firmware.idf_version }})
{{ node.last_messages.CONFIG_FIRMWARE.parsed.version }}
(IDF {{ node.last_messages.CONFIG_FIRMWARE.parsed.idf_version }})
</td>
<td>
{% blocktrans trimmed with timesince=node.last_msg|timesince %}
{{ timesince }} ago
{% endblocktrans %}
</td>
<td>{{ node.last_msg }}</td>
<td>{{ node.parent }}</td>
<td>{{ node.route }}</td>
<td>{{ node.last_messages.CONFIG_FIRMWARE.data }}</td>
</tr>
{% endfor %}
</table>

View file

@ -1,6 +1,5 @@
from django.db.models import Max
from django.views.generic import ListView
from django.views.generic.edit import FormMixin
from c3nav.control.forms import MeshMessageFilerForm
from c3nav.control.views.base import ControlPanelMixin
@ -14,7 +13,7 @@ class MeshNodeListView(ControlPanelMixin, ListView):
context_object_name = "nodes"
def get_queryset(self):
return super().get_queryset().annotate(last_msg=Max('received_messages__datetime'))
return super().get_queryset().annotate(last_msg=Max('received_messages__datetime')).prefetch_last_messages()
class MeshMessageListView(ControlPanelMixin, ListView):
@ -31,8 +30,8 @@ class MeshMessageListView(ControlPanelMixin, ListView):
if self.form.is_valid():
if self.form.cleaned_data['message_types']:
qs = qs.filter(message_type__in=self.form.cleaned_data['message_types'])
if self.form.cleaned_data['nodes']:
qs = qs.filter(node__in=self.form.cleaned_data['nodes'])
if self.form.cleaned_data['src_nodes']:
qs = qs.filter(src_node__in=self.form.cleaned_data['src_nodes'])
return qs
@ -47,4 +46,3 @@ class MeshMessageListView(ControlPanelMixin, ListView):
'form_data': form_data.urlencode(),
})
return ctx

View file

@ -1,31 +1,33 @@
import traceback
from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncWebsocketConsumer
from asgiref.sync import async_to_sync
from channels.generic.websocket import WebsocketConsumer
from c3nav.mesh import messages
from c3nav.mesh.models import MeshNode, NodeMessage, Firmware
from c3nav.mesh.models import MeshNode, NodeMessage
class MeshConsumer(AsyncWebsocketConsumer):
async def connect(self):
# noinspection PyAttributeOutsideInit
class MeshConsumer(WebsocketConsumer):
def connect(self):
print('connected!')
# todo: auth
self.node = None
await self.accept()
self.uplink_node = None
self.dst_nodes = set()
self.accept()
async def disconnect(self, close_code):
def disconnect(self, close_code):
print('disconnected!')
if self.node is not None:
await self.remove_route(self.node.address)
await self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name)
await self.channel_layer.group_discard('route_broadcast', self.channel_name)
if self.uplink_node is not None:
self.remove_route(self.uplink_node)
self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name)
self.channel_layer.group_discard('route_broadcast', self.channel_name)
async def send_msg(self, msg):
def send_msg(self, msg):
print('Sending message:', msg)
await self.send(bytes_data=msg.encode())
self.send(bytes_data=msg.encode())
async def receive(self, text_data=None, bytes_data=None):
def receive(self, text_data=None, bytes_data=None):
if bytes_data is None:
return
try:
@ -41,77 +43,81 @@ class MeshConsumer(AsyncWebsocketConsumer):
print('Received message:', msg)
src_node, created = MeshNode.objects.get_or_create(address=msg.src)
if isinstance(msg, messages.MeshSigninMessage):
self.node, created = await self.get_node(msg.src)
if created:
print('New node signing in!')
print(self.node)
await self.log_received_message(msg)
await self.send_msg(messages.MeshLayerAnnounceMessage(
self.uplink_node = src_node
# log message, since we will not log it further down
self.log_received_message(src_node, msg)
# inform signed in uplink node about its layer
self.send_msg(messages.MeshLayerAnnounceMessage(
src=messages.ROOT_ADDRESS,
dst=msg.src,
layer=messages.NO_LAYER
))
await self.send_msg(messages.ConfigDumpMessage(
src=messages.ROOT_ADDRESS,
dst=msg.src,
))
await self.channel_layer.group_add('route_%s' % self.node.address.replace(':', ''), self.channel_name)
await self.channel_layer.group_add('route_broadcast', self.channel_name)
await self.set_parent_of_nodes(None, (self.node.address, ))
await self.add_route_to_nodes(self.node.address, (self.node.address,))
# add signed in uplink node to broadcast route
async_to_sync(self.channel_layer.group_add)('mesh_broadcast', self.channel_name)
# add this node as a destination that this uplink handles (duh)
self.add_dst_nodes((src_node.address, ))
return
if self.node is None:
if self.uplink_node is None:
print('Expected sign-in message, but got a different one!')
await self.close()
self.close()
return
await self.log_received_message(msg)
self.log_received_message(src_node, msg)
if isinstance(msg, messages.ConfigFirmwareMessage):
await self._handle_config_firmware_msg(msg)
return
def uplink_change(self, data):
# message handler: if we are not the given uplink, leave this group
if data["uplink"] != self.uplink_node.address:
group = self.group_name_for_node(data["address"])
print('leaving uplink group...')
async_to_sync(self.channel_layer.group_discard)(group, self.channel_name)
@database_sync_to_async
def _handle_config_firmware_msg(self, msg):
self.firmware, created = Firmware.objects.get_or_create(**msg.to_model_data())
self.node.firmware = self.firmware
self.node.save()
@database_sync_to_async
def get_node(self, address):
return MeshNode.objects.get_or_create(address=address)
@database_sync_to_async
def log_received_message(self, msg: messages.Message):
def log_received_message(self, src_node: MeshNode, msg: messages.Message):
NodeMessage.objects.create(
node=self.node,
uplink_node=self.uplink_node,
src_node=src_node,
message_type=msg.msg_id,
data=msg.tojson()
)
@database_sync_to_async
def create_nodes(self, addresses):
MeshNode.objects.bulk_create(MeshNode(address=address) for address in addresses)
def add_dst_nodes(self, addresses):
# add ourselves to this one
for address in addresses:
# create group name for this address
group = self.group_name_for_node(address)
@database_sync_to_async
def set_parent_of_nodes(self, parent_address, node_addresses):
MeshNode.objects.filter(address__in=node_addresses).update(parent_node_id=parent_address)
# if we aren't handling this address yet, join the group
if address not in self.dst_nodes:
async_to_sync(self.channel_layer.group_add)(group, self.channel_name)
self.dst_nodes.add(address)
@database_sync_to_async
def add_route_to_nodes(self, route_address, node_addresses):
MeshNode.objects.filter(address__in=node_addresses).update(route_id=route_address)
# tell other consumers to leave the group
async_to_sync(self.channel_layer.group_send)(group, {
"type": "uplink_change",
"node": address,
"uplink": self.uplink_node.address
})
# tell the node to dump its current information
self.send_msg(
messages.ConfigDumpMessage(
src=messages.ROOT_ADDRESS,
dst=address,
)
)
# add the stuff to the db as well
MeshNode.objects.filter(address__in=addresses).update(route_id=self.uplink_node.address)
def group_name_for_node(self, address):
return 'mesh_%s' % address.replace(':', '-')
@database_sync_to_async
def remove_route(self, route_address):
MeshNode.objects.filter(route_id=route_address).update(route_id=None)
@database_sync_to_async
def remove_route_to_nodes(self, route_address, node_addresses):
MeshNode.objects.filter(address__in=node_addresses, route_id=route_address).update(route_id=None)
@database_sync_to_async
def set_node_firmware(self, firmware):
self.node.firmware = firmware
self.node.save()

View file

@ -30,6 +30,12 @@ class MessageType(IntEnum):
M = TypeVar('M', bound='Message')
@unique
class ChipType(IntEnum):
ESP32_S2 = 2
ESP32_C3 = 5
@dataclass
class Message:
dst: str = field(metadata={'format': MacAddressFormat()})
@ -134,6 +140,9 @@ class ConfigFirmwareMessage(Message, msg_id=MessageType.CONFIG_FIRMWARE):
'sha256_hash': self.app_elf_sha256,
}
def get_chip_display(self):
return ChipType(self.chip).name.replace('_', '-')
@dataclass
class ConfigPositionMessage(Message, msg_id=MessageType.CONFIG_POSITION):

View file

@ -0,0 +1,70 @@
# Generated by Django 4.2.1 on 2023-10-03 13:42
from django.db import migrations, models
import django.db.models.deletion
def forwards_func(apps, schema_editor):
NodeMessage = apps.get_model("mesh", "NodeMessage")
MeshNode = apps.get_model("mesh", "MeshNode")
NodeMessage.objects.filter(uplink_node=None).delete()
nodes = {node.address: node for node in MeshNode.objects.all()}
for msg in NodeMessage.objects.all():
if msg.data["src"] not in nodes:
nodes[msg.data["src"]] = MeshNode.objects.create(address=msg.data["src"])
msg.src_node = nodes[msg.data["src"]]
msg.save()
class Migration(migrations.Migration):
dependencies = [
('mesh', '0003_meshnode_name'),
]
operations = [
migrations.RenameField(
model_name='nodemessage',
old_name='node',
new_name='uplink_node',
),
migrations.RemoveField(
model_name='meshnode',
name='firmware',
),
migrations.RemoveField(
model_name='meshnode',
name='parent_node',
),
migrations.AddField(
model_name='nodemessage',
name='src_node',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT,
related_name='received_messages', to='mesh.meshnode', verbose_name='node'),
),
migrations.AlterField(
model_name='nodemessage',
name='uplink_node',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT,
related_name='relayed_messages', to='mesh.meshnode', verbose_name='uplink node'),
),
migrations.RunPython(forwards_func, migrations.RunPython.noop),
migrations.AlterField(
model_name='nodemessage',
name='src_node',
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='received_messages',
to='mesh.meshnode', verbose_name='node'),
),
migrations.AlterField(
model_name='nodemessage',
name='uplink_node',
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='relayed_messages',
to='mesh.meshnode', verbose_name='uplink node'),
),
]

View file

@ -1,35 +1,98 @@
from django.db import models
from collections import UserDict
from functools import cached_property
from django.db import models, NotSupportedError
from django.utils.translation import gettext_lazy as _
from c3nav.mesh.messages import MessageType
from c3nav.mesh.messages import MessageType, ChipType, Message as MeshMessage
class ChipID(models.IntegerChoices):
ESP32S2 = 2, 'ESP32-S2'
ESP32C3 = 5, 'ESP32-C3'
class MeshNodeQuerySet(models.QuerySet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._prefetch_last_messages = set()
self._prefetch_last_messages_done = False
def _clone(self):
clone = super()._clone()
clone._prefetch_last_messages = self._prefetch_last_messages
return clone
def prefetch_last_messages(self, *types: MessageType):
clone = self._chain()
clone._prefetch_last_messages |= (
set(types) if types else set(msgtype.value for msgtype in MessageType)
)
return clone
def _fetch_all(self):
super()._fetch_all()
if self._prefetch_last_messages and not self._prefetch_last_messages_done:
nodes = {node.pk: node for node in self._result_cache}
try:
for message in NodeMessage.objects.order_by('-datetime', '-pk').filter(
message_type__in=self._prefetch_last_messages,
node__in=nodes.keys(),
).distinct('message_type', 'node'):
nodes[message.node].last_messages[message.message_type] = message
except NotSupportedError:
pass
print(tuple(nodes.values())[0].last_messages[MessageType.MESH_SIGNIN])
class LastMessagesByTypeLookup(UserDict):
def __init__(self, node):
super().__init__()
self.node = node
def _get_key(self, item):
if isinstance(item, MessageType):
return item
if isinstance(item, str):
try:
return getattr(MessageType, item)
except AttributeError:
pass
return MessageType(item)
def __getitem__(self, key):
key = self._get_key(key)
try:
return self.data[key]
except KeyError:
pass
msg = self.node.received_messages.filter(message_type=key).order_by('-datetime', '-pk').first()
self.data[key] = msg
return msg
def __setitem__(self, key, item):
self.data[self._get_key(key)] = item
class MeshNode(models.Model):
address = models.CharField(_('mac address'), max_length=17, primary_key=True)
name = models.CharField(_('name'), max_length=32, null=True, blank=True)
first_seen = models.DateTimeField(_('first seen'), auto_now_add=True)
parent_node = models.ForeignKey('MeshNode', models.PROTECT, null=True,
related_name='child_nodes', verbose_name=_('parent node'))
route = models.ForeignKey('MeshNode', models.PROTECT, null=True,
related_name='routed_nodes', verbose_name=_('route'))
firmware = models.ForeignKey('Firmware', models.PROTECT, null=True,
related_name='installed_on', verbose_name=_('firmware'))
objects = models.Manager.from_queryset(MeshNodeQuerySet)()
def __str__(self):
if self.name:
return '%s (%s)' % (self.address, self.name)
return self.address
@cached_property
def last_messages(self):
return LastMessagesByTypeLookup(self)
class NodeMessage(models.Model):
MESSAGE_TYPES = [(msgtype.value, msgtype.name) for msgtype in MessageType]
node = models.ForeignKey('MeshNode', models.PROTECT, null=True,
related_name='received_messages', verbose_name=_('node'))
src_node = models.ForeignKey('MeshNode', models.PROTECT,
related_name='received_messages', verbose_name=_('node'))
uplink_node = models.ForeignKey('MeshNode', models.PROTECT,
related_name='relayed_messages', verbose_name=_('uplink node'))
datetime = models.DateTimeField(_('datetime'), db_index=True, auto_now_add=True)
message_type = models.SmallIntegerField(_('message type'), db_index=True, choices=MESSAGE_TYPES)
data = models.JSONField(_('message data'))
@ -37,9 +100,14 @@ class NodeMessage(models.Model):
def __str__(self):
return '(#%d) %s at %s' % (self.pk, self.get_message_type_display(), self.datetime)
@cached_property
def parsed(self):
return MeshMessage.fromjson(self.data)
class Firmware(models.Model):
chip = models.SmallIntegerField(_('chip'), db_index=True, choices=ChipID.choices)
CHIPS = [(msgtype.value, msgtype.name.replace('_', '-')) for msgtype in ChipType]
chip = models.SmallIntegerField(_('chip'), db_index=True, choices=CHIPS)
project_name = models.CharField(_('project name'), max_length=32)
version = models.CharField(_('firmware version'), max_length=32)
idf_version = models.CharField(_('IDF version'), max_length=32)