diff --git a/src/c3nav/asgi.py b/src/c3nav/asgi.py index 528accc4..5a723042 100644 --- a/src/c3nav/asgi.py +++ b/src/c3nav/asgi.py @@ -3,8 +3,10 @@ from contextlib import suppress from channels.auth import AuthMiddlewareStack from channels.routing import ProtocolTypeRouter, URLRouter +from channels.security.websocket import AllowedHostsOriginValidator from django.core.asgi import get_asgi_application +from c3nav.control.middleware import UserPermissionsChannelMiddleware from c3nav.urls import websocket_urlpatterns os.environ.setdefault("DJANGO_SETTINGS_MODULE", "c3nav.settings") @@ -12,8 +14,12 @@ django_asgi = get_asgi_application() application = ProtocolTypeRouter({ "http": django_asgi, - "websocket": AuthMiddlewareStack( - URLRouter(websocket_urlpatterns) + "websocket": AllowedHostsOriginValidator( + AuthMiddlewareStack( + UserPermissionsChannelMiddleware( + URLRouter(websocket_urlpatterns), + ), + ), ), }) @@ -30,7 +36,4 @@ with suppress(ImportError): Mount(settings.STATIC_URL, app=StaticFiles(directory=settings.STATIC_ROOT), name='static'), Mount('/', app=django_asgi), ]), - "websocket": AuthMiddlewareStack( - URLRouter(websocket_urlpatterns) - ), }) diff --git a/src/c3nav/control/middleware.py b/src/c3nav/control/middleware.py index 03e257c1..13d9d4d2 100644 --- a/src/c3nav/control/middleware.py +++ b/src/c3nav/control/middleware.py @@ -1,8 +1,15 @@ -from django.utils.functional import SimpleLazyObject, lazy +from channels.db import database_sync_to_async +from channels.middleware import BaseMiddleware as BaseChannelsMiddleware +from django.utils.functional import LazyObject, SimpleLazyObject, lazy from c3nav.control.models import UserPermissions, UserSpaceAccess +class UserPermissionsLazyObject(LazyObject): + def _setup(self): + raise ValueError("Accessing scope user before it is ready.") + + class UserPermissionsMiddleware: """ This middleware adds request.user_permissions to get the UserPermissions for the current request/user. @@ -32,3 +39,12 @@ class UserPermissionsMiddleware: request.user_permissions = SimpleLazyObject(lambda: self.get_user_permissions(request)) request.user_space_accesses = lazy(self.get_user_space_accesses, dict)(request) return self.get_response(request) + + +class UserPermissionsChannelMiddleware(BaseChannelsMiddleware): + async def __call__(self, scope, receive, send): + # todo: this doesn't seem to actually be lazy. and scope["user"] isn't either? + scope["user_permissions"] = UserPermissionsLazyObject() + scope["user_permissions"]._wrapped = await database_sync_to_async(UserPermissions.get_for_user)(scope["user"]) + + return await super().__call__(scope, receive, send) diff --git a/src/c3nav/control/models.py b/src/c3nav/control/models.py index 8fb006d3..54414b76 100644 --- a/src/c3nav/control/models.py +++ b/src/c3nav/control/models.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Dict +from typing import Dict, Self from django.conf import settings from django.contrib.auth.models import User @@ -73,7 +73,7 @@ class UserPermissions(models.Model): yield @classmethod - def get_for_user(cls, user, force=False) -> 'UserPermissions': + def get_for_user(cls, user, force=False) -> Self: if not user.is_authenticated: return cls() cache_key = cls.get_cache_key(user.pk) diff --git a/src/c3nav/mesh/consumers.py b/src/c3nav/mesh/consumers.py index 706cfc06..a1e020d6 100644 --- a/src/c3nav/mesh/consumers.py +++ b/src/c3nav/mesh/consumers.py @@ -5,6 +5,7 @@ from functools import cached_property from asgiref.sync import async_to_sync from channels.db import database_sync_to_async +from channels.exceptions import DenyConnection from channels.generic.websocket import AsyncJsonWebsocketConsumer, AsyncWebsocketConsumer from django.db import transaction from django.utils import timezone @@ -347,7 +348,8 @@ class MeshUIConsumer(AsyncJsonWebsocketConsumer): self.msg_received_filter = {} async def connect(self): - # todo: auth + if not self.scope["user_permisions"].mesh_control: + raise DenyConnection await self.accept() async def receive_json(self, content, **kwargs):