mesh routing now fully database-bound
This commit is contained in:
parent
ee539f678a
commit
66810b20b0
6 changed files with 94 additions and 93 deletions
|
@ -420,7 +420,7 @@ class StructType:
|
|||
return data
|
||||
|
||||
@classmethod
|
||||
def fromjson(cls, data: dict):
|
||||
def fromjson(cls, data: dict) -> Self:
|
||||
data = data.copy()
|
||||
|
||||
# todo: upgrade_json
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import traceback
|
||||
from asyncio import get_event_loop
|
||||
from functools import cached_property
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.db import database_sync_to_async
|
||||
|
@ -12,7 +13,7 @@ from c3nav.mesh import messages
|
|||
from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage,
|
||||
MeshMessageType)
|
||||
from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage
|
||||
from c3nav.mesh.utils import get_mesh_comm_group
|
||||
from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP, UPLINK_PING, get_mesh_uplink_group
|
||||
|
||||
|
||||
class MeshConsumer(AsyncWebsocketConsumer):
|
||||
|
@ -35,9 +36,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
await self.log_text(self.uplink.node, "mesh websocket disconnected")
|
||||
if self.uplink is not None:
|
||||
# leave broadcast group
|
||||
await self.channel_layer.group_discard(
|
||||
get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name
|
||||
)
|
||||
await self.channel_layer.group_discard("mesh_comm_broadcast", self.channel_name)
|
||||
|
||||
# remove all other destinations
|
||||
await self.remove_dst_nodes(self.dst_nodes)
|
||||
|
@ -66,6 +65,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
# "msg": msg.tojson(), # not doing this part for privacy reasons
|
||||
})
|
||||
|
||||
@cached_property
|
||||
def same_uplinks_group(self):
|
||||
return 'mesh_uplink_%s' % self.uplink.node.address.replace(':', '-')
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
if bytes_data is None:
|
||||
return
|
||||
|
@ -92,7 +95,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding")
|
||||
msg.trace.append(MESH_ROOT_ADDRESS)
|
||||
|
||||
msg.send(exclude_uplink_address=self.uplink.node.address)
|
||||
result = await msg.send(exclude_uplink_address=self.uplink.node.address)
|
||||
|
||||
if not result:
|
||||
print('message had no route')
|
||||
|
||||
# don't handle this message unless it's a broadcast message
|
||||
if msg.dst != messages.MESH_BROADCAST_ADDRESS:
|
||||
|
@ -110,7 +116,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
await self.create_uplink_in_database(msg.src)
|
||||
|
||||
# inform other uplinks to shut down
|
||||
await self.channel_layer.group_send(get_mesh_comm_group(msg.src), {
|
||||
await self.channel_layer.group_send(get_mesh_uplink_group(msg.src), {
|
||||
"type": "mesh.uplink_consumer",
|
||||
"name": self.channel_name,
|
||||
})
|
||||
|
@ -126,9 +132,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
))
|
||||
|
||||
# add signed in uplink node to broadcast group
|
||||
await self.channel_layer.group_add(
|
||||
get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name
|
||||
)
|
||||
await self.channel_layer.group_add(MESH_ALL_UPLINKS_GROUP, self.channel_name)
|
||||
|
||||
# add this node as a destination that this uplink handles (duh)
|
||||
await self.add_dst_nodes(nodes=(src_node, ))
|
||||
|
@ -151,27 +155,22 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
if isinstance(msg, messages.MeshRouteRequestMessage):
|
||||
if msg.address == MESH_ROOT_ADDRESS:
|
||||
await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace")
|
||||
messages.MeshRouteTraceMessage(
|
||||
await self.send_msg(messages.MeshRouteTraceMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=msg.src,
|
||||
request_id=msg.request_id,
|
||||
trace=[MESH_ROOT_ADDRESS],
|
||||
).send()
|
||||
))
|
||||
else:
|
||||
# todo: find a way to send a "no route" message if there is no route
|
||||
await self.log_text(MESH_ROOT_ADDRESS, "requesting route response responsible uplink")
|
||||
await self.log_text(MESH_ROOT_ADDRESS, "route request about someone else, sending response")
|
||||
self.open_requests.add(msg.request_id)
|
||||
await self.channel_layer.group_send(get_mesh_comm_group(msg.address), {
|
||||
"type": "mesh.send_route_response",
|
||||
"request_id": msg.request_id,
|
||||
"channel": self.channel_name,
|
||||
"dst": msg.src,
|
||||
})
|
||||
await self.delayed_group_send(5, self.channel_name, {
|
||||
"type": "mesh.no_route_response",
|
||||
"request_id": msg.request_id,
|
||||
"dst": msg.src,
|
||||
})
|
||||
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.address)
|
||||
await self.send_msg(messages.MeshRouteResponseMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=msg.src,
|
||||
request_id=msg.request_id,
|
||||
route=uplink.node_id if uplink else MESH_NONE_ADDRESS,
|
||||
))
|
||||
|
||||
@database_sync_to_async
|
||||
def create_uplink_in_database(self, address):
|
||||
|
@ -193,7 +192,7 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
|
||||
async def ping_regularly(self):
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(UPLINK_PING)
|
||||
await MeshUplink.objects.filter(pk=self.uplink.pk).aupdate(last_ping=timezone.now())
|
||||
|
||||
async def delayed_group_send(self, delay: int, group: str, msg: dict):
|
||||
|
@ -219,9 +218,10 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
"""
|
||||
message handler: if we are not the given uplink, leave this group
|
||||
"""
|
||||
if data["uplink"] != self.uplink.node.address:
|
||||
if data["uplink"] != self.channel_name:
|
||||
await self.log_text(data["address"], "node now served by new consumer")
|
||||
await self.remove_dst_nodes((data["address"], ))
|
||||
# going the short way cause the other consumer will already have done database stuff
|
||||
self.dst_nodes.discard(data["address"])
|
||||
|
||||
async def mesh_send(self, data):
|
||||
if self.uplink.node.address == data["exclude_uplink_address"]:
|
||||
|
@ -236,35 +236,6 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
return
|
||||
await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"])
|
||||
|
||||
async def mesh_send_route_response(self, data):
|
||||
await self.log_text(self.uplink.node.address, "we're the uplink for this address, sending route response...")
|
||||
messages.MeshRouteResponseMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=data["dst"],
|
||||
request_id=data["request_id"],
|
||||
route=self.uplink.node.address,
|
||||
).send()
|
||||
async_to_sync(self.channel_layer.send)(data["channel"], {
|
||||
"type": "mesh.route_response_sent",
|
||||
"request_id": data["request_id"],
|
||||
})
|
||||
|
||||
async def mesh_route_response_sent(self, data):
|
||||
self.open_requests.discard(data["request_id"])
|
||||
|
||||
async def mesh_no_route_response(self, data):
|
||||
print('no route response check')
|
||||
if data["request_id"] not in self.open_requests:
|
||||
print('a route was sent')
|
||||
return
|
||||
print('sending no route')
|
||||
messages.MeshRouteResponseMessage(
|
||||
src=MESH_ROOT_ADDRESS,
|
||||
dst=data["dst"],
|
||||
request_id=data["request_id"],
|
||||
route=MESH_NONE_ADDRESS,
|
||||
).send()
|
||||
|
||||
"""
|
||||
helper functions
|
||||
"""
|
||||
|
@ -331,26 +302,22 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
def _add_destination(self, address):
|
||||
with transaction.atomic():
|
||||
node = MeshNode.objects.select_for_update().get(address=address)
|
||||
|
||||
# create group name for this address
|
||||
group = get_mesh_comm_group(address)
|
||||
|
||||
# if we aren't handling this address yet, join the group
|
||||
if address not in self.dst_nodes:
|
||||
async_to_sync(self.channel_layer.group_add)(group, self.channel_name)
|
||||
self.dst_nodes.add(address)
|
||||
|
||||
# tell other consumers to leave the group
|
||||
async_to_sync(self.channel_layer.group_send)(group, {
|
||||
"type": "mesh.dst_node_uplink",
|
||||
"node": address,
|
||||
"uplink": self.uplink.node.address
|
||||
})
|
||||
|
||||
# update database
|
||||
node.uplink = self.uplink,
|
||||
node.last_signin = timezone.now()
|
||||
node.save()
|
||||
|
||||
# tell other consumers that it's us now
|
||||
async_to_sync(self.channel_layer.group_send)(MESH_ALL_UPLINKS_GROUP, {
|
||||
"type": "mesh.dst_node_uplink",
|
||||
"node": address,
|
||||
"uplink": self.channel_name
|
||||
})
|
||||
|
||||
# if we aren't handling this address yet, write it down
|
||||
if address not in self.dst_nodes:
|
||||
self.dst_nodes.add(address)
|
||||
|
||||
async def remove_dst_nodes(self, addresses):
|
||||
for address in tuple(addresses):
|
||||
await self.log_text(address, "destination removed")
|
||||
|
@ -360,19 +327,18 @@ class MeshConsumer(AsyncWebsocketConsumer):
|
|||
@database_sync_to_async
|
||||
def _remove_destination(self, address):
|
||||
with transaction.atomic():
|
||||
node = MeshNode.objects.select_for_update().get(address=address)
|
||||
try:
|
||||
node = MeshNode.objects.select_for_update().get(address=address, uplink=self.uplink)
|
||||
except MeshNode.DoesNotExist:
|
||||
pass
|
||||
else:
|
||||
node.uplink = None
|
||||
node.save()
|
||||
|
||||
# create group name for this address
|
||||
group = get_mesh_comm_group(address)
|
||||
|
||||
# leave the group
|
||||
# no longer serving this node
|
||||
if address in self.dst_nodes:
|
||||
async_to_sync(self.channel_layer.group_discard)(group, self.channel_name)
|
||||
self.dst_nodes.discard(address)
|
||||
|
||||
node.uplink = None
|
||||
node.save()
|
||||
|
||||
|
||||
class MeshUIConsumer(AsyncJsonWebsocketConsumer):
|
||||
def __init__(self):
|
||||
|
@ -407,7 +373,7 @@ class MeshUIConsumer(AsyncJsonWebsocketConsumer):
|
|||
self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]}
|
||||
|
||||
for recipient in msg_to_send["recipients"]:
|
||||
MeshMessage.fromjson({
|
||||
await MeshMessage.fromjson({
|
||||
'dst': recipient,
|
||||
**msg_to_send["msg_data"],
|
||||
}).send(sender=self.channel_name)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import time
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from django import forms
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.http import Http404
|
||||
|
@ -85,10 +86,10 @@ class MeshMessageForm(forms.Form):
|
|||
recipients = self.get_recipients()
|
||||
for recipient in recipients:
|
||||
print('sending to ', recipient)
|
||||
MeshMessage.fromjson({
|
||||
async_to_sync(MeshMessage.fromjson({
|
||||
'dst': recipient,
|
||||
**msg_data,
|
||||
}).send()
|
||||
}).send)()
|
||||
|
||||
|
||||
class MeshRouteRequestForm(MeshMessageForm):
|
||||
|
|
|
@ -3,13 +3,13 @@ from enum import IntEnum, unique
|
|||
from typing import TypeVar
|
||||
|
||||
import channels
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.db import database_sync_to_async
|
||||
|
||||
from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat,
|
||||
VarBytesFormat, VarStrFormat, normalize_name)
|
||||
from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat,
|
||||
RangeResultItem, RawFTMEntry)
|
||||
from c3nav.mesh.utils import get_mesh_comm_group
|
||||
from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP
|
||||
|
||||
MESH_ROOT_ADDRESS = '00:00:00:00:00:00'
|
||||
MESH_NONE_ADDRESS = '00:00:00:00:00:00'
|
||||
|
@ -94,13 +94,25 @@ class MeshMessage(StructType, union_type_field="msg_type"):
|
|||
raise TypeError('duplicate use of c_struct_name %s' % c_struct_name)
|
||||
MeshMessage.c_structs[c_struct_name] = cls
|
||||
|
||||
def send(self, sender=None, exclude_uplink_address=None):
|
||||
async_to_sync(channels.layers.get_channel_layer().group_send)(get_mesh_comm_group(self.dst), {
|
||||
async def send(self, sender=None, exclude_uplink_address=None) -> bool:
|
||||
data = {
|
||||
"type": "mesh.send",
|
||||
"sender": sender,
|
||||
"exclude_uplink_address": exclude_uplink_address,
|
||||
"msg": MeshMessage.tojson(self),
|
||||
})
|
||||
}
|
||||
|
||||
if self.dst in (MESH_CHILDREN_ADDRESS, MESH_BROADCAST_ADDRESS):
|
||||
await channels.layers.get_channel_layer().group_send(MESH_ALL_UPLINKS_GROUP, data)
|
||||
return True
|
||||
|
||||
from c3nav.mesh.models import MeshNode
|
||||
uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(self.dst)
|
||||
if not uplink:
|
||||
return False
|
||||
if uplink.node_id == exclude_uplink_address:
|
||||
return False
|
||||
await channels.layers.get_channel_layer().send(uplink.name, data)
|
||||
|
||||
@classmethod
|
||||
def get_ignore_c_fields(self):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from collections import UserDict, namedtuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from functools import cached_property
|
||||
from operator import attrgetter
|
||||
from typing import Any, Mapping, Optional, Self
|
||||
|
@ -8,6 +8,7 @@ from typing import Any, Mapping, Optional, Self
|
|||
from django.contrib.auth import get_user_model
|
||||
from django.db import NotSupportedError, models
|
||||
from django.db.models import Q, UniqueConstraint
|
||||
from django.utils import timezone
|
||||
from django.utils.text import slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
@ -15,6 +16,7 @@ from c3nav.mesh.dataformats import BoardType
|
|||
from c3nav.mesh.messages import ChipType, ConfigFirmwareMessage, ConfigHardwareMessage
|
||||
from c3nav.mesh.messages import MeshMessage as MeshMessage
|
||||
from c3nav.mesh.messages import MeshMessageType
|
||||
from c3nav.mesh.utils import UPLINK_TIMEOUT
|
||||
|
||||
FirmwareLookup = namedtuple('FirmwareLookup', ('sha256_hash', 'chip', 'project_name', 'version', 'idf_version'))
|
||||
|
||||
|
@ -200,6 +202,21 @@ class MeshNode(models.Model):
|
|||
def board(self) -> ChipType:
|
||||
return self.last_messages[MeshMessageType.CONFIG_BOARD].parsed.board_config.board
|
||||
|
||||
def get_uplink(self) -> Optional["MeshUplink"]:
|
||||
if self.uplink_id is None:
|
||||
return None
|
||||
if self.uplink.last_ping + timedelta(seconds=UPLINK_TIMEOUT) < timezone.now():
|
||||
return None
|
||||
return self.uplink
|
||||
|
||||
@classmethod
|
||||
def get_node_and_uplink(self, address) -> Optional["MeshUplink"]:
|
||||
try:
|
||||
dst_node = MeshNode.objects.select_related('uplink').get(address=address)
|
||||
except MeshNode.DoesNotExist:
|
||||
return False
|
||||
return dst_node.get_uplink()
|
||||
|
||||
|
||||
class MeshUplink(models.Model):
|
||||
"""
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
from operator import attrgetter
|
||||
|
||||
|
||||
def get_mesh_comm_group(address):
|
||||
return 'mesh_comm_%s' % address.replace(':', '-')
|
||||
def get_mesh_uplink_group(address):
|
||||
return 'mesh_uplink_%s' % address.replace(':', '-')
|
||||
|
||||
|
||||
MESH_ALL_UPLINKS_GROUP = "mesh_uplink_all"
|
||||
UPLINK_PING = 5
|
||||
UPLINK_TIMEOUT = UPLINK_PING+5
|
||||
|
||||
|
||||
def indent_c(code):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue