more mesh communication implementation, lots of fixes and stuff

This commit is contained in:
Laura Klünder 2023-10-03 17:23:29 +02:00
parent df6efbc8d5
commit 4d3f54bbe8
8 changed files with 260 additions and 94 deletions

View file

@ -306,7 +306,7 @@ class MeshMessageFilerForm(Form):
required=False, required=False,
label=_('message types'), label=_('message types'),
) )
nodes = ModelMultipleChoiceField( src_nodes = ModelMultipleChoiceField(
queryset=MeshNode.objects.all(), queryset=MeshNode.objects.all(),
required=False, required=False,
label=_('nodes'), label=_('nodes'),

View file

@ -10,7 +10,7 @@
{{ form.message_types }} {{ form.message_types }}
</div> </div>
<div class="field"> <div class="field">
{{ form.nodes }} {{ form.src_nodes }}
</div> </div>
<div class="field"> <div class="field">
<button type="submit">Filer</button> <button type="submit">Filer</button>
@ -26,11 +26,12 @@
<th>{% trans 'Node' %}</th> <th>{% trans 'Node' %}</th>
<th>{% trans 'Type' %}</th> <th>{% trans 'Type' %}</th>
<th>{% trans 'Data' %}</th> <th>{% trans 'Data' %}</th>
<th>{% trans 'Uplink' %}</th>
</tr> </tr>
{% for msg in mesh_messages %} {% for msg in mesh_messages %}
<tr> <tr>
<td>{{ msg.datetime }}</td> <td>{{ msg.datetime }}</td>
<td>{{ msg.node }}</td> <td>{{ msg.src_node }}</td>
<td>{{ msg.get_message_type_display }}</td> <td>{{ msg.get_message_type_display }}</td>
<td> <td>
@ -42,6 +43,7 @@
{% endif %} {% endif %}
{% endfor %} {% endfor %}
</td> </td>
<td>{{ msg.uplink_node }}</td>
</tr> </tr>
{% endfor %} {% endfor %}
</table> </table>

View file

@ -6,8 +6,8 @@
{% block subcontent %} {% block subcontent %}
<table> <table>
<tr> <tr>
<th>{% trans 'Address' %}</th> <th>{% trans 'Node' %}</th>
<th>{% trans 'Name' %}</th> <th>{% trans 'Status' %}</th>
<th>{% trans 'Chip' %}</th> <th>{% trans 'Chip' %}</th>
<th>{% trans 'Firmware' %}</th> <th>{% trans 'Firmware' %}</th>
<th>{% trans 'Last msg' %}</th> <th>{% trans 'Last msg' %}</th>
@ -16,17 +16,30 @@
</tr> </tr>
{% for node in nodes %} {% for node in nodes %}
<tr> <tr>
<td>{{ node.address }}</td>
<td>{{ node.name }}</td>
<td> <td>
{{ node.firmware.get_chip_display }} {% if node.route %}
<span style="color: green;">{% trans "online" %}</span>
{% else %}
<span style="color: red;">{% trans "offline" %}</span>
{% endif %}
</td>
<td>{{ node }}</td>
<td>
{{ node.last_messages.CONFIG_FIRMWARE.parsed.get_chip_display }}
rev{{ node.last_messages.CONFIG_FIRMWARE.parsed.revision|join:"." }}
</td> </td>
<td> <td>
{{ node.firmware.version }} (IDF {{ node.firmware.idf_version }}) {{ node.last_messages.CONFIG_FIRMWARE.parsed.version }}
(IDF {{ node.last_messages.CONFIG_FIRMWARE.parsed.idf_version }})
</td>
<td>
{% blocktrans trimmed with timesince=node.last_msg|timesince %}
{{ timesince }} ago
{% endblocktrans %}
</td> </td>
<td>{{ node.last_msg }}</td>
<td>{{ node.parent }}</td> <td>{{ node.parent }}</td>
<td>{{ node.route }}</td> <td>{{ node.route }}</td>
<td>{{ node.last_messages.CONFIG_FIRMWARE.data }}</td>
</tr> </tr>
{% endfor %} {% endfor %}
</table> </table>

View file

@ -1,6 +1,5 @@
from django.db.models import Max from django.db.models import Max
from django.views.generic import ListView from django.views.generic import ListView
from django.views.generic.edit import FormMixin
from c3nav.control.forms import MeshMessageFilerForm from c3nav.control.forms import MeshMessageFilerForm
from c3nav.control.views.base import ControlPanelMixin from c3nav.control.views.base import ControlPanelMixin
@ -14,7 +13,7 @@ class MeshNodeListView(ControlPanelMixin, ListView):
context_object_name = "nodes" context_object_name = "nodes"
def get_queryset(self): def get_queryset(self):
return super().get_queryset().annotate(last_msg=Max('received_messages__datetime')) return super().get_queryset().annotate(last_msg=Max('received_messages__datetime')).prefetch_last_messages()
class MeshMessageListView(ControlPanelMixin, ListView): class MeshMessageListView(ControlPanelMixin, ListView):
@ -31,8 +30,8 @@ class MeshMessageListView(ControlPanelMixin, ListView):
if self.form.is_valid(): if self.form.is_valid():
if self.form.cleaned_data['message_types']: if self.form.cleaned_data['message_types']:
qs = qs.filter(message_type__in=self.form.cleaned_data['message_types']) qs = qs.filter(message_type__in=self.form.cleaned_data['message_types'])
if self.form.cleaned_data['nodes']: if self.form.cleaned_data['src_nodes']:
qs = qs.filter(node__in=self.form.cleaned_data['nodes']) qs = qs.filter(src_node__in=self.form.cleaned_data['src_nodes'])
return qs return qs
@ -47,4 +46,3 @@ class MeshMessageListView(ControlPanelMixin, ListView):
'form_data': form_data.urlencode(), 'form_data': form_data.urlencode(),
}) })
return ctx return ctx

View file

@ -1,31 +1,33 @@
import traceback import traceback
from channels.db import database_sync_to_async from asgiref.sync import async_to_sync
from channels.generic.websocket import AsyncWebsocketConsumer from channels.generic.websocket import WebsocketConsumer
from c3nav.mesh import messages from c3nav.mesh import messages
from c3nav.mesh.models import MeshNode, NodeMessage, Firmware from c3nav.mesh.models import MeshNode, NodeMessage
class MeshConsumer(AsyncWebsocketConsumer): # noinspection PyAttributeOutsideInit
async def connect(self): class MeshConsumer(WebsocketConsumer):
def connect(self):
print('connected!') print('connected!')
# todo: auth # todo: auth
self.node = None self.uplink_node = None
await self.accept() self.dst_nodes = set()
self.accept()
async def disconnect(self, close_code): def disconnect(self, close_code):
print('disconnected!') print('disconnected!')
if self.node is not None: if self.uplink_node is not None:
await self.remove_route(self.node.address) self.remove_route(self.uplink_node)
await self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name) self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name)
await self.channel_layer.group_discard('route_broadcast', self.channel_name) self.channel_layer.group_discard('route_broadcast', self.channel_name)
async def send_msg(self, msg): def send_msg(self, msg):
print('Sending message:', msg) print('Sending message:', msg)
await self.send(bytes_data=msg.encode()) self.send(bytes_data=msg.encode())
async def receive(self, text_data=None, bytes_data=None): def receive(self, text_data=None, bytes_data=None):
if bytes_data is None: if bytes_data is None:
return return
try: try:
@ -41,77 +43,81 @@ class MeshConsumer(AsyncWebsocketConsumer):
print('Received message:', msg) print('Received message:', msg)
src_node, created = MeshNode.objects.get_or_create(address=msg.src)
if isinstance(msg, messages.MeshSigninMessage): if isinstance(msg, messages.MeshSigninMessage):
self.node, created = await self.get_node(msg.src) self.uplink_node = src_node
if created: # log message, since we will not log it further down
print('New node signing in!') self.log_received_message(src_node, msg)
print(self.node)
await self.log_received_message(msg) # inform signed in uplink node about its layer
await self.send_msg(messages.MeshLayerAnnounceMessage( self.send_msg(messages.MeshLayerAnnounceMessage(
src=messages.ROOT_ADDRESS, src=messages.ROOT_ADDRESS,
dst=msg.src, dst=msg.src,
layer=messages.NO_LAYER layer=messages.NO_LAYER
)) ))
await self.send_msg(messages.ConfigDumpMessage(
src=messages.ROOT_ADDRESS, # add signed in uplink node to broadcast route
dst=msg.src, async_to_sync(self.channel_layer.group_add)('mesh_broadcast', self.channel_name)
))
await self.channel_layer.group_add('route_%s' % self.node.address.replace(':', ''), self.channel_name) # add this node as a destination that this uplink handles (duh)
await self.channel_layer.group_add('route_broadcast', self.channel_name) self.add_dst_nodes((src_node.address, ))
await self.set_parent_of_nodes(None, (self.node.address, ))
await self.add_route_to_nodes(self.node.address, (self.node.address,))
return return
if self.node is None: if self.uplink_node is None:
print('Expected sign-in message, but got a different one!') print('Expected sign-in message, but got a different one!')
await self.close() self.close()
return return
await self.log_received_message(msg) self.log_received_message(src_node, msg)
if isinstance(msg, messages.ConfigFirmwareMessage): def uplink_change(self, data):
await self._handle_config_firmware_msg(msg) # message handler: if we are not the given uplink, leave this group
return if data["uplink"] != self.uplink_node.address:
group = self.group_name_for_node(data["address"])
print('leaving uplink group...')
async_to_sync(self.channel_layer.group_discard)(group, self.channel_name)
@database_sync_to_async def log_received_message(self, src_node: MeshNode, msg: messages.Message):
def _handle_config_firmware_msg(self, msg):
self.firmware, created = Firmware.objects.get_or_create(**msg.to_model_data())
self.node.firmware = self.firmware
self.node.save()
@database_sync_to_async
def get_node(self, address):
return MeshNode.objects.get_or_create(address=address)
@database_sync_to_async
def log_received_message(self, msg: messages.Message):
NodeMessage.objects.create( NodeMessage.objects.create(
node=self.node, uplink_node=self.uplink_node,
src_node=src_node,
message_type=msg.msg_id, message_type=msg.msg_id,
data=msg.tojson() data=msg.tojson()
) )
@database_sync_to_async def add_dst_nodes(self, addresses):
def create_nodes(self, addresses): # add ourselves to this one
MeshNode.objects.bulk_create(MeshNode(address=address) for address in addresses) for address in addresses:
# create group name for this address
group = self.group_name_for_node(address)
@database_sync_to_async # if we aren't handling this address yet, join the group
def set_parent_of_nodes(self, parent_address, node_addresses): if address not in self.dst_nodes:
MeshNode.objects.filter(address__in=node_addresses).update(parent_node_id=parent_address) async_to_sync(self.channel_layer.group_add)(group, self.channel_name)
self.dst_nodes.add(address)
@database_sync_to_async # tell other consumers to leave the group
def add_route_to_nodes(self, route_address, node_addresses): async_to_sync(self.channel_layer.group_send)(group, {
MeshNode.objects.filter(address__in=node_addresses).update(route_id=route_address) "type": "uplink_change",
"node": address,
"uplink": self.uplink_node.address
})
# tell the node to dump its current information
self.send_msg(
messages.ConfigDumpMessage(
src=messages.ROOT_ADDRESS,
dst=address,
)
)
# add the stuff to the db as well
MeshNode.objects.filter(address__in=addresses).update(route_id=self.uplink_node.address)
def group_name_for_node(self, address):
return 'mesh_%s' % address.replace(':', '-')
@database_sync_to_async
def remove_route(self, route_address): def remove_route(self, route_address):
MeshNode.objects.filter(route_id=route_address).update(route_id=None) MeshNode.objects.filter(route_id=route_address).update(route_id=None)
@database_sync_to_async
def remove_route_to_nodes(self, route_address, node_addresses):
MeshNode.objects.filter(address__in=node_addresses, route_id=route_address).update(route_id=None)
@database_sync_to_async
def set_node_firmware(self, firmware):
self.node.firmware = firmware
self.node.save()

View file

@ -30,6 +30,12 @@ class MessageType(IntEnum):
M = TypeVar('M', bound='Message') M = TypeVar('M', bound='Message')
@unique
class ChipType(IntEnum):
ESP32_S2 = 2
ESP32_C3 = 5
@dataclass @dataclass
class Message: class Message:
dst: str = field(metadata={'format': MacAddressFormat()}) dst: str = field(metadata={'format': MacAddressFormat()})
@ -134,6 +140,9 @@ class ConfigFirmwareMessage(Message, msg_id=MessageType.CONFIG_FIRMWARE):
'sha256_hash': self.app_elf_sha256, 'sha256_hash': self.app_elf_sha256,
} }
def get_chip_display(self):
return ChipType(self.chip).name.replace('_', '-')
@dataclass @dataclass
class ConfigPositionMessage(Message, msg_id=MessageType.CONFIG_POSITION): class ConfigPositionMessage(Message, msg_id=MessageType.CONFIG_POSITION):

View file

@ -0,0 +1,70 @@
# Generated by Django 4.2.1 on 2023-10-03 13:42
from django.db import migrations, models
import django.db.models.deletion
def forwards_func(apps, schema_editor):
NodeMessage = apps.get_model("mesh", "NodeMessage")
MeshNode = apps.get_model("mesh", "MeshNode")
NodeMessage.objects.filter(uplink_node=None).delete()
nodes = {node.address: node for node in MeshNode.objects.all()}
for msg in NodeMessage.objects.all():
if msg.data["src"] not in nodes:
nodes[msg.data["src"]] = MeshNode.objects.create(address=msg.data["src"])
msg.src_node = nodes[msg.data["src"]]
msg.save()
class Migration(migrations.Migration):
dependencies = [
('mesh', '0003_meshnode_name'),
]
operations = [
migrations.RenameField(
model_name='nodemessage',
old_name='node',
new_name='uplink_node',
),
migrations.RemoveField(
model_name='meshnode',
name='firmware',
),
migrations.RemoveField(
model_name='meshnode',
name='parent_node',
),
migrations.AddField(
model_name='nodemessage',
name='src_node',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT,
related_name='received_messages', to='mesh.meshnode', verbose_name='node'),
),
migrations.AlterField(
model_name='nodemessage',
name='uplink_node',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT,
related_name='relayed_messages', to='mesh.meshnode', verbose_name='uplink node'),
),
migrations.RunPython(forwards_func, migrations.RunPython.noop),
migrations.AlterField(
model_name='nodemessage',
name='src_node',
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='received_messages',
to='mesh.meshnode', verbose_name='node'),
),
migrations.AlterField(
model_name='nodemessage',
name='uplink_node',
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='relayed_messages',
to='mesh.meshnode', verbose_name='uplink node'),
),
]

View file

@ -1,35 +1,98 @@
from django.db import models from collections import UserDict
from functools import cached_property
from django.db import models, NotSupportedError
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from c3nav.mesh.messages import MessageType from c3nav.mesh.messages import MessageType, ChipType, Message as MeshMessage
class ChipID(models.IntegerChoices): class MeshNodeQuerySet(models.QuerySet):
ESP32S2 = 2, 'ESP32-S2' def __init__(self, *args, **kwargs):
ESP32C3 = 5, 'ESP32-C3' super().__init__(*args, **kwargs)
self._prefetch_last_messages = set()
self._prefetch_last_messages_done = False
def _clone(self):
clone = super()._clone()
clone._prefetch_last_messages = self._prefetch_last_messages
return clone
def prefetch_last_messages(self, *types: MessageType):
clone = self._chain()
clone._prefetch_last_messages |= (
set(types) if types else set(msgtype.value for msgtype in MessageType)
)
return clone
def _fetch_all(self):
super()._fetch_all()
if self._prefetch_last_messages and not self._prefetch_last_messages_done:
nodes = {node.pk: node for node in self._result_cache}
try:
for message in NodeMessage.objects.order_by('-datetime', '-pk').filter(
message_type__in=self._prefetch_last_messages,
node__in=nodes.keys(),
).distinct('message_type', 'node'):
nodes[message.node].last_messages[message.message_type] = message
except NotSupportedError:
pass
print(tuple(nodes.values())[0].last_messages[MessageType.MESH_SIGNIN])
class LastMessagesByTypeLookup(UserDict):
def __init__(self, node):
super().__init__()
self.node = node
def _get_key(self, item):
if isinstance(item, MessageType):
return item
if isinstance(item, str):
try:
return getattr(MessageType, item)
except AttributeError:
pass
return MessageType(item)
def __getitem__(self, key):
key = self._get_key(key)
try:
return self.data[key]
except KeyError:
pass
msg = self.node.received_messages.filter(message_type=key).order_by('-datetime', '-pk').first()
self.data[key] = msg
return msg
def __setitem__(self, key, item):
self.data[self._get_key(key)] = item
class MeshNode(models.Model): class MeshNode(models.Model):
address = models.CharField(_('mac address'), max_length=17, primary_key=True) address = models.CharField(_('mac address'), max_length=17, primary_key=True)
name = models.CharField(_('name'), max_length=32, null=True, blank=True) name = models.CharField(_('name'), max_length=32, null=True, blank=True)
first_seen = models.DateTimeField(_('first seen'), auto_now_add=True) first_seen = models.DateTimeField(_('first seen'), auto_now_add=True)
parent_node = models.ForeignKey('MeshNode', models.PROTECT, null=True,
related_name='child_nodes', verbose_name=_('parent node'))
route = models.ForeignKey('MeshNode', models.PROTECT, null=True, route = models.ForeignKey('MeshNode', models.PROTECT, null=True,
related_name='routed_nodes', verbose_name=_('route')) related_name='routed_nodes', verbose_name=_('route'))
firmware = models.ForeignKey('Firmware', models.PROTECT, null=True, objects = models.Manager.from_queryset(MeshNodeQuerySet)()
related_name='installed_on', verbose_name=_('firmware'))
def __str__(self): def __str__(self):
if self.name: if self.name:
return '%s (%s)' % (self.address, self.name) return '%s (%s)' % (self.address, self.name)
return self.address return self.address
@cached_property
def last_messages(self):
return LastMessagesByTypeLookup(self)
class NodeMessage(models.Model): class NodeMessage(models.Model):
MESSAGE_TYPES = [(msgtype.value, msgtype.name) for msgtype in MessageType] MESSAGE_TYPES = [(msgtype.value, msgtype.name) for msgtype in MessageType]
node = models.ForeignKey('MeshNode', models.PROTECT, null=True, src_node = models.ForeignKey('MeshNode', models.PROTECT,
related_name='received_messages', verbose_name=_('node')) related_name='received_messages', verbose_name=_('node'))
uplink_node = models.ForeignKey('MeshNode', models.PROTECT,
related_name='relayed_messages', verbose_name=_('uplink node'))
datetime = models.DateTimeField(_('datetime'), db_index=True, auto_now_add=True) datetime = models.DateTimeField(_('datetime'), db_index=True, auto_now_add=True)
message_type = models.SmallIntegerField(_('message type'), db_index=True, choices=MESSAGE_TYPES) message_type = models.SmallIntegerField(_('message type'), db_index=True, choices=MESSAGE_TYPES)
data = models.JSONField(_('message data')) data = models.JSONField(_('message data'))
@ -37,9 +100,14 @@ class NodeMessage(models.Model):
def __str__(self): def __str__(self):
return '(#%d) %s at %s' % (self.pk, self.get_message_type_display(), self.datetime) return '(#%d) %s at %s' % (self.pk, self.get_message_type_display(), self.datetime)
@cached_property
def parsed(self):
return MeshMessage.fromjson(self.data)
class Firmware(models.Model): class Firmware(models.Model):
chip = models.SmallIntegerField(_('chip'), db_index=True, choices=ChipID.choices) CHIPS = [(msgtype.value, msgtype.name.replace('_', '-')) for msgtype in ChipType]
chip = models.SmallIntegerField(_('chip'), db_index=True, choices=CHIPS)
project_name = models.CharField(_('project name'), max_length=32) project_name = models.CharField(_('project name'), max_length=32)
version = models.CharField(_('firmware version'), max_length=32) version = models.CharField(_('firmware version'), max_length=32)
idf_version = models.CharField(_('IDF version'), max_length=32) idf_version = models.CharField(_('IDF version'), max_length=32)