mesh routing now fully database-bound

This commit is contained in:
Laura Klünder 2023-11-07 16:35:46 +01:00
parent ee539f678a
commit 66810b20b0
6 changed files with 94 additions and 93 deletions

View file

@ -420,7 +420,7 @@ class StructType:
return data
@classmethod
def fromjson(cls, data: dict):
def fromjson(cls, data: dict) -> Self:
data = data.copy()
# todo: upgrade_json

View file

@ -1,6 +1,7 @@
import asyncio
import traceback
from asyncio import get_event_loop
from functools import cached_property
from asgiref.sync import async_to_sync
from channels.db import database_sync_to_async
@ -12,7 +13,7 @@ from c3nav.mesh import messages
from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage,
MeshMessageType)
from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage
from c3nav.mesh.utils import get_mesh_comm_group
from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP, UPLINK_PING, get_mesh_uplink_group
class MeshConsumer(AsyncWebsocketConsumer):
@ -35,9 +36,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
await self.log_text(self.uplink.node, "mesh websocket disconnected")
if self.uplink is not None:
# leave broadcast group
await self.channel_layer.group_discard(
get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name
)
await self.channel_layer.group_discard("mesh_comm_broadcast", self.channel_name)
# remove all other destinations
await self.remove_dst_nodes(self.dst_nodes)
@ -66,6 +65,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
# "msg": msg.tojson(), # not doing this part for privacy reasons
})
@cached_property
def same_uplinks_group(self):
return 'mesh_uplink_%s' % self.uplink.node.address.replace(':', '-')
async def receive(self, text_data=None, bytes_data=None):
if bytes_data is None:
return
@ -92,7 +95,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding")
msg.trace.append(MESH_ROOT_ADDRESS)
msg.send(exclude_uplink_address=self.uplink.node.address)
result = await msg.send(exclude_uplink_address=self.uplink.node.address)
if not result:
print('message had no route')
# don't handle this message unless it's a broadcast message
if msg.dst != messages.MESH_BROADCAST_ADDRESS:
@ -110,7 +116,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
await self.create_uplink_in_database(msg.src)
# inform other uplinks to shut down
await self.channel_layer.group_send(get_mesh_comm_group(msg.src), {
await self.channel_layer.group_send(get_mesh_uplink_group(msg.src), {
"type": "mesh.uplink_consumer",
"name": self.channel_name,
})
@ -126,9 +132,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
))
# add signed in uplink node to broadcast group
await self.channel_layer.group_add(
get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name
)
await self.channel_layer.group_add(MESH_ALL_UPLINKS_GROUP, self.channel_name)
# add this node as a destination that this uplink handles (duh)
await self.add_dst_nodes(nodes=(src_node, ))
@ -151,27 +155,22 @@ class MeshConsumer(AsyncWebsocketConsumer):
if isinstance(msg, messages.MeshRouteRequestMessage):
if msg.address == MESH_ROOT_ADDRESS:
await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace")
messages.MeshRouteTraceMessage(
await self.send_msg(messages.MeshRouteTraceMessage(
src=MESH_ROOT_ADDRESS,
dst=msg.src,
request_id=msg.request_id,
trace=[MESH_ROOT_ADDRESS],
).send()
))
else:
# todo: find a way to send a "no route" message if there is no route
await self.log_text(MESH_ROOT_ADDRESS, "requesting route response responsible uplink")
await self.log_text(MESH_ROOT_ADDRESS, "route request about someone else, sending response")
self.open_requests.add(msg.request_id)
await self.channel_layer.group_send(get_mesh_comm_group(msg.address), {
"type": "mesh.send_route_response",
"request_id": msg.request_id,
"channel": self.channel_name,
"dst": msg.src,
})
await self.delayed_group_send(5, self.channel_name, {
"type": "mesh.no_route_response",
"request_id": msg.request_id,
"dst": msg.src,
})
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.address)
await self.send_msg(messages.MeshRouteResponseMessage(
src=MESH_ROOT_ADDRESS,
dst=msg.src,
request_id=msg.request_id,
route=uplink.node_id if uplink else MESH_NONE_ADDRESS,
))
@database_sync_to_async
def create_uplink_in_database(self, address):
@ -193,7 +192,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
async def ping_regularly(self):
while True:
await asyncio.sleep(5)
await asyncio.sleep(UPLINK_PING)
await MeshUplink.objects.filter(pk=self.uplink.pk).aupdate(last_ping=timezone.now())
async def delayed_group_send(self, delay: int, group: str, msg: dict):
@ -219,9 +218,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
"""
message handler: if we are not the given uplink, leave this group
"""
if data["uplink"] != self.uplink.node.address:
if data["uplink"] != self.channel_name:
await self.log_text(data["address"], "node now served by new consumer")
await self.remove_dst_nodes((data["address"], ))
# going the short way cause the other consumer will already have done database stuff
self.dst_nodes.discard(data["address"])
async def mesh_send(self, data):
if self.uplink.node.address == data["exclude_uplink_address"]:
@ -236,35 +236,6 @@ class MeshConsumer(AsyncWebsocketConsumer):
return
await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"])
async def mesh_send_route_response(self, data):
await self.log_text(self.uplink.node.address, "we're the uplink for this address, sending route response...")
messages.MeshRouteResponseMessage(
src=MESH_ROOT_ADDRESS,
dst=data["dst"],
request_id=data["request_id"],
route=self.uplink.node.address,
).send()
async_to_sync(self.channel_layer.send)(data["channel"], {
"type": "mesh.route_response_sent",
"request_id": data["request_id"],
})
async def mesh_route_response_sent(self, data):
self.open_requests.discard(data["request_id"])
async def mesh_no_route_response(self, data):
print('no route response check')
if data["request_id"] not in self.open_requests:
print('a route was sent')
return
print('sending no route')
messages.MeshRouteResponseMessage(
src=MESH_ROOT_ADDRESS,
dst=data["dst"],
request_id=data["request_id"],
route=MESH_NONE_ADDRESS,
).send()
"""
helper functions
"""
@ -331,26 +302,22 @@ class MeshConsumer(AsyncWebsocketConsumer):
def _add_destination(self, address):
with transaction.atomic():
node = MeshNode.objects.select_for_update().get(address=address)
# create group name for this address
group = get_mesh_comm_group(address)
# if we aren't handling this address yet, join the group
if address not in self.dst_nodes:
async_to_sync(self.channel_layer.group_add)(group, self.channel_name)
self.dst_nodes.add(address)
# tell other consumers to leave the group
async_to_sync(self.channel_layer.group_send)(group, {
"type": "mesh.dst_node_uplink",
"node": address,
"uplink": self.uplink.node.address
})
# update database
node.uplink = self.uplink,
node.last_signin = timezone.now()
node.save()
# tell other consumers that it's us now
async_to_sync(self.channel_layer.group_send)(MESH_ALL_UPLINKS_GROUP, {
"type": "mesh.dst_node_uplink",
"node": address,
"uplink": self.channel_name
})
# if we aren't handling this address yet, write it down
if address not in self.dst_nodes:
self.dst_nodes.add(address)
async def remove_dst_nodes(self, addresses):
for address in tuple(addresses):
await self.log_text(address, "destination removed")
@ -360,19 +327,18 @@ class MeshConsumer(AsyncWebsocketConsumer):
@database_sync_to_async
def _remove_destination(self, address):
with transaction.atomic():
node = MeshNode.objects.select_for_update().get(address=address)
try:
node = MeshNode.objects.select_for_update().get(address=address, uplink=self.uplink)
except MeshNode.DoesNotExist:
pass
else:
node.uplink = None
node.save()
# create group name for this address
group = get_mesh_comm_group(address)
# leave the group
# no longer serving this node
if address in self.dst_nodes:
async_to_sync(self.channel_layer.group_discard)(group, self.channel_name)
self.dst_nodes.discard(address)
node.uplink = None
node.save()
class MeshUIConsumer(AsyncJsonWebsocketConsumer):
def __init__(self):
@ -407,7 +373,7 @@ class MeshUIConsumer(AsyncJsonWebsocketConsumer):
self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]}
for recipient in msg_to_send["recipients"]:
MeshMessage.fromjson({
await MeshMessage.fromjson({
'dst': recipient,
**msg_to_send["msg_data"],
}).send(sender=self.channel_name)

View file

@ -1,5 +1,6 @@
import time
from asgiref.sync import async_to_sync
from django import forms
from django.core.exceptions import ValidationError
from django.http import Http404
@ -85,10 +86,10 @@ class MeshMessageForm(forms.Form):
recipients = self.get_recipients()
for recipient in recipients:
print('sending to ', recipient)
MeshMessage.fromjson({
async_to_sync(MeshMessage.fromjson({
'dst': recipient,
**msg_data,
}).send()
}).send)()
class MeshRouteRequestForm(MeshMessageForm):

View file

@ -3,13 +3,13 @@ from enum import IntEnum, unique
from typing import TypeVar
import channels
from asgiref.sync import async_to_sync
from channels.db import database_sync_to_async
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat,
VarBytesFormat, VarStrFormat, normalize_name)
from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat,
RangeResultItem, RawFTMEntry)
from c3nav.mesh.utils import get_mesh_comm_group
from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
MESH_NONE_ADDRESS = '00:00:00:00:00:00'
@ -94,13 +94,25 @@ class MeshMessage(StructType, union_type_field="msg_type"):
raise TypeError('duplicate use of c_struct_name %s' % c_struct_name)
MeshMessage.c_structs[c_struct_name] = cls
def send(self, sender=None, exclude_uplink_address=None):
async_to_sync(channels.layers.get_channel_layer().group_send)(get_mesh_comm_group(self.dst), {
async def send(self, sender=None, exclude_uplink_address=None) -> bool:
data = {
"type": "mesh.send",
"sender": sender,
"exclude_uplink_address": exclude_uplink_address,
"msg": MeshMessage.tojson(self),
})
}
if self.dst in (MESH_CHILDREN_ADDRESS, MESH_BROADCAST_ADDRESS):
await channels.layers.get_channel_layer().group_send(MESH_ALL_UPLINKS_GROUP, data)
return True
from c3nav.mesh.models import MeshNode
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(self.dst)
if not uplink:
return False
if uplink.node_id == exclude_uplink_address:
return False
await channels.layers.get_channel_layer().send(uplink.name, data)
@classmethod
def get_ignore_c_fields(self):

View file

@ -1,6 +1,6 @@
from collections import UserDict, namedtuple
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from functools import cached_property
from operator import attrgetter
from typing import Any, Mapping, Optional, Self
@ -8,6 +8,7 @@ from typing import Any, Mapping, Optional, Self
from django.contrib.auth import get_user_model
from django.db import NotSupportedError, models
from django.db.models import Q, UniqueConstraint
from django.utils import timezone
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
@ -15,6 +16,7 @@ from c3nav.mesh.dataformats import BoardType
from c3nav.mesh.messages import ChipType, ConfigFirmwareMessage, ConfigHardwareMessage
from c3nav.mesh.messages import MeshMessage as MeshMessage
from c3nav.mesh.messages import MeshMessageType
from c3nav.mesh.utils import UPLINK_TIMEOUT
FirmwareLookup = namedtuple('FirmwareLookup', ('sha256_hash', 'chip', 'project_name', 'version', 'idf_version'))
@ -200,6 +202,21 @@ class MeshNode(models.Model):
def board(self) -> ChipType:
return self.last_messages[MeshMessageType.CONFIG_BOARD].parsed.board_config.board
def get_uplink(self) -> Optional["MeshUplink"]:
if self.uplink_id is None:
return None
if self.uplink.last_ping + timedelta(seconds=UPLINK_TIMEOUT) < timezone.now():
return None
return self.uplink
@classmethod
def get_node_and_uplink(self, address) -> Optional["MeshUplink"]:
try:
dst_node = MeshNode.objects.select_related('uplink').get(address=address)
except MeshNode.DoesNotExist:
return False
return dst_node.get_uplink()
class MeshUplink(models.Model):
"""

View file

@ -1,8 +1,13 @@
from operator import attrgetter
def get_mesh_comm_group(address):
return 'mesh_comm_%s' % address.replace(':', '-')
def get_mesh_uplink_group(address):
return 'mesh_uplink_%s' % address.replace(':', '-')
MESH_ALL_UPLINKS_GROUP = "mesh_uplink_all"
UPLINK_PING = 5
UPLINK_TIMEOUT = UPLINK_PING+5
def indent_c(code):