mesh routing now fully database-bound

This commit is contained in:
Laura Klünder 2023-11-07 16:35:46 +01:00
parent ee539f678a
commit 66810b20b0
6 changed files with 94 additions and 93 deletions

View file

@ -420,7 +420,7 @@ class StructType:
return data return data
@classmethod @classmethod
def fromjson(cls, data: dict): def fromjson(cls, data: dict) -> Self:
data = data.copy() data = data.copy()
# todo: upgrade_json # todo: upgrade_json

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import traceback import traceback
from asyncio import get_event_loop from asyncio import get_event_loop
from functools import cached_property
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from channels.db import database_sync_to_async from channels.db import database_sync_to_async
@ -12,7 +13,7 @@ from c3nav.mesh import messages
from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage,
MeshMessageType) MeshMessageType)
from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage
from c3nav.mesh.utils import get_mesh_comm_group from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP, UPLINK_PING, get_mesh_uplink_group
class MeshConsumer(AsyncWebsocketConsumer): class MeshConsumer(AsyncWebsocketConsumer):
@ -35,9 +36,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
await self.log_text(self.uplink.node, "mesh websocket disconnected") await self.log_text(self.uplink.node, "mesh websocket disconnected")
if self.uplink is not None: if self.uplink is not None:
# leave broadcast group # leave broadcast group
await self.channel_layer.group_discard( await self.channel_layer.group_discard("mesh_comm_broadcast", self.channel_name)
get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name
)
# remove all other destinations # remove all other destinations
await self.remove_dst_nodes(self.dst_nodes) await self.remove_dst_nodes(self.dst_nodes)
@ -66,6 +65,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
# "msg": msg.tojson(), # not doing this part for privacy reasons # "msg": msg.tojson(), # not doing this part for privacy reasons
}) })
@cached_property
def same_uplinks_group(self):
return 'mesh_uplink_%s' % self.uplink.node.address.replace(':', '-')
async def receive(self, text_data=None, bytes_data=None): async def receive(self, text_data=None, bytes_data=None):
if bytes_data is None: if bytes_data is None:
return return
@ -92,7 +95,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding") await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding")
msg.trace.append(MESH_ROOT_ADDRESS) msg.trace.append(MESH_ROOT_ADDRESS)
msg.send(exclude_uplink_address=self.uplink.node.address) result = await msg.send(exclude_uplink_address=self.uplink.node.address)
if not result:
print('message had no route')
# don't handle this message unless it's a broadcast message # don't handle this message unless it's a broadcast message
if msg.dst != messages.MESH_BROADCAST_ADDRESS: if msg.dst != messages.MESH_BROADCAST_ADDRESS:
@ -110,7 +116,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
await self.create_uplink_in_database(msg.src) await self.create_uplink_in_database(msg.src)
# inform other uplinks to shut down # inform other uplinks to shut down
await self.channel_layer.group_send(get_mesh_comm_group(msg.src), { await self.channel_layer.group_send(get_mesh_uplink_group(msg.src), {
"type": "mesh.uplink_consumer", "type": "mesh.uplink_consumer",
"name": self.channel_name, "name": self.channel_name,
}) })
@ -126,9 +132,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
)) ))
# add signed in uplink node to broadcast group # add signed in uplink node to broadcast group
await self.channel_layer.group_add( await self.channel_layer.group_add(MESH_ALL_UPLINKS_GROUP, self.channel_name)
get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name
)
# add this node as a destination that this uplink handles (duh) # add this node as a destination that this uplink handles (duh)
await self.add_dst_nodes(nodes=(src_node, )) await self.add_dst_nodes(nodes=(src_node, ))
@ -151,27 +155,22 @@ class MeshConsumer(AsyncWebsocketConsumer):
if isinstance(msg, messages.MeshRouteRequestMessage): if isinstance(msg, messages.MeshRouteRequestMessage):
if msg.address == MESH_ROOT_ADDRESS: if msg.address == MESH_ROOT_ADDRESS:
await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace") await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace")
messages.MeshRouteTraceMessage( await self.send_msg(messages.MeshRouteTraceMessage(
src=MESH_ROOT_ADDRESS, src=MESH_ROOT_ADDRESS,
dst=msg.src, dst=msg.src,
request_id=msg.request_id, request_id=msg.request_id,
trace=[MESH_ROOT_ADDRESS], trace=[MESH_ROOT_ADDRESS],
).send() ))
else: else:
# todo: find a way to send a "no route" message if there is no route await self.log_text(MESH_ROOT_ADDRESS, "route request about someone else, sending response")
await self.log_text(MESH_ROOT_ADDRESS, "requesting route response responsible uplink")
self.open_requests.add(msg.request_id) self.open_requests.add(msg.request_id)
await self.channel_layer.group_send(get_mesh_comm_group(msg.address), { uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.address)
"type": "mesh.send_route_response", await self.send_msg(messages.MeshRouteResponseMessage(
"request_id": msg.request_id, src=MESH_ROOT_ADDRESS,
"channel": self.channel_name, dst=msg.src,
"dst": msg.src, request_id=msg.request_id,
}) route=uplink.node_id if uplink else MESH_NONE_ADDRESS,
await self.delayed_group_send(5, self.channel_name, { ))
"type": "mesh.no_route_response",
"request_id": msg.request_id,
"dst": msg.src,
})
@database_sync_to_async @database_sync_to_async
def create_uplink_in_database(self, address): def create_uplink_in_database(self, address):
@ -193,7 +192,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
async def ping_regularly(self): async def ping_regularly(self):
while True: while True:
await asyncio.sleep(5) await asyncio.sleep(UPLINK_PING)
await MeshUplink.objects.filter(pk=self.uplink.pk).aupdate(last_ping=timezone.now()) await MeshUplink.objects.filter(pk=self.uplink.pk).aupdate(last_ping=timezone.now())
async def delayed_group_send(self, delay: int, group: str, msg: dict): async def delayed_group_send(self, delay: int, group: str, msg: dict):
@ -219,9 +218,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
""" """
message handler: if we are not the given uplink, leave this group message handler: if we are not the given uplink, leave this group
""" """
if data["uplink"] != self.uplink.node.address: if data["uplink"] != self.channel_name:
await self.log_text(data["address"], "node now served by new consumer") await self.log_text(data["address"], "node now served by new consumer")
await self.remove_dst_nodes((data["address"], )) # going the short way cause the other consumer will already have done database stuff
self.dst_nodes.discard(data["address"])
async def mesh_send(self, data): async def mesh_send(self, data):
if self.uplink.node.address == data["exclude_uplink_address"]: if self.uplink.node.address == data["exclude_uplink_address"]:
@ -236,35 +236,6 @@ class MeshConsumer(AsyncWebsocketConsumer):
return return
await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"]) await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"])
async def mesh_send_route_response(self, data):
await self.log_text(self.uplink.node.address, "we're the uplink for this address, sending route response...")
messages.MeshRouteResponseMessage(
src=MESH_ROOT_ADDRESS,
dst=data["dst"],
request_id=data["request_id"],
route=self.uplink.node.address,
).send()
async_to_sync(self.channel_layer.send)(data["channel"], {
"type": "mesh.route_response_sent",
"request_id": data["request_id"],
})
async def mesh_route_response_sent(self, data):
self.open_requests.discard(data["request_id"])
async def mesh_no_route_response(self, data):
print('no route response check')
if data["request_id"] not in self.open_requests:
print('a route was sent')
return
print('sending no route')
messages.MeshRouteResponseMessage(
src=MESH_ROOT_ADDRESS,
dst=data["dst"],
request_id=data["request_id"],
route=MESH_NONE_ADDRESS,
).send()
""" """
helper functions helper functions
""" """
@ -331,26 +302,22 @@ class MeshConsumer(AsyncWebsocketConsumer):
def _add_destination(self, address): def _add_destination(self, address):
with transaction.atomic(): with transaction.atomic():
node = MeshNode.objects.select_for_update().get(address=address) node = MeshNode.objects.select_for_update().get(address=address)
# update database
# create group name for this address
group = get_mesh_comm_group(address)
# if we aren't handling this address yet, join the group
if address not in self.dst_nodes:
async_to_sync(self.channel_layer.group_add)(group, self.channel_name)
self.dst_nodes.add(address)
# tell other consumers to leave the group
async_to_sync(self.channel_layer.group_send)(group, {
"type": "mesh.dst_node_uplink",
"node": address,
"uplink": self.uplink.node.address
})
node.uplink = self.uplink, node.uplink = self.uplink,
node.last_signin = timezone.now() node.last_signin = timezone.now()
node.save() node.save()
# tell other consumers that it's us now
async_to_sync(self.channel_layer.group_send)(MESH_ALL_UPLINKS_GROUP, {
"type": "mesh.dst_node_uplink",
"node": address,
"uplink": self.channel_name
})
# if we aren't handling this address yet, write it down
if address not in self.dst_nodes:
self.dst_nodes.add(address)
async def remove_dst_nodes(self, addresses): async def remove_dst_nodes(self, addresses):
for address in tuple(addresses): for address in tuple(addresses):
await self.log_text(address, "destination removed") await self.log_text(address, "destination removed")
@ -360,19 +327,18 @@ class MeshConsumer(AsyncWebsocketConsumer):
@database_sync_to_async @database_sync_to_async
def _remove_destination(self, address): def _remove_destination(self, address):
with transaction.atomic(): with transaction.atomic():
node = MeshNode.objects.select_for_update().get(address=address) try:
node = MeshNode.objects.select_for_update().get(address=address, uplink=self.uplink)
except MeshNode.DoesNotExist:
pass
else:
node.uplink = None
node.save()
# create group name for this address # no longer serving this node
group = get_mesh_comm_group(address)
# leave the group
if address in self.dst_nodes: if address in self.dst_nodes:
async_to_sync(self.channel_layer.group_discard)(group, self.channel_name)
self.dst_nodes.discard(address) self.dst_nodes.discard(address)
node.uplink = None
node.save()
class MeshUIConsumer(AsyncJsonWebsocketConsumer): class MeshUIConsumer(AsyncJsonWebsocketConsumer):
def __init__(self): def __init__(self):
@ -407,7 +373,7 @@ class MeshUIConsumer(AsyncJsonWebsocketConsumer):
self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]} self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]}
for recipient in msg_to_send["recipients"]: for recipient in msg_to_send["recipients"]:
MeshMessage.fromjson({ await MeshMessage.fromjson({
'dst': recipient, 'dst': recipient,
**msg_to_send["msg_data"], **msg_to_send["msg_data"],
}).send(sender=self.channel_name) }).send(sender=self.channel_name)

View file

@ -1,5 +1,6 @@
import time import time
from asgiref.sync import async_to_sync
from django import forms from django import forms
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.http import Http404 from django.http import Http404
@ -85,10 +86,10 @@ class MeshMessageForm(forms.Form):
recipients = self.get_recipients() recipients = self.get_recipients()
for recipient in recipients: for recipient in recipients:
print('sending to ', recipient) print('sending to ', recipient)
MeshMessage.fromjson({ async_to_sync(MeshMessage.fromjson({
'dst': recipient, 'dst': recipient,
**msg_data, **msg_data,
}).send() }).send)()
class MeshRouteRequestForm(MeshMessageForm): class MeshRouteRequestForm(MeshMessageForm):

View file

@ -3,13 +3,13 @@ from enum import IntEnum, unique
from typing import TypeVar from typing import TypeVar
import channels import channels
from asgiref.sync import async_to_sync from channels.db import database_sync_to_async
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat, from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat,
VarBytesFormat, VarStrFormat, normalize_name) VarBytesFormat, VarStrFormat, normalize_name)
from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat, from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat,
RangeResultItem, RawFTMEntry) RangeResultItem, RawFTMEntry)
from c3nav.mesh.utils import get_mesh_comm_group from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP
MESH_ROOT_ADDRESS = '00:00:00:00:00:00' MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
MESH_NONE_ADDRESS = '00:00:00:00:00:00' MESH_NONE_ADDRESS = '00:00:00:00:00:00'
@ -94,13 +94,25 @@ class MeshMessage(StructType, union_type_field="msg_type"):
raise TypeError('duplicate use of c_struct_name %s' % c_struct_name) raise TypeError('duplicate use of c_struct_name %s' % c_struct_name)
MeshMessage.c_structs[c_struct_name] = cls MeshMessage.c_structs[c_struct_name] = cls
def send(self, sender=None, exclude_uplink_address=None): async def send(self, sender=None, exclude_uplink_address=None) -> bool:
async_to_sync(channels.layers.get_channel_layer().group_send)(get_mesh_comm_group(self.dst), { data = {
"type": "mesh.send", "type": "mesh.send",
"sender": sender, "sender": sender,
"exclude_uplink_address": exclude_uplink_address, "exclude_uplink_address": exclude_uplink_address,
"msg": MeshMessage.tojson(self), "msg": MeshMessage.tojson(self),
}) }
if self.dst in (MESH_CHILDREN_ADDRESS, MESH_BROADCAST_ADDRESS):
await channels.layers.get_channel_layer().group_send(MESH_ALL_UPLINKS_GROUP, data)
return True
from c3nav.mesh.models import MeshNode
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(self.dst)
if not uplink:
return False
if uplink.node_id == exclude_uplink_address:
return False
await channels.layers.get_channel_layer().send(uplink.name, data)
@classmethod @classmethod
def get_ignore_c_fields(self): def get_ignore_c_fields(self):

View file

@ -1,6 +1,6 @@
from collections import UserDict, namedtuple from collections import UserDict, namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime, timedelta
from functools import cached_property from functools import cached_property
from operator import attrgetter from operator import attrgetter
from typing import Any, Mapping, Optional, Self from typing import Any, Mapping, Optional, Self
@ -8,6 +8,7 @@ from typing import Any, Mapping, Optional, Self
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db import NotSupportedError, models from django.db import NotSupportedError, models
from django.db.models import Q, UniqueConstraint from django.db.models import Q, UniqueConstraint
from django.utils import timezone
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -15,6 +16,7 @@ from c3nav.mesh.dataformats import BoardType
from c3nav.mesh.messages import ChipType, ConfigFirmwareMessage, ConfigHardwareMessage from c3nav.mesh.messages import ChipType, ConfigFirmwareMessage, ConfigHardwareMessage
from c3nav.mesh.messages import MeshMessage as MeshMessage from c3nav.mesh.messages import MeshMessage as MeshMessage
from c3nav.mesh.messages import MeshMessageType from c3nav.mesh.messages import MeshMessageType
from c3nav.mesh.utils import UPLINK_TIMEOUT
FirmwareLookup = namedtuple('FirmwareLookup', ('sha256_hash', 'chip', 'project_name', 'version', 'idf_version')) FirmwareLookup = namedtuple('FirmwareLookup', ('sha256_hash', 'chip', 'project_name', 'version', 'idf_version'))
@ -200,6 +202,21 @@ class MeshNode(models.Model):
def board(self) -> ChipType: def board(self) -> ChipType:
return self.last_messages[MeshMessageType.CONFIG_BOARD].parsed.board_config.board return self.last_messages[MeshMessageType.CONFIG_BOARD].parsed.board_config.board
def get_uplink(self) -> Optional["MeshUplink"]:
if self.uplink_id is None:
return None
if self.uplink.last_ping + timedelta(seconds=UPLINK_TIMEOUT) < timezone.now():
return None
return self.uplink
@classmethod
def get_node_and_uplink(self, address) -> Optional["MeshUplink"]:
try:
dst_node = MeshNode.objects.select_related('uplink').get(address=address)
except MeshNode.DoesNotExist:
return False
return dst_node.get_uplink()
class MeshUplink(models.Model): class MeshUplink(models.Model):
""" """

View file

@ -1,8 +1,13 @@
from operator import attrgetter from operator import attrgetter
def get_mesh_comm_group(address): def get_mesh_uplink_group(address):
return 'mesh_comm_%s' % address.replace(':', '-') return 'mesh_uplink_%s' % address.replace(':', '-')
MESH_ALL_UPLINKS_GROUP = "mesh_uplink_all"
UPLINK_PING = 5
UPLINK_TIMEOUT = UPLINK_PING+5
def indent_c(code): def indent_c(code):