team-3/src/c3nav/mesh/consumers.py

103 lines
3.7 KiB
Python
Raw Normal View History

2022-04-15 20:02:42 +02:00
import traceback
from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncWebsocketConsumer
from c3nav.mesh import messages
2022-04-15 20:02:42 +02:00
from c3nav.mesh.models import MeshNode, NodeMessage
class MeshConsumer(AsyncWebsocketConsumer):
async def connect(self):
2022-04-06 17:25:46 +02:00
print('connected!')
2022-04-15 20:57:11 +02:00
# todo: auth
self.node = None
await self.accept()
async def disconnect(self, close_code):
2022-04-06 17:25:46 +02:00
print('disconnected!')
2022-04-15 20:57:11 +02:00
if self.node is not None:
2022-04-15 21:06:57 +02:00
await self.remove_route(self.node.address)
2022-04-15 20:57:11 +02:00
await self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name)
await self.channel_layer.group_discard('route_broadcast', self.channel_name)
2022-04-06 22:56:08 +02:00
async def send_msg(self, msg):
print('Sending message:', msg)
await self.send(bytes_data=msg.encode())
async def receive(self, text_data=None, bytes_data=None):
if bytes_data is None:
return
2022-04-15 20:02:42 +02:00
try:
msg = messages.Message.decode(bytes_data)
except Exception:
traceback.print_exc()
return
if msg.dst != messages.ROOT_ADDRESS and msg.dst != messages.PARENT_ADDRESS:
print('Received message for forwarding:', msg)
# todo: this message isn't for us, forward it
return
print('Received message:', msg)
2022-04-15 20:57:11 +02:00
if isinstance(msg, messages.MeshSigninMessage):
2022-04-15 20:57:11 +02:00
self.node, created = await self.get_node(msg.src)
if created:
print('New node signing in!')
print(self.node)
await self.log_received_message(msg)
2022-04-06 22:56:08 +02:00
await self.send_msg(messages.MeshLayerAnnounceMessage(
2022-04-15 20:02:42 +02:00
src=messages.ROOT_ADDRESS,
2022-04-06 17:25:46 +02:00
dst=msg.src,
layer=messages.NO_LAYER
2022-04-06 22:56:08 +02:00
))
await self.send_msg(messages.ConfigDumpMessage(
2022-04-15 20:02:42 +02:00
src=messages.ROOT_ADDRESS,
2022-04-06 22:56:08 +02:00
dst=msg.src,
))
2022-04-15 20:57:11 +02:00
await self.channel_layer.group_add('route_%s' % self.node.address.replace(':', ''), self.channel_name)
await self.channel_layer.group_add('route_broadcast', self.channel_name)
2022-04-15 21:06:57 +02:00
await self.set_parent_of_nodes(None, (self.node.address, ))
await self.add_route_to_nodes(self.node.address, (self.node.address,))
2022-04-15 20:57:11 +02:00
return
if self.node is None:
print('Expected sign-in message, but got a different one!')
await self.close()
return
await self.log_received_message(msg)
@database_sync_to_async
def get_node(self, address):
return MeshNode.objects.get_or_create(address=address)
2022-04-15 20:02:42 +02:00
@database_sync_to_async
2022-04-15 20:57:11 +02:00
def log_received_message(self, msg: messages.Message):
NodeMessage.objects.create(
node=self.node,
2022-04-15 20:02:42 +02:00
message_type=msg.msg_id,
data=msg.tojson()
)
2022-04-15 20:57:11 +02:00
@database_sync_to_async
def create_nodes(self, addresses):
MeshNode.objects.bulk_create(MeshNode(address=address) for address in addresses)
@database_sync_to_async
def set_parent_of_nodes(self, parent_address, node_addresses):
MeshNode.objects.filter(address__in=node_addresses).update(parent_node_id=parent_address)
@database_sync_to_async
2022-04-15 21:06:57 +02:00
def add_route_to_nodes(self, route_address, node_addresses):
2022-04-15 20:57:11 +02:00
MeshNode.objects.filter(address__in=node_addresses).update(route_id=route_address)
2022-04-15 21:06:57 +02:00
@database_sync_to_async
def remove_route(self, route_address):
MeshNode.objects.filter(route_id=route_address).update(route_id=None)
@database_sync_to_async
def remove_route_to_nodes(self, route_address, node_addresses):
MeshNode.objects.filter(address__in=node_addresses, route_id=route_address).update(route_id=None)