diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 26c43b04..b710e1cb 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -10,11 +10,15 @@ from c3nav.mesh.models import MeshNode, NodeMessage class MeshConsumer(AsyncWebsocketConsumer): async def connect(self): print('connected!') + # todo: auth + self.node = None await self.accept() async def disconnect(self, close_code): print('disconnected!') - pass + if self.node is not None: + await self.channel_layer.group_discard('route_%s' % self.node.address.replace(':', ''), self.channel_name) + await self.channel_layer.group_discard('route_broadcast', self.channel_name) async def send_msg(self, msg): print('Sending message:', msg) @@ -35,8 +39,13 @@ class MeshConsumer(AsyncWebsocketConsumer): return print('Received message:', msg) - node = await self.log_received_message(msg) # noqa + if isinstance(msg, messages.MeshSigninMessage): + self.node, created = await self.get_node(msg.src) + if created: + print('New node signing in!') + print(self.node) + await self.log_received_message(msg) await self.send_msg(messages.MeshLayerAnnounceMessage( src=messages.ROOT_ADDRESS, dst=msg.src, @@ -46,12 +55,39 @@ class MeshConsumer(AsyncWebsocketConsumer): src=messages.ROOT_ADDRESS, dst=msg.src, )) + await self.channel_layer.group_add('route_%s' % self.node.address.replace(':', ''), self.channel_name) + await self.channel_layer.group_add('route_broadcast', self.channel_name) + await self.set_parent_of_nodes(None, (self.node, )) + await self.set_route_of_nodes(self.node, (self.node,)) + return + + if self.node is None: + print('Expected sign-in message, but got a different one!') + await self.close() + return + + await self.log_received_message(msg) @database_sync_to_async - def log_received_message(self, msg: messages.Message) -> MeshNode: - node, created = MeshNode.objects.get_or_create(address=msg.src) - return NodeMessage.objects.create( - node=node, + def get_node(self, address): + return MeshNode.objects.get_or_create(address=address) + + @database_sync_to_async + def log_received_message(self, msg: messages.Message): + NodeMessage.objects.create( + node=self.node, message_type=msg.msg_id, data=msg.tojson() ) + + @database_sync_to_async + def create_nodes(self, addresses): + MeshNode.objects.bulk_create(MeshNode(address=address) for address in addresses) + + @database_sync_to_async + def set_parent_of_nodes(self, parent_address, node_addresses): + MeshNode.objects.filter(address__in=node_addresses).update(parent_node_id=parent_address) + + @database_sync_to_async + def set_route_of_nodes(self, route_address, node_addresses): + MeshNode.objects.filter(address__in=node_addresses).update(route_id=route_address) diff --git a/src/c3nav/mesh/migrations/0001_initial.py b/src/c3nav/mesh/migrations/0001_initial.py index e601e6a1..ed614994 100644 --- a/src/c3nav/mesh/migrations/0001_initial.py +++ b/src/c3nav/mesh/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.0.3 on 2022-04-15 18:00 +# Generated by Django 4.0.3 on 2022-04-15 18:52 from django.db import migrations, models import django.db.models.deletion @@ -15,8 +15,7 @@ class Migration(migrations.Migration): migrations.CreateModel( name='MeshNode', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('address', models.CharField(max_length=17, unique=True, verbose_name='mac address')), + ('address', models.CharField(max_length=17, primary_key=True, serialize=False, verbose_name='mac address')), ('first_seen', models.DateTimeField(auto_now_add=True, verbose_name='first seen')), ('parent_node', models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='child_nodes', to='mesh.meshnode', verbose_name='parent node')), ('route', models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='routed_nodes', to='mesh.meshnode', verbose_name='route')), diff --git a/src/c3nav/mesh/models.py b/src/c3nav/mesh/models.py index bb413654..c3c8f520 100644 --- a/src/c3nav/mesh/models.py +++ b/src/c3nav/mesh/models.py @@ -10,7 +10,7 @@ class ChipID(models.IntegerChoices): class MeshNode(models.Model): - address = models.CharField(_('mac address'), max_length=17, unique=True) + address = models.CharField(_('mac address'), max_length=17, primary_key=True) first_seen = models.DateTimeField(_('first seen'), auto_now_add=True) parent_node = models.ForeignKey('MeshNode', models.PROTECT, null=True, related_name='child_nodes', verbose_name=_('parent node')) diff --git a/src/c3nav/settings.py b/src/c3nav/settings.py index c332a8d7..9c526a1d 100644 --- a/src/c3nav/settings.py +++ b/src/c3nav/settings.py @@ -291,6 +291,22 @@ ROOT_URLCONF = 'c3nav.urls' WSGI_APPLICATION = 'c3nav.wsgi.application' ASGI_APPLICATION = 'c3nav.asgi.application' +if HAS_REDIS: + CHANNEL_LAYERS = { + 'default': { + 'BACKEND': 'channels_redis.core.RedisChannelLayer', + 'CONFIG': { + "hosts": [config.get('redis', 'location')], + }, + }, + } +else: + CHANNEL_LAYERS = { + "default": { + "BACKEND": "channels.layers.InMemoryChannelLayer" + } + } + USE_I18N = True USE_L10N = True USE_TZ = True