diff --git a/src/c3nav/mesh/baseformats.py b/src/c3nav/mesh/baseformats.py index 33571dd8..a001624c 100644 --- a/src/c3nav/mesh/baseformats.py +++ b/src/c3nav/mesh/baseformats.py @@ -420,7 +420,7 @@ class StructType: return data @classmethod - def fromjson(cls, data: dict): + def fromjson(cls, data: dict) -> Self: data = data.copy() # todo: upgrade_json diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 353f7d42..706cfc06 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -1,6 +1,7 @@ import asyncio import traceback from asyncio import get_event_loop +from functools import cached_property from asgiref.sync import async_to_sync from channels.db import database_sync_to_async @@ -12,7 +13,7 @@ from c3nav.mesh import messages from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, MeshMessageType) from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage -from c3nav.mesh.utils import get_mesh_comm_group +from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP, UPLINK_PING, get_mesh_uplink_group class MeshConsumer(AsyncWebsocketConsumer): @@ -35,9 +36,7 @@ class MeshConsumer(AsyncWebsocketConsumer): await self.log_text(self.uplink.node, "mesh websocket disconnected") if self.uplink is not None: # leave broadcast group - await self.channel_layer.group_discard( - get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name - ) + await self.channel_layer.group_discard("mesh_comm_broadcast", self.channel_name) # remove all other destinations await self.remove_dst_nodes(self.dst_nodes) @@ -66,6 +65,10 @@ class MeshConsumer(AsyncWebsocketConsumer): # "msg": msg.tojson(), # not doing this part for privacy reasons }) + @cached_property + def same_uplinks_group(self): + return 'mesh_uplink_%s' % self.uplink.node.address.replace(':', '-') + async def receive(self, text_data=None, bytes_data=None): if bytes_data is None: return @@ -92,7 +95,10 @@ class MeshConsumer(AsyncWebsocketConsumer): await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding") msg.trace.append(MESH_ROOT_ADDRESS) - msg.send(exclude_uplink_address=self.uplink.node.address) + result = await msg.send(exclude_uplink_address=self.uplink.node.address) + + if not result: + print('message had no route') # don't handle this message unless it's a broadcast message if msg.dst != messages.MESH_BROADCAST_ADDRESS: @@ -110,7 +116,7 @@ class MeshConsumer(AsyncWebsocketConsumer): await self.create_uplink_in_database(msg.src) # inform other uplinks to shut down - await self.channel_layer.group_send(get_mesh_comm_group(msg.src), { + await self.channel_layer.group_send(get_mesh_uplink_group(msg.src), { "type": "mesh.uplink_consumer", "name": self.channel_name, }) @@ -126,9 +132,7 @@ class MeshConsumer(AsyncWebsocketConsumer): )) # add signed in uplink node to broadcast group - await self.channel_layer.group_add( - get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name - ) + await self.channel_layer.group_add(MESH_ALL_UPLINKS_GROUP, self.channel_name) # add this node as a destination that this uplink handles (duh) await self.add_dst_nodes(nodes=(src_node, )) @@ -151,27 +155,22 @@ class MeshConsumer(AsyncWebsocketConsumer): if isinstance(msg, messages.MeshRouteRequestMessage): if msg.address == MESH_ROOT_ADDRESS: await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace") - messages.MeshRouteTraceMessage( + await self.send_msg(messages.MeshRouteTraceMessage( src=MESH_ROOT_ADDRESS, dst=msg.src, request_id=msg.request_id, trace=[MESH_ROOT_ADDRESS], - ).send() + )) else: - # todo: find a way to send a "no route" message if there is no route - await self.log_text(MESH_ROOT_ADDRESS, "requesting route response responsible uplink") + await self.log_text(MESH_ROOT_ADDRESS, "route request about someone else, sending response") self.open_requests.add(msg.request_id) - await self.channel_layer.group_send(get_mesh_comm_group(msg.address), { - "type": "mesh.send_route_response", - "request_id": msg.request_id, - "channel": self.channel_name, - "dst": msg.src, - }) - await self.delayed_group_send(5, self.channel_name, { - "type": "mesh.no_route_response", - "request_id": msg.request_id, - "dst": msg.src, - }) + uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(msg.address) + await self.send_msg(messages.MeshRouteResponseMessage( + src=MESH_ROOT_ADDRESS, + dst=msg.src, + request_id=msg.request_id, + route=uplink.node_id if uplink else MESH_NONE_ADDRESS, + )) @database_sync_to_async def create_uplink_in_database(self, address): @@ -193,7 +192,7 @@ class MeshConsumer(AsyncWebsocketConsumer): async def ping_regularly(self): while True: - await asyncio.sleep(5) + await asyncio.sleep(UPLINK_PING) await MeshUplink.objects.filter(pk=self.uplink.pk).aupdate(last_ping=timezone.now()) async def delayed_group_send(self, delay: int, group: str, msg: dict): @@ -219,9 +218,10 @@ class MeshConsumer(AsyncWebsocketConsumer): """ message handler: if we are not the given uplink, leave this group """ - if data["uplink"] != self.uplink.node.address: + if data["uplink"] != self.channel_name: await self.log_text(data["address"], "node now served by new consumer") - await self.remove_dst_nodes((data["address"], )) + # going the short way cause the other consumer will already have done database stuff + self.dst_nodes.discard(data["address"]) async def mesh_send(self, data): if self.uplink.node.address == data["exclude_uplink_address"]: @@ -236,35 +236,6 @@ class MeshConsumer(AsyncWebsocketConsumer): return await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"]) - async def mesh_send_route_response(self, data): - await self.log_text(self.uplink.node.address, "we're the uplink for this address, sending route response...") - messages.MeshRouteResponseMessage( - src=MESH_ROOT_ADDRESS, - dst=data["dst"], - request_id=data["request_id"], - route=self.uplink.node.address, - ).send() - async_to_sync(self.channel_layer.send)(data["channel"], { - "type": "mesh.route_response_sent", - "request_id": data["request_id"], - }) - - async def mesh_route_response_sent(self, data): - self.open_requests.discard(data["request_id"]) - - async def mesh_no_route_response(self, data): - print('no route response check') - if data["request_id"] not in self.open_requests: - print('a route was sent') - return - print('sending no route') - messages.MeshRouteResponseMessage( - src=MESH_ROOT_ADDRESS, - dst=data["dst"], - request_id=data["request_id"], - route=MESH_NONE_ADDRESS, - ).send() - """ helper functions """ @@ -331,26 +302,22 @@ class MeshConsumer(AsyncWebsocketConsumer): def _add_destination(self, address): with transaction.atomic(): node = MeshNode.objects.select_for_update().get(address=address) - - # create group name for this address - group = get_mesh_comm_group(address) - - # if we aren't handling this address yet, join the group - if address not in self.dst_nodes: - async_to_sync(self.channel_layer.group_add)(group, self.channel_name) - self.dst_nodes.add(address) - - # tell other consumers to leave the group - async_to_sync(self.channel_layer.group_send)(group, { - "type": "mesh.dst_node_uplink", - "node": address, - "uplink": self.uplink.node.address - }) - + # update database node.uplink = self.uplink, node.last_signin = timezone.now() node.save() + # tell other consumers that it's us now + async_to_sync(self.channel_layer.group_send)(MESH_ALL_UPLINKS_GROUP, { + "type": "mesh.dst_node_uplink", + "node": address, + "uplink": self.channel_name + }) + + # if we aren't handling this address yet, write it down + if address not in self.dst_nodes: + self.dst_nodes.add(address) + async def remove_dst_nodes(self, addresses): for address in tuple(addresses): await self.log_text(address, "destination removed") @@ -360,19 +327,18 @@ class MeshConsumer(AsyncWebsocketConsumer): @database_sync_to_async def _remove_destination(self, address): with transaction.atomic(): - node = MeshNode.objects.select_for_update().get(address=address) + try: + node = MeshNode.objects.select_for_update().get(address=address, uplink=self.uplink) + except MeshNode.DoesNotExist: + pass + else: + node.uplink = None + node.save() - # create group name for this address - group = get_mesh_comm_group(address) - - # leave the group + # no longer serving this node if address in self.dst_nodes: - async_to_sync(self.channel_layer.group_discard)(group, self.channel_name) self.dst_nodes.discard(address) - node.uplink = None - node.save() - class MeshUIConsumer(AsyncJsonWebsocketConsumer): def __init__(self): @@ -407,7 +373,7 @@ class MeshUIConsumer(AsyncJsonWebsocketConsumer): self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]} for recipient in msg_to_send["recipients"]: - MeshMessage.fromjson({ + await MeshMessage.fromjson({ 'dst': recipient, **msg_to_send["msg_data"], }).send(sender=self.channel_name) diff --git a/src/c3nav/mesh/forms.py b/src/c3nav/mesh/forms.py index 3cdb239a..69d720ce 100644 --- a/src/c3nav/mesh/forms.py +++ b/src/c3nav/mesh/forms.py @@ -1,5 +1,6 @@ import time +from asgiref.sync import async_to_sync from django import forms from django.core.exceptions import ValidationError from django.http import Http404 @@ -85,10 +86,10 @@ class MeshMessageForm(forms.Form): recipients = self.get_recipients() for recipient in recipients: print('sending to ', recipient) - MeshMessage.fromjson({ + async_to_sync(MeshMessage.fromjson({ 'dst': recipient, **msg_data, - }).send() + }).send)() class MeshRouteRequestForm(MeshMessageForm): diff --git a/src/c3nav/mesh/messages.py b/src/c3nav/mesh/messages.py index 09aac4f7..a25c2712 100644 --- a/src/c3nav/mesh/messages.py +++ b/src/c3nav/mesh/messages.py @@ -3,13 +3,13 @@ from enum import IntEnum, unique from typing import TypeVar import channels -from asgiref.sync import async_to_sync +from channels.db import database_sync_to_async from c3nav.mesh.baseformats import (BoolFormat, EnumFormat, FixedStrFormat, SimpleFormat, StructType, VarArrayFormat, VarBytesFormat, VarStrFormat, normalize_name) from c3nav.mesh.dataformats import (BoardConfig, FirmwareAppDescription, MacAddressesListFormat, MacAddressFormat, RangeResultItem, RawFTMEntry) -from c3nav.mesh.utils import get_mesh_comm_group +from c3nav.mesh.utils import MESH_ALL_UPLINKS_GROUP MESH_ROOT_ADDRESS = '00:00:00:00:00:00' MESH_NONE_ADDRESS = '00:00:00:00:00:00' @@ -94,13 +94,25 @@ class MeshMessage(StructType, union_type_field="msg_type"): raise TypeError('duplicate use of c_struct_name %s' % c_struct_name) MeshMessage.c_structs[c_struct_name] = cls - def send(self, sender=None, exclude_uplink_address=None): - async_to_sync(channels.layers.get_channel_layer().group_send)(get_mesh_comm_group(self.dst), { + async def send(self, sender=None, exclude_uplink_address=None) -> bool: + data = { "type": "mesh.send", "sender": sender, "exclude_uplink_address": exclude_uplink_address, "msg": MeshMessage.tojson(self), - }) + } + + if self.dst in (MESH_CHILDREN_ADDRESS, MESH_BROADCAST_ADDRESS): + await channels.layers.get_channel_layer().group_send(MESH_ALL_UPLINKS_GROUP, data) + return True + + from c3nav.mesh.models import MeshNode + uplink = database_sync_to_async(MeshNode.get_node_and_uplink)(self.dst) + if not uplink: + return False + if uplink.node_id == exclude_uplink_address: + return False + await channels.layers.get_channel_layer().send(uplink.name, data) @classmethod def get_ignore_c_fields(self): diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py index ac731a9b..733972c1 100644 --- a/src/c3nav/mesh/models.py +++ b/src/c3nav/mesh/models.py @@ -1,6 +1,6 @@ from collections import UserDict, namedtuple from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timedelta from functools import cached_property from operator import attrgetter from typing import Any, Mapping, Optional, Self @@ -8,6 +8,7 @@ from typing import Any, Mapping, Optional, Self from django.contrib.auth import get_user_model from django.db import NotSupportedError, models from django.db.models import Q, UniqueConstraint +from django.utils import timezone from django.utils.text import slugify from django.utils.translation import gettext_lazy as _ @@ -15,6 +16,7 @@ from c3nav.mesh.dataformats import BoardType from c3nav.mesh.messages import ChipType, ConfigFirmwareMessage, ConfigHardwareMessage from c3nav.mesh.messages import MeshMessage as MeshMessage from c3nav.mesh.messages import MeshMessageType +from c3nav.mesh.utils import UPLINK_TIMEOUT FirmwareLookup = namedtuple('FirmwareLookup', ('sha256_hash', 'chip', 'project_name', 'version', 'idf_version')) @@ -200,6 +202,21 @@ class MeshNode(models.Model): def board(self) -> ChipType: return self.last_messages[MeshMessageType.CONFIG_BOARD].parsed.board_config.board + def get_uplink(self) -> Optional["MeshUplink"]: + if self.uplink_id is None: + return None + if self.uplink.last_ping + timedelta(seconds=UPLINK_TIMEOUT) < timezone.now(): + return None + return self.uplink + + @classmethod + def get_node_and_uplink(self, address) -> Optional["MeshUplink"]: + try: + dst_node = MeshNode.objects.select_related('uplink').get(address=address) + except MeshNode.DoesNotExist: + return False + return dst_node.get_uplink() + class MeshUplink(models.Model): """ diff --git a/src/c3nav/mesh/utils.py b/src/c3nav/mesh/utils.py index a0784f98..bd232905 100644 --- a/src/c3nav/mesh/utils.py +++ b/src/c3nav/mesh/utils.py @@ -1,8 +1,13 @@ from operator import attrgetter -def get_mesh_comm_group(address): - return 'mesh_comm_%s' % address.replace(':', '-') +def get_mesh_uplink_group(address): + return 'mesh_uplink_%s' % address.replace(':', '-') + + +MESH_ALL_UPLINKS_GROUP = "mesh_uplink_all" +UPLINK_PING = 5 +UPLINK_TIMEOUT = UPLINK_PING+5 def indent_c(code):