From 18144dde305d0789db662877b79c50468bb3b3a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Tue, 7 Nov 2023 14:25:54 +0100 Subject: [PATCH] our consumers are async now, hurray --- src/c3nav/mesh/consumers.py | 275 ++++++++++++++++++++---------------- src/c3nav/mesh/tasks.py | 11 -- 2 files changed, 155 insertions(+), 131 deletions(-) delete mode 100644 src/c3nav/mesh/tasks.py diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 702628bc..827aa1f7 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -1,7 +1,10 @@ +import asyncio import traceback +from asyncio import get_event_loop from asgiref.sync import async_to_sync -from channels.generic.websocket import JsonWebsocketConsumer, WebsocketConsumer +from channels.db import database_sync_to_async +from channels.generic.websocket import AsyncJsonWebsocketConsumer, AsyncWebsocketConsumer from django.db import transaction from django.utils import timezone @@ -9,47 +12,51 @@ from c3nav.mesh import messages from c3nav.mesh.messages import (MESH_BROADCAST_ADDRESS, MESH_NONE_ADDRESS, MESH_ROOT_ADDRESS, MeshMessage, MeshMessageType) from c3nav.mesh.models import MeshNode, MeshUplink, NodeMessage -from c3nav.mesh.tasks import send_channel_msg from c3nav.mesh.utils import get_mesh_comm_group -# noinspection PyAttributeOutsideInit -class MeshConsumer(WebsocketConsumer): - def connect(self): - # todo: auth +class MeshConsumer(AsyncWebsocketConsumer): + def __init__(self): + super().__init__() self.uplink = None - self.log_text(None, "new mesh websocket connection") self.dst_nodes = set() self.open_requests = set() - self.accept() + self.ping_task = None - # todo: ping stuff + async def connect(self): + # todo: auth - def disconnect(self, close_code): - self.log_text(self.uplink.node, "mesh websocket disconnected") + await self.log_text(None, "new mesh websocket connection") + await self.accept() + self.ping_task = get_event_loop().create_task(self.ping_regularly()) + + async def disconnect(self, close_code): + self.ping_task.cancel() + await self.log_text(self.uplink.node, "mesh websocket disconnected") if self.uplink is not None: # leave broadcast group - async_to_sync(self.channel_layer.group_discard)( + await self.channel_layer.group_discard( get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name ) # remove all other destinations - self.remove_dst_nodes(self.dst_nodes) + await self.remove_dst_nodes(self.dst_nodes) # set end reason (unless we set it to replaced already) - MeshUplink.objects.filter( + # todo: make this better? idk + await MeshUplink.objects.filter( pk=self.uplink.pk, ).exclude( end_reason=MeshUplink.EndReason.REPLACED - ).update( + ).aupdate( end_reason=MeshUplink.EndReason.CLOSED ) - def send_msg(self, msg, sender=None, exclude_uplink_address=None): + async def send_msg(self, msg, sender=None, exclude_uplink_address=None): # print("sending", msg, MeshMessage.encode(msg).hex(' ', 1)) # self.log_text(msg.dst, "sending %s" % msg) - self.send(bytes_data=MeshMessage.encode(msg)) - async_to_sync(self.channel_layer.group_send)("mesh_msg_sent", { + await self.send(bytes_data=MeshMessage.encode(msg)) + await self.channel_layer.group_send("mesh_msg_sent", { "type": "mesh.msg_sent", "timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"), "channel": self.channel_name, @@ -59,7 +66,7 @@ class MeshConsumer(WebsocketConsumer): # "msg": msg.tojson(), # not doing this part for privacy reasons }) - def receive(self, text_data=None, bytes_data=None): + async def receive(self, text_data=None, bytes_data=None): if bytes_data is None: return try: @@ -75,14 +82,14 @@ class MeshConsumer(WebsocketConsumer): print('Received message for forwarding:', msg) if not self.uplink: - self.log_text(None, "received message not for us before sign in message, ignoring...") + await self.log_text(None, "received message not for us before sign in message, ignoring...") print('no sign in yet, ignoring') return # trace messages collect node adresses before forwarding if isinstance(msg, messages.MeshRouteTraceMessage): print('adding ourselves to trace message before forwarding') - self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding") + await self.log_text(MESH_ROOT_ADDRESS, "adding ourselves to trace message before forwarding") msg.trace.append(MESH_ROOT_ADDRESS) msg.send(exclude_uplink_address=self.uplink.node.address) @@ -90,74 +97,60 @@ class MeshConsumer(WebsocketConsumer): # don't handle this message unless it's a broadcast message if msg.dst != messages.MESH_BROADCAST_ADDRESS: # don't handle this message unless it's a broadcast message - self.log_text(MESH_ROOT_ADDRESS, "received non-broadcast message not for us, forwarding...") + await self.log_text(MESH_ROOT_ADDRESS, "received non-broadcast message not for us, forwarding...") return print('it\'s a broadcast so it\'s also for us') - self.log_text(MESH_ROOT_ADDRESS, "received broadcast message, forwarding and handling...") + await self.log_text(MESH_ROOT_ADDRESS, "received broadcast message, forwarding and handling...") # print('Received message:', msg) - src_node, created = MeshNode.objects.get_or_create(address=msg.src) + src_node, created = await MeshNode.objects.aget_or_create(address=msg.src) if isinstance(msg, messages.MeshSigninMessage): - with transaction.atomic(): - # tatabase fumbling, lock the mesh node database row - locked_node = MeshNode.objects.select_for_update().get(address=msg.src) + await self.create_uplink_in_database(msg.src) - # close other uplinks in the database (they might add their own close reason in a bit) - locked_node.uplink_sessions.filter(end_reason__isnull=True).update( - end_reason=MeshUplink.EndReason.NEW_TIMEOUT - ) + # inform other uplinks to shut down + await self.channel_layer.group_send(get_mesh_comm_group(msg.src), { + "type": "mesh.uplink_consumer", + "name": self.channel_name, + }) - # create our own uplink in the database - self.uplink = MeshUplink.objects.create( - node=locked_node, - last_ping=timezone.now(), - name=self.channel_name, - ) + # log message, since we will not log it further down + await self.log_received_message(src_node, msg) - # inform other uplinks to shut down - async_to_sync(self.channel_layer.group_send)(get_mesh_comm_group(msg.src), { - "type": "mesh.uplink_consumer", - "name": self.channel_name, - }) + # inform signed in uplink node about its layer + await self.send_msg(messages.MeshLayerAnnounceMessage( + src=messages.MESH_ROOT_ADDRESS, + dst=msg.src, + layer=messages.NO_LAYER + )) - # log message, since we will not log it further down - self.log_received_message(src_node, msg) + # add signed in uplink node to broadcast group + await self.channel_layer.group_add( + get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name + ) - # inform signed in uplink node about its layer - self.send_msg(messages.MeshLayerAnnounceMessage( - src=messages.MESH_ROOT_ADDRESS, - dst=msg.src, - layer=messages.NO_LAYER - )) - - # add signed in uplink node to broadcast group - async_to_sync(self.channel_layer.group_add)( - get_mesh_comm_group(MESH_BROADCAST_ADDRESS), self.channel_name - ) - - # add this node as a destination that this uplink handles (duh) - self.add_dst_nodes(nodes=(src_node, )) + # add this node as a destination that this uplink handles (duh) + await self.add_dst_nodes(nodes=(src_node, )) return if self.uplink is None: print('Expected sign-in message, but got a different one!') - self.close() + await self.close() return - self.log_received_message(src_node, msg) + await self.log_received_message(src_node, msg) if isinstance(msg, messages.MeshAddDestinationsMessage): - self.add_dst_nodes(addresses=msg.addresses) + await self.add_dst_nodes(addresses=msg.addresses) if isinstance(msg, messages.MeshRemoveDestinationsMessage): - self.remove_dst_nodes(addresses=msg.addresses) + await self.remove_dst_nodes(addresses=msg.addresses) if isinstance(msg, messages.MeshRouteRequestMessage): if msg.address == MESH_ROOT_ADDRESS: - self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace") + await self.log_text(MESH_ROOT_ADDRESS, "route request about us, start a trace") messages.MeshRouteTraceMessage( src=MESH_ROOT_ADDRESS, dst=msg.src, @@ -166,50 +159,85 @@ class MeshConsumer(WebsocketConsumer): ).send() else: # todo: find a way to send a "no route" message if there is no route - self.log_text(MESH_ROOT_ADDRESS, "requesting route response responsible uplink") + await self.log_text(MESH_ROOT_ADDRESS, "requesting route response responsible uplink") self.open_requests.add(msg.request_id) - async_to_sync(self.channel_layer.group_send)(get_mesh_comm_group(msg.address), { + await self.channel_layer.group_send(get_mesh_comm_group(msg.address), { "type": "mesh.send_route_response", "request_id": msg.request_id, "channel": self.channel_name, "dst": msg.src, }) - send_channel_msg.apply_async((self.channel_name, { + await self.delayed_group_send(5, self.channel_name, { "type": "mesh.no_route_response", "request_id": msg.request_id, "dst": msg.src, - }), countdown=5) + }) - def mesh_uplink_consumer(self, data): - # message handler: if we are not the given uplink, leave this group + @database_sync_to_async + def create_uplink_in_database(self, address): + with transaction.atomic(): + # tatabase fumbling, lock the mesh node database row + locked_node = MeshNode.objects.select_for_update().get(address=address) + + # close other uplinks in the database (they might add their own close reason in a bit) + locked_node.uplink_sessions.filter(end_reason__isnull=True).update( + end_reason=MeshUplink.EndReason.NEW_TIMEOUT + ) + + # create our own uplink in the database + self.uplink = MeshUplink.objects.create( + node=locked_node, + last_ping=timezone.now(), + name=self.channel_name, + ) + + async def ping_regularly(self): + while True: + await asyncio.sleep(5) + await MeshUplink.objects.filter(pk=self.uplink.pk).aupdate(last_ping=timezone.now()) + + async def delayed_group_send(self, delay: int, group: str, msg: dict): + await asyncio.sleep(delay) + await self.channel_layer.group_send(group, msg) + + """ + internal event handlers + """ + + async def mesh_uplink_consumer(self, data): + """ + message handler: if we are not the given uplink, leave this group + """ if data["name"] != self.channel_name: - self.log_text(self.uplink.node, "shutting down, uplink now served by new consumer") - MeshUplink.objects.filter(pk=self.uplink.pk,).update( + await self.log_text(self.uplink.node, "shutting down, uplink now served by new consumer") + await MeshUplink.objects.filter(pk=self.uplink.pk,).aupdate( end_reason=MeshUplink.EndReason.REPLACED ) - self.close() + await self.close() - def mesh_dst_node_uplink(self, data): - # message handler: if we are not the given uplink, leave this group + async def mesh_dst_node_uplink(self, data): + """ + message handler: if we are not the given uplink, leave this group + """ if data["uplink"] != self.uplink.node.address: - self.log_text(data["address"], "node now served by new consumer") - self.remove_dst_nodes((data["address"], )) + await self.log_text(data["address"], "node now served by new consumer") + await self.remove_dst_nodes((data["address"], )) - def mesh_send(self, data): + async def mesh_send(self, data): if self.uplink.node.address == data["exclude_uplink_address"]: if data["msg"]["dst"] == MESH_BROADCAST_ADDRESS: - self.log_text( + await self.log_text( self.uplink.node.address, "not forwarding this broadcast message via us since it came from here" ) else: - self.log_text( + await self.log_text( self.uplink.node.address, "we're the route for this message but it came from here so... no" ) return - self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"]) + await self.send_msg(MeshMessage.fromjson(data["msg"]), data["sender"]) - def mesh_send_route_response(self, data): - self.log_text(self.uplink.node.address, "we're the uplink for this address, sending route response...") + async def mesh_send_route_response(self, data): + await self.log_text(self.uplink.node.address, "we're the uplink for this address, sending route response...") messages.MeshRouteResponseMessage( src=MESH_ROOT_ADDRESS, dst=data["dst"], @@ -221,10 +249,10 @@ class MeshConsumer(WebsocketConsumer): "request_id": data["request_id"], }) - def mesh_route_response_sent(self, data): + async def mesh_route_response_sent(self, data): self.open_requests.discard(data["request_id"]) - def mesh_no_route_response(self, data): + async def mesh_no_route_response(self, data): print('no route response check') if data["request_id"] not in self.open_requests: print('a route was sent') @@ -237,25 +265,29 @@ class MeshConsumer(WebsocketConsumer): route=MESH_NONE_ADDRESS, ).send() - def log_received_message(self, src_node: MeshNode, msg: messages.MeshMessage): + """ + helper functions + """ + + async def log_received_message(self, src_node: MeshNode, msg: messages.MeshMessage): as_json = MeshMessage.tojson(msg) - async_to_sync(self.channel_layer.group_send)("mesh_msg_received", { + await self.channel_layer.group_send("mesh_msg_received", { "type": "mesh.msg_received", "timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"), "channel": self.channel_name, "uplink": self.uplink.node.address if self.uplink else None, "msg": as_json, }) - NodeMessage.objects.create( + await NodeMessage.objects.acreate( uplink=self.uplink, src_node=src_node, message_type=msg.msg_type.name, data=as_json, ) - def log_text(self, address, text): + async def log_text(self, address, text): address = getattr(address, 'address', address) - async_to_sync(self.channel_layer.group_send)("mesh_log", { + await self.channel_layer.group_send("mesh_log", { "type": "mesh.log_entry", "timestamp": timezone.now().strftime("%d.%m.%y %H:%M:%S.%f"), "channel": self.channel_name, @@ -265,7 +297,7 @@ class MeshConsumer(WebsocketConsumer): }) print("MESH %s: [%s] %s" % (self.uplink.node, address, text)) - def add_dst_nodes(self, nodes=None, addresses=None): + async def add_dst_nodes(self, nodes=None, addresses=None): nodes = list(nodes) if nodes else [] addresses = set(addresses) if addresses else set() @@ -273,7 +305,7 @@ class MeshConsumer(WebsocketConsumer): missing_addresses = addresses - set(node.address for node in nodes) if missing_addresses: - MeshNode.objects.bulk_create( + await MeshNode.objects.abulk_create( [MeshNode(address=address) for address in missing_addresses], ignore_conflicts=True ) @@ -282,25 +314,25 @@ class MeshConsumer(WebsocketConsumer): addresses |= missing_addresses for address in addresses: - self.log_text(address, "destination added") + await self.log_text(address, "destination added") # create group name for this address group = get_mesh_comm_group(address) # if we aren't handling this address yet, join the group if address not in self.dst_nodes: - async_to_sync(self.channel_layer.group_add)(group, self.channel_name) + await self.channel_layer.group_add(group, self.channel_name) self.dst_nodes.add(address) # tell other consumers to leave the group - async_to_sync(self.channel_layer.group_send)(group, { + await self.channel_layer.group_send(group, { "type": "mesh.dst_node_uplink", "node": address, "uplink": self.uplink.node.address }) # tell the node to dump its current information - self.send_msg( + await self.send_msg( messages.ConfigDumpMessage( src=messages.MESH_ROOT_ADDRESS, dst=address, @@ -308,43 +340,46 @@ class MeshConsumer(WebsocketConsumer): ) # add the stuff to the db as well - MeshNode.objects.filter(address__in=addresses).update( + await MeshNode.objects.filter(address__in=addresses).aupdate( uplink=self.uplink, last_signin=timezone.now(), ) - def remove_dst_nodes(self, addresses): + async def remove_dst_nodes(self, addresses): for address in tuple(addresses): - self.log_text(address, "destination removed") + await self.log_text(address, "destination removed") # create group name for this address group = get_mesh_comm_group(address) # leave the group if address in self.dst_nodes: - async_to_sync(self.channel_layer.group_discard)(group, self.channel_name) + await self.channel_layer.group_discard(group, self.channel_name) self.dst_nodes.discard(address) # add the stuff to the db as well # todo: shouldn't do this because of race condition? - MeshNode.objects.filter(address__in=addresses, uplink=self.uplink).update(uplink=None) + await MeshNode.objects.filter(address__in=addresses, uplink=self.uplink).aupdate(uplink=None) -class MeshUIConsumer(JsonWebsocketConsumer): - def connect(self): - # todo: auth - self.accept() +class MeshUIConsumer(AsyncJsonWebsocketConsumer): + def __init__(self): + super().__init__() self.msg_sent_filter = {} self.msg_received_filter = {} - def receive_json(self, content, **kwargs): + async def connect(self): + # todo: auth + await self.accept() + + async def receive_json(self, content, **kwargs): if content.get("subscribe", None) == "log": - async_to_sync(self.channel_layer.group_add)("mesh_log", self.channel_name) + await self.channel_layer.group_add("mesh_log", self.channel_name) if content.get("subscribe", None) == "msg_sent": - async_to_sync(self.channel_layer.group_add)("mesh_msg_sent", self.channel_name) + await self.channel_layer.group_add("mesh_msg_sent", self.channel_name) self.msg_sent_filter = dict(content.get("filter", {})) if content.get("subscribe", None) == "msg_received": - async_to_sync(self.channel_layer.group_add)("mesh_msg_sent", self.channel_name) + await self.channel_layer.group_add("mesh_msg_sent", self.channel_name) self.msg_received_filter = dict(content.get("filter", {})) if "send_msg" in content: msg_to_send = self.scope["session"].pop("mesh_msg_%s" % content["send_msg"], None) @@ -352,11 +387,11 @@ class MeshUIConsumer(JsonWebsocketConsumer): return self.scope["session"].save() - async_to_sync(self.channel_layer.group_add)("mesh_msg_sent", self.channel_name) + await self.channel_layer.group_add("mesh_msg_sent", self.channel_name) self.msg_sent_filter = {"sender": self.channel_name} if msg_to_send["msg_data"]["msg_type"] == MeshMessageType.MESH_ROUTE_REQUEST.name: - async_to_sync(self.channel_layer.group_add)("mesh_msg_received", self.channel_name) + await self.channel_layer.group_add("mesh_msg_received", self.channel_name) self.msg_received_filter = {"request_id": msg_to_send["msg_data"]["request_id"]} for recipient in msg_to_send["recipients"]: @@ -365,10 +400,10 @@ class MeshUIConsumer(JsonWebsocketConsumer): **msg_to_send["msg_data"], }).send(sender=self.channel_name) - def mesh_log_entry(self, data): - self.send_json(data) + async def mesh_log_entry(self, data): + await self.send_json(data) - def mesh_msg_sent(self, data): + async def mesh_msg_sent(self, data): for key, value in self.msg_sent_filter.items(): if isinstance(value, list): if data.get(key, None) not in value: @@ -376,9 +411,9 @@ class MeshUIConsumer(JsonWebsocketConsumer): else: if data.get(key, None) != value: return - self.send_json(data) + await self.send_json(data) - def mesh_msg_received(self, data): + async def mesh_msg_received(self, data): for key, filter_value in self.msg_received_filter.items(): value = data.get(key, data["msg"].get(key, None)) if isinstance(filter_value, list): @@ -387,9 +422,9 @@ class MeshUIConsumer(JsonWebsocketConsumer): else: if value != filter_value: return - self.send_json(data) + await self.send_json(data) - def disconnect(self, code): - async_to_sync(self.channel_layer.group_discard)("mesh_log", self.channel_name) - async_to_sync(self.channel_layer.group_discard)("mesh_msg_sent", self.channel_name) - async_to_sync(self.channel_layer.group_discard)("mesh_msg_received", self.channel_name) + async def disconnect(self, code): + await self.channel_layer.group_discard("mesh_log", self.channel_name) + await self.channel_layer.group_discard("mesh_msg_sent", self.channel_name) + await self.channel_layer.group_discard("mesh_msg_received", self.channel_name) diff --git a/src/c3nav/mesh/tasks.py b/src/c3nav/mesh/tasks.py deleted file mode 100644 index e2022b17..00000000 --- a/src/c3nav/mesh/tasks.py +++ /dev/null @@ -1,11 +0,0 @@ -import channels -from asgiref.sync import async_to_sync - -from c3nav.celery import app - - -@app.task(bind=True, max_retries=3) -def send_channel_msg(self, layer, msg): - # todo: this is… not ideal, is it? - print("task sending channel msg...") - async_to_sync(channels.layers.get_channel_layer().send)(layer, msg)