diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 33d112f9..dfa2daa0 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -150,6 +150,11 @@ class MeshConsumer(AsyncWebsocketConsumer): src_node, created = await MeshNode.objects.aget_or_create(address=msg.src) if isinstance(msg, messages.MeshSigninMessage): + if not self.check_valid_address(msg.src): + print('reject node with invalid address address') + await self.close() + return + await self.create_uplink_in_database(msg.src) # inform other uplinks to shut down @@ -188,7 +193,10 @@ class MeshConsumer(AsyncWebsocketConsumer): node_status.last_msg[msg.msg_type] = msg if isinstance(msg, messages.MeshAddDestinationsMessage): - await self.add_dst_nodes(addresses=msg.addresses) + result = await self.add_dst_nodes(addresses=msg.addresses) + if not result: + print('disconnecting node that send invalid destinations', msg) + await self.close() if isinstance(msg, messages.MeshRemoveDestinationsMessage): await self.remove_dst_nodes(addresses=msg.addresses) @@ -486,10 +494,17 @@ class MeshConsumer(AsyncWebsocketConsumer): """ routing """ + @staticmethod + def check_valid_address(address): + return not (address.startswith('00:00:00') or address.startswith('ff:ff:ff')) + async def add_dst_nodes(self, nodes=None, addresses=None): nodes = list(nodes) if nodes else [] addresses = set(addresses) if addresses else set() + if not all(self.check_valid_address(a) for a in addresses): + return False + node_addresses = set(node.address for node in nodes) missing_addresses = addresses - set(node.address for node in nodes) @@ -513,6 +528,7 @@ class MeshConsumer(AsyncWebsocketConsumer): self.dst_nodes[address] = NodeState() await self.node_resend_ask(address) + return True @database_sync_to_async def _add_destination(self, address):